diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 38dfdf56..bb86c32f 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -11,12 +11,12 @@ from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import GrpcHandler from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import ( @@ -190,22 +190,29 @@ async def serve( agent_executor=SampleAgentExecutor(), task_store=task_store ) - rest_app_builder = A2ARESTFastAPIApplication( + agent_card_routes = AgentCardRoutes( agent_card=agent_card, - http_handler=request_handler, + card_url='/.well-known/agent-card.json', + ) + # JSON-RPC + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=request_handler, + rpc_url='/a2a/jsonrpc/', enable_v0_3_compat=True, ) - rest_app = rest_app_builder.build() - - jsonrpc_app_builder = A2AFastAPIApplication( + # REST + rest_routes = RestRoutes( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, + rpc_url='/a2a/rest', enable_v0_3_compat=True, ) app = FastAPI() - jsonrpc_app_builder.add_routes_to_app(app, rpc_url='/a2a/jsonrpc/') - app.mount('/a2a/rest', rest_app) + app.routes.extend(agent_card_routes.routes) + app.routes.extend(jsonrpc_routes.routes) + app.routes.extend(rest_routes.routes) grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'{host}:{grpc_port}') diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index cdb701b5..b62c6f66 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -10,8 +10,8 @@ if TYPE_CHECKING: from starlette.requests import Request - from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder from a2a.server.request_handlers.request_handler import RequestHandler + from a2a.server.routes.jsonrpc_dispatcher import CallContextBuilder from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index b0296e40..91186059 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -33,12 +33,11 @@ from a2a.compat.v0_3 import conversions from a2a.compat.v0_3.rest_handler import REST03Handler -from a2a.server.apps.jsonrpc.jsonrpc_app import ( +from a2a.server.context import ServerCallContext +from a2a.server.routes.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, ) -from a2a.server.apps.rest.rest_adapter import RESTAdapterInterface -from a2a.server.context import ServerCallContext from a2a.utils.error_handlers import ( rest_error_handler, rest_stream_error_handler, @@ -53,7 +52,7 @@ logger = logging.getLogger(__name__) -class REST03Adapter(RESTAdapterInterface): +class REST03Adapter: """Adapter to make RequestHandler work with v0.3 RESTful API. Defines v0.3 REST request processors and their routes, as well as managing response generation including Server-Sent Events (SSE). diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py deleted file mode 100644 index 579deaa5..00000000 --- a/src/a2a/server/apps/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""HTTP application components for the A2A server.""" - -from a2a.server.apps.jsonrpc import ( - A2AFastAPIApplication, - A2AStarletteApplication, - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.apps.rest import A2ARESTFastAPIApplication - - -__all__ = [ - 'A2AFastAPIApplication', - 'A2ARESTFastAPIApplication', - 'A2AStarletteApplication', - 'CallContextBuilder', - 'JSONRPCApplication', -] diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py deleted file mode 100644 index 1121fdbc..00000000 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""A2A JSON-RPC Applications.""" - -from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - DefaultCallContextBuilder, - JSONRPCApplication, - StarletteUserProxy, -) -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication - - -__all__ = [ - 'A2AFastAPIApplication', - 'A2AStarletteApplication', - 'CallContextBuilder', - 'DefaultCallContextBuilder', - 'JSONRPCApplication', - 'StarletteUserProxy', -] diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py deleted file mode 100644 index 0ec9d1ab..00000000 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ /dev/null @@ -1,148 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from fastapi import FastAPI - - _package_fastapi_installed = True -else: - try: - from fastapi import FastAPI - - _package_fastapi_installed = True - except ImportError: - FastAPI = Any - - _package_fastapi_installed = False - -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) - - -logger = logging.getLogger(__name__) - - -class A2AFastAPIApplication(JSONRPCApplication): - """A FastAPI application implementing the A2A protocol server endpoints. - - Handles incoming JSON-RPC requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB - enable_v0_3_compat: bool = False, - ) -> None: - """Initializes the A2AFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - """ - if not _package_fastapi_installed: - raise ImportError( - 'The `fastapi` package is required to use the `A2AFastAPIApplication`.' - ' It can be added as a part of `a2a-sdk` optional dependencies,' - ' `a2a-sdk[http-server]`.' - ) - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - max_content_length=max_content_length, - enable_v0_3_compat=enable_v0_3_compat, - ) - - def add_routes_to_app( - self, - app: FastAPI, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> None: - """Adds the routes to the FastAPI application. - - Args: - app: The FastAPI application to add the routes to. - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - """ - app.post( - rpc_url, - openapi_extra={ - 'requestBody': { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/A2ARequest' - } - } - }, - 'required': True, - 'description': 'A2ARequest', - } - }, - )(self._handle_requests) - app.get(agent_card_url)(self._handle_get_agent_card) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> FastAPI: - """Builds and returns the FastAPI application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured FastAPI application instance. - """ - app = FastAPI(**kwargs) - - self.add_routes_to_app(app, agent_card_url, rpc_url) - - return app diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py deleted file mode 100644 index 553fa250..00000000 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ /dev/null @@ -1,169 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from starlette.applications import Starlette - from starlette.routing import Route - - _package_starlette_installed = True - -else: - try: - from starlette.applications import Starlette - from starlette.routing import Route - - _package_starlette_installed = True - except ImportError: - Starlette = Any - Route = Any - - _package_starlette_installed = False - -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) - - -logger = logging.getLogger(__name__) - - -class A2AStarletteApplication(JSONRPCApplication): - """A Starlette application implementing the A2A protocol server endpoints. - - Handles incoming JSON-RPC requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB - enable_v0_3_compat: bool = False, - ) -> None: - """Initializes the A2AStarletteApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use the' - ' `A2AStarletteApplication`. It can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - max_content_length=max_content_length, - enable_v0_3_compat=enable_v0_3_compat, - ) - - def routes( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> list[Route]: - """Returns the Starlette Routes for handling A2A requests. - - Args: - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - - Returns: - A list of Starlette Route objects. - """ - return [ - Route( - rpc_url, - self._handle_requests, - methods=['POST'], - name='a2a_handler', - ), - Route( - agent_card_url, - self._handle_get_agent_card, - methods=['GET'], - name='agent_card', - ), - ] - - def add_routes_to_app( - self, - app: Starlette, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> None: - """Adds the routes to the Starlette application. - - Args: - app: The Starlette application to add the routes to. - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - """ - routes = self.routes( - agent_card_url=agent_card_url, - rpc_url=rpc_url, - ) - app.routes.extend(routes) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> Starlette: - """Builds and returns the Starlette application instance. - - Args: - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - **kwargs: Additional keyword arguments to pass to the Starlette constructor. - - Returns: - A configured Starlette application instance. - """ - app = Starlette(**kwargs) - - self.add_routes_to_app(app, agent_card_url, rpc_url) - - return app diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py deleted file mode 100644 index bafe4cb6..00000000 --- a/src/a2a/server/apps/rest/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""A2A REST Applications.""" - -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication - - -__all__ = [ - 'A2ARESTFastAPIApplication', -] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py deleted file mode 100644 index ea9a501b..00000000 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ /dev/null @@ -1,194 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True -else: - try: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True - except ImportError: - APIRouter = Any - FastAPI = Any - Request = Any - Response = Any - StarletteHTTPException = Any - - _package_fastapi_installed = False - - -from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder -from a2a.server.apps.rest.rest_adapter import RESTAdapter -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - - -logger = logging.getLogger(__name__) - - -_HTTP_TO_GRPC_STATUS_MAP = { - 400: 'INVALID_ARGUMENT', - 401: 'UNAUTHENTICATED', - 403: 'PERMISSION_DENIED', - 404: 'NOT_FOUND', - 405: 'UNIMPLEMENTED', - 409: 'ALREADY_EXISTS', - 415: 'INVALID_ARGUMENT', - 422: 'INVALID_ARGUMENT', - 500: 'INTERNAL', - 501: 'UNIMPLEMENTED', - 502: 'INTERNAL', - 503: 'UNAVAILABLE', - 504: 'DEADLINE_EXCEEDED', -} - - -class A2ARESTFastAPIApplication: - """A FastAPI application implementing the A2A protocol server REST endpoints. - - Handles incoming REST requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - enable_v0_3_compat: bool = False, - ): - """Initializes the A2ARESTFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol - endpoints under the '/v0.3' path prefix using REST03Adapter. - """ - if not _package_fastapi_installed: - raise ImportError( - 'The `fastapi` package is required to use the' - ' `A2ARESTFastAPIApplication`. It can be added as a part of' - ' `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' - ) - self._adapter = RESTAdapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - self.enable_v0_3_compat = enable_v0_3_compat - self._v03_adapter = None - - if self.enable_v0_3_compat: - self._v03_adapter = REST03Adapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = '', - **kwargs: Any, - ) -> FastAPI: - """Builds and returns the FastAPI application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A REST endpoint base path. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured FastAPI application instance. - """ - app = FastAPI(**kwargs) - - @app.exception_handler(StarletteHTTPException) - async def http_exception_handler( - request: Request, exc: StarletteHTTPException - ) -> Response: - """Catches framework-level HTTP exceptions. - - For example, 404 Not Found for bad routes, 422 Unprocessable Entity - for schema validation, and formats them into the A2A standard - google.rpc.Status JSON format (AIP-193). - """ - grpc_status = _HTTP_TO_GRPC_STATUS_MAP.get( - exc.status_code, 'UNKNOWN' - ) - return JSONResponse( - status_code=exc.status_code, - content={ - 'error': { - 'code': exc.status_code, - 'status': grpc_status, - 'message': str(exc.detail) - if hasattr(exc, 'detail') - else 'HTTP Exception', - } - }, - media_type='application/json', - ) - - if self.enable_v0_3_compat and self._v03_adapter: - v03_adapter = self._v03_adapter - v03_router = APIRouter() - for route, callback in v03_adapter.routes().items(): - v03_router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - app.include_router(v03_router) - - router = APIRouter() - for route, callback in self._adapter.routes().items(): - router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - - @router.get(f'{rpc_url}{agent_card_url}') - async def get_agent_card(request: Request) -> Response: - card = await self._adapter.handle_get_agent_card(request) - return JSONResponse(card) - - app.include_router(router) - - return app diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py deleted file mode 100644 index 6b8abb99..00000000 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ /dev/null @@ -1,296 +0,0 @@ -import functools -import json -import logging - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import MessageToDict - -from a2a.utils.helpers import maybe_await - - -if TYPE_CHECKING: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - -else: - try: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - except ImportError: - EventSourceResponse = Any - Request = Any - JSONResponse = Any - Response = Any - - _package_starlette_installed = False - -from a2a.server.apps.jsonrpc import ( - CallContextBuilder, - DefaultCallContextBuilder, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, -) -from a2a.server.request_handlers.rest_handler import RESTHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.error_handlers import ( - rest_error_handler, - rest_stream_error_handler, -) -from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, - InvalidRequestError, -) - - -logger = logging.getLogger(__name__) - - -class RESTAdapterInterface(ABC): - """Interface for RESTAdapter.""" - - @abstractmethod - async def handle_get_agent_card( - self, request: 'Request', call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint.""" - - @abstractmethod - def routes(self) -> dict[tuple[str, str], Callable[['Request'], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers.""" - - -class RESTAdapter(RESTAdapterInterface): - """Adapter to make RequestHandler work with RESTful API. - - Defines REST requests processors and the routes to attach them too, as well as - manages response generation including Server-Sent Events (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - ): - """Initializes the RESTApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use' - ' the `RESTAdapter`. They can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier - self.handler = RESTHandler( - agent_card=agent_card, request_handler=http_handler - ) - self._context_builder = context_builder or DefaultCallContextBuilder() - - @rest_error_handler - async def _handle_request( - self, - method: Callable[[Request, ServerCallContext], Awaitable[Any]], - request: Request, - ) -> Response: - call_context = self._build_call_context(request) - - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_stream_error_handler - async def _handle_streaming_request( - self, - method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], - request: Request, - ) -> EventSourceResponse: - # Pre-consume and cache the request body to prevent deadlock in streaming context - # This is required because Starlette's request.body() can only be consumed once, - # and attempting to consume it after EventSourceResponse starts causes deadlock - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - call_context = self._build_call_context(request) - - async def event_generator( - stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: - async for item in stream: - yield json.dumps(item) - - return EventSourceResponse( - event_generator(method(request, call_context)) - ) - - async def handle_get_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return agent_card_to_dict(card_to_serve) - - async def _handle_authenticated_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Hook for per credential agent card response. - - If a dynamic card is needed based on the credentials provided in the request - override this method and return the customized content. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the authenticated card. - """ - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card - - if not card_to_serve: - card_to_serve = self.agent_card - - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=True) - - def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers. - - This method maps URL paths and HTTP methods to the appropriate handler - functions from the RESTHandler. It can be used by a web framework - (like Starlette or FastAPI) to set up the application's endpoints. - - Returns: - A dictionary where each key is a tuple of (path, http_method) and - the value is the callable handler for that route. - """ - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { - ('/message:send', 'POST'): functools.partial( - self._handle_request, self.handler.on_message_send - ), - ('/message:stream', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_message_send_stream, - ), - ('/tasks/{id}:cancel', 'POST'): functools.partial( - self._handle_request, self.handler.on_cancel_task - ), - ('/tasks/{id}:subscribe', 'GET'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}:subscribe', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}', 'GET'): functools.partial( - self._handle_request, self.handler.on_get_task - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'GET', - ): functools.partial( - self._handle_request, self.handler.get_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'DELETE', - ): functools.partial( - self._handle_request, self.handler.delete_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'POST', - ): functools.partial( - self._handle_request, self.handler.set_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'GET', - ): functools.partial( - self._handle_request, self.handler.list_push_notifications - ), - ('/tasks', 'GET'): functools.partial( - self._handle_request, self.handler.list_tasks - ), - } - - if self.agent_card.capabilities.extended_agent_card: - base_routes[('/extendedAgentCard', 'GET')] = functools.partial( - self._handle_request, self._handle_authenticated_agent_card - ) - - routes: dict[tuple[str, str], Callable[[Request], Any]] = { - (p, method): handler - for (path, method), handler in base_routes.items() - for p in (path, f'/{{tenant}}{path}') - } - - return routes - - def _build_call_context(self, request: Request) -> ServerCallContext: - call_context = self._context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] - return call_context diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 43ebc8e2..ef4f0a74 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -5,13 +5,11 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, ) -from a2a.server.request_handlers.rest_handler import RESTHandler logger = logging.getLogger(__name__) @@ -40,8 +38,6 @@ def __init__(self, *args, **kwargs): __all__ = [ 'DefaultRequestHandler', 'GrpcHandler', - 'JSONRPCHandler', - 'RESTHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py deleted file mode 100644 index e7d5b75a..00000000 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ /dev/null @@ -1,474 +0,0 @@ -"""JSON-RPC handler for A2A server requests.""" - -import logging - -from collections.abc import AsyncIterable, Awaitable, Callable -from typing import Any - -from google.protobuf.json_format import MessageToDict -from jsonrpc.jsonrpc2 import JSONRPC20Response - -from a2a.server.context import ServerCallContext -from a2a.server.jsonrpc_models import ( - InternalError as JSONRPCInternalError, -) -from a2a.server.jsonrpc_models import ( - JSONRPCError, -) -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - DeleteTaskPushNotificationConfigRequest, - GetExtendedAgentCardRequest, - GetTaskPushNotificationConfigRequest, - GetTaskRequest, - ListTaskPushNotificationConfigsRequest, - ListTasksRequest, - SendMessageRequest, - SendMessageResponse, - SubscribeToTaskRequest, - Task, - TaskPushNotificationConfig, -) -from a2a.utils import proto_utils -from a2a.utils.errors import ( - JSON_RPC_ERROR_CODE_MAP, - A2AError, - ContentTypeNotSupportedError, - ExtendedAgentCardNotConfiguredError, - ExtensionSupportRequiredError, - InternalError, - InvalidAgentResponseError, - InvalidParamsError, - InvalidRequestError, - MethodNotFoundError, - PushNotificationNotSupportedError, - TaskNotCancelableError, - TaskNotFoundError, - UnsupportedOperationError, - VersionNotSupportedError, -) -from a2a.utils.helpers import maybe_await, validate, validate_async_generator -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -EXCEPTION_MAP: dict[type[A2AError], type[JSONRPCError]] = { - TaskNotFoundError: JSONRPCError, - TaskNotCancelableError: JSONRPCError, - PushNotificationNotSupportedError: JSONRPCError, - UnsupportedOperationError: JSONRPCError, - ContentTypeNotSupportedError: JSONRPCError, - InvalidAgentResponseError: JSONRPCError, - ExtendedAgentCardNotConfiguredError: JSONRPCError, - InternalError: JSONRPCInternalError, - InvalidParamsError: JSONRPCError, - InvalidRequestError: JSONRPCError, - MethodNotFoundError: JSONRPCError, - ExtensionSupportRequiredError: JSONRPCError, - VersionNotSupportedError: JSONRPCError, -} - - -def _build_success_response( - request_id: str | int | None, result: Any -) -> dict[str, Any]: - """Build a JSON-RPC success response dict.""" - return JSONRPC20Response(result=result, _id=request_id).data - - -def _build_error_response( - request_id: str | int | None, error: Exception -) -> dict[str, Any]: - """Build a JSON-RPC error response dict.""" - jsonrpc_error: JSONRPCError - if isinstance(error, A2AError): - error_type = type(error) - model_class = EXCEPTION_MAP.get(error_type, JSONRPCInternalError) - code = JSON_RPC_ERROR_CODE_MAP.get(error_type, -32603) - jsonrpc_error = model_class( - code=code, - message=str(error), - ) - else: - jsonrpc_error = JSONRPCInternalError(message=str(error)) - - error_dict = jsonrpc_error.model_dump(exclude_none=True) - return JSONRPC20Response(error=error_dict, _id=request_id).data - - -@trace_class(kind=SpanKind.SERVER) -class JSONRPCHandler: - """Maps incoming JSON-RPC requests to the appropriate request handler method and formats responses.""" - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - ): - """Initializes the JSONRPCHandler. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegate requests to. - extended_agent_card: An optional, distinct Extended AgentCard to be served - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - """ - self.agent_card = agent_card - self.request_handler = request_handler - self.extended_agent_card = extended_agent_card - self.extended_card_modifier = extended_card_modifier - self.card_modifier = card_modifier - - def _get_request_id( - self, context: ServerCallContext | None - ) -> str | int | None: - """Get the JSON-RPC request ID from the context.""" - if context is None: - return None - return context.state.get('request_id') - - async def on_message_send( - self, - request: SendMessageRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'message/send' JSON-RPC method. - - Args: - request: The incoming `SendMessageRequest` proto message. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task_or_message = await self.request_handler.on_message_send( - request, context - ) - if isinstance(task_or_message, Task): - response = SendMessageResponse(task=task_or_message) - else: - response = SendMessageResponse(message=task_or_message) - - result = MessageToDict(response) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - request: SendMessageRequest, - context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: - """Handles the 'message/stream' JSON-RPC method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `SendMessageRequest` object (for streaming). - context: Context provided by the server. - - Yields: - Dict representations of JSON-RPC responses containing streaming events. - """ - try: - async for event in self.request_handler.on_message_send_stream( - request, context - ): - # Wrap the event in StreamResponse for consistent client parsing - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield _build_success_response( - self._get_request_id(context), result - ) - except A2AError as e: - yield _build_error_response( - self._get_request_id(context), - e, - ) - - async def on_cancel_task( - self, - request: CancelTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/cancel' JSON-RPC method. - - Args: - request: The incoming `CancelTaskRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task = await self.request_handler.on_cancel_task(request, context) - except A2AError as e: - return _build_error_response(request_id, e) - - if task: - result = MessageToDict(task, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - - return _build_error_response(request_id, TaskNotFoundError()) - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - request: SubscribeToTaskRequest, - context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: - """Handles the 'SubscribeToTask' JSON-RPC method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `SubscribeToTaskRequest` object. - context: Context provided by the server. - - Yields: - Dict representations of JSON-RPC responses containing streaming events. - """ - try: - async for event in self.request_handler.on_subscribe_to_task( - request, context - ): - # Wrap the event in StreamResponse for consistent client parsing - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield _build_success_response( - self._get_request_id(context), result - ) - except A2AError as e: - yield _build_error_response( - self._get_request_id(context), - e, - ) - - async def get_push_notification_config( - self, - request: GetTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/get' JSON-RPC method. - - Args: - request: The incoming `GetTaskPushNotificationConfigRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - config = ( - await self.request_handler.on_get_task_push_notification_config( - request, context - ) - ) - result = MessageToDict(config, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification_config( - self, - request: TaskPushNotificationConfig, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/set' JSON-RPC method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - - Raises: - UnsupportedOperationError: If push notifications are not supported by the agent - (due to the `@validate` decorator). - """ - request_id = self._get_request_id(context) - try: - # Pass the full request to the handler - result_config = await self.request_handler.on_create_task_push_notification_config( - request, context - ) - result = MessageToDict( - result_config, preserving_proto_field_name=False - ) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - async def on_get_task( - self, - request: GetTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/get' JSON-RPC method. - - Args: - request: The incoming `GetTaskRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task = await self.request_handler.on_get_task(request, context) - except A2AError as e: - return _build_error_response(request_id, e) - - if task: - result = MessageToDict(task, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - - return _build_error_response(request_id, TaskNotFoundError()) - - async def list_tasks( - self, - request: ListTasksRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/list' JSON-RPC method. - - Args: - request: The incoming `ListTasksRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - response = await self.request_handler.on_list_tasks( - request, context - ) - result = MessageToDict( - response, - preserving_proto_field_name=False, - always_print_fields_with_no_presence=True, - ) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - async def list_push_notification_configs( - self, - request: ListTaskPushNotificationConfigsRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'ListTaskPushNotificationConfigs' JSON-RPC method. - - Args: - request: The incoming `ListTaskPushNotificationConfigsRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - response = await self.request_handler.on_list_task_push_notification_configs( - request, context - ) - # response is a ListTaskPushNotificationConfigsResponse proto - result = MessageToDict(response, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - async def delete_push_notification_config( - self, - request: DeleteTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/delete' JSON-RPC method. - - Args: - request: The incoming `DeleteTaskPushNotificationConfigRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - await self.request_handler.on_delete_task_push_notification_config( - request, context - ) - return _build_success_response(request_id, None) - except A2AError as e: - return _build_error_response(request_id, e) - - async def get_authenticated_extended_card( - self, - request: GetExtendedAgentCardRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'agent/authenticatedExtendedCard' JSON-RPC method. - - Args: - request: The incoming `GetExtendedAgentCardRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='The agent does not have an extended agent card configured' - ) - - base_card = self.extended_agent_card - if base_card is None: - base_card = self.agent_card - - card_to_serve = base_card - if self.extended_card_modifier and context: - card_to_serve = await maybe_await( - self.extended_card_modifier(base_card, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - - result = MessageToDict(card_to_serve, preserving_proto_field_name=False) - return _build_success_response(request_id, result) diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py deleted file mode 100644 index f01c1371..00000000 --- a/src/a2a/server/request_handlers/rest_handler.py +++ /dev/null @@ -1,321 +0,0 @@ -import logging - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import ( - MessageToDict, - Parse, -) - - -if TYPE_CHECKING: - from starlette.requests import Request -else: - try: - from starlette.requests import Request - except ImportError: - Request = Any - - -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - GetTaskPushNotificationConfigRequest, - SubscribeToTaskRequest, -) -from a2a.utils import proto_utils -from a2a.utils.errors import TaskNotFoundError -from a2a.utils.helpers import validate, validate_async_generator -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.SERVER) -class RESTHandler: - """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. - - This uses the protobuf definitions of the gRPC service as the source of truth. By - doing this, it ensures that this implementation and the gRPC transcoding - (via Envoy) are equivalent. This handler should be used if using the gRPC handler - with Envoy is not feasible for a given deployment solution. Use this handler - and a related application if you desire to ONLY server the RESTful API. - """ - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - ): - """Initializes the RESTHandler. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegate requests to. - """ - self.agent_card = agent_card - self.request_handler = request_handler - - async def on_message_send( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'message/send' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the result (Task or Message) - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - task_or_message = await self.request_handler.on_message_send( - params, context - ) - if isinstance(task_or_message, a2a_pb2.Task): - response = a2a_pb2.SendMessageResponse(task=task_or_message) - else: - response = a2a_pb2.SendMessageResponse(message=task_or_message) - return MessageToDict(response) - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'message/stream' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) - - async def on_cancel_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/cancel' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the updated Task - """ - task_id = request.path_params['id'] - task = await self.request_handler.on_cancel_task( - CancelTaskRequest(id=task_id), context - ) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'SubscribeToTask' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - """ - task_id = request.path_params['id'] - async for event in self.request_handler.on_subscribe_to_task( - SubscribeToTaskRequest(id=task_id), context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) - - async def get_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/get' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = GetTaskPushNotificationConfigRequest( - task_id=task_id, - id=push_id, - ) - config = ( - await self.request_handler.on_get_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/set' REST method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config object. - - Raises: - UnsupportedOperationError: If push notifications are not supported by the agent - (due to the `@validate` decorator), A2AError if processing error is - found. - """ - body = await request.body() - params = a2a_pb2.TaskPushNotificationConfig() - Parse(body, params) - # Set the parent to the task resource name format - params.task_id = request.path_params['id'] - config = ( - await self.request_handler.on_create_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - async def on_get_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/{id}' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `Task` object containing the Task. - """ - params = a2a_pb2.GetTaskRequest() - proto_utils.parse_params(request.query_params, params) - params.id = request.path_params['id'] - task = await self.request_handler.on_get_task(params, context) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - async def delete_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/delete' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - An empty `dict` representing the empty response. - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( - task_id=task_id, id=push_id - ) - await self.request_handler.on_delete_task_push_notification_config( - params, context - ) - return {} - - async def list_tasks( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `Task` objects. - """ - params = a2a_pb2.ListTasksRequest() - proto_utils.parse_params(request.query_params, params) - - result = await self.request_handler.on_list_tasks(params, context) - return MessageToDict(result, always_print_fields_with_no_presence=True) - - async def list_push_notifications( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `TaskPushNotificationConfig` objects. - """ - params = a2a_pb2.ListTaskPushNotificationConfigsRequest() - proto_utils.parse_params(request.query_params, params) - params.task_id = request.path_params['id'] - - result = ( - await self.request_handler.on_list_task_push_notification_configs( - params, context - ) - ) - return MessageToDict(result) diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py new file mode 100644 index 00000000..e94428a4 --- /dev/null +++ b/src/a2a/server/routes/__init__.py @@ -0,0 +1,20 @@ +"""A2A Server Routes.""" + +from a2a.server.routes.agent_card_route import AgentCardRoutes +from a2a.server.routes.jsonrpc_dispatcher import ( + CallContextBuilder, + DefaultCallContextBuilder, + StarletteUserProxy, +) +from a2a.server.routes.jsonrpc_route import JsonRpcRoutes +from a2a.server.routes.rest_routes import RestRoutes + + +__all__ = [ + 'AgentCardRoutes', + 'CallContextBuilder', + 'DefaultCallContextBuilder', + 'JsonRpcRoutes', + 'RestRoutes', + 'StarletteUserProxy', +] diff --git a/src/a2a/server/routes/agent_card_route.py b/src/a2a/server/routes/agent_card_route.py new file mode 100644 index 00000000..d1e27d68 --- /dev/null +++ b/src/a2a/server/routes/agent_card_route.py @@ -0,0 +1,84 @@ +import logging + +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True +else: + try: + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True + except ImportError: + Middleware = Any + Route = Any + Request = Any + Response = Any + JSONResponse = Any + + _package_starlette_installed = False + +from a2a.server.request_handlers.response_helpers import agent_card_to_dict +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.helpers import maybe_await +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + +logger = logging.getLogger(__name__) + + +class AgentCardRoutes: + """Provides the Starlette Route for the A2A protocol agent card endpoint.""" + + def __init__( + self, + agent_card: AgentCard, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + card_url: str = AGENT_CARD_WELL_KNOWN_PATH, + middleware: Sequence['Middleware'] | None = None, + ) -> None: + """Initializes the AgentCardRoute. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + card_url: The URL for the agent card endpoint. + middleware: An optional list of Starlette middleware to apply to the + agent card endpoint. + """ + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use the `AgentCardRoutes`.' + ' `a2a-sdk[http-server]`.' + ) + + self.agent_card = agent_card + self.card_modifier = card_modifier + + async def get_agent_card(request: Request) -> Response: + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = await maybe_await( + self.card_modifier(card_to_serve) + ) + return JSONResponse(agent_card_to_dict(card_to_serve)) + + self.routes = [ + Route( + path=card_url, + endpoint=get_agent_card, + methods=['GET'], + middleware=middleware, + ) + ] diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/routes/jsonrpc_dispatcher.py similarity index 77% rename from src/a2a/server/apps/jsonrpc/jsonrpc_app.py rename to src/a2a/server/routes/jsonrpc_dispatcher.py index 0d79b10e..e50ca034 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -9,8 +9,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TYPE_CHECKING, Any -from google.protobuf.json_format import ParseDict -from jsonrpc.jsonrpc2 import JSONRPC20Request +from google.protobuf.json_format import MessageToDict, ParseDict +from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response from a2a.auth.user import UnauthenticatedUser from a2a.auth.user import User as A2AUser @@ -28,10 +28,8 @@ JSONRPCError, MethodNotFoundError, ) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, build_error_response, ) from a2a.types import A2ARequest @@ -45,15 +43,16 @@ ListTaskPushNotificationConfigsRequest, ListTasksRequest, SendMessageRequest, + SendMessageResponse, SubscribeToTaskRequest, + Task, TaskPushNotificationConfig, ) -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) +from a2a.utils import proto_utils from a2a.utils.errors import ( A2AError, + ExtendedAgentCardNotConfiguredError, + TaskNotFoundError, UnsupportedOperationError, ) from a2a.utils.helpers import maybe_await @@ -167,7 +166,7 @@ def build(self, request: Request) -> ServerCallContext: ) -class JSONRPCApplication(ABC): +class JsonRpcDispatcher: """Base class for A2A JSONRPC applications. Handles incoming JSON-RPC requests, routes them to the appropriate @@ -195,7 +194,7 @@ class JSONRPCApplication(ABC): def __init__( # noqa: PLR0913 self, agent_card: AgentCard, - http_handler: RequestHandler, + request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] @@ -204,7 +203,6 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB enable_v0_3_compat: bool = False, ) -> None: """Initializes the JSONRPCApplication. @@ -238,21 +236,15 @@ def __init__( # noqa: PLR0913 self.extended_agent_card = extended_agent_card self.card_modifier = card_modifier self.extended_card_modifier = extended_card_modifier - self.handler = JSONRPCHandler( - agent_card=agent_card, - request_handler=http_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=extended_card_modifier, - ) + self.request_handler = request_handler self._context_builder = context_builder or DefaultCallContextBuilder() - self._max_content_length = max_content_length self.enable_v0_3_compat = enable_v0_3_compat self._v03_adapter: JSONRPC03Adapter | None = None if self.enable_v0_3_compat: self._v03_adapter = JSONRPC03Adapter( agent_card=agent_card, - http_handler=http_handler, + http_handler=request_handler, extended_agent_card=extended_agent_card, context_builder=context_builder, card_modifier=card_modifier, @@ -317,7 +309,7 @@ def _allowed_content_length(self, request: Request) -> bool: return False return True - async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 + async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 """Handles incoming POST requests to the main A2A endpoint. Parses the request body as JSON, validates it against A2A request types, @@ -423,15 +415,26 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 call_context.state['method'] = method call_context.state['request_id'] = request_id + handler_result: ( + AsyncGenerator[dict[str, Any], None] | dict[str, Any] + ) # Route streaming requests by method name if method in ('SendStreamingMessage', 'SubscribeToTask'): - return await self._process_streaming_request( + handler_result = await self._process_streaming_request( request_id, specific_request, call_context ) + else: + try: + raw_result = await self._process_non_streaming_request( + request_id, specific_request, call_context + ) + handler_result = JSONRPC20Response( + result=raw_result, _id=request_id + ).data + except A2AError as e: + handler_result = build_error_response(request_id, e) - return await self._process_non_streaming_request( - request_id, specific_request, call_context - ) + return self._create_response(call_context, handler_result) except json.decoder.JSONDecodeError as e: traceback.print_exc() return self._generate_error_response( @@ -455,7 +458,7 @@ async def _process_streaming_request( request_id: str | int | None, request_obj: A2ARequest, context: ServerCallContext, - ) -> Response: + ) -> AsyncGenerator[dict[str, Any], None]: """Processes streaming requests (SendStreamingMessage or SubscribeToTask). Args: @@ -464,31 +467,48 @@ async def _process_streaming_request( context: The ServerCallContext for the request. Returns: - An `EventSourceResponse` object to stream results to the client. + An async generator yielding JSON-RPC response dicts. """ - handler_result: Any = None - # Check for streaming message request (same type as SendMessage, but handled differently) - if isinstance( - request_obj, - SendMessageRequest, - ): - handler_result = self.handler.on_message_send_stream( + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError( + message='Streaming is not supported by the agent' + ) + + stream: AsyncGenerator | None = None + if isinstance(request_obj, SendMessageRequest): + stream = self.request_handler.on_message_send_stream( request_obj, context ) elif isinstance(request_obj, SubscribeToTaskRequest): - handler_result = self.handler.on_subscribe_to_task( + stream = self.request_handler.on_subscribe_to_task( request_obj, context ) - return self._create_response(context, handler_result) + if stream is None: + raise UnsupportedOperationError(message='Stream not supported') + + async def _wrap_stream( + st: AsyncGenerator, + ) -> AsyncGenerator[dict[str, Any], None]: + try: + async for event in st: + stream_response = proto_utils.to_stream_response(event) + result = MessageToDict( + stream_response, preserving_proto_field_name=False + ) + yield JSONRPC20Response(result=result, _id=request_id).data + except A2AError as e: + yield build_error_response(request_id, e) + + return _wrap_stream(stream) - async def _process_non_streaming_request( + async def _process_non_streaming_request( # noqa: PLR0911, PLR0912 self, request_id: str | int | None, request_obj: A2ARequest, context: ServerCallContext, - ) -> Response: - """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). + ) -> Any: + """Processes non-streaming requests and returns the raw result data. Args: request_id: The ID of the request. @@ -496,71 +516,92 @@ async def _process_non_streaming_request( context: The ServerCallContext for the request. Returns: - A `JSONResponse` object containing the result or error. + The raw response object or dict, to be wrapped in a success response. """ - handler_result: Any = None match request_obj: case SendMessageRequest(): - handler_result = await self.handler.on_message_send( + task_or_message = await self.request_handler.on_message_send( request_obj, context ) + if isinstance(task_or_message, Task): + msg_response = SendMessageResponse(task=task_or_message) + else: + msg_response = SendMessageResponse(message=task_or_message) + return MessageToDict(msg_response) case CancelTaskRequest(): - handler_result = await self.handler.on_cancel_task( + task = await self.request_handler.on_cancel_task( request_obj, context ) + if not task: + raise TaskNotFoundError + return MessageToDict(task, preserving_proto_field_name=False) case GetTaskRequest(): - handler_result = await self.handler.on_get_task( - request_obj, context - ) + task = await self.request_handler.on_get_task(request_obj, context) + if not task: + raise TaskNotFoundError + return MessageToDict(task, preserving_proto_field_name=False) case ListTasksRequest(): - handler_result = await self.handler.list_tasks( + tasks_response = await self.request_handler.on_list_tasks( request_obj, context ) + return MessageToDict( + tasks_response, + preserving_proto_field_name=False, + always_print_fields_with_no_presence=True, + ) case TaskPushNotificationConfig(): - handler_result = ( - await self.handler.set_push_notification_config( - request_obj, - context, + if not self.agent_card.capabilities.push_notifications: + raise UnsupportedOperationError( + message='Push notifications are not supported by the agent' ) + result_config = await self.request_handler.on_create_task_push_notification_config( + request_obj, context + ) + return MessageToDict( + result_config, preserving_proto_field_name=False ) case GetTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.get_push_notification_config( - request_obj, - context, - ) + config = await self.request_handler.on_get_task_push_notification_config( + request_obj, context ) + return MessageToDict(config, preserving_proto_field_name=False) case ListTaskPushNotificationConfigsRequest(): - handler_result = ( - await self.handler.list_push_notification_configs( - request_obj, - context, - ) + list_push_response = await self.request_handler.on_list_task_push_notification_configs( + request_obj, context + ) + return MessageToDict( + list_push_response, preserving_proto_field_name=False ) case DeleteTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.delete_push_notification_config( - request_obj, - context, - ) + await self.request_handler.on_delete_task_push_notification_config( + request_obj, context ) + return None case GetExtendedAgentCardRequest(): - handler_result = ( - await self.handler.get_authenticated_extended_card( - request_obj, - context, + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='The agent does not have an extended agent card configured' + ) + base_card = self.extended_agent_card or self.agent_card + card_to_serve = base_card + if self.extended_card_modifier and context: + card_to_serve = await maybe_await( + self.extended_card_modifier(base_card, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await( + self.card_modifier(base_card) ) + return MessageToDict( + card_to_serve, preserving_proto_field_name=False ) case _: logger.error( 'Unhandled validated request type: %s', type(request_obj) ) - error = UnsupportedOperationError( + raise UnsupportedOperationError( message=f'Request type {type(request_obj).__name__} is unknown.' ) - return self._generate_error_response(request_id, error) - - return self._create_response(context, handler_result) def _create_response( self, @@ -598,43 +639,3 @@ async def event_generator( # handler_result is a dict (JSON-RPC response) return JSONResponse(handler_result, headers=headers) - - async def _handle_get_agent_card(self, request: Request) -> JSONResponse: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return JSONResponse( - agent_card_to_dict( - card_to_serve, - ) - ) - - @abstractmethod - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> FastAPI | Starlette: - """Builds and returns the JSONRPC application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured JSONRPC application instance. - """ - raise NotImplementedError( - 'Subclasses must implement the build method to create the application instance.' - ) diff --git a/src/a2a/server/routes/jsonrpc_route.py b/src/a2a/server/routes/jsonrpc_route.py new file mode 100644 index 00000000..73bca828 --- /dev/null +++ b/src/a2a/server/routes/jsonrpc_route.py @@ -0,0 +1,107 @@ +import logging + +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from starlette.middleware import Middleware + from starlette.routing import Route, Router + + _package_starlette_installed = True +else: + try: + from starlette.middleware import Middleware + from starlette.routing import Route, Router + + _package_starlette_installed = True + except ImportError: + Middleware = Any + Route = Any + Router = Any + + _package_starlette_installed = False + + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes.jsonrpc_dispatcher import ( + CallContextBuilder, + JsonRpcDispatcher, +) +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.constants import DEFAULT_RPC_URL + + +logger = logging.getLogger(__name__) + + +class JsonRpcRoutes: + """Provides the Starlette Route for the A2A protocol JSON-RPC endpoint. + + Handles incoming JSON-RPC requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + request_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + enable_v0_3_compat: bool = False, + rpc_url: str = DEFAULT_RPC_URL, + middleware: Sequence[Middleware] | None = None, + ) -> None: + """Initializes the JsonRpcRoute. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + rpc_url: The URL prefix for the RPC endpoints. + middleware: An optional list of Starlette middleware to apply to the routes. + """ + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use the `JsonRpcRoutes`.' + ' It can be added as a part of `a2a-sdk` optional dependencies,' + ' `a2a-sdk[http-server]`.' + ) + + self.dispatcher = JsonRpcDispatcher( + agent_card=agent_card, + http_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + enable_v0_3_compat=enable_v0_3_compat, + ) + + self.routes = [ + Route( + path=rpc_url, + endpoint=self.dispatcher.handle_requests, + methods=['POST'], + middleware=middleware, + ) + ] diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py new file mode 100644 index 00000000..0ab24fcb --- /dev/null +++ b/src/a2a/server/routes/rest_routes.py @@ -0,0 +1,382 @@ +import logging + +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any + +from a2a.compat.v0_3.rest_adapter import ( + REST03Adapter as V03RESTAdapter, +) + + +if TYPE_CHECKING: + from sse_starlette.sse import EventSourceResponse + from starlette.exceptions import HTTPException as StarletteHTTPException + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True +else: + try: + from sse_starlette.sse import EventSourceResponse + from starlette.exceptions import HTTPException as StarletteHTTPException + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True + except ImportError: + Middleware = Any + Route = Any + EventSourceResponse = Any + Request = Any + JSONResponse = Any + Response = Any + StarletteHTTPException = Any + + _package_starlette_installed = False + +import json + +from google.protobuf.json_format import MessageToDict, Parse + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes.jsonrpc_dispatcher import ( + CallContextBuilder, + DefaultCallContextBuilder, +) +from a2a.types import a2a_pb2 +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils import proto_utils +from a2a.utils.error_handlers import ( + rest_error_handler, + rest_stream_error_handler, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + InvalidRequestError, + TaskNotFoundError, + UnsupportedOperationError, +) +from a2a.utils.helpers import maybe_await +from a2a.utils.constants import DEFAULT_RPC_URL + + +logger = logging.getLogger(__name__) + + +class RestRoutes: + """Provides the Starlette Routes for the A2A protocol REST endpoints. + + Handles incoming JSON-REST requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + request_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + enable_v0_3_compat: bool = False, + rpc_url: str = DEFAULT_RPC_URL, + middleware: Sequence['Middleware'] | None = None, + ) -> None: + """Initializes the RestRoutes. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + rpc_url: The URL prefix for the RPC endpoints. + middleware: An optional list of Starlette middleware to apply to the routes. + """ + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use the `RestRoutes`.' + ' It can be added as a part of `a2a-sdk` optional dependencies,' + ' `a2a-sdk[http-server]`.' + ) + + self.agent_card = agent_card + self.request_handler = request_handler + self.extended_agent_card = extended_agent_card + self._context_builder = context_builder or DefaultCallContextBuilder() + self.card_modifier = card_modifier + self.extended_card_modifier = extended_card_modifier + self.enable_v0_3_compat = enable_v0_3_compat + + self._v03_adapter = None + if enable_v0_3_compat: + self._v03_adapter = V03RESTAdapter( + agent_card=agent_card, + http_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + ) + + self._setup_routes(rpc_url) + + def _build_call_context(self, request: Request) -> ServerCallContext: + call_context = self._context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] + return call_context + + def _setup_routes( + self, + rpc_url: str, + ) -> None: + """Sets up the Starlette routes.""" + self.routes = [] + if self.enable_v0_3_compat and self._v03_adapter: + for route, callback in self._v03_adapter.routes().items(): + self.routes.append( + Route( + path=f'{rpc_url}{route[0]}', + endpoint=callback, + methods=[route[1]], + ) + ) + + base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { + ('/message:send', 'POST'): self._message_send, + ('/message:stream', 'POST'): self._message_stream, + ('/tasks/{id}:cancel', 'POST'): self._cancel_task, + ('/tasks/{id}:subscribe', 'GET'): self._subscribe_task, + ('/tasks/{id}:subscribe', 'POST'): self._subscribe_task, + ('/tasks/{id}', 'GET'): self._get_task, + ( + '/tasks/{id}/pushNotificationConfigs/{push_id}', + 'GET', + ): self._get_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs/{push_id}', + 'DELETE', + ): self._delete_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs', + 'POST', + ): self._set_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs', + 'GET', + ): self._list_push_notifications, + ('/tasks', 'GET'): self._list_tasks, + ('/extendedAgentCard', 'GET'): self._get_extended_agent_card, + } + + self.routes.extend( + [ + Route( + path=f'{rpc_url}{p}', + endpoint=handler, + methods=[method], + ) + for (path, method), handler in base_routes.items() + for p in (path, f'/{{tenant}}{path}') + ] + ) + + @rest_error_handler + async def _message_send(self, request: Request) -> Response: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + context = self._build_call_context(request) + task_or_message = await self.request_handler.on_message_send( + params, context + ) + if isinstance(task_or_message, a2a_pb2.Task): + response = a2a_pb2.SendMessageResponse(task=task_or_message) + else: + response = a2a_pb2.SendMessageResponse(message=task_or_message) + return JSONResponse(MessageToDict(response)) + + @rest_stream_error_handler + async def _message_stream(self, request: Request) -> EventSourceResponse: + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError( + message='Streaming is not supported by the agent' + ) + + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + context = self._build_call_context(request) + + async def event_generator() -> AsyncIterator[str]: + async for event in self.request_handler.on_message_send_stream( + params, context + ): + yield json.dumps( + MessageToDict(proto_utils.to_stream_response(event)) + ) + + return EventSourceResponse(event_generator()) + + @rest_error_handler + async def _cancel_task(self, request: Request) -> Response: + task_id = request.path_params['id'] + params = a2a_pb2.CancelTaskRequest(id=task_id) + context = self._build_call_context(request) + task = await self.request_handler.on_cancel_task(params, context) + if not task: + raise TaskNotFoundError + return JSONResponse(MessageToDict(task)) + + @rest_stream_error_handler + async def _subscribe_task(self, request: Request) -> EventSourceResponse: + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + task_id = request.path_params['id'] + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError( + message='Streaming is not supported by the agent' + ) + params = a2a_pb2.SubscribeToTaskRequest(id=task_id) + context = self._build_call_context(request) + + async def event_generator() -> AsyncIterator[str]: + async for event in self.request_handler.on_subscribe_to_task( + params, context + ): + yield json.dumps( + MessageToDict(proto_utils.to_stream_response(event)) + ) + + return EventSourceResponse(event_generator()) + + @rest_error_handler + async def _get_task(self, request: Request) -> Response: + params = a2a_pb2.GetTaskRequest() + proto_utils.parse_params(request.query_params, params) + params.id = request.path_params['id'] + context = self._build_call_context(request) + task = await self.request_handler.on_get_task(params, context) + if not task: + raise TaskNotFoundError + return JSONResponse(MessageToDict(task)) + + @rest_error_handler + async def _get_push_notification(self, request: Request) -> Response: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.GetTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + context = self._build_call_context(request) + config = ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) + ) + return JSONResponse(MessageToDict(config)) + + @rest_error_handler + async def _delete_push_notification(self, request: Request) -> Response: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + context = self._build_call_context(request) + await self.request_handler.on_delete_task_push_notification_config( + params, context + ) + return JSONResponse({}) + + @rest_error_handler + async def _set_push_notification(self, request: Request) -> Response: + if not self.agent_card.capabilities.push_notifications: + raise UnsupportedOperationError( + message='Push notifications are not supported by the agent' + ) + body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) + params.task_id = request.path_params['id'] + context = self._build_call_context(request) + config = ( + await self.request_handler.on_create_task_push_notification_config( + params, context + ) + ) + return JSONResponse(MessageToDict(config)) + + @rest_error_handler + async def _list_push_notifications(self, request: Request) -> Response: + params = a2a_pb2.ListTaskPushNotificationConfigsRequest() + proto_utils.parse_params(request.query_params, params) + params.task_id = request.path_params['id'] + context = self._build_call_context(request) + result = ( + await self.request_handler.on_list_task_push_notification_configs( + params, context + ) + ) + return JSONResponse(MessageToDict(result)) + + @rest_error_handler + async def _list_tasks(self, request: Request) -> Response: + params = a2a_pb2.ListTasksRequest() + proto_utils.parse_params(request.query_params, params) + context = self._build_call_context(request) + result = await self.request_handler.on_list_tasks(params, context) + return JSONResponse( + MessageToDict(result, always_print_fields_with_no_presence=True) + ) + + @rest_error_handler + async def _get_extended_agent_card(self, request: Request) -> Response: + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='Authenticated card not supported' + ) + card_to_serve = self.extended_agent_card or self.agent_card + + if self.extended_card_modifier: + context = self._build_call_context(request) + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) + + return JSONResponse( + MessageToDict(card_to_serve, preserving_proto_field_name=True) + ) diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 7196b828..f2cc03b3 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -16,15 +16,12 @@ from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps import ( - A2ARESTFastAPIApplication, - A2AStarletteApplication, -) from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) from a2a.server.request_handlers.grpc_handler import GrpcHandler +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore from a2a.types import ( @@ -44,6 +41,7 @@ JSONRPC_URL = '/a2a/jsonrpc' REST_URL = '/a2a/rest' +AGENT_CARD_URL = '/.well-known/agent-card.json' logging.basicConfig(level=logging.INFO) logger = logging.getLogger('SUTAgent') @@ -196,22 +194,30 @@ def serve(task_store: TaskStore) -> None: task_store=task_store, ) - main_app = Starlette() - + # Agent Card + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, + card_url=AGENT_CARD_URL, + ) # JSONRPC - jsonrpc_server = A2AStarletteApplication( + jsonrpc_routes = JsonRpcRoutes( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, + rpc_url=JSONRPC_URL, ) - jsonrpc_server.add_routes_to_app(main_app, rpc_url=JSONRPC_URL) - # REST - rest_server = A2ARESTFastAPIApplication( + rest_routes = RestRoutes( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, + rpc_url=REST_URL, ) - rest_app = rest_server.build(rpc_url=REST_URL) - main_app.mount('', rest_app) + + routes = [ + *agent_card_routes.routes, + *jsonrpc_routes.routes, + *rest_routes.routes, + ] + main_app = Starlette(routes=routes) config = uvicorn.Config( main_app, host='127.0.0.1', port=http_port, log_level='info' diff --git a/tests/compat/v0_3/test_jsonrpc_app_compat.py b/tests/compat/v0_3/test_jsonrpc_app_compat.py index 4f09bb23..4b344c67 100644 --- a/tests/compat/v0_3/test_jsonrpc_app_compat.py +++ b/tests/compat/v0_3/test_jsonrpc_app_compat.py @@ -6,7 +6,8 @@ import pytest from starlette.testclient import TestClient -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from starlette.applications import Starlette +from a2a.server.routes import JsonRpcRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, @@ -50,16 +51,18 @@ def test_app(mock_handler): mock_agent_card.capabilities.streaming = False mock_agent_card.capabilities.push_notifications = True mock_agent_card.capabilities.extended_agent_card = True - return A2AStarletteApplication( + router = JsonRpcRoutes( agent_card=mock_agent_card, - http_handler=mock_handler, + request_handler=mock_handler, enable_v0_3_compat=True, + rpc_url='/', ) + return Starlette(routes=router.routes) @pytest.fixture def client(test_app): - return TestClient(test_app.build()) + return TestClient(test_app) def test_send_message_v03_compat( diff --git a/tests/compat/v0_3/test_rest_fastapi_app_compat.py b/tests/compat/v0_3/test_rest_fastapi_app_compat.py index 8625b7e0..2b2ca22d 100644 --- a/tests/compat/v0_3/test_rest_fastapi_app_compat.py +++ b/tests/compat/v0_3/test_rest_fastapi_app_compat.py @@ -9,7 +9,8 @@ from google.protobuf import json_format from httpx import ASGITransport, AsyncClient -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import RestRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, @@ -50,17 +51,19 @@ async def request_handler() -> RequestHandler: async def app( agent_card: AgentCard, request_handler: RequestHandler, -) -> FastAPI: - """Builds the FastAPI application for testing.""" - return A2ARESTFastAPIApplication( - agent_card, - request_handler, +) -> Starlette: + """Builds the Starlette application for testing.""" + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=request_handler, enable_v0_3_compat=True, - ).build(agent_card_url='/well-known/agent.json', rpc_url='') + rpc_url='', + ) + return Starlette(routes=rest_routes.routes) @pytest.fixture -async def client(app: FastAPI) -> AsyncClient: +async def client(app: Starlette) -> AsyncClient: return AsyncClient( transport=ASGITransport(app=app), base_url='http://testapp' ) diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index ca1a234b..ec2456d6 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -3,7 +3,8 @@ from fastapi import FastAPI from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import RestRoutes from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue from a2a.server.request_handlers import DefaultRequestHandler @@ -139,17 +140,20 @@ def create_agent_app( ) -> FastAPI: """Creates a new HTTP+REST FastAPI application for the test agent.""" push_config_store = InMemoryPushNotificationConfigStore() - app = A2ARESTFastAPIApplication( - agent_card=test_agent_card(url), - http_handler=DefaultRequestHandler( - agent_executor=TestAgentExecutor(), - task_store=InMemoryTaskStore(), - push_config_store=push_config_store, - push_sender=BasePushNotificationSender( - httpx_client=notification_client, - config_store=push_config_store, - context=ServerCallContext(), - ), + handler = DefaultRequestHandler( + agent_executor=TestAgentExecutor(), + task_store=InMemoryTaskStore(), + push_config_store=push_config_store, + push_sender=BasePushNotificationSender( + httpx_client=notification_client, + config_store=push_config_store, + context=ServerCallContext(), ), ) - return app.build() + rest_routes = RestRoutes( + agent_card=test_agent_card(url), + request_handler=handler, + extended_agent_card=test_agent_card(url), + rpc_url='', + ) + return FastAPI(routes=rest_routes.routes) diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index e079fdf2..56c21bf3 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -5,7 +5,8 @@ import grpc from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import JsonRpcRoutes, RestRoutes, AgentCardRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -166,17 +167,29 @@ async def main_async(http_port: int, grpc_port: int): app = FastAPI() app.add_middleware(CustomLoggingMiddleware) - jsonrpc_app = A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build() + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=handler, + enable_v0_3_compat=True, + ) + jsonrpc_card = AgentCardRoutes( + agent_card=agent_card, + card_url='/.well-known/agent-card.json', + ) + jsonrpc_app = Starlette(routes=jsonrpc_routes.routes + jsonrpc_card.routes) app.mount('/jsonrpc', jsonrpc_app) - app.mount( - '/rest', - A2ARESTFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build(), + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=handler, + enable_v0_3_compat=True, + ) + rest_card = AgentCardRoutes( + agent_card=agent_card, + card_url='/.well-known/agent-card.json', ) + rest_app = Starlette(routes=rest_routes.routes + rest_card.routes) + app.mount('/rest', rest_app) # Start gRPC Server server = grpc.aio.server() diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index 0af06ad7..710399e4 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -4,7 +4,8 @@ from fastapi import FastAPI from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler @@ -69,16 +70,27 @@ async def test_agent_card_integration() -> None: app = FastAPI() # Mount JSONRPC application - # In JSONRPCApplication, the default agent_card_url is AGENT_CARD_WELL_KNOWN_PATH - jsonrpc_app = A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card - ).build() + jsonrpc_routes = [ + *AgentCardRoutes( + agent_card=agent_card, card_url='/.well-known/agent-card.json' + ).routes, + *JsonRpcRoutes( + agent_card=agent_card, request_handler=handler, rpc_url='/' + ).routes, + ] + jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) # Mount REST application - rest_app = A2ARESTFastAPIApplication( - http_handler=handler, agent_card=agent_card - ).build() + rest_routes = [ + *AgentCardRoutes( + agent_card=agent_card, card_url='/.well-known/agent-card.json' + ).routes, + *RestRoutes( + agent_card=agent_card, request_handler=handler, rpc_url='' + ).routes, + ] + rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) expected_content = { diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e239d780..2d2d3ac9 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -23,7 +23,8 @@ with_a2a_extensions, ) from a2a.client.transports import JsonRpcTransport, RestTransport -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -220,10 +221,14 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): def jsonrpc_setup(http_base_setup) -> TransportSetup: """Sets up the JsonRpcTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2AFastAPIApplication( - agent_card, mock_request_handler, extended_agent_card=agent_card + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) factory = ClientFactory( config=ClientConfig( @@ -239,10 +244,13 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: def rest_setup(http_base_setup) -> TransportSetup: """Sets up the RestTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2ARESTFastAPIApplication( - agent_card, mock_request_handler, extended_agent_card=agent_card + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='', ) - app = app_builder.build() + app = Starlette(routes=rest_routes.routes) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) factory = ClientFactory( config=ClientConfig( @@ -619,12 +627,16 @@ async def test_json_transport_get_signed_base_card( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, - card_modifier=signer, # Sign the base card + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, card_url='/', card_modifier=signer ) - app = app_builder.build() + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', + ) + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url @@ -639,7 +651,8 @@ async def test_json_transport_get_signed_base_card( # Verification happens here result = await resolver.get_agent_card( - signature_verifier=signature_verifier + relative_card_path='/', + signature_verifier=signature_verifier, ) # Create transport with the verified card @@ -684,15 +697,15 @@ async def test_client_get_signed_extended_card( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, extended_agent_card=extended_agent_card, - extended_card_modifier=lambda card, ctx: signer( - card - ), # Sign the extended card + extended_card_modifier=lambda card, ctx: signer(card), + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) transport = JsonRpcTransport( @@ -753,16 +766,17 @@ async def test_client_get_signed_base_and_extended_cards( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, card_url='/', card_modifier=signer + ) + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, extended_agent_card=extended_agent_card, - card_modifier=signer, # Sign the base card - extended_card_modifier=lambda card, ctx: signer( - card - ), # Sign the extended card + extended_card_modifier=lambda card, ctx: signer(card), + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url @@ -777,7 +791,8 @@ async def test_client_get_signed_base_and_extended_cards( # 1. Fetch base card base_card = await resolver.get_agent_card( - signature_verifier=signature_verifier + relative_card_path='/', + signature_verifier=signature_verifier, ) # 2. Create transport with base card diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index ddf9edbf..a2735e9e 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -10,7 +10,8 @@ from a2a.client.client import ClientConfig from a2a.client.client_factory import ClientFactory from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -171,8 +172,13 @@ def base_e2e_setup(): @pytest.fixture def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - app_builder = A2ARESTFastAPIApplication(agent_card, handler) - app = app_builder.build() + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='', + ) + app = Starlette(routes=rest_routes.routes) httpx_client = httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url='http://testserver' ) @@ -192,10 +198,14 @@ def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: @pytest.fixture def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - app_builder = A2AFastAPIApplication( - agent_card, handler, extended_agent_card=agent_card + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url='http://testserver' ) diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 903b90a2..10f92461 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -19,7 +19,8 @@ from a2a.client import ClientConfig, ClientFactory from a2a.utils.constants import TransportProtocol -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.context import ServerCallContext @@ -197,10 +198,13 @@ def jsonrpc_agent_card(self): @pytest.fixture def server_app(self, jsonrpc_agent_card, mock_handler): - app = A2AStarletteApplication( + jsonrpc_routes = JsonRpcRoutes( agent_card=jsonrpc_agent_card, - http_handler=mock_handler, - ).build(rpc_url='/jsonrpc') + request_handler=mock_handler, + extended_agent_card=jsonrpc_agent_card, + rpc_url='/jsonrpc', + ) + app = Starlette(routes=jsonrpc_routes.routes) return app @pytest.mark.asyncio diff --git a/tests/server/apps/jsonrpc/test_fastapi_app.py b/tests/server/apps/jsonrpc/test_fastapi_app.py deleted file mode 100644 index 11831df5..00000000 --- a/tests/server/apps/jsonrpc/test_fastapi_app.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from a2a.server.apps.jsonrpc import fastapi_app -from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication -from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec -) -from a2a.types.a2a_pb2 import AgentCard # For mock spec - - -# --- A2AFastAPIApplication Tests --- - - -class TestA2AFastAPIApplicationOptionalDeps: - # Running tests in this class requires the optional dependency fastapi to be - # present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_fastapi_is_present(self): - try: - import fastapi as _fastapi # noqa: F401 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' the optional dependency fastapi to be present in the test' - ' environment. Run `uv sync --dev ...` before running the test' - ' suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'capabilities.extended_agent_card' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_fastapi_not_installed(self): - pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed - fastapi_app._package_fastapi_installed = False - yield - fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag - - def test_create_a2a_fastapi_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - try: - _app = A2AFastAPIApplication(**mock_app_params) - except ImportError: - pytest.fail( - 'With the fastapi package present, creating a' - ' A2AFastAPIApplication instance should not raise ImportError' - ) - - def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror( - self, - mock_app_params: dict, - mark_pkg_fastapi_not_installed: Any, - ): - with pytest.raises( - ImportError, - match=( - 'The `fastapi` package is required to use the' - ' `A2AFastAPIApplication`' - ), - ): - _app = A2AFastAPIApplication(**mock_app_params) - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/apps/jsonrpc/test_serialization.py b/tests/server/apps/jsonrpc/test_serialization.py deleted file mode 100644 index 825f8e2a..00000000 --- a/tests/server/apps/jsonrpc/test_serialization.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Tests for JSON-RPC serialization behavior.""" - -from unittest import mock - -import pytest -from starlette.testclient import TestClient - -from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication -from a2a.server.jsonrpc_models import JSONParseError -from a2a.types import ( - InvalidRequestError, -) -from a2a.types.a2a_pb2 import ( - AgentCapabilities, - AgentInterface, - AgentCard, - AgentSkill, - APIKeySecurityScheme, - Message, - Part, - Role, - SecurityRequirement, - SecurityScheme, -) - - -@pytest.fixture -def minimal_agent_card(): - """Provides a minimal AgentCard for testing.""" - return AgentCard( - name='TestAgent', - description='A test agent.', - supported_interfaces=[ - AgentInterface( - url='http://example.com/agent', protocol_binding='HTTP+JSON' - ) - ], - version='1.0.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['text/plain'], - skills=[ - AgentSkill( - id='skill-1', - name='Test Skill', - description='A test skill', - tags=['test'], - ) - ], - ) - - -@pytest.fixture -def agent_card_with_api_key(): - """Provides an AgentCard with an APIKeySecurityScheme for testing serialization.""" - api_key_scheme = APIKeySecurityScheme( - name='X-API-KEY', - location='header', - ) - - security_scheme = SecurityScheme(api_key_security_scheme=api_key_scheme) - - card = AgentCard( - name='APIKeyAgent', - description='An agent that uses API Key auth.', - supported_interfaces=[ - AgentInterface( - url='http://example.com/apikey-agent', - protocol_binding='HTTP+JSON', - ) - ], - version='1.0.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['text/plain'], - ) - # Add security scheme to the map - card.security_schemes['api_key_auth'].CopyFrom(security_scheme) - - return card - - -def test_starlette_agent_card_serialization(minimal_agent_card: AgentCard): - """Tests that the A2AStarletteApplication endpoint correctly serializes agent card.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - assert response_data['name'] == 'TestAgent' - assert response_data['description'] == 'A test agent.' - assert ( - response_data['supportedInterfaces'][0]['url'] - == 'http://example.com/agent' - ) - assert response_data['version'] == '1.0.0' - - -def test_starlette_agent_card_with_api_key_scheme( - agent_card_with_api_key: AgentCard, -): - """Tests that the A2AStarletteApplication endpoint correctly serializes API key schemes.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(agent_card_with_api_key, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - # Check security schemes are serialized - assert 'securitySchemes' in response_data - assert 'api_key_auth' in response_data['securitySchemes'] - - -def test_fastapi_agent_card_serialization(minimal_agent_card: AgentCard): - """Tests that the A2AFastAPIApplication endpoint correctly serializes agent card.""" - handler = mock.AsyncMock() - app_instance = A2AFastAPIApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - assert response_data['name'] == 'TestAgent' - assert response_data['description'] == 'A test agent.' - - -def test_handle_invalid_json(minimal_agent_card: AgentCard): - """Test handling of malformed JSON.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.post( - '/', - content='{ "jsonrpc": "2.0", "method": "test", "id": 1, "params": { "key": "value" }', - ) - assert response.status_code == 200 - data = response.json() - assert data['error']['code'] == JSONParseError().code - - -def test_handle_oversized_payload(minimal_agent_card: AgentCard): - """Test handling of oversized JSON payloads.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - large_string = 'a' * 11 * 1_000_000 # 11MB string - payload = { - 'jsonrpc': '2.0', - 'method': 'test', - 'id': 1, - 'params': {'data': large_string}, - } - - response = client.post('/', json=payload) - assert response.status_code == 200 - data = response.json() - assert data['error']['code'] == -32600 - - -@pytest.mark.parametrize( - 'max_content_length', - [ - None, - 11 * 1024 * 1024, - 30 * 1024 * 1024, - ], -) -def test_handle_oversized_payload_with_max_content_length( - minimal_agent_card: AgentCard, - max_content_length: int | None, -): - """Test handling of JSON payloads with sizes within custom max_content_length.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication( - minimal_agent_card, handler, max_content_length=max_content_length - ) - client = TestClient(app_instance.build()) - - large_string = 'a' * 11 * 1_000_000 # 11MB string - payload = { - 'jsonrpc': '2.0', - 'method': 'test', - 'id': 1, - 'params': {'data': large_string}, - } - - response = client.post('/', json=payload) - assert response.status_code == 200 - data = response.json() - # When max_content_length is set, requests up to that size should not be - # rejected due to payload size. The request might fail for other reasons, - # but it shouldn't be an InvalidRequestError related to the content length. - if max_content_length is not None: - assert data['error']['code'] != -32600 - - -def test_handle_unicode_characters(minimal_agent_card: AgentCard): - """Test handling of unicode characters in JSON payload.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - unicode_text = 'こんにちは世界' # "Hello world" in Japanese - - # Mock a handler response - handler.on_message_send.return_value = Message( - role=Role.ROLE_AGENT, - parts=[Part(text=f'Received: {unicode_text}')], - message_id='response-unicode', - ) - - unicode_payload = { - 'jsonrpc': '2.0', - 'method': 'SendMessage', - 'id': 'unicode_test', - 'params': { - 'message': { - 'role': 'ROLE_USER', - 'parts': [{'text': unicode_text}], - 'messageId': 'msg-unicode', - } - }, - } - - response = client.post('/', json=unicode_payload) - - # We are testing that the server can correctly deserialize the unicode payload - assert response.status_code == 200 - data = response.json() - # Check that we got a result (handler was called) - if 'result' in data: - # Response should contain the unicode text - result = data['result'] - if 'message' in result: - assert ( - result['message']['parts'][0]['text'] - == f'Received: {unicode_text}' - ) - elif 'parts' in result: - assert result['parts'][0]['text'] == f'Received: {unicode_text}' - - -def test_fastapi_sub_application(minimal_agent_card: AgentCard): - """ - Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application. - """ - from fastapi import FastAPI - - handler = mock.AsyncMock() - sub_app_instance = A2AFastAPIApplication(minimal_agent_card, handler) - app_instance = FastAPI() - app_instance.mount('/a2a', sub_app_instance.build()) - client = TestClient(app_instance) - - response = client.get('/a2a/openapi.json') - assert response.status_code == 200 - response_data = response.json() - - # The generated a2a.json (OpenAPI 2.0 / Swagger) does not typically include a 'servers' block - # unless specifically configured or converted to OpenAPI 3.0. - # FastAPI usually generates OpenAPI 3.0 schemas which have 'servers'. - # When we inject the raw Swagger 2.0 schema, it won't have 'servers'. - # We check if it is indeed the injected schema by checking for 'swagger': '2.0' - # or by checking for 'basePath' if we want to test path correctness. - - if response_data.get('swagger') == '2.0': - # It's the injected Swagger 2.0 schema - pass - else: - # It's an auto-generated OpenAPI 3.0+ schema (fallback or otherwise) - assert 'servers' in response_data - assert response_data['servers'] == [{'url': '/a2a'}] diff --git a/tests/server/apps/jsonrpc/test_starlette_app.py b/tests/server/apps/jsonrpc/test_starlette_app.py deleted file mode 100644 index fa686871..00000000 --- a/tests/server/apps/jsonrpc/test_starlette_app.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from a2a.server.apps.jsonrpc import starlette_app -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication -from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec -) -from a2a.types.a2a_pb2 import AgentCard # For mock spec - - -# --- A2AStarletteApplication Tests --- - - -class TestA2AStarletteApplicationOptionalDeps: - # Running tests in this class requires optional dependencies starlette and - # sse-starlette to be present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_starlette_is_present(self): - try: - import sse_starlette as _sse_starlette # noqa: F401 - import starlette as _starlette # noqa: F401 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' optional dependencies starlette and sse-starlette to be' - ' present in the test environment. Run `uv sync --dev ...`' - ' before running the test suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'capabilities.extended_agent_card' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_starlette_not_installed(self): - pkg_starlette_installed_flag = ( - starlette_app._package_starlette_installed - ) - starlette_app._package_starlette_installed = False - yield - starlette_app._package_starlette_installed = ( - pkg_starlette_installed_flag - ) - - def test_create_a2a_starlette_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - try: - _app = A2AStarletteApplication(**mock_app_params) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating an' - ' A2AStarletteApplication instance should not raise ImportError' - ) - - def test_create_a2a_starlette_app_with_missing_deps_raises_importerror( - self, - mock_app_params: dict, - mark_pkg_starlette_not_installed: Any, - ): - with pytest.raises( - ImportError, - match='Packages `starlette` and `sse-starlette` are required', - ): - _app = A2AStarletteApplication(**mock_app_params) - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/apps/rest/__init__.py b/tests/server/apps/rest/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py deleted file mode 100644 index cbdf6b5e..00000000 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ /dev/null @@ -1,1453 +0,0 @@ -import asyncio -import unittest -import unittest.async_case - -from collections.abc import AsyncGenerator -from typing import Any, NoReturn -from unittest.mock import ANY, AsyncMock, MagicMock, call, patch - -import httpx -import pytest - -from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.agent_execution.request_context_builder import ( - RequestContextBuilder, -) -from a2a.server.context import ServerCallContext -from a2a.server.events import QueueManager -from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers import DefaultRequestHandler, JSONRPCHandler -from a2a.server.tasks import ( - BasePushNotificationSender, - InMemoryPushNotificationConfigStore, - PushNotificationConfigStore, - PushNotificationSender, - TaskStore, -) -from a2a.types import ( - InternalError, - TaskNotFoundError, - UnsupportedOperationError, -) -from a2a.types.a2a_pb2 import ( - AgentCapabilities, - AgentCard, - AgentInterface, - Artifact, - CancelTaskRequest, - DeleteTaskPushNotificationConfigRequest, - GetExtendedAgentCardRequest, - GetTaskPushNotificationConfigRequest, - GetTaskRequest, - ListTaskPushNotificationConfigsRequest, - ListTaskPushNotificationConfigsResponse, - ListTasksResponse, - Message, - Part, - TaskPushNotificationConfig, - Role, - SendMessageConfiguration, - SendMessageRequest, - TaskPushNotificationConfig, - SubscribeToTaskRequest, - Task, - TaskArtifactUpdateEvent, - TaskPushNotificationConfig, - TaskState, - TaskStatus, - TaskStatusUpdateEvent, -) - - -# Helper function to create a minimal Task proto -def create_task( - task_id: str = 'task_123', context_id: str = 'session-xyz' -) -> Task: - return Task( - id=task_id, - context_id=context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - - -# Helper function to create a Message proto -def create_message( - message_id: str = '111', - role: Role = Role.ROLE_AGENT, - text: str = 'test message', - task_id: str | None = None, - context_id: str | None = None, -) -> Message: - msg = Message( - message_id=message_id, - role=role, - parts=[Part(text=text)], - ) - if task_id: - msg.task_id = task_id - if context_id: - msg.context_id = context_id - return msg - - -# Helper functions for checking JSON-RPC response structure -def is_success_response(response: dict[str, Any]) -> bool: - """Check if response is a successful JSON-RPC response.""" - return 'result' in response and 'error' not in response - - -def is_error_response(response: dict[str, Any]) -> bool: - """Check if response is an error JSON-RPC response.""" - return 'error' in response - - -def get_error_code(response: dict[str, Any]) -> int | None: - """Get error code from JSON-RPC error response.""" - if 'error' in response: - return response['error'].get('code') - return None - - -def get_error_message(response: dict[str, Any]) -> str | None: - """Get error message from JSON-RPC error response.""" - if 'error' in response: - return response['error'].get('message') - return None - - -class TestJSONRPCtHandler(unittest.async_case.IsolatedAsyncioTestCase): - @pytest.fixture(autouse=True) - def init_fixtures(self) -> None: - self.mock_agent_card = MagicMock( - spec=AgentCard, - ) - self.mock_agent_card.capabilities = MagicMock(spec=AgentCapabilities) - self.mock_agent_card.capabilities.extended_agent_card = True - self.mock_agent_card.capabilities.streaming = True - self.mock_agent_card.capabilities.push_notifications = True - - # Mock supported_interfaces list - interface = MagicMock(spec=AgentInterface) - interface.url = 'http://agent.example.com/api' - self.mock_agent_card.supported_interfaces = [interface] - - async def test_on_get_task_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - request = GetTaskRequest(id=f'{task_id}') - response = await handler.on_get_task(request, call_context) - # Response is now a dict with 'result' key for success - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - assert response['result']['id'] == task_id - mock_task_store.get.assert_called_once_with(f'{task_id}', ANY) - - async def test_on_get_task_not_found(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = GetTaskRequest(id='nonexistent_id') - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - response = await handler.on_get_task(request, call_context) - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32001 - - async def test_on_list_tasks_success(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task1 = create_task() - task2 = create_task() - task2.id = 'task_456' - mock_result = ListTasksResponse( - next_page_token='123', - tasks=[task1, task2], - ) - request_handler.on_list_tasks.return_value = mock_result - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest( - page_size=10, - page_token='token', - ) - call_context = ServerCallContext(state={'foo': 'bar'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertIn('tasks', response['result']) - self.assertEqual(len(response['result']['tasks']), 2) - self.assertEqual(response['result']['nextPageToken'], '123') - - async def test_on_list_tasks_error(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - request_handler.on_list_tasks.side_effect = InternalError( - message='DB down' - ) - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest(page_size=10) - call_context = ServerCallContext(state={'request_id': '2'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['message'], 'DB down') - - async def test_on_list_tasks_empty(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_result = ListTasksResponse(page_size=10) - request_handler.on_list_tasks.return_value = mock_result - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest(page_size=10) - call_context = ServerCallContext(state={'foo': 'bar'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertIn('tasks', response['result']) - self.assertEqual(len(response['result']['tasks']), 0) - self.assertIn('nextPageToken', response['result']) - self.assertEqual(response['result']['nextPageToken'], '') - self.assertIn('pageSize', response['result']) - self.assertEqual(response['result']['pageSize'], 10) - self.assertIn('totalSize', response['result']) - self.assertEqual(response['result']['totalSize'], 0) - - async def test_on_cancel_task_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - mock_agent_executor.cancel.return_value = None - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - - async def streaming_coro(): - mock_task.status.state = TaskState.TASK_STATE_CANCELED - yield mock_task - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = CancelTaskRequest(id=f'{task_id}') - response = await handler.on_cancel_task(request, call_context) - assert mock_agent_executor.cancel.call_count == 1 - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result is converted to dict for JSON serialization - assert response['result']['id'] == task_id # type: ignore - assert ( - response['result']['status']['state'] == 'TASK_STATE_CANCELED' - ) # type: ignore - mock_agent_executor.cancel.assert_called_once() - - async def test_on_cancel_task_not_supported(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - mock_agent_executor.cancel.return_value = None - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - - async def streaming_coro(): - raise UnsupportedOperationError() - yield - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = CancelTaskRequest(id=f'{task_id}') - response = await handler.on_cancel_task(request, call_context) - assert mock_agent_executor.cancel.call_count == 1 - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32004 - mock_agent_executor.cancel.assert_called_once() - - async def test_on_cancel_task_not_found(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = CancelTaskRequest(id='nonexistent_id') - call_context = ServerCallContext(state={'request_id': '1'}) - response = await handler.on_cancel_task(request, call_context) - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32001 - mock_task_store.get.assert_called_once_with('nonexistent_id', ANY) - mock_agent_executor.cancel.assert_not_called() - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_new_message_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message( - task_id='task_123', context_id='session-xyz' - ), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - # execute is called asynchronously in background task - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - - async def test_on_message_new_message_with_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - # execute is called asynchronously in background task - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - - async def test_on_message_error(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - async def streaming_coro(): - raise UnsupportedOperationError() - yield - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, context_id=mock_task.context_id - ), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - - # Allow the background event loop to start the execution_task - import asyncio - - await asyncio.sleep(0) - - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32004 - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_stream_new_message_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - # Latch to ensure background execute is scheduled before asserting - execute_called = asyncio.Event() - - async def exec_side_effect(*args, **kwargs): - execute_called.set() - - mock_agent_executor.execute.side_effect = exec_side_effect - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message( - task_id='task_123', context_id='session-xyz' - ), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == len(events) - await asyncio.wait_for(execute_called.wait(), timeout=0.1) - mock_agent_executor.execute.assert_called_once() - - async def test_on_message_stream_new_message_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - # Latch to ensure background execute is scheduled before asserting - execute_called = asyncio.Event() - - async def exec_side_effect(*args, **kwargs): - execute_called.set() - - mock_agent_executor.execute.side_effect = exec_side_effect - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events = [item async for item in response] - assert len(collected_events) == len(events) - await asyncio.wait_for(execute_called.wait(), timeout=0.1) - mock_agent_executor.execute.assert_called_once() - assert mock_task.history is not None and len(mock_task.history) == 1 - - async def test_set_push_notification_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_push_notification_store = AsyncMock( - spec=PushNotificationConfigStore - ) - - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=mock_push_notification_store, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - ) - context = ServerCallContext() - response = await handler.set_push_notification_config(request, context) - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, request, context - ) - - async def test_get_push_notification_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - push_notification_store = InMemoryPushNotificationConfigStore() - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=push_notification_store, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - push_config = TaskPushNotificationConfig( - id='default', url='http://example.com' - ) - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - id='default', - ) - await handler.set_push_notification_config(request, ServerCallContext()) - - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='default', - ) - get_response = await handler.get_push_notification_config( - get_request, ServerCallContext() - ) - self.assertIsInstance(get_response, dict) - self.assertTrue(is_success_response(get_response)) - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_stream_new_message_send_push_notification_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) - push_notification_store = InMemoryPushNotificationConfigStore() - push_notification_sender = BasePushNotificationSender( - mock_httpx_client, push_notification_store, ServerCallContext() - ) - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=push_notification_store, - push_sender=push_notification_sender, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = None - mock_agent_executor.execute.return_value = None - mock_httpx_client.post.return_value = httpx.Response(200) - request = SendMessageRequest( - message=create_message(), - configuration=SendMessageConfiguration( - accepted_output_modes=['text'], - task_push_notification_config=TaskPushNotificationConfig( - url='http://example.com' - ), - ), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - - collected_events = [item async for item in response] - assert len(collected_events) == len(events) - - async def test_on_resubscribe_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_queue_manager = AsyncMock(spec=QueueManager) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store, mock_queue_manager - ) - self.mock_agent_card = MagicMock(spec=AgentCard) - self.mock_agent_card.capabilities = MagicMock(spec=AgentCapabilities) - self.mock_agent_card.capabilities.streaming = True - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_queue_manager.tap.return_value = EventQueue() - request = SubscribeToTaskRequest(id=f'{mock_task.id}') - response = handler.on_subscribe_to_task( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert ( - len(collected_events) == len(events) + 1 - ) # First event is task itself - assert mock_task.history is not None and len(mock_task.history) == 0 - - async def test_on_subscribe_no_existing_task_error(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = SubscribeToTaskRequest(id='nonexistent_id') - response = handler.on_subscribe_to_task(request, ServerCallContext()) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == 1 - self.assertIsInstance(collected_events[0], dict) - self.assertTrue(is_error_response(collected_events[0])) - assert collected_events[0]['error']['code'] == -32001 - - async def test_streaming_not_supported_error( - self, - ) -> None: - """Test that on_message_send_stream raises an error when streaming not supported.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - # Create agent card with streaming capability disabled - self.mock_agent_card.capabilities = AgentCapabilities(streaming=False) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Act & Assert - request = SendMessageRequest( - message=create_message(), - ) - - # Should raise UnsupportedOperationError about streaming not supported - with self.assertRaises(UnsupportedOperationError) as context: - async for _ in handler.on_message_send_stream( - request, ServerCallContext() - ): - pass - - self.assertEqual( - str(context.exception.message), - 'Streaming is not supported by the agent', - ) - - async def test_push_notifications_not_supported_error(self) -> None: - """Test that set_push_notification raises an error when push notifications not supported.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - # Create agent card with push notifications capability disabled - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=False, streaming=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Act & Assert - request = TaskPushNotificationConfig( - task_id='task_123', - url='http://example.com', - ) - - # Should raise UnsupportedOperationError about push notifications not supported - with self.assertRaises(UnsupportedOperationError) as context: - await handler.set_push_notification_config( - request, ServerCallContext() - ) - - self.assertEqual( - str(context.exception.message), - 'Push notifications are not supported by the agent', - ) - - async def test_on_get_push_notification_no_push_config_store(self) -> None: - """Test get_push_notification with no push notifier configured.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - # Create request handler without a push notifier - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Act - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='default', - ) - response = await handler.get_push_notification_config( - get_request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_set_push_notification_no_push_config_store(self) -> None: - """Test set_push_notification with no push notifier configured.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - # Create request handler without a push notifier - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Act - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - ) - response = await handler.set_push_notification_config( - request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_message_send_internal_error(self) -> None: - """Test on_message_send with an internal error.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Make the request handler raise an Internal error without specifying an error type - async def raise_server_error(*args, **kwargs) -> NoReturn: - raise InternalError(message='Internal Error') - - # Patch the method to raise an error - with patch.object( - request_handler, 'on_message_send', side_effect=raise_server_error - ): - # Act - request = SendMessageRequest( - message=create_message(), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_message_stream_internal_error(self) -> None: - """Test on_message_send_stream with an internal error.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Make the request handler raise an Internal error without specifying an error type - async def raise_server_error(*args, **kwargs): - raise InternalError(message='Internal Error') - yield # Need this to make it an async generator - - # Patch the method to raise an error - with patch.object( - request_handler, - 'on_message_send_stream', - return_value=raise_server_error(), - ): - # Act - request = SendMessageRequest( - message=create_message(), - ) - - # Get the single error response - responses = [] - async for response in handler.on_message_send_stream( - request, ServerCallContext() - ): - responses.append(response) - - # Assert - self.assertEqual(len(responses), 1) - self.assertIsInstance(responses[0], dict) - self.assertTrue(is_error_response(responses[0])) - self.assertEqual(responses[0]['error']['code'], -32603) - - async def test_default_request_handler_with_custom_components(self) -> None: - """Test DefaultRequestHandler initialization with custom components.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_queue_manager = AsyncMock(spec=QueueManager) - mock_push_config_store = AsyncMock(spec=PushNotificationConfigStore) - mock_push_sender = AsyncMock(spec=PushNotificationSender) - mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) - - # Act - handler = DefaultRequestHandler( - agent_executor=mock_agent_executor, - task_store=mock_task_store, - queue_manager=mock_queue_manager, - push_config_store=mock_push_config_store, - push_sender=mock_push_sender, - request_context_builder=mock_request_context_builder, - ) - - # Assert - self.assertEqual(handler.agent_executor, mock_agent_executor) - self.assertEqual(handler.task_store, mock_task_store) - self.assertEqual(handler._queue_manager, mock_queue_manager) - self.assertEqual(handler._push_config_store, mock_push_config_store) - self.assertEqual(handler._push_sender, mock_push_sender) - self.assertEqual( - handler._request_context_builder, mock_request_context_builder - ) - - async def test_on_message_send_error_handling(self) -> None: - """Test error handling in on_message_send when consuming raises A2AError.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Let task exist - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Set up consume_and_break_on_interrupt to raise UnsupportedOperationError - async def consume_raises_error(*args, **kwargs) -> NoReturn: - raise UnsupportedOperationError() - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - side_effect=consume_raises_error, - ): - # Act - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - - response = await handler.on_message_send( - request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_message_send_task_id_mismatch(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - # Mock returns task with different ID than what will be generated - mock_task_store.get.return_value = None # No existing task - mock_agent_executor.execute.return_value = None - - # Task returned has task_id='task_123' but request_context will have generated UUID - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message(), # No task_id, so UUID is generated - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - # The task ID mismatch should cause an error - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_message_stream_task_id_mismatch(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - events: list[Any] = [create_task()] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = None - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message(), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == 1 - self.assertIsInstance(collected_events[0], dict) - self.assertTrue(is_error_response(collected_events[0])) - self.assertEqual(collected_events[0]['error']['code'], -32603) - - async def test_on_get_push_notification(self) -> None: - """Test get_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - task_push_config = TaskPushNotificationConfig( - task_id=mock_task.id, id='config1', url='http://example.com' - ) - request_handler.on_get_task_push_notification_config.return_value = ( - task_push_config - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='config1', - ) - response = await handler.get_push_notification_config( - get_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result is converted to dict for JSON serialization - self.assertEqual( - response['result']['id'], - 'config1', - ) - self.assertEqual( - response['result']['taskId'], - mock_task.id, - ) - - async def test_on_list_push_notification(self) -> None: - """Test list_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - task_push_config = TaskPushNotificationConfig( - task_id=mock_task.id, id='default', url='http://example.com' - ) - request_handler.on_list_task_push_notification_configs.return_value = ( - ListTaskPushNotificationConfigsResponse(configs=[task_push_config]) - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - list_request = ListTaskPushNotificationConfigsRequest( - task_id=mock_task.id, - ) - response = await handler.list_push_notification_configs( - list_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result contains the response dict with configs field - self.assertIsInstance(response['result'], dict) - - async def test_on_list_push_notification_error(self) -> None: - """Test list_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - # throw server error - request_handler.on_list_task_push_notification_configs.side_effect = ( - InternalError() - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - list_request = ListTaskPushNotificationConfigsRequest( - task_id=mock_task.id, - ) - response = await handler.list_push_notification_configs( - list_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_delete_push_notification(self) -> None: - """Test delete_push_notification_config handling""" - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - request_handler.on_delete_task_push_notification_config.return_value = ( - None - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - delete_request = DeleteTaskPushNotificationConfigRequest( - task_id='task1', - id='config1', - ) - response = await handler.delete_push_notification_config( - delete_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['result'], None) - - async def test_on_delete_push_notification_error(self) -> None: - """Test delete_push_notification_config error handling""" - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - # throw server error - request_handler.on_delete_task_push_notification_config.side_effect = ( - UnsupportedOperationError() - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - delete_request = DeleteTaskPushNotificationConfigRequest( - task_id='task1', - id='config1', - ) - response = await handler.delete_push_notification_config( - delete_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_get_authenticated_extended_card_success(self) -> None: - """Test successful retrieval of the authenticated extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_extended_card = AgentCard( - name='Extended Card', - description='More details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.1', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_extended_card, - extended_card_modifier=None, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-1'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-1') - # Result is the agent card proto - - async def test_get_authenticated_extended_card_not_configured(self) -> None: - """Test error when authenticated extended agent card is not configured.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - # We need a proper card here because agent_card_to_dict accesses multiple fields - card = AgentCard( - name='TestAgent', - version='1.0.0', - supported_interfaces=[ - AgentInterface( - url='http://localhost', - protocol_binding='JSONRPC', - protocol_version='1.0.0', - ) - ], - capabilities=AgentCapabilities(extended_agent_card=True), - ) - - handler = JSONRPCHandler( - card, - mock_request_handler, - extended_agent_card=None, - extended_card_modifier=None, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-2'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - # Authenticated Extended Card flag is set with no extended card, - # returns base card in this case. - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-2') - - async def test_get_authenticated_extended_card_with_modifier(self) -> None: - """Test successful retrieval of a dynamically modified extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_base_card = AgentCard( - name='Base Card', - description='Base details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - - async def modifier( - card: AgentCard, context: ServerCallContext - ) -> AgentCard: - modified_card = AgentCard() - modified_card.CopyFrom(card) - modified_card.name = 'Modified Card' - modified_card.description = ( - f'Modified for context: {context.state.get("foo")}' - ) - return modified_card - - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_base_card, - extended_card_modifier=modifier, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext(state={'foo': 'bar'}) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertFalse(is_error_response(response)) - from google.protobuf.json_format import ParseDict - - modified_card = ParseDict( - response['result'], AgentCard(), ignore_unknown_fields=True - ) - self.assertEqual(modified_card.name, 'Modified Card') - self.assertEqual(modified_card.description, 'Modified for context: bar') - self.assertEqual(modified_card.version, '1.0') - - async def test_get_authenticated_extended_card_with_modifier_sync( - self, - ) -> None: - """Test successful retrieval of a synchronously dynamically modified extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_base_card = AgentCard( - name='Base Card', - description='Base details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - - def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: - # Copy the card by creating a new one with the same fields - from copy import deepcopy - - modified_card = AgentCard() - modified_card.CopyFrom(card) - modified_card.name = 'Modified Card' - modified_card.description = ( - f'Modified for context: {context.state.get("foo")}' - ) - return modified_card - - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_base_card, - extended_card_modifier=modifier, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-mod'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-mod') - # Result is converted to dict for JSON serialization - modified_card_dict = response['result'] - self.assertEqual(modified_card_dict['name'], 'Modified Card') - self.assertEqual( - modified_card_dict['description'], 'Modified for context: bar' - ) - self.assertEqual(modified_card_dict['version'], '1.0') diff --git a/tests/server/routes/test_agent_card_route.py b/tests/server/routes/test_agent_card_route.py new file mode 100644 index 00000000..6ba4dbee --- /dev/null +++ b/tests/server/routes/test_agent_card_route.py @@ -0,0 +1,55 @@ +import pytest +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from a2a.server.routes import AgentCardRoutes +from a2a.types import a2a_pb2 +from a2a.server.request_handlers.response_helpers import agent_card_to_dict + + +@pytest.fixture +def mock_agent_card(): + return a2a_pb2.AgentCard( + name='test-agent', + version='1.0.0', + documentation_url='http://localhost:8000', + ) + + +@pytest.fixture +def test_app(mock_agent_card): + app = Starlette() + card_route = AgentCardRoutes(mock_agent_card) + app.routes.append(card_route.routes[0]) + return app + + +@pytest.fixture +def client(test_app): + return TestClient(test_app) + + +def test_agent_card_route_returns_json(client, mock_agent_card): + response = client.get('/') + assert response.status_code == 200 + + # The route returns JSON, not protobuf SerializeToString() + actual_json = response.json() + expected_json = agent_card_to_dict(mock_agent_card) + + assert actual_json == expected_json + + +def test_agent_card_route_with_modifier(mock_agent_card): + async def modifier(card): + card.name = 'modified-agent' + return card + + card_route = AgentCardRoutes(mock_agent_card, card_modifier=modifier) + app = Starlette() + app.routes.append(card_route.routes[0]) + client = TestClient(app) + + response = client.get('/') + assert response.status_code == 200 + assert response.json()['name'] == 'modified-agent' diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py new file mode 100644 index 00000000..5a216825 --- /dev/null +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -0,0 +1,152 @@ +import pytest +import json +from unittest.mock import AsyncMock, MagicMock, patch +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.datastructures import Headers + +from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types.a2a_pb2 import ( + AgentCard, Message, Role, Part, GetTaskRequest, Task, + ListTasksResponse, TaskPushNotificationConfig, + ListTaskPushNotificationConfigsResponse +) +from a2a.server.jsonrpc_models import ( + JSONParseError, + InvalidRequestError, + MethodNotFoundError, + InvalidParamsError, +) + +@pytest.fixture +def mock_handler(): + handler = AsyncMock(spec=RequestHandler) + return handler + +@pytest.fixture +def agent_card(): + card = MagicMock(spec=AgentCard) + card.capabilities = MagicMock() + card.capabilities.streaming = True + card.capabilities.push_notifications = True + return card + +class TestJsonRpcDispatcher: + def _create_request(self, body_dict=None, headers=None, body_bytes=None): + """Helper to create a starlette Request for testing""" + scope = { + 'type': 'http', + 'method': 'POST', + 'path': '/', + 'headers': Headers(headers or {}).raw + } + + async def receive(): + if body_bytes: + return {'type': 'http.request', 'body': body_bytes, 'more_body': False} + return {'type': 'http.request', 'body': json.dumps(body_dict or {}).encode('utf-8'), 'more_body': False} + + return Request(scope, receive) + + @pytest.mark.asyncio + async def test_generate_error_response(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + resp = dispatcher._generate_error_response(1, JSONParseError(message='test')) + assert resp.status_code == 200 + data = json.loads(resp.body.decode()) + assert data['id'] == 1 + assert 'error' in data + assert data['error']['code'] == -32700 + + @pytest.mark.asyncio + @pytest.mark.parametrize("method, params, handler_attr, mock_return, expected_key", [ + ("GetTask", {"id": "task-1"}, "on_get_task", Task(id="task-1"), "id"), + ("SendMessage", {"message": {"parts": [{"text": "hi"}]}}, "on_message_send", Task(id="task-1"), "task"), + ("CancelTask", {"id": "task-1"}, "on_cancel_task", Task(id="task-1"), "id"), + ("ListTasks", {}, "on_list_tasks", ListTasksResponse(tasks=[Task(id="task-1")]), "tasks"), + ("CreateTaskPushNotificationConfig", {"taskId": "task-1"}, "on_create_task_push_notification_config", TaskPushNotificationConfig(task_id="task-1"), "taskId"), + ("GetTaskPushNotificationConfig", {"taskId": "task-1"}, "on_get_task_push_notification_config", TaskPushNotificationConfig(task_id="task-1"), "taskId"), + ("ListTaskPushNotificationConfigs", {}, "on_list_task_push_notification_configs", ListTaskPushNotificationConfigsResponse(configs=[TaskPushNotificationConfig(task_id="task-1")]), "configs"), + ("DeleteTaskPushNotificationConfig", {"taskId": "task-1"}, "on_delete_task_push_notification_config", None, None), + ]) + async def test_handle_requests_success_non_streaming(self, agent_card, mock_handler, method, params, handler_attr, mock_return, expected_key): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req_body = { + 'jsonrpc': '2.0', + 'id': 'msg-1', + 'method': method, + 'params': params + } + req = self._create_request(body_dict=req_body) + + mock_func = getattr(mock_handler, handler_attr) + if hasattr(mock_func, 'return_value'): + mock_func.return_value = mock_return + + resp = await dispatcher.handle_requests(req) + assert resp.status_code == 200 + res = json.loads(resp.body.decode()) + assert res['id'] == 'msg-1' + if expected_key: + assert 'result' in res + assert expected_key in res['result'] + assert mock_func.called + + @pytest.mark.asyncio + async def test_handle_requests_payload_too_large(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler, max_content_length=10) + req = self._create_request( + body_dict={'jsonrpc': '2.0', 'id': '1', 'method': 'GetTask'}, + headers={'content-length': '100'} + ) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32600 + assert 'Payload too large' in res['error']['message'] + + @pytest.mark.asyncio + async def test_handle_requests_batch_not_supported(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_dict=[{'jsonrpc': '2.0'}]) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32600 + # The underlying jsonrpc library formats the exact text differently depending on parse path + assert 'Invalid Request' in res['error']['message'] + + @pytest.mark.asyncio + async def test_handle_requests_invalid_jsonrpc_version(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_dict={'jsonrpc': '1.0', 'id': '1', 'method': 'GetTask'}) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32600 + + @pytest.mark.asyncio + async def test_handle_requests_method_not_found(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_dict={'jsonrpc': '2.0', 'id': '1', 'method': 'UnknownMethod'}) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32601 + + @pytest.mark.asyncio + async def test_v03_compat_delegation(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler, enable_v0_3_compat=True) + dispatcher._v03_adapter.supports_method = MagicMock(return_value=True) + dispatcher._v03_adapter.handle_request = AsyncMock(return_value=JSONResponse({'v03': 'compat'})) + + req = self._create_request(body_dict={'jsonrpc': '2.0', 'id': '1', 'method': 'message/send'}) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res == {'v03': 'compat'} + + @pytest.mark.asyncio + async def test_invalid_json_body_error(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_bytes=b'{"invalid": json}') + + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32700 diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/routes/test_jsonrpc_route.py similarity index 50% rename from tests/server/apps/jsonrpc/test_jsonrpc_app.py rename to tests/server/routes/test_jsonrpc_route.py index ab220e9c..dd53922d 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/routes/test_jsonrpc_route.py @@ -1,4 +1,3 @@ -# ruff: noqa: INP001 from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -7,26 +6,11 @@ from starlette.responses import JSONResponse from starlette.testclient import TestClient - -# Attempt to import StarletteBaseUser, fallback to MagicMock if not available -try: - from starlette.authentication import BaseUser as StarletteBaseUser -except ImportError: - StarletteBaseUser = MagicMock() # type: ignore - from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.apps.jsonrpc import ( - jsonrpc_app, # Keep this import for optional deps test -) -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - JSONRPCApplication, - StarletteUserProxy, -) -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from a2a.server.routes import JsonRpcRoutes, StarletteUserProxy + from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import ( - RequestHandler, -) # For mock spec +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, Message, @@ -35,49 +19,9 @@ ) -# --- StarletteUserProxy Tests --- - - -class TestStarletteUserProxy: - def test_starlette_user_proxy_is_authenticated_true(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.is_authenticated = True - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.is_authenticated is True - - def test_starlette_user_proxy_is_authenticated_false(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.is_authenticated = False - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.is_authenticated is False - - def test_starlette_user_proxy_user_name(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.display_name = 'Test User DisplayName' - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.user_name == 'Test User DisplayName' - - def test_starlette_user_proxy_user_name_raises_attribute_error(self): - """ - Tests that if the underlying starlette user object is missing the - display_name attribute, the proxy currently raises an AttributeError. - """ - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - # Ensure display_name is not present on the mock to trigger AttributeError - del starlette_user_mock.display_name - - proxy = StarletteUserProxy(starlette_user_mock) - with pytest.raises(AttributeError, match='display_name'): - _ = proxy.user_name - - -# --- JSONRPCApplication Tests (Selected) --- - - @pytest.fixture def mock_handler(): handler = AsyncMock(spec=RequestHandler) - # Return a proto Message object directly - the handler wraps it in SendMessageResponse handler.on_message_send.return_value = Message( message_id='test', role=Role.ROLE_AGENT, @@ -90,23 +34,25 @@ def mock_handler(): def test_app(mock_handler): mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://mockurl.com' - # Set up capabilities.streaming to avoid validation issues mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False - return A2AStarletteApplication( - agent_card=mock_agent_card, http_handler=mock_handler - ) + + from starlette.applications import Starlette + + app = Starlette() + jsonrpc_routes = JsonRpcRoutes(mock_agent_card, mock_handler) + app.routes.extend(jsonrpc_routes.routes) + return app @pytest.fixture def client(test_app): - return TestClient(test_app.build()) + return TestClient(test_app) def _make_send_message_request( text: str = 'hi', tenant: str | None = None ) -> dict: - """Helper to create a JSON-RPC send message request.""" params: dict[str, Any] = { 'message': { 'messageId': '1', @@ -125,112 +71,6 @@ def _make_send_message_request( } -class TestJSONRPCApplicationSetup: # Renamed to avoid conflict - def test_jsonrpc_app_build_method_abstract_raises_typeerror( - self, - ): # Renamed test - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in JSONRPCApplication.__init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed in __init__ - mock_agent_card.url = 'http://mockurl.com' - # Ensure 'supportsAuthenticatedExtendedCard' attribute exists - - # This will fail at definition time if an abstract method is not implemented - with pytest.raises( - TypeError, - match=r".*abstract class IncompleteJSONRPCApp .* abstract method '?build'?", - ): - - class IncompleteJSONRPCApp(JSONRPCApplication): - # Intentionally not implementing 'build' - def some_other_method(self): - pass - - IncompleteJSONRPCApp( - agent_card=mock_agent_card, http_handler=mock_handler - ) # type: ignore[abstract] - - -class TestJSONRPCApplicationOptionalDeps: - # Running tests in this class requires optional dependencies starlette and - # sse-starlette to be present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_starlette_is_present(self): - try: - import sse_starlette as _sse_starlette # noqa: F401, PLC0415 - import starlette as _starlette # noqa: F401, PLC0415 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' optional dependencies starlette and sse-starlette to be' - ' present in the test environment. Run `uv sync --dev ...`' - ' before running the test suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'supportsAuthenticatedExtendedCard' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_starlette_not_installed(self): - pkg_starlette_installed_flag = jsonrpc_app._package_starlette_installed - jsonrpc_app._package_starlette_installed = False - yield - jsonrpc_app._package_starlette_installed = pkg_starlette_installed_flag - - def test_create_jsonrpc_based_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - class MockJSONRPCApp(JSONRPCApplication): - def build( # type: ignore[override] - self, - agent_card_url='/.well-known/agent.json', - rpc_url='/', - **kwargs, - ): - return object() # type: ignore[return-value] - - try: - _app = MockJSONRPCApp(**mock_app_params) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating a' - ' JSONRPCApplication-based instance should not raise' - ' ImportError' - ) - - def test_create_jsonrpc_based_app_with_missing_deps_raises_importerror( - self, mock_app_params: dict, mark_pkg_starlette_not_installed: Any - ): - class MockJSONRPCApp(JSONRPCApplication): - def build( # type: ignore[override] - self, - agent_card_url='/.well-known/agent.json', - rpc_url='/', - **kwargs, - ): - return object() # type: ignore[return-value] - - with pytest.raises( - ImportError, - match=( - 'Packages `starlette` and `sse-starlette` are required to use' - ' the `JSONRPCApplication`' - ), - ): - _app = MockJSONRPCApp(**mock_app_params) - - class TestJSONRPCApplicationExtensions: def test_request_with_single_extension(self, client, mock_handler): headers = {HTTP_EXTENSION_HEADER: 'foo'} @@ -312,7 +152,6 @@ def test_response_with_activated_extensions(self, client, mock_handler): def side_effect(request, context: ServerCallContext): context.activated_extensions.add('foo') context.activated_extensions.add('baz') - # Return a proto Message object directly return Message( message_id='test', role=Role.ROLE_AGENT, @@ -363,19 +202,21 @@ def test_no_tenant_extraction(self, client, mock_handler): class TestJSONRPCApplicationV03Compat: - def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): + def test_v0_3_compat_flag_routes_to_dispatcher(self, mock_handler): mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://mockurl.com' mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False - app = A2AStarletteApplication( - agent_card=mock_agent_card, - http_handler=mock_handler, - enable_v0_3_compat=True, + from starlette.applications import Starlette + + app = Starlette() + routes = JsonRpcRoutes( + mock_agent_card, mock_handler, enable_v0_3_compat=True ) + app.routes.extend(routes.routes) - client = TestClient(app.build()) + client = TestClient(app) request_data = { 'jsonrpc': '2.0', @@ -390,8 +231,11 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): }, } + # Instead of _v03_adapter, the handler handles it or it's dispatcher with patch.object( - app._v03_adapter, 'handle_request', new_callable=AsyncMock + routes.dispatcher._v03_adapter, + 'handle_request', + new_callable=AsyncMock, ) as mock_handle: mock_handle.return_value = JSONResponse( {'jsonrpc': '2.0', 'id': '1', 'result': {}} @@ -401,7 +245,6 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): response.raise_for_status() assert mock_handle.called - assert mock_handle.call_args[1]['method'] == 'message/send' def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): mock_agent_card = MagicMock(spec=AgentCard) @@ -409,13 +252,15 @@ def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False - app = A2AStarletteApplication( - agent_card=mock_agent_card, - http_handler=mock_handler, - enable_v0_3_compat=False, + from starlette.applications import Starlette + + app = Starlette() + routes = JsonRpcRoutes( + mock_agent_card, mock_handler, enable_v0_3_compat=False ) + app.routes.extend(routes.routes) - client = TestClient(app.build()) + client = TestClient(app) request_data = { 'jsonrpc': '2.0', @@ -433,8 +278,6 @@ def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): response = client.post('/', json=request_data) assert response.status_code == 200 - # Should return MethodNotFoundError because the v0.3 method is not recognized - # without the adapter enabled. resp_json = response.json() assert 'error' in resp_json assert resp_json['error']['code'] == -32601 diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/routes/test_rest_routes.py similarity index 81% rename from tests/server/apps/rest/test_rest_fastapi_app.py rename to tests/server/routes/test_rest_routes.py index c8510023..50990f22 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/routes/test_rest_routes.py @@ -10,9 +10,7 @@ from google.protobuf import json_format from httpx import ASGITransport, AsyncClient -from a2a.server.apps.rest import fastapi_app, rest_adapter -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication -from a2a.server.apps.rest.rest_adapter import RESTAdapter +from a2a.server.routes import RestRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( @@ -74,12 +72,13 @@ async def extended_card_modifier() -> MagicMock | None: @pytest.fixture async def streaming_app( streaming_agent_card: AgentCard, request_handler: RequestHandler -) -> FastAPI: - """Builds the FastAPI application for testing streaming endpoints.""" +) -> Any: + from starlette.applications import Starlette - return A2ARESTFastAPIApplication( - streaming_agent_card, request_handler - ).build(agent_card_url='/well-known/agent-card.json', rpc_url='') + router = RestRoutes(streaming_agent_card, request_handler, rpc_url='') + app = Starlette() + app.routes.extend(router.routes) + return app @pytest.fixture @@ -95,14 +94,26 @@ async def app( agent_card: AgentCard, request_handler: RequestHandler, extended_card_modifier: MagicMock | None, -) -> FastAPI: - """Builds the FastAPI application for testing.""" +) -> Any: + from starlette.applications import Starlette + from a2a.server.routes import AgentCardRoutes - return A2ARESTFastAPIApplication( + # Return Starlette app + app_instance = Starlette() + + rest_router = RestRoutes( agent_card, request_handler, extended_card_modifier=extended_card_modifier, - ).build(agent_card_url='/well-known/agent.json', rpc_url='') + rpc_url='', + ) + app_instance.routes.extend(rest_router.routes) + + # Also Agent card endpoint? if needed in tests + card_routes = AgentCardRoutes(agent_card, card_url='/well-known/agent.json') + app_instance.routes.extend(card_routes.routes) + + return app_instance @pytest.fixture @@ -112,93 +123,18 @@ async def client(app: FastAPI) -> AsyncClient: ) -@pytest.fixture -def mark_pkg_starlette_not_installed(): - pkg_starlette_installed_flag = rest_adapter._package_starlette_installed - rest_adapter._package_starlette_installed = False - yield - rest_adapter._package_starlette_installed = pkg_starlette_installed_flag - - -@pytest.fixture -def mark_pkg_fastapi_not_installed(): - pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed - fastapi_app._package_fastapi_installed = False - yield - fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag - - -@pytest.mark.anyio -async def test_create_rest_adapter_with_present_deps_succeeds( - agent_card: AgentCard, request_handler: RequestHandler -): - try: - _app = RESTAdapter(agent_card, request_handler) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating an' - ' RESTAdapter instance should not raise ImportError' - ) - - -@pytest.mark.anyio -async def test_create_rest_adapter_with_missing_deps_raises_importerror( - agent_card: AgentCard, - request_handler: RequestHandler, - mark_pkg_starlette_not_installed: Any, -): - with pytest.raises( - ImportError, - match=( - r'Packages `starlette` and `sse-starlette` are required to use' - r' the `RESTAdapter`.' - ), - ): - _app = RESTAdapter(agent_card, request_handler) - - -@pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_present_deps_succeeds( - agent_card: AgentCard, request_handler: RequestHandler -): - try: - _app = A2ARESTFastAPIApplication(agent_card, request_handler).build( - agent_card_url='/well-known/agent.json', rpc_url='' - ) - except ImportError: - pytest.fail( - 'With the fastapi package present, creating a' - ' A2ARESTFastAPIApplication instance should not raise ImportError' - ) - - -@pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_missing_deps_raises_importerror( - agent_card: AgentCard, - request_handler: RequestHandler, - mark_pkg_fastapi_not_installed: Any, -): - with pytest.raises( - ImportError, - match=( - 'The `fastapi` package is required to use the' - ' `A2ARESTFastAPIApplication`' - ), - ): - _app = A2ARESTFastAPIApplication(agent_card, request_handler).build( - agent_card_url='/well-known/agent.json', rpc_url='' - ) +# --- RestRouter Tests --- @pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_v0_3_compat( +async def test_create_rest_router_with_v0_3_compat( agent_card: AgentCard, request_handler: RequestHandler ): - app = A2ARESTFastAPIApplication( - agent_card, request_handler, enable_v0_3_compat=True - ).build(agent_card_url='/well-known/agent.json', rpc_url='') - - routes = [getattr(route, 'path', '') for route in app.routes] + router = RestRoutes( + agent_card, request_handler, enable_v0_3_compat=True, rpc_url='' + ) + # Check if a route from v0.3 compatibility is present + routes = [getattr(route, 'path', '') for route in router.routes] assert '/v1/message:send' in routes @@ -692,33 +628,5 @@ async def test_tenant_extraction_extended_agent_card( assert context.tenant == '' -@pytest.mark.anyio -async def test_global_http_exception_handler_returns_rpc_status( - client: AsyncClient, -) -> None: - """Test that a standard FastAPI 404 is transformed into the A2A google.rpc.Status format.""" - - # Send a request to an endpoint that does not exist - response = await client.get('/non-existent-route') - - # Verify it returns a 404 with standard application/json - assert response.status_code == 404 - assert response.headers.get('content-type') == 'application/json' - - data = response.json() - - # Assert the payload is wrapped in the "error" envelope - assert 'error' in data - error_payload = data['error'] - - # Assert it has the correct AIP-193 format - assert error_payload['code'] == 404 - assert error_payload['status'] == 'NOT_FOUND' - assert 'Not Found' in error_payload['message'] - - # Standard HTTP errors shouldn't leak details - assert 'details' not in error_payload - - if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index e6bb5f88..ec974b8c 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -18,10 +18,8 @@ from starlette.routing import Route from starlette.testclient import TestClient -from a2a.server.apps import ( - A2AFastAPIApplication, - A2AStarletteApplication, -) +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes + from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( InternalError, @@ -148,14 +146,48 @@ def handler(): return handler +class AppBuilder: + def __init__(self, agent_card, handler, card_modifier=None): + self.agent_card = agent_card + self.handler = handler + self.card_modifier = card_modifier + + def build( + self, + rpc_url='/', + agent_card_url=AGENT_CARD_WELL_KNOWN_PATH, + middleware=None, + routes=None, + ): + from starlette.applications import Starlette + + app_instance = Starlette(middleware=middleware, routes=routes or []) + + # Agent card router + card_routes = AgentCardRoutes( + self.agent_card, + card_url=agent_card_url, + card_modifier=self.card_modifier, + ) + app_instance.routes.extend(card_routes.routes) + + # JSON-RPC router + rpc_routes = JsonRpcRoutes( + self.agent_card, self.handler, rpc_url=rpc_url + ) + app_instance.routes.extend(rpc_routes.routes) + + return app_instance + + @pytest.fixture def app(agent_card: AgentCard, handler: mock.AsyncMock): - return A2AStarletteApplication(agent_card, handler) + return AppBuilder(agent_card, handler) @pytest.fixture -def client(app: A2AStarletteApplication, **kwargs): - """Create a test client with the Starlette app.""" +def client(app, **kwargs): + """Create a test client with the app builder.""" return TestClient(app.build(**kwargs)) @@ -172,9 +204,7 @@ def test_agent_card_endpoint(client: TestClient, agent_card: AgentCard): assert 'streaming' in data['capabilities'] -def test_agent_card_custom_url( - app: A2AStarletteApplication, agent_card: AgentCard -): +def test_agent_card_custom_url(app, agent_card: AgentCard): """Test the agent card endpoint with a custom URL.""" client = TestClient(app.build(agent_card_url='/my-agent')) response = client.get('/my-agent') @@ -183,9 +213,7 @@ def test_agent_card_custom_url( assert data['name'] == agent_card.name -def test_starlette_rpc_endpoint_custom_url( - app: A2AStarletteApplication, handler: mock.AsyncMock -): +def test_starlette_rpc_endpoint_custom_url(app, handler: mock.AsyncMock): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = MINIMAL_TASK_STATUS @@ -206,9 +234,7 @@ def test_starlette_rpc_endpoint_custom_url( assert data['result']['id'] == 'task1' -def test_fastapi_rpc_endpoint_custom_url( - app: A2AFastAPIApplication, handler: mock.AsyncMock -): +def test_fastapi_rpc_endpoint_custom_url(app, handler: mock.AsyncMock): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = MINIMAL_TASK_STATUS @@ -229,9 +255,7 @@ def test_fastapi_rpc_endpoint_custom_url( assert data['result']['id'] == 'task1' -def test_starlette_build_with_extra_routes( - app: A2AStarletteApplication, agent_card: AgentCard -): +def test_starlette_build_with_extra_routes(app, agent_card: AgentCard): """Test building the app with additional routes.""" def custom_handler(request): @@ -253,9 +277,7 @@ def custom_handler(request): assert data['name'] == agent_card.name -def test_fastapi_build_with_extra_routes( - app: A2AFastAPIApplication, agent_card: AgentCard -): +def test_fastapi_build_with_extra_routes(app, agent_card: AgentCard): """Test building the app with additional routes.""" def custom_handler(request): @@ -277,9 +299,7 @@ def custom_handler(request): assert data['name'] == agent_card.name -def test_fastapi_build_custom_agent_card_path( - app: A2AFastAPIApplication, agent_card: AgentCard -): +def test_fastapi_build_custom_agent_card_path(app, agent_card: AgentCard): """Test building the app with a custom agent card path.""" test_app = app.build(agent_card_url='/agent-card') @@ -467,7 +487,7 @@ def test_get_push_notification_config( handler.on_get_task_push_notification_config.assert_awaited_once() -def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock): +def test_server_auth(app, handler: mock.AsyncMock): class TestAuthMiddleware(AuthenticationBackend): async def authenticate( self, conn: HTTPConnection @@ -529,9 +549,7 @@ async def authenticate( @pytest.mark.asyncio -async def test_message_send_stream( - app: A2AStarletteApplication, handler: mock.AsyncMock -) -> None: +async def test_message_send_stream(app, handler: mock.AsyncMock) -> None: """Test streaming message sending.""" # Setup mock streaming response @@ -605,9 +623,7 @@ async def stream_generator(): @pytest.mark.asyncio -async def test_task_resubscription( - app: A2AStarletteApplication, handler: mock.AsyncMock -) -> None: +async def test_task_resubscription(app, handler: mock.AsyncMock) -> None: """Test task resubscription streaming.""" # Setup mock streaming response @@ -738,9 +754,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AStarletteApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -763,9 +777,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AStarletteApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -788,9 +800,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AFastAPIApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -810,9 +820,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AFastAPIApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -924,7 +932,7 @@ def test_agent_card_backward_compatibility_supports_extended_card( ): """Test that supportsAuthenticatedExtendedCard is injected when extended_agent_card is True.""" agent_card.capabilities.extended_agent_card = True - app_instance = A2AStarletteApplication(agent_card, handler) + app_instance = AppBuilder(agent_card, handler) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) assert response.status_code == 200 @@ -937,7 +945,7 @@ def test_agent_card_backward_compatibility_no_extended_card( ): """Test that supportsAuthenticatedExtendedCard is absent when extended_agent_card is False.""" agent_card.capabilities.extended_agent_card = False - app_instance = A2AStarletteApplication(agent_card, handler) + app_instance = AppBuilder(agent_card, handler) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) assert response.status_code == 200