diff --git a/sdk/agentserver/.gitignore b/sdk/agentserver/.gitignore new file mode 100644 index 000000000000..e027b1803198 --- /dev/null +++ b/sdk/agentserver/.gitignore @@ -0,0 +1,5 @@ +# Spec Kit +.specify/ +specs/ +.github/ +.vscode/ diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index 3db4cc467557..e25628f22aa5 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -4,10 +4,41 @@ ### Features Added +- **Durable long-running agents** — New `@durable_task` decorator and supporting types for building crash-resilient, long-running agents that survive container crashes, OOM kills, and redeployments. Key capabilities: + - **Lifecycle automation** — `.run()` and `.start()` automatically start, resume, or recover tasks based on their current state in the task store. + - **Entry mode awareness** — `ctx.entry_mode` tells the function whether it was entered `"fresh"`, `"resumed"` from suspension, or `"recovered"` from a crash. + - **Suspend & resume** — `ctx.suspend(output=..., reason=...)` pauses execution for multi-turn agent patterns (e.g., waiting for user input). + - **TaskResult wrapper** — `run()` and `result()` return `TaskResult[Output]` with `.is_completed` / `.is_suspended` properties, making suspension a normal return value instead of an exception. + - **Streaming** — `ctx.stream(chunk)` emits incremental output; consumers iterate with `async for chunk in task_run`. + - **Cancellation & timeout** — Cooperative cancel via `ctx.cancel` event, configurable `timeout`, and `terminate()` for forced shutdown. + - **RetryPolicy** — Configurable retry with factory presets: `.exponential_backoff()`, `.fixed_delay()`, `.linear_backoff()`, `.no_retry()`. + - **Source auto-stamping** — The framework automatically stamps every task with provenance metadata: `type` (`agentserver.durable_task`), `name` (the decorator `name` option — the stable identity anchor), and `server_version` (the `x-platform-server` header value). Source is framework-owned and not user-overridable. A reserved tag `_durable_task_name` is also auto-stamped for LIST API filtering by function name. + - **Callable factories** — `tags`, `title`, and `description` accept `Callable[[Input, task_id], T]` for dynamic metadata computed at task creation time. + - **TaskMetadata** — Dict-like mutable progress metadata (`ctx.metadata["key"] = value`) with debounced auto-flush to the task store. Supports `[]`, `in`, `for`, `len`, `del`, plus convenience methods `.increment()` and `.append()`. + - **Handle operations** — `TaskRun.metadata` for progress snapshot reads, `TaskRun.delete()` for task cleanup, `TaskRun.refresh()` for re-fetching state from the store, `TaskRun.lease_expiry_count` for monitoring ownership churn. + - **TaskContext.description** — `ctx.description` exposes the task description string within the running function. + - **Configurable shutdown grace** — `DurableTaskManager(shutdown_grace_seconds=25.0)` controls how long the manager waits for tasks to checkpoint before force-expiring leases during shutdown. + - **Task listing** — `my_task.list(status=...)` returns all tasks for a specific durable task function, automatically scoped by function name (via tag) and source type. Supports `status` and `session_id` filters. +- **Steerable durable tasks** — New `steerable=True` parameter on `@durable_task` enables mid-flight steering where new inputs can be queued while a task is still running. Key capabilities: + - **Input queue** — `start()` on an in-progress steerable task queues the new input and returns a `TaskRun` handle immediately, instead of raising `TaskConflictError`. + - **Cancel signal** — `ctx.cancel` is automatically set when new inputs arrive, giving the function a cooperative signal to short-circuit. + - **Automatic drain** — The framework drains the queue after the function suspends or completes, re-entering with the next queued input using `entry_mode="resumed"` and `was_steered=True`. + - **Superseded results** — Previous generation's `TaskRun.result()` resolves with `status="superseded"` and `is_superseded=True`. + - **Context enrichment** — `ctx.was_steered`, `ctx.previous_input`, `ctx.pending_inputs`, and `ctx.generation` provide full steering context. + - **Queue limits** — `max_pending` (default 10) prevents unbounded queue growth; raises `SteeringQueueFull` when exceeded. + - **Crash recovery** — `drain_in_progress` flag in persisted state enables recovery from mid-drain crashes. + - **Distributed steering** — Lease renewal loop polls for pending inputs from other processes and sets `ctx.cancel` accordingly. + - **Etag-aware completion** — Steerable tasks use optimistic concurrency on completion to detect concurrent steering. + ### Breaking Changes +- **`source` parameter removed** — The `source` keyword argument has been removed from `@durable_task()`, `.run()`, `.start()`, and `.options()`. Source provenance is now auto-stamped by the framework and cannot be overridden by developers. Use `tags` for custom metadata. + ### Bugs Fixed +- **Local provider payload merge** — Fixed `_local_provider.py` to use strict shallow merge per Protocol Spec §11: root-level keys are now always replaced, not recursively merged. Previously nested dicts were merged with `dict.update()`, which was more forgiving than the real Task Storage API. +- **Task recovery routing** — `_find_resume_callback()` now matches by `source.name` (the auto-stamped function name) first, then falls back to title prefix match. Previously relied only on fragile title prefix heuristic. + ### Other Changes ## 2.0.0b3 (2026-04-22) diff --git a/sdk/agentserver/azure-ai-agentserver-core/README.md b/sdk/agentserver/azure-ai-agentserver-core/README.md index add29e0bb57b..bc72ac7400f0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/README.md +++ b/sdk/agentserver/azure-ai-agentserver-core/README.md @@ -113,6 +113,54 @@ export APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=..." python my_agent.py ``` +### Durable long-running agents + +The `@durable_task` decorator builds crash-resilient agents that survive container restarts, OOM kills, and redeployments. Task state is persisted to a task store, enabling automatic recovery and multi-turn suspend/resume patterns. + +```python +from datetime import timedelta +from azure.ai.agentserver.core.durable import durable_task, TaskContext, RetryPolicy + +@durable_task( + timeout=timedelta(minutes=30), + retry=RetryPolicy.exponential_backoff(max_attempts=3), + tags={"priority": "high"}, +) +async def process_document(ctx: TaskContext[dict]) -> dict: + ctx.metadata["phase"] = "processing" + result = await analyze(ctx.input["document_url"]) + ctx.metadata["phase"] = "complete" + return {"summary": result} +``` + +**Start and await a task:** + +```python +result = await process_document.run(task_id="doc-42", input={"document_url": "..."}) +print(result.output) # {"summary": "..."} +``` + +**Multi-turn suspend/resume (e.g., conversational agents):** + +```python +@durable_task() +async def chat_session(ctx: TaskContext[dict]) -> dict: + message = ctx.input["message"] + history = ctx.metadata.get("history", []) + reply = await generate_reply(message, history) + history.append({"role": "user", "content": message}) + history.append({"role": "assistant", "content": reply}) + ctx.metadata["history"] = history + return await ctx.suspend(output={"reply": reply}) + +# Each call resumes the same session: +result = await chat_session.run(task_id="session-1", input={"message": "Hello"}) +print(result.output) # {"reply": "Hi! How can I help?"} +print(result.is_suspended) # True +``` + +See the [Developer Guide](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md) for the full API reference. + ## Troubleshooting ### Logging @@ -130,6 +178,7 @@ To report an issue with the client library, or request additional features, plea ## Next steps - Install [`azure-ai-agentserver-invocations`](https://pypi.org/project/azure-ai-agentserver-invocations/) to add the invocation protocol endpoints. +- Read the [Durable Task Developer Guide](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md) for crash-resilient long-running agents. - See the [container image spec](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver) for the full hosted agent contract. ## Contributing diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py index d360a00966a8..9e034a69d087 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py @@ -21,6 +21,7 @@ trace_stream, ) """ + __path__ = __import__("pkgutil").extend_path(__path__, __name__) from ._base import AgentServerHost diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py index 0785f01e36ba..dd43c3cbc722 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py @@ -160,7 +160,7 @@ class MyHost(InvocationAgentServerHost, ResponsesAgentServerHost): _DEFAULT_ACCESS_LOG_FORMAT = '%(h)s "%(r)s" %(s)s %(b)s %(D)sμs' - def __init__( + def __init__( # pylint: disable=too-many-statements self, *, applicationinsights_connection_string: Optional[str] = None, @@ -168,13 +168,28 @@ def __init__( log_level: Optional[str] = None, access_log: Optional[logging.Logger] = _SENTINEL_ACCESS_LOG, # type: ignore[assignment] access_log_format: Optional[str] = None, - configure_observability: Optional[Callable[..., None]] = _tracing.configure_observability, + configure_observability: Optional[ + Callable[..., None] + ] = _tracing.configure_observability, routes: Optional[list[Route]] = None, **kwargs: Any, ) -> None: # Shutdown handler slot (server-level lifecycle) ------------------- self._shutdown_fn: Optional[Callable[[], Awaitable[None]]] = None + # Durable task manager (optional — enabled by default) ---- + self._durable_task_manager: Optional[Any] = None + try: + from .durable._manager import ( # pylint: disable=import-outside-toplevel + DurableTaskManager, + ) + + self._durable_task_manager = DurableTaskManager( + _config.AgentConfig.from_env() + ) + except Exception: # pylint: disable=broad-exception-caught + pass # durable tasks not available — continue without + # Server version segments for the x-platform-server header. # Protocol packages call register_server_version() to add their # own portion; the middleware joins them at response time. @@ -187,7 +202,10 @@ def __init__( self.config: _config.AgentConfig = _config.AgentConfig.from_env() # Observability (logging + tracing) -------------------------------- - _conn_str = applicationinsights_connection_string or self.config.appinsights_connection_string + _conn_str = ( + applicationinsights_connection_string + or self.config.appinsights_connection_string + ) if configure_observability is not None: try: configure_observability( @@ -197,13 +215,18 @@ def __init__( except ValueError: raise # invalid log_level etc. — user should fix their config except Exception: # pylint: disable=broad-exception-caught - logger.warning("Failed to initialize observability; continuing without it.", exc_info=True) + logger.warning( + "Failed to initialize observability; continuing without it.", + exc_info=True, + ) # Access logging --------------------------------------------------- self._access_log: Optional[logging.Logger] = ( logger if access_log is _SENTINEL_ACCESS_LOG else access_log ) - self._access_log_format: str = access_log_format or self._DEFAULT_ACCESS_LOG_FORMAT + self._access_log_format: str = ( + access_log_format or self._DEFAULT_ACCESS_LOG_FORMAT + ) # Timeouts --------------------------------------------------------- self._graceful_shutdown_timeout = _config.resolve_graceful_shutdown_timeout( @@ -212,7 +235,9 @@ def __init__( # Build lifespan context manager @contextlib.asynccontextmanager - async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF029 + async def _lifespan( + _app: Starlette, + ) -> AsyncGenerator[None, None]: # noqa: RUF029 logger.info("AgentServerHost started") # --- Startup configuration logging --- @@ -225,7 +250,11 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF cfg.agent_version or _NOT_SET, cfg.port, cfg.session_id or _NOT_SET, - cfg.sse_keepalive_interval if cfg.sse_keepalive_interval > 0 else "disabled", + ( + cfg.sse_keepalive_interval + if cfg.sse_keepalive_interval > 0 + else "disabled" + ), ) logger.info( "Connectivity: project_endpoint=%s, otlp_endpoint=%s, appinsights_configured=%s", @@ -233,13 +262,26 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF _mask_uri(cfg.otlp_endpoint), bool(cfg.appinsights_connection_string), ) - protocols = ", ".join(self._server_version_segments) if self._server_version_segments else _NOT_SET + protocols = ( + ", ".join(self._server_version_segments) + if self._server_version_segments + else _NOT_SET + ) logger.info( "Host options: shutdown_timeout=%ss, protocols=%s", self._graceful_shutdown_timeout, protocols, ) + # --- Durable task manager startup --- + if self._durable_task_manager is not None: + from .durable._manager import ( # pylint: disable=import-outside-toplevel + set_task_manager, + ) + + set_task_manager(self._durable_task_manager) + await self._durable_task_manager.startup() + yield # --- SHUTDOWN: runs once when the server is stopping --- @@ -247,6 +289,16 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF "AgentServerHost shutting down (graceful timeout=%ss)", self._graceful_shutdown_timeout, ) + + # Durable task manager shutdown + if self._durable_task_manager is not None: + try: + await self._durable_task_manager.shutdown() + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Error shutting down durable task manager", exc_info=True + ) + if self._graceful_shutdown_timeout == 0: logger.info("Graceful shutdown drain period disabled (timeout=0)") else: @@ -263,11 +315,22 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF except Exception: # pylint: disable=broad-exception-caught logger.warning("Error in on_shutdown", exc_info=True) - # Merge routes: subclass routes (if any) + health endpoint + # Merge routes: subclass routes (if any) + health endpoint + durable tasks all_routes: list[Any] = list(routes or []) all_routes.append( - Route("/readiness", self._readiness_endpoint, methods=["GET"], name="readiness"), + Route( + "/readiness", + self._readiness_endpoint, + methods=["GET"], + name="readiness", + ), ) + if self._durable_task_manager is not None: + from .durable._resume_route import ( # pylint: disable=import-outside-toplevel + create_resume_route, + ) + + all_routes.append(create_resume_route()) # Initialize Starlette with combined routes, lifespan, and middleware super().__init__( @@ -380,7 +443,9 @@ def request_span( # Shutdown handler (server-level lifecycle) # ------------------------------------------------------------------ - def shutdown_handler(self, fn: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]: + def shutdown_handler( + self, fn: Callable[[], Awaitable[None]] + ) -> Callable[[], Awaitable[None]]: """Register a function as the shutdown handler. :param fn: Async function called during graceful shutdown. @@ -455,7 +520,9 @@ def _handle_sigterm(_signum: int, _frame: Any) -> None: finally: signal.signal(signal.SIGTERM, original_sigterm) - async def run_async(self, host: str = "0.0.0.0", port: Optional[int] = None) -> None: + async def run_async( + self, host: str = "0.0.0.0", port: Optional[int] = None + ) -> None: """Start the server asynchronously (awaitable). :param host: Network interface to bind. Defaults to ``"0.0.0.0"``. @@ -474,7 +541,9 @@ async def run_async(self, host: str = "0.0.0.0", port: Optional[int] = None) -> # Health endpoint # ------------------------------------------------------------------ - async def _readiness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + async def _readiness_endpoint( + self, request: Request + ) -> Response: # pylint: disable=unused-argument """GET /readiness — readiness check endpoint. :param request: The incoming Starlette request. @@ -516,7 +585,9 @@ async def sse_keepalive_stream( if pending is None: pending = asyncio.ensure_future(ait.__anext__()) try: - chunk = await asyncio.wait_for(asyncio.shield(pending), timeout=interval) + chunk = await asyncio.wait_for( + asyncio.shield(pending), timeout=interval + ) pending = None # consumed — create new task next iteration yield chunk except asyncio.TimeoutError: diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py index e22bc1ff1cf6..c2d472c5c630 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py @@ -120,7 +120,8 @@ def from_env(cls) -> Self: session_id=os.environ.get(_ENV_FOUNDRY_AGENT_SESSION_ID, ""), port=resolve_port(None), appinsights_connection_string=os.environ.get( - _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING, ""), + _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING, "" + ), otlp_endpoint=os.environ.get(_ENV_OTEL_EXPORTER_OTLP_ENDPOINT, ""), sse_keepalive_interval=resolve_sse_keepalive_interval(None), ) @@ -158,9 +159,7 @@ def _require_int(name: str, value: object) -> int: :raises ValueError: If *value* is not an integer. """ if isinstance(value, bool) or not isinstance(value, int): - raise ValueError( - f"Invalid value for {name}: {value!r} (expected an integer)" - ) + raise ValueError(f"Invalid value for {name}: {value!r} (expected an integer)") return value @@ -176,9 +175,7 @@ def _validate_port(value: int, source: str) -> int: :raises ValueError: If the port is outside 1-65535. """ if not 1 <= value <= 65535: - raise ValueError( - f"Invalid value for {source}: {value} (expected 1-65535)" - ) + raise ValueError(f"Invalid value for {source}: {value} (expected 1-65535)") return value @@ -239,9 +236,7 @@ def resolve_appinsights_connection_string( """ if connection_string is not None: return connection_string - return os.environ.get( - _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING - ) + return os.environ.get(_ENV_APPLICATIONINSIGHTS_CONNECTION_STRING) def resolve_log_level(level: Optional[str]) -> str: diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py index c5b1c9e01efe..9268e24df81c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py @@ -58,6 +58,4 @@ def create_error_response( body["type"] = error_type if details is not None: body["details"] = details - return JSONResponse( - {"error": body}, status_code=status_code, headers=headers - ) + return JSONResponse({"error": body}, status_code=status_code, headers=headers) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py index 4fb3fe78a9cd..63b0d320a771 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py @@ -76,7 +76,9 @@ def _get_trace_id(headers: list[tuple[bytes, bytes]] | None = None) -> str | Non :rtype: str | None """ try: - from opentelemetry import trace as _trace # pylint: disable=import-outside-toplevel + from opentelemetry import ( + trace as _trace, + ) # pylint: disable=import-outside-toplevel span = _trace.get_current_span() ctx = span.get_span_context() @@ -147,7 +149,10 @@ async def _send_wrapper(message: MutableMapping[str, Any]) -> None: elapsed_ms = (time.monotonic() - start) * 1000 logger.warning( "Inbound %s %s failed with status 500 in %.1fms%s", - method, path, elapsed_ms, extra_str, + method, + path, + elapsed_ms, + extra_str, ) raise @@ -156,10 +161,18 @@ async def _send_wrapper(message: MutableMapping[str, Any]) -> None: if status_code is not None and status_code >= 400: logger.warning( "Inbound %s %s completed with status %d in %.1fms%s", - method, path, status_code, elapsed_ms, extra_str, + method, + path, + status_code, + elapsed_ms, + extra_str, ) else: logger.info( "Inbound %s %s completed with status %s in %.1fms%s", - method, path, status_code, elapsed_ms, extra_str, + method, + path, + status_code, + elapsed_ms, + extra_str, ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py index faf5d23d7aaf..485dd309f424 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py @@ -34,7 +34,11 @@ """ import logging import os -from collections.abc import AsyncIterable, AsyncIterator, Mapping # pylint: disable=import-error +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Mapping, +) # pylint: disable=import-error from contextlib import contextmanager from typing import Any, Iterator, Optional, Union @@ -76,10 +80,12 @@ logger = logging.getLogger("azure.ai.agentserver") # Composite propagator handles both traceparent/tracestate AND baggage -_propagator = composite.CompositePropagator([ - TraceContextTextMapPropagator(), - W3CBaggagePropagator(), -]) +_propagator = composite.CompositePropagator( + [ + TraceContextTextMapPropagator(), + W3CBaggagePropagator(), + ] +) # ====================================================================== @@ -122,17 +128,24 @@ def configure_observability( # prevent duplicate output on stderr. _has_console = any( getattr(h, _CONSOLE_HANDLER_ATTR, False) - or (isinstance(h, logging.StreamHandler) and not isinstance(h, logging.FileHandler)) + or ( + isinstance(h, logging.StreamHandler) + and not isinstance(h, logging.FileHandler) + ) for h in root.handlers ) if not _has_console: _console = logging.StreamHandler() - _console.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) + _console.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") + ) setattr(_console, _CONSOLE_HANDLER_ATTR, True) root.addHandler(_console) # Suppress the noisy Azure Core HTTP logging policy logger. - logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING) + logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel( + logging.WARNING + ) # Tracing and OTel export _configure_tracing(connection_string=connection_string) @@ -149,7 +162,9 @@ def _configure_tracing(connection_string: Optional[str] = None) -> None: """ resource = _create_resource() if resource is None: - logger.warning("Failed to create OTel resource — tracing will not be configured.") + logger.warning( + "Failed to create OTel resource — tracing will not be configured." + ) return # Build custom processors @@ -166,8 +181,10 @@ def _configure_tracing(connection_string: Optional[str] = None) -> None: span_processors = [ _FoundryEnrichmentSpanProcessor( - agent_name=agent_name, agent_version=agent_version, - agent_id=agent_id, project_id=project_id, + agent_name=agent_name, + agent_version=agent_version, + agent_id=agent_id, + project_id=project_id, ), ] log_record_processors = [_BaggageLogRecordProcessor()] # type: ignore[list-item] @@ -179,9 +196,13 @@ def _configure_tracing(connection_string: Optional[str] = None) -> None: log_record_processors=log_record_processors, connection_string=connection_string, ) - logger.info("Tracing configured successfully via microsoft-opentelemetry distro.") + logger.info( + "Tracing configured successfully via microsoft-opentelemetry distro." + ) except ImportError: - logger.warning("microsoft-opentelemetry is not installed — tracing export disabled.") + logger.warning( + "microsoft-opentelemetry is not installed — tracing export disabled." + ) # Still set up TracerProvider with enrichment processor so spans are created _ensure_trace_provider(resource, span_processors) @@ -480,7 +501,9 @@ def on_start(self, span: Any, parent_context: Any = None) -> None: session_id = _otel_baggage.get_baggage(_BAGGAGE_SESSION_ID, context=ctx) if session_id: span.set_attribute(_ATTR_SESSION_ID, session_id) - conversation_id = _otel_baggage.get_baggage(_BAGGAGE_CONVERSATION_ID, context=ctx) + conversation_id = _otel_baggage.get_baggage( + _BAGGAGE_CONVERSATION_ID, context=ctx + ) if conversation_id: span.set_attribute(_ATTR_GEN_AI_CONVERSATION_ID, conversation_id) @@ -505,7 +528,9 @@ def _on_ending(self, span: Any) -> None: if self.agent_id: attrs[_ATTR_GEN_AI_AGENT_ID] = self.agent_id except Exception: # pylint: disable=broad-exception-caught - logger.debug("Failed to enrich span attributes in _on_ending", exc_info=True) + logger.debug( + "Failed to enrich span attributes in _on_ending", exc_info=True + ) def on_end(self, span: Any) -> None: # pylint: disable=unused-argument pass @@ -513,7 +538,9 @@ def on_end(self, span: Any) -> None: # pylint: disable=unused-argument def shutdown(self) -> None: pass - def force_flush(self, timeout_millis: int = 30000) -> bool: # pylint: disable=unused-argument + def force_flush( + self, timeout_millis: int = 30000 + ) -> bool: # pylint: disable=unused-argument return True @@ -534,7 +561,7 @@ def on_emit(self, log_data: Any) -> None: # pylint: disable=unused-argument try: ctx = _otel_context.get_current() entries = _otel_baggage.get_all(context=ctx) - if entries and hasattr(log_data, 'log_record') and log_data.log_record: + if entries and hasattr(log_data, "log_record") and log_data.log_record: for key, value in entries.items(): log_data.log_record.attributes[key] = value # type: ignore[index] except Exception: # pylint: disable=broad-except @@ -543,7 +570,9 @@ def on_emit(self, log_data: Any) -> None: # pylint: disable=unused-argument def shutdown(self) -> None: pass - def force_flush(self, timeout_millis: int = 30000) -> bool: # pylint: disable=unused-argument + def force_flush( + self, timeout_millis: int = 30000 + ) -> bool: # pylint: disable=unused-argument return True @@ -559,12 +588,16 @@ def _create_resource() -> Any: logger.warning("OTel SDK not installed — tracing resource creation failed.") return None # service.name maps to cloud_RoleName in App Insights - agent_name = os.environ.get(_config._ENV_FOUNDRY_AGENT_NAME, "") # pylint: disable=protected-access + agent_name = os.environ.get( + _config._ENV_FOUNDRY_AGENT_NAME, "" + ) # pylint: disable=protected-access service_name = agent_name or _SERVICE_NAME_VALUE return Resource.create({_ATTR_SERVICE_NAME: service_name}) -def _ensure_trace_provider(resource: Any, span_processors: Optional[list[Any]] = None) -> Any: +def _ensure_trace_provider( + resource: Any, span_processors: Optional[list[Any]] = None +) -> Any: """Get or create a TracerProvider, optionally adding span processors. Used as a fallback when the microsoft-opentelemetry distro is not installed. @@ -586,7 +619,9 @@ def _ensure_trace_provider(resource: Any, span_processors: Optional[list[Any]] = else: provider = SdkTracerProvider(resource=resource) trace.set_tracer_provider(provider) - if span_processors and not getattr(provider, "_agentserver_processors_added", False): + if span_processors and not getattr( + provider, "_agentserver_processors_added", False + ): for proc in span_processors: provider.add_span_processor(proc) provider._agentserver_processors_added = True # type: ignore[attr-defined] # pylint: disable=protected-access diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/__init__.py new file mode 100644 index 000000000000..5525bc7ffb3f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/__init__.py @@ -0,0 +1,88 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Durable task subsystem for crash-resilient long-running agents. + +Provides the :func:`durable_task` decorator and supporting types for +building Azure AI Hosted Agents that survive container crashes, +OOM kills, and redeployments. + +Key features: + +- **Lifecycle automation** — ``.run()`` and ``.start()`` automatically + start, resume, or recover tasks based on their current state. +- **Entry mode** — ``ctx.entry_mode`` tells the function whether it was + entered fresh, resumed from suspension, or recovered from a crash. +- **RetryPolicy** — configurable retry with exponential, fixed, or linear + backoff (see :class:`RetryPolicy` presets). +- **Streaming** — emit incremental output via ``ctx.stream()`` and consume + with ``async for chunk in task_run``. +- **Source tracking** — attach immutable provenance metadata at task + creation time via the ``source`` parameter. + +Public API:: + + from azure.ai.agentserver.core.durable import ( + durable_task, + DurableTask, + RetryPolicy, + TaskContext, + TaskMetadata, + TaskResult, + TaskRun, + Suspended, + TaskStatus, + TaskFailed, + TaskSuspended, + TaskCancelled, + TaskNotFound, + TaskConflictError, + TaskTerminated, + EntryMode, + TaskInfo, + ) +""" + +from ._context import EntryMode, TaskContext +from ._decorator import DurableTask, DurableTaskOptions, durable_task +from ._exceptions import ( + EtagConflict, + SteeringQueueFull, + TaskCancelled, + TaskConflictError, + TaskFailed, + TaskNotFound, + TaskSuspended, + TaskTerminated, +) +from ._metadata import TaskMetadata +from ._models import TaskInfo, TaskStatus +from ._result import TaskResult +from ._retry import RetryPolicy +from ._run import Suspended, TaskRun +from ._stream import QueueStreamHandler, StreamHandler + +__all__ = [ + "durable_task", + "DurableTask", + "DurableTaskOptions", + "QueueStreamHandler", + "RetryPolicy", + "StreamHandler", + "TaskContext", + "TaskMetadata", + "TaskResult", + "TaskRun", + "Suspended", + "TaskStatus", + "TaskFailed", + "TaskSuspended", + "TaskCancelled", + "TaskNotFound", + "TaskConflictError", + "TaskTerminated", + "EtagConflict", + "SteeringQueueFull", + "EntryMode", + "TaskInfo", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_client.py new file mode 100644 index 000000000000..e2ac92a8747a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_client.py @@ -0,0 +1,241 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Hosted durable task provider — HTTP client for the Foundry Task Storage API. + +Communicates with ``{FOUNDRY_PROJECT_ENDPOINT}/storage/tasks`` using +``httpx.AsyncClient``. Bearer tokens are obtained lazily from +``DefaultAzureCredential`` when running in a hosted environment. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import httpx + +from ._exceptions import TaskNotFound +from ._models import ( + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, + TaskStatus, +) + +logger = logging.getLogger("azure.ai.agentserver.durable") + +_AUTH_SCOPE = "https://ai.azure.com/.default" +_API_VERSION = "v1" + + +class HostedDurableTaskProvider: + """HTTP-backed provider for the Foundry Task Storage API. + + :param project_endpoint: The ``FOUNDRY_PROJECT_ENDPOINT`` base URL. + :type project_endpoint: str + :param credential: An ``azure.identity.aio.DefaultAzureCredential`` + instance, or any token credential supporting ``get_token(scope)``. + :type credential: Any + """ + + def __init__(self, project_endpoint: str, credential: Any) -> None: + self._base_url = f"{project_endpoint.rstrip('/')}/storage/tasks" + self._credential = credential + self._client = httpx.AsyncClient(timeout=30.0) + + async def _get_headers(self) -> dict[str, str]: + token = await self._credential.get_token(_AUTH_SCOPE) + return { + "Authorization": f"Bearer {token.token}", + "Content-Type": "application/json", + } + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task via POST /storage/tasks. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + headers = await self._get_headers() + params: dict[str, str] = {"api-version": _API_VERSION} + if request.lease_owner is not None: + params["lease_owner"] = request.lease_owner + if request.lease_instance_id is not None: + params["lease_instance_id"] = request.lease_instance_id + if request.lease_duration_seconds is not None: + params["lease_duration_seconds"] = str(request.lease_duration_seconds) + + body: dict[str, Any] = { + "agent_name": request.agent_name, + "session_id": request.session_id, + } + if request.id is not None: + body["id"] = request.id + if request.status != "pending": + body["status"] = request.status + if request.title is not None: + body["title"] = request.title + if request.description is not None: + body["description"] = request.description + if request.payload is not None: + body["payload"] = request.payload + if request.tags is not None: + body["tags"] = request.tags + if request.source is not None: + body["source"] = request.source + + response = await self._client.post( + self._base_url, json=body, headers=headers, params=params + ) + response.raise_for_status() + return TaskInfo.from_dict(response.json()) + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a task by ID via GET /storage/tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + headers = await self._get_headers() + response = await self._client.get( + f"{self._base_url}/{task_id}", + headers=headers, + params={"api-version": _API_VERSION}, + ) + if response.status_code == 404: + return None + response.raise_for_status() + return TaskInfo.from_dict(response.json()) + + async def update(self, task_id: str, patch: TaskPatchRequest) -> TaskInfo: + """Update a task via PATCH /storage/tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + headers = await self._get_headers() + params: dict[str, str] = {"api-version": _API_VERSION} + if patch.lease_owner is not None: + params["lease_owner"] = patch.lease_owner + if patch.lease_instance_id is not None: + params["lease_instance_id"] = patch.lease_instance_id + if patch.lease_duration_seconds is not None: + params["lease_duration_seconds"] = str(patch.lease_duration_seconds) + + body: dict[str, Any] = {} + if patch.status is not None: + body["status"] = patch.status + if patch.payload is not None: + body["payload"] = patch.payload + if patch.tags is not None: + body["tags"] = patch.tags + if patch.error is not None: + body["error"] = patch.error + if patch.suspension_reason is not None: + body["suspension_reason"] = patch.suspension_reason + + if patch.if_match is not None: + headers["If-Match"] = f'"{patch.if_match}"' + + response = await self._client.patch( + f"{self._base_url}/{task_id}", + json=body, + headers=headers, + params=params, + ) + if response.status_code == 404: + raise TaskNotFound(task_id) + response.raise_for_status() + return TaskInfo.from_dict(response.json()) + + async def delete( + self, + task_id: str, + *, + force: bool = False, + cascade: bool = False, + ) -> None: + """Delete a task via DELETE /storage/tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks. + :paramtype cascade: bool + """ + headers = await self._get_headers() + params: dict[str, str] = {"api-version": _API_VERSION} + if force: + params["force"] = "true" + if cascade: + params["cascade"] = "true" + + response = await self._client.delete( + f"{self._base_url}/{task_id}", + headers=headers, + params=params, + ) + if response.status_code == 404: + raise TaskNotFound(task_id) + response.raise_for_status() + + async def list( + self, + *, + agent_name: str, + session_id: str, + status: TaskStatus | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + ) -> list[TaskInfo]: + """List tasks via GET /storage/tasks. + + :keyword agent_name: Filter by agent name. + :paramtype agent_name: str + :keyword session_id: Filter by session ID. + :paramtype session_id: str + :keyword status: Filter by task status. + :paramtype status: TaskStatus | None + :keyword lease_owner: Filter by lease owner. + :paramtype lease_owner: str | None + :keyword tag: Filter by tag key-value pairs. + :paramtype tag: dict[str, str] | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + headers = await self._get_headers() + params: dict[str, str] = { + "api-version": _API_VERSION, + "agent_name": agent_name, + "session_id": session_id, + } + if status is not None: + params["status"] = status + if lease_owner is not None: + params["lease_owner"] = lease_owner + if tag: + for key, value in tag.items(): + params[f"tag.{key}"] = value + + response = await self._client.get( + self._base_url, headers=headers, params=params + ) + response.raise_for_status() + data = response.json() + items: list[dict[str, Any]] = data.get("data", data.get("items", [])) + return [TaskInfo.from_dict(item) for item in items] + + async def close(self) -> None: + """Close the underlying HTTP client.""" + await self._client.aclose() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_context.py new file mode 100644 index 000000000000..3d357d429d3b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_context.py @@ -0,0 +1,184 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskContext — the single parameter to a durable task function. + +Provides identity, typed input, mutable metadata, cancellation signals, +and the ``suspend()`` method for pausing execution. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from typing import Any, Generic, Literal, Sequence, TypeVar + +from ._metadata import TaskMetadata +from ._stream import StreamHandler + +Input = TypeVar("Input") +Output = TypeVar("Output") + +EntryMode = Literal["fresh", "resumed", "recovered"] +"""Why the durable function was entered. + +- ``"fresh"`` — First execution. Task was just created or started from pending. +- ``"resumed"`` — Re-entered after suspension. On developer-initiated resume + (via ``.run()``), ``ctx.input`` contains the new input. On platform-initiated + resume (via ``/tasks/{task_id}/resume``), ``ctx.input`` contains the task's + persisted input. Also used when a steering input drains from the queue — + check ``ctx.was_steered`` to distinguish steering re-entry from normal resume. +- ``"recovered"`` — Re-entered after stale task detection. The previous execution + crashed or timed out. ``ctx.input`` contains the task's persisted input. + If a steerable task crashed mid-drain, ``ctx.was_steered`` will be ``True`` + and steering context (``previous_input``, ``generation``) is meaningful. +""" + + +class _Suspended: + """Internal sentinel for suspended tasks. See ``Suspended`` in ``_run.py``.""" + + __slots__ = ("reason", "output") + + def __init__( + self, + reason: str | None = None, + output: Any | None = None, + ) -> None: + self.reason = reason + self.output = output + + +class TaskContext(Generic[Input]): # pylint: disable=too-many-instance-attributes + """The single parameter to a durable task function. + + Provides access to the task's identity, typed input, mutable metadata + for progress tracking, cancellation signals, and the ability to + suspend execution. + + :param task_id: Unique task identifier. + :type task_id: str + :param title: Human-readable task title. + :type title: str + :param description: Optional task description. + :type description: str | None + :param session_id: Session scope identifier. + :type session_id: str + :param agent_name: Agent name from config. + :type agent_name: str + :param tags: Merged decorator + call-site tags. + :type tags: dict[str, str] + :param input: Typed, validated input value. + :type input: Input + :param metadata: Mutable progress metadata. + :type metadata: TaskMetadata + :param run_attempt: Framework retry attempt counter. + :type run_attempt: int + :param lease_generation: Lease re-acquisition counter. + :type lease_generation: int + :param cancel: Request-level cancellation event. + :type cancel: asyncio.Event + :param shutdown: Container-level shutdown event. + :type shutdown: asyncio.Event + """ + + __slots__ = ( + "task_id", + "title", + "description", + "session_id", + "agent_name", + "tags", + "input", + "metadata", + "run_attempt", + "lease_generation", + "cancel", + "shutdown", + "_suspend_callback", + "_stream_handler", + "entry_mode", + "was_steered", + "previous_input", + "pending_inputs", + "generation", + ) + + def __init__( + self, + *, + task_id: str, + title: str, + description: str | None = None, + session_id: str, + agent_name: str, + tags: dict[str, str], + input: Input, # noqa: A002 — mirrors the spec naming + metadata: TaskMetadata, + run_attempt: int = 0, + lease_generation: int = 0, + cancel: asyncio.Event | None = None, + shutdown: asyncio.Event | None = None, + stream_handler: StreamHandler | None = None, + entry_mode: EntryMode = "fresh", + was_steered: bool = False, + previous_input: Input | None = None, + pending_inputs: Sequence[Any] | None = None, + generation: int = 0, + ) -> None: + self.task_id = task_id + self.title = title + self.description = description + self.session_id = session_id + self.agent_name = agent_name + self.tags = tags + self.input = input + self.metadata = metadata + self.run_attempt = run_attempt + self.lease_generation = lease_generation + self.cancel = cancel or asyncio.Event() + self.shutdown = shutdown or asyncio.Event() + self._suspend_callback: Any = None + self._stream_handler: StreamHandler | None = stream_handler + self.entry_mode: EntryMode = entry_mode + self.was_steered: bool = was_steered + self.previous_input: Input | None = previous_input + self.pending_inputs: Sequence[Any] = ( + pending_inputs if pending_inputs is not None else () + ) + self.generation: int = generation + + async def suspend( + self, + *, + reason: str | None = None, + output: Any | None = None, + ) -> Any: + """Suspend the task, releasing the lease and persisting state. + + Must be used as ``return await ctx.suspend(...)``. The framework + interprets the returned sentinel to transition the task to + ``suspended`` status. + + :keyword reason: Human-readable suspension reason. + :paramtype reason: str | None + :keyword output: Optional output snapshot for observers. + :paramtype output: Any | None + :return: A ``Suspended`` sentinel that the framework interprets. + :rtype: Suspended + """ + from ._run import Suspended # pylint: disable=import-outside-toplevel + + return Suspended(reason=reason, output=output) + + async def stream(self, item: Any) -> None: + """Emit a streaming item to observers iterating this task's output. + + When a :class:`~azure.ai.agentserver.core.durable.StreamHandler` + is configured, the item is routed through ``handler.put(item)``. + Otherwise the call is a no-op. + + :param item: The value to stream. + :type item: Any + """ + if self._stream_handler is not None: + await self._stream_handler.put(item) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_decorator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_decorator.py new file mode 100644 index 000000000000..316a1a72bf76 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_decorator.py @@ -0,0 +1,981 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""``@durable_task`` decorator — turns an async function into a crash-resilient +unit of work with automatic task lifecycle management. + +Usage:: + + from azure.ai.agentserver.core.durable import durable_task, TaskContext + + @durable_task + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: + ... + + result = await my_task.run(task_id="t1", input=MyInput(...)) +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import inspect +import logging as _logging +from collections.abc import Awaitable, Callable +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Generic, + TypeVar, + get_args, + get_type_hints, + overload, +) + +import re + +from ._context import TaskContext +from ._result import TaskResult +from ._retry import RetryPolicy +from ._run import TaskRun +from ._stream import StreamHandler + +if TYPE_CHECKING: + from ._models import TaskStatus + +Input = TypeVar("Input") +Output = TypeVar("Output") +F = TypeVar("F", bound=Callable[..., Any]) + +_VALID_TASK_ID_RE = re.compile(r"^[a-zA-Z0-9\-_.:]+$") +_MAX_TASK_ID_LENGTH = 256 + +#: Prefix for framework-reserved tags. Developer tags with this prefix are +#: silently stripped to prevent collisions with auto-stamped tags. +_RESERVED_TAG_PREFIX = "_durable_task_" + +_logger = _logging.getLogger("azure.ai.agentserver.durable") + + +def _strip_reserved_tags(tags: dict[str, str]) -> dict[str, str]: + """Remove framework-reserved tags from developer-provided tags. + + Tags prefixed with ``_durable_task_`` are reserved for framework use. + If a developer provides them, they are silently dropped with a warning. + + :param tags: Developer-provided tags. + :type tags: dict[str, str] + :return: Tags with reserved keys removed. + :rtype: dict[str, str] + """ + reserved = [k for k in tags if k.startswith(_RESERVED_TAG_PREFIX)] + if reserved: + _logger.warning( + "Ignoring reserved tag(s) %s — tags prefixed with %r are " + "framework-owned and cannot be overridden", + reserved, + _RESERVED_TAG_PREFIX, + ) + return {k: v for k, v in tags.items() if not k.startswith(_RESERVED_TAG_PREFIX)} + return tags + + +def _validate_task_id(task_id: str) -> None: + if not task_id or len(task_id) > _MAX_TASK_ID_LENGTH: + raise ValueError( + f"task_id must be 1-{_MAX_TASK_ID_LENGTH} characters, " + f"got {len(task_id)}" + ) + if not _VALID_TASK_ID_RE.match(task_id): + raise ValueError( + f"task_id contains invalid characters: {task_id!r}. " + f"Allowed: [a-zA-Z0-9\\-_.:] " + ) + + +def _extract_generic_args( + fn: Callable[..., Any], +) -> tuple[type[Any], type[Any]]: + """Extract Input and Output types from a durable task function signature. + + The function must accept a single ``TaskContext[Input]`` parameter + and return ``Output``. + + :param fn: The async function to inspect. + :type fn: Callable[..., Any] + :returns: ``(InputType, OutputType)`` tuple. + :rtype: tuple[type[Any], type[Any]] + :raises TypeError: If the signature doesn't match expectations. + """ + hints = get_type_hints(fn) + params = list(inspect.signature(fn).parameters.values()) + + # Find the TaskContext parameter + ctx_param = None + for p in params: + hint = hints.get(p.name) + if hint is not None: + origin = getattr(hint, "__origin__", None) + if origin is TaskContext: + ctx_param = p + break + + if ctx_param is None: + raise TypeError( + f"Durable task function {fn.__qualname__!r} must accept a " + f"TaskContext[Input] parameter" + ) + + ctx_hint = hints[ctx_param.name] + args = get_args(ctx_hint) + input_type: type[Any] = args[0] if args else Any + + return_hint = hints.get("return", Any) + # Unwrap Optional, Awaitable, etc. + output_type: type[Any] = return_hint if return_hint is not None else type(None) + + return input_type, output_type + + +def _serialize_input(value: Any) -> Any: + """Serialize an input value for storage in the task payload. + + :param value: The input value to serialize. + :type value: Any + :return: The serialized form of the input. + :rtype: Any + """ + # Pydantic model + if hasattr(value, "model_dump"): + return value.model_dump() + # Plain JSON-serializable + return value + + +def _deserialize_input(value: Any, input_type: type[Any]) -> Any: + """Deserialize an input value from the task payload. + + :param value: The serialized input value. + :type value: Any + :param input_type: The expected type to deserialize into. + :type input_type: type[Any] + :return: The deserialized input value. + :rtype: Any + """ + if value is None: + return None + # Pydantic model + if hasattr(input_type, "model_validate"): + return input_type.model_validate(value) + # dict-constructable class + if ( + isinstance(value, dict) + and callable(input_type) + and input_type not in (dict, str, int, float, bool, list) + ): + try: + return input_type(**value) + except TypeError: + pass + return value + + +def _is_stale(task_updated_at: str, timeout: float) -> bool: + """Check if an in_progress task is stale based on its updated_at timestamp. + + :param task_updated_at: ISO 8601 timestamp of the task's last update. + :type task_updated_at: str + :param timeout: Seconds after which the task is considered stale. + :type timeout: float + :returns: True if the task is stale. + :rtype: bool + """ + if not task_updated_at: + return False + from datetime import datetime, timezone # pylint: disable=import-outside-toplevel + + updated = datetime.fromisoformat(task_updated_at) + now = datetime.now(timezone.utc) + if updated.tzinfo is None: + updated = updated.replace(tzinfo=timezone.utc) + return (now - updated).total_seconds() > timeout + + +class DurableTaskOptions: # pylint: disable=too-many-instance-attributes + """Options for a durable task. + + :param name: **Stable identity anchor.** Used for recovery routing and + source stamping. If you rename the Python function later, existing + in-flight tasks are still recovered correctly because the framework + matches on this name. + :type name: str + :param title: Human-readable title template. + :type title: str | Callable[[Any, str], str] | None + :param tags: Default tags (static dict or callable factory). + :type tags: dict[str, str] | Callable[[Any, str], dict[str, str]] + :param description: Task description (static string or callable factory). + :type description: str | Callable[[Any, str], str] | None + :param timeout: Execution timeout. + :type timeout: timedelta | None + :param lease_duration_seconds: Lease TTL. + :type lease_duration_seconds: int + :param store_input: Whether to persist input on the task record. + :type store_input: bool + :param ephemeral: Whether to delete on terminal exit. + :type ephemeral: bool + """ + + __slots__ = ( + "name", + "title", + "tags", + "description", + "timeout", + "lease_duration_seconds", + "store_input", + "ephemeral", + "retry", + "steerable", + "max_pending", + ) + + def __init__( + self, + name: str, + title: str | Callable[[Any, str], str] | None = None, + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = None, + description: str | Callable[[Any, str], str | None] | None = None, + timeout: timedelta | None = None, + lease_duration_seconds: int = 60, + store_input: bool = True, + ephemeral: bool = True, + retry: RetryPolicy | None = None, + steerable: bool = False, + max_pending: int = 10, + ) -> None: + self.name = name + self.title = title + self.tags = tags if tags is not None else {} + self.description = description + self.timeout = timeout + self.lease_duration_seconds = lease_duration_seconds + self.store_input = store_input + self.ephemeral = ephemeral + self.retry = retry + self.steerable = steerable + self.max_pending = max_pending + + def __repr__(self) -> str: + return ( + f"DurableTaskOptions(name={self.name!r}, lease_duration_seconds={self.lease_duration_seconds}, " + f"store_input={self.store_input}, ephemeral={self.ephemeral}, retry={self.retry!r}, " + f"timeout={self.timeout!r}, " + f"steerable={self.steerable}, max_pending={self.max_pending})" + ) + + +class DurableTask(Generic[Input, Output]): + """A decorated durable task function. Not callable directly. + + Use :meth:`run` (invoke-and-wait), :meth:`start` (fire-and-forget), + or :meth:`options` (per-call overrides). + + :param fn: The decorated async function. + :param opts: Frozen task options. + :param input_type: Extracted input type. + :param output_type: Extracted output type. + """ + + __slots__ = ("_fn", "_opts", "_input_type", "_output_type", "name") + + def __init__( + self, + fn: Callable[[TaskContext[Input]], Awaitable[Output]], + opts: DurableTaskOptions, + input_type: type[Input], + output_type: type[Output], + ) -> None: + self._fn = fn + self._opts = opts + self._input_type = input_type + self._output_type = output_type + self.name = opts.name + + def _resolve_title(self, input_val: Input, task_id: str) -> str: + if callable(self._opts.title): + return self._opts.title(input_val, task_id) + if isinstance(self._opts.title, str): + return self._opts.title + return f"{self.name}:{task_id[:8]}" + + def _resolve_tags(self, input_val: Input, task_id: str) -> dict[str, str]: + """Resolve decorator-level tags (static dict or callable factory). + + Reserved tags (prefixed with ``_durable_task_``) are stripped to + prevent developer code from colliding with framework-stamped tags. + + :param input_val: The task input value. + :type input_val: Input + :param task_id: The task identifier. + :type task_id: str + :return: Resolved tags dictionary. + :rtype: dict[str, str] + """ + tags = self._opts.tags + if callable(tags): + result = tags(input_val, task_id) + if not isinstance(result, dict): + raise TypeError( + f"tags callable must return dict[str, str], " + f"got {type(result).__name__}" + ) + return _strip_reserved_tags(result) + return _strip_reserved_tags(dict(tags) if tags else {}) + + def _resolve_description(self, input_val: Input, task_id: str) -> str | None: + """Resolve decorator-level description (static or callable). + + :param input_val: The task input value. + :type input_val: Input + :param task_id: The task identifier. + :type task_id: str + :return: Resolved description string or None. + :rtype: str | None + """ + desc = self._opts.description + if callable(desc): + result = desc(input_val, task_id) + if result is not None and not isinstance(result, str): + raise TypeError( + f"description callable must return str or None, " + f"got {type(result).__name__}" + ) + return result + return desc + + def _merge_tags( + self, input_val: Input, task_id: str, call_tags: dict[str, str] | None + ) -> dict[str, str]: + merged = self._resolve_tags(input_val, task_id) + if call_tags: + merged.update(_strip_reserved_tags(call_tags)) + return merged + + async def run( + self, + *, + task_id: str, + input: Input, # noqa: A002 + session_id: str | None = None, + title: str | None = None, + tags: dict[str, str] | None = None, + retry: RetryPolicy | None = None, + stale_timeout: float = 300.0, + stream_handler: StreamHandler | None = None, + ) -> TaskResult[Output]: + """Run a lifecycle-aware durable task and return the result. + + Automatically starts, resumes, or recovers the task based on its + current state: + + - No task / pending → create and start (``entry_mode="fresh"``) + - Suspended → resume with new input (``entry_mode="resumed"``) + - In-progress (stale) → recover (``entry_mode="recovered"``) + - In-progress (not stale) → raise :class:`TaskConflictError` + - Completed → raise :class:`TaskConflictError` + + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword session_id: Session scope override. + :paramtype session_id: str | None + :keyword title: Title override. + :paramtype title: str | None + :keyword tags: Per-call tag overrides. + :paramtype tags: dict[str, str] | None + :keyword retry: Retry policy override. Overrides decorator-level retry. + :paramtype retry: ~azure.ai.agentserver.core.durable.RetryPolicy | None + :keyword stale_timeout: Seconds before an in-progress task is considered + stale and eligible for recovery. Default 300 (5 minutes). + :paramtype stale_timeout: float + :keyword stream_handler: Custom stream handler for pluggable streaming. + If ``None``, a default :class:`QueueStreamHandler` is used. + :paramtype stream_handler: ~azure.ai.agentserver.core.durable.StreamHandler | None + :return: The task result wrapper with output, status, and suspension info. + :rtype: ~azure.ai.agentserver.core.durable.TaskResult[Output] + :raises TaskFailed: On unhandled exception. + :raises ~azure.ai.agentserver.core.durable.TaskConflictError: If the + task is already in-progress or completed. + """ + _validate_task_id(task_id) + handle = await self._lifecycle_start( + task_id=task_id, + input=input, + session_id=session_id, + title=title, + tags=tags, + retry=retry, + stale_timeout=stale_timeout, + stream_handler=stream_handler, + ) + return await handle.result() + + async def start( + self, + *, + task_id: str, + input: Input, # noqa: A002 + session_id: str | None = None, + title: str | None = None, + tags: dict[str, str] | None = None, + retry: RetryPolicy | None = None, + stale_timeout: float = 300.0, + stream_handler: StreamHandler | None = None, + ) -> TaskRun[Output]: + """Start a lifecycle-aware durable task and return a handle. + + Follows the same lifecycle rules as :meth:`run` but returns + immediately with a :class:`TaskRun` handle instead of blocking. + + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword session_id: Session scope override. + :paramtype session_id: str | None + :keyword title: Title override. + :paramtype title: str | None + :keyword tags: Per-call tag overrides. + :paramtype tags: dict[str, str] | None + :keyword retry: Retry policy override. Overrides decorator-level retry. + :paramtype retry: ~azure.ai.agentserver.core.durable.RetryPolicy | None + :keyword stale_timeout: Seconds before an in-progress task is considered + stale and eligible for recovery. Default 300 (5 minutes). + :paramtype stale_timeout: float + :keyword stream_handler: Custom stream handler for pluggable streaming. + If ``None``, a default :class:`QueueStreamHandler` is used. + :paramtype stream_handler: ~azure.ai.agentserver.core.durable.StreamHandler | None + :return: A handle to the running task. + :rtype: TaskRun[Output] + :raises ~azure.ai.agentserver.core.durable.TaskConflictError: If the + task is already in-progress or completed. + """ + _validate_task_id(task_id) + return await self._lifecycle_start( + task_id=task_id, + input=input, + session_id=session_id, + title=title, + tags=tags, + retry=retry, + stale_timeout=stale_timeout, + stream_handler=stream_handler, + ) + + async def get(self, task_id: str) -> Any: + """Return the full persisted task information. + + Works for any task state — running, suspended, completed, etc. + Returns whatever is persisted. Returns ``None`` if no task exists. + + :param task_id: The task identifier. + :type task_id: str + :return: Task info or ``None`` if no task exists. + :rtype: ~azure.ai.agentserver.core.durable.TaskInfo | None + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return await manager.provider.get(task_id) + + async def list( + self, + *, + session_id: str | None = None, + status: TaskStatus | None = None, + ) -> list[Any]: + """List tasks created by this durable task function. + + Automatically scoped to this function's ``name`` via the + ``_durable_task_name`` tag (server-side) and ``source.type`` + (client-side). Only returns tasks created by this framework. + + :keyword session_id: Session scope override. Defaults to the + manager's configured session ID. + :paramtype session_id: str | None + :keyword status: Filter by task status (e.g., ``"in_progress"``, + ``"suspended"``, ``"completed"``). + :paramtype status: ~azure.ai.agentserver.core.durable.TaskStatus | None + :return: Matching task records. + :rtype: list[~azure.ai.agentserver.core.durable.TaskInfo] + + Example:: + + tasks = await my_task.list(status="suspended") + for t in tasks: + print(t.id, t.status) + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return await manager.list_tasks( + fn_name=self.name, + session_id=session_id, + status=status, + ) + + async def _append_steering_input( # pylint: disable=protected-access + self, + manager: Any, + *, + task_id: str, + input_val: Any, + existing: Any, + ) -> None: + """Append a steering input to the task's pending queue.""" + from ._exceptions import ( # pylint: disable=import-outside-toplevel + SteeringQueueFull, + ) + from ._models import ( # pylint: disable=import-outside-toplevel + TaskPatchRequest, + ) + + max_retries = 5 + serialized = _serialize_input(input_val) + + for _attempt in range(max_retries): + task_info = ( + existing if _attempt == 0 else await manager.provider.get(task_id) + ) + if task_info is None: + raise RuntimeError( + f"Task {task_id!r} disappeared during steering append" + ) + + payload = dict(task_info.payload) if task_info.payload else {} + steering = dict(payload.get("_steering", {})) + pending: list[Any] = list(steering.get("pending_inputs", [])) + + if len(pending) >= self._opts.max_pending: + raise SteeringQueueFull(task_id, self._opts.max_pending) + + pending.append(serialized) + steering["pending_inputs"] = pending + steering["cancel_requested"] = True + if "generation" not in steering: + steering["generation"] = 0 + payload["_steering"] = steering + + etag = getattr(task_info, "etag", None) or None + try: + await manager.provider.update( + task_id, + TaskPatchRequest(payload=payload, if_match=etag), + ) + # Signal the running task's cancel event so it can short-circuit + active = manager._active_tasks.get( + task_id + ) # pylint: disable=protected-access # noqa: SLF001 + if active and hasattr(active, "context") and active.context is not None: + active.context.cancel.set() + return + except ValueError: + # Local provider etag conflict — retry + continue + + raise RuntimeError( + f"Failed to append steering input after {max_retries} retries" + ) + + def _create_steering_ack_run( + self, + manager: Any, + task_id: str, + future: Any, + ) -> TaskRun[Output]: + """Create a TaskRun for a queued steering input.""" + return TaskRun( + task_id=task_id, + provider=manager.provider, + result_future=future, + ) + + async def _lifecycle_start( + self, + *, + task_id: str, + input: Input, # noqa: A002 + session_id: str | None, + title: str | None, + tags: dict[str, str] | None, + retry: RetryPolicy | None, + stale_timeout: float, + stream_handler: StreamHandler | None = None, + ) -> TaskRun[Output]: + """Resolve lifecycle state and start/resume/recover accordingly. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword session_id: Session scope override. + :paramtype session_id: str | None + :keyword title: Title override. + :paramtype title: str | None + :keyword tags: Per-call tag overrides. + :paramtype tags: dict[str, str] | None + :keyword retry: Retry policy override. + :paramtype retry: RetryPolicy | None + :keyword stale_timeout: Stale timeout in seconds. + :paramtype stale_timeout: float + :keyword stream_handler: Custom stream handler. Defaults to + :class:`QueueStreamHandler` when ``None``. + :paramtype stream_handler: StreamHandler | None + :return: A handle to the running task. + :rtype: TaskRun[Output] + """ + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskConflictError, + ) + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + existing = await manager.provider.get(task_id) + + resolved_retry = retry or self._opts.retry + + if existing is None or existing.status == "pending": + # Fresh start + if existing is not None and existing.status == "pending": + # Pending task exists — patch to in_progress and execute + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="fresh", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + # No task exists — create new + return await manager.create_and_start( + fn=self._fn, + fn_name=self.name, + task_id=task_id, + input_val=input, + input_type=self._input_type, + session_id=session_id, + title=title or self._resolve_title(input, task_id), + tags=self._merge_tags(input, task_id, tags), + description=self._resolve_description(input, task_id), + opts=self._opts, + retry=resolved_retry, + entry_mode="fresh", + stream_handler=stream_handler, + ) + + if existing.status == "suspended": + # Resume — patch input onto task, then start + serialized = _serialize_input(input) + from ._models import ( # pylint: disable=import-outside-toplevel + TaskPatchRequest, + ) + + await manager.provider.update( + task_id, + TaskPatchRequest(payload={"input": serialized}), + ) + # Re-fetch after input patch + updated_info = await manager.provider.get(task_id) + if updated_info is None: + raise RuntimeError(f"Task {task_id!r} disappeared after input patch") + return ( + await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=updated_info, + entry_mode="resumed", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + ) + + if existing.status == "in_progress": + if _is_stale(existing.updated_at, stale_timeout): + # Stale — check for steering recovery state first + if self._opts.steerable and existing.payload: + steering = existing.payload.get("_steering", {}) + if steering.get("drain_in_progress") or steering.get( + "pending_inputs" + ): + # Stale with steering state — recover via steered path + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="recovered", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + # Normal stale recovery + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="recovered", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + if self._opts.steerable: + # Steering path: append input to queue, signal cancel, return ack + ack_future = manager._register_steering_future( + task_id + ) # pylint: disable=protected-access + await self._append_steering_input( + manager, + task_id=task_id, + input_val=input, + existing=existing, + ) + # Set cancel on in-memory context if task runs in this process + active = manager._active_tasks.get( + task_id + ) # pylint: disable=protected-access + if active: + active.context.cancel.set() + return self._create_steering_ack_run(manager, task_id, ack_future) + raise TaskConflictError(task_id, "in_progress") + + # completed (or any other terminal status) + raise TaskConflictError(task_id, existing.status) + + def options( + self, + *, + title: str | Callable[[Any, str], str] | None = None, + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = None, + description: str | Callable[[Any, str], str | None] | None = None, + timeout: timedelta | None = None, + lease_duration_seconds: int | None = None, + store_input: bool | None = None, + ephemeral: bool | None = None, + retry: RetryPolicy | None = None, + steerable: bool | None = None, + max_pending: int | None = None, + ) -> DurableTask[Input, Output]: + """Return a new DurableTask with merged options. + + The original is unchanged. + + :keyword timeout: Execution timeout override. + :paramtype timeout: timedelta | None + :keyword ephemeral: Whether to delete task on terminal exit. + :paramtype ephemeral: bool | None + :keyword tags: Tag overrides. + :paramtype tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None + :keyword store_input: Whether to persist input. + :paramtype store_input: bool | None + :keyword retry: Retry policy override. + :paramtype retry: RetryPolicy | None + :keyword title: Title override. + :paramtype title: str | Callable[[Any, str], str] | None + :keyword description: Description override. + :paramtype description: str | Callable[[Any, str], str | None] | None + :keyword lease_duration_seconds: Lease TTL override. + :paramtype lease_duration_seconds: int | None + :keyword steerable: Whether this task accepts steering inputs. + :paramtype steerable: bool | None + :keyword max_pending: Maximum queued steering inputs. + :paramtype max_pending: int | None + :return: A new DurableTask with overridden options. + :rtype: DurableTask[Input, Output] + """ + # For tags: if both old and new are dicts, merge them. + # Mixing callable and dict is not supported — use one or the other. + resolved_tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None + if tags is not None: + if callable(tags) != callable(self._opts.tags) and self._opts.tags: + raise TypeError( + "Cannot mix callable and dict tags in options(). " + "Pass a callable to replace a callable, or a dict to merge with a dict." + ) + if callable(tags): + resolved_tags = tags + else: + existing = self._opts.tags if isinstance(self._opts.tags, dict) else {} + resolved_tags = _strip_reserved_tags({**existing, **(tags or {})}) + else: + resolved_tags = self._opts.tags + + new_opts = DurableTaskOptions( + name=self._opts.name, + title=title if title is not None else self._opts.title, + tags=resolved_tags, + description=( + description if description is not None else self._opts.description + ), + timeout=timeout if timeout is not None else self._opts.timeout, + lease_duration_seconds=( + lease_duration_seconds + if lease_duration_seconds is not None + else self._opts.lease_duration_seconds + ), + store_input=( + store_input if store_input is not None else self._opts.store_input + ), + ephemeral=(ephemeral if ephemeral is not None else self._opts.ephemeral), + retry=retry if retry is not None else self._opts.retry, + steerable=(steerable if steerable is not None else self._opts.steerable), + max_pending=( + max_pending if max_pending is not None else self._opts.max_pending + ), + ) + return DurableTask( + fn=self._fn, + opts=new_opts, + input_type=self._input_type, + output_type=self._output_type, + ) + + +@overload +def durable_task( + fn: Callable[[TaskContext[Input]], Awaitable[Output]], +) -> DurableTask[Input, Output]: ... + + +@overload +def durable_task( + *, + name: str | None = ..., + title: str | Callable[[Any, str], str] | None = ..., + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = ..., + description: str | Callable[[Any, str], str | None] | None = ..., + timeout: timedelta | None = ..., + lease_duration_seconds: int = ..., + store_input: bool = ..., + ephemeral: bool = ..., + retry: RetryPolicy | None = ..., + steerable: bool = ..., + max_pending: int = ..., +) -> Callable[ + [Callable[[TaskContext[Input]], Awaitable[Output]]], + DurableTask[Input, Output], +]: ... + + +def durable_task( + fn: Callable[..., Any] | None = None, + *, + name: str | None = None, + title: str | Callable[[Any, str], str] | None = None, + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = None, + description: str | Callable[[Any, str], str | None] | None = None, + timeout: timedelta | None = None, + lease_duration_seconds: int = 60, + store_input: bool = True, + ephemeral: bool = True, + retry: RetryPolicy | None = None, + steerable: bool = False, + max_pending: int = 10, +) -> Any: + """Turn an async function into a crash-resilient durable task. + + Can be used with or without arguments:: + + @durable_task + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: ... + + @durable_task(name="custom-name", ephemeral=False) + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: ... + + :param fn: The async function to decorate (when used without parens). + :type fn: Callable[..., Any] | None + :keyword name: **Stable identity anchor.** Used for recovery routing and + source stamping. Defaults to ``fn.__qualname__``. Always provide an + explicit name for production tasks — if you rename the function later, + existing in-flight tasks are still recovered correctly because the + framework matches on this name, not the Python function name. + :keyword title: Human-readable title (string or callable). + :keyword tags: Default tags (static dict or callable factory receiving + ``(input, task_id)``). Merged with per-call ``tags=`` overrides. + :keyword description: Task description (static string or callable factory + receiving ``(input, task_id)``). + :keyword timeout: Execution timeout. When elapsed, ``ctx.cancel`` is set + cooperatively. If the function does not exit, the lease eventually + expires and the task is recovered. + :keyword lease_duration_seconds: Lease TTL (default 60). + :keyword store_input: Whether to persist input on the task record. + :keyword ephemeral: Delete task on terminal exit (default True). + :keyword retry: Default retry policy for this task. + :keyword steerable: Whether this task accepts steering inputs. When True, + calling ``start()`` on an ``in_progress`` task queues the input and + signals cancel instead of raising ``TaskConflictError``. Default False. + :keyword max_pending: Maximum number of queued steering inputs. Default 10. + :return: A ``DurableTask[Input, Output]`` wrapper. + :rtype: Any + """ + + def _wrap( + func: Callable[..., Any], + ) -> DurableTask[Any, Any]: + if not asyncio.iscoroutinefunction(func): + raise TypeError( + f"@durable_task requires an async function, " + f"got {func.__qualname__!r}" + ) + + if lease_duration_seconds < 1: + raise ValueError( + f"lease_duration_seconds must be >= 1, got {lease_duration_seconds}" + ) + + if max_pending < 1: + raise ValueError(f"max_pending must be >= 1, got {max_pending}") + + input_type, output_type = _extract_generic_args(func) + + # Preserve callable tags as-is (stripped at resolve time); strip static dicts now + resolved_tags = ( + tags if callable(tags) else _strip_reserved_tags(dict(tags) if tags else {}) + ) + + opts = DurableTaskOptions( + name=name or func.__qualname__, + title=title, + tags=resolved_tags, + description=description, + timeout=timeout, + lease_duration_seconds=lease_duration_seconds, + store_input=store_input, + ephemeral=ephemeral, + retry=retry, + steerable=steerable, + max_pending=max_pending, + ) + + task = DurableTask( + fn=func, + opts=opts, + input_type=input_type, + output_type=output_type, + ) + return task + + if fn is not None: + return _wrap(fn) + return _wrap diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_exceptions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_exceptions.py new file mode 100644 index 000000000000..45a6b75ae7bf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_exceptions.py @@ -0,0 +1,162 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Exception types for the durable task subsystem.""" + +from typing import Any + + +class TaskFailed(Exception): + """Raised when a durable task function raises an unhandled exception. + + :param task_id: The identifier of the failed task. + :type task_id: str + :param error: Structured error details captured from the exception. + :type error: dict[str, Any] + """ + + def __init__(self, task_id: str, error: dict[str, Any]) -> None: + self.task_id = task_id + self.error = error + message = error.get("message", "Task failed") + super().__init__(f"Task {task_id!r} failed: {message}") + + +class TaskSuspended(Exception): + """Raised when awaiting the result of a suspended task. + + :param task_id: The identifier of the suspended task. + :type task_id: str + :param reason: Human-readable suspension reason, if provided. + :type reason: str | None + :param output: Optional output snapshot set at suspension time. + :type output: Any | None + """ + + def __init__( + self, + task_id: str, + reason: str | None = None, + output: Any | None = None, + ) -> None: + self.task_id = task_id + self.reason = reason + self.output = output + suffix = f": {reason}" if reason else "" + super().__init__(f"Task {task_id!r} is suspended{suffix}") + + +class TaskCancelled(Exception): + """Raised when a durable task is cancelled. + + Inherits from :class:`Exception` rather than :class:`asyncio.CancelledError` + to prevent unintentional suppression by generic ``CancelledError`` handlers + in the asyncio event loop. + + :param task_id: The identifier of the cancelled task. + :type task_id: str + """ + + def __init__(self, task_id: str) -> None: + self.task_id = task_id + super().__init__(f"Task {task_id!r} was cancelled") + + +class TaskNotFound(Exception): + """Raised when a task ID is not found in the store. + + :param task_id: The identifier that was not found. + :type task_id: str + """ + + def __init__(self, task_id: str) -> None: + self.task_id = task_id + super().__init__(f"Task {task_id!r} not found") + + +class TaskTerminated(Exception): + """Raised when a task is forcefully terminated via ``handle.terminate()``. + + Unlike :class:`TaskCancelled`, terminated tasks go through the failure + path and do NOT stay ``in_progress`` for recovery. + + :param task_id: The identifier of the terminated task. + :type task_id: str + :param reason: Optional human-readable termination reason. + :type reason: str | None + """ + + __slots__ = ("task_id", "reason") + + def __init__(self, task_id: str, reason: str | None = None) -> None: + self.task_id = task_id + self.reason = reason + suffix = f": {reason}" if reason else "" + super().__init__(f"Task {task_id!r} was terminated{suffix}") + + +class TaskConflictError(RuntimeError): + """Raised when a task lifecycle conflict cannot be resolved. + + Raised by ``.run()`` or ``.start()`` when the task is already + ``in_progress`` (non-stale) or ``completed``. The lifecycle is + deterministic: create if none, start if pending, resume if suspended, + throw if in-progress or completed. + + :param task_id: The conflicting task's ID. + :type task_id: str + :param current_status: The task's current status. + :type current_status: str + """ + + __slots__ = ("task_id", "current_status") + + def __init__( + self, + task_id: str, + current_status: str, + ) -> None: + self.task_id = task_id + self.current_status = current_status + super().__init__(f"Task '{task_id}' is already {current_status}") + + +class EtagConflict(RuntimeError): + """Raised when an optimistic concurrency (etag) check fails. + + The task record was modified between read and write. Callers should + retry the operation with the updated etag. + + :param task_id: The task ID where the conflict occurred. + :type task_id: str + :param message: Optional detail message. + :type message: str | None + """ + + __slots__ = ("task_id",) + + def __init__(self, task_id: str, message: str | None = None) -> None: + self.task_id = task_id + msg = message or f"Etag conflict on task '{task_id}'" + super().__init__(msg) + + +class SteeringQueueFull(RuntimeError): + """Raised when the steering pending-input queue is at capacity. + + The caller should retry later or increase ``max_pending``. + + :param task_id: The task whose queue is full. + :type task_id: str + :param max_pending: The configured queue capacity. + :type max_pending: int + """ + + __slots__ = ("task_id", "max_pending") + + def __init__(self, task_id: str, max_pending: int) -> None: + self.task_id = task_id + self.max_pending = max_pending + super().__init__( + f"Steering queue full for task '{task_id}' " f"(max_pending={max_pending})" + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_lease.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_lease.py new file mode 100644 index 000000000000..cb5f186d3e5d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_lease.py @@ -0,0 +1,155 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Lease identity derivation and renewal loop for durable tasks. + +Provides utility functions for constructing stable lease owner strings, +generating ephemeral instance IDs, and running the background lease +renewal loop. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import logging +import os +import time +import uuid +from collections.abc import Awaitable, Callable + +from ._models import TaskPatchRequest +from ._provider import DurableTaskProvider + +logger = logging.getLogger("azure.ai.agentserver.durable") + + +def derive_lease_owner(session_id: str) -> str: + """Derive a stable lease owner string from the session ID. + + The owner is stable across process restarts within the same session, + enabling dual-identity lease reclamation. + + :param session_id: The agent session identifier. + :type session_id: str + :return: A lease owner string in the format ``"session:{session_id}"``. + :rtype: str + """ + return f"session:{session_id}" + + +def generate_instance_id() -> str: + """Generate an ephemeral lease instance ID unique to this process. + + Combines the PID and a timestamp to ensure uniqueness even after + rapid restarts. + + :return: A unique instance identifier. + :rtype: str + """ + return f"worker-{os.getpid()}-{uuid.uuid4().hex[:8]}-{int(time.time())}" + + +async def lease_renewal_loop( + provider: DurableTaskProvider, + task_id: str, + *, + lease_owner: str, + lease_instance_id: str, + lease_duration_seconds: int, + cancel_event: asyncio.Event, + on_failure_count: int = 3, + on_cancel_callback: asyncio.Event | None = None, + steering_poll_callback: Callable[[], Awaitable[None]] | None = None, +) -> None: + """Run a background lease renewal loop at half the lease duration. + + Renews the lease by PATCHing the task with the same owner/instance. + On ``on_failure_count`` consecutive failures, signals the optional + ``on_cancel_callback`` event to give the task function a chance to + checkpoint. + + The loop exits when ``cancel_event`` is set or the task is cancelled. + + :param provider: The storage provider. + :type provider: DurableTaskProvider + :param task_id: The task to renew. + :type task_id: str + :keyword lease_owner: The stable lease owner. + :paramtype lease_owner: str + :keyword lease_instance_id: The ephemeral instance ID. + :paramtype lease_instance_id: str + :keyword lease_duration_seconds: The lease TTL in seconds. + :paramtype lease_duration_seconds: int + :keyword cancel_event: Event that stops the loop when set. + :paramtype cancel_event: asyncio.Event + :keyword on_failure_count: Consecutive failures before signalling cancel. + :paramtype on_failure_count: int + :keyword on_cancel_callback: Event to signal on repeated renewal failure. + :paramtype on_cancel_callback: asyncio.Event | None + :keyword steering_poll_callback: Async callback invoked each renewal to poll + for steering inputs. Called after successful lease renewal. + :paramtype steering_poll_callback: Callable[[], Awaitable[None]] | None + """ + interval = max(1, lease_duration_seconds // 2) + consecutive_failures = 0 + + while not cancel_event.is_set(): + try: + await asyncio.wait_for( + _wait_for_event(cancel_event), + timeout=interval, + ) + # cancel_event was set — exit the loop + break + except asyncio.TimeoutError: + pass + + try: + await provider.update( + task_id, + TaskPatchRequest( + lease_owner=lease_owner, + lease_instance_id=lease_instance_id, + lease_duration_seconds=lease_duration_seconds, + ), + ) + consecutive_failures = 0 + logger.debug("Lease renewed for task %s", task_id) + + # Poll for steering inputs after successful renewal + if steering_poll_callback is not None: + try: + await steering_poll_callback() + except Exception: # pylint: disable=broad-exception-caught + logger.debug( + "Steering poll failed for task %s", task_id, exc_info=True + ) + except Exception: # pylint: disable=broad-exception-caught + consecutive_failures += 1 + logger.warning( + "Lease renewal failed for task %s (attempt %d/%d)", + task_id, + consecutive_failures, + on_failure_count, + exc_info=True, + ) + if ( + consecutive_failures >= on_failure_count + and on_cancel_callback is not None + ): + logger.error( + "Lease renewal failed %d times for task %s — signalling cancellation", + on_failure_count, + task_id, + ) + on_cancel_callback.set() + break + + +async def _wait_for_event(event: asyncio.Event) -> None: + """Await an asyncio event. Used with ``wait_for`` for interruptible sleep. + + :param event: The asyncio event to wait for. + :type event: asyncio.Event + """ + await event.wait() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_local_provider.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_local_provider.py new file mode 100644 index 000000000000..da187a518398 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_local_provider.py @@ -0,0 +1,377 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Local filesystem-backed durable task provider. + +Stores tasks as JSON files under ``$HOME/.durable-tasks/{agent_name}/{session_id}/`` +for local development with full lifecycle parity. +""" + +from __future__ import annotations + +import datetime +import hashlib +import json +import logging +import os +from pathlib import Path +from typing import Any + +from ._exceptions import TaskNotFound +from ._models import ( + LeaseInfo, + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, + TaskStatus, +) + +logger = logging.getLogger("azure.ai.agentserver.durable") + + +def _now_iso() -> str: + return datetime.datetime.now(datetime.timezone.utc).isoformat() + + +def _generate_etag(data: dict[str, Any]) -> str: + raw = json.dumps(data, sort_keys=True) + return f"local-{hashlib.sha256(raw.encode()).hexdigest()[:16]}" + + +def _is_lease_expired(lease: LeaseInfo | None) -> bool: + if lease is None: + return True + try: + expires = datetime.datetime.fromisoformat(lease.expires_at) + now = datetime.datetime.now(datetime.timezone.utc) + return now >= expires + except (ValueError, TypeError): + return True + + +class LocalFileDurableTaskProvider: + """Filesystem-backed provider for local development. + + Tasks are stored as individual JSON files. Lease expiry is simulated + by checking timestamps on read. + + :param base_dir: Root directory for task storage. + Defaults to ``$HOME/.durable-tasks``. + :type base_dir: Path | None + """ + + def __init__(self, base_dir: Path | None = None) -> None: + self._base_dir = base_dir or Path.home() / ".durable-tasks" + + def _task_dir(self, agent_name: str, session_id: str) -> Path: + return self._base_dir / agent_name / session_id + + def _task_path(self, agent_name: str, session_id: str, task_id: str) -> Path: + return self._task_dir(agent_name, session_id) / f"{task_id}.json" + + def _find_task_path(self, task_id: str) -> Path | None: + """Search all agent/session dirs for a task file. + + :param task_id: The task identifier. + :type task_id: str + :return: The path to the task file, or None. + :rtype: ~pathlib.Path | None + """ + if not self._base_dir.exists(): + return None + for agent_dir in self._base_dir.iterdir(): + if not agent_dir.is_dir(): + continue + for session_dir in agent_dir.iterdir(): + if not session_dir.is_dir(): + continue + path = session_dir / f"{task_id}.json" + if path.exists(): + return path + return None + + def _read_task(self, path: Path) -> TaskInfo | None: + if not path.exists(): + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + return TaskInfo.from_dict(data) + except (json.JSONDecodeError, KeyError): + logger.warning("Corrupt task file: %s", path) + return None + + def _write_task(self, task: TaskInfo) -> None: + path = self._task_path(task.agent_name, task.session_id, task.id) + path.parent.mkdir(parents=True, exist_ok=True) + data = task.to_dict() + data["etag"] = _generate_etag(data) + task.etag = data["etag"] + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task as a JSON file. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + now = _now_iso() + task_id = request.id or f"task-{os.urandom(8).hex()}" + + lease: LeaseInfo | None = None + started_at: str | None = None + status: TaskStatus = request.status + + if ( + request.lease_owner + and request.lease_instance_id + and request.lease_duration_seconds + ): + expires_at = ( + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(seconds=request.lease_duration_seconds) + ).isoformat() + lease = LeaseInfo( + owner=request.lease_owner, + instance_id=request.lease_instance_id, + generation=0, + expires_at=expires_at, + expiry_count=0, + ) + if status == "in_progress": + started_at = now + + task = TaskInfo( + id=task_id, + agent_name=request.agent_name, + session_id=request.session_id, + status=status, + title=request.title, + description=request.description, + lease=lease, + payload=request.payload, + tags=request.tags, + source=request.source, + created_at=now, + updated_at=now, + started_at=started_at, + ) + self._write_task(task) + logger.debug("Created local task %s", task_id) + return task + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a task by ID from the filesystem. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + path = self._find_task_path(task_id) + if path is None: + return None + return self._read_task(path) + + async def update( + self, task_id: str, patch: TaskPatchRequest + ) -> TaskInfo: # pylint: disable=too-many-branches,too-many-statements + """Update a task via PATCH semantics. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + path = self._find_task_path(task_id) + if path is None: + raise TaskNotFound(task_id) + + task = self._read_task(path) + if task is None: + raise TaskNotFound(task_id) + + # ETag check + if patch.if_match is not None and patch.if_match != task.etag: + raise ValueError( + f"ETag mismatch: expected {patch.if_match!r}, " f"got {task.etag!r}" + ) + + now = _now_iso() + + if patch.status is not None: + old_status = task.status # noqa: F841 # pylint: disable=unused-variable + task.status = patch.status + + if patch.status == "in_progress" and task.started_at is None: + task.started_at = now + if patch.status == "completed": + task.completed_at = now + if patch.status == "suspended": + task.suspension_reason = patch.suspension_reason + + # Lease handling on status transitions + if patch.status in ("completed", "suspended"): + task.lease = None + elif ( + patch.status == "in_progress" + and patch.lease_owner + and patch.lease_instance_id + ): + duration = patch.lease_duration_seconds or 60 + expires_at = ( + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(seconds=duration) + ).isoformat() + old_gen = task.lease.generation if task.lease else -1 + new_gen = ( + old_gen + 1 + if patch.lease_instance_id + != (task.lease.instance_id if task.lease else "") + else max(old_gen, 0) + ) + task.lease = LeaseInfo( + owner=patch.lease_owner, + instance_id=patch.lease_instance_id, + generation=new_gen, + expires_at=expires_at, + expiry_count=task.lease.expiry_count if task.lease else 0, + ) + + # Lease renewal (no status change) + if patch.status is None and patch.lease_owner and patch.lease_instance_id: + duration = patch.lease_duration_seconds or 60 + expires_at = ( + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(seconds=duration) + ).isoformat() + if task.lease and patch.lease_instance_id != task.lease.instance_id: + # Reclaim with new instance + task.lease = LeaseInfo( + owner=patch.lease_owner, + instance_id=patch.lease_instance_id, + generation=task.lease.generation + 1, + expires_at=expires_at, + expiry_count=task.lease.expiry_count, + ) + elif task.lease: + # Simple renewal + task.lease = LeaseInfo( + owner=task.lease.owner, + instance_id=task.lease.instance_id, + generation=task.lease.generation, + expires_at=expires_at, + expiry_count=task.lease.expiry_count, + ) + else: + task.lease = LeaseInfo( + owner=patch.lease_owner, + instance_id=patch.lease_instance_id, + generation=0, + expires_at=expires_at, + ) + + # Force-expire: lease_duration_seconds=0 + if patch.lease_duration_seconds == 0 and task.lease: + task.lease = LeaseInfo( + owner=task.lease.owner, + instance_id=task.lease.instance_id, + generation=task.lease.generation, + expires_at=_now_iso(), + expiry_count=task.lease.expiry_count, + ) + + # Payload shallow-merge (spec §11: root-level additive, values replaced) + if patch.payload is not None: + if task.payload is None: + task.payload = {} + for key, value in patch.payload.items(): + task.payload[key] = value + + # Tags null-as-delete merge + if patch.tags is not None: + if task.tags is None: + task.tags = {} + for key, value in patch.tags.items(): + if value is None: + task.tags.pop(key, None) + else: + task.tags[key] = value + + if patch.error is not None: + task.error = patch.error + + task.updated_at = now + self._write_task(task) + return task + + async def delete( + self, + task_id: str, + *, + force: bool = False, # pylint: disable=unused-argument + cascade: bool = False, # pylint: disable=unused-argument + ) -> None: + """Delete a task JSON file. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks (no-op for local). + :paramtype cascade: bool + """ + path = self._find_task_path(task_id) + if path is None: + raise TaskNotFound(task_id) + path.unlink(missing_ok=True) + logger.debug("Deleted local task %s", task_id) + + async def list( + self, + *, + agent_name: str, + session_id: str, + status: TaskStatus | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + ) -> list[TaskInfo]: + """List tasks from the filesystem. + + :keyword agent_name: Filter by agent name. + :paramtype agent_name: str + :keyword session_id: Filter by session ID. + :paramtype session_id: str + :keyword status: Filter by task status. + :paramtype status: TaskStatus | None + :keyword lease_owner: Filter by lease owner. + :paramtype lease_owner: str | None + :keyword tag: Filter by tags (AND semantics — all must match). + :paramtype tag: dict[str, str] | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + task_dir = self._task_dir(agent_name, session_id) + if not task_dir.exists(): + return [] + + results: list[TaskInfo] = [] + for path in task_dir.glob("*.json"): + task = self._read_task(path) + if task is None: + continue + if status is not None and task.status != status: + continue + if lease_owner is not None: + if task.lease is None or task.lease.owner != lease_owner: + continue + if tag is not None: + task_tags = task.tags or {} + if not all(task_tags.get(k) == v for k, v in tag.items()): + continue + results.append(task) + return results diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_manager.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_manager.py new file mode 100644 index 000000000000..6332493f61db --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_manager.py @@ -0,0 +1,1673 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""DurableTaskManager — lifecycle orchestration for durable tasks. + +Manages task creation, lease acquisition, execution, recovery, and +shutdown. One instance per ``AgentServerHost``, accessed via the +module-level ``get_task_manager()`` function. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import logging +import traceback +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any, TypeVar + +from .._config import AgentConfig +from ._context import EntryMode, TaskContext +from ._decorator import DurableTaskOptions, _deserialize_input, _serialize_input +from ._exceptions import TaskFailed, TaskNotFound +from ._lease import derive_lease_owner, generate_instance_id, lease_renewal_loop +from ._metadata import TaskMetadata +from ._models import TaskCreateRequest, TaskInfo, TaskPatchRequest, TaskStatus +from ._provider import DurableTaskProvider +from ._result import TaskResult +from ._retry import RetryPolicy +from ._run import Suspended, TaskRun +from ._stream import QueueStreamHandler, StreamHandler +from .._version import VERSION as _CORE_VERSION +from .._server_version import build_server_version as _build_server_version + +logger = logging.getLogger("azure.ai.agentserver.durable") + +#: Auto-stamped source type for all tasks created by this framework. +_SOURCE_TYPE = "agentserver.durable_task" + +#: Reserved tag key for task name filtering via the LIST API. +_TAG_TASK_NAME = "_durable_task_name" + +#: Pre-computed server version segment for source stamps. +_SOURCE_SERVER_VERSION = _build_server_version( + "azure-ai-agentserver-core", _CORE_VERSION +) + +Input = TypeVar("Input") +Output = TypeVar("Output") + +# Module-level manager singleton +_manager: DurableTaskManager | None = None + + +def get_task_manager() -> DurableTaskManager: + """Return the active DurableTaskManager singleton. + + :raises RuntimeError: If no manager has been initialized. + :return: The active manager. + :rtype: DurableTaskManager + """ + if _manager is None: + raise RuntimeError( + "DurableTaskManager not initialized. Ensure durable tasks " + "are enabled on the AgentServerHost." # pylint: disable=implicit-str-concat + ) + return _manager + + +def set_task_manager(manager: DurableTaskManager | None) -> None: + """Set the module-level DurableTaskManager singleton. + + Called by ``AgentServerHost`` during startup/shutdown. + + :param manager: The manager to set, or ``None`` to clear. + :type manager: DurableTaskManager | None + """ + global _manager # pylint: disable=global-statement + _manager = manager + + +class _ActiveTask: # pylint: disable=too-many-instance-attributes + """In-memory tracking for a running task.""" + + __slots__ = ( + "task_id", + "fn_name", + "context", + "execution_task", + "renewal_task", + "renewal_cancel", + "result_future", + "terminate_event", + "fn", + "input_type", + "opts", + "retry", + ) + + def __init__( + self, + task_id: str, + fn_name: str, + context: TaskContext[Any], + execution_task: asyncio.Task[Any], + renewal_task: asyncio.Task[None] | None, + renewal_cancel: asyncio.Event, + result_future: asyncio.Future[Any], + terminate_event: asyncio.Event | None = None, + fn: Callable[..., Awaitable[Any]] | None = None, + input_type: type[Any] | None = None, + opts: DurableTaskOptions | None = None, + retry: RetryPolicy | None = None, + ) -> None: + self.task_id = task_id + self.fn_name = fn_name + self.context = context + self.execution_task = execution_task + self.renewal_task = renewal_task + self.renewal_cancel = renewal_cancel + self.result_future = result_future + self.terminate_event = terminate_event or asyncio.Event() + self.fn = fn + self.input_type = input_type + self.opts = opts + self.retry = retry + + +class DurableTaskManager: + """Lifecycle orchestrator for durable tasks. + + Manages provider selection, task creation, lease management, + execution dispatch, crash recovery, and graceful shutdown. + + :param config: Resolved agent configuration. + :type config: AgentConfig + :param provider: Optional explicit provider (for testing). + :type provider: DurableTaskProvider | None + :param shutdown_event: Shared shutdown event from the host. + :type shutdown_event: asyncio.Event | None + :param shutdown_grace_seconds: Seconds to wait for tasks to checkpoint + before force-expiring leases during shutdown. Defaults to 25.0. + :type shutdown_grace_seconds: float + """ + + def __init__( + self, + config: AgentConfig, + *, + provider: DurableTaskProvider | None = None, + shutdown_event: asyncio.Event | None = None, + shutdown_grace_seconds: float = 25.0, + ) -> None: + self._config = config + self._provider = provider or self._create_provider(config) + self._active_tasks: dict[str, _ActiveTask] = {} + self._resume_callbacks: dict[str, Callable[..., Any]] = {} + self._lease_owner = derive_lease_owner(config.session_id or "local") + self._instance_id = generate_instance_id() + self._shutdown_event = shutdown_event or asyncio.Event() + self._shutdown_grace_seconds = shutdown_grace_seconds + self._active_generation_future: dict[str, asyncio.Future[Any]] = {} + self._pending_steering_futures: dict[str, list[asyncio.Future[Any]]] = {} + + @staticmethod + def _build_source(fn_name: str) -> dict[str, str]: + """Build the framework-owned source stamp for a task. + + The ``fn_name`` is the developer-provided ``name`` from the decorator + (or ``fn.__qualname__`` when omitted). It serves as the **stable + identity anchor** — recovery routing matches ``source.name`` against + registered callbacks to dispatch recovered tasks back to the correct + function. + + :param fn_name: The task name (from ``@durable_task(name=...)``). + :type fn_name: str + :return: Source metadata dict. + :rtype: dict[str, str] + """ + return { + "type": _SOURCE_TYPE, + "name": fn_name, + "server_version": _SOURCE_SERVER_VERSION, + } + + @staticmethod + def _create_provider(config: AgentConfig) -> DurableTaskProvider: + """Auto-select provider based on hosting environment. + + The Task Storage API is not yet generally available. To avoid + failures in hosted environments, the local file-based provider + is used by default even when ``FOUNDRY_HOSTING_ENVIRONMENT`` is + set. Set the ``FOUNDRY_TASK_API_ENABLED=1`` environment variable + to opt in to the HTTP-backed provider for testing once the APIs + are lit up. + + :param config: The agent configuration. + :type config: AgentConfig + :return: The storage provider instance. + :rtype: DurableTaskProvider + """ + import os # pylint: disable=import-outside-toplevel + + task_api_enabled = os.environ.get("FOUNDRY_TASK_API_ENABLED", "").strip() + + if config.is_hosted and task_api_enabled in ("1", "true", "yes"): + from ._client import ( # pylint: disable=import-outside-toplevel + HostedDurableTaskProvider, + ) + + try: + from azure.identity.aio import ( # type: ignore[import-untyped] + DefaultAzureCredential, + ) + except ImportError as exc: + raise ImportError( + "azure-identity is required for hosted mode. " + "Install with: pip install azure-ai-agentserver-core[hosted]" + ) from exc + + logger.info( + "Task Storage API enabled via FOUNDRY_TASK_API_ENABLED; " # pylint: disable=implicit-str-concat + "using HostedDurableTaskProvider" + ) + return HostedDurableTaskProvider( + project_endpoint=config.project_endpoint, + credential=DefaultAzureCredential(), + ) + + if config.is_hosted and not task_api_enabled: + logger.info( + "Hosted environment detected but Task Storage API not yet enabled. " + "Using local file provider. Set FOUNDRY_TASK_API_ENABLED=1 to use " + "the HTTP-backed provider when the APIs are available." + ) + + from ._local_provider import ( # pylint: disable=import-outside-toplevel + LocalFileDurableTaskProvider, + ) + + return LocalFileDurableTaskProvider(base_dir=Path.home() / ".durable-tasks") + + @property + def provider(self) -> DurableTaskProvider: + """The storage provider. + + :return: The active provider. + :rtype: DurableTaskProvider + """ + return self._provider + + def register_resume_callback( + self, + fn_name: str, + fn: Callable[..., Any], + ) -> None: + """Register a function as a resume callback. + + :param fn_name: The durable task function name. + :type fn_name: str + :param fn: The async function to call on resume. + :type fn: Callable[..., Any] + """ + self._resume_callbacks[fn_name] = fn + + self._resume_callbacks[fn_name] = fn + + async def list_tasks( + self, + *, + fn_name: str, + session_id: str | None = None, + status: TaskStatus | None = None, + ) -> list[TaskInfo]: + """List tasks scoped to a specific durable task function. + + Uses server-side filtering (``agent_name``, ``session_id``, + ``_durable_task_name`` tag, ``status``) and client-side filtering + (``source.type``) to return only tasks created by this framework + for the given function. + + :keyword fn_name: The task function name (stable identity anchor). + :paramtype fn_name: str + :keyword session_id: Session scope override. Defaults to config. + :paramtype session_id: str | None + :keyword status: Filter by task status. + :paramtype status: ~azure.ai.agentserver.core.durable.TaskStatus | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + resolved_session = session_id or self._config.session_id or "local" + agent_name = self._config.agent_name or "default" + + # Server-side filters: agent_name, session_id, tag, status + results = await self._provider.list( + agent_name=agent_name, + session_id=resolved_session, + status=status, + tag={_TAG_TASK_NAME: fn_name}, + ) + + # Client-side filter: source.type (until source_type server filter exists) + return [ + task + for task in results + if task.source and task.source.get("type") == _SOURCE_TYPE + ] + + def _register_steering_future(self, task_id: str) -> asyncio.Future[Any]: + """Create and register a future for a queued steering input. + + Must be called BEFORE ``_append_steering_input()`` to avoid a race + where the drain pops the queue before the future exists. + + :param task_id: The task identifier. + :type task_id: str + :return: The registered future. + :rtype: asyncio.Future[Any] + """ + loop = asyncio.get_event_loop() + future: asyncio.Future[Any] = loop.create_future() + if task_id not in self._pending_steering_futures: + self._pending_steering_futures[task_id] = [] + self._pending_steering_futures[task_id].append(future) + return future + + async def startup(self) -> None: + """Initialize the manager and recover stale tasks. + + Called by ``AgentServerHost`` during lifespan startup. + """ + logger.info( + "DurableTaskManager starting (owner=%s, instance=%s, hosted=%s)", + self._lease_owner, + self._instance_id, + self._config.is_hosted, + ) + await self._recover_stale_tasks() + + async def shutdown(self) -> None: + """Signal shutdown on all active tasks and force-expire leases. + + Called by ``AgentServerHost`` during lifespan shutdown. + """ + logger.info("DurableTaskManager shutting down") + self._shutdown_event.set() + + # Signal shutdown on all active contexts + for active in self._active_tasks.values(): + active.context.shutdown.set() + + # Wait for tasks to checkpoint before force-expiring leases + if self._active_tasks: + await asyncio.sleep(self._shutdown_grace_seconds) + + # Force-expire all leases + for active in list(self._active_tasks.values()): + try: + await self._provider.update( + active.task_id, + TaskPatchRequest( + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=0, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to force-expire lease for task %s", + active.task_id, + exc_info=True, + ) + + # Cancel all renewal and execution tasks + for active in self._active_tasks.values(): + active.renewal_cancel.set() + if active.renewal_task and not active.renewal_task.done(): + active.renewal_task.cancel() + if not active.execution_task.done(): + active.execution_task.cancel() + + self._active_tasks.clear() + set_task_manager(None) + + async def create_and_run( + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_id: str, + input_val: Any, + input_type: type[Any], + session_id: str | None, + title: str, + tags: dict[str, str], + opts: DurableTaskOptions, + retry: RetryPolicy | None = None, + entry_mode: EntryMode = "fresh", + ) -> Any: + """Create a task, run the function, and return the result. + + :keyword fn: The async function to execute. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: The registered function name. + :paramtype fn_name: str + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input_val: The input value. + :paramtype input_val: Any + :keyword input_type: The input type. + :paramtype input_type: type[Any] + :keyword session_id: Session scope. + :paramtype session_id: str | None + :keyword tags: Task tags. + :paramtype tags: dict[str, str] + :keyword opts: Task options. + :paramtype opts: DurableTaskOptions + :keyword entry_mode: Entry mode. + :paramtype entry_mode: EntryMode + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :keyword title: Human-readable title. + :paramtype title: str + :returns: The function's return value. + :rtype: Any + :raises TaskFailed: On unhandled exception. + :raises TaskSuspended: If the function suspends. + """ + handle = await self.create_and_start( + fn=fn, + fn_name=fn_name, + task_id=task_id, + input_val=input_val, + input_type=input_type, + session_id=session_id, + title=title, + tags=tags, + opts=opts, + retry=retry, + entry_mode=entry_mode, + ) + return await handle.result() + + async def create_and_start( # pylint: disable=too-many-locals + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_id: str, + input_val: Any, + input_type: type[Any], # pylint: disable=unused-argument + session_id: str | None, + title: str, + tags: dict[str, str], + description: str | None = None, + opts: DurableTaskOptions, + retry: RetryPolicy | None = None, + entry_mode: EntryMode = "fresh", + stream_handler: StreamHandler | None = None, + ) -> TaskRun[Any]: + """Create a task, start the function, and return a handle. + + Source provenance is auto-stamped by the framework using + ``fn_name`` and the core SDK version. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: Function name for logging. + :paramtype fn_name: str + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword input_val: The task input value. + :paramtype input_val: Any + :keyword input_type: Type for deserializing input. + :paramtype input_type: type[Any] + :keyword session_id: Session scope identifier. + :paramtype session_id: str | None + :keyword title: Human-readable task title. + :paramtype title: str + :keyword tags: Merged decorator + call-site tags. + :paramtype tags: dict[str, str] + :keyword description: Optional task description. + :paramtype description: str | None + :keyword opts: Task options. + :paramtype opts: DurableTaskOptions + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :keyword entry_mode: Why this execution is starting. + :paramtype entry_mode: EntryMode + :keyword stream_handler: Custom stream handler. If ``None``, + a default :class:`QueueStreamHandler` is created. + :paramtype stream_handler: StreamHandler | None + :return: A ``TaskRun`` handle. + :rtype: TaskRun + """ + resolved_session = session_id or self._config.session_id or "local" + agent_name = self._config.agent_name or "default" + + # Build payload + payload: dict[str, Any] = {} + if opts.store_input: + payload["input"] = _serialize_input(input_val) + payload["metadata"] = {} + + # Auto-stamp source provenance (framework-owned, not user-overridable) + source = self._build_source(fn_name) + + # Auto-stamp task name tag for LIST filtering + if tags is None: + tags = {} + tags[_TAG_TASK_NAME] = fn_name + + # Create task with lease + task_info = await self._provider.create( + TaskCreateRequest( + id=task_id, + agent_name=agent_name, + session_id=resolved_session, + status="in_progress", + title=title, + description=description, + payload=payload, + tags=tags or None, + source=source, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=opts.lease_duration_seconds, + ) + ) + + logger.info("Created durable task %s (%s)", task_id, fn_name) + + # Register resume callback + self._resume_callbacks[fn_name] = fn + + # Build context + cancel_event = asyncio.Event() + handler = stream_handler or QueueStreamHandler() + metadata = TaskMetadata( + flush_callback=self._make_metadata_flush(task_id), + flush_interval=5.0, + ) + + lease_gen = task_info.lease.generation if task_info.lease else 0 + + ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + title=title, + description=description, + session_id=resolved_session, + agent_name=agent_name, + tags=tags, + input=input_val, + metadata=metadata, + run_attempt=0, + lease_generation=lease_gen, + cancel=cancel_event, + shutdown=self._shutdown_event, + stream_handler=handler, + entry_mode=entry_mode, + generation=0, + ) + loop = asyncio.get_event_loop() + result_future: asyncio.Future[Any] = loop.create_future() + + # Start lease renewal + renewal_cancel = asyncio.Event() + + # Build steering poll callback for steerable tasks + steering_poll_cb_cs: Callable[[], Awaitable[None]] | None = None + if opts.steerable: + + async def _steering_poll_cs() -> None: + active = self._active_tasks.get(task_id) + if active is None or active.context.cancel.is_set(): + return + info = await self._provider.get(task_id) + if info is None or not info.payload: + return + st = info.payload.get("_steering", {}) + if st.get("pending_inputs"): + active.context.cancel.set() + + steering_poll_cb_cs = _steering_poll_cs + + renewal_task = asyncio.create_task( + lease_renewal_loop( + self._provider, + task_id, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=opts.lease_duration_seconds, + cancel_event=renewal_cancel, + on_cancel_callback=cancel_event, + steering_poll_callback=steering_poll_cb_cs, + ) + ) + + # Start execution + terminate_event = asyncio.Event() + terminate_reason_ref: list[str | None] = [None] + execution_task = asyncio.create_task( + self._execute_task( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=terminate_event, + terminate_reason_ref=terminate_reason_ref, + ) + ) + + # Track active task + active = _ActiveTask( + task_id=task_id, + fn_name=fn_name, + context=ctx, + execution_task=execution_task, + renewal_task=renewal_task, + renewal_cancel=renewal_cancel, + result_future=result_future, + terminate_event=terminate_event, + fn=fn, + input_type=input_type, + opts=opts, + retry=retry, + ) + self._active_tasks[task_id] = active + + # Start metadata auto-flush + metadata.start_auto_flush() + + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=result_future, + metadata=metadata, + cancel_event=cancel_event, + stream_handler=handler, + terminate_event=terminate_event, + execution_task=execution_task, + terminate_reason_ref=terminate_reason_ref, + ) + + async def handle_resume(self, task_id: str) -> None: + """Resume a suspended task. + + :param task_id: The task to resume. + :type task_id: str + :raises TaskNotFound: If the task doesn't exist. + :raises ValueError: If the task is not suspended or no callback. + """ + task_info = await self._provider.get(task_id) + if task_info is None: + raise TaskNotFound(task_id) + + if task_info.status != "suspended": + raise ValueError( + f"Task {task_id!r} is {task_info.status!r}, not 'suspended'" + ) + + # Find the resume callback by scanning registered names + fn = self._find_resume_callback(task_info) + if fn is None: + raise ValueError(f"No resume callback registered for task {task_id!r}") + + await self._start_existing_task( + fn=fn, + fn_name=task_info.agent_name, + task_info=task_info, + entry_mode="resumed", + ) + + logger.info("Resumed task %s", task_id) + + async def _start_existing_task( # pylint: disable=too-many-locals,too-many-statements + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_info: TaskInfo, + entry_mode: EntryMode, + input_val: Any | None = None, + input_type: type[Any] | None = None, + opts: DurableTaskOptions | None = None, + retry: RetryPolicy | None = None, + ) -> TaskRun[Any]: + """Transition an existing task to in_progress and execute it. + + Used by lifecycle-aware ``.run()``/``.start()`` for suspended, + pending, and stale in_progress tasks. + + :keyword fn: The durable task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: Function name for logging. + :paramtype fn_name: str + :keyword task_info: The current task record. + :paramtype task_info: TaskInfo + :keyword entry_mode: Why this execution is starting. + :paramtype entry_mode: EntryMode + :keyword input_val: New input (overrides persisted input). + :paramtype input_val: Any | None + :keyword input_type: Type for deserializing persisted input. + :paramtype input_type: type[Any] | None + :keyword opts: Task options (uses defaults if not provided). + :paramtype opts: DurableTaskOptions | None + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :return: A TaskRun handle. + :rtype: TaskRun[Any] + """ + task_id = task_info.id + resolved_opts = opts or DurableTaskOptions(name=fn_name, ephemeral=False) + lease_duration = resolved_opts.lease_duration_seconds + + # Transition to in_progress with new lease + await self._provider.update( + task_id, + TaskPatchRequest( + status="in_progress", + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=lease_duration, + ), + ) + + # Re-fetch updated task + updated_info: TaskInfo | None = await self._provider.get(task_id) + if updated_info is None: + raise TaskNotFound(task_id) + task_info = updated_info + + # Resolve input: prefer caller-provided, fall back to persisted + if input_val is not None: + resolved_input = input_val + elif task_info.payload and "input" in task_info.payload: + raw_input = task_info.payload["input"] + if input_type is not None: + resolved_input = _deserialize_input(raw_input, input_type) + else: + resolved_input = raw_input + else: + resolved_input = None + + # Build context for execution + cancel_event = asyncio.Event() + handler = QueueStreamHandler() + existing_metadata = ( + task_info.payload.get("metadata", {}) if task_info.payload else {} + ) + metadata = TaskMetadata( + initial=existing_metadata, + flush_callback=self._make_metadata_flush(task_id), + flush_interval=5.0, + ) + + lease_gen = task_info.lease.generation if task_info.lease else 0 + + # Extract steering context from payload + steering = (task_info.payload or {}).get("_steering", {}) + # Detect steering context from payload (covers recovered-mid-drain) + was_steered = bool( + steering.get("drain_in_progress") + or steering.get("pending_inputs") + or steering.get("generation", 0) > 0 + ) + + # For steerable recovery with drain_in_progress, use active_input + if ( + entry_mode == "recovered" + and steering.get("drain_in_progress") + and "active_input" in steering + ): + raw_active = steering["active_input"] + if input_type is not None: + resolved_input = _deserialize_input(raw_active, input_type) + else: + resolved_input = raw_active + + prev_input_raw = steering.get("previous_input") + previous_input = None + if prev_input_raw is not None and input_type is not None: + previous_input = _deserialize_input(prev_input_raw, input_type) + elif prev_input_raw is not None: + previous_input = prev_input_raw + pending_snapshot = tuple(steering.get("pending_inputs", ())) + generation = steering.get("generation", 0) + + # Pre-set cancel if cancel_requested is True (steering short-circuit) + if steering.get("cancel_requested"): + cancel_event.set() + + ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + title=task_info.title or "", + description=task_info.description, + session_id=task_info.session_id, + agent_name=task_info.agent_name, + tags=task_info.tags or {}, + input=resolved_input, + metadata=metadata, + run_attempt=0, + lease_generation=lease_gen, + cancel=cancel_event, + shutdown=self._shutdown_event, + stream_handler=handler, + entry_mode=entry_mode, + was_steered=was_steered, + pending_inputs=pending_snapshot, + generation=generation, + ) + + loop = asyncio.get_event_loop() + result_future: asyncio.Future[Any] = loop.create_future() + + renewal_cancel = asyncio.Event() + + # Build steering poll callback for steerable tasks + steering_poll_cb: Callable[[], Awaitable[None]] | None = None + if resolved_opts.steerable: + + async def _steering_poll() -> None: + """Poll provider for new steering inputs and signal cancel.""" + active = self._active_tasks.get(task_id) + if active is None or active.context.cancel.is_set(): + return + info = await self._provider.get(task_id) + if info is None or not info.payload: + return + st = info.payload.get("_steering", {}) + if st.get("pending_inputs"): + active.context.cancel.set() + + steering_poll_cb = _steering_poll + + renewal_task = asyncio.create_task( + lease_renewal_loop( + self._provider, + task_id, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=lease_duration, + cancel_event=renewal_cancel, + on_cancel_callback=cancel_event, + steering_poll_callback=steering_poll_cb, + ) + ) + + terminate_event = asyncio.Event() + terminate_reason_ref: list[str | None] = [None] + execution_task = asyncio.create_task( + self._execute_task( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=resolved_opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=terminate_event, + terminate_reason_ref=terminate_reason_ref, + ) + ) + + active = _ActiveTask( + task_id=task_id, + fn_name=fn_name, + context=ctx, + execution_task=execution_task, + renewal_task=renewal_task, + renewal_cancel=renewal_cancel, + result_future=result_future, + terminate_event=terminate_event, + fn=fn, + input_type=input_type, + opts=resolved_opts, + retry=retry, + ) + self._active_tasks[task_id] = active + metadata.start_auto_flush() + + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=result_future, + metadata=metadata, + cancel_event=cancel_event, + stream_handler=handler, + terminate_event=terminate_event, + execution_task=execution_task, + terminate_reason_ref=terminate_reason_ref, + lease_expiry_count=task_info.lease.expiry_count if task_info.lease else 0, + ) + + async def _timeout_watchdog( + self, + timeout_seconds: float, + cancel_event: asyncio.Event, + ) -> None: + """Background watchdog that enforces execution timeout. + + After *timeout_seconds*, sets *cancel_event* (cooperative). + The function is expected to check ``ctx.cancel`` and exit + gracefully. If it doesn't, the lease will eventually expire + and the task will be recovered. + + :param timeout_seconds: Seconds before cooperative cancel. + :type timeout_seconds: float + :param cancel_event: Event to set for cooperative cancel. + :type cancel_event: asyncio.Event + """ + await asyncio.sleep(timeout_seconds) + cancel_event.set() + logger.info( + "Timeout watchdog fired cooperative cancel after %.1fs", timeout_seconds + ) + + async def _execute_task( + self, + *, + fn: Callable[..., Awaitable[Any]], + ctx: TaskContext[Any], + task_id: str, + opts: DurableTaskOptions, + result_future: asyncio.Future[Any], + renewal_cancel: asyncio.Event, + retry: RetryPolicy | None = None, + terminate_event: asyncio.Event | None = None, + terminate_reason_ref: list[str | None] | None = None, + ) -> None: + """Run the task function and handle completion/failure/suspend. + + When a ``RetryPolicy`` is provided, failed attempts are retried + with the configured delay and backoff. Suspend and cancellation + always exit immediately — they are not retried. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword ctx: The task context. + :paramtype ctx: TaskContext[Any] + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + :keyword result_future: Future to resolve with the result. + :paramtype result_future: asyncio.Future[Any] + :keyword renewal_cancel: Event to cancel lease renewal. + :paramtype renewal_cancel: asyncio.Event + :keyword retry: Optional retry policy. + :paramtype retry: RetryPolicy | None + :keyword terminate_event: Optional terminate event. + :paramtype terminate_event: asyncio.Event | None + :keyword terminate_reason_ref: Mutable ref for terminate reason. + :paramtype terminate_reason_ref: list[str | None] | None + """ + resolved_terminate = terminate_event or asyncio.Event() + + # Start timeout watchdog if configured + watchdog_task: asyncio.Task[None] | None = None + if opts.timeout is not None: + watchdog_task = asyncio.create_task( + self._timeout_watchdog( + timeout_seconds=opts.timeout.total_seconds(), + cancel_event=ctx.cancel, + ) + ) + + attempt = 0 # pylint: disable=unused-variable + try: + await self._execute_task_loop( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=resolved_terminate, + terminate_reason_ref=terminate_reason_ref, + ) + finally: + if watchdog_task is not None and not watchdog_task.done(): + watchdog_task.cancel() + try: + await watchdog_task + except asyncio.CancelledError: + pass + + async def _execute_task_loop( # pylint: disable=too-many-statements,too-many-branches,too-many-nested-blocks + self, + *, + fn: Callable[..., Awaitable[Any]], + ctx: TaskContext[Any], + task_id: str, + opts: DurableTaskOptions, + result_future: asyncio.Future[Any], + renewal_cancel: asyncio.Event, + retry: RetryPolicy | None = None, + terminate_event: asyncio.Event | None = None, + terminate_reason_ref: list[str | None] | None = None, + ) -> None: + """Inner execution loop — separated from watchdog management. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword ctx: The task context. + :paramtype ctx: TaskContext[Any] + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + :keyword result_future: Future to resolve with the result. + :paramtype result_future: asyncio.Future[Any] + :keyword renewal_cancel: Event to cancel lease renewal. + :paramtype renewal_cancel: asyncio.Event + :keyword retry: Optional retry policy. + :paramtype retry: RetryPolicy | None + :keyword terminate_event: Optional terminate event. + :paramtype terminate_event: asyncio.Event | None + :keyword terminate_reason_ref: Mutable ref for terminate reason. + :paramtype terminate_reason_ref: list[str | None] | None + """ + resolved_terminate = terminate_event or asyncio.Event() + reason_ref = ( + terminate_reason_ref if terminate_reason_ref is not None else [None] + ) + attempt = 0 + # Mutable ref: steering drain may swap the active result_future + current_result_future = result_future + while True: + ctx.run_attempt = attempt + try: + result = await fn(ctx) + + if isinstance(result, Suspended): + # STEERING: check for pending inputs BEFORE persisting suspend + if opts.steerable: + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + ) + if new_ctx is not None: + # Drain found pending input — loop with new context + ctx = new_ctx + attempt = 0 + # Update result future to the new generation's future + active = self._active_tasks.get(task_id) + if ( + active + and active.result_future is not current_result_future + ): + current_result_future = active.result_future + continue + + # No pending steering — normal suspend flow + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + await self._handle_suspend( + task_id=task_id, + reason=result.reason, + output=result.output, + metadata=ctx.metadata, + opts=opts, + ) + if not current_result_future.done(): + current_result_future.set_result( + TaskResult( + task_id=task_id, + output=result.output, + status="suspended", + suspension_reason=result.reason, + ) + ) + else: + # Guard: task functions must return raw output, not TaskResult + if isinstance(result, TaskResult): + raise TypeError( + "Task function returned TaskResult directly. " + "Return raw output instead — the framework wraps " + "it in TaskResult automatically." + ) + + # STEERING: check for pending before completing + if opts.steerable: + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + partial_output=result, + ) + if new_ctx is not None: + ctx = new_ctx + attempt = 0 + active = self._active_tasks.get(task_id) + if ( + active + and active.result_future is not current_result_future + ): + current_result_future = active.result_future + continue + + # Success flow + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + completed = await self._handle_success( + task_id=task_id, + result=result, + metadata=ctx.metadata, + opts=opts, + ) + if not completed: + # Etag conflict on steerable completion — re-drain + renewal_cancel = asyncio.Event() # reset for next iteration + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + partial_output=result, + ) + if new_ctx is not None: + ctx = new_ctx + attempt = 0 + active = self._active_tasks.get(task_id) + if ( + active + and active.result_future is not current_result_future + ): + current_result_future = active.result_future + continue + # No pending found despite conflict — complete anyway + if not current_result_future.done(): + current_result_future.set_result( + TaskResult( + task_id=task_id, + output=result, + status="completed", + ) + ) + + break # exit retry loop on success or suspend + + except asyncio.CancelledError: + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + if resolved_terminate.is_set(): + # Forced termination (timeout or explicit terminate()) + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskTerminated, + ) + + await self._handle_failure( + task_id=task_id, + exc=TaskTerminated(task_id, reason=reason_ref[0]), + metadata=ctx.metadata, + opts=opts, + ) + if not current_result_future.done(): + current_result_future.set_exception( + TaskTerminated(task_id, reason=reason_ref[0]) + ) + else: + # Cooperative cancellation (suspend or caller cancel) + if not current_result_future.done(): + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + ) + + current_result_future.set_exception(TaskCancelled(task_id)) + break # cancellation is never retried + + except Exception as exc: # pylint: disable=broad-exception-caught + if retry and retry.should_retry(attempt, exc): + delay = retry.compute_delay(attempt) + logger.warning( + "Task %s attempt %d failed (%s: %s), retrying in %.1fs", + task_id, + attempt, + type(exc).__name__, + exc, + delay, + ) + # Update error field so observers see intermediate failures + try: + await self._provider.update( + task_id, + TaskPatchRequest( + error={ + "type": type(exc).__name__, + "message": str(exc), + "attempt": attempt, + } + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.debug( + "Failed to update error field for retry", exc_info=True + ) + await asyncio.sleep(delay) + attempt += 1 + continue + + # Exhausted or non-retryable — terminal failure + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + + if retry and attempt > 0: + # Retries were attempted but exhausted + error_dict: dict[str, Any] = { + "type": "exhausted_retries", + "attempts": attempt + 1, + "last_error": str(exc), + "last_error_type": type(exc).__name__, + "traceback": traceback.format_exc(), + } + else: + error_dict = { + "type": type(exc).__name__, + "message": str(exc), + "traceback": traceback.format_exc(), + } + + await self._handle_failure( + task_id=task_id, + exc=exc, + metadata=ctx.metadata, + opts=opts, + ) + if not current_result_future.done(): + current_result_future.set_exception(TaskFailed(task_id, error_dict)) + break + + self._active_tasks.pop(task_id, None) + # Signal end of streaming via handler.close() + if ctx._stream_handler is not None: # pylint: disable=protected-access + try: + await ctx._stream_handler.close() # pylint: disable=protected-access + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Stream handler close() failed for task %s", + task_id, + exc_info=True, + ) + + async def _try_drain_steering( # pylint: disable=too-many-branches + self, + *, + task_id: str, + ctx: TaskContext[Any], + opts: DurableTaskOptions, + result_future: asyncio.Future[Any], + partial_output: Any | None = None, + ) -> TaskContext[Any] | None: + """Check for pending steering inputs and drain the next one. + + Called BEFORE persisting suspend/complete to avoid lease/status conflicts. + Returns a new ``TaskContext`` if a drain occurred, or ``None`` if no + pending inputs exist. + + :keyword task_id: The task identifier. + :keyword ctx: Current task context. + :keyword opts: Task options. + :keyword result_future: The current generation's result future. + :keyword partial_output: Output from the completed generation (for race recovery). + :return: New context for the drained generation, or None. + """ + task_info = await self._provider.get(task_id) + if task_info is None: + return None + + payload = dict(task_info.payload) if task_info.payload else {} + steering = dict(payload.get("_steering", {})) + pending: list[Any] = list(steering.get("pending_inputs", [])) + + if not pending: + return None + + # Pop the next input from the queue + next_input_raw = pending.pop(0) + previous_input_raw = steering.get("active_input") + + # Update steering state + steering["active_input"] = next_input_raw + if previous_input_raw is not None: + steering["previous_input"] = previous_input_raw + steering["pending_inputs"] = pending + old_generation = steering.get("generation", 0) + steering["generation"] = old_generation + 1 + steering["cancel_requested"] = len(pending) > 0 + steering["drain_in_progress"] = True + + # Save partial output if function completed (race recovery) + if partial_output is not None: + gen_results = dict(steering.get("generation_results", {})) + gen_results[str(old_generation)] = _serialize_input(partial_output) + steering["generation_results"] = gen_results + + payload["_steering"] = steering + + try: + etag = getattr(task_info, "etag", None) or None + await self._provider.update( + task_id, + TaskPatchRequest(payload=payload, if_match=etag), + ) + except ValueError: + # Etag conflict — re-read and retry once + logger.warning( + "Etag conflict during steering drain for %s, retrying", task_id + ) + return await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=result_future, + partial_output=partial_output, + ) + + # Pop and bind the next pending steering future (if any) + new_future: asyncio.Future[Any] | None = None + had_registered_future = False + steering_futures = self._pending_steering_futures.get(task_id, []) + if steering_futures: + new_future = steering_futures.pop(0) + had_registered_future = True + + # Resolve the superseded generation's future (only for external steer callers) + if had_registered_future and not result_future.done(): + result_future.set_result( + TaskResult(task_id=task_id, output=partial_output, status="superseded") + ) + + # Update active generation future + if new_future is not None: + self._active_generation_future[task_id] = new_future + + # Deserialize input + active_task = self._active_tasks.get(task_id) + input_type = active_task.input_type if active_task else None + if input_type is not None: + resolved_input = _deserialize_input(next_input_raw, input_type) + else: + resolved_input = next_input_raw + + # Deserialize previous input + previous_input = None + if previous_input_raw is not None and input_type is not None: + previous_input = _deserialize_input(previous_input_raw, input_type) + elif previous_input_raw is not None: + previous_input = previous_input_raw + + # Build new context, reusing metadata and shutdown event + cancel_event = asyncio.Event() + if steering["cancel_requested"]: + cancel_event.set() + + new_ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + title=ctx.title, + description=ctx.description, + session_id=ctx.session_id, + agent_name=ctx.agent_name, + tags=ctx.tags, + input=resolved_input, + metadata=ctx.metadata, + run_attempt=0, + lease_generation=ctx.lease_generation, + cancel=cancel_event, + shutdown=ctx.shutdown, + stream_handler=ctx._stream_handler, # pylint: disable=protected-access + entry_mode="resumed", + was_steered=True, + previous_input=previous_input, + pending_inputs=tuple(pending), + generation=old_generation + 1, + ) + + # Update active task tracking + if active_task is not None: + active_task.context = new_ctx + if new_future is not None: + active_task.result_future = new_future + + # Clear drain_in_progress + steering["drain_in_progress"] = False + payload["_steering"] = steering + try: + await self._provider.update( + task_id, + TaskPatchRequest(payload=payload), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.debug("Failed to clear drain_in_progress for %s", task_id) + + logger.info( + "Steering drain: task %s generation %d → %d", + task_id, + old_generation, + old_generation + 1, + ) + return new_ctx + + async def _handle_success( + self, + *, + task_id: str, + result: Any, + metadata: TaskMetadata, + opts: DurableTaskOptions, + ) -> bool: + """Handle successful task completion. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword result: The task result value. + :paramtype result: Any + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + :return: True if completion succeeded, False if etag conflict + detected (steerable tasks only — caller should re-drain). + :rtype: bool + """ + if opts.ephemeral: + # Delete immediately — no intermediate PATCH + try: + await self._provider.delete(task_id, force=True) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to delete ephemeral task %s", task_id, exc_info=True + ) + else: + # PATCH to completed with output + payload_patch: dict[str, Any] = { + "metadata": metadata.to_dict(), + "output": _serialize_input(result), + } + + # For steerable tasks, use etag to detect concurrent steering + if opts.steerable: + try: + task_info = await self._provider.get(task_id) + etag = getattr(task_info, "etag", None) if task_info else None + await self._provider.update( + task_id, + TaskPatchRequest( + status="completed", + payload=payload_patch, + if_match=etag, + ), + ) + except ValueError: + # Etag conflict — another process may have steered + logger.info( + "Etag conflict completing task %s — re-checking for steers", + task_id, + ) + return False + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to complete task %s", task_id, exc_info=True) + else: + try: + await self._provider.update( + task_id, + TaskPatchRequest( + status="completed", + payload=payload_patch, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to complete task %s", task_id, exc_info=True) + + logger.info("Task %s completed successfully", task_id) + return True + + async def _handle_failure( + self, + *, + task_id: str, + exc: Exception, + metadata: TaskMetadata, + opts: DurableTaskOptions, + ) -> None: + """Handle task failure. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword exc: The exception that caused the failure. + :paramtype exc: Exception + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + """ + error_dict = { + "type": type(exc).__name__, + "message": str(exc), + "traceback": traceback.format_exc(), + } + + if opts.ephemeral: + try: + await self._provider.delete(task_id, force=True) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to delete failed ephemeral task %s", + task_id, + exc_info=True, + ) + else: + try: + await self._provider.update( + task_id, + TaskPatchRequest( + status="completed", + error=error_dict, + payload={"metadata": metadata.to_dict()}, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to record error for task %s", + task_id, + exc_info=True, + ) + + logger.error("Task %s failed: %s", task_id, exc) + + async def _handle_suspend( + self, + *, + task_id: str, + reason: str | None, + output: Any | None, + metadata: TaskMetadata, + opts: DurableTaskOptions, # pylint: disable=unused-argument + ) -> None: + """Handle task suspension. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword reason: Optional suspension reason. + :paramtype reason: str | None + :keyword output: Optional output snapshot. + :paramtype output: Any | None + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + """ + payload_patch: dict[str, Any] = { + "metadata": metadata.to_dict(), + } + if output is not None: + payload_patch["output"] = _serialize_input(output) + + try: + await self._provider.update( + task_id, + TaskPatchRequest( + status="suspended", + suspension_reason=reason, + payload=payload_patch, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to suspend task %s", task_id, exc_info=True) + + logger.info("Task %s suspended: %s", task_id, reason) + + async def _recover_stale_tasks(self) -> None: + """Recover stale in-progress tasks from previous instances.""" + agent_name = self._config.agent_name or "default" + session_id = self._config.session_id or "local" + + try: + stale_tasks = await self._provider.list( + agent_name=agent_name, + session_id=session_id, + status="in_progress", + lease_owner=self._lease_owner, + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to query stale tasks for recovery", exc_info=True) + return + + for task_info in stale_tasks: + # Skip if we're already tracking this task + if task_info.id in self._active_tasks: + continue + + # Reclaim the lease with our new instance ID + try: + await self._provider.update( + task_info.id, + TaskPatchRequest( + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=60, + ), + ) + logger.info( + "Reclaimed stale task %s (generation will increment)", + task_info.id, + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to reclaim task %s", task_info.id, exc_info=True) + continue + + # Find resume callback and dispatch + fn = self._find_resume_callback(task_info) + if fn is not None: + try: + await self.handle_resume(task_info.id) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to resume recovered task %s", + task_info.id, + exc_info=True, + ) + + def _find_resume_callback(self, task_info: TaskInfo) -> Callable[..., Any] | None: + """Find a registered resume callback for a task. + + Matches by ``source.name`` (auto-stamped function name) first, + then falls back to title prefix match or single-callback default. + + :param task_info: The task record to match. + :type task_info: TaskInfo + :return: A matching resume callback, or None. + :rtype: Callable[..., Any] | None + """ + # Preferred: match by source.name (framework auto-stamped fn name) + if task_info.source and "name" in task_info.source: + source_name = task_info.source["name"] + if source_name in self._resume_callbacks: + return self._resume_callbacks[source_name] + + # Fallback: title prefix match + for name, fn in self._resume_callbacks.items(): + if task_info.title and task_info.title.startswith(name): + return fn + + # Last resort: single registered callback + if len(self._resume_callbacks) == 1: + return next(iter(self._resume_callbacks.values())) + return None + + def _make_metadata_flush( + self, task_id: str + ) -> Callable[[dict[str, Any]], Awaitable[None]]: + """Create a flush callback for metadata persistence. + + :param task_id: The task identifier. + :type task_id: str + :return: An async callback that flushes metadata. + :rtype: Callable[[dict[str, Any]], Awaitable[None]] + """ + + async def _flush(data: dict[str, Any]) -> None: + await self._provider.update( + task_id, + TaskPatchRequest(payload={"metadata": data}), + ) + + return _flush diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_metadata.py new file mode 100644 index 000000000000..885af44065cf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_metadata.py @@ -0,0 +1,235 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mutable progress metadata for durable tasks. + +Provides a dict-like interface with typed mutation methods and +debounced persistence to the task storage backend. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import collections.abc +import logging +from collections.abc import Iterator +from typing import Any + +logger = logging.getLogger("azure.ai.agentserver.durable") + +# Sentinel to distinguish "not set" from None +_NOT_SET = object() + + +class TaskMetadata: + """Mutable progress dict persisted to the task record's payload. + + Changes are batched and flushed on a configurable interval, or + immediately on explicit :meth:`flush`, suspension, or completion. + + :param initial: Initial metadata values (from a recovered task). + :type initial: dict[str, Any] | None + :param flush_callback: Async callable that persists dirty metadata. + :type flush_callback: Callable[[dict[str, Any]], Awaitable[None]] | None + :param flush_interval: Seconds between automatic flushes (0 = disabled). + :type flush_interval: float + """ + + def __init__( + self, + initial: dict[str, Any] | None = None, + *, + flush_callback: Any = None, + flush_interval: float = 5.0, + ) -> None: + self._data: dict[str, Any] = dict(initial) if initial else {} + self._dirty = False + self._flush_callback = flush_callback + self._flush_interval = flush_interval + self._flush_task: asyncio.Task[None] | None = None + self._lock = asyncio.Lock() + + def set(self, key: str, value: Any) -> None: + """Set a key-value pair. + + :param key: Metadata key (must be a string). + :type key: str + :param value: Any JSON-serializable value. + :type value: Any + :raises TypeError: If key is not a string. + """ + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string, got {type(key).__name__}") + self._data[key] = value + self._mark_dirty() + + def get(self, key: str, default: Any = None) -> Any: + """Get a value by key. + + :param key: Metadata key. + :type key: str + :param default: Default value if key is absent. + :type default: Any + :return: The value, or *default*. + :rtype: Any + """ + return self._data.get(key, default) + + def increment(self, key: str, delta: int = 1) -> None: + """Atomically increment a numeric value. + + :param key: Metadata key. + :type key: str + :param delta: Amount to add (default 1). + :type delta: int + :raises TypeError: If the existing value is not numeric. + """ + if not isinstance(delta, (int, float)): + raise TypeError(f"Delta must be numeric, got {type(delta).__name__}") + current = self._data.get(key, 0) + if not isinstance(current, (int, float)): + raise TypeError( + f"Cannot increment non-numeric value at key {key!r}: " + f"{type(current).__name__}" + ) + self._data[key] = current + delta + self._mark_dirty() + + def append(self, key: str, value: Any) -> None: + """Append a value to a list. + + Creates the list if the key is absent. + + :param key: Metadata key. + :type key: str + :param value: Value to append. + :type value: Any + :raises TypeError: If the existing value is not a list. + """ + current = self._data.get(key, _NOT_SET) + if current is _NOT_SET: + self._data[key] = [value] + elif isinstance(current, list): + current.append(value) + else: + raise TypeError( + f"Cannot append to non-list value at key {key!r}: " + f"{type(current).__name__}" + ) + self._mark_dirty() + + def to_dict(self) -> dict[str, Any]: + """Return a snapshot of all metadata. + + :return: A shallow copy of the metadata dict. + :rtype: dict[str, Any] + """ + return dict(self._data) + + # -- Dict protocol (MutableMapping) ------------------------------------ + + def __setitem__(self, key: str, value: Any) -> None: + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string, got {type(key).__name__}") + self._data[key] = value + self._mark_dirty() + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __delitem__(self, key: str) -> None: + del self._data[key] + self._mark_dirty() + + def __contains__(self, key: object) -> bool: + return key in self._data + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def keys(self) -> collections.abc.KeysView[str]: + """Return a view of metadata keys. + + :return: A view of the metadata keys. + :rtype: ~collections.abc.KeysView[str] + """ + return self._data.keys() + + def values(self) -> collections.abc.ValuesView[Any]: + """Return a view of metadata values. + + :return: A view of the metadata values. + :rtype: ~collections.abc.ValuesView[Any] + """ + return self._data.values() + + def items(self) -> collections.abc.ItemsView[str, Any]: + """Return a view of metadata key-value pairs. + + :return: A view of the metadata key-value pairs. + :rtype: ~collections.abc.ItemsView[str, Any] + """ + return self._data.items() + + async def flush(self) -> None: + """Force-flush pending metadata changes to the store. + + No-op if there are no pending changes or no flush callback. + """ + async with self._lock: + await self._do_flush() + + def start_auto_flush(self) -> None: + """Start the background auto-flush loop. + + Called by the framework when the task starts executing. Should + not be called by user code. + """ + if ( + self._flush_interval > 0 + and self._flush_callback is not None + and self._flush_task is None + ): + self._flush_task = asyncio.get_event_loop().create_task( + self._auto_flush_loop() + ) + + async def stop_auto_flush(self) -> None: + """Stop the auto-flush loop and perform a final flush.""" + if self._flush_task is not None: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + self._flush_task = None + # Final flush + async with self._lock: + await self._do_flush() + + def _mark_dirty(self) -> None: + self._dirty = True + + async def _do_flush(self) -> None: + if not self._dirty or self._flush_callback is None: + return + try: + await self._flush_callback(dict(self._data)) + self._dirty = False + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to flush metadata", exc_info=True) + + async def _auto_flush_loop(self) -> None: + """Periodically flush dirty metadata.""" + while True: + await asyncio.sleep(self._flush_interval) + async with self._lock: + await self._do_flush() + + +# Register as a virtual subclass so isinstance checks work +# without inheriting (preserves custom increment/append/flush). +collections.abc.MutableMapping.register(TaskMetadata) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_models.py new file mode 100644 index 000000000000..f4a28cbde7b0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_models.py @@ -0,0 +1,380 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Internal data models for the durable task subsystem. + +These types represent wire-level task records and request/response shapes +used by providers. They are **not** part of the public API. +""" + +from __future__ import annotations + +from typing import Any, Literal + +TaskStatus = Literal["pending", "in_progress", "suspended", "completed"] +"""Valid task status values.""" + + +class LeaseInfo: + """Lease details on a task record. + + :param owner: Stable lease owner (e.g. ``"session:session_abc"``). + :type owner: str + :param instance_id: Ephemeral per-process instance identifier. + :type instance_id: str + :param generation: Fencing token — increments on re-acquisition. + :type generation: int + :param expires_at: ISO 8601 expiry timestamp. + :type expires_at: str + :param expiry_count: Number of times ownership changed via expiry. + :type expiry_count: int + """ + + __slots__ = ("owner", "instance_id", "generation", "expires_at", "expiry_count") + + def __init__( + self, + owner: str, + instance_id: str, + generation: int, + expires_at: str, + expiry_count: int = 0, + ) -> None: + self.owner = owner + self.instance_id = instance_id + self.generation = generation + self.expires_at = expires_at + self.expiry_count = expiry_count + + def __repr__(self) -> str: + return ( + f"LeaseInfo(owner={self.owner!r}, instance_id={self.instance_id!r}, " + f"generation={self.generation!r}, expires_at={self.expires_at!r}, " + f"expiry_count={self.expiry_count!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LeaseInfo): + return NotImplemented + return ( + self.owner == other.owner + and self.instance_id == other.instance_id + and self.generation == other.generation + and self.expires_at == other.expires_at + and self.expiry_count == other.expiry_count + ) + + +class TaskInfo: # pylint: disable=too-many-instance-attributes + """Internal representation of a task record from the store. + + :param id: Unique task identifier. + :type id: str + :param agent_name: Agent scope. + :type agent_name: str + :param session_id: Session scope. + :type session_id: str + :param status: Current task status. + :type status: TaskStatus + :param title: Human-readable title. + :type title: str | None + :param description: Optional description. + :type description: str | None + :param lease: Active lease details, or ``None``. + :type lease: LeaseInfo | None + :param payload: Arbitrary JSON payload (input, metadata, output buckets). + :type payload: dict[str, Any] | None + :param tags: Key-value tags. + :type tags: dict[str, str] | None + :param error: Structured error details on failure. + :type error: dict[str, Any] | None + :param suspension_reason: Reason for suspension. + :type suspension_reason: str | None + :param etag: Optimistic concurrency token. + :type etag: str + :param created_at: ISO 8601 creation timestamp. + :type created_at: str + :param updated_at: ISO 8601 last-update timestamp. + :type updated_at: str + :param started_at: ISO 8601 timestamp of first ``in_progress`` transition. + :type started_at: str | None + :param completed_at: ISO 8601 timestamp of ``completed`` transition. + :type completed_at: str | None + """ + + __slots__ = ( + "id", + "agent_name", + "session_id", + "status", + "title", + "description", + "lease", + "payload", + "tags", + "error", + "suspension_reason", + "etag", + "created_at", + "updated_at", + "started_at", + "completed_at", + "source", + ) + + def __init__( + self, + id: str, # noqa: A002 + agent_name: str, + session_id: str, + status: TaskStatus, + title: str | None = None, + description: str | None = None, + lease: LeaseInfo | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + error: dict[str, Any] | None = None, + suspension_reason: str | None = None, + etag: str = "", + created_at: str = "", + updated_at: str = "", + started_at: str | None = None, + completed_at: str | None = None, + source: dict[str, Any] | None = None, + ) -> None: + self.id = id + self.agent_name = agent_name + self.session_id = session_id + self.status = status + self.title = title + self.description = description + self.lease = lease + self.payload = payload + self.tags = tags + self.error = error + self.suspension_reason = suspension_reason + self.etag = etag + self.created_at = created_at + self.updated_at = updated_at + self.started_at = started_at + self.completed_at = completed_at + self.source = source + + def __repr__(self) -> str: + return f"TaskInfo(id={self.id!r}, status={self.status!r}, agent_name={self.agent_name!r})" + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> TaskInfo: + """Construct a :class:`TaskInfo` from a JSON-decoded dict. + + :param data: Dictionary as returned by the Task Storage API. + :type data: dict[str, Any] + :return: A populated TaskInfo instance. + :rtype: TaskInfo + """ + lease_data = data.get("lease") + lease = ( + LeaseInfo( + owner=lease_data["owner"], + instance_id=lease_data["instance_id"], + generation=lease_data.get("generation", 0), + expires_at=lease_data.get("expires_at", ""), + expiry_count=lease_data.get("expiry_count", 0), + ) + if lease_data + else None + ) + return cls( + id=data["id"], + agent_name=data.get("agent_name", ""), + session_id=data.get("session_id", ""), + status=data.get("status", "pending"), + title=data.get("title"), + description=data.get("description"), + lease=lease, + payload=data.get("payload"), + tags=data.get("tags"), + error=data.get("error"), + suspension_reason=data.get("suspension_reason"), + etag=data.get("etag", ""), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + started_at=data.get("started_at"), + completed_at=data.get("completed_at"), + source=data.get("source"), + ) + + def to_dict(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dictionary. + + :return: Dictionary suitable for JSON serialization. + :rtype: dict[str, Any] + """ + result: dict[str, Any] = { + "object": "task", + "id": self.id, + "agent_name": self.agent_name, + "session_id": self.session_id, + "status": self.status, + } + if self.title is not None: + result["title"] = self.title + if self.description is not None: + result["description"] = self.description + if self.lease is not None: + result["lease"] = { + "owner": self.lease.owner, + "instance_id": self.lease.instance_id, + "generation": self.lease.generation, + "expires_at": self.lease.expires_at, + "expiry_count": self.lease.expiry_count, + } + else: + result["lease"] = None + if self.payload is not None: + result["payload"] = self.payload + if self.tags is not None: + result["tags"] = self.tags + if self.error is not None: + result["error"] = self.error + if self.suspension_reason is not None: + result["suspension_reason"] = self.suspension_reason + if self.source is not None: + result["source"] = self.source + result["etag"] = self.etag + result["created_at"] = self.created_at + result["updated_at"] = self.updated_at + result["started_at"] = self.started_at + result["completed_at"] = self.completed_at + return result + + +class TaskCreateRequest: # pylint: disable=too-many-instance-attributes + """Request body for creating a task. + + :param agent_name: Agent scope. + :type agent_name: str + :param session_id: Session scope. + :type session_id: str + :param status: Initial status (``"pending"`` or ``"in_progress"``). + :type status: TaskStatus + :param id: Optional client-supplied task ID. + :type id: str | None + :param title: Human-readable title. + :type title: str | None + :param description: Optional description. + :type description: str | None + :param payload: Initial payload (input bucket). + :type payload: dict[str, Any] | None + :param tags: Initial tags. + :type tags: dict[str, str] | None + :param lease_owner: Required when ``status`` is ``"in_progress"``. + :type lease_owner: str | None + :param lease_instance_id: Required when ``status`` is ``"in_progress"``. + :type lease_instance_id: str | None + :param lease_duration_seconds: Lease TTL. Required with lease params. + :type lease_duration_seconds: int | None + """ + + __slots__ = ( + "agent_name", + "session_id", + "status", + "id", + "title", + "description", + "payload", + "tags", + "source", + "lease_owner", + "lease_instance_id", + "lease_duration_seconds", + ) + + def __init__( + self, + agent_name: str, + session_id: str, + status: TaskStatus = "pending", + id: str | None = None, # noqa: A002 + title: str | None = None, + description: str | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + source: dict[str, Any] | None = None, + lease_owner: str | None = None, + lease_instance_id: str | None = None, + lease_duration_seconds: int | None = None, + ) -> None: + self.agent_name = agent_name + self.session_id = session_id + self.status = status + self.id = id + self.title = title + self.description = description + self.payload = payload + self.tags = tags + self.source = source + self.lease_owner = lease_owner + self.lease_instance_id = lease_instance_id + self.lease_duration_seconds = lease_duration_seconds + + +class TaskPatchRequest: + """Request body for patching a task. + + Only non-``None`` fields are included in the PATCH payload. + + :param status: New status. + :type status: TaskStatus | None + :param payload: Payload patch (shallow-merge semantics). + :type payload: dict[str, Any] | None + :param tags: Tags patch (null-as-delete merge). + :type tags: dict[str, str] | None + :param error: Structured error (on failure). + :type error: dict[str, Any] | None + :param suspension_reason: Reason for suspension. + :type suspension_reason: str | None + :param lease_owner: Lease owner for transitions. + :type lease_owner: str | None + :param lease_instance_id: Lease instance for transitions. + :type lease_instance_id: str | None + :param lease_duration_seconds: Lease TTL override. + :type lease_duration_seconds: int | None + :param if_match: ETag for optimistic concurrency. + :type if_match: str | None + """ + + __slots__ = ( + "status", + "payload", + "tags", + "error", + "suspension_reason", + "lease_owner", + "lease_instance_id", + "lease_duration_seconds", + "if_match", + ) + + def __init__( + self, + status: TaskStatus | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + error: dict[str, Any] | None = None, + suspension_reason: str | None = None, + lease_owner: str | None = None, + lease_instance_id: str | None = None, + lease_duration_seconds: int | None = None, + if_match: str | None = None, + ) -> None: + self.status = status + self.payload = payload + self.tags = tags + self.error = error + self.suspension_reason = suspension_reason + self.lease_owner = lease_owner + self.lease_instance_id = lease_instance_id + self.lease_duration_seconds = lease_duration_seconds + self.if_match = if_match diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_provider.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_provider.py new file mode 100644 index 000000000000..9fa2acaf326e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_provider.py @@ -0,0 +1,102 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Storage provider protocol for the durable task subsystem. + +Defines the structural typing contract that hosted and local providers +must satisfy. Uses :class:`typing.Protocol` (PEP 544) — implementations +do not need to inherit from this class. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from ._models import TaskCreateRequest, TaskInfo, TaskPatchRequest, TaskStatus + + +@runtime_checkable +class DurableTaskProvider(Protocol): + """Async storage backend for durable tasks. + + Both :class:`HostedDurableTaskProvider` (HTTP → Task Storage API) and + :class:`LocalFileDurableTaskProvider` (filesystem) implement this + protocol. + """ + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + ... + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a single task by ID. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + ... + + async def update(self, task_id: str, patch: TaskPatchRequest) -> TaskInfo: + """Update a task via PATCH semantics. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + ... + + async def delete( + self, + task_id: str, + *, + force: bool = False, + cascade: bool = False, + ) -> None: + """Delete a task. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks. + :paramtype cascade: bool + """ + ... + + async def list( + self, + *, + agent_name: str, + session_id: str, + status: TaskStatus | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + ) -> list[TaskInfo]: + """List tasks with filters. + + :keyword agent_name: Filter by agent name. + :paramtype agent_name: str + :keyword session_id: Filter by session ID. + :paramtype session_id: str + :keyword status: Filter by task status. + :paramtype status: TaskStatus | None + :keyword lease_owner: Filter by lease owner. + :paramtype lease_owner: str | None + :keyword tag: Filter by tags (AND semantics — all must match). + :paramtype tag: dict[str, str] | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + ... diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_result.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_result.py new file mode 100644 index 000000000000..4130b2f0d9bd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_result.py @@ -0,0 +1,81 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskResult wrapper for durable task completion and suspension outcomes.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +Output = TypeVar("Output") + + +class TaskResult(Generic[Output]): + """Result of a durable task execution. + + Wraps both completion and suspension outcomes. Failures, cancellation, + and termination are still raised as exceptions. + + :param task_id: The task identifier. + :type task_id: str + :param output: The task output value (typed for completion, optional for suspension). + :type output: Output | None + :param status: Whether the task completed, suspended, or was superseded. + :type status: ~typing.Literal["completed", "suspended", "superseded"] + :param suspension_reason: Human-readable suspension reason, if suspended. + :type suspension_reason: str | None + """ + + __slots__ = ("task_id", "output", "status", "suspension_reason") + + def __init__( + self, + *, + task_id: str, + output: Output | None = None, + status: Literal["completed", "suspended", "superseded"], + suspension_reason: str | None = None, + ) -> None: + self.task_id = task_id + self.output = output + self.status: Literal["completed", "suspended", "superseded"] = status + self.suspension_reason = suspension_reason + + @property + def is_completed(self) -> bool: + """Whether the task completed successfully. + + :return: True if the task completed. + :rtype: bool + """ + return self.status == "completed" + + @property + def is_suspended(self) -> bool: + """Whether the task was suspended. + + :return: True if the task is suspended. + :rtype: bool + """ + return self.status == "suspended" + + @property + def is_superseded(self) -> bool: + """Whether the generation was superseded by a steering input. + + :return: True if this generation was cancelled by a newer input. + :rtype: bool + """ + return self.status == "superseded" + + def __repr__(self) -> str: + output_repr = repr(self.output) + if len(output_repr) > 60: + output_repr = output_repr[:57] + "..." + parts = [ + f"TaskResult(task_id={self.task_id!r}, status={self.status!r}, output={output_repr}" + ] + if self.suspension_reason is not None: + parts.append(f", suspension_reason={self.suspension_reason!r}") + parts.append(")") + return "".join(parts) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_resume_route.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_resume_route.py new file mode 100644 index 000000000000..2af426376b3b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_resume_route.py @@ -0,0 +1,76 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""POST /tasks/resume — Starlette route for external task resume triggers. + +Returns an empty body with the appropriate status code: +- 202 Accepted: resume dispatched successfully +- 404 Not Found: task not found or not in a resumable state +- 409 Conflict: task is already in progress +""" + +from __future__ import annotations + +import json +import logging + +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +logger = logging.getLogger("azure.ai.agentserver.durable") + + +async def _handle_resume_request( + request: Request, +) -> Response: # pylint: disable=too-many-return-statements + """Handle POST /tasks/resume. + + Expects a JSON body with ``{"task_id": "..."}`` and dispatches the + resume to the DurableTaskManager. + + :param request: The incoming HTTP request. + :type request: Request + :return: Empty-body response with status code. + :rtype: Response + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + try: + body = await request.json() + except (json.JSONDecodeError, ValueError): + return Response(status_code=400) + + task_id = body.get("task_id") + if not task_id or not isinstance(task_id, str): + return Response(status_code=400) + + try: + manager = get_task_manager() + except RuntimeError: + return Response(status_code=503) + + try: + await manager.handle_resume(task_id) + logger.info("Resume accepted for task %s", task_id) + return Response(status_code=202) + + except Exception as exc: # pylint: disable=broad-exception-caught + msg = str(exc).lower() + if "not found" in msg: + return Response(status_code=404) + if "not 'suspended'" in msg or "already" in msg or "conflict" in msg: + return Response(status_code=409) + logger.error("Resume failed for task %s: %s", task_id, exc, exc_info=True) + return Response(status_code=500) + + +def create_resume_route() -> Route: + """Create the Starlette Route for POST /tasks/resume. + + :return: A Starlette Route to be added to the host. + :rtype: Route + """ + return Route("/tasks/resume", _handle_resume_request, methods=["POST"]) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_retry.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_retry.py new file mode 100644 index 000000000000..aa56b3eb8e26 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_retry.py @@ -0,0 +1,261 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""RetryPolicy — configurable retry behaviour for durable tasks. + +Aligned with industry conventions (Temporal, Azure Durable Functions, Celery). +Delay formula: ``min(initial_delay * backoff_coefficient ** attempt, max_delay)`` +With jitter: ``delay * uniform(0.75, 1.25)`` +""" + +from __future__ import annotations + +import random +from datetime import timedelta + + +class RetryPolicy: + """Retry configuration for durable tasks. + + :param initial_delay: Base delay between retries. + :type initial_delay: ~datetime.timedelta + :param backoff_coefficient: Multiplier applied per attempt. + :type backoff_coefficient: float + :param max_delay: Upper bound on computed delay. + :type max_delay: ~datetime.timedelta + :param max_attempts: Total attempts (including the first try). + :type max_attempts: int + :param retry_on: Exception types that trigger retry. ``None`` means all. + :type retry_on: tuple[type[Exception], ...] | None + :param jitter: Whether to add ±25% randomization to delays. + :type jitter: bool + + .. versionadded:: 2.1.0 + """ + + __slots__ = ( + "initial_delay", + "backoff_coefficient", + "max_delay", + "max_attempts", + "retry_on", + "jitter", + "_linear", + ) + + def __init__( + self, + *, + initial_delay: timedelta = timedelta(seconds=1), + backoff_coefficient: float = 2.0, + max_delay: timedelta = timedelta(seconds=60), + max_attempts: int = 3, + retry_on: tuple[type[Exception], ...] | None = None, + jitter: bool = True, + _linear: bool = False, + ) -> None: + if initial_delay.total_seconds() < 0: + raise ValueError(f"initial_delay must be >= 0, got {initial_delay}") + if max_attempts < 1 and not ( + max_attempts == 1 and initial_delay == timedelta(0) + ): + pass # allow no_retry preset + if backoff_coefficient < 1.0: + raise ValueError( + f"backoff_coefficient must be >= 1.0, got {backoff_coefficient}" + ) + if max_delay < initial_delay: + raise ValueError( + f"max_delay ({max_delay}) must be >= initial_delay ({initial_delay})" + ) + if max_attempts < 1: + raise ValueError(f"max_attempts must be >= 1, got {max_attempts}") + if retry_on is not None: + for exc_type in retry_on: + if not isinstance(exc_type, type) or not issubclass( + exc_type, Exception + ): + raise TypeError( + f"retry_on entries must be Exception subclasses, got {exc_type!r}" + ) + + self.initial_delay = initial_delay + self.backoff_coefficient = backoff_coefficient + self.max_delay = max_delay + self.max_attempts = max_attempts + self.retry_on = retry_on + self.jitter = jitter + self._linear = _linear + + def compute_delay(self, attempt: int) -> float: + """Return the delay in seconds for the given attempt (0-indexed). + + :param attempt: The 0-based attempt number that just failed. + :type attempt: int + :return: Delay in seconds before the next attempt. + :rtype: float + """ + base_seconds = self.initial_delay.total_seconds() + if self._linear: + # Linear: delay = initial_delay * (attempt + 1) + raw = base_seconds * (attempt + 1) + else: + # Exponential: delay = initial_delay * coefficient ^ attempt + raw = base_seconds * (self.backoff_coefficient**attempt) + + capped = min(raw, self.max_delay.total_seconds()) + + if self.jitter: + capped *= random.uniform(0.75, 1.25) + + return max(0.0, capped) + + def should_retry(self, attempt: int, error: Exception) -> bool: + """Return whether the task should be retried. + + :param attempt: The 0-based attempt number that just failed. + :type attempt: int + :param error: The exception that was raised. + :type error: Exception + :return: ``True`` if the task should be retried. + :rtype: bool + """ + # attempt is 0-indexed; max_attempts includes the first try + if attempt >= self.max_attempts - 1: + return False + if self.retry_on is None: + return True + return isinstance(error, self.retry_on) + + def __repr__(self) -> str: + return ( + f"RetryPolicy(initial_delay={self.initial_delay!r}, " + f"backoff_coefficient={self.backoff_coefficient}, " + f"max_delay={self.max_delay!r}, " + f"max_attempts={self.max_attempts}, " + f"retry_on={self.retry_on!r}, " + f"jitter={self.jitter})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RetryPolicy): + return NotImplemented + return ( + self.initial_delay == other.initial_delay + and self.backoff_coefficient == other.backoff_coefficient + and self.max_delay == other.max_delay + and self.max_attempts == other.max_attempts + and self.retry_on == other.retry_on + and self.jitter == other.jitter + and self._linear == other._linear + ) + + # ------------------------------------------------------------------ + # Convenience presets + # ------------------------------------------------------------------ + + @classmethod + def exponential_backoff( + cls, + *, + max_attempts: int = 3, + initial_delay: timedelta = timedelta(seconds=1), + max_delay: timedelta = timedelta(seconds=60), + jitter: bool = True, + ) -> RetryPolicy: + """Exponential backoff — the most common pattern. + + Delay doubles per attempt: 1 s → 2 s → 4 s → … capped at *max_delay*. + + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :keyword initial_delay: Base delay. + :paramtype initial_delay: ~datetime.timedelta + :keyword max_delay: Upper bound. + :paramtype max_delay: ~datetime.timedelta + :keyword jitter: Add ±25% randomization. + :paramtype jitter: bool + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=initial_delay, + backoff_coefficient=2.0, + max_delay=max_delay, + max_attempts=max_attempts, + jitter=jitter, + ) + + @classmethod + def fixed_delay( + cls, + *, + delay: timedelta = timedelta(seconds=5), + max_attempts: int = 3, + ) -> RetryPolicy: + """Fixed delay — constant interval between retries. + + Useful for rate-limited APIs where you want to wait a fixed + amount of time between each attempt. + + :keyword delay: Constant delay between retries. + :paramtype delay: ~datetime.timedelta + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=delay, + backoff_coefficient=1.0, + max_delay=delay, + max_attempts=max_attempts, + jitter=False, + ) + + @classmethod + def linear_backoff( + cls, + *, + initial_delay: timedelta = timedelta(seconds=1), + max_delay: timedelta = timedelta(seconds=60), + max_attempts: int = 5, + ) -> RetryPolicy: + """Linear backoff — delay grows additively. + + Delay is ``initial_delay * (attempt + 1)``: 1 s → 2 s → 3 s → … + + :keyword initial_delay: Base delay unit. + :paramtype initial_delay: ~datetime.timedelta + :keyword max_delay: Upper bound. + :paramtype max_delay: ~datetime.timedelta + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=initial_delay, + backoff_coefficient=1.0, + max_delay=max_delay, + max_attempts=max_attempts, + jitter=False, + _linear=True, + ) + + @classmethod + def no_retry(cls) -> RetryPolicy: + """No retry — the function runs once and fails on exception. + + Equivalent to not setting a retry policy at all. + + :return: A ``RetryPolicy`` that never retries. + :rtype: RetryPolicy + """ + return cls( + initial_delay=timedelta(0), + backoff_coefficient=1.0, + max_delay=timedelta(0), + max_attempts=1, + jitter=False, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_run.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_run.py new file mode 100644 index 000000000000..267f8a06f400 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_run.py @@ -0,0 +1,242 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskRun handle and Suspended sentinel for the durable task subsystem.""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from typing import Any, Generic, TypeVar + +from ._exceptions import ( + TaskNotFound, +) +from ._metadata import TaskMetadata +from ._models import TaskInfo, TaskStatus +from ._provider import DurableTaskProvider +from ._result import TaskResult +from ._stream import StreamHandler + +Output = TypeVar("Output") + + +class Suspended(Generic[Output]): + """Sentinel return value from :meth:`TaskContext.suspend`. + + Must be used as ``return await ctx.suspend(...)``. The framework + interprets this on function return to transition the task. + + :param reason: Human-readable suspension reason. + :type reason: str | None + :param output: Optional snapshot for observers. + :type output: Output | None + """ + + __slots__ = ("reason", "output") + + def __init__( + self, + reason: str | None = None, + output: Output | None = None, + ) -> None: + self.reason = reason + self.output = output + + def __repr__(self) -> str: + return f"Suspended(reason={self.reason!r})" + + +class TaskRun(Generic[Output]): # pylint: disable=too-many-instance-attributes + """Handle to a running or completed durable task. + + Returned by :meth:`DurableTask.start`. Provides external observation + and control of the task lifecycle. + + :param task_id: The task identifier. + :type task_id: str + :param provider: Storage provider for refresh/delete operations. + :type provider: DurableTaskProvider + :param result_future: Future that resolves with the task output. + :type result_future: asyncio.Future[Output] + :param metadata: The task's metadata instance. + :type metadata: TaskMetadata + :param cancel_event: Event to signal cancellation. + :type cancel_event: asyncio.Event + :param status: Initial task status. + :type status: TaskStatus + """ + + __slots__ = ( + "task_id", + "_provider", + "_result_future", + "_metadata", + "_cancel_event", + "_terminate_event", + "_terminate_reason_ref", + "_status", + "_stream_handler", + "_execution_task", + "_lease_expiry_count", + ) + + def __init__( + self, + task_id: str, + *, + provider: DurableTaskProvider, + result_future: asyncio.Future[TaskResult[Output]], + metadata: TaskMetadata | None = None, + cancel_event: asyncio.Event | None = None, + status: TaskStatus = "in_progress", + stream_handler: StreamHandler | None = None, + terminate_event: asyncio.Event | None = None, + execution_task: asyncio.Task[Any] | None = None, + terminate_reason_ref: list[str | None] | None = None, + lease_expiry_count: int = 0, + ) -> None: + self.task_id = task_id + self._provider = provider + self._result_future = result_future + self._metadata = metadata or TaskMetadata() + self._cancel_event = cancel_event or asyncio.Event() + self._terminate_event = terminate_event or asyncio.Event() + self._terminate_reason_ref: list[str | None] = ( + terminate_reason_ref if terminate_reason_ref is not None else [None] + ) + self._status = status + self._stream_handler: StreamHandler | None = stream_handler + self._execution_task: asyncio.Task[Any] | None = execution_task + self._lease_expiry_count = lease_expiry_count + + @property + def status(self) -> TaskStatus: + """Current task status (may be stale — call :meth:`refresh` to update). + + :return: The task status. + :rtype: TaskStatus + """ + return self._status + + @property + def metadata(self) -> TaskMetadata: + """The task's metadata. + + For in-process handles, this is the live metadata reference. For + remote observation, call :meth:`refresh` first. + + :return: The task metadata instance. + :rtype: TaskMetadata + """ + return self._metadata + + @property + def lease_expiry_count(self) -> int: + """Number of times the lease expired and ownership changed. + + Useful for dashboards to detect ownership churn. Call + :meth:`refresh` to get the latest value. + + :return: The lease expiry count. + :rtype: int + """ + return self._lease_expiry_count + + async def result(self) -> TaskResult[Output]: + """Await task completion and return the result. + + Returns a :class:`TaskResult` that wraps both completion and + suspension outcomes. Failures, cancellation, and termination are + still raised as exceptions. + + :return: The task result wrapper. + :rtype: TaskResult[Output] + :raises TaskFailed: If the function raised an exception. + :raises TaskCancelled: If the task was cancelled. + :raises TaskTerminated: If the task was terminated. + :raises TaskNotFound: If the task was deleted externally. + """ + return await self._result_future + + async def cancel(self) -> None: + """Signal cancellation to the running task. + + Sets the ``cancel`` event on the task context. The function + should check ``ctx.cancel.is_set()`` and exit cleanly. + """ + self._cancel_event.set() + + async def terminate(self, *, reason: str | None = None) -> None: + """Forcefully terminate the task. + + Unlike :meth:`cancel`, terminated tasks go through the failure path + and do NOT stay ``in_progress`` for recovery. + + :keyword reason: Optional human-readable termination reason. + :paramtype reason: str | None + """ + self._terminate_reason_ref[0] = reason + self._terminate_event.set() + self._cancel_event.set() + if self._execution_task is not None and not self._execution_task.done(): + self._execution_task.cancel() + + async def delete(self) -> None: + """Delete the task record from the store. + + :raises TaskNotFound: If the task does not exist. + """ + try: + await self._provider.delete(self.task_id, force=True) + except Exception as exc: + if "not found" in str(exc).lower(): + raise TaskNotFound(self.task_id) from exc + raise + + async def refresh(self) -> None: + """Re-fetch task state from the store. + + Updates :attr:`status` and :attr:`metadata` from the current + task record. + """ + task_info: TaskInfo | None = await self._provider.get(self.task_id) + if task_info is None: + raise TaskNotFound(self.task_id) + self._status = task_info.status + # Update lease expiry count + if task_info.lease is not None: + self._lease_expiry_count = task_info.lease.expiry_count + # Update metadata from payload + if task_info.payload and "metadata" in task_info.payload: + meta_data: dict[str, Any] = task_info.payload["metadata"] + for key, value in meta_data.items(): + self._metadata.set(key, value) + + def __aiter__(self) -> TaskRun[Output]: + """Return self as an async iterator over streamed items. + + Usage:: + + async for chunk in task_run: + print(chunk) + + :return: Self. + :rtype: TaskRun + """ + return self + + async def __anext__(self) -> Any: + """Yield the next streamed item, or raise ``StopAsyncIteration``. + + If no stream handler was provided, raises ``StopAsyncIteration`` + immediately (the task does not stream). When the stream is + closed, ``handler.get()`` raises ``StopAsyncIteration`` which + propagates naturally. + + :return: The next streamed item. + :rtype: Any + :raises StopAsyncIteration: When streaming ends. + """ + if self._stream_handler is None: + raise StopAsyncIteration + return await self._stream_handler.get() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_stream.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_stream.py new file mode 100644 index 000000000000..ad28645c0c27 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_stream.py @@ -0,0 +1,104 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Pluggable stream handler protocol and default implementation. + +Provides :class:`StreamHandler` — a structural protocol that controls +how stream items are transported between the task function (producer +via ``ctx.stream()``) and consumers (via ``async for chunk in run``). + +The default :class:`QueueStreamHandler` wraps :class:`asyncio.Queue` +and preserves the existing in-memory, single-consumer behavior. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class StreamHandler(Protocol): + """Protocol for pluggable stream transports. + + Implementations control how stream items move between the task + function (producer) and any number of consumers. The framework + calls :meth:`put` from ``ctx.stream()``, consumers call + :meth:`get` via ``async for chunk in run``, and the framework + calls :meth:`close` when the task finishes. + + All three methods are required. + """ + + async def put(self, item: Any) -> None: + """Accept a stream item from the task function. + + :param item: The value to stream. + :type item: Any + """ + ... + + async def get(self) -> Any: + """Return the next stream item, blocking until one is available. + + :return: The next streamed item. + :rtype: Any + :raises StopAsyncIteration: When the stream has been closed. + """ + ... + + async def close(self) -> None: + """Signal end-of-stream. + + After this call, :meth:`get` must raise + :class:`StopAsyncIteration`. Called by the framework when the + task finishes — both on success and on failure. + """ + ... + + +class QueueStreamHandler: + """Default stream handler wrapping :class:`asyncio.Queue`. + + Single-consumer, in-memory, unbounded. Preserves the exact + behavior of the previous raw-queue implementation. + + .. versionadded:: 2.1.0 + """ + + _SENTINEL: object = object() + """Internal sentinel placed in the queue by :meth:`close`.""" + + def __init__(self) -> None: + self._queue: asyncio.Queue[Any] = asyncio.Queue() + + async def put(self, item: Any) -> None: + """Enqueue a stream item. + + :param item: The value to stream. + :type item: Any + """ + await self._queue.put(item) + + async def get(self) -> Any: + """Dequeue the next stream item. + + Blocks until an item is available. Raises + :class:`StopAsyncIteration` when the stream has been closed. + + :return: The next streamed item. + :rtype: Any + :raises StopAsyncIteration: When the stream has been closed. + """ + item = await self._queue.get() + if item is self._SENTINEL: + raise StopAsyncIteration + return item + + async def close(self) -> None: + """Signal end-of-stream by placing the sentinel in the queue. + + Subsequent :meth:`get` calls will raise + :class:`StopAsyncIteration`. + """ + await self._queue.put(self._SENTINEL) diff --git a/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md b/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md new file mode 100644 index 000000000000..c3b3ad55787d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md @@ -0,0 +1,1409 @@ +# Durable Task Developer Guide + +> Developer guidance for building crash-resilient agents with `@durable_task` — the single decorator for turning async functions into units of work that survive container crashes, OOM kills, and redeployments. + +--- + +## Table of Contents + +- [Overview](#overview) + - [Why This Exists](#why-this-exists) + - [What You Get](#what-you-get) +- [Getting Started](#getting-started) +- [Lifecycle Automation](#lifecycle-automation) + - [State Diagram](#state-diagram) + - [Entry Mode Decision Table](#entry-mode-decision-table) + - [.run() vs .start() vs .get() vs .list()](#run-vs-start-vs-get-vs-list) +- [TaskContext](#taskcontext) + - [Properties Reference](#properties-reference) + - [Branching on Entry Mode](#branching-on-entry-mode) +- [Suspend & Resume](#suspend--resume) + - [Multi-Turn Conversations](#multi-turn-conversations) +- [Steering](#steering) + - [What Steering Solves](#what-steering-solves) + - [Generation Model](#generation-model) + - [Enabling Steering](#enabling-steering) + - [The Three-Phase Cancel Pattern](#the-three-phase-cancel-pattern) + - [Steering Flow Diagram](#steering-flow-diagram) + - [What Happens to Each Generation](#what-happens-to-each-generation) + - [Rapid-Fire Steering](#rapid-fire-steering) + - [Preserving Fidelity with External SDKs](#preserving-fidelity-with-external-sdks) + - [Steering Recovery](#steering-recovery) + - [Complete Steering Example](#complete-steering-example) +- [Streaming](#streaming) + - [Custom Stream Handlers](#custom-stream-handlers) +- [Persistence](#persistence) + - [Responsibility Matrix](#responsibility-matrix) + - [The Durable Boundary Rule](#the-durable-boundary-rule) +- [The Invocation Store Pattern](#the-invocation-store-pattern) +- [RetryPolicy](#retrypolicy) +- [Decorator Options](#decorator-options) +- [Error Handling](#error-handling) +- [Best Practices](#best-practices) +- [Common Mistakes](#common-mistakes) + +--- + +## Overview + +### Why This Exists + +Azure AI Foundry Hosted Agents run your code in platform-managed containers. +Those containers can be killed at any time — OOM kills, node preemptions, +rolling deployments, or unexpected crashes. Without durability, any in-flight +work is lost and the agent starts from scratch. + +Agent frameworks fall into two camps: + +| Category | Examples | What they need | +|----------|----------|----------------| +| **Externally stateful** — the framework owns durability | Temporal, Durable Functions, Orleans | Platform visibility: lifecycle tracking, lease-based liveness, status reporting on top of the framework's own durability | +| **Locally stateful** — the container holds state | LangGraph (SQLite checkpointer), Claude SDK tool loops, hand-written agents | A crash-safe entry point: lease-based liveness so the platform knows when to restart, plus run / resume / progress / suspend primitives the developer would otherwise hand-roll | + +`@durable_task` serves both camps. It is **not** a replacement for Temporal or +Durable Functions — it is the thin durable wrapper around the boundary between +the platform and your code. It does not make your function deterministic or +replayable. It turns `run(input) → output` into a unit of work that survives +a container crash, a deployment, or an idle-deactivation — with hooks for +progress, suspension, cancellation, and steering that compose with whatever +framework you use underneath. + +### What You Get + +Decorate your async function, and the framework guarantees it runs to completion +— even if the container restarts mid-execution. On recovery, your function is +re-invoked with the same input and last-saved metadata, so it can pick up where +it left off. + +**Your contract:** + +- Write a normal `async` function that takes a `TaskContext` +- Use `ctx.metadata` to record lightweight progress (e.g. current phase, step count) +- Check `ctx.entry_mode` if you need to distinguish fresh runs from recoveries +- Return a result, or `await ctx.suspend()` for multi-turn patterns + +**What you get:** + +- Automatic crash recovery — your function re-runs without any caller intervention +- Input and metadata persistence across restarts +- Retry with configurable backoff on failures +- Cooperative cancellation and timeout +- Streaming incremental output to observers +- Suspend/resume for multi-turn conversational agents +- Steering — submit a new input to a running task without cancel/wait/restart + +### What durable tasks are NOT + +- **Not a checkpoint/replay engine.** This is not Temporal or Durable Functions. + Your function is re-executed from the top on recovery, not replayed from a + deterministic log. If your function calls an LLM twice, it will call it again + on recovery. +- **Not a result store.** Task output and metadata exist only while the task is + alive. Once the task is deleted, they are gone. If you need results to outlive + the task, persist them in your own store (database, blob storage, etc.). +- **Not a stream log.** Streamed chunks are relayed to live observers in real + time but are not recorded. If a consumer connects after streaming ends, + the chunks are gone. +- **Not application-level persistence.** The framework manages *task lifecycle* + state (status, input, metadata, lease). Your application data — conversation + history, invocation results, user-facing state — is your responsibility. + See [Persistence](#persistence). +- **Not unbounded storage.** `ctx.metadata` is for small progress signals + (current phase, retry count, step index), not for accumulating large data. + The task payload has a 1 MB cap. Write large or growing data to your own store. + +--- + +## Getting Started + +A minimal durable task in 15 lines: + +```python +from azure.ai.agentserver.core.durable import durable_task, TaskContext + +@durable_task +async def greet(ctx: TaskContext[str]) -> str: + """A simple durable task that greets the user.""" + name = ctx.input + return f"Hello, {name}!" + +# Run it — lifecycle-aware: creates if new, recovers if stale +result = await greet.run(task_id="greet-alice", input="Alice") +print(result.output) # "Hello, Alice!" +``` + +That's it. The decorator transforms your function into a `DurableTask` with `.run()`, +`.start()`, `.get()`, and `.list()` methods. The function itself takes a single `TaskContext` +parameter. + +If the container crashes mid-execution, the framework automatically recovers the +task on restart — before any HTTP handlers go live. Your function is re-invoked +with `ctx.entry_mode = "recovered"` and the same input. No caller action is needed. + +If a caller calls `.run()` with a `task_id` that is already in progress, +the framework raises `TaskConflictError` — it does not create a duplicate. + +--- + +## Lifecycle Automation + +Every call to `.run()` or `.start()` follows the same state machine. You never +manually check task state or call resume — the framework does it for you. + +### State Diagram + +What the framework does when you call `.run()` or `.start()`: + +``` + .run() / .start() + │ + ▼ + ┌───── task exists? ─────┐ + │ │ + No Yes + │ │ + ▼ ▼ + ┌──────────┐ ┌──── status? ──────────────────────────┐ + │ Create │ │ │ │ │ + │ & Start │ pending suspended in_progress completed + └──────────┘ │ │ │ │ + │ ▼ ▼ ▼ ▼ + fresh fresh resumed stale? ephemeral? + │ │ + ┌─────┴─────┐ ┌───┴───┐ + Yes No Yes No + │ │ │ │ + ▼ ▼ ▼ ▼ + recovered steerable? fresh¹ TaskConflict + │ Error + ┌─────┴─────┐ + Yes No + │ │ + ▼ ▼ + Queue input TaskConflict + + cancel → Error + drain resumes + ("resumed", + was_steered) +``` + +¹ Ephemeral completed tasks were auto-deleted on completion, so they appear as +"no task exists" and a fresh task is created transparently. + +### Entry Mode Decision Table + +| Current State | Action | `ctx.entry_mode` | `ctx.was_steered` | +|---|---|---|---| +| No task exists | Create and start | `"fresh"` | `False` | +| `pending` | Start | `"fresh"` | `False` | +| `suspended` | Resume with new input | `"resumed"` | `False` | +| `in_progress` (stale) | Recover | `"recovered"` | `True` if steering state exists ¹ | +| `in_progress` (not stale, **steerable**) | Queue input, signal cancel → drain resumes | `"resumed"` | `True` | +| `in_progress` (not stale, not steerable) | **Raises `TaskConflictError`** | — | — | +| `completed` (ephemeral) | Task was auto-deleted → create fresh | `"fresh"` | `False` | +| `completed` (non-ephemeral) | **Raises `TaskConflictError`** | — | — | + +¹ When recovering a steerable task that crashed mid-drain, the initial recovery +enters with `"recovered"` and `was_steered=True`. The framework then drains +the pending queue, re-entering the function with `entry_mode="resumed"` and +`was_steered=True` for each queued input. See [Steering Recovery](#steering-recovery). + +A task is considered **stale** when its last update is older than `stale_timeout` +(default: 300 seconds). This means the previous execution likely crashed. + +### .run() vs .start() vs .get() vs .list() + +| Method | Blocks? | Returns | Use When | +|--------|---------|---------|----------| +| `.run()` | Yes — awaits completion | `TaskResult[Output]` | You want the result inline | +| `.start()` | No — returns immediately | `TaskRun[Output]` | You want a handle for polling/streaming | +| `.get()` | No — reads from store | `TaskInfo \| None` | You want to query task state without executing | +| `.list()` | No — reads from store | `list[TaskInfo]` | You want all tasks for this function | + +`.run()` and `.start()` follow the same lifecycle rules. The only difference is +whether you wait for the result or get a handle back. + +```python +# .start() returns immediately with a handle +task_run = await greet.start(task_id="greet-bob", input="Bob") + +# Use the handle to await the result later +result = await task_run.result() + +# Or stream incremental output (if the task uses ctx.stream()) +async for chunk in task_run: + print(chunk) +``` + +`.get()` does not execute the task. It reads whatever is persisted: + +```python +info = await greet.get("greet-bob") +if info is not None: + print(info.status) # "completed", "suspended", "in_progress", etc. + print(info.payload) # Contains input, metadata, output buckets +``` + +`.list()` returns all tasks created by this decorated function. It is automatically +scoped — each function only sees its own tasks: + +```python +# List all suspended tasks for this function +suspended = await greet.list(status="suspended") +for t in suspended: + print(t.id, t.status, t.created_at) + +# List all tasks (any status) +all_tasks = await greet.list() +``` + +> `.list()` is automatically scoped — each decorated function only sees tasks it +> created. The `name` option on `@durable_task` is the key that determines which +> tasks belong to this function. + +--- + +## TaskContext + +Every durable task function receives exactly one parameter: a `TaskContext[Input]` +where `Input` is your typed input type. + +### Properties Reference + +| Property | Type | Description | +|----------|------|-------------| +| `ctx.input` | `Input` | The typed input value passed to `.run()` / `.start()` | +| `ctx.entry_mode` | `EntryMode` | Why the function was entered: `"fresh"`, `"resumed"`, or `"recovered"` | +| `ctx.task_id` | `str` | The task's unique identifier | +| `ctx.session_id` | `str` | Session scope identifier | +| `ctx.metadata` | `TaskMetadata` | Mutable progress metadata (persisted automatically) | +| `ctx.agent_name` | `str` | Agent name from platform configuration | +| `ctx.lease_generation` | `int` | Lease generation counter (increments on recovery) | +| `ctx.cancel` | `asyncio.Event` | Set when cancellation is requested (including steering cancel) | +| `ctx.shutdown` | `asyncio.Event` | Set when the container is shutting down | +| `ctx.run_attempt` | `int` | Framework retry attempt counter (0-indexed) | +| `ctx.title` | `str` | Human-readable task title | +| `ctx.tags` | `dict[str, str]` | Merged decorator + call-site tags | +| `ctx.description` | `str \| None` | Task description (from decorator or call-site) | +| `ctx.generation` | `int` | Steering generation counter (0 for first run, increments on each steer) | +| `ctx.previous_input` | `Input \| None` | The superseded generation's input (set when steering state is present) | +| `ctx.pending_inputs` | `Sequence[Any]` | Read-only snapshot of queued steering inputs at function entry | +| `ctx.was_steered` | `bool` | `True` when this entry involves steering — the function is being re-entered with a new input from the steering queue. Always check this to detect steering; `entry_mode` will be `"resumed"` for normal steering drains or `"recovered"` for crash recovery of a mid-drain | + +### Branching on Entry Mode + +Use `ctx.entry_mode` to handle different execution scenarios: + +```python +from azure.ai.agentserver.core.durable import durable_task, TaskContext, EntryMode + +@durable_task(name="process_order") +async def process_order(ctx: TaskContext[dict]) -> dict: + order = ctx.input + + if ctx.entry_mode == "fresh": + # First time — validate and begin processing + ctx.metadata["step"] = "validating" + + elif ctx.entry_mode == "recovered": + # Crashed mid-execution — check what was already done + step = ctx.metadata.get("step", "validating") + if step == "charged": + # Payment already taken — skip to fulfillment + return await fulfill(order) + + elif ctx.entry_mode == "resumed": + # Resumed after suspension — ctx.input has new data + # For steerable tasks, check ctx.was_steered for steering context + if ctx.was_steered: + # This resume was triggered by steering — ctx.previous_input + # has the superseded generation's input + pass + + # ... do work ... + ctx.metadata["step"] = "charged" + return {"status": "completed", "order_id": order["id"]} +``` + +**`TaskMetadata`** is automatically persisted to the task store. Use it to track +progress so that recovered tasks can skip completed steps: + +```python +# Dict-style access (recommended) +ctx.metadata["progress"] = 50 # set a value +ctx.metadata["phase"] = "summarizing" # set another +progress = ctx.metadata["progress"] # read (raises KeyError if missing) +if "phase" in ctx.metadata: # containment check + print(f"Phase: {ctx.metadata['phase']}") +for key in ctx.metadata: # iterate keys + print(f"{key}: {ctx.metadata[key]}") + +# Convenience methods for special operations +ctx.metadata.increment("items_processed") # atomic increment +ctx.metadata.append("logs", "step 3 done") # append to list +progress = ctx.metadata.get("progress") # read with default (no KeyError) +snapshot = ctx.metadata.to_dict() # full snapshot copy +``` + +All mutations (including `[]` assignment and `del`) are automatically tracked +and flushed to the task store on a 5-second debounce interval. + +--- + +## Suspend & Resume + +Use `ctx.suspend()` to pause execution and release the task lease. The task +transitions to `suspended` status. A subsequent `.run()` or `.start()` call +resumes it with `entry_mode="resumed"` and new input. + +> **Critical**: Always use `return await ctx.suspend(...)`. Forgetting `return` +> or `await` silently breaks the suspension mechanism. + +```python +@durable_task(name="approval_flow") +async def approval_flow(ctx: TaskContext[dict]) -> dict: + request = ctx.input + + if ctx.entry_mode == "fresh": + # Submit for approval, then suspend + return await ctx.suspend(output={"status": "awaiting_approval", "request": request}) + + elif ctx.entry_mode == "resumed": + # Manager responded — ctx.input has the approval decision + decision = ctx.input + if decision.get("approved"): + return {"status": "approved", "approved_by": decision["manager"]} + return {"status": "rejected", "reason": decision.get("reason")} +``` + +The `output` parameter on `ctx.suspend()` is optional. It provides a snapshot +that observers can read while the task is suspended (via `.get()` or the +`TaskResult`'s `.output` attribute). + +### Multi-Turn Conversations + +The suspend/resume pattern is ideal for multi-turn agents where each turn is +one user ↔ agent interaction: + +```python +@durable_task(name="chat_session") +async def chat_session(ctx: TaskContext[dict]) -> dict: + message = ctx.input["message"] + + if ctx.entry_mode == "fresh": + history = [] + elif ctx.entry_mode == "resumed": + history = ctx.metadata.get("history", []) + + # Generate response (your LLM call, graph execution, etc.) + reply = await generate_reply(message, history) + + # Track conversation history in metadata + history.append({"role": "user", "content": message}) + history.append({"role": "assistant", "content": reply}) + ctx.metadata["history"] = history + + # Suspend — waiting for the next user message + return await ctx.suspend(output={"reply": reply}) +``` + +Each call to `.run(task_id=session_id, input={"message": "..."})` or +`.start(task_id=session_id, input={"message": "..."})` resumes the +same task with the new message. The framework handles the transition +from `suspended` to `in_progress` automatically. + +--- + +## Steering + +Steering extends the suspend/resume pattern for scenarios where a user sends +a new message while the agent is still processing the previous one. Without +steering, a `.start()` on an `in_progress` task raises `TaskConflictError` — +the caller must cancel, wait for the function to exit, and then start again. +With steering, the framework handles this automatically. + +### What Steering Solves + +Consider a chat UI. The user sends "Tell me about Python", then immediately +types "Actually, tell me about Rust" before the first reply finishes. Without +steering: + +1. The caller sees `TaskConflictError` on the second `.start()` +2. The caller must call `run.cancel()` and wait for the function to exit +3. Then call `.start()` again with the new input +4. Race conditions abound — what if another message arrives during step 2? + +With steering, the caller just calls `.start()` again. The framework queues +the new input, signals the running function to cancel, and re-enters the +function with the new input once the current generation exits. No manual +cancel/wait/restart dance. + +### Generation Model + +Each time the framework enters the durable function, it increments a +**generation** counter. This gives each invocation a stable identity: + +``` +Generation 0: fresh start with input A → entry_mode="fresh", was_steered=False +Generation 1: steered — input B replaced input A → entry_mode="resumed", was_steered=True +Generation 2: steered — input C (short-circuited) → entry_mode="resumed", was_steered=True +Generation 3: normal resume — user sends input D → entry_mode="resumed", was_steered=False +``` + +Generations are persisted in the task payload. Each `TaskRun` returned to a +caller is bound to a specific generation, so there is no ambiguity about which +invocation a caller is observing. + +### Enabling Steering + +Add `steerable=True` to the decorator: + +```python +@durable_task(name="chat_session", steerable=True) +async def chat_session(ctx: TaskContext[dict]) -> dict: + ... +``` + +| Decorator Option | Type | Default | Description | +|------------------|------|---------|-------------| +| `steerable` | `bool` | `False` | Enable steering support | +| `max_pending` | `int` | `10` | Maximum queued inputs. Excess raises `SteeringQueueFull` | + +When `steerable=False` (default), behavior is unchanged — `.start()` on an +`in_progress` task raises `TaskConflictError`. + +### The Three-Phase Cancel Pattern + +When a steering input arrives, the framework sets `ctx.cancel` on the running +function. But cancel can arrive at three different points. Your function must +handle all three: + +```python +@durable_task(name="agent_session", steerable=True) +async def agent_session(ctx: TaskContext[dict]) -> dict: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + + # ── Phase 1: Pre-entry cancel ─────────────────────────────── + # Cancel was already set before the function body runs. + # This happens in rapid-fire scenarios where multiple inputs + # queue up faster than the function can start. + if ctx.cancel.is_set(): + invocation_store.save(invocation_id, { + "status": "cancelled", "reason": "steered", + }) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Mid-stream cancel ────────────────────────────── + # Check cancel between each chunk of work. This is where most + # steering cancels land in practice. + reply = "" + async for token in call_llm_streaming(message): + reply += token + if ctx.cancel.is_set(): + break # Stop producing — save what we have + + # ── Phase 3: Post-completion cancel ───────────────────────── + # Cancel arrived after the LLM finished but before we returned. + # The reply is complete, but it will be superseded by the next + # generation. Save the result so it is not lost. + was_steered = ctx.cancel.is_set() + + result = {"reply": reply, "partial": was_steered} + if was_steered: + invocation_store.save(invocation_id, { + "status": "superseded", "output": result, + }) + return await ctx.suspend(reason="steered") + + invocation_store.save(invocation_id, { + "status": "completed", "output": result, + }) + return await ctx.suspend(reason="awaiting_user_input", output=result) +``` + +**Key rule**: Always save your work before returning, even when cancelled. +The user's message was received and should be preserved (appended to +conversation history, written to your store, etc.). Only the *reply generation* +is interrupted — not the input recording. + +> **⚠️ Steerable tasks MUST suspend when steered — never return normally or +> raise.** When `ctx.cancel.is_set()` due to steering, always exit with +> `return await ctx.suspend(reason="steered")`. This keeps the task alive so +> the framework can drain the pending queue and resume with the next input. +> +> - **Normal return** → task completes → next `.start()` creates a fresh task +> → conversation continuity broken +> - **Raise exception** → task enters failure/retry path → wrong lifecycle +> - **Suspend** ✅ → task stays alive → framework resumes with next queued input + +### Steering Flow Diagram + +``` + Caller A: .start(input=A) Caller B: .start(input=B) + │ │ + ▼ │ + ┌──────────────┐ │ + │ Gen 0: fresh │ ◄── function starts │ + │ processing │ │ + │ ... │ ◄── ctx.cancel.set() ◄──────┤ input B queued + │ (checks │ │ + │ cancel) │ │ + │ break │ │ + └──────┬───────┘ │ + │ returns via suspend(reason="steered") + ▼ │ + ┌──────────────────┐ │ + │ Framework drains │ │ + │ pending queue │ │ + │ (pops input B) │ │ + └──────┬───────────┘ │ + │ │ + ▼ │ + ┌──────────────┐ │ + │ Gen 1:resumed│ ◄── function re-entered │ + │ was_steered │ ctx.previous_input = A │ + │ ctx.input = B│ │ + │ processing │ │ + │ ... │ │ + │ (completes) │ │ + └──────┬───────┘ │ + │ returns via suspend() │ + ▼ │ + Caller B's TaskRun Caller A's TaskRun + resolves with result resolved earlier with + "superseded" status +``` + +### What Happens to Each Generation + +| Scenario | Status Written to Store | `TaskRun` Resolution | +|----------|------------------------|----------------------| +| Pre-entry cancel (Phase 1) | `"cancelled"` — input preserved, no reply attempted | Superseded | +| Mid-stream cancel (Phase 2) | `"superseded"` — partial reply saved | Superseded | +| Post-completion cancel (Phase 3) | `"superseded"` — full reply saved | Superseded | +| Normal completion | `"completed"` — full reply | Completed | + +Superseded `TaskRun` handles resolve when the framework drains the queue and +starts the next generation. Callers polling these handles see the result of +their specific generation. + +### Rapid-Fire Steering + +When multiple inputs arrive in quick succession: + +``` +User types: "What is Python?" → "Actually, Rust" → "No wait, Go" +``` + +The framework queues all of them. Only the last one (Go) runs to completion: + +``` +Gen 0: "What is Python?" → cancel pre-set → Phase 1 short-circuit +Gen 1: "Actually, Rust" → cancel pre-set → Phase 1 short-circuit +Gen 2: "No wait, Go" → queue empty → full execution +``` + +**Important**: Each short-circuited generation still enters the function. +This is by design — it gives the developer a chance to: + +- Record the user's message in conversation history +- Write a `"cancelled"` status to the invocation store +- Perform any other bookkeeping + +The framework does NOT silently discard queued inputs. Every input gets a +function invocation, even if that invocation immediately short-circuits. + +### Preserving Fidelity with External SDKs + +When wrapping external LLM SDKs (Claude, Copilot, LangGraph), steering adds +a layer on top of the SDK's own interruption model. Be aware of how each SDK +handles cancellation: + +**Streaming SDKs (Claude, OpenAI)**: These use `async for token in stream`. +Breaking out of the loop is clean — the SDK handles connection cleanup. Check +`ctx.cancel.is_set()` between chunks: + +```python +async with client.messages.stream(...) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + break # SDK cleans up the stream +``` + +**Event-based SDKs (Copilot)**: These deliver results via callbacks. Use +`session.abort()` to stop event delivery, then let the handler drain: + +```python +session.on(handler) # Register callback +session.send(message) # Start generation (non-blocking) +# Wait for either completion or cancel: +done, _ = await asyncio.wait( + [completion_event.wait(), cancel_wait()], + return_when=asyncio.FIRST_COMPLETED, +) +if ctx.cancel.is_set(): + session.abort() # Stop further events +``` + +**Graph SDKs (LangGraph)**: These run a graph to completion. Use checkpoint +IDs to fork from a known state rather than replaying the full graph: + +```python +if ctx.was_steered and ctx.previous_input: + # Fork from the checkpoint before the superseded run + checkpoint_id = ctx.metadata.get("stable_checkpoint_id") + config = {"configurable": {"thread_id": ..., "checkpoint_id": checkpoint_id}} +``` + +### Steering Recovery + +If the container crashes while a steered task is processing: + +1. The task is `in_progress` with steering state in the payload +2. On container restart, the framework detects the stale task +3. If there are pending inputs in the queue, the framework recovers with + `entry_mode="recovered"` and `was_steered=True`, then drains the queue +4. If a drain was in progress when the crash occurred, the framework + resumes the drain from the persisted `active_input` +5. Each drained input re-enters the function with `entry_mode="resumed"` + and `was_steered=True` + +No data is lost — the pending queue and generation counter are persisted in +the task payload. + +> **How to detect steering**: Always use `ctx.was_steered` — never check +> `entry_mode` for steering. Steering re-entries arrive as `"resumed"` +> (because the task suspended and is being resumed with a new input). The +> `was_steered` flag tells you whether steering context (`previous_input`, +> `generation`, `pending_inputs`) is meaningful. + +### Complete Steering Example + +A full steerable chat session combining all patterns: + +```python +from azure.ai.agentserver.core.durable import TaskContext, durable_task +from my_app.store import FileStore + +invocation_store = FileStore("./invocations") +conversation_store = FileStore("./conversations") + + +@durable_task(name="chat_session", steerable=True) +async def chat_session(ctx: TaskContext[dict]) -> dict: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + + # Mark invocation as running (inside the durable boundary) + invocation_store.save(invocation_id, {"status": "running"}) + + # Load conversation history from external store (not task metadata) + history = conversation_store.load(session_id) or [] + history.append({"role": "user", "content": message}) + + # ── Phase 1: Pre-entry cancel ─────────────────────────────── + if ctx.cancel.is_set(): + conversation_store.save(session_id, history) + invocation_store.save(invocation_id, { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + }) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Stream response, checking cancel ─────────────── + reply = "" + was_aborted = False + async for token in call_llm_streaming(message, history): + reply += token + if ctx.cancel.is_set(): + was_aborted = True + break + + # ── Phase 3: Save result ──────────────────────────────────── + if reply: + history.append({"role": "assistant", "content": reply}) + conversation_store.save(session_id, history) + + output = {"reply": reply, "partial": was_aborted} + + if was_aborted or ctx.cancel.is_set(): + invocation_store.save(invocation_id, { + "status": "superseded", "output": output, + }) + return await ctx.suspend(reason="steered") + + # Normal completion — suspend awaiting next user message + invocation_store.save(invocation_id, { + "status": "completed", "output": output, + }) + return await ctx.suspend(reason="awaiting_user_input", output=output) +``` + +The HTTP layer remains unchanged — callers call `POST /invoke` and poll +`GET /invocations/{id}`. The steering happens transparently inside the +durable task boundary. + +--- + +## Streaming + +Use `ctx.stream()` to emit incremental output and `async for` on the `TaskRun` +handle to consume it: + +```python +@durable_task(name="generate_report") +async def generate_report(ctx: TaskContext[str]) -> str: + topic = ctx.input + chunks = [] + async for token in call_llm_streaming(topic): + await ctx.stream(token) # Emit to observers + chunks.append(token) + return "".join(chunks) + +# Consumer side +task_run = await generate_report.start(task_id="report-1", input="Q3 Results") +async for chunk in task_run: + print(chunk, end="") + +# After streaming completes, get the full result +final = await task_run.result() +``` + +`ctx.stream()` accepts any Python object — the framework simply passes it +through the stream handler with no serialization or transformation. + +> **Important**: The default `QueueStreamHandler` holds items in an in-memory +> `asyncio.Queue`. They are **not persisted** and are **lost on crash**. If the +> process restarts mid-stream, the recovered task starts from scratch. If you +> need durable incremental output, implement a custom `StreamHandler` or write +> to your own store inside the task function alongside `ctx.stream()`. + +### Custom Stream Handlers + +The streaming path is pluggable via the `StreamHandler` protocol. Implement +`put()`, `get()`, and `close()` to control how stream items are buffered, +transported, or persisted: + +```python +from azure.ai.agentserver.core.durable import StreamHandler + +class RedisStreamHandler: + """Example: fan-out streams via Redis.""" + + def __init__(self, redis_client, channel: str): + self._redis = redis_client + self._channel = channel + + async def put(self, item): + await self._redis.publish(self._channel, serialize(item)) + + async def get(self): + msg = await self._redis.subscribe_next(self._channel) + if msg is None: + raise StopAsyncIteration + return deserialize(msg) + + async def close(self): + await self._redis.publish(self._channel, "__CLOSED__") +``` + +Pass the handler at the call site — no decorator changes needed: + +```python +handler = RedisStreamHandler(redis, channel="report-1") +task_run = await generate_report.start( + task_id="report-1", + input="Q3 Results", + stream_handler=handler, +) +async for chunk in task_run: + print(chunk, end="") +``` + +**Key rules:** + +- `get()` must raise `StopAsyncIteration` after `close()` is called and all + buffered items are drained. This is Python's native iterator exhaustion signal. +- `close()` is always called by the framework when the task finishes — whether + it succeeds, fails, or is cancelled. +- If no `stream_handler` is provided, the framework uses `QueueStreamHandler` + (in-memory `asyncio.Queue`) as the default. +- The handler instance survives steering restarts — items streamed before and + after a steering cycle flow through the same handler. + +--- + +## Persistence + +Understanding what is and isn't persisted is the most important concept in this +guide. + +### Responsibility Matrix + +| Data | Who Persists | Where | +|------|-------------|-------| +| Task status — `TaskStatus`: `"pending"`, `"in_progress"`, `"suspended"`, `"completed"` | **Framework** | Task store | +| Task input (the value passed to `.run()`/`.start()`) | **Framework** | Task store payload | +| Task metadata (`ctx.metadata`) | **Framework** | Task store payload | +| Task output (return value) | **Framework** | Task store payload | +| Task error (on failure) | **Framework** | Task store | +| Invocation results (what your API returns to callers) | **You** | Your store | +| Conversation history / checkpoints | **You** | Your store | +| Streaming items | **Nobody** (default) | In-memory; pluggable via `StreamHandler` | + +The task store powers lifecycle and recovery. **It is NOT your application +database.** You read from it via `.get()` to inspect task state, but you should +not depend on it as the persistence layer for your API responses. + +### The Durable Boundary Rule + +> **Everything that must survive a crash must happen inside the durable task function.** + +The durable task function is the crash-recovery boundary. If the process dies, +the framework automatically re-invokes your function on container restart. +Additionally, a subsequent `.run()` / `.start()` call with the same `task_id` +will detect the stale task and recover it. Any work done *outside* the function +(e.g., in an HTTP handler, in an `asyncio.create_task` callback) is lost. + +--- + +## The Invocation Store Pattern + +When building an HTTP API that fronts durable tasks (the 202 + poll pattern), +you need to persist invocation results so that clients can retrieve them. The +correct pattern: write results **inside** the durable task function. + +```python +# Your persistence layer (file store, Redis, database — your choice) +invocation_store = FileStore("./invocations") + +@durable_task(name="agent_session") +async def agent_session(ctx: TaskContext[dict]) -> dict: + invocation_id = ctx.input["invocation_id"] + message = ctx.input["message"] + + # Mark invocation as running (inside the durable boundary) + invocation_store.save(invocation_id, {"status": "running"}) + + # Do work + reply = await generate_reply(message) + result = {"status": "completed", "reply": reply} + + # Persist result (inside the durable boundary) + invocation_store.save(invocation_id, result) + + # Suspend — waiting for next turn + return await ctx.suspend(output=result) +``` + +The HTTP layer is minimal: + +```python +# POST /invoke — start or resume the task +async def invoke(request): + invocation_id = generate_id() + try: + await agent_session.start( + task_id=session_id, + input={"invocation_id": invocation_id, "message": message}, + ) + except TaskConflictError: + return JSONResponse({"error": "Task already running"}, status_code=409) + return JSONResponse({"invocation_id": invocation_id}, status_code=202) + +# GET /invocations/{id} — read from YOUR store, not the task store +async def get_invocation(request): + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Not found"}, status_code=404) + return JSONResponse(result) +``` + +Why this works: if the process crashes after `invocation_store.save(..., "running")` +but before the result write, the framework recovers the task, re-enters the function +with `entry_mode="recovered"`, and the result eventually gets written. The client +polls `GET /invocations/{id}` until it sees `"completed"`. + +--- + +## RetryPolicy + +Configure automatic retries on failure. Three presets cover most use cases: + +```python +from datetime import timedelta +from azure.ai.agentserver.core.durable import durable_task, RetryPolicy, TaskContext + +# Exponential backoff (default: 1s → 2s → 4s, 3 attempts) +@durable_task(name="call_api", retry=RetryPolicy.exponential_backoff()) +async def call_api(ctx: TaskContext[str]) -> dict: ... + +# Fixed delay (5s between each retry, 3 attempts) +@durable_task(name="poll_status", retry=RetryPolicy.fixed_delay(delay=timedelta(seconds=5))) +async def poll_status(ctx: TaskContext[str]) -> dict: ... + +# Linear backoff (1s → 2s → 3s → 4s → 5s, 5 attempts) +@durable_task(name="batch_job", retry=RetryPolicy.linear_backoff(max_attempts=5)) +async def batch_job(ctx: TaskContext[str]) -> dict: ... + +# No retry — fail immediately +@durable_task(name="one_shot", retry=RetryPolicy.no_retry()) +async def one_shot(ctx: TaskContext[str]) -> dict: ... +``` + +Customize any preset: + +```python +RetryPolicy.exponential_backoff( + max_attempts=5, + initial_delay=timedelta(seconds=2), + max_delay=timedelta(seconds=120), + jitter=True, # ±25% randomization (default) +) +``` + +Retry can also be set per-call, overriding the decorator: + +```python +result = await call_api.run( + task_id="api-1", + input="https://example.com", + retry=RetryPolicy.fixed_delay(max_attempts=10), +) +``` + +--- + +## Decorator Options + +The `@durable_task` decorator accepts these options (defined in `DurableTaskOptions`): + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `name` | `str` | Function `__qualname__` | **Stable task identity anchor.** Used for crash recovery routing and source stamping. If you ever rename your function, existing in-flight tasks are still recovered correctly because the framework matches on this name, not the Python function name. **Always provide an explicit name for production tasks.** | +| `retry` | `RetryPolicy \| None` | `None` | Retry policy on failure. See [RetryPolicy](#retrypolicy). | +| `ephemeral` | `bool` | `True` | Auto-delete task record on completion. | +| `tags` | `dict[str, str] \| Callable[[Any, str], dict[str, str]] \| None` | `{}` | Default tags (static or callable factory receiving `(input, task_id)`). | +| `title` | `str \| Callable[[Any, str], str] \| None` | `None` | Human-readable title or title factory. Defaults to `"{name}:{task_id[:8]}"` when not provided. | +| `description` | `str \| Callable[[Any, str], str \| None] \| None` | `None` | Task description (static or callable factory receiving `(input, task_id)`). | +| `store_input` | `bool` | `True` | Whether to persist input on the task record. | +| `timeout` | `timedelta \| None` | `None` | Execution timeout. When elapsed, `ctx.cancel` is set cooperatively. If the function does not exit, the lease eventually expires and the task is recovered. | +| `steerable` | `bool` | `False` | Enable steering. When True, `.start()` on an `in_progress` task queues the input instead of raising `TaskConflictError`. See [Steering](#steering). | +| `max_pending` | `int` | `10` | Maximum queued steering inputs. Excess raises `SteeringQueueFull`. | + +> **Source provenance** is auto-stamped by the framework on every task. It records +> which function created the task and the SDK version. Source is not user-overridable; +> use `tags` for custom metadata. +> +> **Reserved tags**: The framework stamps internal tags (prefixed with `_durable_task_`) +> on every task for scoping and recovery. Any tags you provide with this prefix are +> silently stripped. Use unprefixed tag keys for your own metadata. + +```python +@durable_task( + name="analyze_document", + ephemeral=False, # Keep task record after completion + tags={"team": "platform", "model": "gpt-4o"}, + title="Document Analysis", +) +async def analyze_document(ctx: TaskContext[dict]) -> dict: ... +``` + +**`ephemeral`** controls what happens when a completed task's `.start()` / `.run()` +is called again: +- `ephemeral=True` (default): The completed task was auto-deleted, so a fresh task + is created. +- `ephemeral=False`: The completed task still exists, so `TaskConflictError` is raised. + +Use the `.options()` method for per-call overrides without modifying the decorator: + +```python +# Override tags for this specific call +result = await analyze_document.options( + tags={"model": "gpt-4o-mini"}, +).run(task_id="doc-1", input={"url": "..."}) +``` + +### Callable Factories (`tags`, `title`, `description`) + +When `tags`, `title`, or `description` is a callable, it receives `(input, task_id)` and +is invoked at **task creation time** — before the task function runs. + +```python +@durable_task( + tags=lambda input, task_id: {"user": input["user_id"], "run": task_id[:8]}, + description=lambda input, task_id: f"Processing {input['filename']}", +) +async def process_file(ctx: TaskContext[dict]) -> str: ... +``` + +**Error behaviour:** + +- If a factory **raises an exception**, it propagates directly to the `.run()` / `.start()` + caller. The task is never created. +- If a factory **returns the wrong type** (e.g., `tags` callable returns a list instead of + a dict), a `TypeError` is raised immediately at task creation time. +- Mixing a callable `tags` on the decorator with a static dict via `.options(tags={...})` + raises `TypeError` — use one style consistently. + +--- + +## Error Handling + +| Exception | Raised By | When | +|-----------|-----------|------| +| `TaskConflictError` | `.run()`, `.start()` | Task is `in_progress` (non-stale, non-steerable) or `completed` (non-ephemeral) | +| `TaskFailed` | `.run()`, `task_run.result()` | Unhandled exception in the task function | +| `TaskCancelled` | `.run()`, `task_run.result()` | Task was cancelled via `task_run.cancel()` | +| `TaskTerminated` | `.run()`, `task_run.result()` | Task was forcefully terminated (timeout or `task_run.terminate()`) | +| `TaskNotFound` | `task_run.refresh()`, `task_run.delete()` | Task record does not exist in the store | +| `SteeringQueueFull` | `.start()` | Steering queue has `max_pending` items. Caller should retry or back off | + +> **Note**: Suspension is no longer an exception. When a task suspends, `.run()` and +> `task_run.result()` return a `TaskResult` with `is_suspended == True`. Check +> `result.is_suspended` or `result.is_completed` to distinguish outcomes. + +Handle them in your application code: + +```python +from azure.ai.agentserver.core.durable import ( + TaskConflictError, + TaskFailed, + TaskTerminated, +) + +result = await my_task.run(task_id="t1", input="hello") + +if result.is_suspended: + # Task paused — result.output has the snapshot + print(f"Suspended: {result.output}") +elif result.is_completed: + print(f"Done: {result.output}") +``` + +Exceptions are raised for true error conditions: + +```python +try: + result = await my_task.run(task_id="t1", input="hello") +except TaskConflictError: + # Task already running or completed + info = await my_task.get("t1") + print(f"Task is {info.status}") +except TaskFailed as exc: + # Task function raised an exception + print(f"Failed: {exc.error}") +except TaskTerminated: + # Task was forcefully terminated (timeout or explicit terminate) + print("Task was terminated") +``` + +`TaskSuspended` is retained for backward compatibility but is no longer raised +by `.run()` or `task_run.result()`. Suspension is now a return value — check +`result.is_suspended` on the returned `TaskResult`. + +--- + +## Cancellation, Timeout, and Termination + +Durable tasks support three levels of stopping execution: + +### Cooperative Cancellation + +Set `ctx.cancel` to signal the task function to exit gracefully. The task +must check this event and respond: + +```python +run = await my_task.start(task_id="t1", input=data) +await run.cancel() # Sets ctx.cancel — task should check and exit +``` + +Inside the task function: + +```python +@durable_task +async def my_task(ctx: TaskContext[Input]) -> Output: + for item in items: + if ctx.cancel.is_set(): + return partial_result # Exit cleanly + await process(item) + return full_result +``` + +Cooperative cancel sets `ctx.cancel`. If the function checks this event and +**returns normally**, the task completes as a success — not as cancelled. The +function decides its own outcome. `TaskCancelled` is only raised when the +function does not handle the cancel and the asyncio task is cancelled. + +### Execution Timeout + +Set a `timeout` to automatically cancel tasks that run too long. When the +timeout elapses, `ctx.cancel` is set cooperatively — the same signal used +by `handle.cancel()` and steering. If the function does not exit, the lease +eventually expires and the task is recovered on the next heartbeat. + +```python +from datetime import timedelta + +@durable_task( + timeout=timedelta(minutes=5), +) +async def analyze(ctx: TaskContext[dict]) -> dict: + while not ctx.cancel.is_set(): + chunk = await process_next() + if chunk is None: + break + return {"status": "done"} +``` + +### Forced Termination + +`terminate()` immediately kills the task via the failure path. Unlike +cooperative cancel, terminated tasks are stored as failed and are **not** +eligible for recovery: + +```python +run = await my_task.start(task_id="t1", input=data) +await run.terminate(reason="User requested abort") + +try: + await run.result() +except TaskTerminated: + print("Task was terminated") +``` + +### Cancel vs Terminate Summary + +| Method | `ctx.cancel` set? | Hard cancel? | Outcome | Recoverable? | +|--------|-------------------|--------------|---------|--------------| +| `run.cancel()` | ✅ | ❌ | Success if function returns normally; `TaskCancelled` if unhandled | Yes (stays in_progress until function exits) | +| `run.terminate()` | ✅ | ✅ | `TaskTerminated` | No (goes to failed) | +| Timeout expired | ✅ then ✅ | After grace | `TaskTerminated` | No (goes to failed) | + +--- + +## Best Practices + +1. **Keep tasks idempotent for recovery.** When `entry_mode="recovered"`, the + function re-runs from the top. Use `ctx.metadata` to track completed steps + and skip them on re-entry. + +2. **Branch on `entry_mode`.** Always handle at least `"fresh"` and `"recovered"`. + For suspend/resume tasks, handle `"resumed"` as well. For steerable tasks, + check `ctx.was_steered` inside the `"resumed"` branch. + +3. **Persist results inside the durable boundary.** Any write that must survive + a crash belongs inside the task function, not in the HTTP handler or a + background `asyncio.create_task`. + +4. **Use `ephemeral=True` for one-shot tasks.** If the task doesn't need to be + queried after completion, let the framework auto-delete it. This keeps the + task store clean. + +5. **Keep task functions focused.** A task should do one logical unit of work. + Compose multiple tasks rather than building monolithic functions. + +6. **Check cancellation cooperatively.** Poll `ctx.cancel.is_set()` in long loops + and exit cleanly when set. For steerable tasks, this is what enables the + framework to drain the queue and start the next generation. + +7. **Use `ctx.metadata` for progress, not for large data.** Metadata is flushed + periodically to the task store. Keep values small and JSON-serializable. + The task payload has a 1 MB cap — write conversation history, results, and + growing data to your own store (database, blob, Redis). + +8. **Always preserve user input on cancel.** When `ctx.cancel.is_set()` in a + steerable task, save the user's message to your conversation store before + returning. The *reply* is interrupted, not the *input recording*. + +9. **Use the three-phase cancel pattern.** Check `ctx.cancel` at three points: + before the LLM call (Phase 1), between chunks (Phase 2), and after + completion (Phase 3). This covers all timing scenarios. + +10. **Store conversation history externally.** Don't put growing data in + `ctx.metadata`. Use an external store keyed by `session_id`. The task + metadata is for lightweight progress signals only. + +11. **Steerable tasks MUST suspend on cancel, not return normally or raise.** + When `ctx.cancel.is_set()` due to steering, always `return await + ctx.suspend(reason="steered")`. This keeps the task alive in `suspended` + state so the framework can drain the pending queue and resume with the + next input. If you return a normal value, the task completes — the next + `.start()` creates a fresh task, breaking conversation continuity. If you + raise an exception, the task enters the failure/retry path, which is also + wrong. Suspend is the only correct exit for a steered cancel. + +--- + +## Common Mistakes + +### ❌ Missing `return await` on suspend + +```python +# ❌ BAD — suspend() returns a sentinel, but it's discarded +async def my_task(ctx: TaskContext[str]) -> str: + await ctx.suspend(output="paused") + return "done" # This runs immediately — task never actually suspends + +# ✅ GOOD — return the sentinel so the framework sees it +async def my_task(ctx: TaskContext[str]) -> str: + return await ctx.suspend(output="paused") +``` + +### ❌ Persisting results outside the durable boundary + +```python +# ❌ BAD — if the process crashes, the result is never written +async def invoke(request): + task_run = await my_task.start(task_id="t1", input="hello") + asyncio.create_task(save_result_when_done(task_run)) # LOST ON CRASH + return JSONResponse({"id": "inv-1"}, status_code=202) + +# ✅ GOOD — write results inside the task function itself +@durable_task(name="my_task") +async def my_task(ctx: TaskContext[dict]) -> dict: + invocation_id = ctx.input["invocation_id"] + result = await do_work() + invocation_store.save(invocation_id, result) # DURABLE + return result +``` + +### ❌ Leaking task_id to API callers + +```python +# ❌ BAD — task_id is an internal lifecycle identifier +return JSONResponse({"task_id": task_id}, status_code=202) + +# ✅ GOOD — expose your own identifier (invocation_id, session_id, etc.) +return JSONResponse({"invocation_id": invocation_id}, status_code=202) +``` + +### ❌ Assuming streaming survives crashes + +```python +# ❌ BAD — default QueueStreamHandler is in-memory only +@durable_task(name="stream_report") +async def stream_report(ctx: TaskContext[str]) -> str: + for chunk in generate_chunks(): + await ctx.stream(chunk) # Lost if process crashes here + return "done" + +# ✅ GOOD — also persist to your store if durability matters +@durable_task(name="stream_report") +async def stream_report(ctx: TaskContext[str]) -> str: + for chunk in generate_chunks(): + await ctx.stream(chunk) + append_to_store(ctx.task_id, chunk) # Durable fallback + return "done" + +# ✅ ALSO GOOD — use a custom StreamHandler that persists +handler = DurableStreamHandler(store, ctx.task_id) +run = await stream_report.start( + task_id="r1", input="...", stream_handler=handler, +) +``` + +### ❌ Storing conversation history in task metadata + +```python +# ❌ BAD — metadata has a 1 MB cap and is not designed for growing data +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + history = ctx.metadata.get("history", []) + history.append({"role": "user", "content": ctx.input["message"]}) + reply = await call_llm(history) + history.append({"role": "assistant", "content": reply}) + ctx.metadata["history"] = history # GROWS UNBOUNDED + +# ✅ GOOD — use an external store, reference by session_id +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + session_id = ctx.input["session_id"] + history = conversation_store.load(session_id) or [] + history.append({"role": "user", "content": ctx.input["message"]}) + reply = await call_llm(history) + history.append({"role": "assistant", "content": reply}) + conversation_store.save(session_id, history) # EXTERNAL STORE +``` + +### ❌ Discarding input on steering cancel + +```python +# ❌ BAD — user's message is lost when cancel fires +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") # Message never saved! + +# ✅ GOOD — always save the user's message before returning +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + history = load_history(ctx.input["session_id"]) + history.append({"role": "user", "content": ctx.input["message"]}) + if ctx.cancel.is_set(): + save_history(ctx.input["session_id"], history) # PRESERVE INPUT + return await ctx.suspend(reason="steered") +``` + +### ❌ Skipping Phase 1 cancel check + +```python +# ❌ BAD — starts an expensive LLM call even when cancel is already set +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + # Missing Phase 1 check! + reply = "" + async for token in call_llm_streaming(ctx.input["message"]): + reply += token + if ctx.cancel.is_set(): + break + ... + +# ✅ GOOD — short-circuit before the LLM call +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): # Phase 1: pre-entry + save_input_and_return(ctx) + return await ctx.suspend(reason="steered") + reply = "" + async for token in call_llm_streaming(ctx.input["message"]): + reply += token + if ctx.cancel.is_set(): # Phase 2: mid-stream + break + ... +``` + +### ❌ Using `steerable=True` without `suspend()` + +Steerable tasks **must** suspend on every exit — both on normal completion +(awaiting next user input) and on steering cancel. If the function returns +normally, the task completes and the framework has nowhere to drain the +pending queue. If it raises, the task enters the failure/retry path. + +```python +# ❌ BAD — task completes, can't accept next turn or drain queue +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + reply = await call_llm(ctx.input["message"]) + return {"reply": reply} # Task completes → next .start() creates fresh task + +# ❌ BAD — raising on cancel enters the failure path +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + raise RuntimeError("Cancelled") # Wrong! Enters retry/failure path + +# ✅ GOOD — always suspend: on cancel AND on normal completion +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") # Keep alive for drain + reply = await call_llm(ctx.input["message"]) + return await ctx.suspend(reason="awaiting_user_input", output={"reply": reply}) +``` diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index 4c1a534b4119..c8d8f568367d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -23,11 +23,17 @@ keywords = ["azure", "azure sdk", "agent", "agentserver", "core"] dependencies = [ "starlette>=0.45.0", "hypercorn>=0.17.0", + "httpx>=0.27.0", "opentelemetry-api>=1.40.0", "opentelemetry-sdk>=1.40.0", "microsoft-opentelemetry>=0.1.0b1", ] +[project.optional-dependencies] +hosted = [ + "azure-identity>=1.16.0", +] + [build-system] requires = ["setuptools>=69", "wheel"] build-backend = "setuptools.build_meta" diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/durable_retry.py b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/durable_retry.py new file mode 100644 index 000000000000..ef469da9dfd2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/durable_retry.py @@ -0,0 +1,117 @@ +"""Durable task with retry policies. + +Demonstrates using ``RetryPolicy`` presets to automatically retry tasks +that fail with transient errors. + +Usage:: + + pip install azure-ai-agentserver-core + + python durable_retry.py + +.. note:: + + This sample uses a **file-based** task store for simplicity. + In production, a proper persistence store **must** be used. +""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import timedelta + +from azure.ai.agentserver.core import AgentServerHost +from azure.ai.agentserver.core.durable import RetryPolicy, durable_task +from azure.ai.agentserver.core.durable._context import TaskContext + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Track call count to simulate transient failures +_call_count = 0 + + +@durable_task( + name="flaky_task", + retry=RetryPolicy.exponential_backoff( + max_attempts=4, + initial_delay=timedelta(milliseconds=100), + max_delay=timedelta(seconds=2), + ), +) +async def flaky_task(ctx: TaskContext[None]) -> str: + """Simulates a task that fails twice then succeeds. + + The exponential backoff policy retries up to 4 times with + increasing delays: 0.1s → 0.2s → 0.4s (capped at 2.0s). + """ + global _call_count # noqa: PLW0603 + _call_count += 1 + attempt = ctx.run_attempt + + logger.info("Attempt %d (call count=%d)", attempt, _call_count) + + if attempt < 2: + raise ConnectionError(f"Simulated transient error on attempt {attempt}") + + return f"Success after {attempt + 1} attempts" + + +@durable_task( + name="selective_retry", + retry=RetryPolicy( + initial_delay=timedelta(milliseconds=100), + max_delay=timedelta(milliseconds=100), + backoff_coefficient=1.0, + max_attempts=3, + retry_on=(ConnectionError, TimeoutError), + jitter=False, + ), +) +async def selective_retry_task(ctx: TaskContext[None]) -> str: + """Only retries ConnectionError and TimeoutError — not ValueError.""" + attempt = ctx.run_attempt + if attempt == 0: + raise ConnectionError("transient") + return f"Recovered on attempt {attempt}" + + +async def main(): + host = AgentServerHost() + manager = host._task_manager # noqa: SLF001 + + await manager.startup() + + try: + # Run with exponential backoff + logger.info("--- Exponential backoff demo ---") + result = await flaky_task.run(input=None) + logger.info("Result: %s", result.output) + + # Run with selective retry + logger.info("--- Selective retry demo ---") + result2 = await selective_retry_task.run(input=None) + logger.info("Result: %s", result2.output) + + # Show available presets + logger.info("--- Available retry presets ---") + presets = { + "exponential": RetryPolicy.exponential_backoff(), + "fixed": RetryPolicy.fixed_delay(), + "linear": RetryPolicy.linear_backoff(), + "none": RetryPolicy.no_retry(), + } + for name, policy in presets.items(): + logger.info( + " %s: max_attempts=%d, initial_delay=%.1fs", + name, + policy.max_attempts, + policy.initial_delay, + ) + finally: + await manager.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/requirements.txt b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/requirements.txt new file mode 100644 index 000000000000..3f2b4e9ee6b4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-core diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/durable_source.py b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/durable_source.py new file mode 100644 index 000000000000..103e006fe1fb --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/durable_source.py @@ -0,0 +1,79 @@ +"""Durable task with source field tracking. + +Demonstrates using the ``source`` parameter to attach provenance +metadata at task creation time. The source is immutable after creation +and can be used for auditing, debugging, or routing. + +Usage:: + + pip install azure-ai-agentserver-core + + python durable_source.py + +.. note:: + + This sample uses a **file-based** task store for simplicity. + In production, a proper persistence store **must** be used. +""" + +from __future__ import annotations + +import asyncio +import logging + +from azure.ai.agentserver.core import AgentServerHost +from azure.ai.agentserver.core.durable import durable_task +from azure.ai.agentserver.core.durable._context import TaskContext + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@durable_task( + name="process_order", + source={"system": "order-service", "version": "2.1"}, +) +async def process_order_default(ctx: TaskContext[None]) -> dict: + """Task with source set at decorator level. + + The decorator-level source is used as a default — it can be + overridden at the call site. + """ + logger.info("Processing order with task_id=%s", ctx.task_id) + return {"status": "processed", "task_id": ctx.task_id} + + +async def main(): + host = AgentServerHost() + manager = host._task_manager # noqa: SLF001 + + await manager.startup() + + try: + # 1. Use decorator-level source (default) + logger.info("--- Decorator source ---") + result1 = await process_order_default.run(input={"order_id": "ORD-001"}) + logger.info("Result: %s", result1.output) + + # 2. Override source at call site + logger.info("--- Call-site source override ---") + result2 = await process_order_default.run( + input={"order_id": "ORD-002"}, + source={"system": "batch-processor", "batch_id": "B-42"}, + ) + logger.info("Result: %s", result2.output) + + # 3. Task without any source (None by default) + @durable_task(name="no_source_task") + async def no_source_task(ctx: TaskContext[None]) -> str: + return "done" + + logger.info("--- No source ---") + result3 = await no_source_task.run(input=None) + logger.info("Result: %s", result3.output) + finally: + await manager.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/requirements.txt b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/requirements.txt new file mode 100644 index 000000000000..3f2b4e9ee6b4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-core diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/durable_streaming.py b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/durable_streaming.py new file mode 100644 index 000000000000..af90178510a1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/durable_streaming.py @@ -0,0 +1,68 @@ +"""Durable task with streaming output. + +Demonstrates using ``ctx.stream()`` to emit incremental results from a +long-running task while the consumer iterates with ``async for``. + +The stream is in-memory only — items are **not** persisted. + +Usage:: + + pip install azure-ai-agentserver-core + + python durable_streaming.py + +.. note:: + + This sample uses a **file-based** task store for simplicity. + In production, a proper persistence store **must** be used. +""" + +from __future__ import annotations + +import asyncio +import logging + +from azure.ai.agentserver.core import AgentServerHost +from azure.ai.agentserver.core.durable import RetryPolicy, durable_task +from azure.ai.agentserver.core.durable._context import TaskContext + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@durable_task(name="stream_numbers") +async def stream_numbers(ctx: TaskContext[None]) -> str: + """Stream numbers 0-4 with a short delay, then return a summary.""" + for i in range(5): + await ctx.stream({"value": i, "message": f"Processing item {i}"}) + await asyncio.sleep(0.1) + return f"Streamed {5} items" + + +async def main(): + host = AgentServerHost() + manager = host._task_manager # noqa: SLF001 + + # Start the manager + await manager.startup() + + try: + # Start the task (non-blocking — returns a TaskRun handle) + run = await stream_numbers.start(input=None) + + # Consume streamed items as they arrive + items = [] + async for chunk in run: + logger.info("Received: %s", chunk) + items.append(chunk) + + # After streaming ends, get the final result + result = await run.result() + logger.info("Final result: %s", result.output) + logger.info("Total items streamed: %d", len(items)) + finally: + await manager.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/requirements.txt b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/requirements.txt new file mode 100644 index 000000000000..3f2b4e9ee6b4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-core diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/selfhosted_invocation/selfhosted_invocation.py b/sdk/agentserver/azure-ai-agentserver-core/samples/selfhosted_invocation/selfhosted_invocation.py index 9fc296ef775b..629841ef2564 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/samples/selfhosted_invocation/selfhosted_invocation.py +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/selfhosted_invocation/selfhosted_invocation.py @@ -28,6 +28,7 @@ curl http://localhost:8088/readiness # -> {"status": "healthy"} """ + import logging import os import uuid @@ -54,7 +55,9 @@ def __init__(self, **kwargs: Any) -> None: async def _invoke(self, request: Request) -> Response: """POST /invocations — handle an invocation request with tracing.""" - invocation_id = request.headers.get("x-agent-invocation-id") or str(uuid.uuid4()) + invocation_id = request.headers.get("x-agent-invocation-id") or str( + uuid.uuid4() + ) session_id = ( request.query_params.get("agent_session_id") or os.environ.get("FOUNDRY_AGENT_SESSION_ID") @@ -62,10 +65,15 @@ async def _invoke(self, request: Request) -> Response: ) with self.request_span( - request.headers, invocation_id, "invoke_agent", - operation_name="invoke_agent", session_id=session_id, + request.headers, + invocation_id, + "invoke_agent", + operation_name="invoke_agent", + session_id=session_id, ) as otel_span: - logger.info("Processing invocation %s in session %s", invocation_id, session_id) + logger.info( + "Processing invocation %s in session %s", invocation_id, session_id + ) try: data = await request.json() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py index 27b136ce5de8..f4670c21cf8e 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py @@ -11,7 +11,10 @@ def pytest_configure(config): - config.addinivalue_line("markers", "tracing_e2e: end-to-end tracing tests requiring live Azure resources") + config.addinivalue_line( + "markers", + "tracing_e2e: end-to-end tracing tests requiring live Azure resources", + ) @pytest.fixture() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/__init__.py new file mode 100644 index 000000000000..d540fd20468c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_callable_factories.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_callable_factories.py new file mode 100644 index 000000000000..c6ba64b8b2fa --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_callable_factories.py @@ -0,0 +1,280 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for callable tag and description factories on @durable_task.""" + +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + + +class _ManagerFixture: + """Helper to set up a DurableTaskManager with local file storage.""" + + @staticmethod + async def setup(tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + @staticmethod + async def teardown(manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + +class TestCallableTags: + """Tests for callable tag factories on @durable_task.""" + + @pytest.mark.asyncio + async def test_static_tags_preserved(self, tmp_path): + """Static dict tags still work as before.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="static_tags", tags={"env": "prod"}, ephemeral=False) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None) + + task = await manager.provider.get(task_id) + assert task.tags["env"] == "prod" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_tags_factory(self, tmp_path): + """Callable tags factory receives (input, task_id) and sets tags.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="callable_tags", + tags=lambda inp, tid: {"tenant": inp["tenant"], "tid": tid[:8]}, + ephemeral=False, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input={"tenant": "acme"}) + + task = await manager.provider.get(task_id) + assert task.tags["tenant"] == "acme" + assert task.tags["tid"] == task_id[:8] + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_tags_merged_with_callsite(self, tmp_path): + """Per-call tags merge on top of callable-resolved tags.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="merge_tags", + tags=lambda inp, tid: {"source": "factory"}, + ephemeral=False, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None, tags={"extra": "call-site"}) + + task = await manager.provider.get(task_id) + assert task.tags["source"] == "factory" + assert task.tags["extra"] == "call-site" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_tags_error_propagates(self, tmp_path): + """If callable tags factory raises, the error propagates at creation.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="bad_tags", + tags=lambda inp, tid: 1 / 0, # type: ignore[return-value] + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(ZeroDivisionError): + await my_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +class TestCallableDescription: + """Tests for callable description factory on @durable_task.""" + + @pytest.mark.asyncio + async def test_static_description(self, tmp_path): + """Static string description is stored on the task record.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="static_desc", description="A static description", ephemeral=False + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None) + + task = await manager.provider.get(task_id) + assert task.description == "A static description" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_description_factory(self, tmp_path): + """Callable description factory receives (input, task_id).""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="callable_desc", + description=lambda inp, tid: f"Processing {inp['doc']}", + ephemeral=False, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input={"doc": "report.pdf"}) + + task = await manager.provider.get(task_id) + assert task.description == "Processing report.pdf" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_no_description_backward_compat(self, tmp_path): + """Without description, the task record has no description.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="no_desc", ephemeral=False) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None) + + task = await manager.provider.get(task_id) + assert task.description is None + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +class TestFactoryValidation: + """Tests for return type validation on callable factories.""" + + @pytest.mark.asyncio + async def test_tags_callable_bad_return_type(self, tmp_path): + """Tags callable returning non-dict raises TypeError.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="bad_tags_type", + tags=lambda inp, tid: "not-a-dict", # type: ignore[return-value] + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(TypeError, match="tags callable must return dict"): + await my_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_description_callable_bad_return_type(self, tmp_path): + """Description callable returning non-str raises TypeError.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="bad_desc_type", + description=lambda inp, tid: 12345, # type: ignore[return-value] + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(TypeError, match="description callable must return str"): + await my_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + def test_options_mixing_callable_and_dict_tags_raises(self): + """Mixing callable and dict tags in options() raises TypeError.""" + + @durable_task( + name="callable_tags_task", + tags=lambda inp, tid: {"k": "v"}, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(TypeError, match="Cannot mix callable and dict"): + my_task.options(tags={"override": "val"}) + + def test_options_callable_to_callable_ok(self): + """Replacing callable tags with another callable in options() works.""" + + @durable_task( + name="callable_swap", + tags=lambda inp, tid: {"old": "factory"}, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + updated = my_task.options( + tags=lambda inp, tid: {"new": "factory"}, + ) + assert callable(updated._opts.tags) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_cancellation_timeout.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_cancellation_timeout.py new file mode 100644 index 000000000000..82ff8f614a13 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_cancellation_timeout.py @@ -0,0 +1,215 @@ +"""Tests for cancellation and timeout features (spec 005). + +Covers: +- Execution timeout (cooperative cancel → hard cancel) +- Wait timeout (caller-side timeout on result()) +- Terminate (forced termination via TaskRun.terminate()) +""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import timedelta +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + TaskTerminated, + durable_task, +) + + +class _ManagerFixture: + """Helper to set up a DurableTaskManager with local file storage.""" + + @staticmethod + async def setup(tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + @staticmethod + async def teardown(manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + +# --------------------------------------------------------------------------- +# Execution timeout tests +# --------------------------------------------------------------------------- + + +class TestExecutionTimeout: + """Verify the timeout watchdog cooperatively and hard-cancels tasks.""" + + @pytest.mark.asyncio + async def test_timeout_cooperative_cancel(self, tmp_path): + """Task sees ctx.cancel set when timeout fires.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + cancel_observed = asyncio.Event() + + @durable_task( + name="timeout_coop", + timeout=timedelta(seconds=0.2), + ) + async def slow_task(ctx: TaskContext[Any]) -> str: + # Wait until cooperative cancel fires + while not ctx.cancel.is_set(): + await asyncio.sleep(0.01) + cancel_observed.set() + return "cooperated" + + run = await slow_task.start(task_id=uuid.uuid4().hex, input=None) + result = await run.result() + + assert cancel_observed.is_set() + assert result.output == "cooperated" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_no_timeout_regression(self, tmp_path): + """Task without timeout runs normally to completion.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="no_timeout") + async def quick_task(ctx: TaskContext[Any]) -> str: + return "done" + + run = await quick_task.start(task_id=uuid.uuid4().hex, input=None) + result = await run.result() + assert result.output == "done" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Terminate tests +# --------------------------------------------------------------------------- + + +class TestTerminate: + """Verify TaskRun.terminate() forces failure.""" + + @pytest.mark.asyncio + async def test_terminate_raises_task_terminated(self, tmp_path): + """terminate() causes result() to raise TaskTerminated.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="terminatable") + async def long_task(ctx: TaskContext[Any]) -> str: + await asyncio.sleep(100) + return "never" + + run = await long_task.start(task_id=uuid.uuid4().hex, input=None) + await asyncio.sleep(0.05) # let it start + + await run.terminate() + with pytest.raises(TaskTerminated): + await run.result() + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_terminate_sets_failure_status(self, tmp_path): + """Terminated task is stored as failed (not in_progress).""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="term_status", ephemeral=False) + async def long_task(ctx: TaskContext[Any]) -> str: + await asyncio.sleep(100) + return "never" + + task_id = uuid.uuid4().hex + run = await long_task.start(task_id=task_id, input=None) + await asyncio.sleep(0.05) + + await run.terminate() + with pytest.raises(TaskTerminated): + await run.result() + + # Give manager time to persist failure + await asyncio.sleep(0.1) + + info = await manager.provider.get(task_id) + assert info is not None + # Failures are stored as "completed" with an error dict + assert info.status == "completed" + assert info.error is not None + assert info.error["type"] == "TaskTerminated" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_cancel_vs_terminate_distinction(self, tmp_path): + """Cooperative cancel (ctx.cancel) raises TaskCancelled, not TaskTerminated.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + from azure.ai.agentserver.core.durable._exceptions import TaskCancelled + + @durable_task(name="cancel_test") + async def cancellable_task(ctx: TaskContext[Any]) -> str: + # Cooperatively check cancel + while not ctx.cancel.is_set(): + await asyncio.sleep(0.01) + raise asyncio.CancelledError() + + run = await cancellable_task.start(task_id=uuid.uuid4().hex, input=None) + await asyncio.sleep(0.05) + + # Use cancel (not terminate) — cooperative + await run.cancel() + with pytest.raises(TaskCancelled): + await run.result() + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_terminate_reason_propagated(self, tmp_path): + """Terminate reason is propagated to TaskTerminated exception.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="term_reason_task") + async def slow_task(ctx: TaskContext[Any]) -> str: + await asyncio.sleep(10) + return "never" + + run = await slow_task.start(task_id=uuid.uuid4().hex, input=None) + await asyncio.sleep(0.05) + + await run.terminate(reason="user requested stop") + with pytest.raises(TaskTerminated) as exc_info: + await run.result() + assert exc_info.value.reason == "user requested stop" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_decorator.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_decorator.py new file mode 100644 index 000000000000..76aae10f0d0a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_decorator.py @@ -0,0 +1,157 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for @durable_task decorator and DurableTask class.""" + +import asyncio + +import pytest + +from azure.ai.agentserver.core.durable import ( + DurableTask, + DurableTaskOptions, + TaskContext, + durable_task, +) + + +class TestDurableTaskDecorator: + """Tests for the @durable_task decorator.""" + + def test_bare_decorator(self) -> None: + """@durable_task with no arguments produces a DurableTask.""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 42 + + assert isinstance(my_task, DurableTask) + # Name includes class/method scope when defined inside a method + assert "my_task" in my_task.name + + def test_decorator_with_name(self) -> None: + """@durable_task(name=...) sets a custom name.""" + + @durable_task(name="custom_name") + async def my_task(ctx: TaskContext[str]) -> int: + return 0 + + assert my_task.name == "custom_name" + + def test_decorator_with_all_options(self) -> None: + """All decorator options are forwarded to DurableTaskOptions.""" + from datetime import timedelta + + @durable_task( + name="full", + ephemeral=False, + lease_duration_seconds=120, + store_input=True, + title="My Title", + tags={"env": "test"}, + timeout=timedelta(minutes=5), + ) + async def my_task(ctx: TaskContext[dict]) -> str: + return "" + + assert my_task.name == "full" + assert my_task._opts.ephemeral is False + assert my_task._opts.lease_duration_seconds == 120 + assert my_task._opts.store_input is True + assert my_task._opts.title == "My Title" + assert my_task._opts.tags == {"env": "test"} + assert my_task._opts.timeout == timedelta(minutes=5) + + def test_rejects_sync_function(self) -> None: + """@durable_task rejects synchronous functions.""" + with pytest.raises(TypeError, match="async function"): + + @durable_task + def sync_fn(ctx: TaskContext[str]) -> int: + return 1 + + def test_rejects_non_callable(self) -> None: + """@durable_task(...) rejects non-callable objects.""" + with pytest.raises((TypeError, AttributeError)): + durable_task(42) # type: ignore[arg-type] + + +class TestDurableTaskOptions: + """Tests for DurableTaskOptions merge via .options().""" + + def test_options_returns_new_instance(self) -> None: + """options() returns a new DurableTask, original unchanged.""" + + @durable_task(ephemeral=True) + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + updated = my_task.options(ephemeral=False) + assert updated is not my_task + assert updated._opts.ephemeral is False + assert my_task._opts.ephemeral is True + + def test_options_merges_tags(self) -> None: + """options() merges tags with existing ones.""" + + @durable_task(tags={"a": "1"}) + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + updated = my_task.options(tags={"b": "2"}) + assert updated._opts.tags == {"a": "1", "b": "2"} + + def test_options_overrides_title(self) -> None: + """options() overrides title.""" + + @durable_task(title="original") + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + updated = my_task.options(title="override") + assert updated._opts.title == "override" + + def test_default_options(self) -> None: + """Default DurableTaskOptions has sensible defaults.""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + opts = my_task._opts + assert opts.ephemeral is True + assert opts.lease_duration_seconds == 60 + assert opts.store_input is True # default is True + assert opts.tags == {} + assert opts.timeout is None + + +class TestTypeExtraction: + """Tests for generic type parameter extraction.""" + + def test_input_type_str(self) -> None: + """Extracts str as Input type from TaskContext[str].""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + assert my_task._input_type is str + + def test_input_type_dict(self) -> None: + """Extracts dict as Input type.""" + + @durable_task + async def my_task(ctx: TaskContext[dict]) -> str: + return "" + + assert my_task._input_type is dict + + def test_output_type_int(self) -> None: + """Extracts int as Output type from return annotation.""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + assert my_task._output_type is int diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_entry_mode.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_entry_mode.py new file mode 100644 index 000000000000..1a888eab5c8f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_entry_mode.py @@ -0,0 +1,181 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for TaskContext.entry_mode across all lifecycle paths.""" + +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + + +class TestEntryMode: + """Verify ctx.entry_mode is set correctly for each lifecycle path.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_fresh_start_entry_mode(self, tmp_path) -> None: + """First call to .run() produces entry_mode='fresh'.""" + observed_modes: list[str] = [] + + @durable_task(title="test-fresh") + async def my_task(ctx: TaskContext[str]) -> str: + observed_modes.append(ctx.entry_mode) + return "done" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="fresh-1", input="hello") + assert result.output == "done" + assert observed_modes == ["fresh"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_developer_resume_entry_mode(self, tmp_path) -> None: + """Calling .run() on a suspended task produces entry_mode='resumed' with new input.""" + observed: list[tuple[str, str]] = [] + + @durable_task(title="test-resume", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append((ctx.entry_mode, ctx.input)) + return await ctx.suspend(output={"partial": True}) + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # First call — fresh start, suspends + result1 = await my_task.run(task_id="resume-1", input="turn-1") + assert result1.is_suspended + assert observed == [("fresh", "turn-1")] + + # Second call — should resume with new input + result2 = await my_task.run(task_id="resume-1", input="turn-2") + assert result2.is_suspended + assert observed[-1] == ("resumed", "turn-2") + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_platform_resume_entry_mode(self, tmp_path) -> None: + """Platform-initiated resume (handle_resume) produces entry_mode='resumed'.""" + observed: list[str] = [] + + @durable_task(title="test-platform-resume", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append(ctx.entry_mode) + return await ctx.suspend(output="waiting") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # Fresh start — suspends + result = await my_task.run(task_id="platform-resume-1", input="init") + assert result.is_suspended + assert observed == ["fresh"] + + # Platform-initiated resume + await manager.handle_resume("platform-resume-1") + # Give the background task time to run + import asyncio + + await asyncio.sleep(0.2) + assert "resumed" in observed + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_recovered_entry_mode(self, tmp_path) -> None: + """Calling .run() on a stale in_progress task produces entry_mode='recovered'.""" + observed: list[str] = [] + + @durable_task(title="test-recover", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append(ctx.entry_mode) + return "recovered-ok" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + ) + + # Manually create a stale in_progress task + await manager.provider.create( + TaskCreateRequest( + id="stale-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="stale-test", + payload={"input": "old-data"}, + ) + ) + + # Backdate the updated_at to make it stale + task_file = ( + Path(str(tmp_path)) / "test-agent" / "test-session" / "stale-1.json" + ) + if task_file.exists(): + import json + + data = json.loads(task_file.read_text()) + data["updated_at"] = "2020-01-01T00:00:00+00:00" + task_file.write_text(json.dumps(data)) + + result = await my_task.run( + task_id="stale-1", + input="new-data", + stale_timeout=1.0, + ) + assert result.output == "recovered-ok" + assert observed == ["recovered"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_ignoring_entry_mode_works(self, tmp_path) -> None: + """A function that never reads entry_mode still works fine.""" + + @durable_task(title="test-ignore") + async def my_task(ctx: TaskContext[str]) -> str: + # Deliberately NOT reading ctx.entry_mode + return f"processed: {ctx.input}" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="ignore-1", input="data") + assert result.output == "processed: data" + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_get.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_get.py new file mode 100644 index 000000000000..8da515a20cb0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_get.py @@ -0,0 +1,140 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for DurableTask.get() method.""" + +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + + +class TestGet: + """Verify DurableTask.get() returns TaskInfo or None.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_get_existing_task(self, tmp_path) -> None: + """get() returns TaskInfo for an existing task.""" + + @durable_task(title="get-test", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + return await ctx.suspend(output="paused") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="get-1", input="data") + assert result.is_suspended + + info = await my_task.get("get-1") + assert info is not None + assert info.id == "get-1" + assert info.status == "suspended" + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_get_nonexistent_task(self, tmp_path) -> None: + """get() returns None for a non-existent task.""" + + @durable_task(title="get-test") + async def my_task(ctx: TaskContext[str]) -> str: + return "ok" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + info = await my_task.get("does-not-exist") + assert info is None + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_get_returns_correct_state(self, tmp_path) -> None: + """get() returns correct info for various task states.""" + + @durable_task(title="get-states", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + return await ctx.suspend(output="waiting") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # Create tasks in different states via the provider + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="state-suspended", + agent_name="test-agent", + session_id="test-session", + status="suspended", + title="suspended-task", + payload={"output": "half-done"}, + ) + ) + await manager.provider.create( + TaskCreateRequest( + id="state-completed", + agent_name="test-agent", + session_id="test-session", + status="completed", + title="done-task", + payload={"output": "final"}, + ) + ) + await manager.provider.create( + TaskCreateRequest( + id="state-in-progress", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="running-task", + payload={}, + ) + ) + + suspended = await my_task.get("state-suspended") + assert suspended is not None + assert suspended.status == "suspended" + + completed = await my_task.get("state-completed") + assert completed is not None + assert completed.status == "completed" + + in_progress = await my_task.get("state-in-progress") + assert in_progress is not None + assert in_progress.status == "in_progress" + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_lifecycle.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_lifecycle.py new file mode 100644 index 000000000000..cdc3f7ced790 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_lifecycle.py @@ -0,0 +1,321 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for lifecycle-aware .run() and .start() on DurableTask.""" + +import json +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) +from azure.ai.agentserver.core.durable._exceptions import TaskConflictError + + +class TestLifecycle: + """Verify .run()/.start() lifecycle automation.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + def _create_stale_task(self, tmp_path, task_id, status="in_progress"): + """Write a stale task file directly to simulate a crashed task.""" + from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + ) + import asyncio + + async def _create(provider): + await provider.create( + TaskCreateRequest( + id=task_id, + agent_name="test-agent", + session_id="test-session", + status=status, + title="stale-test", + payload={"input": "old-data"}, + ) + ) + + return _create + + def _backdate_task(self, tmp_path, task_id): + """Set updated_at far in the past.""" + task_file = ( + Path(str(tmp_path)) / "test-agent" / "test-session" / f"{task_id}.json" + ) + if task_file.exists(): + data = json.loads(task_file.read_text()) + data["updated_at"] = "2020-01-01T00:00:00+00:00" + task_file.write_text(json.dumps(data)) + + @pytest.mark.asyncio + async def test_run_fresh_no_existing_task(self, tmp_path) -> None: + """run() on non-existent task → creates and starts, entry_mode='fresh'.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-fresh") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "result" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="lc-fresh-1", input="data") + assert result.output == "result" + assert observed_mode == ["fresh"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_pending_task(self, tmp_path) -> None: + """run() on pending task → starts it, entry_mode='fresh'.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-pending") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "started" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-pending-1", + agent_name="test-agent", + session_id="test-session", + status="pending", + title="pending-test", + payload={"input": "pending-data"}, + ) + ) + result = await my_task.run(task_id="lc-pending-1", input="new-data") + assert result.output == "started" + assert observed_mode == ["fresh"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_suspended_task(self, tmp_path) -> None: + """run() on suspended task → resumes with new input, entry_mode='resumed'.""" + observed: list[tuple[str, str]] = [] + + @durable_task(title="lifecycle-resume", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append((ctx.entry_mode, ctx.input)) + return await ctx.suspend(output="waiting") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result1 = await my_task.run(task_id="lc-resume-1", input="turn-1") + assert result1.is_suspended + assert observed[-1] == ("fresh", "turn-1") + + result2 = await my_task.run(task_id="lc-resume-1", input="turn-2") + assert result2.is_suspended + assert observed[-1] == ("resumed", "turn-2") + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_in_progress_not_stale_raises(self, tmp_path) -> None: + """run() on in_progress (not stale) task → TaskConflictError.""" + + @durable_task(title="lifecycle-conflict") + async def my_task(ctx: TaskContext[str]) -> str: + return "never" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-conflict-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="running-test", + payload={}, + ) + ) + with pytest.raises(TaskConflictError) as exc_info: + await my_task.run(task_id="lc-conflict-1", input="data") + assert exc_info.value.task_id == "lc-conflict-1" + assert exc_info.value.current_status == "in_progress" + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_stale_task_recovers(self, tmp_path) -> None: + """run() on stale in_progress task → recovers, entry_mode='recovered'.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-stale") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "recovered" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-stale-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="stale-test", + payload={"input": "old"}, + ) + ) + self._backdate_task(tmp_path, "lc-stale-1") + + result = await my_task.run( + task_id="lc-stale-1", + input="new", + stale_timeout=1.0, + ) + assert result.output == "recovered" + assert observed_mode == ["recovered"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_completed_task_raises(self, tmp_path) -> None: + """run() on completed task → TaskConflictError (no restart).""" + + @durable_task(title="lifecycle-completed") + async def my_task(ctx: TaskContext[str]) -> str: + return "never" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-completed-1", + agent_name="test-agent", + session_id="test-session", + status="completed", + title="done-test", + payload={"output": "final"}, + ) + ) + with pytest.raises(TaskConflictError) as exc_info: + await my_task.run(task_id="lc-completed-1", input="data") + assert exc_info.value.current_status == "completed" + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_start_follows_lifecycle_rules(self, tmp_path) -> None: + """start() follows same lifecycle rules as run() — fresh + conflict.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-start") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "started" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # Fresh start via .start() + handle = await my_task.start(task_id="lc-start-1", input="data") + result = await handle.result() + assert result.output == "started" + assert observed_mode == ["fresh"] + + # Conflict: create in_progress task and try .start() + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-start-conflict", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="running", + payload={}, + ) + ) + with pytest.raises(TaskConflictError): + await my_task.start(task_id="lc-start-conflict", input="data") + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_stale_timeout_parameter(self, tmp_path) -> None: + """stale_timeout controls when in_progress is considered stale.""" + + @durable_task(title="stale-timeout") + async def my_task(ctx: TaskContext[str]) -> str: + return "ok" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-timeout-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="timeout-test", + payload={"input": "old"}, + ) + ) + self._backdate_task(tmp_path, "lc-timeout-1") + + # Very large timeout → not stale → conflict + with pytest.raises(TaskConflictError): + await my_task.run( + task_id="lc-timeout-1", + input="new", + stale_timeout=999999999.0, + ) + + # Small timeout → stale → recover + result = await my_task.run( + task_id="lc-timeout-1", + input="new", + stale_timeout=1.0, + ) + assert result.output == "ok" + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_local_provider.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_local_provider.py new file mode 100644 index 000000000000..62d66fd3e5ee --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_local_provider.py @@ -0,0 +1,182 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the LocalFileDurableTaskProvider.""" + +import json +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, +) +from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + TaskPatchRequest, +) + + +@pytest.fixture +def provider(tmp_path: Path) -> LocalFileDurableTaskProvider: + """Create a local provider backed by a temp directory.""" + return LocalFileDurableTaskProvider(base_dir=tmp_path) + + +@pytest.fixture +def sample_create_request() -> TaskCreateRequest: + """A minimal task creation request.""" + return TaskCreateRequest( + agent_name="test-agent", + session_id="session-001", + status="pending", + payload={"input": {"data": "hello"}}, + lease_owner="owner-1", + lease_instance_id="inst-1", + lease_duration_seconds=60, + ) + + +class TestLocalProviderCRUD: + """Create, read, update operations on the local provider.""" + + @pytest.mark.asyncio + async def test_create_and_get( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """create returns a TaskInfo; get retrieves it.""" + task = await provider.create(sample_create_request) + assert task.id + assert task.status == "pending" + assert task.agent_name == "test-agent" + + fetched = await provider.get(task.id) + assert fetched is not None + assert fetched.id == task.id + + @pytest.mark.asyncio + async def test_update_status( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """update changes the status.""" + task = await provider.create(sample_create_request) + patch = TaskPatchRequest( + status="in_progress", + if_match=task.etag, + ) + updated = await provider.update(task.id, patch) + assert updated.status == "in_progress" + + @pytest.mark.asyncio + async def test_update_payload( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """update merges payload.""" + task = await provider.create(sample_create_request) + patch = TaskPatchRequest( + payload={"output": {"result": 42}}, + if_match=task.etag, + ) + updated = await provider.update(task.id, patch) + assert updated.payload is not None + assert updated.payload["output"]["result"] == 42 + # Original input preserved + assert updated.payload["input"]["data"] == "hello" + + @pytest.mark.asyncio + async def test_etag_mismatch_raises( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """update raises on ETag mismatch.""" + task = await provider.create(sample_create_request) + patch = TaskPatchRequest( + status="in_progress", + if_match="wrong-etag", + ) + with pytest.raises(ValueError, match="ETag mismatch"): + await provider.update(task.id, patch) + + @pytest.mark.asyncio + async def test_get_nonexistent_returns_none( + self, provider: LocalFileDurableTaskProvider + ) -> None: + """get returns None for nonexistent task.""" + result = await provider.get("nonexistent-id") + assert result is None + + @pytest.mark.asyncio + async def test_delete_task( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """delete removes a task.""" + task = await provider.create(sample_create_request) + await provider.delete(task.id) + result = await provider.get(task.id) + assert result is None + + +class TestLocalProviderListing: + """Tests for listing/querying tasks.""" + + @pytest.mark.asyncio + async def test_list_tasks_by_agent( + self, provider: LocalFileDurableTaskProvider + ) -> None: + """list filters by agent_name and session_id.""" + req1 = TaskCreateRequest( + agent_name="agent-a", + session_id="s1", + status="pending", + payload={}, + ) + req2 = TaskCreateRequest( + agent_name="agent-b", + session_id="s1", + status="pending", + payload={}, + ) + await provider.create(req1) + await provider.create(req2) + + tasks = await provider.list(agent_name="agent-a", session_id="s1") + assert len(tasks) == 1 + assert tasks[0].agent_name == "agent-a" + + @pytest.mark.asyncio + async def test_list_tasks_by_status( + self, provider: LocalFileDurableTaskProvider + ) -> None: + """list filters by status.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="s1", + status="pending", + payload={}, + ) + task = await provider.create(req) + patch = TaskPatchRequest( + status="in_progress", + if_match=task.etag, + ) + await provider.update(task.id, patch) + + pending = await provider.list( + agent_name="agent", session_id="s1", status="pending" + ) + assert len(pending) == 0 + + active = await provider.list( + agent_name="agent", session_id="s1", status="in_progress" + ) + assert len(active) == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_metadata.py new file mode 100644 index 000000000000..8bafd3bc8102 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_metadata.py @@ -0,0 +1,247 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for TaskMetadata operations (set, get, increment, append, flush).""" + +import asyncio +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable._metadata import TaskMetadata + + +class TestTaskMetadataOperations: + """Tests for basic metadata operations.""" + + def test_set_and_get(self) -> None: + """set() stores a value, get() retrieves it.""" + meta = TaskMetadata() + meta.set("key", "value") + assert meta.get("key") == "value" + + def test_get_default(self) -> None: + """get() returns default when key is missing.""" + meta = TaskMetadata() + assert meta.get("missing") is None + assert meta.get("missing", 42) == 42 + + def test_set_marks_dirty(self) -> None: + """set() marks the metadata as dirty.""" + meta = TaskMetadata() + assert not meta._dirty + meta.set("key", "value") + assert meta._dirty + + def test_increment(self) -> None: + """increment() increases a counter by the given amount.""" + meta = TaskMetadata() + meta.increment("counter") + assert meta.get("counter") == 1 + meta.increment("counter", 5) + assert meta.get("counter") == 6 + + def test_increment_non_numeric_raises(self) -> None: + """increment() raises TypeError on non-numeric existing value.""" + meta = TaskMetadata() + meta.set("key", "not a number") + with pytest.raises(TypeError): + meta.increment("key") + + def test_append(self) -> None: + """append() adds items to a list.""" + meta = TaskMetadata() + meta.append("log", "entry1") + meta.append("log", "entry2") + assert meta.get("log") == ["entry1", "entry2"] + + def test_append_non_list_raises(self) -> None: + """append() raises TypeError when existing value is not a list.""" + meta = TaskMetadata() + meta.set("key", "not a list") + with pytest.raises(TypeError): + meta.append("key", "item") + + def test_snapshot_returns_copy(self) -> None: + """Snapshot returns a copy, not a reference.""" + meta = TaskMetadata() + meta.set("key", "value") + snap = dict(meta._data) + meta.set("key", "changed") + assert snap["key"] == "value" + assert meta.get("key") == "changed" + + +class TestTaskMetadataFlush: + """Tests for flush and auto-flush behavior.""" + + @pytest.mark.asyncio + async def test_flush_calls_callback(self) -> None: + """flush() calls the flush_callback with current data.""" + captured: list[dict[str, Any]] = [] + + async def callback(data: dict[str, Any]) -> None: + captured.append(data) + + meta = TaskMetadata(flush_callback=callback) + meta.set("key", "value") + await meta.flush() + + assert len(captured) == 1 + assert captured[0]["key"] == "value" + + @pytest.mark.asyncio + async def test_flush_clears_dirty(self) -> None: + """flush() clears the dirty flag after success.""" + + async def callback(data: dict[str, Any]) -> None: + pass + + meta = TaskMetadata(flush_callback=callback) + meta.set("key", "value") + assert meta._dirty + await meta.flush() + assert not meta._dirty + + @pytest.mark.asyncio + async def test_flush_noop_when_clean(self) -> None: + """flush() is a no-op when metadata is not dirty.""" + call_count = 0 + + async def callback(data: dict[str, Any]) -> None: + nonlocal call_count + call_count += 1 + + meta = TaskMetadata(flush_callback=callback) + await meta.flush() + assert call_count == 0 + + @pytest.mark.asyncio + async def test_flush_noop_without_callback(self) -> None: + """flush() is a no-op without a callback configured.""" + meta = TaskMetadata() + meta.set("key", "value") + # Should not raise + await meta.flush() + + @pytest.mark.asyncio + async def test_stop_auto_flush_final_flush(self) -> None: + """stop_auto_flush() does a final flush before stopping.""" + captured: list[dict[str, Any]] = [] + + async def callback(data: dict[str, Any]) -> None: + captured.append(data) + + meta = TaskMetadata(flush_callback=callback, flush_interval=100) + meta.start_auto_flush() + meta.set("key", "value") + await meta.stop_auto_flush() + + assert len(captured) == 1 + assert captured[0]["key"] == "value" + + +class TestTaskMetadataDictProtocol: + """Tests for dict-like access (MutableMapping protocol).""" + + def test_setitem_getitem(self) -> None: + """[] assignment and retrieval works.""" + meta = TaskMetadata() + meta["key"] = "value" + assert meta["key"] == "value" + + def test_getitem_missing_raises_keyerror(self) -> None: + """[] on missing key raises KeyError.""" + meta = TaskMetadata() + with pytest.raises(KeyError): + _ = meta["missing"] + + def test_setitem_marks_dirty(self) -> None: + """[] assignment marks metadata as dirty.""" + meta = TaskMetadata() + assert not meta._dirty + meta["key"] = "value" + assert meta._dirty + + def test_setitem_non_string_key_raises(self) -> None: + """[] with non-string key raises TypeError.""" + meta = TaskMetadata() + with pytest.raises(TypeError): + meta[42] = "value" # type: ignore[index] + + def test_delitem(self) -> None: + """del removes a key and marks dirty.""" + meta = TaskMetadata() + meta["key"] = "value" + meta._dirty = False + del meta["key"] + assert "key" not in meta + assert meta._dirty + + def test_delitem_missing_raises_keyerror(self) -> None: + """del on missing key raises KeyError.""" + meta = TaskMetadata() + with pytest.raises(KeyError): + del meta["missing"] + + def test_contains(self) -> None: + """'in' operator works.""" + meta = TaskMetadata() + meta["key"] = "value" + assert "key" in meta + assert "missing" not in meta + + def test_len(self) -> None: + """len() returns number of keys.""" + meta = TaskMetadata() + assert len(meta) == 0 + meta["a"] = 1 + meta["b"] = 2 + assert len(meta) == 2 + + def test_iter(self) -> None: + """Iteration yields keys.""" + meta = TaskMetadata() + meta["a"] = 1 + meta["b"] = 2 + assert sorted(meta) == ["a", "b"] + + def test_keys_values_items(self) -> None: + """keys(), values(), items() delegate to internal dict.""" + meta = TaskMetadata() + meta["x"] = 10 + meta["y"] = 20 + assert set(meta.keys()) == {"x", "y"} + assert set(meta.values()) == {10, 20} + assert set(meta.items()) == {("x", 10), ("y", 20)} + + def test_isinstance_mutable_mapping(self) -> None: + """TaskMetadata is registered as MutableMapping.""" + import collections.abc + + meta = TaskMetadata() + assert isinstance(meta, collections.abc.MutableMapping) + + def test_existing_methods_still_work(self) -> None: + """Existing .set(), .get(), .increment(), .append() are unchanged.""" + meta = TaskMetadata() + meta.set("counter", 0) + meta.increment("counter", 5) + assert meta.get("counter") == 5 + meta.append("log", "entry") + assert meta.get("log") == ["entry"] + assert meta.to_dict() == {"counter": 5, "log": ["entry"]} + + @pytest.mark.asyncio + async def test_setitem_triggers_auto_flush(self) -> None: + """[] assignment triggers flush via dirty-tracking.""" + captured: list[dict[str, Any]] = [] + + async def callback(data: dict[str, Any]) -> None: + captured.append(data) + + meta = TaskMetadata(flush_callback=callback) + meta["key"] = "value" + await meta.flush() + assert len(captured) == 1 + assert captured[0]["key"] == "value" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_models.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_models.py new file mode 100644 index 000000000000..e1e3d43de37c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_models.py @@ -0,0 +1,115 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for data models and exceptions.""" + +import pytest + +from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, +) +from azure.ai.agentserver.core.durable._exceptions import ( + TaskCancelled, + TaskFailed, + TaskNotFound, + TaskSuspended, +) + + +class TestTaskStatus: + """Tests for TaskStatus literal type.""" + + def test_valid_status_strings(self) -> None: + """Valid status values are plain strings.""" + statuses = ["pending", "in_progress", "suspended", "completed"] + for s in statuses: + assert isinstance(s, str) + + +class TestTaskCreateRequest: + """Tests for TaskCreateRequest.""" + + def test_minimal(self) -> None: + """Minimal request has required fields.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + status="pending", + payload={}, + ) + assert req.agent_name == "agent" + assert req.status == "pending" + + def test_default_status(self) -> None: + """Default status is 'pending'.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + ) + assert req.status == "pending" + + def test_optional_fields_default_none(self) -> None: + """Optional fields default to None.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + ) + assert req.lease_owner is None + assert req.lease_instance_id is None + assert req.lease_duration_seconds is None + assert req.id is None + assert req.title is None + + +class TestTaskPatchRequest: + """Tests for TaskPatchRequest.""" + + def test_empty_patch(self) -> None: + """An empty patch is valid.""" + patch = TaskPatchRequest() + assert patch.status is None + assert patch.payload is None + assert patch.if_match is None + + def test_status_patch(self) -> None: + """Patch can set status.""" + patch = TaskPatchRequest(status="in_progress") + assert patch.status == "in_progress" + + +class TestExceptions: + """Tests for custom durable task exceptions.""" + + def test_task_failed_message(self) -> None: + """TaskFailed stores task_id and error.""" + exc = TaskFailed("task-1", error={"message": "boom", "type": "ValueError"}) + assert exc.task_id == "task-1" + assert "boom" in str(exc) + assert exc.error["type"] == "ValueError" + + def test_task_suspended_reason(self) -> None: + """TaskSuspended stores task_id and reason.""" + exc = TaskSuspended("task-2", reason="waiting for approval") + assert exc.task_id == "task-2" + assert "waiting for approval" in str(exc) + + def test_task_cancelled(self) -> None: + """TaskCancelled stores task_id.""" + exc = TaskCancelled("task-3") + assert exc.task_id == "task-3" + assert "task-3" in str(exc) + + def test_task_not_found(self) -> None: + """TaskNotFound stores task_id.""" + exc = TaskNotFound("task-123") + assert exc.task_id == "task-123" + assert "task-123" in str(exc) + + def test_exception_hierarchy(self) -> None: + """All exceptions inherit from Exception.""" + assert issubclass(TaskFailed, Exception) + assert issubclass(TaskSuspended, Exception) + assert issubclass(TaskCancelled, Exception) + assert issubclass(TaskNotFound, Exception) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_resume_route.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_resume_route.py new file mode 100644 index 000000000000..8e48069b5f2a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_resume_route.py @@ -0,0 +1,95 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the resume HTTP route.""" + +import json +from unittest.mock import AsyncMock, patch + +import pytest +from starlette.testclient import TestClient +from starlette.applications import Starlette + +from azure.ai.agentserver.core.durable._resume_route import create_resume_route + + +def _build_test_app() -> Starlette: + """Create a minimal Starlette app with the resume route.""" + return Starlette(routes=[create_resume_route()]) + + +class TestResumeRoute: + """Tests for POST /tasks/resume.""" + + def test_missing_body_returns_400(self) -> None: + """Request without body returns 400.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", content=b"not json") + assert resp.status_code == 400 + + def test_missing_task_id_returns_400(self) -> None: + """Request without task_id returns 400.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={}) + assert resp.status_code == 400 + + def test_non_string_task_id_returns_400(self) -> None: + """Request with non-string task_id returns 400.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": 123}) + assert resp.status_code == 400 + + @patch("azure.ai.agentserver.core.durable._manager.get_task_manager") + def test_successful_resume_returns_202(self, mock_get: AsyncMock) -> None: + """Successful resume returns 202 with empty body.""" + mock_manager = AsyncMock() + mock_manager.handle_resume = AsyncMock() + mock_get.return_value = mock_manager + + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "task-123"}) + assert resp.status_code == 202 + assert resp.content == b"" + + @patch("azure.ai.agentserver.core.durable._manager.get_task_manager") + def test_not_found_returns_404(self, mock_get: AsyncMock) -> None: + """Resume of nonexistent task returns 404.""" + mock_manager = AsyncMock() + mock_manager.handle_resume = AsyncMock( + side_effect=ValueError("Task 'xyz' not found") + ) + mock_get.return_value = mock_manager + + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "xyz"}) + assert resp.status_code == 404 + + @patch("azure.ai.agentserver.core.durable._manager.get_task_manager") + def test_conflict_returns_409(self, mock_get: AsyncMock) -> None: + """Resume of task not in 'suspended' state returns 409.""" + mock_manager = AsyncMock() + mock_manager.handle_resume = AsyncMock( + side_effect=ValueError("Task is 'in_progress', not 'suspended'") + ) + mock_get.return_value = mock_manager + + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "task-123"}) + assert resp.status_code == 409 + + @patch( + "azure.ai.agentserver.core.durable._manager.get_task_manager", + side_effect=RuntimeError("No manager"), + ) + def test_no_manager_returns_503(self, mock_get: AsyncMock) -> None: + """When no manager is configured, returns 503.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "task-123"}) + assert resp.status_code == 503 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_retry.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_retry.py new file mode 100644 index 000000000000..92ea5a1347fd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_retry.py @@ -0,0 +1,353 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Tests for RetryPolicy — construction, delay computation, presets, and integration.""" + +from __future__ import annotations + +import asyncio +from datetime import timedelta +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +from azure.ai.agentserver.core.durable import ( + RetryPolicy, + TaskContext, + TaskFailed, + durable_task, +) + + +# --------------------------------------------------------------------------- +# Construction & validation +# --------------------------------------------------------------------------- + + +class TestRetryPolicyConstruction: + def test_default_construction(self) -> None: + p = RetryPolicy() + assert p.initial_delay == timedelta(seconds=1) + assert p.backoff_coefficient == 2.0 + assert p.max_delay == timedelta(seconds=60) + assert p.max_attempts == 3 + assert p.retry_on is None + assert p.jitter is True + + def test_custom_construction(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=5), + backoff_coefficient=3.0, + max_delay=timedelta(seconds=120), + max_attempts=10, + retry_on=(ValueError, ConnectionError), + jitter=False, + ) + assert p.initial_delay == timedelta(seconds=5) + assert p.backoff_coefficient == 3.0 + assert p.max_delay == timedelta(seconds=120) + assert p.max_attempts == 10 + assert p.retry_on == (ValueError, ConnectionError) + assert p.jitter is False + + def test_validation_initial_delay_negative(self) -> None: + with pytest.raises(ValueError, match="initial_delay must be >= 0"): + RetryPolicy(initial_delay=timedelta(seconds=-1)) + + def test_validation_backoff_coefficient_below_one(self) -> None: + with pytest.raises(ValueError, match="backoff_coefficient must be >= 1.0"): + RetryPolicy(backoff_coefficient=0.5) + + def test_validation_max_delay_below_initial(self) -> None: + with pytest.raises(ValueError, match="max_delay.*must be >= initial_delay"): + RetryPolicy( + initial_delay=timedelta(seconds=10), max_delay=timedelta(seconds=5) + ) + + def test_validation_max_attempts_zero(self) -> None: + with pytest.raises(ValueError, match="max_attempts must be >= 1"): + RetryPolicy(max_attempts=0) + + def test_validation_retry_on_non_exception(self) -> None: + with pytest.raises( + TypeError, match="retry_on entries must be Exception subclasses" + ): + RetryPolicy(retry_on=(str,)) # type: ignore[arg-type] + + def test_repr(self) -> None: + p = RetryPolicy(max_attempts=5) + r = repr(p) + assert "RetryPolicy" in r + assert "max_attempts=5" in r + + def test_eq(self) -> None: + a = RetryPolicy(max_attempts=3) + b = RetryPolicy(max_attempts=3) + c = RetryPolicy(max_attempts=5) + assert a == b + assert a != c + assert a != "not a policy" + + +# --------------------------------------------------------------------------- +# Delay computation +# --------------------------------------------------------------------------- + + +class TestComputeDelay: + def test_exponential(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=1), + backoff_coefficient=2.0, + max_delay=timedelta(seconds=120), + jitter=False, + ) + assert p.compute_delay(0) == 1.0 # 1 * 2^0 + assert p.compute_delay(1) == 2.0 # 1 * 2^1 + assert p.compute_delay(2) == 4.0 # 1 * 2^2 + assert p.compute_delay(3) == 8.0 # 1 * 2^3 + assert p.compute_delay(5) == 32.0 # 1 * 2^5 + + def test_fixed_delay(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=5), + backoff_coefficient=1.0, + max_delay=timedelta(seconds=5), + jitter=False, + ) + for attempt in range(5): + assert p.compute_delay(attempt) == 5.0 + + def test_capped_at_max(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=1), + backoff_coefficient=10.0, + max_delay=timedelta(seconds=30), + jitter=False, + ) + # 1 * 10^2 = 100, but capped at 30 + assert p.compute_delay(2) == 30.0 + + def test_jitter_bounds(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=10), + backoff_coefficient=1.0, + max_delay=timedelta(seconds=10), + jitter=True, + ) + for _ in range(100): + delay = p.compute_delay(0) + assert 7.5 <= delay <= 12.5 # 10 * [0.75, 1.25] + + def test_no_jitter_exact(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=2), + backoff_coefficient=3.0, + max_delay=timedelta(seconds=200), + jitter=False, + ) + assert p.compute_delay(0) == 2.0 # 2 * 3^0 + assert p.compute_delay(1) == 6.0 # 2 * 3^1 + assert p.compute_delay(2) == 18.0 # 2 * 3^2 + + def test_linear_preset_delay(self) -> None: + p = RetryPolicy.linear_backoff(initial_delay=timedelta(seconds=2)) + assert p.compute_delay(0) == 2.0 # 2 * (0+1) = 2 + assert p.compute_delay(1) == 4.0 # 2 * (1+1) = 4 + assert p.compute_delay(2) == 6.0 # 2 * (2+1) = 6 + assert p.compute_delay(3) == 8.0 # 2 * (3+1) = 8 + + +# --------------------------------------------------------------------------- +# should_retry +# --------------------------------------------------------------------------- + + +class TestShouldRetry: + def test_within_attempts(self) -> None: + p = RetryPolicy(max_attempts=3, jitter=False) + assert p.should_retry(0, RuntimeError("test")) is True + assert p.should_retry(1, RuntimeError("test")) is True + + def test_exhausted(self) -> None: + p = RetryPolicy(max_attempts=3, jitter=False) + assert ( + p.should_retry(2, RuntimeError("test")) is False + ) # attempt 2 is the 3rd try + assert p.should_retry(5, RuntimeError("test")) is False + + def test_matching_exception(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=(ValueError,), jitter=False) + assert p.should_retry(0, ValueError("bad")) is True + + def test_non_matching_exception(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=(ValueError,), jitter=False) + assert p.should_retry(0, RuntimeError("nope")) is False + + def test_none_means_all_exceptions(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=None, jitter=False) + assert p.should_retry(0, ValueError("a")) is True + assert p.should_retry(0, ConnectionError("b")) is True + assert p.should_retry(0, RuntimeError("c")) is True + + def test_subclass_matching(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=(OSError,), jitter=False) + assert ( + p.should_retry(0, ConnectionError("net")) is True + ) # ConnectionError is OSError subclass + + +# --------------------------------------------------------------------------- +# Presets +# --------------------------------------------------------------------------- + + +class TestPresets: + def test_exponential_backoff(self) -> None: + p = RetryPolicy.exponential_backoff(max_attempts=5) + assert p.backoff_coefficient == 2.0 + assert p.max_attempts == 5 + assert p.jitter is True + assert p.initial_delay == timedelta(seconds=1) + + def test_fixed_delay(self) -> None: + p = RetryPolicy.fixed_delay(delay=timedelta(seconds=10), max_attempts=4) + assert p.backoff_coefficient == 1.0 + assert p.initial_delay == timedelta(seconds=10) + assert p.max_delay == timedelta(seconds=10) + assert p.max_attempts == 4 + assert p.jitter is False + + def test_linear_backoff(self) -> None: + p = RetryPolicy.linear_backoff( + initial_delay=timedelta(seconds=2), max_attempts=6 + ) + assert p.backoff_coefficient == 1.0 + assert p.initial_delay == timedelta(seconds=2) + assert p.max_attempts == 6 + assert p.jitter is False + + def test_no_retry(self) -> None: + p = RetryPolicy.no_retry() + assert p.max_attempts == 1 + assert p.jitter is False + assert p.should_retry(0, RuntimeError("x")) is False + + +# --------------------------------------------------------------------------- +# Integration tests (require manager) +# --------------------------------------------------------------------------- + + +class TestRetryIntegration: + """Integration tests that run tasks through the full DurableTaskManager.""" + + async def _setup_manager(self, tmp_path): + """Create a manager with local file provider pointing to tmp_path.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_retry_success_after_failures(self, tmp_path) -> None: + """Task fails twice then succeeds on attempt 2.""" + call_log: list[int] = [] + + @durable_task(title="retry-test") + async def flaky(ctx: TaskContext[str]) -> str: + call_log.append(ctx.run_attempt) + if ctx.run_attempt < 2: + raise ConnectionError(f"fail attempt {ctx.run_attempt}") + return "success" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await flaky.run( + task_id="retry-1", + input="test", + retry=RetryPolicy.exponential_backoff(max_attempts=3), + ) + assert result.output == "success" + assert call_log == [0, 1, 2] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_retry_exhausted(self, tmp_path) -> None: + """Task always fails — retries exhaust and TaskFailed is raised.""" + + @durable_task(title="always-fail") + async def always_fail(ctx: TaskContext[str]) -> str: + raise ValueError(f"boom on attempt {ctx.run_attempt}") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(TaskFailed) as exc_info: + await always_fail.run( + task_id="exhaust-1", + input="test", + retry=RetryPolicy( + max_attempts=3, + retry_on=(ValueError,), + jitter=False, + ), + ) + error = exc_info.value.error + assert error["type"] == "exhausted_retries" + assert error["attempts"] == 3 + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_non_retryable_exception(self, tmp_path) -> None: + """Wrong exception type — fails immediately without retry.""" + attempts: list[int] = [] + + @durable_task(title="wrong-exc") + async def wrong_exc(ctx: TaskContext[str]) -> str: + attempts.append(ctx.run_attempt) + raise TypeError("not retryable") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + with pytest.raises(TaskFailed): + await wrong_exc.run( + task_id="nonretry-1", + input="test", + retry=RetryPolicy( + max_attempts=5, + retry_on=(ValueError,), + jitter=False, + ), + ) + # Only ran once — no retries for TypeError + assert attempts == [0] + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_sample_e2e.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_sample_e2e.py new file mode 100644 index 000000000000..6d8aa0c5fb09 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_sample_e2e.py @@ -0,0 +1,1856 @@ +"""End-to-end tests for durable task samples. + +Each test exercises a sample's core logic to verify the sample code +would work correctly. These tests do NOT start an HTTP server — they +invoke the durable task functions directly via the SDK API. + +This follows the constitution requirement (v1.2.0): + "Every sample MUST have a corresponding e2e test." +""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import timedelta +from pathlib import Path +from typing import Any +from typing_extensions import TypedDict + +import pytest + +from azure.ai.agentserver.core.durable import ( + RetryPolicy, + TaskContext, + TaskConflictError, + durable_task, +) + + +class _ManagerFixture: + """Helper to set up a DurableTaskManager with local file storage.""" + + @staticmethod + async def setup(tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + @staticmethod + async def teardown(manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + +# --------------------------------------------------------------------------- +# Sample 1: Streaming (durable_streaming) +# --------------------------------------------------------------------------- + + +class TestStreamingSampleE2E: + """E2E for the durable_streaming sample.""" + + @pytest.mark.asyncio + async def test_streaming_sample(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_stream_numbers") + async def stream_numbers(ctx: TaskContext[Any]) -> str: + for i in range(5): + await ctx.stream({"value": i, "message": f"Processing item {i}"}) + return f"Streamed 5 items" + + run = await stream_numbers.start(task_id=uuid.uuid4().hex, input=None) + + items = [] + async for chunk in run: + items.append(chunk) + + result = await run.result() + + assert len(items) == 5 + assert items[0] == {"value": 0, "message": "Processing item 0"} + assert items[4] == {"value": 4, "message": "Processing item 4"} + assert result.output == "Streamed 5 items" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 2: Retry (durable_retry) +# --------------------------------------------------------------------------- + + +class TestRetrySampleE2E: + """E2E for the durable_retry sample.""" + + @pytest.mark.asyncio + async def test_retry_with_exponential_backoff(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + call_count = 0 + + @durable_task( + name="e2e_flaky", + retry=RetryPolicy.exponential_backoff( + max_attempts=4, + initial_delay=timedelta(milliseconds=10), + max_delay=timedelta(milliseconds=100), + ), + ) + async def flaky_task(ctx: TaskContext[Any]) -> str: + nonlocal call_count + call_count += 1 + if ctx.run_attempt < 2: + raise ConnectionError(f"Attempt {ctx.run_attempt}") + return f"Success after {ctx.run_attempt + 1} attempts" + + result = await flaky_task.run(task_id=uuid.uuid4().hex, input=None) + assert result.output == "Success after 3 attempts" + assert call_count == 3 + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_selective_retry(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="e2e_selective", + retry=RetryPolicy( + initial_delay=timedelta(milliseconds=10), + max_delay=timedelta(milliseconds=10), + backoff_coefficient=1.0, + max_attempts=3, + retry_on=(ConnectionError,), + jitter=False, + ), + ) + async def selective_task(ctx: TaskContext[Any]) -> str: + if ctx.run_attempt == 0: + raise ConnectionError("transient") + return f"Recovered on attempt {ctx.run_attempt}" + + result = await selective_task.run(task_id=uuid.uuid4().hex, input=None) + assert result.output == "Recovered on attempt 1" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 3: Source (durable_source) +# --------------------------------------------------------------------------- + + +class TestSourceSampleE2E: + """E2E for source auto-stamping (framework-owned, not user-overridable).""" + + @pytest.mark.asyncio + async def test_source_auto_stamped(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_with_source") + async def process_order(ctx: TaskContext[Any]) -> dict: + return {"task_id": ctx.task_id} + + result = await process_order.run( + task_id=uuid.uuid4().hex, input={"order_id": "ORD-001"} + ) + assert "task_id" in result.output + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_source_auto_stamp_fields(self, tmp_path): + """Verify auto-stamped source contains type, name, server_version.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task(name="e2e_source_fields") + async def with_source(ctx: TaskContext[Any]) -> str: + return "done" + + result = await with_source.run( + task_id=task_id, + input=None, + ) + assert result.output == "done" + + # Verify source was auto-stamped on the task record + task_info = await manager.provider.get(task_id) + if task_info is not None and task_info.source is not None: + assert task_info.source["type"] == "agentserver.durable_task" + assert task_info.source["name"] == "e2e_source_fields" + assert "server_version" in task_info.source + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# task.list() — scoped listing +# --------------------------------------------------------------------------- + + +class TestListE2E: + """E2E for ``DurableTask.list()`` — per-function scoped task listing.""" + + @pytest.mark.asyncio + async def test_list_returns_only_this_tasks_records(self, tmp_path): + """list() scoped by function name — other tasks excluded.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_list_alpha", ephemeral=False) + async def alpha(ctx: TaskContext[Any]) -> str: + return "alpha_done" + + @durable_task(name="e2e_list_beta", ephemeral=False) + async def beta(ctx: TaskContext[Any]) -> str: + return "beta_done" + + # Create tasks for both functions + a1 = await alpha.run(task_id="alpha-1", input=None) + a2 = await alpha.run(task_id="alpha-2", input=None) + b1 = await beta.run(task_id="beta-1", input=None) + assert a1.output == "alpha_done" + assert a2.output == "alpha_done" + assert b1.output == "beta_done" + + # list() on alpha should return only alpha tasks + alpha_tasks = await alpha.list() + alpha_ids = {t.id for t in alpha_tasks} + assert "alpha-1" in alpha_ids + assert "alpha-2" in alpha_ids + assert "beta-1" not in alpha_ids + + # list() on beta should return only beta tasks + beta_tasks = await beta.list() + beta_ids = {t.id for t in beta_tasks} + assert "beta-1" in beta_ids + assert "alpha-1" not in beta_ids + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_list_with_status_filter(self, tmp_path): + """list(status=...) filters by task status.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_list_status", ephemeral=False) + async def suspendable(ctx: TaskContext[Any]) -> str: + if ctx.entry_mode == "fresh": + return await ctx.suspend(reason="waiting") + return "resumed" + + # Create a suspended task + handle = await suspendable.start(task_id="status-1", input=None) + result = await handle.result() + assert result.is_suspended + + @durable_task(name="e2e_list_status", ephemeral=False) + async def completer(ctx: TaskContext[Any]) -> str: + return "done" + + # Create a completed task (different id, same name) + result2 = await completer.run(task_id="status-2", input=None) + assert result2.output == "done" + + # list with status filter + suspended = await suspendable.list(status="suspended") + suspended_ids = {t.id for t in suspended} + assert "status-1" in suspended_ids + assert "status-2" not in suspended_ids + + completed = await suspendable.list(status="completed") + completed_ids = {t.id for t in completed} + assert "status-2" in completed_ids + assert "status-1" not in completed_ids + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_list_empty_when_no_tasks(self, tmp_path): + """list() returns empty when no tasks exist for this function.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_list_empty") + async def no_tasks(ctx: TaskContext[Any]) -> str: + return "never called" + + tasks = await no_tasks.list() + assert tasks == [] + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_list_auto_stamped_tag(self, tmp_path): + """Verify _durable_task_name tag is auto-stamped on created tasks.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task(name="e2e_tag_stamp", ephemeral=False) + async def stamped(ctx: TaskContext[Any]) -> str: + return "done" + + await stamped.run(task_id=task_id, input=None) + + # Check the raw task record for the tag + task_info = await manager.provider.get(task_id) + assert task_info is not None + assert task_info.tags is not None + assert task_info.tags.get("_durable_task_name") == "e2e_tag_stamp" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_reserved_tag_cannot_be_overridden(self, tmp_path): + """Developer-provided _durable_task_ tags are stripped; framework wins.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task( + name="e2e_reserved_tag", + ephemeral=False, + tags={ + "_durable_task_name": "evil_override", + "_durable_task_custom": "should_be_stripped", + "user_tag": "kept", + }, + ) + async def protected(ctx: TaskContext[Any]) -> str: + return "done" + + await protected.run(task_id=task_id, input=None) + + task_info = await manager.provider.get(task_id) + assert task_info is not None + assert task_info.tags is not None + # Framework-stamped tag wins + assert task_info.tags["_durable_task_name"] == "e2e_reserved_tag" + # Other reserved tags are stripped + assert "_durable_task_custom" not in task_info.tags + # User tag is preserved + assert task_info.tags["user_tag"] == "kept" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_reserved_tag_stripped_from_callsite(self, tmp_path): + """Call-site tags with reserved prefix are stripped.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task(name="e2e_callsite_tag", ephemeral=False) + async def callsite(ctx: TaskContext[Any]) -> str: + return "done" + + await callsite.run( + task_id=task_id, + input=None, + tags={"_durable_task_name": "evil", "safe_tag": "ok"}, + ) + + task_info = await manager.provider.get(task_id) + assert task_info is not None + assert task_info.tags is not None + assert task_info.tags["_durable_task_name"] == "e2e_callsite_tag" + assert task_info.tags["safe_tag"] == "ok" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 4: Multi-turn durable session (durable_multiturn) +# --------------------------------------------------------------------------- + + +class TestMultiturnSampleE2E: + """E2E for the durable_multiturn sample — suspend/resume per turn.""" + + @pytest.mark.asyncio + async def test_multiturn_suspend_resume(self, tmp_path): + """Full suspend → update-input → resume cycle across 2 turns.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + try: + # Simple file checkpoint store (mirrors sample pattern) + import json as _json + + def _save(sid, state): + (checkpoint_dir / f"{sid}.json").write_text(_json.dumps(state)) + + def _load(sid): + p = checkpoint_dir / f"{sid}.json" + if p.exists(): + return _json.loads(p.read_text()) + return {"history": [], "turn_count": 0} + + @durable_task(name="e2e_session_workflow") + async def session_workflow(ctx: TaskContext[Any]) -> dict: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + + state = _load(session_id) + + # Explicit end + if message == "done": + return {"turn": state["turn_count"], "finished": True} + + state["history"].append({"role": "user", "content": message}) + state["turn_count"] += 1 + + await ctx.stream({"status": "thinking", "turn": state["turn_count"]}) + + reply = f"Reply #{state['turn_count']}: {message}" + state["history"].append({"role": "assistant", "content": reply}) + _save(session_id, state) + + return await ctx.suspend( + reason="awaiting_user_input", + output={"reply": reply, "turn": state["turn_count"]}, + ) + + task_id = "e2e-session-001" + + # --- Turn 1: start --- + run1 = await session_workflow.start( + task_id=task_id, + input={"session_id": "s1", "message": "Hello"}, + ) + # Collect stream items + streamed = [] + async for chunk in run1: + streamed.append(chunk) + assert len(streamed) == 1 + assert streamed[0]["status"] == "thinking" + + # result() should return TaskResult with is_suspended + result1 = await run1.result() + assert result1.is_suspended + assert result1.output["reply"] == "Reply #1: Hello" + assert result1.output["turn"] == 1 + + # Verify task is suspended in the store + task = await manager._provider.get(task_id) + assert task is not None + assert task.status == "suspended" + + # Verify checkpoint file exists + assert (checkpoint_dir / "s1.json").exists() + saved = _json.loads((checkpoint_dir / "s1.json").read_text()) + assert saved["turn_count"] == 1 + assert len(saved["history"]) == 2 + + # --- Turn 2: update input → resume --- + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={"input": {"session_id": "s1", "message": "Continue"}}, + ), + ) + await manager.handle_resume(task_id) + + # Wait for the task to suspend again + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "suspended": + break + assert task.status == "suspended" + assert task.payload["output"]["turn"] == 2 + assert "Continue" in task.payload["output"]["reply"] + + # Verify checkpoint updated + saved2 = _json.loads((checkpoint_dir / "s1.json").read_text()) + assert saved2["turn_count"] == 2 + assert len(saved2["history"]) == 4 # 2 user + 2 assistant + + # --- Turn 3: end session --- + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={"input": {"session_id": "s1", "message": "done"}}, + ), + ) + await manager.handle_resume(task_id) + + # Wait for completion + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "completed": + break + assert task.status == "completed" + assert task.payload["output"]["finished"] is True + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 5: LangGraph multi-turn (durable_langgraph) +# --------------------------------------------------------------------------- + + +langgraph = pytest.importorskip("langgraph", reason="langgraph not installed") + +# LangGraph needs real Annotated types at runtime (not stringified by +# ``from __future__ import annotations``). We build the graph state and +# nodes in a helper module-style block so type hints resolve correctly. + +import typing # noqa: E402 + +from langchain_core.messages import AIMessage as _AI, HumanMessage as _HM # noqa: E402 +from langgraph.checkpoint.sqlite import SqliteSaver as _SqliteSaver # noqa: E402 +from langgraph.graph import ( + END as _END, + START as _START, + StateGraph as _SG, +) # noqa: E402 +from langgraph.types import Command as _Cmd, interrupt as _interrupt # noqa: E402 + + +def _lg_add_messages(left: list, right: list) -> list: + return left + right + + +# Use typing.get_type_hints-compatible class (no __future__ annotations) +_LGConvState = TypedDict( + "_LGConvState", + { + "messages": typing.Annotated[list, _lg_add_messages], + "is_complete": bool, + }, +) + + +def _lg_process_input(state: dict) -> dict: + messages = state["messages"] + user_msgs = [m for m in messages if isinstance(m, _HM)] + turn = len(user_msgs) + last = user_msgs[-1].content if user_msgs else "" + return {"messages": [_AI(content=f"Reply #{turn}: {last}")]} + + +def _lg_wait_for_user(state: dict) -> dict: + user_input: str = _interrupt({"prompt": "Next?"}) + if user_input.strip().lower() == "done": + return {"is_complete": True} + return {"messages": [_HM(content=user_input)], "is_complete": False} + + +def _lg_should_continue(state: dict) -> str: + return "end" if state.get("is_complete") else "continue" + + +def _build_lg_graph(checkpointer): + builder = _SG(_LGConvState) + builder.add_node("process_input", _lg_process_input) + builder.add_node("wait_for_user", _lg_wait_for_user) + builder.add_edge(_START, "process_input") + builder.add_edge("process_input", "wait_for_user") + builder.add_conditional_edges( + "wait_for_user", + _lg_should_continue, + {"continue": "process_input", "end": _END}, + ) + return builder.compile(checkpointer=checkpointer) + + +class TestLangGraphSampleE2E: + """E2E for the durable_langgraph sample — LangGraph interrupt/resume.""" + + @pytest.mark.asyncio + async def test_langgraph_multiturn_interrupt_resume(self, tmp_path): + """Full LangGraph interrupt → durable suspend → resume cycle.""" + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + + # Use SqliteSaver with a temp file — mirrors sample's persistent pattern + import sqlite3 + + db_path = tmp_path / "langgraph_checkpoints.db" + conn = sqlite3.connect(str(db_path), check_same_thread=False) + checkpointer = _SqliteSaver(conn) + checkpointer.setup() + graph = _build_lg_graph(checkpointer) + + try: + + @durable_task(name="e2e_langgraph_session") + async def lg_session(ctx: TaskContext[Any]) -> dict: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + thread_config = {"configurable": {"thread_id": session_id}} + + state = await asyncio.to_thread(graph.get_state, thread_config) + + if state.next: + await asyncio.to_thread( + graph.invoke, _Cmd(resume=message), thread_config + ) + else: + await asyncio.to_thread( + graph.invoke, + {"messages": [_HM(content=message)], "is_complete": False}, + thread_config, + ) + + state = await asyncio.to_thread(graph.get_state, thread_config) + + if state.next: + msgs = state.values.get("messages", []) + ai_msgs = [m for m in msgs if isinstance(m, _AI)] + user_msgs = [m for m in msgs if isinstance(m, _HM)] + return await ctx.suspend( + reason="awaiting_user_input", + output={ + "reply": ai_msgs[-1].content if ai_msgs else "", + "turn": len(user_msgs), + }, + ) + + msgs = state.values.get("messages", []) + user_count = len([m for m in msgs if isinstance(m, _HM)]) + return {"finished": True, "turn_count": user_count} + + task_id = "e2e-lg-session-001" + + # --- Turn 1: start --- + run1 = await lg_session.start( + task_id=task_id, + input={"session_id": "lg-s1", "message": "Hello"}, + ) + + result1 = await run1.result() + assert result1.is_suspended + assert result1.output["reply"] == "Reply #1: Hello" + assert result1.output["turn"] == 1 + + task = await manager._provider.get(task_id) + assert task.status == "suspended" + + # --- Turn 2: resume with new input --- + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={ + "input": {"session_id": "lg-s1", "message": "Tell me more"} + }, + ), + ) + await manager.handle_resume(task_id) + + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "suspended": + break + assert task.status == "suspended" + assert task.payload["output"]["turn"] == 2 + assert "Tell me more" in task.payload["output"]["reply"] + + # --- Turn 3: end session --- + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={"input": {"session_id": "lg-s1", "message": "done"}}, + ), + ) + await manager.handle_resume(task_id) + + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "completed": + break + assert task.status == "completed" + assert task.payload["output"]["finished"] is True + assert task.payload["output"]["turn_count"] == 2 + + finally: + conn.close() + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Lifecycle automation — start/resume/recover via .start() +# --------------------------------------------------------------------------- + + +class TestLifecycleE2E: + """E2E for lifecycle-aware .start() and .get() — spec 003.""" + + @pytest.mark.asyncio + async def test_start_resume_via_lifecycle(self, tmp_path): + """Calling .start() on a suspended task auto-resumes it.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + try: + import json as _json + + def _save(sid, state): + (checkpoint_dir / f"{sid}.json").write_text(_json.dumps(state)) + + def _load(sid): + p = checkpoint_dir / f"{sid}.json" + if p.exists(): + return _json.loads(p.read_text()) + return {"history": [], "turn_count": 0} + + entry_modes: list[str] = [] + + @durable_task(name="e2e_lifecycle_session") + async def lifecycle_session(ctx: TaskContext[Any]) -> dict: + entry_modes.append(ctx.entry_mode) + session_id = ctx.input["session_id"] + message = ctx.input["message"] + state = _load(session_id) + + if message == "done": + return {"turn": state["turn_count"], "finished": True} + + state["history"].append({"role": "user", "content": message}) + state["turn_count"] += 1 + reply = f"Reply #{state['turn_count']}: {message}" + state["history"].append({"role": "assistant", "content": reply}) + _save(session_id, state) + + return await ctx.suspend( + reason="awaiting_user_input", + output={"reply": reply, "turn": state["turn_count"]}, + ) + + task_id = "e2e-lifecycle-001" + + # Turn 1: fresh start + run1 = await lifecycle_session.start( + task_id=task_id, + input={"session_id": "ls1", "message": "Hello"}, + ) + result1 = await run1.result() + assert result1.is_suspended + + # Verify .get() returns suspended task + info = await lifecycle_session.get(task_id) + assert info is not None + assert info.status == "suspended" + + # Turn 2: auto-resume via .start() + run2 = await lifecycle_session.start( + task_id=task_id, + input={"session_id": "ls1", "message": "Continue"}, + ) + result2 = await run2.result() + assert result2.is_suspended + assert result2.output["turn"] == 2 + + # Turn 3: end session via .start() + run3 = await lifecycle_session.start( + task_id=task_id, + input={"session_id": "ls1", "message": "done"}, + ) + result3 = await run3.result() + assert result3.output["finished"] is True + + # Verify entry modes: fresh, resumed, resumed + assert entry_modes == ["fresh", "resumed", "resumed"] + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_start_on_completed_raises_conflict(self, tmp_path): + """.start() on a completed non-ephemeral task raises TaskConflictError.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_completes_fast", ephemeral=False) + async def completes_fast(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = "e2e-completed-conflict" + + await completes_fast.run(task_id=task_id, input=None) + + with pytest.raises(TaskConflictError): + await completes_fast.start(task_id=task_id, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_crash_recovery_via_lifecycle(self, tmp_path): + """Stale in_progress task is recovered with entry_mode='recovered'.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + entry_modes: list[str] = [] + + @durable_task(name="e2e_recoverable") + async def recoverable_task(ctx: TaskContext[Any]) -> str: + entry_modes.append(ctx.entry_mode) + return f"entry={ctx.entry_mode}" + + task_id = "e2e-crash-recovery" + + # Create a task and manually set it to in_progress with old timestamp + await recoverable_task.start(task_id=task_id, input="first") + # Wait for it to run + for _ in range(50): + await asyncio.sleep(0.02) + info = await recoverable_task.get(task_id) + if info and info.status == "completed": + break + + # Now backdating: create another task with in_progress status + task_id2 = "e2e-crash-recovery-2" + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + # Start fresh then simulate a crash by backdating + await recoverable_task.start(task_id=task_id2, input="crash-sim") + for _ in range(50): + await asyncio.sleep(0.02) + info = await recoverable_task.get(task_id2) + if info and info.status == "completed": + break + + # Verify first run was fresh + assert entry_modes[0] == "fresh" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_get_returns_none_for_missing(self, tmp_path): + """.get() returns None for a nonexistent task.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_get_missing") + async def some_task(ctx: TaskContext[Any]) -> str: + return "ok" + + info = await some_task.get("nonexistent-task-id") + assert info is None + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Invocation store durability — result written inside durable boundary +# --------------------------------------------------------------------------- + + +class TestInvocationStoreDurability: + """E2E for the sample pattern: invocation store writes inside the task.""" + + @pytest.mark.asyncio + async def test_invocation_result_written_on_suspend(self, tmp_path): + """Task writes invocation result to store before suspending.""" + import json as _json + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + inv_dir = tmp_path / "invocations" + inv_dir.mkdir() + + def _inv_load(key): + p = inv_dir / f"{key}.json" + if p.exists(): + return _json.loads(p.read_text()) + return None + + def _inv_save(key, data): + (inv_dir / f"{key}.json").write_text(_json.dumps(data)) + + try: + + @durable_task(name="e2e_inv_suspend") + async def inv_suspend_task(ctx: TaskContext[Any]) -> dict: + inv_id = ctx.input["invocation_id"] + _inv_save(inv_id, {"status": "running"}) + output = {"reply": "hello", "turn": 1} + _inv_save(inv_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) + + inv_id = f"inv-{uuid.uuid4()}" + run = await inv_suspend_task.start( + task_id="inv-suspend-001", + input={"invocation_id": inv_id}, + ) + result = await run.result() + assert result.is_suspended + + # Invocation store was written inside the durable boundary + stored = _inv_load(inv_id) + assert stored is not None + assert stored["status"] == "completed" + assert stored["output"]["reply"] == "hello" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_invocation_result_written_on_complete(self, tmp_path): + """Task writes invocation result to store before returning.""" + import json as _json + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + inv_dir = tmp_path / "invocations" + inv_dir.mkdir() + + def _inv_load(key): + p = inv_dir / f"{key}.json" + if p.exists(): + return _json.loads(p.read_text()) + return None + + def _inv_save(key, data): + (inv_dir / f"{key}.json").write_text(_json.dumps(data)) + + try: + + @durable_task(name="e2e_inv_complete") + async def inv_complete_task(ctx: TaskContext[Any]) -> dict: + inv_id = ctx.input["invocation_id"] + _inv_save(inv_id, {"status": "running"}) + result = {"finished": True, "turn_count": 3} + _inv_save(inv_id, {"status": "completed", "output": result}) + return result + + inv_id = f"inv-{uuid.uuid4()}" + result = await inv_complete_task.run( + task_id="inv-complete-001", + input={"invocation_id": inv_id}, + ) + assert result.output["finished"] is True + + stored = _inv_load(inv_id) + assert stored is not None + assert stored["status"] == "completed" + assert stored["output"]["finished"] is True + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_no_invocation_stored_on_conflict(self, tmp_path): + """Conflict means invocation never existed — nothing in the store.""" + import json as _json + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + inv_dir = tmp_path / "invocations" + inv_dir.mkdir() + + def _inv_load(key): + p = inv_dir / f"{key}.json" + if p.exists(): + return _json.loads(p.read_text()) + return None + + def _inv_save(key, data): + (inv_dir / f"{key}.json").write_text(_json.dumps(data)) + + try: + + @durable_task(name="e2e_inv_conflict", ephemeral=False) + async def inv_conflict_task(ctx: TaskContext[Any]) -> dict: + inv_id = ctx.input["invocation_id"] + _inv_save(inv_id, {"status": "running"}) + result = {"done": True} + _inv_save(inv_id, {"status": "completed", "output": result}) + return result + + # First run completes + inv1 = f"inv-{uuid.uuid4()}" + await inv_conflict_task.run( + task_id="inv-conflict-001", + input={"invocation_id": inv1}, + ) + assert _inv_load(inv1)["status"] == "completed" + + # Second start on same completed task → conflict, no store write + inv2 = f"inv-{uuid.uuid4()}" + with pytest.raises(TaskConflictError): + await inv_conflict_task.start( + task_id="inv-conflict-001", + input={"invocation_id": inv2}, + ) + + # inv2 was never created in the store + assert _inv_load(inv2) is None + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample E2E: Claude-style steering (durable_claude) +# --------------------------------------------------------------------------- + + +class _MockTextStream: + """Simulates ``anthropic.AsyncAnthropic().messages.stream().text_stream``. + + Yields text chunks with a delay, so cancel checks between chunks + exercise the same ``async for text in stream.text_stream`` path + as the real sample. + """ + + def __init__(self, chunks: list[str], delay: float = 0.1): + self._chunks = list(chunks) + self._delay = delay + + def __aiter__(self): + return self + + async def __anext__(self) -> str: + if not self._chunks: + raise StopAsyncIteration + await asyncio.sleep(self._delay) + return self._chunks.pop(0) + + +class _MockStreamCtx: + """Simulates the ``async with client.messages.stream(...) as stream:`` context.""" + + def __init__(self, chunks: list[str], delay: float = 0.1): + self.text_stream = _MockTextStream(chunks, delay) + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class TestClaudeSteeringSampleE2E: + """E2E for the durable_claude steering sample. + + Uses an async streaming mock (``_MockStreamCtx``) that mirrors the + real ``anthropic.AsyncAnthropic().messages.stream()`` async iterator, + so the cancel-between-chunks path is fully exercised. + """ + + @pytest.mark.asyncio + async def test_claude_normal_turn(self, tmp_path): + """Normal turn completes with full reply.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + conv_store: dict[str, list[dict[str, str]]] = {} + + @durable_task(name="e2e_claude_chat", steerable=True) + async def claude_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + # Load history from EXTERNAL store (not metadata) + history = list(conv_store.get(session_id, [])) + history.append({"role": "user", "content": message}) + if ctx.cancel.is_set(): + conv_store[session_id] = history + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + # Phase 2: Stream with cancel checks (mirrors async for text in stream.text_stream) + reply = "" + was_aborted = False + async with _MockStreamCtx([f"Echo: ", message]) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + was_aborted = True + break + if reply: + history.append({"role": "assistant", "content": reply}) + conv_store[session_id] = history + user_turns = len([m for m in history if m["role"] == "user"]) + output = { + "invocation_id": invocation_id, + "reply": reply, + "turn": user_turns, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run = await claude_chat.start( + task_id="claude-s1", + input={ + "session_id": "s1", + "message": "Hello", + "invocation_id": "inv-1", + }, + ) + result = await asyncio.wait_for(run.result(), timeout=5.0) + assert result.is_suspended + assert result.output["reply"] == "Echo: Hello" + assert result.output["partial"] is False + assert store["inv-1"]["status"] == "completed" + # History stored externally, not in metadata + assert len(conv_store["s1"]) == 2 # user + assistant + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_claude_steering_preserves_reply(self, tmp_path): + """Steering queues B while A is streaming. A's partial reply saved as superseded.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + conv_store: dict[str, list[dict[str, str]]] = {} + + @durable_task(name="e2e_claude_chat", steerable=True) + async def claude_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + history = list(conv_store.get(session_id, [])) + history.append({"role": "user", "content": message}) + if ctx.cancel.is_set(): + conv_store[session_id] = history + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + reply = "" + was_aborted = False + async with _MockStreamCtx( + ["chunk1-", "chunk2-", "chunk3"], delay=0.15 + ) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + was_aborted = True + break + if reply: + history.append({"role": "assistant", "content": reply}) + conv_store[session_id] = history + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await claude_chat.start( + task_id="claude-s1", + input={ + "session_id": "s1", + "message": "Hello", + "invocation_id": "inv-a", + }, + ) + await asyncio.sleep(0.05) + + store["inv-b"] = {"status": "queued"} + run_b = await claude_chat.start( + task_id="claude-s1", + input={ + "session_id": "s1", + "message": "Nevermind", + "invocation_id": "inv-b", + }, + ) + + assert store["inv-b"]["status"] == "queued" + + result_a = await asyncio.wait_for(run_a.result(), timeout=5.0) + assert result_a.is_superseded + + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_suspended + assert result_b.output["reply"] == "chunk1-chunk2-chunk3" + + assert store["inv-a"]["status"] == "superseded" + assert "output" in store["inv-a"] + assert len(store["inv-a"]["output"]["reply"]) > 0 + assert store["inv-b"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_claude_rapid_fire_preserves_intermediate_messages(self, tmp_path): + """Rapid-fire: A→B→C. B is short-circuited but its message is preserved in external store.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + conv_store: dict[str, list[dict[str, str]]] = {} + + @durable_task(name="e2e_claude_chat", steerable=True) + async def claude_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + history = list(conv_store.get(session_id, [])) + history.append({"role": "user", "content": message}) + if ctx.cancel.is_set(): + conv_store[session_id] = history + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + reply = "" + was_aborted = False + async with _MockStreamCtx([f"Reply to {message}"], delay=0.3) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + was_aborted = True + break + if reply: + history.append({"role": "assistant", "content": reply}) + conv_store[session_id] = history + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await claude_chat.start( + task_id="claude-rf", + input={"session_id": "s1", "message": "A", "invocation_id": "rf-a"}, + ) + await asyncio.sleep(0.05) + + run_b = await claude_chat.start( + task_id="claude-rf", + input={"session_id": "s1", "message": "B", "invocation_id": "rf-b"}, + ) + run_c = await claude_chat.start( + task_id="claude-rf", + input={"session_id": "s1", "message": "C", "invocation_id": "rf-c"}, + ) + + result_c = await asyncio.wait_for(run_c.result(), timeout=5.0) + assert result_c.output["reply"] == "Reply to C" + + # B was short-circuited but message preserved in external store + assert store["rf-b"]["message_preserved"] is True + assert store["rf-b"]["status"] == "cancelled" + # All user messages should be in external history + user_msgs = [m["content"] for m in conv_store["s1"] if m["role"] == "user"] + assert "B" in user_msgs # B's message was NOT lost + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample E2E: Copilot-style steering (durable_copilot) +# --------------------------------------------------------------------------- + + +class _MockCopilotSession: + """Simulates a Copilot SDK session with event-based send + abort. + + Mirrors the real pattern: ``session.on(handler)`` registers an event + listener, ``session.send(msg)`` fires ``AssistantMessageData`` events + then ``IdleData``, and ``session.abort()`` stops further events. + """ + + def __init__(self, reply_chunks: list[str], delay: float = 0.1): + self._chunks = reply_chunks + self._delay = delay + self._handler: Any = None + self._aborted = False + self._idle_event = asyncio.Event() + + def on(self, handler: Any) -> None: + self._handler = handler + + async def send(self, message: str) -> None: + """Deliver reply chunks as events, then fire idle.""" + asyncio.get_event_loop().create_task(self._deliver_events()) + + async def _deliver_events(self) -> None: + for chunk in self._chunks: + if self._aborted: + break + await asyncio.sleep(self._delay) + if self._aborted: + break + if self._handler: + # Simulate AssistantMessageData event + event = type("E", (), {"data": type("D", (), {"content": chunk})()})() + self._handler(event) + if not self._aborted and self._handler: + # Simulate IdleData event + idle_data = type("IdleData", (), {})() + event = type("E", (), {"data": idle_data})() + self._handler(event) + self._idle_event.set() + + async def abort(self) -> None: + self._aborted = True + + +class TestCopilotSteeringSampleE2E: + """E2E for the durable_copilot steering sample. + + Uses ``_MockCopilotSession`` that mirrors the real Copilot SDK + event-based pattern: ``session.on(handler)`` → ``session.send()`` + → events fire → ``session.abort()`` on cancel. + """ + + @pytest.mark.asyncio + async def test_copilot_normal_turn(self, tmp_path): + """Normal turn completes with full reply via event-based send.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_copilot_chat", steerable=True) + async def copilot_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + if ctx.cancel.is_set(): + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + + # Event-based send (mirrors session.on + session.send) + session = _MockCopilotSession([f"Echo: {message}"]) + reply_parts: list[str] = [] + idle_event = asyncio.Event() + + def on_event(event: Any) -> None: + if hasattr(event.data, "content"): + reply_parts.append(event.data.content or "") + elif type(event.data).__name__ == "IdleData": + idle_event.set() + + session.on(on_event) + await session.send(message) + + # Wait for idle or cancel + cancel_task = asyncio.create_task(ctx.cancel.wait()) + idle_task = asyncio.create_task(idle_event.wait()) + was_aborted = False + try: + done, pending = await asyncio.wait( + {cancel_task, idle_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if cancel_task in done and idle_task not in done: + was_aborted = True + await session.abort() + finally: + for t in (cancel_task, idle_task): + if not t.done(): + t.cancel() + + reply = "".join(reply_parts) + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run = await copilot_chat.start( + task_id="copilot-s1", + input={ + "session_id": "s1", + "message": "Explain decorators", + "invocation_id": "inv-1", + }, + ) + result = await asyncio.wait_for(run.result(), timeout=5.0) + assert result.is_suspended + assert result.output["reply"] == "Echo: Explain decorators" + assert result.output["partial"] is False + assert store["inv-1"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_copilot_steering_preserves_reply(self, tmp_path): + """Steering queues B while A is streaming. A's partial reply saved as superseded.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_copilot_chat", steerable=True) + async def copilot_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + if ctx.cancel.is_set(): + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + + session = _MockCopilotSession(["part1-", "part2-", "part3"], delay=0.15) + reply_parts: list[str] = [] + idle_event = asyncio.Event() + + def on_event(event: Any) -> None: + if hasattr(event.data, "content"): + reply_parts.append(event.data.content or "") + elif type(event.data).__name__ == "IdleData": + idle_event.set() + + session.on(on_event) + await session.send(message) + + cancel_task = asyncio.create_task(ctx.cancel.wait()) + idle_task = asyncio.create_task(idle_event.wait()) + was_aborted = False + try: + done, pending = await asyncio.wait( + {cancel_task, idle_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if cancel_task in done and idle_task not in done: + was_aborted = True + await session.abort() + finally: + for t in (cancel_task, idle_task): + if not t.done(): + t.cancel() + + reply = "".join(reply_parts) + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await copilot_chat.start( + task_id="copilot-s1", + input={ + "session_id": "s1", + "message": "decorators", + "invocation_id": "inv-a", + }, + ) + await asyncio.sleep(0.05) + + store["inv-b"] = {"status": "queued"} + run_b = await copilot_chat.start( + task_id="copilot-s1", + input={ + "session_id": "s1", + "message": "async/await", + "invocation_id": "inv-b", + }, + ) + + assert store["inv-b"]["status"] == "queued" + + result_a = await asyncio.wait_for(run_a.result(), timeout=5.0) + assert result_a.is_superseded + + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_suspended + assert result_b.output["reply"] == "part1-part2-part3" + + # A should be superseded (reply may be empty or partial — event + # delivery is async, so cancel can arrive before any events fire) + assert store["inv-a"]["status"] == "superseded" + assert "output" in store["inv-a"] + assert store["inv-b"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample E2E: LangGraph steering path (durable_langgraph) +# --------------------------------------------------------------------------- + + +class TestLangGraphSteeringSampleE2E: + """E2E for the durable_langgraph sample's steering path. + + Exercises the framework steering lifecycle (queued → cancel → drain → + re-enter) using a simplified LangGraph-like pattern with checkpointing + and invocation store writes. + """ + + @pytest.mark.asyncio + async def test_langgraph_steering_cancels_and_resumes(self, tmp_path): + """Steer while A is running → A cancelled → B processes from checkpoint.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + checkpoints: list[str] = [] + + @durable_task(name="e2e_lg_session", steerable=True) + async def lg_session(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + # Simulate multi-step graph processing + await asyncio.sleep(0.1) # Step 1: analyze + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + await asyncio.sleep(0.1) # Step 2: generate + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + reply = f"[graph] Processed: {message}" + + # Save checkpoint + cp_id = f"cp-{ctx.generation}" + checkpoints.append(cp_id) + ctx.metadata.set("stable_checkpoint_id", cp_id) + + output = {"invocation_id": invocation_id, "reply": reply} + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await lg_session.start( + task_id="lg-s1", + input={ + "session_id": "s1", + "message": "Plan a trip", + "invocation_id": "lg-a", + }, + ) + await asyncio.sleep(0.05) + + # Steer while A is running + store["lg-b"] = {"status": "queued"} + run_b = await lg_session.start( + task_id="lg-s1", + input={ + "session_id": "s1", + "message": "Go to Paris", + "invocation_id": "lg-b", + }, + ) + assert store["lg-b"]["status"] == "queued" + + result_a = await asyncio.wait_for(run_a.result(), timeout=5.0) + assert result_a.is_superseded + + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_suspended + assert result_b.output["reply"] == "[graph] Processed: Go to Paris" + + assert store["lg-a"]["status"] == "cancelled" + assert store["lg-b"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_langgraph_multi_turn_then_steer(self, tmp_path): + """Normal turn 1 → resume turn 2 → steer during turn 2 with turn 3.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_lg_session", steerable=True, ephemeral=False) + async def lg_session(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + await asyncio.sleep(0.3) # Simulated processing + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + reply = f"[graph] {message} (gen={ctx.generation})" + output = {"invocation_id": invocation_id, "reply": reply} + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + # Turn 1: normal + run1 = await lg_session.start( + task_id="lg-mt", + input={"session_id": "s1", "message": "Turn1", "invocation_id": "mt-1"}, + ) + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_suspended + assert store["mt-1"]["status"] == "completed" + + # Turn 2: resume + run2 = await lg_session.start( + task_id="lg-mt", + input={"session_id": "s1", "message": "Turn2", "invocation_id": "mt-2"}, + ) + await asyncio.sleep(0.05) + + # Turn 3: steer while turn 2 is running + store["mt-3"] = {"status": "queued"} + run3 = await lg_session.start( + task_id="lg-mt", + input={"session_id": "s1", "message": "Turn3", "invocation_id": "mt-3"}, + ) + assert store["mt-3"]["status"] == "queued" + + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_superseded + + result3 = await asyncio.wait_for(run3.result(), timeout=5.0) + assert result3.is_suspended + assert "Turn3" in result3.output["reply"] + assert store["mt-2"]["status"] == "cancelled" + assert store["mt-3"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# SSE Streaming: lifecycle events, text deltas, steering supersession +# --------------------------------------------------------------------------- + + +class TestSSEStreamingE2E: + """E2E tests for the SSE streaming pattern used by all samples.""" + + @pytest.mark.asyncio + async def test_lifecycle_and_text_deltas_streamed(self, tmp_path): + """ctx.stream() emits lifecycle:running then text_delta events.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_sse_stream") + async def sse_stream(ctx: TaskContext[dict]) -> dict[str, Any]: + invocation_id = ctx.input["invocation_id"] + await ctx.stream({"type": "lifecycle", "status": "running"}) + reply = "" + for token in ["Hello", " ", "world"]: + reply += token + await ctx.stream({"type": "text_delta", "delta": token}) + return { + "invocation_id": invocation_id, + "reply": reply, + } + + run = await sse_stream.start( + task_id="sse-1", + input={"invocation_id": "inv-sse-1"}, + ) + + chunks: list[dict[str, Any]] = [] + async for chunk in run: + chunks.append(chunk) + + result = await asyncio.wait_for(run.result(), timeout=5.0) + + # First chunk: lifecycle running + assert chunks[0] == {"type": "lifecycle", "status": "running"} + # Then three text deltas + assert chunks[1] == {"type": "text_delta", "delta": "Hello"} + assert chunks[2] == {"type": "text_delta", "delta": " "} + assert chunks[3] == {"type": "text_delta", "delta": "world"} + assert len(chunks) == 4 + assert result.output["reply"] == "Hello world" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_steering_produces_superseded_stream(self, tmp_path): + """When steering cancels a running task, the stream ends after cancel.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_sse_steer", steerable=True) + async def sse_steer(ctx: TaskContext[dict]) -> dict[str, Any]: + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + await ctx.stream({"type": "lifecycle", "status": "running"}) + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + # Simulate slow generation that gets interrupted + reply = "" + for token in ["Slow", " ", "reply", " ", "here"]: + reply += token + await ctx.stream({"type": "text_delta", "delta": token}) + await asyncio.sleep(0.05) + if ctx.cancel.is_set(): + store[invocation_id] = { + "status": "superseded", + "partial_reply": reply, + } + return await ctx.suspend(reason="steered") + + store[invocation_id] = {"status": "completed", "reply": reply} + return await ctx.suspend( + reason="awaiting_user_input", + output={"invocation_id": invocation_id, "reply": reply}, + ) + + # Start turn 1 + run1 = await sse_steer.start( + task_id="sse-steer-1", + input={"invocation_id": "inv-s1"}, + ) + + # Collect some chunks from turn 1 + chunks1: list[dict[str, Any]] = [] + async for chunk in run1: + chunks1.append(chunk) + if len(chunks1) >= 2: + # Steer with turn 2 while turn 1 is streaming + await sse_steer.start( + task_id="sse-steer-1", + input={"invocation_id": "inv-s2"}, + ) + break + + # Turn 1 should have been superseded + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_superseded + assert store["inv-s1"]["status"] in ("superseded", "cancelled") + + # First chunk was lifecycle:running + assert chunks1[0] == {"type": "lifecycle", "status": "running"} + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_stream_with_invocation_store_snapshots(self, tmp_path): + """Dual-write: ctx.stream() for live SSE + store for GET snapshots.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_sse_snapshot") + async def sse_snapshot(ctx: TaskContext[dict]) -> dict[str, Any]: + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + await ctx.stream({"type": "lifecycle", "status": "running"}) + + reply = "" + for token in ["A", "B", "C"]: + reply += token + await ctx.stream({"type": "text_delta", "delta": token}) + store[invocation_id] = {"status": "streaming", "text": reply} + + store[invocation_id] = { + "status": "completed", + "reply": reply, + } + return {"invocation_id": invocation_id, "reply": reply} + + run = await sse_snapshot.start( + task_id="sse-snap-1", + input={"invocation_id": "inv-snap-1"}, + ) + + chunks: list[dict[str, Any]] = [] + async for chunk in run: + chunks.append(chunk) + + result = await asyncio.wait_for(run.result(), timeout=5.0) + + # Stream had lifecycle + 3 deltas + assert len(chunks) == 4 + assert chunks[0]["type"] == "lifecycle" + + # Store has final snapshot + assert store["inv-snap-1"]["status"] == "completed" + assert store["inv-snap-1"]["reply"] == "ABC" + assert result.output["reply"] == "ABC" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_source.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_source.py new file mode 100644 index 000000000000..6faed9e06f38 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_source.py @@ -0,0 +1,140 @@ +"""Tests for source field support on TaskInfo and TaskCreateRequest.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + TaskInfo, +) + + +class TestTaskInfoSource: + """Source field on TaskInfo.""" + + def test_default_none(self): + info = TaskInfo(id="t1", agent_name="a", session_id="s", status="pending") + assert info.source is None + + def test_set_at_construction(self): + src = {"type": "user", "origin": "cli"} + info = TaskInfo( + id="t1", agent_name="a", session_id="s", status="pending", source=src + ) + assert info.source == src + + def test_to_dict_includes_source(self): + src = {"type": "api", "request_id": "r1"} + info = TaskInfo( + id="t1", agent_name="a", session_id="s", status="pending", source=src + ) + d = info.to_dict() + assert d["source"] == src + + def test_to_dict_omits_none_source(self): + info = TaskInfo(id="t1", agent_name="a", session_id="s", status="pending") + d = info.to_dict() + assert "source" not in d + + def test_from_dict_with_source(self): + data = { + "id": "t1", + "agent_name": "a", + "session_id": "s", + "status": "pending", + "source": {"type": "workflow", "step": 3}, + } + info = TaskInfo.from_dict(data) + assert info.source == {"type": "workflow", "step": 3} + + def test_from_dict_without_source(self): + data = {"id": "t1", "agent_name": "a", "session_id": "s", "status": "pending"} + info = TaskInfo.from_dict(data) + assert info.source is None + + def test_round_trip(self): + src = {"origin": "test", "nested": {"a": 1}} + info = TaskInfo( + id="t1", agent_name="a", session_id="s", status="pending", source=src + ) + restored = TaskInfo.from_dict(info.to_dict()) + assert restored.source == src + + +class TestTaskCreateRequestSource: + """Source field on TaskCreateRequest.""" + + def test_default_none(self): + req = TaskCreateRequest(agent_name="a", session_id="s") + assert req.source is None + + def test_set_at_construction(self): + src = {"type": "decorator"} + req = TaskCreateRequest(agent_name="a", session_id="s", source=src) + assert req.source == src + + +class TestSourceLocalProvider: + """Source persisted via LocalFileDurableTaskProvider.""" + + @pytest.mark.asyncio + async def test_source_persisted_and_retrieved(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + src = {"type": "test", "run_id": "abc123"} + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + source=src, + ) + created = await provider.create(req) + assert created.source == src + + # Re-read from disk + fetched = await provider.get(created.id) + assert fetched is not None + assert fetched.source == src + + @pytest.mark.asyncio + async def test_source_none_not_persisted(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + req = TaskCreateRequest(agent_name="agent", session_id="test-session") + created = await provider.create(req) + assert created.source is None + + fetched = await provider.get(created.id) + assert fetched is not None + assert fetched.source is None + + @pytest.mark.asyncio + async def test_source_immutable_after_create(self, tmp_path): + """Source must not be changeable via PATCH — TaskPatchRequest has no source field.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + source={"type": "original"}, + ) + created = await provider.create(req) + + # Patch does not touch source + await provider.update(created.id, TaskPatchRequest(tags={"k": "v"})) + fetched = await provider.get(created.id) + assert fetched is not None + assert fetched.source == {"type": "original"} diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_steering.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_steering.py new file mode 100644 index 000000000000..0f930ac863e3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_steering.py @@ -0,0 +1,679 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for steerable durable tasks — steering, drain, context, and recovery.""" + +import asyncio +import json +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + TaskResult, + durable_task, + EntryMode, + EtagConflict, + SteeringQueueFull, + TaskConflictError, +) + + +class TestSteering: + """Core steering functionality: append, drain, short-circuit.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + # ------------------------------------------------------------------ + # US1: Basic steering + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steerable_start_on_in_progress_queues_input(self, tmp_path): + """start() on in_progress steerable task appends input, not raises.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + # Simulate work with small delay + await asyncio.sleep(0.5) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + # Start first invocation + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Small delay for A to enter function body + await asyncio.sleep(0.1) + + # Steer while in progress — should NOT raise + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + # run2 should be a TaskRun (ack), not raise TaskConflictError + assert run2.task_id == "t1" + + # Verify queue has the input + task_info = await manager.provider.get("t1") + steering = task_info.payload.get("_steering", {}) + assert len(steering["pending_inputs"]) >= 1 + assert steering["cancel_requested"] is True + + # run1 should be superseded (A was cancelled) + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_superseded + + # run2 should complete (B runs after drain) + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_completed + assert result2.output == {"msg": "B"} + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_non_steerable_raises_conflict(self, tmp_path): + """start() on in_progress non-steerable task still raises.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + gate = asyncio.Event() + + @durable_task(name="regular") + async def regular(ctx: TaskContext[dict]) -> dict: + await gate.wait() + return {"msg": "done"} + + run1 = await regular.start(task_id="t1", input={"msg": "A"}) + + with pytest.raises(TaskConflictError): + await regular.start(task_id="t1", input={"msg": "B"}) + + gate.set() + await asyncio.wait_for(run1.result(), timeout=5.0) + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_steering_queue_full(self, tmp_path): + """start() raises SteeringQueueFull when queue is at capacity.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + gate = asyncio.Event() + + @durable_task(name="chat", steerable=True, max_pending=2) + async def chat(ctx: TaskContext[dict]) -> dict: + await gate.wait() + return {"msg": "done"} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Fill the queue + await chat.start(task_id="t1", input={"msg": "B"}) + await chat.start(task_id="t1", input={"msg": "C"}) + + # Queue is full — should raise + with pytest.raises(SteeringQueueFull) as exc_info: + await chat.start(task_id="t1", input={"msg": "D"}) + + assert exc_info.value.max_pending == 2 + + gate.set() + await asyncio.wait_for(run1.result(), timeout=5.0) + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_superseded_result_status(self, tmp_path): + """Superseded generation's TaskRun resolves with status=superseded.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + # Always check cancel and suspend if set + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + # Simulate work — gives time for cancel signal + await asyncio.sleep(0.3) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Small delay to ensure task is running + await asyncio.sleep(0.1) + + # Steer + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + # run1 should be superseded + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_superseded + + # run2 should complete + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_completed + assert result2.output == {"msg": "B"} + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # US2: Rapid-fire short-circuit + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_rapid_fire_only_last_completes(self, tmp_path): + """3 rapid-fire steers: only the last gen runs to completion.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + entries: list[tuple[str, bool]] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + entries.append((ctx.input.get("msg", "?"), ctx.cancel.is_set())) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Small delay for A to start + await asyncio.sleep(0.05) + + # Rapid-fire B, C, D + run_b = await chat.start(task_id="t1", input={"msg": "B"}) + run_c = await chat.start(task_id="t1", input={"msg": "C"}) + run_d = await chat.start(task_id="t1", input={"msg": "D"}) + + # D should be the one that completes + result_d = await asyncio.wait_for(run_d.result(), timeout=5.0) + assert result_d.is_completed + assert result_d.output == {"msg": "D"} + + # B and C should be superseded + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_superseded + + result_c = await asyncio.wait_for(run_c.result(), timeout=5.0) + assert result_c.is_superseded + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_cancel_pre_set_when_queue_has_items(self, tmp_path): + """ctx.cancel is pre-set at function entry when queue has items.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + cancel_states: list[bool] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + cancel_states.append(ctx.cancel.is_set()) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.05) + + # Queue B and C + run_b = await chat.start(task_id="t1", input={"msg": "B"}) + run_c = await chat.start(task_id="t1", input={"msg": "C"}) + + result_c = await asyncio.wait_for(run_c.result(), timeout=5.0) + assert result_c.is_completed + + # A: cancel set by steering signal + # B: cancel pre-set (C still queued) + # C: cancel NOT set (nothing queued after C) + # cancel_states should have at least 3 entries + assert len(cancel_states) >= 3 + # The last one (C) should be False + assert cancel_states[-1] is False + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # US3: Context enrichment + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steered_context_fields(self, tmp_path): + """Steered generation has was_steered=True, previous_input set.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + contexts: list[dict[str, Any]] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + contexts.append( + { + "entry_mode": ctx.entry_mode, + "was_steered": ctx.was_steered, + "previous_input": ctx.previous_input, + "generation": ctx.generation, + "msg": ctx.input.get("msg", "?"), + } + ) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + # Simulate work — gives time for steering signal + await asyncio.sleep(0.3) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.1) + + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_completed + + # First entry: fresh, not steered + assert contexts[0]["entry_mode"] == "fresh" + assert contexts[0]["was_steered"] is False + assert contexts[0]["generation"] == 0 + + # Second entry: steered (entry_mode="resumed" with was_steered=True) + steered = [c for c in contexts if c["was_steered"] is True] + assert len(steered) >= 1 + assert steered[0]["entry_mode"] == "resumed" + assert steered[0]["generation"] > 0 + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_entry_mode_steered(self, tmp_path): + """Steered generations enter with entry_mode='resumed' and was_steered=True.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + modes: list[str] = [] + steered_flags: list[bool] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + modes.append(ctx.entry_mode) + steered_flags.append(ctx.was_steered) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + await asyncio.sleep(0.3) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": "done"} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.1) + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + await asyncio.wait_for(run2.result(), timeout=5.0) + + assert "fresh" in modes + assert "resumed" in modes + # The steered generation should have was_steered=True + assert True in steered_flags + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # TaskResult.is_superseded + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_task_result_is_superseded(self): + """TaskResult with status=superseded has is_superseded=True.""" + result = TaskResult(task_id="t1", status="superseded") + assert result.is_superseded is True + assert result.is_completed is False + assert result.is_suspended is False + assert result.output is None + + @pytest.mark.asyncio + async def test_task_result_completed_not_superseded(self): + """TaskResult with status=completed has is_superseded=False.""" + result = TaskResult(task_id="t1", status="completed", output=42) + assert result.is_superseded is False + assert result.is_completed is True + + # ------------------------------------------------------------------ + # Options passthrough + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steerable_via_options(self, tmp_path): + """steerable can be set via .options().""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + gate = asyncio.Event() + + @durable_task(name="chat") + async def chat(ctx: TaskContext[dict]) -> dict: + await gate.wait() + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": "done"} + + steerable_chat = chat.options(steerable=True) + + run1 = await steerable_chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.05) + + # This should work because steerable=True via options + run2 = await steerable_chat.start(task_id="t1", input={"msg": "B"}) + assert run2.task_id == "t1" + + gate.set() + await asyncio.wait_for(run2.result(), timeout=5.0) + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # DurableTaskOptions validation + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_max_pending_validation(self): + """max_pending < 1 raises ValueError at decoration time.""" + with pytest.raises(ValueError, match="max_pending"): + + @durable_task(name="bad", max_pending=0) + async def bad(ctx: TaskContext[dict]) -> dict: + return {} + + # ------------------------------------------------------------------ + # Exceptions + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_etag_conflict_exception(self): + """EtagConflict has task_id attribute.""" + exc = EtagConflict("t1", "test message") + assert exc.task_id == "t1" + assert "test message" in str(exc) + + @pytest.mark.asyncio + async def test_steering_queue_full_exception(self): + """SteeringQueueFull has task_id and max_pending attributes.""" + exc = SteeringQueueFull("t1", 10) + assert exc.task_id == "t1" + assert exc.max_pending == 10 + assert "10" in str(exc) + + # ------------------------------------------------------------------ + # Steering with function that completes (not suspends) + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steering_function_ignores_cancel_completes(self, tmp_path): + """If function ignores cancel and returns, steering still works via drain.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + call_count = 0 + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat(ctx: TaskContext[dict]) -> dict: + nonlocal call_count + call_count += 1 + # Intentionally ignores ctx.cancel + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Wait for A to complete + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_completed + + # For non-ephemeral completed tasks, steerable or not, raises conflict + with pytest.raises(TaskConflictError): + await chat.start(task_id="t1", input={"msg": "B"}) + + finally: + await self._teardown_manager(manager, mgr_mod) + + +class TestSteeringRecovery: + """Crash recovery for steerable tasks.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_recovery_with_drain_in_progress(self, tmp_path): + """Recovery after crash mid-drain uses active_input from steering state.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + from azure.ai.agentserver.core.durable._models import ( + TaskPatchRequest, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + + # Phase 1: Create a task and simulate crash mid-drain + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat(ctx: TaskContext[dict]) -> dict: + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.wait_for(run1.result(), timeout=5.0) + + # Simulate crash state: task is in_progress with drain_in_progress + # Reset status to in_progress and inject steering state + await provider.update( + "t1", + TaskPatchRequest( + status="in_progress", + payload={ + "_steering": { + "generation": 1, + "active_input": {"msg": "B"}, + "previous_input": {"msg": "A"}, + "pending_inputs": [], + "cancel_requested": False, + "drain_in_progress": True, + }, + }, + ), + ) + + await manager.shutdown() + mgr_mod._manager = None + + # Phase 2: Recover — new manager picks up the crashed task + manager2 = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager2 + await manager2.startup() + + inputs_seen: list[dict] = [] + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat2(ctx: TaskContext[dict]) -> dict: + inputs_seen.append(dict(ctx.input)) + return {"msg": ctx.input.get("msg", "?")} + + # Start with recovery input (doesn't matter — active_input overrides) + run2 = await chat2.start( + task_id="t1", input={"msg": "recovery"}, stale_timeout=0.0 + ) + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + + # Should have used active_input "B", not the recovery caller input + assert result2.output == {"msg": "B"} + assert inputs_seen[-1] == {"msg": "B"} + + await manager2.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_recovery_with_pending_inputs(self, tmp_path): + """Recovery with pending inputs drains them after function completes.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + from azure.ai.agentserver.core.durable._models import ( + TaskPatchRequest, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + + # Phase 1: Create a task normally, then simulate crash with pending + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat_setup(ctx: TaskContext[dict]) -> dict: + # Long sleep — we'll kill the manager before it completes + await asyncio.sleep(10) + return {"msg": "never"} + + run1 = await chat_setup.start(task_id="t2", input={"msg": "X"}) + await asyncio.sleep(0.1) # let it start + + # Force shutdown (simulates crash) + await manager.shutdown() + mgr_mod._manager = None + + # Patch the task to simulate crash-with-pending state + await provider.update( + "t2", + TaskPatchRequest( + status="in_progress", + payload={ + "input": {"msg": "X"}, + "_steering": { + "generation": 0, + "active_input": {"msg": "X"}, + "pending_inputs": [{"msg": "Y"}, {"msg": "Z"}], + "cancel_requested": True, + "drain_in_progress": False, + }, + }, + ), + ) + + # Phase 2: New manager recovers the task + manager2 = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager2 + await manager2.startup() + + inputs_seen: list[str] = [] + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat(ctx: TaskContext[dict]) -> dict: + inputs_seen.append(ctx.input.get("msg", "?")) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run2 = await chat.start( + task_id="t2", input={"msg": "recover"}, stale_timeout=0.0 + ) + result = await asyncio.wait_for(run2.result(), timeout=5.0) + + # Should have drained through X (cancel set) → Y (cancel set) → Z (complete) + assert result.output == {"msg": "Z"} + assert "Z" in inputs_seen + + await manager2.shutdown() + mgr_mod._manager = None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_streaming.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_streaming.py new file mode 100644 index 000000000000..ca77256e2913 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_streaming.py @@ -0,0 +1,181 @@ +"""Tests for streaming support (ctx.stream + async-for on TaskRun).""" + +from __future__ import annotations + +import asyncio + +import pytest + +from azure.ai.agentserver.core.durable._context import TaskContext +from azure.ai.agentserver.core.durable._metadata import TaskMetadata +from azure.ai.agentserver.core.durable._run import TaskRun +from azure.ai.agentserver.core.durable._stream import QueueStreamHandler + + +def _make_ctx(stream_handler=None, **overrides): + defaults = dict( + task_id="t1", + title="test", + session_id="s1", + agent_name="a1", + tags={}, + input=None, + metadata=TaskMetadata(), + stream_handler=stream_handler, + ) + defaults.update(overrides) + return TaskContext(**defaults) + + +def _make_run(stream_handler=None, result_future=None, **overrides): + loop = asyncio.get_event_loop() + if result_future is None: + result_future = loop.create_future() + defaults = dict( + task_id="t1", + provider=None, + result_future=result_future, + metadata=TaskMetadata(), + cancel_event=asyncio.Event(), + stream_handler=stream_handler, + ) + defaults.update(overrides) + return TaskRun(**defaults) + + +class TestContextStream: + """ctx.stream() puts items via the handler.""" + + @pytest.mark.asyncio + async def test_stream_puts_item(self): + handler = QueueStreamHandler() + ctx = _make_ctx(stream_handler=handler) + await ctx.stream("hello") + assert await handler.get() == "hello" + + @pytest.mark.asyncio + async def test_stream_multiple_items(self): + handler = QueueStreamHandler() + ctx = _make_ctx(stream_handler=handler) + await ctx.stream(1) + await ctx.stream(2) + await ctx.stream(3) + assert await handler.get() == 1 + assert await handler.get() == 2 + assert await handler.get() == 3 + + @pytest.mark.asyncio + async def test_stream_no_handler_noop(self): + ctx = _make_ctx(stream_handler=None) + # Should not raise + await ctx.stream("ignored") + + @pytest.mark.asyncio + async def test_stream_various_types(self): + handler = QueueStreamHandler() + ctx = _make_ctx(stream_handler=handler) + items = ["text", 42, {"key": "val"}, [1, 2], None, True] + for item in items: + await ctx.stream(item) + collected = [await handler.get() for _ in range(len(items))] + assert collected == items + + +class TestTaskRunAsyncIter: + """TaskRun.__aiter__ / __anext__ consume via the stream handler.""" + + @pytest.mark.asyncio + async def test_iterate_items(self): + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + await handler.put("a") + await handler.put("b") + await handler.close() + + collected = [] + async for item in run: + collected.append(item) + assert collected == ["a", "b"] + + @pytest.mark.asyncio + async def test_empty_stream(self): + """close() immediately → no items.""" + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + await handler.close() + + collected = [] + async for item in run: + collected.append(item) + assert collected == [] + + @pytest.mark.asyncio + async def test_no_handler_stops_immediately(self): + run = _make_run(stream_handler=None) + collected = [] + async for item in run: + collected.append(item) + assert collected == [] + + @pytest.mark.asyncio + async def test_stream_and_result(self): + """Stream items, then also await result().""" + handler = QueueStreamHandler() + loop = asyncio.get_event_loop() + fut: asyncio.Future = loop.create_future() + run = _make_run(stream_handler=handler, result_future=fut) + + await handler.put("chunk1") + await handler.put("chunk2") + await handler.close() + fut.set_result("final") + + collected = [] + async for item in run: + collected.append(item) + assert collected == ["chunk1", "chunk2"] + result = await run.result() + assert result == "final" # Unit test uses raw future, not manager pipeline + + @pytest.mark.asyncio + async def test_concurrent_producer_consumer(self): + """Producer streams while consumer iterates.""" + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + + async def produce(): + for i in range(5): + await handler.put(i) + await asyncio.sleep(0.01) + await handler.close() + + collected = [] + + async def consume(): + async for item in run: + collected.append(item) + + await asyncio.gather(produce(), consume()) + assert collected == [0, 1, 2, 3, 4] + + +class TestStreamingErrorCases: + """Streaming under error/suspend/cancel conditions.""" + + @pytest.mark.asyncio + async def test_close_terminates_iteration(self): + """close() terminates iteration cleanly.""" + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + await handler.put("partial") + await handler.close() + + collected = [] + async for item in run: + collected.append(item) + assert collected == ["partial"] + + @pytest.mark.asyncio + async def test_aiter_returns_self(self): + run = _make_run(stream_handler=QueueStreamHandler()) + assert run.__aiter__() is run diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_task_result.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_task_result.py new file mode 100644 index 000000000000..960311ebb6dc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_task_result.py @@ -0,0 +1,130 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for the TaskResult wrapper class.""" + +import pytest + +from azure.ai.agentserver.core.durable import TaskResult + + +class TestTaskResult: + """Tests for TaskResult construction and properties.""" + + def test_completed_result(self): + """A completed result has is_completed=True, is_suspended=False.""" + r = TaskResult(task_id="t1", output="hello", status="completed") + assert r.is_completed + assert not r.is_suspended + assert r.output == "hello" + assert r.task_id == "t1" + assert r.suspension_reason is None + + def test_suspended_result(self): + """A suspended result has is_suspended=True, is_completed=False.""" + r = TaskResult( + task_id="t2", + output={"turn": 1}, + status="suspended", + suspension_reason="awaiting_user", + ) + assert r.is_suspended + assert not r.is_completed + assert r.output == {"turn": 1} + assert r.suspension_reason == "awaiting_user" + + def test_suspended_without_output(self): + """A suspended result can have no output (output=None).""" + r = TaskResult(task_id="t3", status="suspended") + assert r.is_suspended + assert r.output is None + assert r.suspension_reason is None + + def test_completed_with_none_output(self): + """A completed result can return None explicitly.""" + r = TaskResult(task_id="t4", output=None, status="completed") + assert r.is_completed + assert r.output is None + + def test_completed_with_complex_output(self): + """TaskResult works with dict outputs.""" + data = {"items": [1, 2, 3], "total": 3} + r = TaskResult(task_id="t5", output=data, status="completed") + assert r.output["items"] == [1, 2, 3] + assert r.output["total"] == 3 + + def test_repr_completed(self): + """__repr__ shows status and output for completed results.""" + r = TaskResult(task_id="t6", output="done", status="completed") + s = repr(r) + assert "t6" in s + assert "completed" in s + assert "done" in s + assert "suspension_reason" not in s + + def test_repr_suspended(self): + """__repr__ includes suspension_reason when present.""" + r = TaskResult( + task_id="t7", output=None, status="suspended", suspension_reason="waiting" + ) + s = repr(r) + assert "suspended" in s + assert "waiting" in s + + def test_repr_truncates_long_output(self): + """__repr__ truncates output longer than 60 chars.""" + long_val = "x" * 100 + r = TaskResult(task_id="t8", output=long_val, status="completed") + s = repr(r) + assert "..." in s + assert len(s) < 200 + + +class TestNestedTaskResultGuard: + """Test that returning TaskResult from a task function raises TypeError.""" + + @pytest.mark.asyncio + async def test_returning_taskresult_raises_typeerror(self, tmp_path): + """Task function that returns TaskResult directly gets TypeError.""" + import uuid + from pathlib import Path + from azure.ai.agentserver.core.durable import TaskContext, durable_task + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import DurableTaskManager + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test", + "session_id": "test", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + + try: + from typing import Any + from azure.ai.agentserver.core.durable import TaskContext + + @durable_task(name="bad_return") + async def bad_task(ctx: TaskContext[Any]) -> Any: + return TaskResult( + task_id=ctx.task_id, output="data", status="completed" + ) + + from azure.ai.agentserver.core.durable._exceptions import TaskFailed + + with pytest.raises(TaskFailed): + await bad_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await manager.shutdown() + mgr_mod._manager = None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py index be194e6ec0fd..f1de9e022dea 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py @@ -10,25 +10,33 @@ class TestAgentConfigIsHosted: """Tests for AgentConfig.is_hosted snapshotting behavior.""" - def test_is_hosted_false_when_env_var_absent(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_false_when_env_var_absent( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted is False when FOUNDRY_HOSTING_ENVIRONMENT is not set.""" monkeypatch.delenv("FOUNDRY_HOSTING_ENVIRONMENT", raising=False) config = AgentConfig.from_env() assert config.is_hosted is False - def test_is_hosted_false_when_env_var_empty(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_false_when_env_var_empty( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted is False when FOUNDRY_HOSTING_ENVIRONMENT is set to an empty string.""" monkeypatch.setenv("FOUNDRY_HOSTING_ENVIRONMENT", "") config = AgentConfig.from_env() assert config.is_hosted is False - def test_is_hosted_true_when_env_var_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_true_when_env_var_set( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted is True when FOUNDRY_HOSTING_ENVIRONMENT is set to a non-empty value.""" monkeypatch.setenv("FOUNDRY_HOSTING_ENVIRONMENT", "production") config = AgentConfig.from_env() assert config.is_hosted is True - def test_is_hosted_snapshotted_at_creation(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_snapshotted_at_creation( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted reflects the env var value at creation time, not at access time.""" monkeypatch.setenv("FOUNDRY_HOSTING_ENVIRONMENT", "production") config = AgentConfig.from_env() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py index 7c538c0ddc31..c15bccfd85b0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py @@ -11,7 +11,10 @@ import pytest from azure.ai.agentserver.core import AgentServerHost -from azure.ai.agentserver.core._config import resolve_graceful_shutdown_timeout, _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT +from azure.ai.agentserver.core._config import ( + resolve_graceful_shutdown_timeout, + _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT, +) # ------------------------------------------------------------------ # @@ -26,7 +29,10 @@ def test_explicit_wins(self) -> None: assert resolve_graceful_shutdown_timeout(10) == 10 def test_default(self) -> None: - assert resolve_graceful_shutdown_timeout(None) == _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + assert ( + resolve_graceful_shutdown_timeout(None) + == _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + ) def test_non_int_explicit_raises(self) -> None: with pytest.raises(ValueError, match="expected an integer"): @@ -193,7 +199,10 @@ async def send(message): await agent(scope, receive, send) # The error should be logged - assert any("on_shutdown" in r.message.lower() or "error" in r.message.lower() for r in caplog.records) + assert any( + "on_shutdown" in r.message.lower() or "error" in r.message.lower() + for r in caplog.records + ) # Server should still complete shutdown assert any(m["type"] == "lifespan.shutdown.complete" for m in sent_messages) @@ -204,7 +213,9 @@ async def send(message): @pytest.mark.asyncio -async def test_slow_shutdown_cancelled_with_warning(caplog: pytest.LogCaptureFixture) -> None: +async def test_slow_shutdown_cancelled_with_warning( + caplog: pytest.LogCaptureFixture, +) -> None: """A shutdown handler exceeding the timeout is cancelled and a warning is logged.""" agent = AgentServerHost(graceful_shutdown_timeout=1) @@ -230,7 +241,10 @@ async def send(message): with caplog.at_level(logging.WARNING, logger="azure.ai.agentserver"): await agent(scope, receive, send) - assert any("did not complete" in r.message.lower() or "timeout" in r.message.lower() for r in caplog.records) + assert any( + "did not complete" in r.message.lower() or "timeout" in r.message.lower() + for r in caplog.records + ) assert any(m["type"] == "lifespan.shutdown.complete" for m in sent_messages) @@ -341,7 +355,9 @@ def fake_asyncio_run(coroutine): finally: signal.signal(signal.SIGTERM, original) - def test_sigterm_handler_logs_and_re_raises(self, caplog: pytest.LogCaptureFixture) -> None: + def test_sigterm_handler_logs_and_re_raises( + self, caplog: pytest.LogCaptureFixture + ) -> None: """The installed SIGTERM handler logs then re-raises via os.kill.""" original = signal.getsignal(signal.SIGTERM) try: diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py index a95e4980d530..9b2c05287882 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py @@ -16,4 +16,5 @@ def test_log_level_preserved_across_imports() -> None: lib_logger = logging.getLogger("azure.ai.agentserver") lib_logger.setLevel(logging.ERROR) from azure.ai.agentserver.core import _base # noqa: F401 + assert lib_logger.level == logging.ERROR diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py index 85e28c1bf15e..4ea165ee3f84 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py @@ -12,7 +12,6 @@ from azure.ai.agentserver.core._config import resolve_port - # ------------------------------------------------------------------ # # Port resolution # ------------------------------------------------------------------ # diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py index d6af2accb52c..1802f52b0644 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py @@ -20,7 +20,9 @@ def test_none_like_empty_returns_not_set(self) -> None: assert _mask_uri(" ") == _NOT_SET def test_https_uri_strips_path_and_query(self) -> None: - result = _mask_uri("https://myproject.azure.com/subscriptions/abc?api-version=2024") + result = _mask_uri( + "https://myproject.azure.com/subscriptions/abc?api-version=2024" + ) assert result == "https://myproject.azure.com" def test_http_uri_with_port(self) -> None: @@ -68,7 +70,9 @@ def _clean_env(self, monkeypatch: pytest.MonkeyPatch) -> None: @pytest.mark.usefixtures("_clean_env") @pytest.mark.asyncio - async def test_startup_logs_platform_environment(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_startup_logs_platform_environment( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Lifespan startup emits platform environment log line.""" from azure.ai.agentserver.core import AgentServerHost @@ -78,7 +82,9 @@ async def test_startup_logs_platform_environment(self, caplog: pytest.LogCapture async with app.router.lifespan_context(app): pass - platform_logs = [r for r in caplog.records if "Platform environment" in r.message] + platform_logs = [ + r for r in caplog.records if "Platform environment" in r.message + ] assert len(platform_logs) == 1 msg = platform_logs[0].message assert "is_hosted=False" in msg @@ -86,7 +92,9 @@ async def test_startup_logs_platform_environment(self, caplog: pytest.LogCapture @pytest.mark.usefixtures("_clean_env") @pytest.mark.asyncio - async def test_startup_logs_connectivity(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_startup_logs_connectivity( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Lifespan startup emits connectivity log line with masked URIs.""" from azure.ai.agentserver.core import AgentServerHost @@ -104,7 +112,9 @@ async def test_startup_logs_connectivity(self, caplog: pytest.LogCaptureFixture) @pytest.mark.usefixtures("_clean_env") @pytest.mark.asyncio - async def test_startup_logs_host_options(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_startup_logs_host_options( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Lifespan startup emits host options log line.""" from azure.ai.agentserver.core import AgentServerHost @@ -125,7 +135,9 @@ async def test_startup_masks_project_endpoint( self, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: """Project endpoint URI is masked to scheme://host only.""" - monkeypatch.setenv("FOUNDRY_PROJECT_ENDPOINT", "https://myproject.azure.com/sub/123?key=secret") + monkeypatch.setenv( + "FOUNDRY_PROJECT_ENDPOINT", "https://myproject.azure.com/sub/123?key=secret" + ) monkeypatch.delenv("FOUNDRY_HOSTING_ENVIRONMENT", raising=False) monkeypatch.delenv("PORT", raising=False) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py index 2b3531b552d1..c1eb8a81ae3f 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py @@ -7,7 +7,11 @@ from opentelemetry import baggage as _otel_baggage, context as _otel_context from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult +from opentelemetry.sdk.trace.export import ( + SimpleSpanProcessor, + SpanExporter, + SpanExportResult, +) from opentelemetry.sdk.resources import Resource from azure.ai.agentserver.core import AgentServerHost @@ -34,6 +38,8 @@ def shutdown(self): def force_flush(self, timeout_millis=30000): return True + + # ------------------------------------------------------------------ # # Tracing enabled / disabled # ------------------------------------------------------------------ # @@ -53,14 +59,24 @@ def test_observability_always_called(self) -> None: mock_configure.assert_called_once() def test_observability_receives_appinsights_env_var(self) -> None: - with mock.patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): + with mock.patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): mock_configure = mock.MagicMock() AgentServerHost(configure_observability=mock_configure) mock_configure.assert_called_once() - assert mock_configure.call_args[1]["connection_string"] == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + assert ( + mock_configure.call_args[1]["connection_string"] + == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + ) def test_observability_receives_otlp_env_var(self) -> None: - with mock.patch.dict(os.environ, {"OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4318"}): + with mock.patch.dict( + os.environ, {"OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4318"} + ): mock_configure = mock.MagicMock() AgentServerHost(configure_observability=mock_configure) mock_configure.assert_called_once() @@ -78,7 +94,12 @@ def test_observability_receives_constructor_connection_string(self) -> None: def test_observability_disabled_when_none(self) -> None: """Passing configure_observability=None disables all SDK-managed observability.""" - with mock.patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): + with mock.patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): # Should not raise even with App Insights configured AgentServerHost(configure_observability=None) @@ -92,14 +113,19 @@ class TestAppInsightsConnectionString: """Tests for resolve_appinsights_connection_string().""" def test_explicit_wins(self) -> None: - assert resolve_appinsights_connection_string("InstrumentationKey=abc") == "InstrumentationKey=abc" + assert ( + resolve_appinsights_connection_string("InstrumentationKey=abc") + == "InstrumentationKey=abc" + ) def test_env_var(self) -> None: with mock.patch.dict( os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=env"}, ): - assert resolve_appinsights_connection_string(None) == "InstrumentationKey=env" + assert ( + resolve_appinsights_connection_string(None) == "InstrumentationKey=env" + ) def test_none_when_unset(self) -> None: env = os.environ.copy() @@ -112,7 +138,9 @@ def test_explicit_overrides_env_var(self) -> None: os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=env"}, ): - result = resolve_appinsights_connection_string("InstrumentationKey=explicit") + result = resolve_appinsights_connection_string( + "InstrumentationKey=explicit" + ) assert result == "InstrumentationKey=explicit" @@ -125,18 +153,29 @@ class TestSetupDistroExport: """Verify _configure_tracing calls the distro with the right args.""" def test_distro_called_when_conn_str_provided(self) -> None: - with mock.patch("azure.ai.agentserver.core._tracing._setup_distro_export") as mock_distro: + with mock.patch( + "azure.ai.agentserver.core._tracing._setup_distro_export" + ) as mock_distro: from azure.ai.agentserver.core import _tracing - _tracing._configure_tracing(connection_string="InstrumentationKey=00000000-0000-0000-0000-000000000000") + + _tracing._configure_tracing( + connection_string="InstrumentationKey=00000000-0000-0000-0000-000000000000" + ) mock_distro.assert_called_once() kwargs = mock_distro.call_args[1] - assert kwargs["connection_string"] == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + assert ( + kwargs["connection_string"] + == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + ) assert len(kwargs["span_processors"]) >= 1 assert len(kwargs["log_record_processors"]) >= 1 def test_distro_called_without_conn_str(self) -> None: - with mock.patch("azure.ai.agentserver.core._tracing._setup_distro_export") as mock_distro: + with mock.patch( + "azure.ai.agentserver.core._tracing._setup_distro_export" + ) as mock_distro: from azure.ai.agentserver.core import _tracing + _tracing._configure_tracing(connection_string=None) mock_distro.assert_called_once() kwargs = mock_distro.call_args[1] @@ -187,8 +226,10 @@ def _create_provider(processor): def test_agent_attrs_present_on_exported_span(self) -> None: proc = _FoundryEnrichmentSpanProcessor( - agent_name="my-agent", agent_version="1.0", - agent_id="my-agent:1.0", project_id="proj-123", + agent_name="my-agent", + agent_version="1.0", + agent_id="my-agent:1.0", + project_id="proj-123", ) provider, collector = self._create_provider(proc) tracer = provider.get_tracer("test") @@ -205,8 +246,10 @@ def test_agent_attrs_present_on_exported_span(self) -> None: def test_agent_attrs_survive_framework_overwrite(self) -> None: """A framework setting agent attrs mid-span must not win.""" proc = _FoundryEnrichmentSpanProcessor( - agent_name="my-agent", agent_version="1.0", - agent_id="my-agent:1.0", project_id="proj-123", + agent_name="my-agent", + agent_version="1.0", + agent_id="my-agent:1.0", + project_id="proj-123", ) provider, collector = self._create_provider(proc) tracer = provider.get_tracer("test") @@ -221,8 +264,10 @@ def test_agent_attrs_survive_framework_overwrite(self) -> None: def test_none_fields_are_skipped(self) -> None: proc = _FoundryEnrichmentSpanProcessor( - agent_name=None, agent_version=None, - agent_id=None, project_id=None, + agent_name=None, + agent_version=None, + agent_id=None, + project_id=None, ) provider, collector = self._create_provider(proc) tracer = provider.get_tracer("test") @@ -239,7 +284,9 @@ def test_none_fields_are_skipped(self) -> None: def test_no_crash_when_span_lacks_attributes(self) -> None: """If the SDK changes internals, _on_ending must not raise.""" proc = _FoundryEnrichmentSpanProcessor( - agent_name="a", agent_version="1", agent_id="a:1", + agent_name="a", + agent_version="1", + agent_id="a:1", ) fake_span = object() # no _attributes at all proc._on_ending(fake_span) # should not raise @@ -253,7 +300,8 @@ def test_session_id_from_baggage(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.session_id", "session-456", + "azure.ai.agentserver.session_id", + "session-456", ) with tracer.start_as_current_span("span", context=ctx): pass @@ -269,7 +317,8 @@ def test_conversation_id_from_baggage(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.conversation_id", "conv-123", + "azure.ai.agentserver.conversation_id", + "conv-123", ) with tracer.start_as_current_span("span", context=ctx): pass @@ -285,10 +334,13 @@ def test_both_session_and_conversation_set_independently(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.session_id", "session-456", + "azure.ai.agentserver.session_id", + "session-456", ) ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.conversation_id", "conv-123", context=ctx, + "azure.ai.agentserver.conversation_id", + "conv-123", + context=ctx, ) with tracer.start_as_current_span("span", context=ctx): pass @@ -317,10 +369,13 @@ def test_baggage_ids_propagate_to_child_spans(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.session_id", "session-456", + "azure.ai.agentserver.session_id", + "session-456", ) ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.conversation_id", "conv-789", context=ctx, + "azure.ai.agentserver.conversation_id", + "conv-789", + context=ctx, ) token = _otel_context.attach(ctx) try: @@ -364,6 +419,3 @@ def test_agent_version_default_empty(self) -> None: env.pop("FOUNDRY_AGENT_VERSION", None) with mock.patch.dict(os.environ, env, clear=True): assert resolve_agent_version() == "" - - - diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing_e2e.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing_e2e.py index d1c428e2bfa3..b2638d214a3e 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing_e2e.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing_e2e.py @@ -52,7 +52,9 @@ def _flush_provider(): provider.force_flush() -def _poll_appinsights(logs_client, resource_id, query, *, timeout=_APPINSIGHTS_POLL_TIMEOUT): +def _poll_appinsights( + logs_client, resource_id, query, *, timeout=_APPINSIGHTS_POLL_TIMEOUT +): """Poll Application Insights until the KQL query returns ≥1 row or timeout. Returns the list of rows from the first table, or an empty list on timeout. @@ -74,6 +76,7 @@ def _poll_appinsights(logs_client, resource_id, query, *, timeout=_APPINSIGHTS_P # Minimal echo app factories using core's AgentServerHost + request_span() # --------------------------------------------------------------------------- + def _make_echo_app(): """Create an AgentServerHost with a POST /echo route that creates a traced span. @@ -104,6 +107,7 @@ async def stream_handler(request: Request) -> StreamingResponse: req_id = str(uuid.uuid4()) request_ids.append(req_id) with app.request_span(dict(request.headers), req_id, "invoke_agent"): + async def generate(): for chunk in [b"chunk1\n", b"chunk2\n", b"chunk3\n"]: yield chunk @@ -151,7 +155,9 @@ async def fail_handler(request: Request) -> Response: req_id = str(uuid.uuid4()) request_ids.append(req_id) try: - with app.request_span(dict(request.headers), req_id, "invoke_agent") as span: + with app.request_span( + dict(request.headers), req_id, "invoke_agent" + ) as span: raise ValueError("e2e error test") except ValueError: span.set_status(trace.StatusCode.ERROR, "e2e error test") @@ -169,6 +175,7 @@ async def fail_handler(request: Request) -> Response: # E2E: Verify spans are ingested into Application Insights # --------------------------------------------------------------------------- + class TestAppInsightsIngestionE2E: """Query Application Insights ``requests`` table to confirm spans were actually ingested, correlating via gen_ai.response.id.""" @@ -219,9 +226,9 @@ def test_streaming_span_in_appinsights( "| take 1" ) rows = _poll_appinsights(logs_query_client, appinsights_resource_id, query) - assert len(rows) > 0, ( - f"Streaming span with response_id={req_id} not found in App Insights" - ) + assert ( + len(rows) > 0 + ), f"Streaming span with response_id={req_id} not found in App Insights" def test_error_span_in_appinsights( self, @@ -243,9 +250,9 @@ def test_error_span_in_appinsights( "| take 1" ) rows = _poll_appinsights(logs_query_client, appinsights_resource_id, query) - assert len(rows) > 0, ( - f"Error span with response_id={req_id} not found in App Insights" - ) + assert ( + len(rows) > 0 + ), f"Error span with response_id={req_id} not found in App Insights" def test_genai_attributes_in_appinsights( self, @@ -305,14 +312,16 @@ def test_span_parenting_in_appinsights( "| project id, name, operation_Id, operation_ParentId " "| take 1" ) - child_rows = _poll_appinsights(logs_query_client, appinsights_resource_id, child_query) + child_rows = _poll_appinsights( + logs_query_client, appinsights_resource_id, child_query + ) assert len(child_rows) > 0, ( f"Child framework_child span (id={child_span_id}) not found in " f"dependencies table after {_APPINSIGHTS_POLL_TIMEOUT}s" ) - operation_id = child_rows[0][2] # operation_Id column - child_parent_id = child_rows[0][3] # operation_ParentId column + operation_id = child_rows[0][2] # operation_Id column + child_parent_id = child_rows[0][3] # operation_ParentId column # Step 2: Find the parent span in the requests table using the child's operation_ParentId. parent_query = ( @@ -322,12 +331,14 @@ def test_span_parenting_in_appinsights( "| project id, name, operation_Id " "| take 1" ) - parent_rows = _poll_appinsights(logs_query_client, appinsights_resource_id, parent_query) + parent_rows = _poll_appinsights( + logs_query_client, appinsights_resource_id, parent_query + ) assert len(parent_rows) > 0, ( f"Parent span (id={child_parent_id}) referenced by child's " f"operation_ParentId not found in requests table" ) - assert parent_rows[0][1] == "invoke_agent", ( - f"Expected parent span name 'invoke_agent', got '{parent_rows[0][1]}'" - ) + assert ( + parent_rows[0][1] == "invoke_agent" + ), f"Expected parent span name 'invoke_agent', got '{parent_rows[0][1]}'" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md index 6ead3c39d58d..40c49f54d5ff 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md @@ -4,12 +4,16 @@ ### Features Added +- **Durable invocation samples** — Added `durable_langgraph` and `durable_multiturn` sample applications demonstrating crash-resilient long-running agents using `@durable_task` with the invocations protocol. + ### Breaking Changes ### Bugs Fixed ### Other Changes +- Bumped minimum `azure-ai-agentserver-core` dependency to `>=2.0.0b4`. + ## 1.0.0b3 (2026-04-22) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py index bf3120974fa0..aab98236653f 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py @@ -40,8 +40,12 @@ # Context variables for structured logging — concurrency-safe alternative to logger filters. -_invocation_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("invocation_id", default="") -_session_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("session_id", default="") +_invocation_id_var: contextvars.ContextVar[str] = contextvars.ContextVar( + "invocation_id", default="" +) +_session_id_var: contextvars.ContextVar[str] = contextvars.ContextVar( + "session_id", default="" +) class _InvocationLogFilter(logging.Filter): @@ -254,12 +258,16 @@ async def _dispatch_invoke(self, request: Request) -> Response: async def _dispatch_get_invocation(self, request: Request) -> Response: if self._get_invocation_fn is not None: return await self._get_invocation_fn(request) - return create_error_response("not_found", "get_invocation not implemented", status_code=404) + return create_error_response( + "not_found", "get_invocation not implemented", status_code=404 + ) async def _dispatch_cancel_invocation(self, request: Request) -> Response: if self._cancel_invocation_fn is not None: return await self._cancel_invocation_fn(request) - return create_error_response("not_found", "cancel_invocation not implemented", status_code=404) + return create_error_response( + "not_found", "cancel_invocation not implemented", status_code=404 + ) def get_openapi_spec(self) -> Optional[dict[str, Any]]: """Return the stored OpenAPI spec, or None.""" @@ -277,7 +285,9 @@ def _safe_set_attrs(span: Any, attrs: dict[str, str]) -> None: for key, value in attrs.items(): span.set_attribute(key, value) except Exception: # pylint: disable=broad-exception-caught - logger.debug("Failed to set span attributes: %s", list(attrs.keys()), exc_info=True) + logger.debug( + "Failed to set span attributes: %s", list(attrs.keys()), exc_info=True + ) # ------------------------------------------------------------------ # Streaming response helpers @@ -330,49 +340,67 @@ async def _iter_with_context(): # type: ignore[return-value] # Endpoint handlers # ------------------------------------------------------------------ - async def _get_openapi_spec_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + async def _get_openapi_spec_endpoint( + self, request: Request + ) -> Response: # pylint: disable=unused-argument spec = self.get_openapi_spec() if spec is None: - return create_error_response("not_found", "No OpenAPI spec registered", status_code=404) + return create_error_response( + "not_found", "No OpenAPI spec registered", status_code=404 + ) return JSONResponse(spec) async def _create_invocation_endpoint(self, request: Request) -> Response: generated_id = str(uuid.uuid4()) - raw_invocation_id = request.headers.get(InvocationConstants.INVOCATION_ID_HEADER) or "" + raw_invocation_id = ( + request.headers.get(InvocationConstants.INVOCATION_ID_HEADER) or "" + ) invocation_id = _sanitize_id(raw_invocation_id, generated_id) request.state.invocation_id = invocation_id # Session ID: query param overrides env var / generated UUID raw_session_id = ( - request.query_params.get("agent_session_id") - or self.config.session_id - or "" + request.query_params.get("agent_session_id") or self.config.session_id or "" ) session_id = _sanitize_id(raw_session_id, str(uuid.uuid4())) request.state.session_id = session_id # Platform isolation headers — expose to handlers - request.state.user_isolation_key = request.headers.get("x-agent-user-isolation-key", "") - request.state.chat_isolation_key = request.headers.get("x-agent-chat-isolation-key", "") + request.state.user_isolation_key = request.headers.get( + "x-agent-user-isolation-key", "" + ) + request.state.chat_isolation_key = request.headers.get( + "x-agent-chat-isolation-key", "" + ) with self.request_span( - request.headers, invocation_id, "invoke_agent", - operation_name="invoke_agent", session_id=session_id, + request.headers, + invocation_id, + "invoke_agent", + operation_name="invoke_agent", + session_id=session_id, end_on_exit=False, ) as otel_span: - self._safe_set_attrs(otel_span, { - InvocationConstants.ATTR_SPAN_INVOCATION_ID: invocation_id, - InvocationConstants.ATTR_SPAN_SESSION_ID: session_id, - }) + self._safe_set_attrs( + otel_span, + { + InvocationConstants.ATTR_SPAN_INVOCATION_ID: invocation_id, + InvocationConstants.ATTR_SPAN_SESSION_ID: session_id, + }, + ) # Propagate invocation/session IDs as W3C baggage so downstream # services receive them automatically via the baggage header. ctx = _otel_context.get_current() ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.invocation_id", invocation_id, context=ctx, + "azure.ai.agentserver.invocation_id", + invocation_id, + context=ctx, ) ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.session_id", session_id, context=ctx, + "azure.ai.agentserver.session_id", + session_id, + context=ctx, ) baggage_token = _otel_context.attach(ctx) @@ -382,13 +410,18 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: session_token = _session_id_var.set(session_id) try: response = await self._dispatch_invoke(request) - response.headers[InvocationConstants.INVOCATION_ID_HEADER] = invocation_id + response.headers[InvocationConstants.INVOCATION_ID_HEADER] = ( + invocation_id + ) response.headers[InvocationConstants.SESSION_ID_HEADER] = session_id except NotImplementedError as exc: - self._safe_set_attrs(otel_span, { - InvocationConstants.ATTR_SPAN_ERROR_CODE: "not_implemented", - InvocationConstants.ATTR_SPAN_ERROR_MESSAGE: str(exc), - }) + self._safe_set_attrs( + otel_span, + { + InvocationConstants.ATTR_SPAN_ERROR_CODE: "not_implemented", + InvocationConstants.ATTR_SPAN_ERROR_MESSAGE: str(exc), + }, + ) end_span(otel_span, exc=exc) logger.error("Invocation %s failed: %s", invocation_id, exc) return create_error_response( @@ -401,12 +434,20 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: }, ) except Exception as exc: # pylint: disable=broad-exception-caught - self._safe_set_attrs(otel_span, { - InvocationConstants.ATTR_SPAN_ERROR_CODE: "internal_error", - InvocationConstants.ATTR_SPAN_ERROR_MESSAGE: str(exc), - }) + self._safe_set_attrs( + otel_span, + { + InvocationConstants.ATTR_SPAN_ERROR_CODE: "internal_error", + InvocationConstants.ATTR_SPAN_ERROR_MESSAGE: str(exc), + }, + ) end_span(otel_span, exc=exc) - logger.error("Error processing invocation %s: %s", invocation_id, exc, exc_info=True) + logger.error( + "Error processing invocation %s: %s", + invocation_id, + exc, + exc_info=True, + ) return create_error_response( "internal_error", "Internal server error", @@ -444,27 +485,44 @@ async def _traced_invocation_endpoint( session_id = _sanitize_id(raw_session_id, "") if raw_session_id else "" with self.request_span( - request.headers, invocation_id, span_operation, - operation_name=span_operation, session_id=session_id, + request.headers, + invocation_id, + span_operation, + operation_name=span_operation, + session_id=session_id, ) as _otel_span: - self._safe_set_attrs(_otel_span, { - InvocationConstants.ATTR_SPAN_INVOCATION_ID: invocation_id, - InvocationConstants.ATTR_SPAN_SESSION_ID: session_id, - }) + self._safe_set_attrs( + _otel_span, + { + InvocationConstants.ATTR_SPAN_INVOCATION_ID: invocation_id, + InvocationConstants.ATTR_SPAN_SESSION_ID: session_id, + }, + ) _ensure_log_filter() inv_token = _invocation_id_var.set(invocation_id) session_token = _session_id_var.set(session_id) try: response = await dispatch(request) - response.headers[InvocationConstants.INVOCATION_ID_HEADER] = invocation_id + response.headers[InvocationConstants.INVOCATION_ID_HEADER] = ( + invocation_id + ) return response except Exception as exc: # pylint: disable=broad-exception-caught - self._safe_set_attrs(_otel_span, { - InvocationConstants.ATTR_SPAN_ERROR_CODE: "internal_error", - InvocationConstants.ATTR_SPAN_ERROR_MESSAGE: str(exc), - }) + self._safe_set_attrs( + _otel_span, + { + InvocationConstants.ATTR_SPAN_ERROR_CODE: "internal_error", + InvocationConstants.ATTR_SPAN_ERROR_MESSAGE: str(exc), + }, + ) record_error(_otel_span, exc) - logger.error("Error in %s %s: %s", span_operation, invocation_id, exc, exc_info=True) + logger.error( + "Error in %s %s: %s", + span_operation, + invocation_id, + exc, + exc_info=True, + ) return create_error_response( "internal_error", "Internal server error", diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml index 7657fdf1df67..b70d8ea30022 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ keywords = ["azure", "azure sdk", "agent", "agentserver", "invocations"] dependencies = [ - "azure-ai-agentserver-core>=2.0.0b3", + "azure-ai-agentserver-core>=2.0.0b4", ] [dependency-groups] diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py index cde877039960..227bda4ca2f5 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py @@ -38,6 +38,7 @@ curl -X POST http://localhost:8088/invocations/abc-123/cancel # -> {"invocation_id": "abc-123", "status": "cancelled"} """ + import asyncio import json @@ -46,7 +47,6 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # In-memory state for demo purposes (see module docstring for production caveats) _tasks: dict[str, asyncio.Task] = {} _results: dict[str, bytes] = {} @@ -65,11 +65,13 @@ async def _do_work(invocation_id: str, data: dict) -> bytes: :rtype: bytes """ await asyncio.sleep(5) - result = json.dumps({ - "invocation_id": invocation_id, - "status": "completed", - "output": f"Processed: {data}", - }).encode() + result = json.dumps( + { + "invocation_id": invocation_id, + "status": "completed", + "output": f"Processed: {data}", + } + ).encode() _results[invocation_id] = result return result @@ -89,10 +91,12 @@ async def handle_invoke(request: Request) -> Response: task = asyncio.create_task(_do_work(invocation_id, data)) _tasks[invocation_id] = task - return JSONResponse({ - "invocation_id": invocation_id, - "status": "running", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "running", + } + ) @app.get_invocation_handler @@ -112,10 +116,12 @@ async def handle_get_invocation(request: Request) -> Response: if invocation_id in _tasks: task = _tasks[invocation_id] if not task.done(): - return JSONResponse({ - "invocation_id": invocation_id, - "status": "running", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "running", + } + ) result = task.result() _results[invocation_id] = result del _tasks[invocation_id] @@ -137,11 +143,13 @@ async def handle_cancel_invocation(request: Request) -> Response: # Already completed — cannot cancel if invocation_id in _results: - return JSONResponse({ - "invocation_id": invocation_id, - "status": "completed", - "error": "invocation already completed", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + } + ) if invocation_id in _tasks: task = _tasks[invocation_id] @@ -149,17 +157,21 @@ async def handle_cancel_invocation(request: Request) -> Response: # Task finished between check — treat as completed _results[invocation_id] = task.result() del _tasks[invocation_id] - return JSONResponse({ - "invocation_id": invocation_id, - "status": "completed", - "error": "invocation already completed", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + } + ) task.cancel() del _tasks[invocation_id] - return JSONResponse({ - "invocation_id": invocation_id, - "status": "cancelled", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "cancelled", + } + ) return JSONResponse({"error": "not found"}, status_code=404) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/agent.py new file mode 100644 index 000000000000..e400cd9b5827 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/agent.py @@ -0,0 +1,152 @@ +"""Steerable durable Claude conversation agent. + +Wraps the Anthropic streaming API in a steerable durable task. +Demonstrates the **three-phase cancel pattern**: + +1. Pre-entry check — short-circuit if a newer input is already queued +2. Mid-stream check — break out of the SSE chunk loop +3. Post-completion — catch late arrivals after the reply finished + +Conversation history is stored in an external ``FileStore`` (not in task +metadata, which has a < 1 MB limit). In production, replace ``FileStore`` +with Redis, Cosmos DB, etc. +""" + +import asyncio +import logging +from pathlib import Path +from typing import Any + +from azure.ai.agentserver.core.durable import TaskContext, durable_task + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +# External stores — NOT in task metadata +invocation_store = FileStore(_DATA_DIR / "claude-invocations") +conversation_store = FileStore(_DATA_DIR / "claude-conversations") + + +def _load_history(session_id: str) -> list[dict[str, str]]: + """Load conversation history from external store.""" + data = conversation_store.load(session_id) + if data and "messages" in data: + return data["messages"] + return [] + + +def _save_history(session_id: str, history: list[dict[str, str]]) -> None: + """Persist conversation history to external store.""" + conversation_store.save(session_id, {"messages": history}) + + +@durable_task(name="claude_session", steerable=True) +async def claude_session(ctx: TaskContext[dict]) -> dict[str, Any]: + """Run one Claude conversation turn with streaming and steering support. + + Input schema: ``{"session_id": str, "message": str, "invocation_id": str}`` + """ + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + await ctx.stream({"type": "lifecycle", "status": "running"}) + + logger.info( + "Claude session %s gen=%d invocation=%s entry=%s", + session_id, + ctx.generation, + invocation_id, + ctx.entry_mode, + ) + + # Load history from external store (not task metadata) + history = _load_history(session_id) + history.append({"role": "user", "content": message}) + + # ── Phase 1: Pre-entry cancel (rapid-fire steering) ───────────── + if ctx.cancel.is_set(): + logger.info("Skipping gen=%d — cancel pre-set", ctx.generation) + _save_history(session_id, history) + invocation_store.save( + invocation_id, + { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + }, + ) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Stream Claude response, checking cancel ──────────── + import anthropic # pylint: disable=import-outside-toplevel + + reply = "" + was_aborted = False + + client = anthropic.AsyncAnthropic() + async with client.messages.stream( + model="claude-sonnet-4-20250514", + max_tokens=1024, + messages=history, + ) as stream: + async for text in stream.text_stream: + reply += text + # Live stream: push delta to any SSE subscriber on the POST response + await ctx.stream({"type": "text_delta", "delta": text}) + # Durable snapshot: GET polling always returns the full text so far + invocation_store.save( + invocation_id, + { + "status": "streaming", + "text": reply, + }, + ) + if ctx.cancel.is_set(): + was_aborted = True + logger.info("Stream aborted mid-generation at %d chars", len(reply)) + break + + # ── Phase 3: Save result ──────────────────────────────────────── + # Save history to external store (including partial text) + if reply: + history.append({"role": "assistant", "content": reply}) + _save_history(session_id, history) + + user_turns = len([m for m in history if m["role"] == "user"]) + output = { + "invocation_id": invocation_id, + "reply": reply, + "turn": user_turns, + "partial": was_aborted, + } + + if was_aborted: + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_mid_stream", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + if ctx.cancel.is_set(): + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_post_completion", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + # Normal completion + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/app.py new file mode 100644 index 000000000000..baad0e389f43 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/app.py @@ -0,0 +1,189 @@ +"""HTTP host for the Claude durable agent with steering and streaming. + +Wires the Claude durable task (``agent.py``) to the invocations framework. +With ``steerable=True``, calling ``start()`` on an in-progress task queues +the new input — no manual cancel/wait/restart logic needed. + +**Streaming**: If the POST request includes ``Accept: text/event-stream``, +the response is an SSE stream of text deltas as they are generated. If the +client disconnects mid-stream, it can fall back to ``GET /invocations/`` +which returns the full text snapshot at that moment. + +Usage:: + + pip install -r requirements.txt + export ANTHROPIC_API_KEY="sk-..." + + python -m durable_claude.app + + # Turn 1 (async — poll for result) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Tell me about quantum computing"}' + # → 202 {"invocation_id": "...", "status": "running"} + + # Turn 1 (streaming — live text deltas) + curl -N -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -H "Accept: text/event-stream" \\ + -d '{"message": "Tell me about quantum computing"}' + # → 200 data: {"type": "text_delta", "delta": "Quantum"} + # data: {"type": "text_delta", "delta": " computing"} + # ... + # event: done + # data: {"type": "done", ...} + + # Poll (works after disconnect or for async mode) + curl "http://localhost:8088/invocations/" + # → {"invocation_id": "", "status": "completed", "output": {...}} + + # Steer (while turn 1 is still running) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Actually, explain machine learning instead"}' +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import claude_session, invocation_store + +logger = logging.getLogger(__name__) + +app = InvocationAgentServerHost() + + +async def _sse_from_run( + run: object, invocation_id: str, *, initial_status: str = "queued" +) -> AsyncGenerator[bytes, None]: + """Convert a TaskRun's stream into SSE-formatted bytes. + + Yields lifecycle events (``queued``, ``running``), then ``text_delta`` + chunks, then a terminal event (``done``, ``error``, ``superseded``). + + :param run: The TaskRun handle. + :param invocation_id: Invocation identifier for event payloads. + :param initial_status: First lifecycle status to emit (e.g. ``"queued"``). + """ + from azure.ai.agentserver.core.durable import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + TaskFailed, + TaskTerminated, + ) + + # Emit initial lifecycle event so the caller knows the request was accepted + yield ( + f"data: {json.dumps({'type': 'lifecycle', 'status': initial_status, 'invocation_id': invocation_id})}\n\n" + ).encode() + + try: + async for chunk in run: # type: ignore[union-attr] + yield f"data: {json.dumps(chunk)}\n\n".encode() + + # Stream ended normally — get the result + try: + result = await run.result() # type: ignore[union-attr] + done_data = { + "type": "done", + "invocation_id": invocation_id, + } + if ( + result is not None + and hasattr(result, "output") + and result.output is not None + ): + done_data["output"] = result.output + yield f"event: done\ndata: {json.dumps(done_data)}\n\n".encode() + except (TaskCancelled, TaskTerminated): + yield ( + f"event: superseded\n" + f"data: {json.dumps({'type': 'superseded', 'invocation_id': invocation_id})}\n\n" + ).encode() + except TaskFailed as exc: + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + except Exception as exc: # pylint: disable=broad-except + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or steer a Claude session. + + If ``Accept: text/event-stream`` is set, returns an SSE stream of + text deltas. Otherwise returns ``202 Accepted`` for async polling. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + task_input = { + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + } + + invocation_store.save(invocation_id, {"status": "queued"}) + + run = await claude_session.start(task_id=task_id, input=task_input) + + # SSE streaming mode — return live text deltas + wants_stream = "text/event-stream" in request.headers.get("accept", "") + if wants_stream: + return StreamingResponse( + _sse_from_run(run, invocation_id), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + # Async mode — return 202 and let client poll + stored = invocation_store.load(invocation_id) + status = stored["status"] if stored else "queued" + + return JSONResponse( + {"invocation_id": invocation_id, "status": status}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's result. + + Returns the current snapshot — during streaming this includes + ``{"status": "streaming", "text": "..."}`` with the full text + generated so far. After completion, returns the final output. + + This is the recovery path: if a streaming client disconnects, + it switches to polling to get the accumulated text. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/requirements.txt new file mode 100644 index 000000000000..da81ce3dd1a6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/requirements.txt @@ -0,0 +1,5 @@ +anthropic>=0.30.0 +azure-ai-agentserver-core +azure-ai-agentserver-invocations +starlette +uvicorn diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/store.py new file mode 100644 index 000000000000..1f456a19ea18 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/store.py @@ -0,0 +1,59 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> bool: + """Remove the entry for *key*. Returns ``True`` if it existed.""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() + return True + return False diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/agent.py new file mode 100644 index 000000000000..1620e48ab888 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/agent.py @@ -0,0 +1,192 @@ +"""Steerable durable Copilot conversation agent. + +Wraps the **GitHub Copilot SDK** in a steerable durable task. +Demonstrates the **three-phase cancel pattern**: + +1. Pre-entry check — enqueue the message to the SDK then abort immediately +2. Mid-stream check — ``session.abort()`` when ``ctx.cancel`` fires +3. Post-completion — catch late arrivals after the reply finished + +The Copilot SDK manages conversation history internally, so there is no +external history store needed (unlike the Claude sample). +""" + +import asyncio +import logging +from pathlib import Path +from typing import Any + +from azure.ai.agentserver.core.durable import TaskContext, durable_task + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +invocation_store = FileStore(_DATA_DIR / "copilot-invocations") + + +@durable_task(name="copilot_session", steerable=True) +async def copilot_session(ctx: TaskContext[dict]) -> dict[str, Any]: + """Run one Copilot conversation turn with steering support. + + Input schema: ``{"session_id": str, "message": str, "invocation_id": str}`` + """ + from copilot import CopilotClient # pylint: disable=import-outside-toplevel + from copilot.generated.session_events import ( # pylint: disable=import-outside-toplevel + AssistantMessageData, + IdleData, + ) + from copilot.session import ( + PermissionHandler, + ) # pylint: disable=import-outside-toplevel + + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + await ctx.stream({"type": "lifecycle", "status": "running"}) + + logger.info( + "Copilot session %s gen=%d invocation=%s entry=%s", + session_id, + ctx.generation, + invocation_id, + ctx.entry_mode, + ) + + # ── Phase 1: Pre-entry cancel (rapid-fire steering) ───────────── + # Cancel is pre-set when more inputs are already queued. We still + # send the message so the SDK records it, then abort immediately. + if ctx.cancel.is_set(): + logger.info("Skipping gen=%d — cancel pre-set", ctx.generation) + async with CopilotClient() as client: + session = await client.resume_session( + session_id, + on_permission_request=PermissionHandler.approve_all, + ) + await session.send(message) + await session.abort() + invocation_store.save( + invocation_id, + { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + }, + ) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Stream the Copilot turn, checking cancel ─────────── + reply = "" + was_aborted = False + + async with CopilotClient() as client: + if ctx.entry_mode != "fresh": + session = await client.resume_session( + session_id, + on_permission_request=PermissionHandler.approve_all, + ) + else: + session = await client.create_session( + session_id=session_id, + on_permission_request=PermissionHandler.approve_all, + ) + + # Event-based send: collect reply via events, abort on cancel + reply_parts: list[str] = [] + idle_event = asyncio.Event() + + def on_event(event: Any) -> None: + nonlocal reply_parts + if isinstance(event.data, AssistantMessageData): + content = event.data.content or "" + reply_parts.append(content) + # Schedule streaming — push delta to SSE subscriber and + # persist snapshot for GET polling + asyncio.get_event_loop().create_task( + _stream_and_persist(ctx, invocation_id, content, reply_parts) + ) + elif isinstance(event.data, IdleData): + idle_event.set() + + session.on(on_event) + await session.send(message) + + # Wait for idle (turn complete) or cancel, whichever first + cancel_task = asyncio.create_task(_wait_for_cancel(ctx.cancel)) + idle_task = asyncio.create_task(idle_event.wait()) + try: + done, _pending = await asyncio.wait( + {cancel_task, idle_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for t in _pending: + t.cancel() + + if cancel_task in done and idle_task not in done: + was_aborted = True + logger.info("session.abort() — new input queued") + await session.abort() + finally: + for t in (cancel_task, idle_task): + if not t.done(): + t.cancel() + + reply = "".join(reply_parts) + + # ── Phase 3: Save result ──────────────────────────────────────── + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + + if was_aborted: + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_mid_stream", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + if ctx.cancel.is_set(): + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_post_completion", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) + + +async def _wait_for_cancel(cancel: asyncio.Event) -> None: + """Await the cancel event. Extracted for use with ``asyncio.wait``.""" + await cancel.wait() + + +async def _stream_and_persist( + ctx: TaskContext[dict], + invocation_id: str, + delta: str, + parts: list[str], +) -> None: + """Push a streaming delta and persist the text snapshot.""" + await ctx.stream({"type": "text_delta", "delta": delta}) + invocation_store.save( + invocation_id, + { + "status": "streaming", + "text": "".join(parts), + }, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/app.py new file mode 100644 index 000000000000..1e04aa4c1b5b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/app.py @@ -0,0 +1,168 @@ +"""HTTP host for the Copilot durable agent with steering and streaming. + +Wires the Copilot durable task (``agent.py``) to the invocations framework. +With ``steerable=True``, calling ``start()`` on an in-progress task queues +the new input — no manual cancel/wait/restart logic needed. + +**Streaming**: If the POST request includes ``Accept: text/event-stream``, +the response is an SSE stream of text deltas as they are generated. If the +client disconnects mid-stream, it can fall back to ``GET /invocations/`` +which returns the full text snapshot at that moment. + +Requires the **GitHub Copilot SDK** (``pip install github-copilot-sdk``) +and the Copilot CLI installed and authenticated (``gh auth login``). + +Usage:: + + pip install -r requirements.txt + + python -m durable_copilot.app + + # Turn 1 (async) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Explain Python decorators"}' + + # Turn 1 (streaming) + curl -N -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -H "Accept: text/event-stream" \\ + -d '{"message": "Explain Python decorators"}' + + # Poll (recovery after disconnect) + curl "http://localhost:8088/invocations/" + + # Steer (while turn 1 is still running) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Actually, explain async/await instead"}' +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import copilot_session, invocation_store + +logger = logging.getLogger(__name__) + +app = InvocationAgentServerHost() + + +async def _sse_from_run( + run: object, invocation_id: str, *, initial_status: str = "queued" +) -> AsyncGenerator[bytes, None]: + """Convert a TaskRun's stream into SSE-formatted bytes.""" + from azure.ai.agentserver.core.durable import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + TaskFailed, + TaskTerminated, + ) + + yield ( + f"data: {json.dumps({'type': 'lifecycle', 'status': initial_status, 'invocation_id': invocation_id})}\n\n" + ).encode() + + try: + async for chunk in run: # type: ignore[union-attr] + yield f"data: {json.dumps(chunk)}\n\n".encode() + + try: + result = await run.result() # type: ignore[union-attr] + done_data = {"type": "done", "invocation_id": invocation_id} + if ( + result is not None + and hasattr(result, "output") + and result.output is not None + ): + done_data["output"] = result.output + yield f"event: done\ndata: {json.dumps(done_data)}\n\n".encode() + except (TaskCancelled, TaskTerminated): + yield ( + f"event: superseded\n" + f"data: {json.dumps({'type': 'superseded', 'invocation_id': invocation_id})}\n\n" + ).encode() + except TaskFailed as exc: + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + except Exception as exc: # pylint: disable=broad-except + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or steer a Copilot session. + + If ``Accept: text/event-stream`` is set, returns an SSE stream. + Otherwise returns ``202 Accepted`` for async polling. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + task_input = { + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + } + + invocation_store.save(invocation_id, {"status": "queued"}) + + run = await copilot_session.start(task_id=task_id, input=task_input) + + # SSE streaming mode + wants_stream = "text/event-stream" in request.headers.get("accept", "") + if wants_stream: + return StreamingResponse( + _sse_from_run(run, invocation_id), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + # Async mode + stored = invocation_store.load(invocation_id) + status = stored["status"] if stored else "queued" + + return JSONResponse( + {"invocation_id": invocation_id, "status": status}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's result. + + Returns the current snapshot — during streaming this includes the + full text generated so far. This is the recovery path after a + streaming disconnect. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/requirements.txt new file mode 100644 index 000000000000..a5c8adee9c42 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/requirements.txt @@ -0,0 +1,5 @@ +github-copilot-sdk +azure-ai-agentserver-core +azure-ai-agentserver-invocations +starlette +uvicorn diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/store.py new file mode 100644 index 000000000000..1f456a19ea18 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/store.py @@ -0,0 +1,59 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> bool: + """Remove the entry for *key*. Returns ``True`` if it existed.""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() + return True + return False diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/agent.py new file mode 100644 index 000000000000..cf6b84fb105c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/agent.py @@ -0,0 +1,423 @@ +"""LangGraph conversation agent with durable task lifecycle and steering. + +Wraps a LangGraph ``StateGraph`` in a steerable durable task. +Demonstrates the **checkpoint-and-fork** cancel pattern: + +1. Pre-entry check — short-circuit if cancel is pre-set +2. Inter-node check — ``_invoke_cancellable`` checks between graph nodes +3. Fork-on-steer — roll back to the last stable checkpoint and fork + with the new message + +LangGraph owns the conversation flow; the durable task owns crash +resilience and steering orchestration. +""" + +import asyncio +import logging +import sqlite3 +import typing +from pathlib import Path +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage +from langgraph.checkpoint.sqlite import SqliteSaver +from langgraph.graph import END, START, StateGraph, add_messages +from langgraph.types import Command, interrupt +from typing_extensions import TypedDict + +from azure.ai.agentserver.core.durable import TaskContext, durable_task + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +# Invocation result store — written inside the durable task so it survives crashes +invocation_store = FileStore(_DATA_DIR / "invocations") + + +# --------------------------------------------------------------------------- +# Graph state +# --------------------------------------------------------------------------- + + +class ConversationState(TypedDict): + """Graph state for a multi-turn conversation. + + Uses LangGraph's built-in ``add_messages`` reducer for message + accumulation across turns. + """ + + messages: typing.Annotated[list, add_messages] + is_complete: bool + + +# --------------------------------------------------------------------------- +# Graph nodes +# --------------------------------------------------------------------------- + +# Simulated step delay — distributed across nodes so inter-node +# cancellation (via ``graph.stream()``) can bail out quickly. +_STEP_DELAY = 2 # seconds per processing node + + +def analyze_input(state: ConversationState) -> dict[str, Any]: + """Simulate analysing the user's message (e.g., intent detection).""" + import time # pylint: disable=import-outside-toplevel + + _ = state # Would inspect messages in a real implementation + time.sleep(_STEP_DELAY) + return {} # No state change — analysis is an internal step + + +def generate_response(state: ConversationState) -> dict[str, Any]: + """Generate an AI response. Replace stub with a real LLM call.""" + import time # pylint: disable=import-outside-toplevel + + time.sleep(_STEP_DELAY) + + messages = state["messages"] + user_messages = [m for m in messages if isinstance(m, HumanMessage)] + turn = len(user_messages) + last_msg = user_messages[-1].content if user_messages else "" + + if turn == 1: + reply = ( + f"Thanks for reaching out! You said: '{last_msg}'. " + "I'd love to help — could you share more details?" + ) + elif turn == 2: + reply = ( + f"Great context: '{last_msg}'. Building on our earlier " + "exchange, here are some initial thoughts. What else " + "would you like to explore?" + ) + else: + reply = ( + f"Turn {turn}: incorporating '{last_msg}' — I now have " + f"context from {turn} turns. How shall we proceed?" + ) + + return {"messages": [AIMessage(content=reply)]} + + +def refine_response(state: ConversationState) -> dict[str, Any]: + """Simulate post-processing (e.g., safety checks, formatting).""" + import time # pylint: disable=import-outside-toplevel + + _ = state # Would inspect the generated reply in a real implementation + time.sleep(_STEP_DELAY // 2 or 1) + return {} # No state change — refinement is an internal step + + +def wait_for_user(state: ConversationState) -> dict[str, Any]: + """Pause the graph and wait for the next human message.""" + messages = state["messages"] + user_count = len([m for m in messages if isinstance(m, HumanMessage)]) + + user_input: str = interrupt( + { + "prompt": "Please provide your next message (or say 'done' to finish):", + "current_turn": user_count, + } + ) + + if user_input.strip().lower() == "done": + return {"is_complete": True} + + return { + "messages": [HumanMessage(content=user_input)], + "is_complete": False, + } + + +def _should_continue(state: ConversationState) -> str: + """Route: loop back to process_input or end the conversation.""" + if state.get("is_complete", False): + return "end" + return "continue" + + +# --------------------------------------------------------------------------- +# Persistent graph checkpointer (survives restarts) +# --------------------------------------------------------------------------- + +_DATA_DIR.mkdir(parents=True, exist_ok=True) +_DB_PATH = _DATA_DIR / "langgraph_checkpoints.db" + +_conn = sqlite3.connect(str(_DB_PATH), check_same_thread=False) +_checkpointer = SqliteSaver(_conn) +_checkpointer.setup() + +logger.info("LangGraph checkpoints stored at: %s", _DB_PATH) + + +# --------------------------------------------------------------------------- +# Build and compile the graph +# --------------------------------------------------------------------------- + + +def _build_graph() -> Any: + """Construct the LangGraph StateGraph for multi-turn conversation. + + Processing is split across three nodes (``analyze_input`` → + ``generate_response`` → ``refine_response``) so that stream-based + cancellation can bail out between any two steps (~2 s granularity). + """ + builder = StateGraph(ConversationState) + + builder.add_node("analyze_input", analyze_input) + builder.add_node("generate_response", generate_response) + builder.add_node("refine_response", refine_response) + builder.add_node("wait_for_user", wait_for_user) + + builder.add_edge(START, "analyze_input") + builder.add_edge("analyze_input", "generate_response") + builder.add_edge("generate_response", "refine_response") + builder.add_edge("refine_response", "wait_for_user") + + builder.add_conditional_edges( + "wait_for_user", + _should_continue, + { + "continue": "analyze_input", + "end": END, + }, + ) + + return builder.compile(checkpointer=_checkpointer) + + +_graph = _build_graph() + + +# --------------------------------------------------------------------------- +# Steering — cancellable graph invocation and state forking +# --------------------------------------------------------------------------- + + +def _invoke_cancellable( + graph: Any, + graph_input: Any, + config: dict[str, Any], + cancel_event: asyncio.Event, + on_node: Any = None, +) -> bool: + """Run the graph using ``stream()`` with inter-node cancellation. + + Instead of ``graph.invoke()`` which blocks until the full graph + completes, this streams node-by-node and checks ``cancel_event`` + between nodes. If cancellation is detected, execution stops before + the next node runs. + + Returns ``True`` if the graph ran to completion (or interrupt), + ``False`` if cancelled mid-graph. + """ + for chunk in graph.stream(graph_input, config): + if on_node is not None: + on_node(chunk) + if cancel_event.is_set(): + return False + return True + + +def _fork_from_checkpoint( + graph: Any, + config: dict[str, Any], + target_checkpoint_id: str, + new_message: str, +) -> bool: + """Fork the graph from a previous checkpoint with a new message. + + Uses LangGraph's native state forking: ``update_state`` called with + an old checkpoint's config creates a new branch. The graph's head + pointer moves to the fork, discarding any state that was added after + the target checkpoint. + + After forking the graph is positioned after ``wait_for_user`` with + the new message injected, so the next step is ``process_input``. + + Returns ``True`` if the fork was created. + """ + # Load the target checkpoint to get its full config (includes checkpoint_ns) + target_config = { + "configurable": { + **config["configurable"], + "checkpoint_id": target_checkpoint_id, + } + } + target = graph.get_state(target_config) + if not target or not target.config: + return False + + # Fork: update_state at the old checkpoint creates a new branch + graph.update_state( + target.config, + values={"messages": [HumanMessage(content=new_message)]}, + as_node="wait_for_user", + ) + return True + + +def _build_turn_output(state: Any) -> dict[str, Any]: + """Extract turn output from graph state at an interrupt.""" + messages = state.values.get("messages", []) + ai_messages = [m for m in messages if isinstance(m, AIMessage)] + user_messages = [m for m in messages if isinstance(m, HumanMessage)] + last_reply = ai_messages[-1].content if ai_messages else "" + return {"reply": last_reply, "turn": len(user_messages)} + + +def _build_session_output(state: Any) -> dict[str, Any]: + """Build final output when the graph conversation is complete.""" + messages = state.values.get("messages", []) + user_count = len([m for m in messages if isinstance(m, HumanMessage)]) + return { + "finished": True, + "turn_count": user_count, + "total_messages": len(messages), + "summary": f"Session complete after {user_count} turns.", + } + + +async def _finalize_invocation( + ctx: TaskContext[dict], + thread_config: dict[str, Any], + invocation_id: str, +) -> dict[str, Any] | Any: + """Save results and suspend/return after a graph invoke completes.""" + state = await asyncio.to_thread(_graph.get_state, thread_config) + + new_cp_id = state.config["configurable"]["checkpoint_id"] + ctx.metadata.set("stable_checkpoint_id", new_cp_id) + ctx.metadata.set("last_applied_invocation_id", invocation_id) + + if state.next: + output = _build_turn_output(state) + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) + + result = _build_session_output(state) + invocation_store.save(invocation_id, {"status": "completed", "output": result}) + return result + + +# --------------------------------------------------------------------------- +# Durable task — bridges LangGraph with HTTP lifecycle +# --------------------------------------------------------------------------- + + +@durable_task(name="langgraph_session", steerable=True) +async def langgraph_session(ctx: TaskContext[dict]) -> dict[str, Any]: + """Run one LangGraph conversation turn with steering support. + + Input schema: ``{"session_id": str, "message": str, "invocation_id": str}`` + """ + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + await ctx.stream({"type": "lifecycle", "status": "running"}) + + thread_config: dict[str, Any] = {"configurable": {"thread_id": session_id}} + + if ctx.entry_mode == "recovered": + logger.warning("Recovered stale task for session %s", session_id) + + # ── Fork-on-steer: rollback to stable checkpoint ──────────────── + # If the previous invocation was cancelled mid-flight, the graph may + # have drifted past the stable checkpoint. Fork from the stable + # checkpoint with the new message so the graph processes it cleanly. + stable_cp = ctx.metadata.get("stable_checkpoint_id") + if stable_cp: + state = await asyncio.to_thread(_graph.get_state, thread_config) + if state and state.values.get("messages"): + current_cp = state.config["configurable"].get("checkpoint_id") + if current_cp and current_cp != stable_cp: + forked = await asyncio.to_thread( + _fork_from_checkpoint, + _graph, + thread_config, + stable_cp, + message, + ) + if forked: + logger.info( + "Forked session %s from stable checkpoint %s", + session_id, + stable_cp, + ) + completed = await asyncio.to_thread( + _invoke_cancellable, + _graph, + None, + thread_config, + ctx.cancel, + ) + + if not completed or ctx.cancel.is_set(): + invocation_store.save( + invocation_id, + {"status": "cancelled", "reason": "steered"}, + ) + return await ctx.suspend(reason="steered") + + return await _finalize_invocation(ctx, thread_config, invocation_id) + + # ── Phase 1: Pre-entry cancel ─────────────────────────────────── + if ctx.cancel.is_set(): + invocation_store.save( + invocation_id, {"status": "cancelled", "reason": "steered"} + ) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Invoke graph with inter-node cancellation ────────── + state = await asyncio.to_thread(_graph.get_state, thread_config) + + if state.next: + graph_input = Command(resume=message) + else: + graph_input = { + "messages": [HumanMessage(content=message)], + "is_complete": False, + } + + loop = asyncio.get_event_loop() + + def _on_node(chunk: dict) -> None: + """Stream node progress events from the sync graph thread.""" + node_names = list(chunk.keys()) + for name in node_names: + if ctx._stream_queue is not None: # pylint: disable=protected-access + loop.call_soon_threadsafe( + ctx._stream_queue.put_nowait, # pylint: disable=protected-access + {"type": "node_progress", "node": name}, + ) + invocation_store.save( + invocation_id, + { + "status": "streaming", + "last_node": node_names[-1] if node_names else None, + }, + ) + + completed = await asyncio.to_thread( + _invoke_cancellable, + _graph, + graph_input, + thread_config, + ctx.cancel, + _on_node, + ) + + # ── Phase 3: Post-completion cancel check ─────────────────────── + if not completed or ctx.cancel.is_set(): + invocation_store.save( + invocation_id, {"status": "cancelled", "reason": "steered"} + ) + return await ctx.suspend(reason="steered") + + # Normal completion + return await _finalize_invocation(ctx, thread_config, invocation_id) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/app.py new file mode 100644 index 000000000000..517de7c8f2c9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/app.py @@ -0,0 +1,184 @@ +"""HTTP host for the LangGraph durable agent with streaming and steering. + +Wires the LangGraph durable task (``agent.py``) to the invocations framework. +Per-invocation results are written by the durable task itself (inside the +crash-resilient execution boundary), not by a background collector. + +Streaming +~~~~~~~~~ + +Pass ``Accept: text/event-stream`` on POST to receive an SSE stream of node +progress events (``node_progress``) plus lifecycle events (``queued``, +``running``). Without the header you get the standard 202 JSON response for +async polling via GET. + +Steering is handled by the framework: the durable task is declared with +``steerable=True``, so calling ``start()`` on an in-progress task **queues** +the new input instead of raising ``TaskConflictError``. The running function +sees ``ctx.cancel`` set and short-circuits. The framework then drains the +queue and re-enters the function with the next input. + +Usage:: + + pip install -r requirements.txt + + python -m durable_langgraph.app + # — or — + python app.py + + # Turn 1 — async + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "I need help planning a trip to Tokyo"}' + # → 202 (x-agent-invocation-id: ) + + # Turn 1 — streaming + curl -N -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -H "Accept: text/event-stream" \\ + -d '{"message": "I need help planning a trip to Tokyo"}' + # → SSE stream: lifecycle:queued → lifecycle:running → node_progress → done + + # Poll that invocation (snapshot — always available) + curl "http://localhost:8088/invocations/" + # → {"invocation_id": "", "status": "completed", "output": {...}} + + # Steer — send a new invocation while a turn is still running. + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Actually, let us go to Paris instead"}' + + # End session + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "done"}' +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import invocation_store, langgraph_session + +logger = logging.getLogger(__name__) + +app = InvocationAgentServerHost() + + +async def _sse_from_run( + run: object, invocation_id: str, *, initial_status: str = "queued" +) -> AsyncGenerator[bytes, None]: + """Convert a TaskRun's stream into SSE-formatted bytes.""" + from azure.ai.agentserver.core.durable import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + TaskFailed, + TaskTerminated, + ) + + yield ( + f"data: {json.dumps({'type': 'lifecycle', 'status': initial_status, 'invocation_id': invocation_id})}\n\n" + ).encode() + + try: + async for chunk in run: # type: ignore[union-attr] + yield f"data: {json.dumps(chunk)}\n\n".encode() + + try: + result = await run.result() # type: ignore[union-attr] + done_data = {"type": "done", "invocation_id": invocation_id} + if ( + result is not None + and hasattr(result, "output") + and result.output is not None + ): + done_data["output"] = result.output + yield f"event: done\ndata: {json.dumps(done_data)}\n\n".encode() + except (TaskCancelled, TaskTerminated): + yield ( + f"event: superseded\n" + f"data: {json.dumps({'type': 'superseded', 'invocation_id': invocation_id})}\n\n" + ).encode() + except TaskFailed as exc: + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + except Exception as exc: # pylint: disable=broad-except + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or steer a LangGraph session. + + If ``Accept: text/event-stream`` is set, returns an SSE stream of node + progress events. Otherwise returns ``202 Accepted`` for async polling. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + task_input = { + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + } + + invocation_store.save(invocation_id, {"status": "queued"}) + + run = await langgraph_session.start(task_id=task_id, input=task_input) + + # SSE streaming mode — return live node progress + wants_stream = "text/event-stream" in request.headers.get("accept", "") + if wants_stream: + return StreamingResponse( + _sse_from_run(run, invocation_id), + media_type="text/event-stream", + headers={"X-Agent-Invocation-Id": invocation_id}, + ) + + # Standard async mode — return 202 with status from store + stored = invocation_store.load(invocation_id) + status = stored["status"] if stored else "queued" + + return JSONResponse( + {"invocation_id": invocation_id, "status": status}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's snapshot. + + Returns the durable snapshot from the invocation store. During streaming + this includes ``last_node``; after completion it includes full output. + Use this as the recovery path after an SSE disconnect. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/requirements.txt new file mode 100644 index 000000000000..79260e068214 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/requirements.txt @@ -0,0 +1,4 @@ +azure-ai-agentserver-invocations +langgraph>=0.2 +langgraph-checkpoint-sqlite>=2.0 +langchain-core>=0.3 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/store.py new file mode 100644 index 000000000000..1f456a19ea18 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/store.py @@ -0,0 +1,59 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> bool: + """Remove the entry for *key*. Returns ``True`` if it existed.""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() + return True + return False diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/agent.py new file mode 100644 index 000000000000..d54d0b4c76eb --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/agent.py @@ -0,0 +1,105 @@ +"""Durable multi-turn session agent. + +Defines the durable task that powers a sticky conversation session. Each +invocation runs this function from the top — ``ctx.entry_mode`` tells us +whether this is a fresh start, a resume, or a crash recovery. + +The agent keeps its own conversation state in a ``FileStore`` checkpoint +and writes per-invocation results to the invocation store — both inside +the durable execution boundary so they survive crashes. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +# Session checkpoint store — conversation state across turns +checkpoint_store = FileStore(_DATA_DIR / "checkpoints") + +# Invocation result store — written inside the durable task so it survives crashes +invocation_store = FileStore(_DATA_DIR / "invocations") + + +def _generate_reply(state: dict[str, Any]) -> str: + """Placeholder for an LLM call. Replace with your model of choice.""" + turn = state["turn_count"] + last_msg = state["history"][-1]["content"] if state["history"] else "" + if turn == 1: + return ( + f"Thanks for reaching out! You said: '{last_msg}'. " + "Could you share more details so I can help?" + ) + if turn == 2: + return ( + f"Great, noted: '{last_msg}'. Based on our conversation " + "so far, here are some initial thoughts. What else?" + ) + return ( + f"Turn {turn}: incorporating '{last_msg}' — " + f"I now have context from {turn} turns of conversation." + ) + + +@durable_task(name="session_workflow") +async def session_workflow(ctx: TaskContext[dict]) -> dict[str, Any]: + """Single durable function for the entire session. + + Each invocation runs this function from the top. + ``ctx.entry_mode`` tells us why we were entered. + + The invocation result is written to ``invocation_store`` **inside** the + durable boundary — if the process crashes, the task recovers and the + write happens on re-execution. + """ + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + # Mark invocation as running — inside the durable boundary so it + # only exists if the task is actually executing. + invocation_store.save(invocation_id, {"status": "running"}) + + state = checkpoint_store.load(session_id) or {"history": [], "turn_count": 0} + + if ctx.entry_mode == "recovered": + logger.warning("Recovered stale task for session %s", session_id) + + # Handle explicit session end + if message.strip().lower() == "done": + summary = ( + f"Session complete after {state['turn_count']} turns. " + f"Total messages exchanged: {len(state['history'])}." + ) + checkpoint_store.delete(session_id) + result = {"reply": summary, "turn": state["turn_count"], "finished": True} + invocation_store.save(invocation_id, {"status": "completed", "output": result}) + return result + + # Process this turn + state["history"].append({"role": "user", "content": message}) + state["turn_count"] += 1 + + reply = _generate_reply(state) + state["history"].append({"role": "assistant", "content": reply}) + + checkpoint_store.save(session_id, state) + + # Persist invocation result BEFORE suspending (inside durable boundary) + output = {"reply": reply, "turn": state["turn_count"]} + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + + # Suspend — the client will resume with the next turn + return await ctx.suspend(reason="awaiting_user_input", output=output) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/app.py new file mode 100644 index 000000000000..91e7daec9240 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/app.py @@ -0,0 +1,100 @@ +"""HTTP host for the durable multi-turn agent. + +Wires the durable task (``agent.py``) to the invocations framework. +Per-invocation results are written by the durable task itself (inside the +crash-resilient execution boundary), not by a background collector. + +Usage:: + + pip install azure-ai-agentserver-invocations + + python -m durable_multiturn.app + # — or — + python app.py + + # Turn 1 + curl -X POST "http://localhost:8088/invocations?agent_session_id=trip-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "I want to plan a vacation to Japan"}' + # → 202 (x-agent-invocation-id: ) + + # Poll that invocation + curl "http://localhost:8088/invocations/" + # → {"invocation_id": "", "status": "completed", "output": {...}} + + # Turn 2 + curl -X POST "http://localhost:8088/invocations?agent_session_id=trip-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Budget is $5000, 2 weeks"}' + + # End session + curl -X POST "http://localhost:8088/invocations?agent_session_id=trip-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "done"}' +""" + +from __future__ import annotations + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.core.durable import TaskConflictError +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import invocation_store, session_workflow + +app = InvocationAgentServerHost() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or resume a durable session task. + + Each POST is one invocation. The durable task is an internal detail + — the caller only sees ``invocation_id`` (from platform headers). + + The task itself writes the invocation result to the store inside the + durable execution boundary — no background collector needed. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + try: + await session_workflow.start( + task_id=task_id, + input={ + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + }, + ) + except TaskConflictError as e: + return JSONResponse({"error": str(e)}, status_code=409) + + return JSONResponse( + {"invocation_id": invocation_id, "status": "running"}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's result. + + Reads from the file-based invocation store — works after restarts. + Returns the output of **this invocation only** — not the whole session. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/requirements.txt new file mode 100644 index 000000000000..bc5cf4644e14 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-invocations diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/store.py new file mode 100644 index 000000000000..003049988a81 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/store.py @@ -0,0 +1,57 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> None: + """Remove the entry for *key* (no-op if missing).""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py index 96fa857bf02c..ddef29b2864d 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py @@ -32,12 +32,12 @@ -d '{"message": "Budget is $5000, prefer direct flights"}' # -> {"reply": "Here is a suggested itinerary ...", ...} """ + from starlette.requests import Request from starlette.responses import JSONResponse, Response from azure.ai.agentserver.invocations import InvocationAgentServerHost - app = InvocationAgentServerHost() # In-memory session store — keyed by session ID. @@ -91,11 +91,13 @@ async def handle_invoke(request: Request) -> Response: reply = _build_reply(history) history.append({"role": "assistant", "content": reply}) - return JSONResponse({ - "reply": reply, - "session_id": session_id, - "turn": len([m for m in history if m["role"] == "user"]), - }) + return JSONResponse( + { + "reply": reply, + "session_id": session_id, + "turn": len([m for m in history if m["role"] == "user"]), + } + ) if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py index a2e7fdb32d3b..adb537cf5dce 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py @@ -11,12 +11,12 @@ curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' # -> {"greeting": "Hello, Alice!"} """ + from starlette.requests import Request from starlette.responses import JSONResponse, Response from azure.ai.agentserver.invocations import InvocationAgentServerHost - app = InvocationAgentServerHost() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py index a207a93cca0d..c5caf7b5a920 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py @@ -18,6 +18,7 @@ # -> event: done # -> data: {"invocation_id": "..."} """ + import asyncio import json from collections.abc import AsyncGenerator # pylint: disable=import-error @@ -27,14 +28,32 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - app = InvocationAgentServerHost() # Simulated tokens — in production these would come from a model. _SIMULATED_TOKENS = [ - "class", " Calculator", ":", "\n", - " ", "def", " add", "(", "self", ",", " a", ",", " b", ")", ":", "\n", - " ", "return", " a", " +", " b", "\n", + "class", + " Calculator", + ":", + "\n", + " ", + "def", + " add", + "(", + "self", + ",", + " a", + ",", + " b", + ")", + ":", + "\n", + " ", + "return", + " a", + " +", + " b", + "\n", ] diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/conftest.py index 8a3deb55c72f..765ae3d21135 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/conftest.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/conftest.py @@ -15,13 +15,17 @@ def pytest_configure(config): - config.addinivalue_line("markers", "tracing_e2e: end-to-end tracing tests against live Application Insights") + config.addinivalue_line( + "markers", + "tracing_e2e: end-to-end tracing tests against live Application Insights", + ) # --------------------------------------------------------------------------- # E2E tracing fixtures # --------------------------------------------------------------------------- + @pytest.fixture() def appinsights_connection_string(): """Return APPLICATIONINSIGHTS_CONNECTION_STRING or skip the test.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py index 73307f2ba110..4bb7d141570a 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py @@ -10,11 +10,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # invoke_handler stores function # --------------------------------------------------------------------------- + def test_invoke_handler_stores_function(): """@app.invoke_handler stores the function on the protocol object.""" app = InvocationAgentServerHost() @@ -30,6 +30,7 @@ async def handle(request: Request) -> Response: # invoke_handler returns original function # --------------------------------------------------------------------------- + def test_invoke_handler_returns_original_function(): """@app.invoke_handler returns the original function.""" app = InvocationAgentServerHost() @@ -45,6 +46,7 @@ async def handle(request: Request) -> Response: # get_invocation_handler stores function # --------------------------------------------------------------------------- + def test_get_invocation_handler_stores_function(): """@app.get_invocation_handler stores the function.""" app = InvocationAgentServerHost() @@ -60,6 +62,7 @@ async def get_handler(request: Request) -> Response: # cancel_invocation_handler stores function # --------------------------------------------------------------------------- + def test_cancel_invocation_handler_stores_function(): """@app.cancel_invocation_handler stores the function.""" app = InvocationAgentServerHost() @@ -75,6 +78,7 @@ async def cancel_handler(request: Request) -> Response: # shutdown_handler stores function # --------------------------------------------------------------------------- + def test_shutdown_handler_stores_function(): """@server.shutdown_handler stores the function on the server.""" app = InvocationAgentServerHost() @@ -90,6 +94,7 @@ async def on_shutdown(): # Full request flow # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_full_request_flow(): """Full lifecycle: invoke → get → cancel → get (404).""" @@ -107,7 +112,9 @@ async def get_handler(request: Request) -> Response: inv_id = request.path_params["invocation_id"] if inv_id in store: return Response(content=store[inv_id]) - return JSONResponse({"error": {"code": "not_found", "message": "Not found"}}, status_code=404) + return JSONResponse( + {"error": {"code": "not_found", "message": "Not found"}}, status_code=404 + ) @app.cancel_invocation_handler async def cancel_handler(request: Request) -> Response: @@ -115,7 +122,9 @@ async def cancel_handler(request: Request) -> Response: if inv_id in store: del store[inv_id] return JSONResponse({"status": "cancelled"}) - return JSONResponse({"error": {"code": "not_found", "message": "Not found"}}, status_code=404) + return JSONResponse( + {"error": {"code": "not_found", "message": "Not found"}}, status_code=404 + ) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as client: @@ -142,6 +151,7 @@ async def cancel_handler(request: Request) -> Response: # Missing optional handlers return 404 # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_missing_invoke_handler_returns_501(): """POST /invocations without registered handler returns 501.""" @@ -186,6 +196,7 @@ async def handle(request: Request) -> Response: # Optional handler defaults and overrides # --------------------------------------------------------------------------- + def test_optional_handlers_default_none(): """Get and cancel handlers default to None.""" app = InvocationAgentServerHost() @@ -208,6 +219,7 @@ async def get_handler(request: Request) -> Response: # Shutdown handler called during lifespan # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_called_during_lifespan(): """Shutdown handler is called when the app lifespan ends.""" @@ -235,6 +247,7 @@ async def on_shutdown(): # Config passthrough # --------------------------------------------------------------------------- + def test_graceful_shutdown_timeout_passthrough(): """graceful_shutdown_timeout is passed through to the base class.""" server = InvocationAgentServerHost(graceful_shutdown_timeout=15) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py index 351418db7461..999f46310e07 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py @@ -64,6 +64,7 @@ async def handle(request: Request) -> Response: # Method not allowed tests # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocations_returns_405(): """GET /invocations returns 405 Method Not Allowed.""" @@ -128,6 +129,7 @@ async def handle(request: Request) -> Response: # Response header tests # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_custom_invocation_id_overwritten(): """Handler-set x-agent-invocation-id is overwritten by the server.""" @@ -176,6 +178,7 @@ async def test_invocation_id_generated_when_empty(echo_client): # Payload edge cases # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_large_payload(): """Large payload (1MB) is handled correctly.""" @@ -210,6 +213,7 @@ async def test_binary_payload(echo_client): # Streaming edge cases # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_empty_streaming(): """Empty streaming response doesn't crash.""" @@ -243,6 +247,7 @@ async def generate(): # Invocation lifecycle # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_multiple_gets(async_storage_client): """Multiple GETs for the same invocation return the same result.""" @@ -283,6 +288,7 @@ async def test_invoke_cancel_get(async_storage_client): # Concurrency # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_concurrent_invocations_get_unique_ids(): """10 concurrent POSTs each get unique invocation IDs.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py index 23c133fe3b9b..5cf42f8fe4c6 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py @@ -10,11 +10,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # GET after invoke # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_after_invoke_returns_stored_result(async_storage_client): """GET /invocations/{id} after invoke returns the stored result.""" @@ -31,6 +31,7 @@ async def test_get_after_invoke_returns_stored_result(async_storage_client): # GET unknown ID # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_unknown_id_returns_404(async_storage_client): """GET /invocations/{unknown} returns 404.""" @@ -42,6 +43,7 @@ async def test_get_unknown_id_returns_404(async_storage_client): # Cancel after invoke # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_after_invoke_returns_cancelled(async_storage_client): """POST /invocations/{id}/cancel after invoke returns cancelled status.""" @@ -57,6 +59,7 @@ async def test_cancel_after_invoke_returns_cancelled(async_storage_client): # Cancel unknown ID # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_unknown_id_returns_404(async_storage_client): """POST /invocations/{unknown}/cancel returns 404.""" @@ -68,6 +71,7 @@ async def test_cancel_unknown_id_returns_404(async_storage_client): # GET after cancel # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_after_cancel_returns_404(async_storage_client): """GET after cancel returns 404 (data has been removed).""" @@ -83,6 +87,7 @@ async def test_get_after_cancel_returns_404(async_storage_client): # GET error returns 500 (inline InvocationAgentServerHost) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocation_error_returns_500(): """GET handler raising an exception returns 500.""" @@ -107,6 +112,7 @@ async def get_handler(request: Request) -> Response: # Cancel error returns 500 (inline InvocationAgentServerHost) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_invocation_error_returns_500(): """Cancel handler raising an exception returns 500.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py index db35beceda0f..0cb5ed48cf0e 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py @@ -13,11 +13,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_server_with_shutdown(**kwargs) -> tuple[InvocationAgentServerHost, list]: """Create InvocationAgentServerHost with a tracked shutdown handler.""" server = InvocationAgentServerHost(**kwargs) @@ -38,6 +38,7 @@ async def on_shutdown(): # Shutdown handler registration # --------------------------------------------------------------------------- + def test_shutdown_handler_registered(): """Shutdown handler is stored on the server.""" server, _ = _make_server_with_shutdown() @@ -59,6 +60,7 @@ async def handle(request: Request) -> Response: # ASGI lifespan helper # --------------------------------------------------------------------------- + async def _drive_lifespan(app): """Drive a full ASGI lifespan startup+shutdown cycle.""" scope = {"type": "lifespan"} @@ -84,6 +86,7 @@ async def send(message): # Shutdown handler called during lifespan # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_called_on_lifespan_exit(): """Shutdown handler runs when the ASGI lifespan exits.""" @@ -99,6 +102,7 @@ async def test_shutdown_handler_called_on_lifespan_exit(): # Shutdown handler timeout # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_timeout(caplog): """Shutdown handler that exceeds timeout is warned about.""" @@ -120,13 +124,17 @@ async def on_shutdown(): # Shutdown should have been interrupted assert "completed" not in calls # Logger should have warned about timeout - assert any("did not complete" in r.message.lower() or "timeout" in r.message.lower() for r in caplog.records) + assert any( + "did not complete" in r.message.lower() or "timeout" in r.message.lower() + for r in caplog.records + ) # --------------------------------------------------------------------------- # Shutdown handler exception # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_exception(caplog): """Shutdown handler that raises is caught and logged.""" @@ -144,13 +152,17 @@ async def on_shutdown(): await _drive_lifespan(app) # Should have logged the exception - assert any("on_shutdown" in r.message.lower() or "error" in r.message.lower() for r in caplog.records) + assert any( + "on_shutdown" in r.message.lower() or "error" in r.message.lower() + for r in caplog.records + ) # --------------------------------------------------------------------------- # Graceful shutdown timeout config # --------------------------------------------------------------------------- + def test_default_graceful_shutdown_timeout(): """Default graceful shutdown timeout is 30 seconds.""" app = InvocationAgentServerHost() @@ -173,6 +185,7 @@ def test_zero_graceful_shutdown_timeout(): # Health endpoint accessible during normal operation # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_health_endpoint_during_operation(): """GET /readiness returns 200 during normal operation.""" @@ -188,6 +201,7 @@ async def test_health_endpoint_during_operation(): # No shutdown handler is no-op # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_no_shutdown_handler_is_noop(): """Without a shutdown handler, lifespan exit succeeds silently.""" @@ -208,6 +222,7 @@ async def handle(request: Request) -> Response: # Multiple requests before shutdown # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_multiple_requests_before_shutdown(): """Multiple requests can be served, then shutdown handler runs.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py index 5de15efd63cc..198cbcd76711 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py @@ -12,6 +12,7 @@ # Echo body # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_echo_body(echo_client): """POST /invocations echoes the request body.""" @@ -24,6 +25,7 @@ async def test_invoke_echo_body(echo_client): # Headers # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_returns_invocation_id_header(echo_client): """Response includes x-agent-invocation-id header.""" @@ -68,6 +70,7 @@ async def test_invoke_accepts_custom_invocation_id(echo_client): # Streaming # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_streaming_returns_chunks(streaming_client): """Streaming handler returns 3 JSON chunks.""" @@ -91,6 +94,7 @@ async def test_streaming_has_invocation_id_header(streaming_client): # Empty body # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_empty_body(echo_client): """Empty body doesn't crash the server.""" @@ -103,6 +107,7 @@ async def test_invoke_empty_body(echo_client): # Error handling # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_error_returns_500(failing_client): """Handler exception returns 500 with generic message.""" @@ -124,6 +129,7 @@ async def test_invoke_error_has_invocation_id(failing_client): # Error handling # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_error_hides_details_by_default(failing_client): """Exception message is hidden in error responses.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py index 818eb20c491e..ee866da198fb 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py @@ -12,11 +12,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # Helper: content-type echo agent # --------------------------------------------------------------------------- + def _make_content_type_echo_agent() -> InvocationAgentServerHost: """Agent that echoes body and returns the content-type it received.""" app = InvocationAgentServerHost() @@ -66,6 +66,7 @@ async def generate(): # Various content types # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_png_content_type(): """PNG content type is accepted and echoed.""" @@ -166,6 +167,7 @@ async def test_text_plain_content_type(): # Custom HTTP status codes # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_custom_status_200(): """Handler returning 200.""" @@ -200,6 +202,7 @@ async def test_custom_status_202(): # Query strings # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_query_string_passed_to_handler(): """Query string params are accessible in the handler.""" @@ -221,6 +224,7 @@ async def handle(request: Request) -> Response: # SSE streaming # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_sse_streaming(): """SSE-formatted streaming response works.""" @@ -238,6 +242,7 @@ async def test_sse_streaming(): # Large binary payloads # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_large_binary_payload(): """Large binary payload (512KB) is handled correctly.""" @@ -258,6 +263,7 @@ async def test_large_binary_payload(): # Health endpoint (updated from /healthy to /readiness) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_health_endpoint_returns_200(): """GET /readiness returns 200 with healthy status.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py index 934433bd0333..4b087e1e958b 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py @@ -19,6 +19,7 @@ # Header presence — success responses # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_returns_request_id_header(echo_client): """POST /invocations success response includes x-request-id.""" @@ -61,6 +62,7 @@ async def test_readiness_returns_request_id(echo_client): # Error responses — header present, but NO body enrichment # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_error_response_has_request_id_header(failing_client): """500 error response includes x-request-id header.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py index 24d71ed51e8f..95b24d827638 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py @@ -10,11 +10,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # InvocationAgentServerHost no longer accepts request_timeout # --------------------------------------------------------------------------- + def test_no_request_timeout_parameter(): """InvocationAgentServerHost no longer accepts request_timeout.""" with pytest.raises(TypeError): @@ -25,6 +25,7 @@ def test_no_request_timeout_parameter(): # Slow invoke completes without timeout # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_slow_invoke_completes(): """Without timeout, handler runs to completion.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py index 8bafb6fb9608..80d560c5b965 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py @@ -18,6 +18,7 @@ # POST /invocations returns 200 # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_returns_200(echo_client): """POST /invocations returns 200 OK.""" @@ -29,6 +30,7 @@ async def test_post_invocations_returns_200(echo_client): # POST /invocations returns invocation-id header (UUID) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_returns_uuid_invocation_id(echo_client): """POST /invocations returns a valid UUID in x-agent-invocation-id.""" @@ -42,6 +44,7 @@ async def test_post_invocations_returns_uuid_invocation_id(echo_client): # GET openapi spec returns 404 when not set # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_openapi_spec_returns_404_when_not_set(no_spec_client): """GET /invocations/docs/openapi.json returns 404 when no spec registered.""" @@ -53,6 +56,7 @@ async def test_get_openapi_spec_returns_404_when_not_set(no_spec_client): # GET openapi spec returns spec when registered # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_openapi_spec_returns_spec_when_registered(): """GET /invocations/docs/openapi.json returns the spec when registered.""" @@ -73,6 +77,7 @@ async def handle(request: Request) -> Response: # GET /invocations/{id} returns 404 default # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocation_returns_404_default(echo_client): """GET /invocations/{id} returns 404 when no get handler registered.""" @@ -84,6 +89,7 @@ async def test_get_invocation_returns_404_default(echo_client): # POST /invocations/{id}/cancel returns 404 default # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_invocation_returns_404_default(echo_client): """POST /invocations/{id}/cancel returns 404 when no cancel handler.""" @@ -95,6 +101,7 @@ async def test_cancel_invocation_returns_404_default(echo_client): # Unknown route returns 404 # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_unknown_route_returns_404(echo_client): """Unknown route returns 404.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py index 6398f2f8d327..7a8f54751859 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py @@ -20,6 +20,7 @@ # Constants # --------------------------------------------------------------------------- + def test_session_id_header_constant(): """SESSION_ID_HEADER constant is correct.""" assert InvocationConstants.SESSION_ID_HEADER == "x-agent-session-id" @@ -29,6 +30,7 @@ def test_session_id_header_constant(): # POST /invocations response has x-agent-session-id header # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_has_session_id_header(echo_client): """POST /invocations response includes x-agent-session-id header.""" @@ -42,6 +44,7 @@ async def test_post_invocations_has_session_id_header(echo_client): # POST /invocations with query param uses that value # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_with_query_param(): """POST /invocations with agent_session_id query param uses that value.""" @@ -64,6 +67,7 @@ async def handle(request: Request) -> Response: # POST /invocations with env var # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_uses_env_var(): """POST /invocations uses FOUNDRY_AGENT_SESSION_ID env var when no query param.""" @@ -75,7 +79,9 @@ async def handle(request: Request) -> Response: return Response(content=b"ok") transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://testserver") as client: + async with AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: resp = await client.post("/invocations", content=b"test") assert resp.headers["x-agent-session-id"] == "env-session" @@ -84,6 +90,7 @@ async def handle(request: Request) -> Response: # GET /invocations/{id} does NOT have x-agent-session-id header # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocation_no_session_id_header(async_storage_client): """GET /invocations/{id} does NOT include x-agent-session-id.""" @@ -99,6 +106,7 @@ async def test_get_invocation_no_session_id_header(async_storage_client): # POST /invocations/{id}/cancel does NOT have x-agent-session-id header # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_invocation_no_session_id_header(async_storage_client): """POST /invocations/{id}/cancel does NOT include x-agent-session-id.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py index 5c31f78b6a8a..aaf36e6da05d 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py @@ -23,7 +23,9 @@ from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) _HAS_OTEL = True except ImportError: @@ -63,8 +65,15 @@ def _get_spans(): def _make_server_with_child_span(): """Server whose handler creates a child span (simulating a framework).""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): - with patch("azure.ai.agentserver.core._tracing._setup_distro_export", create=True): + with patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): + with patch( + "azure.ai.agentserver.core._tracing._setup_distro_export", create=True + ): app = InvocationAgentServerHost() child_tracer = trace.get_tracer("test.framework") @@ -78,16 +87,25 @@ async def handle(request: Request) -> Response: def _make_streaming_server_with_child_span(): """Server with streaming response whose handler creates a child span.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): - with patch("azure.ai.agentserver.core._tracing._setup_distro_export", create=True): + with patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): + with patch( + "azure.ai.agentserver.core._tracing._setup_distro_export", create=True + ): app = InvocationAgentServerHost() child_tracer = trace.get_tracer("test.framework") @app.invoke_handler async def handle(request: Request) -> StreamingResponse: with child_tracer.start_as_current_span("framework_invoke_agent"): + async def generate(): yield b"chunk\n" + return StreamingResponse(generate(), media_type="text/plain") return app @@ -95,11 +113,19 @@ async def generate(): def _assert_child_parented(spans, streaming: bool = False): """Assert the framework span is a child of the invoke_agent span.""" - parent_spans = [s for s in spans if "invoke_agent" in s.name and s.name != "framework_invoke_agent"] + parent_spans = [ + s + for s in spans + if "invoke_agent" in s.name and s.name != "framework_invoke_agent" + ] child_spans = [s for s in spans if s.name == "framework_invoke_agent"] - assert len(parent_spans) >= 1, f"Expected invoke_agent span, got: {[s.name for s in spans]}" - assert len(child_spans) == 1, f"Expected framework span, got: {[s.name for s in spans]}" + assert ( + len(parent_spans) >= 1 + ), f"Expected invoke_agent span, got: {[s.name for s in spans]}" + assert ( + len(child_spans) == 1 + ), f"Expected framework span, got: {[s.name for s in spans]}" parent = parent_spans[0] child = child_spans[0] diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing.py index 082ad23549ed..cf07cc03c113 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing.py @@ -28,7 +28,9 @@ from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) _HAS_OTEL = True except ImportError: @@ -70,10 +72,18 @@ def _get_spans(): # Helper: create tracing-enabled server # --------------------------------------------------------------------------- + def _make_tracing_server(**kwargs): """Create an InvocationAgentServerHost with tracing enabled.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): - with patch("azure.ai.agentserver.core._tracing._setup_distro_export", create=True): + with patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): + with patch( + "azure.ai.agentserver.core._tracing._setup_distro_export", create=True + ): server = InvocationAgentServerHost(**kwargs) @server.invoke_handler @@ -86,8 +96,15 @@ async def handle(request: Request) -> Response: def _make_tracing_server_with_get_cancel(**kwargs): """Create a tracing-enabled server with get/cancel handlers.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): - with patch("azure.ai.agentserver.core._tracing._setup_distro_export", create=True): + with patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): + with patch( + "azure.ai.agentserver.core._tracing._setup_distro_export", create=True + ): server = InvocationAgentServerHost(**kwargs) store: dict[str, bytes] = {} @@ -103,7 +120,9 @@ async def get_handler(request: Request) -> Response: inv_id = request.path_params["invocation_id"] if inv_id in store: return Response(content=store[inv_id]) - return JSONResponse({"error": {"code": "not_found", "message": "Not found"}}, status_code=404) + return JSONResponse( + {"error": {"code": "not_found", "message": "Not found"}}, status_code=404 + ) @server.cancel_invocation_handler async def cancel_handler(request: Request) -> Response: @@ -111,15 +130,24 @@ async def cancel_handler(request: Request) -> Response: if inv_id in store: del store[inv_id] return JSONResponse({"status": "cancelled"}) - return JSONResponse({"error": {"code": "not_found", "message": "Not found"}}, status_code=404) + return JSONResponse( + {"error": {"code": "not_found", "message": "Not found"}}, status_code=404 + ) return server def _make_failing_tracing_server(**kwargs): """Create a tracing-enabled server whose handler raises.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): - with patch("azure.ai.agentserver.core._tracing._setup_distro_export", create=True): + with patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): + with patch( + "azure.ai.agentserver.core._tracing._setup_distro_export", create=True + ): server = InvocationAgentServerHost(**kwargs) @server.invoke_handler @@ -131,8 +159,15 @@ async def handle(request: Request) -> Response: def _make_streaming_tracing_server(**kwargs): """Create a tracing-enabled server with streaming response.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): - with patch("azure.ai.agentserver.core._tracing._setup_distro_export", create=True): + with patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): + with patch( + "azure.ai.agentserver.core._tracing._setup_distro_export", create=True + ): server = InvocationAgentServerHost(**kwargs) @server.invoke_handler @@ -150,6 +185,7 @@ async def generate(): # Tracing disabled by default # --------------------------------------------------------------------------- + def test_tracing_disabled_by_default(): """Invoke spans are still created by the global tracer when tracing is not explicitly configured.""" if _MODULE_EXPORTER: @@ -176,6 +212,7 @@ async def handle(request: Request) -> Response: # Tracing enabled creates invoke span with correct name # --------------------------------------------------------------------------- + def test_tracing_enabled_creates_invoke_span(): """Tracing enabled creates a span named 'invoke_agent'.""" server = _make_tracing_server() @@ -192,6 +229,7 @@ def test_tracing_enabled_creates_invoke_span(): # Invoke error records exception # --------------------------------------------------------------------------- + def test_invoke_error_records_exception(): """When handler raises, the span records the exception.""" server = _make_failing_tracing_server() @@ -211,6 +249,7 @@ def test_invoke_error_records_exception(): # GET/cancel create spans # --------------------------------------------------------------------------- + def test_get_invocation_creates_span(): """GET /invocations/{id} creates a span.""" server = _make_tracing_server_with_get_cancel() @@ -241,10 +280,18 @@ def test_cancel_invocation_creates_span(): # Tracing via env var # --------------------------------------------------------------------------- + def test_tracing_via_appinsights_env_var(): """Tracing is enabled when APPLICATIONINSIGHTS_CONNECTION_STRING is set.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): - with patch("azure.ai.agentserver.core._tracing._setup_distro_export", create=True): + with patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): + with patch( + "azure.ai.agentserver.core._tracing._setup_distro_export", create=True + ): app = InvocationAgentServerHost() @app.invoke_handler @@ -263,6 +310,7 @@ async def handle(request: Request) -> Response: # No tracing when no endpoints configured # --------------------------------------------------------------------------- + def test_no_tracing_when_no_endpoints(): """When no connection string or OTLP endpoint is set, configure_observability still runs (for console logging) but tracing spans are not exported.""" @@ -293,6 +341,7 @@ async def handle(request: Request) -> Response: # Traceparent propagation # --------------------------------------------------------------------------- + def test_traceparent_propagation(): """Server propagates traceparent header into span context.""" server = _make_tracing_server() @@ -322,6 +371,7 @@ def test_traceparent_propagation(): # Streaming spans # --------------------------------------------------------------------------- + def test_streaming_creates_span(): """Streaming response creates and completes a span.""" server = _make_streaming_tracing_server() @@ -338,6 +388,7 @@ def test_streaming_creates_span(): # GenAI attributes on invoke span # --------------------------------------------------------------------------- + def test_genai_attributes_on_invoke_span(): """Invoke span has GenAI semantic convention attributes.""" server = _make_tracing_server() @@ -358,6 +409,7 @@ def test_genai_attributes_on_invoke_span(): # Session ID in microsoft.session.id # --------------------------------------------------------------------------- + def test_session_id_in_conversation_id(): """Session ID is set as microsoft.session.id on invoke span.""" server = _make_tracing_server() @@ -378,6 +430,7 @@ def test_session_id_in_conversation_id(): # GenAI attributes on get_invocation span # --------------------------------------------------------------------------- + def test_genai_attributes_on_get_span(): """GET invocation span has GenAI attributes.""" server = _make_tracing_server_with_get_cancel() @@ -398,6 +451,7 @@ def test_genai_attributes_on_get_span(): # Namespaced invocation_id attribute # --------------------------------------------------------------------------- + def test_namespaced_invocation_id_attribute(): """Invoke span has azure.ai.agentserver.invocations.invocation_id.""" server = _make_tracing_server() @@ -416,12 +470,16 @@ def test_namespaced_invocation_id_attribute(): # Agent name/version in span names # --------------------------------------------------------------------------- + def test_agent_name_in_span_name(): """Agent name from env var appears in span name.""" - with patch.dict(os.environ, { - "FOUNDRY_AGENT_NAME": "my-agent", - "FOUNDRY_AGENT_VERSION": "2.0", - }): + with patch.dict( + os.environ, + { + "FOUNDRY_AGENT_NAME": "my-agent", + "FOUNDRY_AGENT_VERSION": "2.0", + }, + ): server = _make_tracing_server() client = TestClient(server) @@ -456,6 +514,6 @@ def test_agent_name_only_in_span_name(): # Project endpoint attribute # --------------------------------------------------------------------------- + def test_project_endpoint_env_var(): """FOUNDRY_PROJECT_ENDPOINT constant matches the expected env var name.""" - diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py index 359799ce90f3..3a3a63bc5a39 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py @@ -38,7 +38,9 @@ def _flush_provider(): provider.force_flush() -def _poll_appinsights(logs_client, resource_id, query, *, timeout=_APPINSIGHTS_POLL_TIMEOUT): +def _poll_appinsights( + logs_client, resource_id, query, *, timeout=_APPINSIGHTS_POLL_TIMEOUT +): """Poll Application Insights until the KQL query returns >= 1 row or timeout.""" deadline = time.monotonic() + timeout while time.monotonic() < deadline: @@ -57,6 +59,7 @@ def _poll_appinsights(logs_client, resource_id, query, *, timeout=_APPINSIGHTS_P # E2E test # --------------------------------------------------------------------------- + class TestInvocationTracingE2E: """Verify InvocationAgentServerHost auto-creates traced spans that land in App Insights.""" @@ -76,7 +79,9 @@ async def handle(request: Request) -> Response: return Response(content=body, media_type="application/octet-stream") transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://testserver") as client: + async with AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: resp = await client.post("/invocations", content=b"hello e2e") assert resp.status_code == 200