From 931f990817942ead4071d3245d73d36bac0cdf67 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 17 Mar 2026 18:39:16 +0000 Subject: [PATCH 01/17] wip refactoring --- .../request_handlers/rest_handler_v2.py | 156 +++++ src/a2a/server/router/__init__.py | 21 + src/a2a/server/router/agent_card_router.py | 76 +++ src/a2a/server/router/jsonrpc_dispatcher.py | 615 ++++++++++++++++++ src/a2a/server/router/jsonrpc_router.py | 129 ++++ src/a2a/server/router/rest_router.py | 349 ++++++++++ tck/sut_agent.py | 20 +- 7 files changed, 1360 insertions(+), 6 deletions(-) create mode 100644 src/a2a/server/request_handlers/rest_handler_v2.py create mode 100644 src/a2a/server/router/__init__.py create mode 100644 src/a2a/server/router/agent_card_router.py create mode 100644 src/a2a/server/router/jsonrpc_dispatcher.py create mode 100644 src/a2a/server/router/jsonrpc_router.py create mode 100644 src/a2a/server/router/rest_router.py diff --git a/src/a2a/server/request_handlers/rest_handler_v2.py b/src/a2a/server/request_handlers/rest_handler_v2.py new file mode 100644 index 00000000..5df91eaa --- /dev/null +++ b/src/a2a/server/request_handlers/rest_handler_v2.py @@ -0,0 +1,156 @@ +import logging + +from collections.abc import AsyncIterator +from typing import Any + +from google.protobuf.json_format import ( + MessageToDict, +) + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import a2a_pb2 +from a2a.types.a2a_pb2 import ( + AgentCard, +) +from a2a.utils import proto_utils +from a2a.utils.errors import TaskNotFoundError +from a2a.utils.helpers import validate, validate_async_generator +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.SERVER) +class RESTHandlerV2: + """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses.""" + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + ): + self.agent_card = agent_card + self.request_handler = request_handler + + async def on_message_send( + self, + params: a2a_pb2.SendMessageRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + task_or_message = await self.request_handler.on_message_send( + params, context + ) + if isinstance(task_or_message, a2a_pb2.Task): + response = a2a_pb2.SendMessageResponse(task=task_or_message) + else: + response = a2a_pb2.SendMessageResponse(message=task_or_message) + return MessageToDict(response) + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_message_send_stream( + self, + params: a2a_pb2.SendMessageRequest, + context: ServerCallContext, + ) -> AsyncIterator[dict[str, Any]]: + async for event in self.request_handler.on_message_send_stream( + params, context + ): + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) + + async def on_cancel_task( + self, + params: a2a_pb2.CancelTaskRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + task = await self.request_handler.on_cancel_task(params, context) + if task: + return MessageToDict(task) + raise TaskNotFoundError + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_subscribe_to_task( + self, + params: a2a_pb2.SubscribeToTaskRequest, + context: ServerCallContext, + ) -> AsyncIterator[dict[str, Any]]: + async for event in self.request_handler.on_subscribe_to_task( + params, context + ): + yield MessageToDict(proto_utils.to_stream_response(event)) + + async def get_push_notification( + self, + params: a2a_pb2.GetTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + config = ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) + ) + return MessageToDict(config) + + @validate( + lambda self: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) + async def set_push_notification( + self, + params: a2a_pb2.TaskPushNotificationConfig, + context: ServerCallContext, + ) -> dict[str, Any]: + config = ( + await self.request_handler.on_create_task_push_notification_config( + params, context + ) + ) + return MessageToDict(config) + + async def on_get_task( + self, + params: a2a_pb2.GetTaskRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + task = await self.request_handler.on_get_task(params, context) + if task: + return MessageToDict(task) + raise TaskNotFoundError + + async def delete_push_notification( + self, + params: a2a_pb2.DeleteTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + await self.request_handler.on_delete_task_push_notification_config( + params, context + ) + return {} + + async def list_tasks( + self, + params: a2a_pb2.ListTasksRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + result = await self.request_handler.on_list_tasks(params, context) + return MessageToDict(result, always_print_fields_with_no_presence=True) + + async def list_push_notifications( + self, + params: a2a_pb2.ListTaskPushNotificationConfigsRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + result = ( + await self.request_handler.on_list_task_push_notification_configs( + params, context + ) + ) + return MessageToDict(result) diff --git a/src/a2a/server/router/__init__.py b/src/a2a/server/router/__init__.py new file mode 100644 index 00000000..d7bdd5d7 --- /dev/null +++ b/src/a2a/server/router/__init__.py @@ -0,0 +1,21 @@ +"""A2A JSON-RPC Applications.""" + +from a2a.server.router.jsonrpc_router import JsonRpcRouter +from a2a.server.router.rest_router import RestRouter +from a2a.server.router.agent_card_router import AgentCardRouter +from a2a.server.apps.jsonrpc.jsonrpc_app import ( + CallContextBuilder, + DefaultCallContextBuilder, + StarletteUserProxy, +) +from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication + + +__all__ = [ + 'A2AFastAPIApplication', + 'A2AStarletteApplication', + 'CallContextBuilder', + 'DefaultCallContextBuilder', + 'JSONRPCApplication', + 'StarletteUserProxy', +] diff --git a/src/a2a/server/router/agent_card_router.py b/src/a2a/server/router/agent_card_router.py new file mode 100644 index 00000000..3ff57b6f --- /dev/null +++ b/src/a2a/server/router/agent_card_router.py @@ -0,0 +1,76 @@ +import logging + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from fastapi import APIRouter + from starlette.requests import Request + from starlette.responses import JSONResponse, Response +else: + try: + from fastapi import APIRouter + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + except ImportError: + APIRouter = Any + Request = Any + Response = Any + JSONResponse = Any + +from a2a.server.request_handlers.response_helpers import agent_card_to_dict +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH +from a2a.utils.helpers import maybe_await + + +logger = logging.getLogger(__name__) + + +class AgentCardRouter: + """A FastAPI router implementing the A2A protocol agent card endpoints.""" + + def __init__( + self, + agent_card: AgentCard, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, + rpc_url: str = '', + ) -> None: + """Initializes the AgentCardRouter. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + agent_card_url: The URL for the agent card endpoint. + rpc_url: The URL prefix for the endpoint. + """ + self.agent_card = agent_card + self.card_modifier = card_modifier + + self.router = APIRouter() + self._setup_router(agent_card_url, rpc_url) + + def _setup_router( + self, + agent_card_url: str, + rpc_url: str, + ) -> None: + """Configures the APIRouter with the A2A endpoints. + + Args: + agent_card_url: The URL for the agent card endpoint. + rpc_url: The URL prefix for the endpoint. + """ + + @self.router.get(f'{rpc_url}{agent_card_url}') + async def get_agent_card(request: Request) -> Response: + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = await maybe_await( + self.card_modifier(card_to_serve) + ) + return JSONResponse(agent_card_to_dict(card_to_serve)) diff --git a/src/a2a/server/router/jsonrpc_dispatcher.py b/src/a2a/server/router/jsonrpc_dispatcher.py new file mode 100644 index 00000000..f2332abd --- /dev/null +++ b/src/a2a/server/router/jsonrpc_dispatcher.py @@ -0,0 +1,615 @@ +"""JSON-RPC application for A2A server.""" + +import contextlib +import json +import logging +import traceback + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from google.protobuf.json_format import ParseDict +from jsonrpc.jsonrpc2 import JSONRPC20Request + +from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import User as A2AUser +from a2a.compat.v0_3.jsonrpc_adapter import JSONRPC03Adapter +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) +from a2a.server.context import ServerCallContext +from a2a.server.jsonrpc_models import ( + InternalError, + InvalidParamsError, + InvalidRequestError, + JSONParseError, + JSONRPCError, + MethodNotFoundError, +) +from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.response_helpers import ( + agent_card_to_dict, + build_error_response, +) +from a2a.types import A2ARequest +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + SendMessageRequest, + SubscribeToTaskRequest, + TaskPushNotificationConfig, +) +from a2a.utils.errors import ( + A2AError, + UnsupportedOperationError, +) +from a2a.utils.helpers import maybe_await + + +INTERNAL_ERROR_CODE = -32603 + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from fastapi import FastAPI + from sse_starlette.sse import EventSourceResponse + from starlette.applications import Starlette + from starlette.authentication import BaseUser + from starlette.exceptions import HTTPException + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + try: + # Starlette v0.48.0 + from starlette.status import HTTP_413_CONTENT_TOO_LARGE + except ImportError: + from starlette.status import ( # type: ignore[no-redef] + HTTP_413_REQUEST_ENTITY_TOO_LARGE as HTTP_413_CONTENT_TOO_LARGE, + ) + + _package_starlette_installed = True +else: + FastAPI = Any + try: + from sse_starlette.sse import EventSourceResponse + from starlette.applications import Starlette + from starlette.authentication import BaseUser + from starlette.exceptions import HTTPException + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + try: + # Starlette v0.48.0 + from starlette.status import HTTP_413_CONTENT_TOO_LARGE + except ImportError: + from starlette.status import ( + HTTP_413_REQUEST_ENTITY_TOO_LARGE as HTTP_413_CONTENT_TOO_LARGE, + ) + + _package_starlette_installed = True + except ImportError: + _package_starlette_installed = False + # Provide placeholder types for runtime type hinting when dependencies are not installed. + # These will not be used if the code path that needs them is guarded by _http_server_installed. + EventSourceResponse = Any + Starlette = Any + BaseUser = Any + HTTPException = Any + Request = Any + JSONResponse = Any + Response = Any + HTTP_413_CONTENT_TOO_LARGE = Any + + +class StarletteUserProxy(A2AUser): + """Adapts the Starlette User class to the A2A user representation.""" + + def __init__(self, user: BaseUser): + self._user = user + + @property + def is_authenticated(self) -> bool: + """Returns whether the current user is authenticated.""" + return self._user.is_authenticated + + @property + def user_name(self) -> str: + """Returns the user name of the current user.""" + return self._user.display_name + + +class CallContextBuilder(ABC): + """A class for building ServerCallContexts using the Starlette Request.""" + + @abstractmethod + def build(self, request: Request) -> ServerCallContext: + """Builds a ServerCallContext from a Starlette Request.""" + + +class DefaultCallContextBuilder(CallContextBuilder): + """A default implementation of CallContextBuilder.""" + + def build(self, request: Request) -> ServerCallContext: + """Builds a ServerCallContext from a Starlette Request. + + Args: + request: The incoming Starlette Request object. + + Returns: + A ServerCallContext instance populated with user and state + information from the request. + """ + user: A2AUser = UnauthenticatedUser() + state = {} + with contextlib.suppress(Exception): + user = StarletteUserProxy(request.user) + state['auth'] = request.auth + state['headers'] = dict(request.headers) + return ServerCallContext( + user=user, + state=state, + requested_extensions=get_requested_extensions( + request.headers.getlist(HTTP_EXTENSION_HEADER) + ), + ) + + +class JsonRpcDispatcher: + """Base class for A2A JSONRPC applications. + + Handles incoming JSON-RPC requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + # Method-to-model mapping for centralized routing + # Proto types don't have model_fields, so we define the mapping explicitly + # Method names match gRPC service method names + METHOD_TO_MODEL: dict[str, type] = { + 'SendMessage': SendMessageRequest, + 'SendStreamingMessage': SendMessageRequest, # Same proto type as SendMessage + 'GetTask': GetTaskRequest, + 'ListTasks': ListTasksRequest, + 'CancelTask': CancelTaskRequest, + 'CreateTaskPushNotificationConfig': TaskPushNotificationConfig, + 'GetTaskPushNotificationConfig': GetTaskPushNotificationConfigRequest, + 'ListTaskPushNotificationConfigs': ListTaskPushNotificationConfigsRequest, + 'DeleteTaskPushNotificationConfig': DeleteTaskPushNotificationConfigRequest, + 'SubscribeToTask': SubscribeToTaskRequest, + 'GetExtendedAgentCard': GetExtendedAgentCardRequest, + } + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + http_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB + enable_v0_3_compat: bool = False, + ) -> None: + """Initializes the JSONRPCApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + """ + if not _package_starlette_installed: + raise ImportError( + 'Packages `starlette` and `sse-starlette` are required to use the' + ' `JSONRPCApplication`. They can be added as a part of `a2a-sdk`' + ' optional dependencies, `a2a-sdk[http-server]`.' + ) + + self.agent_card = agent_card + self.extended_agent_card = extended_agent_card + self.card_modifier = card_modifier + self.extended_card_modifier = extended_card_modifier + self.handler = JSONRPCHandler( + agent_card=agent_card, + request_handler=http_handler, + extended_agent_card=extended_agent_card, + extended_card_modifier=extended_card_modifier, + ) + self._context_builder = context_builder or DefaultCallContextBuilder() + self._max_content_length = max_content_length + self.enable_v0_3_compat = enable_v0_3_compat + self._v03_adapter: JSONRPC03Adapter | None = None + + if self.enable_v0_3_compat: + self._v03_adapter = JSONRPC03Adapter( + agent_card=agent_card, + http_handler=http_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + ) + + def _generate_error_response( + self, + request_id: str | int | None, + error: Exception | JSONRPCError | A2AError, + ) -> JSONResponse: + """Creates a Starlette JSONResponse for a JSON-RPC error. + + Logs the error based on its type. + + Args: + request_id: The ID of the request that caused the error. + error: The error object (one of the JSONRPCError types). + + Returns: + A `JSONResponse` object formatted as a JSON-RPC error response. + """ + if not isinstance(error, A2AError | JSONRPCError): + error = InternalError(message=str(error)) + + response_data = build_error_response(request_id, error) + error_info = response_data.get('error', {}) + code = error_info.get('code') + message = error_info.get('message') + data = error_info.get('data') + + log_level = logging.WARNING + if code == INTERNAL_ERROR_CODE: + log_level = logging.ERROR + + logger.log( + log_level, + "Request Error (ID: %s): Code=%s, Message='%s'%s", + request_id, + code, + message, + f', Data={data}' if data else '', + ) + return JSONResponse( + response_data, + status_code=200, + ) + + def _allowed_content_length(self, request: Request) -> bool: + """Checks if the request content length is within the allowed maximum. + + Args: + request: The incoming Starlette Request object. + + Returns: + False if the content length is larger than the allowed maximum, True otherwise. + """ + if self._max_content_length is not None: + with contextlib.suppress(ValueError): + content_length = int(request.headers.get('content-length', '0')) + if content_length and content_length > self._max_content_length: + return False + return True + + async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 + """Handles incoming POST requests to the main A2A endpoint. + + Parses the request body as JSON, validates it against A2A request types, + dispatches it to the appropriate handler method, and returns the response. + Handles JSON parsing errors, validation errors, and other exceptions, + returning appropriate JSON-RPC error responses. + + Args: + request: The incoming Starlette Request object. + + Returns: + A Starlette Response object (JSONResponse or EventSourceResponse). + + Raises: + (Implicitly handled): Various exceptions are caught and converted + into JSON-RPC error responses by this method. + """ + request_id = None + body = None + + try: + body = await request.json() + if isinstance(body, dict): + request_id = body.get('id') + # Ensure request_id is valid for JSON-RPC response (str/int/None only) + if request_id is not None and not isinstance( + request_id, str | int + ): + request_id = None + # Treat payloads lager than allowed as invalid request (-32600) before routing + if not self._allowed_content_length(request): + return self._generate_error_response( + request_id, + InvalidRequestError(message='Payload too large'), + ) + logger.debug('Request body: %s', body) + # 1) Validate base JSON-RPC structure only (-32600 on failure) + try: + base_request = JSONRPC20Request.from_data(body) + if not isinstance(base_request, JSONRPC20Request): + # Batch requests are not supported + return self._generate_error_response( + request_id, + InvalidRequestError( + message='Batch requests are not supported' + ), + ) + if body.get('jsonrpc') != '2.0': + return self._generate_error_response( + request_id, + InvalidRequestError( + message="Invalid request: 'jsonrpc' must be exactly '2.0'" + ), + ) + except Exception as e: + logger.exception('Failed to validate base JSON-RPC request') + return self._generate_error_response( + request_id, + InvalidRequestError(data=str(e)), + ) + + # 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure) + method: str | None = base_request.method + request_id = base_request._id # noqa: SLF001 + + if not method: + return self._generate_error_response( + request_id, + InvalidRequestError(message='Method is required'), + ) + + if ( + self.enable_v0_3_compat + and self._v03_adapter + and self._v03_adapter.supports_method(method) + ): + return await self._v03_adapter.handle_request( + request_id=request_id, + method=method, + body=body, + request=request, + ) + + model_class = self.METHOD_TO_MODEL.get(method) + if not model_class: + return self._generate_error_response( + request_id, MethodNotFoundError() + ) + try: + # Parse the params field into the proto message type + params = body.get('params', {}) + specific_request = ParseDict(params, model_class()) + except Exception as e: + logger.exception('Failed to parse request params') + return self._generate_error_response( + request_id, + InvalidParamsError(data=str(e)), + ) + + # 3) Build call context and wrap the request for downstream handling + call_context = self._context_builder.build(request) + call_context.tenant = getattr(specific_request, 'tenant', '') + call_context.state['method'] = method + call_context.state['request_id'] = request_id + + # Route streaming requests by method name + if method in ('SendStreamingMessage', 'SubscribeToTask'): + return await self._process_streaming_request( + request_id, specific_request, call_context + ) + + return await self._process_non_streaming_request( + request_id, specific_request, call_context + ) + except json.decoder.JSONDecodeError as e: + traceback.print_exc() + return self._generate_error_response( + None, JSONParseError(message=str(e)) + ) + except HTTPException as e: + if e.status_code == HTTP_413_CONTENT_TOO_LARGE: + return self._generate_error_response( + request_id, + InvalidRequestError(message='Payload too large'), + ) + raise e + except Exception as e: + logger.exception('Unhandled exception') + return self._generate_error_response( + request_id, InternalError(message=str(e)) + ) + + async def _process_streaming_request( + self, + request_id: str | int | None, + request_obj: A2ARequest, + context: ServerCallContext, + ) -> Response: + """Processes streaming requests (SendStreamingMessage or SubscribeToTask). + + Args: + request_id: The ID of the request. + request_obj: The proto request message. + context: The ServerCallContext for the request. + + Returns: + An `EventSourceResponse` object to stream results to the client. + """ + handler_result: Any = None + # Check for streaming message request (same type as SendMessage, but handled differently) + if isinstance( + request_obj, + SendMessageRequest, + ): + handler_result = self.handler.on_message_send_stream( + request_obj, context + ) + elif isinstance(request_obj, SubscribeToTaskRequest): + handler_result = self.handler.on_subscribe_to_task( + request_obj, context + ) + + return self._create_response(context, handler_result) + + async def _process_non_streaming_request( + self, + request_id: str | int | None, + request_obj: A2ARequest, + context: ServerCallContext, + ) -> Response: + """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). + + Args: + request_id: The ID of the request. + request_obj: The proto request message. + context: The ServerCallContext for the request. + + Returns: + A `JSONResponse` object containing the result or error. + """ + handler_result: Any = None + match request_obj: + case SendMessageRequest(): + handler_result = await self.handler.on_message_send( + request_obj, context + ) + case CancelTaskRequest(): + handler_result = await self.handler.on_cancel_task( + request_obj, context + ) + case GetTaskRequest(): + handler_result = await self.handler.on_get_task( + request_obj, context + ) + case ListTasksRequest(): + handler_result = await self.handler.list_tasks( + request_obj, context + ) + case TaskPushNotificationConfig(): + handler_result = ( + await self.handler.set_push_notification_config( + request_obj, + context, + ) + ) + case GetTaskPushNotificationConfigRequest(): + handler_result = ( + await self.handler.get_push_notification_config( + request_obj, + context, + ) + ) + case ListTaskPushNotificationConfigsRequest(): + handler_result = ( + await self.handler.list_push_notification_configs( + request_obj, + context, + ) + ) + case DeleteTaskPushNotificationConfigRequest(): + handler_result = ( + await self.handler.delete_push_notification_config( + request_obj, + context, + ) + ) + case GetExtendedAgentCardRequest(): + handler_result = ( + await self.handler.get_authenticated_extended_card( + request_obj, + context, + ) + ) + case _: + logger.error( + 'Unhandled validated request type: %s', type(request_obj) + ) + error = UnsupportedOperationError( + message=f'Request type {type(request_obj).__name__} is unknown.' + ) + return self._generate_error_response(request_id, error) + + return self._create_response(context, handler_result) + + def _create_response( + self, + context: ServerCallContext, + handler_result: AsyncGenerator[dict[str, Any]] | dict[str, Any], + ) -> Response: + """Creates a Starlette Response based on the result from the request handler. + + Handles: + - AsyncGenerator for Server-Sent Events (SSE). + - Dict responses from handlers. + + Args: + context: The ServerCallContext provided to the request handler. + handler_result: The result from a request handler method. Can be an + async generator for streaming or a dict for non-streaming. + + Returns: + A Starlette JSONResponse or EventSourceResponse. + """ + headers = {} + if exts := context.activated_extensions: + headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts)) + if isinstance(handler_result, AsyncGenerator): + # Result is a stream of dict objects + async def event_generator( + stream: AsyncGenerator[dict[str, Any]], + ) -> AsyncGenerator[dict[str, str]]: + async for item in stream: + yield {'data': json.dumps(item)} + + return EventSourceResponse( + event_generator(handler_result), headers=headers + ) + + # handler_result is a dict (JSON-RPC response) + return JSONResponse(handler_result, headers=headers) + + async def _handle_get_agent_card(self, request: Request) -> JSONResponse: + """Handles GET requests for the agent card endpoint. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the agent card data. + """ + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) + + return JSONResponse( + agent_card_to_dict( + card_to_serve, + ) + ) diff --git a/src/a2a/server/router/jsonrpc_router.py b/src/a2a/server/router/jsonrpc_router.py new file mode 100644 index 00000000..0e34849b --- /dev/null +++ b/src/a2a/server/router/jsonrpc_router.py @@ -0,0 +1,129 @@ +import logging + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from fastapi import APIRouter, FastAPI + + _package_fastapi_installed = True +else: + try: + from fastapi import APIRouter, FastAPI + + _package_fastapi_installed = True + except ImportError: + APIRouter = Any + FastAPI = Any + + _package_fastapi_installed = False + + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.router.jsonrpc_dispatcher import ( + CallContextBuilder, + JsonRpcDispatcher, +) +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, + DEFAULT_RPC_URL, +) + + +logger = logging.getLogger(__name__) + + +class JsonRpcRouter: + """A FastAPI application implementing the A2A protocol server endpoints. + + Handles incoming JSON-RPC requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + http_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB + enable_v0_3_compat: bool = False, + agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, + rpc_url: str = '', + ) -> None: + """Initializes the A2AFastAPIApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + """ + if not _package_fastapi_installed: + raise ImportError( + 'The `fastapi` package is required to use the `A2AFastAPIApplication`.' + ' It can be added as a part of `a2a-sdk` optional dependencies,' + ' `a2a-sdk[http-server]`.' + ) + self.dispatcher = JsonRpcDispatcher( + agent_card=agent_card, + http_handler=http_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + max_content_length=max_content_length, + enable_v0_3_compat=enable_v0_3_compat, + ) + self.router = APIRouter() + self._setup_router(agent_card_url, rpc_url) + + def _setup_router( + self, + agent_card_url: str, + rpc_url: str, + ) -> None: + """Configures the APIRouter with the A2A endpoints. + + Args: + agent_card_url: The URL for the agent card endpoint. + rpc_url: The URL for the A2A JSON-RPC endpoint. + """ + self.router.post( + rpc_url, + openapi_extra={ + 'requestBody': { + 'content': { + 'application/json': { + 'schema': { + '$ref': '#/components/schemas/A2ARequest' + } + } + }, + 'required': True, + 'description': 'A2ARequest', + } + }, + )(self.dispatcher._handle_requests) diff --git a/src/a2a/server/router/rest_router.py b/src/a2a/server/router/rest_router.py new file mode 100644 index 00000000..97fc4798 --- /dev/null +++ b/src/a2a/server/router/rest_router.py @@ -0,0 +1,349 @@ +import logging + +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from fastapi import APIRouter, FastAPI + from sse_starlette.sse import EventSourceResponse + from starlette.exceptions import HTTPException as StarletteHTTPException + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + _package_fastapi_installed = True +else: + try: + from fastapi import APIRouter, FastAPI + from sse_starlette.sse import EventSourceResponse + from starlette.exceptions import HTTPException as StarletteHTTPException + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + _package_fastapi_installed = True + except ImportError: + APIRouter = Any + FastAPI = Any + EventSourceResponse = Any + Request = Any + JSONResponse = Any + Response = Any + StarletteHTTPException = Any + + _package_fastapi_installed = False + +import json + +from google.protobuf.json_format import MessageToDict, Parse + +from a2a.server.apps.jsonrpc import ( + CallContextBuilder, + DefaultCallContextBuilder, +) +from a2a.server.apps.rest.fastapi_app import _HTTP_TO_GRPC_STATUS_MAP +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.rest_handler_v2 import RESTHandlerV2 +from a2a.types import a2a_pb2 +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils import proto_utils +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, + DEFAULT_RPC_URL, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + InvalidRequestError, +) +from a2a.utils.helpers import maybe_await + + +logger = logging.getLogger(__name__) + + +class RestRouter: + """A FastAPI application implementing the A2A protocol server endpoints. + + Handles incoming JSON-REST requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + http_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB + enable_v0_3_compat: bool = False, + agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, + rpc_url: str = '', + ) -> None: + """Initializes the A2AFastAPIApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + """ + if not _package_fastapi_installed: + raise ImportError( + 'The `fastapi` package is required to use the `A2AFastAPIApplication`.' + ' It can be added as a part of `a2a-sdk` optional dependencies,' + ' `a2a-sdk[http-server]`.' + ) + + self.agent_card = agent_card + self.http_handler = http_handler + self.rest_handler = RESTHandlerV2( + agent_card=agent_card, request_handler=http_handler + ) + self.extended_agent_card = extended_agent_card + self._context_builder = context_builder or DefaultCallContextBuilder() + self.card_modifier = card_modifier + self.extended_card_modifier = extended_card_modifier + self.max_content_length = max_content_length + self.enable_v0_3_compat = enable_v0_3_compat + + self._v03_adapter = None + if enable_v0_3_compat: + from a2a.compat.v0_3.rest_adapter import ( + REST03Adapter as V03RESTAdapter, + ) + + self._v03_adapter = V03RESTAdapter( + agent_card=agent_card, + http_handler=http_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + ) + + self.router = APIRouter() + self._setup_router(agent_card_url, rpc_url) + + def _build_call_context(self, request: Request) -> ServerCallContext: + call_context = self._context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] + return call_context + + def _setup_router( + self, + agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, + rpc_url: str = '', + **kwargs: Any, + ) -> None: + """Builds and returns the FastAPI application instance.""" + + if self.enable_v0_3_compat and self._v03_adapter: + v03_router = APIRouter() + for route, callback in self._v03_adapter.routes().items(): + v03_router.add_api_route( + f'{rpc_url}{route[0]}', callback, methods=[route[1]] + ) + self.router.include_router(v03_router) + + base_routes: dict[tuple[str, str], Callable[[Request], Any]] = {} + + async def message_send(request: Request) -> Response: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + context = self._build_call_context(request) + result = await self.rest_handler.on_message_send( + params, context + ) + return JSONResponse(result) + + base_routes[('/message:send', 'POST')] = message_send + + async def message_stream(request: Request) -> EventSourceResponse: + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + context = self._build_call_context(request) + + async def event_generator( + stream: AsyncIterator[dict[str, Any]], + ) -> AsyncIterator[str]: + async for item in stream: + yield json.dumps(item) + + return EventSourceResponse( + event_generator( + self.rest_handler.on_message_send_stream( + params, context + ) + ) + ) + + base_routes[('/message:stream', 'POST')] = message_stream + + async def cancel_task(request: Request) -> Response: + task_id = request.path_params['id'] + params = a2a_pb2.CancelTaskRequest(id=task_id) + context = self._build_call_context(request) + result = await self.rest_handler.on_cancel_task(params, context) + return JSONResponse(result) + + base_routes[('/tasks/{id}:cancel', 'POST')] = cancel_task + + async def subscribe_task(request: Request) -> EventSourceResponse: + import contextlib # noqa: PLC0415 + with contextlib.suppress(ValueError, RuntimeError, OSError): + await request.body() + task_id = request.path_params['id'] + params = a2a_pb2.SubscribeToTaskRequest(id=task_id) + context = self._build_call_context(request) + + async def event_generator( + stream: AsyncIterator[dict[str, Any]], + ) -> AsyncIterator[str]: + async for item in stream: + yield json.dumps(item) + + return EventSourceResponse( + event_generator( + self.rest_handler.on_subscribe_to_task(params, context) + ) + ) + + base_routes[('/tasks/{id}:subscribe', 'GET')] = subscribe_task + base_routes[('/tasks/{id}:subscribe', 'POST')] = subscribe_task + + async def get_task(request: Request) -> Response: + params = a2a_pb2.GetTaskRequest() + proto_utils.parse_params(request.query_params, params) + params.id = request.path_params['id'] + context = self._build_call_context(request) + result = await self.rest_handler.on_get_task(params, context) + return JSONResponse(result) + + base_routes[('/tasks/{id}', 'GET')] = get_task + + async def get_push_notification(request: Request) -> Response: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.GetTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + context = self._build_call_context(request) + result = await self.rest_handler.get_push_notification( + params, context + ) + return JSONResponse(result) + + base_routes[('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET')] = get_push_notification + + async def delete_push_notification(request: Request) -> Response: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + context = self._build_call_context(request) + result = await self.rest_handler.delete_push_notification( + params, context + ) + return JSONResponse(result) + + base_routes[('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE')] = delete_push_notification + + async def set_push_notification(request: Request) -> Response: + body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) + params.task_id = request.path_params['id'] + context = self._build_call_context(request) + result = await self.rest_handler.set_push_notification( + params, context + ) + return JSONResponse(result) + + base_routes[('/tasks/{id}/pushNotificationConfigs', 'POST')] = set_push_notification + + async def list_push_notifications(request: Request) -> Response: + params = a2a_pb2.ListTaskPushNotificationConfigsRequest() + proto_utils.parse_params(request.query_params, params) + params.task_id = request.path_params['id'] + context = self._build_call_context(request) + result = await self.rest_handler.list_push_notifications( + params, context + ) + return JSONResponse(result) + + base_routes[('/tasks/{id}/pushNotificationConfigs', 'GET')] = list_push_notifications + + async def list_tasks(request: Request) -> Response: + params = a2a_pb2.ListTasksRequest() + proto_utils.parse_params(request.query_params, params) + context = self._build_call_context(request) + result = await self.rest_handler.list_tasks(params, context) + return JSONResponse(result) + + base_routes[('/tasks', 'GET')] = list_tasks + + if self.agent_card.capabilities.extended_agent_card: + async def get_extended_agent_card(request: Request) -> Response: + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='Authenticated card not supported' + ) + card_to_serve = self.extended_agent_card or self.agent_card + + if self.extended_card_modifier: + context = self._build_call_context(request) + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await( + self.card_modifier(card_to_serve) + ) + + return JSONResponse( + MessageToDict( + card_to_serve, preserving_proto_field_name=True + ) + ) + + base_routes[('/extendedAgentCard', 'GET')] = get_extended_agent_card + + routes: dict[tuple[str, str], Callable[[Request], Any]] = { + (p, method): handler + for (path, method), handler in base_routes.items() + for p in (path, f'/{{tenant}}{path}') + } + + for (path, method), handler in routes.items(): + self.router.add_api_route( + f'{rpc_url}{path}', handler, methods=[method] + ) + diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 7196b828..bb31975e 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -12,6 +12,7 @@ import a2a.compat.v0_3.a2a_v0_3_pb2_grpc as a2a_v0_3_grpc import a2a.types.a2a_pb2_grpc as a2a_grpc +from fastapi import FastAPI from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor @@ -20,6 +21,7 @@ A2ARESTFastAPIApplication, A2AStarletteApplication, ) +from a2a.server.router import JsonRpcRouter, RestRouter, AgentCardRouter from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, @@ -44,6 +46,7 @@ JSONRPC_URL = '/a2a/jsonrpc' REST_URL = '/a2a/rest' +AGENT_CARD_URL = '/.well-known/agent-card.json' logging.basicConfig(level=logging.INFO) logger = logging.getLogger('SUTAgent') @@ -196,22 +199,27 @@ def serve(task_store: TaskStore) -> None: task_store=task_store, ) - main_app = Starlette() + main_app = FastAPI() + + # Agent Card + agent_card_router = AgentCardRouter( + agent_card=agent_card, + ) + main_app.include_router(agent_card_router.router) # JSONRPC - jsonrpc_server = A2AStarletteApplication( + jsonrpc_router = JsonRpcRouter( agent_card=agent_card, http_handler=request_handler, ) - jsonrpc_server.add_routes_to_app(main_app, rpc_url=JSONRPC_URL) + main_app.include_router(jsonrpc_router.router, prefix=JSONRPC_URL) # REST - rest_server = A2ARESTFastAPIApplication( + rest_router = RestRouter( agent_card=agent_card, http_handler=request_handler, ) - rest_app = rest_server.build(rpc_url=REST_URL) - main_app.mount('', rest_app) + main_app.include_router(rest_router.router, prefix=REST_URL) config = uvicorn.Config( main_app, host='127.0.0.1', port=http_port, log_level='info' From 749a760564ee49fefcec29d04e5260598474ba75 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 18 Mar 2026 11:30:26 +0000 Subject: [PATCH 02/17] wip --- src/a2a/server/router/__init__.py | 6 +- src/a2a/server/router/agent_card_router.py | 13 +- src/a2a/server/router/jsonrpc_dispatcher.py | 183 +++++++++-------- src/a2a/server/router/jsonrpc_router.py | 52 ++--- src/a2a/server/router/rest_router.py | 217 ++++++++------------ tck/sut_agent.py | 9 +- 6 files changed, 212 insertions(+), 268 deletions(-) diff --git a/src/a2a/server/router/__init__.py b/src/a2a/server/router/__init__.py index d7bdd5d7..bbd27494 100644 --- a/src/a2a/server/router/__init__.py +++ b/src/a2a/server/router/__init__.py @@ -1,14 +1,14 @@ """A2A JSON-RPC Applications.""" -from a2a.server.router.jsonrpc_router import JsonRpcRouter -from a2a.server.router.rest_router import RestRouter -from a2a.server.router.agent_card_router import AgentCardRouter from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, DefaultCallContextBuilder, StarletteUserProxy, ) from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from a2a.server.router.agent_card_router import AgentCardRouter +from a2a.server.router.jsonrpc_router import JsonRpcRouter +from a2a.server.router.rest_router import RestRouter __all__ = [ diff --git a/src/a2a/server/router/agent_card_router.py b/src/a2a/server/router/agent_card_router.py index 3ff57b6f..2055d8d3 100644 --- a/src/a2a/server/router/agent_card_router.py +++ b/src/a2a/server/router/agent_card_router.py @@ -5,16 +5,16 @@ if TYPE_CHECKING: - from fastapi import APIRouter from starlette.requests import Request from starlette.responses import JSONResponse, Response + from starlette.routing import Router else: try: - from fastapi import APIRouter from starlette.requests import Request from starlette.responses import JSONResponse, Response + from starlette.routing import Router except ImportError: - APIRouter = Any + Router = Any Request = Any Response = Any JSONResponse = Any @@ -51,7 +51,7 @@ def __init__( self.agent_card = agent_card self.card_modifier = card_modifier - self.router = APIRouter() + self.router = Router() self._setup_router(agent_card_url, rpc_url) def _setup_router( @@ -66,7 +66,6 @@ def _setup_router( rpc_url: The URL prefix for the endpoint. """ - @self.router.get(f'{rpc_url}{agent_card_url}') async def get_agent_card(request: Request) -> Response: card_to_serve = self.agent_card if self.card_modifier: @@ -74,3 +73,7 @@ async def get_agent_card(request: Request) -> Response: self.card_modifier(card_to_serve) ) return JSONResponse(agent_card_to_dict(card_to_serve)) + + self.router.add_route( + f'{rpc_url}{agent_card_url}', get_agent_card, methods=['GET'] + ) diff --git a/src/a2a/server/router/jsonrpc_dispatcher.py b/src/a2a/server/router/jsonrpc_dispatcher.py index f2332abd..fba7b381 100644 --- a/src/a2a/server/router/jsonrpc_dispatcher.py +++ b/src/a2a/server/router/jsonrpc_dispatcher.py @@ -9,8 +9,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TYPE_CHECKING, Any -from google.protobuf.json_format import ParseDict -from jsonrpc.jsonrpc2 import JSONRPC20Request +from google.protobuf.json_format import MessageToDict, ParseDict +from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response from a2a.auth.user import UnauthenticatedUser from a2a.auth.user import User as A2AUser @@ -28,10 +28,8 @@ JSONRPCError, MethodNotFoundError, ) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, build_error_response, ) from a2a.types import A2ARequest @@ -45,11 +43,16 @@ ListTaskPushNotificationConfigsRequest, ListTasksRequest, SendMessageRequest, + SendMessageResponse, SubscribeToTaskRequest, + Task, TaskPushNotificationConfig, ) +from a2a.utils import proto_utils from a2a.utils.errors import ( A2AError, + ExtendedAgentCardNotConfiguredError, + TaskNotFoundError, UnsupportedOperationError, ) from a2a.utils.helpers import maybe_await @@ -234,12 +237,7 @@ def __init__( # noqa: PLR0913 self.extended_agent_card = extended_agent_card self.card_modifier = card_modifier self.extended_card_modifier = extended_card_modifier - self.handler = JSONRPCHandler( - agent_card=agent_card, - request_handler=http_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=extended_card_modifier, - ) + self.http_handler = http_handler self._context_builder = context_builder or DefaultCallContextBuilder() self._max_content_length = max_content_length self.enable_v0_3_compat = enable_v0_3_compat @@ -419,15 +417,22 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 call_context.state['method'] = method call_context.state['request_id'] = request_id + # Route streaming requests by method name # Route streaming requests by method name if method in ('SendStreamingMessage', 'SubscribeToTask'): - return await self._process_streaming_request( + handler_result = await self._process_streaming_request( request_id, specific_request, call_context ) + else: + try: + raw_result = await self._process_non_streaming_request( + request_id, specific_request, call_context + ) + handler_result = JSONRPC20Response(result=raw_result, _id=request_id).data + except Exception as e: + return self._generate_error_response(request_id, e) - return await self._process_non_streaming_request( - request_id, specific_request, call_context - ) + return self._create_response(call_context, handler_result) except json.decoder.JSONDecodeError as e: traceback.print_exc() return self._generate_error_response( @@ -451,7 +456,7 @@ async def _process_streaming_request( request_id: str | int | None, request_obj: A2ARequest, context: ServerCallContext, - ) -> Response: + ) -> AsyncGenerator[dict[str, Any], None]: """Processes streaming requests (SendStreamingMessage or SubscribeToTask). Args: @@ -460,31 +465,43 @@ async def _process_streaming_request( context: The ServerCallContext for the request. Returns: - An `EventSourceResponse` object to stream results to the client. + An async generator yielding JSON-RPC response dicts. """ - handler_result: Any = None - # Check for streaming message request (same type as SendMessage, but handled differently) - if isinstance( - request_obj, - SendMessageRequest, - ): - handler_result = self.handler.on_message_send_stream( - request_obj, context + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError( + message='Streaming is not supported by the agent' ) + + stream: AsyncGenerator | None = None + if isinstance(request_obj, SendMessageRequest): + stream = self.http_handler.on_message_send_stream(request_obj, context) elif isinstance(request_obj, SubscribeToTaskRequest): - handler_result = self.handler.on_subscribe_to_task( - request_obj, context - ) + stream = self.http_handler.on_subscribe_to_task(request_obj, context) + + if stream is None: + raise UnsupportedOperationError(message='Stream not supported') + + async def _wrap_stream(st: AsyncGenerator) -> AsyncGenerator[dict[str, Any], None]: + try: + async for event in st: + stream_response = proto_utils.to_stream_response(event) + result = MessageToDict( + stream_response, preserving_proto_field_name=False + ) + yield JSONRPC20Response(result=result, _id=request_id).data + except Exception as e: + error = e if isinstance(e, A2AError) else InternalError(message=str(e)) + yield build_error_response(request_id, error) - return self._create_response(context, handler_result) + return _wrap_stream(stream) async def _process_non_streaming_request( self, request_id: str | int | None, request_obj: A2ARequest, context: ServerCallContext, - ) -> Response: - """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). + ) -> Any: + """Processes non-streaming requests and returns the raw result data. Args: request_id: The ID of the request. @@ -492,71 +509,80 @@ async def _process_non_streaming_request( context: The ServerCallContext for the request. Returns: - A `JSONResponse` object containing the result or error. + The raw response object or dict, to be wrapped in a success response. """ - handler_result: Any = None match request_obj: case SendMessageRequest(): - handler_result = await self.handler.on_message_send( + task_or_message = await self.http_handler.on_message_send( request_obj, context ) + if isinstance(task_or_message, Task): + msg_response = SendMessageResponse(task=task_or_message) + else: + msg_response = SendMessageResponse(message=task_or_message) + return MessageToDict(msg_response) case CancelTaskRequest(): - handler_result = await self.handler.on_cancel_task( - request_obj, context - ) + task = await self.http_handler.on_cancel_task(request_obj, context) + if not task: + raise TaskNotFoundError() + return MessageToDict(task, preserving_proto_field_name=False) case GetTaskRequest(): - handler_result = await self.handler.on_get_task( - request_obj, context - ) + task = await self.http_handler.on_get_task(request_obj, context) + if not task: + raise TaskNotFoundError() + return MessageToDict(task, preserving_proto_field_name=False) case ListTasksRequest(): - handler_result = await self.handler.list_tasks( - request_obj, context + tasks_response = await self.http_handler.on_list_tasks(request_obj, context) + return MessageToDict( + tasks_response, + preserving_proto_field_name=False, + always_print_fields_with_no_presence=True, ) case TaskPushNotificationConfig(): - handler_result = ( - await self.handler.set_push_notification_config( - request_obj, - context, + if not self.agent_card.capabilities.push_notifications: + raise UnsupportedOperationError( + message='Push notifications are not supported by the agent' ) + result_config = await self.http_handler.on_create_task_push_notification_config( + request_obj, context ) + return MessageToDict(result_config, preserving_proto_field_name=False) case GetTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.get_push_notification_config( - request_obj, - context, - ) + config = await self.http_handler.on_get_task_push_notification_config( + request_obj, context ) + return MessageToDict(config, preserving_proto_field_name=False) case ListTaskPushNotificationConfigsRequest(): - handler_result = ( - await self.handler.list_push_notification_configs( - request_obj, - context, - ) + list_push_response = await self.http_handler.on_list_task_push_notification_configs( + request_obj, context ) + return MessageToDict(list_push_response, preserving_proto_field_name=False) case DeleteTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.delete_push_notification_config( - request_obj, - context, - ) + await self.http_handler.on_delete_task_push_notification_config( + request_obj, context ) + return None case GetExtendedAgentCardRequest(): - handler_result = ( - await self.handler.get_authenticated_extended_card( - request_obj, - context, + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='The agent does not have an extended agent card configured' ) - ) + base_card = self.extended_agent_card or self.agent_card + card_to_serve = base_card + if self.extended_card_modifier and context: + card_to_serve = await maybe_await( + self.extended_card_modifier(base_card, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await(self.card_modifier(base_card)) + return MessageToDict(card_to_serve, preserving_proto_field_name=False) case _: logger.error( 'Unhandled validated request type: %s', type(request_obj) ) - error = UnsupportedOperationError( + raise UnsupportedOperationError( message=f'Request type {type(request_obj).__name__} is unknown.' ) - return self._generate_error_response(request_id, error) - - return self._create_response(context, handler_result) def _create_response( self, @@ -594,22 +620,3 @@ async def event_generator( # handler_result is a dict (JSON-RPC response) return JSONResponse(handler_result, headers=headers) - - async def _handle_get_agent_card(self, request: Request) -> JSONResponse: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return JSONResponse( - agent_card_to_dict( - card_to_serve, - ) - ) diff --git a/src/a2a/server/router/jsonrpc_router.py b/src/a2a/server/router/jsonrpc_router.py index 0e34849b..4cec9522 100644 --- a/src/a2a/server/router/jsonrpc_router.py +++ b/src/a2a/server/router/jsonrpc_router.py @@ -5,19 +5,18 @@ if TYPE_CHECKING: - from fastapi import APIRouter, FastAPI + from starlette.routing import Router - _package_fastapi_installed = True + _package_starlette_installed = True else: try: - from fastapi import APIRouter, FastAPI + from starlette.routing import Router - _package_fastapi_installed = True + _package_starlette_installed = True except ImportError: - APIRouter = Any - FastAPI = Any + Router = Any - _package_fastapi_installed = False + _package_starlette_installed = False from a2a.server.context import ServerCallContext @@ -27,10 +26,6 @@ JsonRpcDispatcher, ) from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) logger = logging.getLogger(__name__) @@ -56,9 +51,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB enable_v0_3_compat: bool = False, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, rpc_url: str = '', ) -> None: """Initializes the A2AFastAPIApplication. @@ -72,18 +65,14 @@ def __init__( # noqa: PLR0913 context_builder: The CallContextBuilder used to construct the ServerCallContext passed to the http_handler. If None, no ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. """ - if not _package_fastapi_installed: + if not _package_starlette_installed: raise ImportError( - 'The `fastapi` package is required to use the `A2AFastAPIApplication`.' + 'The `starlette` package is required to use the `JsonRpcRouter`.' ' It can be added as a part of `a2a-sdk` optional dependencies,' ' `a2a-sdk[http-server]`.' ) @@ -94,15 +83,13 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, - max_content_length=max_content_length, enable_v0_3_compat=enable_v0_3_compat, ) - self.router = APIRouter() - self._setup_router(agent_card_url, rpc_url) + self.router = Router() + self._setup_router(rpc_url) def _setup_router( self, - agent_card_url: str, rpc_url: str, ) -> None: """Configures the APIRouter with the A2A endpoints. @@ -111,19 +98,8 @@ def _setup_router( agent_card_url: The URL for the agent card endpoint. rpc_url: The URL for the A2A JSON-RPC endpoint. """ - self.router.post( + self.router.add_route( rpc_url, - openapi_extra={ - 'requestBody': { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/A2ARequest' - } - } - }, - 'required': True, - 'description': 'A2ARequest', - } - }, - )(self.dispatcher._handle_requests) + self.dispatcher._handle_requests, + methods=['POST'] + ) diff --git a/src/a2a/server/router/rest_router.py b/src/a2a/server/router/rest_router.py index 97fc4798..f528f437 100644 --- a/src/a2a/server/router/rest_router.py +++ b/src/a2a/server/router/rest_router.py @@ -5,32 +5,31 @@ if TYPE_CHECKING: - from fastapi import APIRouter, FastAPI from sse_starlette.sse import EventSourceResponse from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response + from starlette.routing import Router - _package_fastapi_installed = True + _package_starlette_installed = True else: try: - from fastapi import APIRouter, FastAPI from sse_starlette.sse import EventSourceResponse from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response + from starlette.routing import Router - _package_fastapi_installed = True + _package_starlette_installed = True except ImportError: - APIRouter = Any - FastAPI = Any + Router = Any EventSourceResponse = Any Request = Any JSONResponse = Any Response = Any StarletteHTTPException = Any - _package_fastapi_installed = False + _package_starlette_installed = False import json @@ -40,20 +39,16 @@ CallContextBuilder, DefaultCallContextBuilder, ) -from a2a.server.apps.rest.fastapi_app import _HTTP_TO_GRPC_STATUS_MAP from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.rest_handler_v2 import RESTHandlerV2 from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) from a2a.utils.errors import ( ExtendedAgentCardNotConfiguredError, InvalidRequestError, + TaskNotFoundError, + UnsupportedOperationError, ) from a2a.utils.helpers import maybe_await @@ -81,9 +76,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB enable_v0_3_compat: bool = False, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, rpc_url: str = '', ) -> None: """Initializes the A2AFastAPIApplication. @@ -102,27 +95,21 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. """ - if not _package_fastapi_installed: + if not _package_starlette_installed: raise ImportError( - 'The `fastapi` package is required to use the `A2AFastAPIApplication`.' + 'The `starlette` package is required to use the `RestRouter`.' ' It can be added as a part of `a2a-sdk` optional dependencies,' ' `a2a-sdk[http-server]`.' ) self.agent_card = agent_card self.http_handler = http_handler - self.rest_handler = RESTHandlerV2( - agent_card=agent_card, request_handler=http_handler - ) self.extended_agent_card = extended_agent_card self._context_builder = context_builder or DefaultCallContextBuilder() self.card_modifier = card_modifier self.extended_card_modifier = extended_card_modifier - self.max_content_length = max_content_length self.enable_v0_3_compat = enable_v0_3_compat self._v03_adapter = None @@ -138,8 +125,8 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, ) - self.router = APIRouter() - self._setup_router(agent_card_url, rpc_url) + self.router = Router() + self._setup_router(rpc_url) def _build_call_context(self, request: Request) -> ServerCallContext: call_context = self._context_builder.build(request) @@ -149,33 +136,27 @@ def _build_call_context(self, request: Request) -> ServerCallContext: def _setup_router( self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = '', + rpc_url, **kwargs: Any, ) -> None: """Builds and returns the FastAPI application instance.""" - if self.enable_v0_3_compat and self._v03_adapter: - v03_router = APIRouter() for route, callback in self._v03_adapter.routes().items(): - v03_router.add_api_route( + self.router.add_route( f'{rpc_url}{route[0]}', callback, methods=[route[1]] ) - self.router.include_router(v03_router) - - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = {} async def message_send(request: Request) -> Response: body = await request.body() params = a2a_pb2.SendMessageRequest() Parse(body, params) context = self._build_call_context(request) - result = await self.rest_handler.on_message_send( - params, context - ) - return JSONResponse(result) - - base_routes[('/message:send', 'POST')] = message_send + task_or_message = await self.http_handler.on_message_send(params, context) + if isinstance(task_or_message, a2a_pb2.Task): + response = a2a_pb2.SendMessageResponse(task=task_or_message) + else: + response = a2a_pb2.SendMessageResponse(message=task_or_message) + return JSONResponse(MessageToDict(response)) async def message_stream(request: Request) -> EventSourceResponse: try: @@ -185,68 +166,54 @@ async def message_stream(request: Request) -> EventSourceResponse: message=f'Failed to pre-consume request body: {e}' ) from e + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError(message='Streaming is not supported by the agent') + body = await request.body() params = a2a_pb2.SendMessageRequest() Parse(body, params) context = self._build_call_context(request) - async def event_generator( - stream: AsyncIterator[dict[str, Any]], - ) -> AsyncIterator[str]: - async for item in stream: - yield json.dumps(item) - - return EventSourceResponse( - event_generator( - self.rest_handler.on_message_send_stream( - params, context - ) - ) - ) + async def event_generator() -> AsyncIterator[str]: + async for event in self.http_handler.on_message_send_stream(params, context): + yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) - base_routes[('/message:stream', 'POST')] = message_stream + return EventSourceResponse(event_generator()) async def cancel_task(request: Request) -> Response: task_id = request.path_params['id'] params = a2a_pb2.CancelTaskRequest(id=task_id) context = self._build_call_context(request) - result = await self.rest_handler.on_cancel_task(params, context) - return JSONResponse(result) - - base_routes[('/tasks/{id}:cancel', 'POST')] = cancel_task + task = await self.http_handler.on_cancel_task(params, context) + if not task: + raise TaskNotFoundError() + return JSONResponse(MessageToDict(task)) async def subscribe_task(request: Request) -> EventSourceResponse: import contextlib # noqa: PLC0415 with contextlib.suppress(ValueError, RuntimeError, OSError): await request.body() task_id = request.path_params['id'] + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError(message='Streaming is not supported by the agent') params = a2a_pb2.SubscribeToTaskRequest(id=task_id) context = self._build_call_context(request) - async def event_generator( - stream: AsyncIterator[dict[str, Any]], - ) -> AsyncIterator[str]: - async for item in stream: - yield json.dumps(item) + async def event_generator() -> AsyncIterator[str]: + async for event in self.http_handler.on_subscribe_to_task(params, context): + yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) - return EventSourceResponse( - event_generator( - self.rest_handler.on_subscribe_to_task(params, context) - ) - ) - - base_routes[('/tasks/{id}:subscribe', 'GET')] = subscribe_task - base_routes[('/tasks/{id}:subscribe', 'POST')] = subscribe_task + return EventSourceResponse(event_generator()) async def get_task(request: Request) -> Response: params = a2a_pb2.GetTaskRequest() proto_utils.parse_params(request.query_params, params) params.id = request.path_params['id'] context = self._build_call_context(request) - result = await self.rest_handler.on_get_task(params, context) - return JSONResponse(result) - - base_routes[('/tasks/{id}', 'GET')] = get_task + task = await self.http_handler.on_get_task(params, context) + if not task: + raise TaskNotFoundError() + return JSONResponse(MessageToDict(task)) async def get_push_notification(request: Request) -> Response: task_id = request.path_params['id'] @@ -255,12 +222,8 @@ async def get_push_notification(request: Request) -> Response: task_id=task_id, id=push_id ) context = self._build_call_context(request) - result = await self.rest_handler.get_push_notification( - params, context - ) - return JSONResponse(result) - - base_routes[('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET')] = get_push_notification + config = await self.http_handler.on_get_task_push_notification_config(params, context) + return JSONResponse(MessageToDict(config)) async def delete_push_notification(request: Request) -> Response: task_id = request.path_params['id'] @@ -269,72 +232,72 @@ async def delete_push_notification(request: Request) -> Response: task_id=task_id, id=push_id ) context = self._build_call_context(request) - result = await self.rest_handler.delete_push_notification( - params, context - ) - return JSONResponse(result) - - base_routes[('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE')] = delete_push_notification + await self.http_handler.on_delete_task_push_notification_config(params, context) + return JSONResponse({}) async def set_push_notification(request: Request) -> Response: + if not self.agent_card.capabilities.push_notifications: + raise UnsupportedOperationError(message='Push notifications are not supported by the agent') body = await request.body() params = a2a_pb2.TaskPushNotificationConfig() Parse(body, params) params.task_id = request.path_params['id'] context = self._build_call_context(request) - result = await self.rest_handler.set_push_notification( - params, context - ) - return JSONResponse(result) - - base_routes[('/tasks/{id}/pushNotificationConfigs', 'POST')] = set_push_notification + config = await self.http_handler.on_create_task_push_notification_config(params, context) + return JSONResponse(MessageToDict(config)) async def list_push_notifications(request: Request) -> Response: params = a2a_pb2.ListTaskPushNotificationConfigsRequest() proto_utils.parse_params(request.query_params, params) params.task_id = request.path_params['id'] context = self._build_call_context(request) - result = await self.rest_handler.list_push_notifications( - params, context - ) - return JSONResponse(result) - - base_routes[('/tasks/{id}/pushNotificationConfigs', 'GET')] = list_push_notifications + result = await self.http_handler.on_list_task_push_notification_configs(params, context) + return JSONResponse(MessageToDict(result)) async def list_tasks(request: Request) -> Response: params = a2a_pb2.ListTasksRequest() proto_utils.parse_params(request.query_params, params) context = self._build_call_context(request) - result = await self.rest_handler.list_tasks(params, context) - return JSONResponse(result) - - base_routes[('/tasks', 'GET')] = list_tasks - - if self.agent_card.capabilities.extended_agent_card: - async def get_extended_agent_card(request: Request) -> Response: - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card or self.agent_card - - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await( - self.card_modifier(card_to_serve) - ) - - return JSONResponse( - MessageToDict( - card_to_serve, preserving_proto_field_name=True - ) + result = await self.http_handler.on_list_tasks(params, context) + return JSONResponse(MessageToDict(result, always_print_fields_with_no_presence=True)) + + async def get_extended_agent_card(request: Request) -> Response: + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='Authenticated card not supported' ) + card_to_serve = self.extended_agent_card or self.agent_card - base_routes[('/extendedAgentCard', 'GET')] = get_extended_agent_card + if self.extended_card_modifier: + context = self._build_call_context(request) + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await( + self.card_modifier(card_to_serve) + ) + + return JSONResponse( + MessageToDict( + card_to_serve, preserving_proto_field_name=True + ) + ) + + base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { + ('/message:send', 'POST'): message_send, + ('/message:stream', 'POST'): message_stream, + ('/tasks/{id}:cancel', 'POST'): cancel_task, + ('/tasks/{id}:subscribe', 'GET'): subscribe_task, + ('/tasks/{id}:subscribe', 'POST'): subscribe_task, + ('/tasks/{id}', 'GET'): get_task, + ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): get_push_notification, + ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE'): delete_push_notification, + ('/tasks/{id}/pushNotificationConfigs', 'POST'): set_push_notification, + ('/tasks/{id}/pushNotificationConfigs', 'GET'): list_push_notifications, + ('/tasks', 'GET'): list_tasks, + ('/extendedAgentCard', 'GET'): get_extended_agent_card, + } routes: dict[tuple[str, str], Callable[[Request], Any]] = { (p, method): handler @@ -343,7 +306,7 @@ async def get_extended_agent_card(request: Request) -> Response: } for (path, method), handler in routes.items(): - self.router.add_api_route( + self.router.add_route( f'{rpc_url}{path}', handler, methods=[method] ) diff --git a/tck/sut_agent.py b/tck/sut_agent.py index bb31975e..00bde9bb 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -8,25 +8,20 @@ import grpc.aio import uvicorn -from starlette.applications import Starlette +from fastapi import FastAPI import a2a.compat.v0_3.a2a_v0_3_pb2_grpc as a2a_v0_3_grpc import a2a.types.a2a_pb2_grpc as a2a_grpc -from fastapi import FastAPI from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps import ( - A2ARESTFastAPIApplication, - A2AStarletteApplication, -) -from a2a.server.router import JsonRpcRouter, RestRouter, AgentCardRouter from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) from a2a.server.request_handlers.grpc_handler import GrpcHandler +from a2a.server.router import AgentCardRouter, JsonRpcRouter, RestRouter from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore from a2a.types import ( From b491c52739c45db59e3863153bfb037c81b3540e Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 18 Mar 2026 14:09:54 +0000 Subject: [PATCH 03/17] wip --- src/a2a/server/router/agent_card_router.py | 32 +- src/a2a/server/router/jsonrpc_router.py | 48 +-- src/a2a/server/router/rest_router.py | 331 +++++++++++---------- tck/sut_agent.py | 16 +- 4 files changed, 203 insertions(+), 224 deletions(-) diff --git a/src/a2a/server/router/agent_card_router.py b/src/a2a/server/router/agent_card_router.py index 2055d8d3..04e280f7 100644 --- a/src/a2a/server/router/agent_card_router.py +++ b/src/a2a/server/router/agent_card_router.py @@ -7,21 +7,20 @@ if TYPE_CHECKING: from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Router + from starlette.routing import Route else: try: from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Router + from starlette.routing import Route except ImportError: - Router = Any + Route = Any Request = Any Response = Any JSONResponse = Any from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH from a2a.utils.helpers import maybe_await @@ -36,8 +35,7 @@ def __init__( agent_card: AgentCard, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = '', + card_url: str = '/', ) -> None: """Initializes the AgentCardRouter. @@ -45,27 +43,11 @@ def __init__( agent_card: The AgentCard describing the agent's capabilities. card_modifier: An optional callback to dynamically modify the public agent card before it is served. - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL prefix for the endpoint. + card_url: The URL for the agent card endpoint. """ self.agent_card = agent_card self.card_modifier = card_modifier - self.router = Router() - self._setup_router(agent_card_url, rpc_url) - - def _setup_router( - self, - agent_card_url: str, - rpc_url: str, - ) -> None: - """Configures the APIRouter with the A2A endpoints. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL prefix for the endpoint. - """ - async def get_agent_card(request: Request) -> Response: card_to_serve = self.agent_card if self.card_modifier: @@ -74,6 +56,4 @@ async def get_agent_card(request: Request) -> Response: ) return JSONResponse(agent_card_to_dict(card_to_serve)) - self.router.add_route( - f'{rpc_url}{agent_card_url}', get_agent_card, methods=['GET'] - ) + self.route = Route(path=card_url, endpoint=get_agent_card, methods=['GET']) diff --git a/src/a2a/server/router/jsonrpc_router.py b/src/a2a/server/router/jsonrpc_router.py index 4cec9522..e16c6158 100644 --- a/src/a2a/server/router/jsonrpc_router.py +++ b/src/a2a/server/router/jsonrpc_router.py @@ -31,8 +31,11 @@ logger = logging.getLogger(__name__) +from starlette.routing import Route + + class JsonRpcRouter: - """A FastAPI application implementing the A2A protocol server endpoints. + """Provides the Starlette Route for the A2A protocol JSON-RPC endpoint. Handles incoming JSON-RPC requests, routes them to the appropriate handler methods, and manages response generation including Server-Sent Events @@ -42,7 +45,7 @@ class JsonRpcRouter: def __init__( # noqa: PLR0913 self, agent_card: AgentCard, - http_handler: RequestHandler, + request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] @@ -52,23 +55,11 @@ def __init__( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, - rpc_url: str = '', + rpc_url: str = '/', ) -> None: - """Initializes the A2AFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + """Initializes the JsonRpcRouter. + + ... (docstrings remain the same) ... """ if not _package_starlette_installed: raise ImportError( @@ -76,30 +67,19 @@ def __init__( # noqa: PLR0913 ' It can be added as a part of `a2a-sdk` optional dependencies,' ' `a2a-sdk[http-server]`.' ) + self.dispatcher = JsonRpcDispatcher( agent_card=agent_card, - http_handler=http_handler, + http_handler=request_handler, extended_agent_card=extended_agent_card, context_builder=context_builder, card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, enable_v0_3_compat=enable_v0_3_compat, ) - self.router = Router() - self._setup_router(rpc_url) - def _setup_router( - self, - rpc_url: str, - ) -> None: - """Configures the APIRouter with the A2A endpoints. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - """ - self.router.add_route( - rpc_url, - self.dispatcher._handle_requests, + self.route = Route( + path=rpc_url, + endpoint=self.dispatcher._handle_requests, methods=['POST'] ) diff --git a/src/a2a/server/router/rest_router.py b/src/a2a/server/router/rest_router.py index f528f437..cf31231a 100644 --- a/src/a2a/server/router/rest_router.py +++ b/src/a2a/server/router/rest_router.py @@ -44,6 +44,10 @@ from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils +from a2a.utils.error_handlers import ( + rest_error_handler, + rest_stream_error_handler, +) from a2a.utils.errors import ( ExtendedAgentCardNotConfiguredError, InvalidRequestError, @@ -67,7 +71,7 @@ class RestRouter: def __init__( # noqa: PLR0913 self, agent_card: AgentCard, - http_handler: RequestHandler, + request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] @@ -83,12 +87,12 @@ def __init__( # noqa: PLR0913 Args: agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A + httpr: The handler instance responsible for processing A2A requests via http. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no + ServerCallContext passed to the httpr. If None, no ServerCallContext is passed. card_modifier: An optional callback to dynamically modify the public agent card before it is served. @@ -105,7 +109,7 @@ def __init__( # noqa: PLR0913 ) self.agent_card = agent_card - self.http_handler = http_handler + self.request_handler = request_handler self.extended_agent_card = extended_agent_card self._context_builder = context_builder or DefaultCallContextBuilder() self.card_modifier = card_modifier @@ -120,7 +124,7 @@ def __init__( # noqa: PLR0913 self._v03_adapter = V03RESTAdapter( agent_card=agent_card, - http_handler=http_handler, + httpr=request_handler, extended_agent_card=extended_agent_card, context_builder=context_builder, ) @@ -136,8 +140,7 @@ def _build_call_context(self, request: Request) -> ServerCallContext: def _setup_router( self, - rpc_url, - **kwargs: Any, + rpc_url: str, ) -> None: """Builds and returns the FastAPI application instance.""" if self.enable_v0_3_compat and self._v03_adapter: @@ -146,157 +149,19 @@ def _setup_router( f'{rpc_url}{route[0]}', callback, methods=[route[1]] ) - async def message_send(request: Request) -> Response: - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - context = self._build_call_context(request) - task_or_message = await self.http_handler.on_message_send(params, context) - if isinstance(task_or_message, a2a_pb2.Task): - response = a2a_pb2.SendMessageResponse(task=task_or_message) - else: - response = a2a_pb2.SendMessageResponse(message=task_or_message) - return JSONResponse(MessageToDict(response)) - - async def message_stream(request: Request) -> EventSourceResponse: - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - if not self.agent_card.capabilities.streaming: - raise UnsupportedOperationError(message='Streaming is not supported by the agent') - - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - context = self._build_call_context(request) - - async def event_generator() -> AsyncIterator[str]: - async for event in self.http_handler.on_message_send_stream(params, context): - yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) - - return EventSourceResponse(event_generator()) - - async def cancel_task(request: Request) -> Response: - task_id = request.path_params['id'] - params = a2a_pb2.CancelTaskRequest(id=task_id) - context = self._build_call_context(request) - task = await self.http_handler.on_cancel_task(params, context) - if not task: - raise TaskNotFoundError() - return JSONResponse(MessageToDict(task)) - - async def subscribe_task(request: Request) -> EventSourceResponse: - import contextlib # noqa: PLC0415 - with contextlib.suppress(ValueError, RuntimeError, OSError): - await request.body() - task_id = request.path_params['id'] - if not self.agent_card.capabilities.streaming: - raise UnsupportedOperationError(message='Streaming is not supported by the agent') - params = a2a_pb2.SubscribeToTaskRequest(id=task_id) - context = self._build_call_context(request) - - async def event_generator() -> AsyncIterator[str]: - async for event in self.http_handler.on_subscribe_to_task(params, context): - yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) - - return EventSourceResponse(event_generator()) - - async def get_task(request: Request) -> Response: - params = a2a_pb2.GetTaskRequest() - proto_utils.parse_params(request.query_params, params) - params.id = request.path_params['id'] - context = self._build_call_context(request) - task = await self.http_handler.on_get_task(params, context) - if not task: - raise TaskNotFoundError() - return JSONResponse(MessageToDict(task)) - - async def get_push_notification(request: Request) -> Response: - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = a2a_pb2.GetTaskPushNotificationConfigRequest( - task_id=task_id, id=push_id - ) - context = self._build_call_context(request) - config = await self.http_handler.on_get_task_push_notification_config(params, context) - return JSONResponse(MessageToDict(config)) - - async def delete_push_notification(request: Request) -> Response: - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( - task_id=task_id, id=push_id - ) - context = self._build_call_context(request) - await self.http_handler.on_delete_task_push_notification_config(params, context) - return JSONResponse({}) - - async def set_push_notification(request: Request) -> Response: - if not self.agent_card.capabilities.push_notifications: - raise UnsupportedOperationError(message='Push notifications are not supported by the agent') - body = await request.body() - params = a2a_pb2.TaskPushNotificationConfig() - Parse(body, params) - params.task_id = request.path_params['id'] - context = self._build_call_context(request) - config = await self.http_handler.on_create_task_push_notification_config(params, context) - return JSONResponse(MessageToDict(config)) - - async def list_push_notifications(request: Request) -> Response: - params = a2a_pb2.ListTaskPushNotificationConfigsRequest() - proto_utils.parse_params(request.query_params, params) - params.task_id = request.path_params['id'] - context = self._build_call_context(request) - result = await self.http_handler.on_list_task_push_notification_configs(params, context) - return JSONResponse(MessageToDict(result)) - - async def list_tasks(request: Request) -> Response: - params = a2a_pb2.ListTasksRequest() - proto_utils.parse_params(request.query_params, params) - context = self._build_call_context(request) - result = await self.http_handler.on_list_tasks(params, context) - return JSONResponse(MessageToDict(result, always_print_fields_with_no_presence=True)) - - async def get_extended_agent_card(request: Request) -> Response: - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card or self.agent_card - - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await( - self.card_modifier(card_to_serve) - ) - - return JSONResponse( - MessageToDict( - card_to_serve, preserving_proto_field_name=True - ) - ) - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { - ('/message:send', 'POST'): message_send, - ('/message:stream', 'POST'): message_stream, - ('/tasks/{id}:cancel', 'POST'): cancel_task, - ('/tasks/{id}:subscribe', 'GET'): subscribe_task, - ('/tasks/{id}:subscribe', 'POST'): subscribe_task, - ('/tasks/{id}', 'GET'): get_task, - ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): get_push_notification, - ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE'): delete_push_notification, - ('/tasks/{id}/pushNotificationConfigs', 'POST'): set_push_notification, - ('/tasks/{id}/pushNotificationConfigs', 'GET'): list_push_notifications, - ('/tasks', 'GET'): list_tasks, - ('/extendedAgentCard', 'GET'): get_extended_agent_card, + ('/message:send', 'POST'): self._message_send, + ('/message:stream', 'POST'): self._message_stream, + ('/tasks/{id}:cancel', 'POST'): self._cancel_task, + ('/tasks/{id}:subscribe', 'GET'): self._subscribe_task, + ('/tasks/{id}:subscribe', 'POST'): self._subscribe_task, + ('/tasks/{id}', 'GET'): self._get_task, + ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): self._get_push_notification, + ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE'): self._delete_push_notification, + ('/tasks/{id}/pushNotificationConfigs', 'POST'): self._set_push_notification, + ('/tasks/{id}/pushNotificationConfigs', 'GET'): self._list_push_notifications, + ('/tasks', 'GET'): self._list_tasks, + ('/extendedAgentCard', 'GET'): self._get_extended_agent_card, } routes: dict[tuple[str, str], Callable[[Request], Any]] = { @@ -310,3 +175,155 @@ async def get_extended_agent_card(request: Request) -> Response: f'{rpc_url}{path}', handler, methods=[method] ) + @rest_error_handler + async def _message_send(self, request: Request) -> Response: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + context = self._build_call_context(request) + task_or_message = await self.request_handler.on_message_send(params, context) + if isinstance(task_or_message, a2a_pb2.Task): + response = a2a_pb2.SendMessageResponse(task=task_or_message) + else: + response = a2a_pb2.SendMessageResponse(message=task_or_message) + return JSONResponse(MessageToDict(response)) + + @rest_stream_error_handler + async def _message_stream(self, request: Request) -> EventSourceResponse: + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError(message='Streaming is not supported by the agent') + + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + context = self._build_call_context(request) + + async def event_generator() -> AsyncIterator[str]: + async for event in self.request_handler.on_message_send_stream(params, context): + yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + + return EventSourceResponse(event_generator()) + + @rest_error_handler + async def _cancel_task(self, request: Request) -> Response: + task_id = request.path_params['id'] + params = a2a_pb2.CancelTaskRequest(id=task_id) + context = self._build_call_context(request) + task = await self.request_handler.on_cancel_task(params, context) + if not task: + raise TaskNotFoundError + return JSONResponse(MessageToDict(task)) + + @rest_stream_error_handler + async def _subscribe_task(self, request: Request) -> EventSourceResponse: + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + task_id = request.path_params['id'] + if not self.agent_card.capabilities.streaming: + raise UnsupportedOperationError(message='Streaming is not supported by the agent') + params = a2a_pb2.SubscribeToTaskRequest(id=task_id) + context = self._build_call_context(request) + + async def event_generator() -> AsyncIterator[str]: + async for event in self.request_handler.on_subscribe_to_task(params, context): + yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + + return EventSourceResponse(event_generator()) + + @rest_error_handler + async def _get_task(self, request: Request) -> Response: + params = a2a_pb2.GetTaskRequest() + proto_utils.parse_params(request.query_params, params) + params.id = request.path_params['id'] + context = self._build_call_context(request) + task = await self.request_handler.on_get_task(params, context) + if not task: + raise TaskNotFoundError + return JSONResponse(MessageToDict(task)) + + @rest_error_handler + async def _get_push_notification(self, request: Request) -> Response: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.GetTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + context = self._build_call_context(request) + config = await self.request_handler.on_get_task_push_notification_config(params, context) + return JSONResponse(MessageToDict(config)) + + @rest_error_handler + async def _delete_push_notification(self, request: Request) -> Response: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + context = self._build_call_context(request) + await self.request_handler.on_delete_task_push_notification_config(params, context) + return JSONResponse({}) + + @rest_error_handler + async def _set_push_notification(self, request: Request) -> Response: + if not self.agent_card.capabilities.push_notifications: + raise UnsupportedOperationError(message='Push notifications are not supported by the agent') + body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) + params.task_id = request.path_params['id'] + context = self._build_call_context(request) + config = await self.request_handler.on_create_task_push_notification_config(params, context) + return JSONResponse(MessageToDict(config)) + + @rest_error_handler + async def _list_push_notifications(self, request: Request) -> Response: + params = a2a_pb2.ListTaskPushNotificationConfigsRequest() + proto_utils.parse_params(request.query_params, params) + params.task_id = request.path_params['id'] + context = self._build_call_context(request) + result = await self.request_handler.on_list_task_push_notification_configs(params, context) + return JSONResponse(MessageToDict(result)) + + @rest_error_handler + async def _list_tasks(self, request: Request) -> Response: + params = a2a_pb2.ListTasksRequest() + proto_utils.parse_params(request.query_params, params) + context = self._build_call_context(request) + result = await self.request_handler.on_list_tasks(params, context) + return JSONResponse(MessageToDict(result, always_print_fields_with_no_presence=True)) + + @rest_error_handler + async def _get_extended_agent_card(self, request: Request) -> Response: + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='Authenticated card not supported' + ) + card_to_serve = self.extended_agent_card or self.agent_card + + if self.extended_card_modifier: + context = self._build_call_context(request) + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await( + self.card_modifier(card_to_serve) + ) + + return JSONResponse( + MessageToDict( + card_to_serve, preserving_proto_field_name=True + ) + ) + diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 00bde9bb..cedd67fe 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -8,7 +8,7 @@ import grpc.aio import uvicorn -from fastapi import FastAPI +from starlette.applications import Starlette import a2a.compat.v0_3.a2a_v0_3_pb2_grpc as a2a_v0_3_grpc import a2a.types.a2a_pb2_grpc as a2a_grpc @@ -194,27 +194,29 @@ def serve(task_store: TaskStore) -> None: task_store=task_store, ) - main_app = FastAPI() + main_app = Starlette() # Agent Card agent_card_router = AgentCardRouter( agent_card=agent_card, + card_url=AGENT_CARD_URL, ) - main_app.include_router(agent_card_router.router) + main_app.routes.append(agent_card_router.route) # JSONRPC jsonrpc_router = JsonRpcRouter( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, + rpc_url=JSONRPC_URL, ) - main_app.include_router(jsonrpc_router.router, prefix=JSONRPC_URL) + main_app.routes.append(jsonrpc_router.route) # REST rest_router = RestRouter( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, ) - main_app.include_router(rest_router.router, prefix=REST_URL) + main_app.mount(REST_URL, rest_router.router) config = uvicorn.Config( main_app, host='127.0.0.1', port=http_port, log_level='info' From 2ad23bc98997ccbb59bff0198b06d7d83242007f Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 08:49:53 +0000 Subject: [PATCH 04/17] wip --- src/a2a/server/apps/__init__.py | 18 - src/a2a/server/apps/jsonrpc/__init__.py | 20 - src/a2a/server/apps/jsonrpc/fastapi_app.py | 148 ---- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 640 ------------------ src/a2a/server/apps/jsonrpc/starlette_app.py | 169 ----- src/a2a/server/apps/rest/__init__.py | 8 - src/a2a/server/apps/rest/fastapi_app.py | 194 ------ src/a2a/server/apps/rest/rest_adapter.py | 296 -------- src/a2a/server/request_handlers/__init__.py | 4 - .../request_handlers/jsonrpc_handler.py | 474 ------------- .../server/request_handlers/rest_handler.py | 321 --------- .../request_handlers/rest_handler_v2.py | 156 ----- src/a2a/server/router/__init__.py | 9 +- src/a2a/server/router/agent_card_router.py | 18 +- src/a2a/server/router/jsonrpc_router.py | 9 +- src/a2a/server/router/rest_router.py | 43 +- tck/sut_agent.py | 4 +- tests/server/apps/jsonrpc/test_fastapi_app.py | 79 --- tests/server/apps/jsonrpc/test_jsonrpc_app.py | 444 ------------ .../server/apps/jsonrpc/test_serialization.py | 280 -------- .../server/apps/jsonrpc/test_starlette_app.py | 81 --- .../server/apps/rest/test_rest_fastapi_app.py | 156 +---- tests/server/router/test_jsonrpc_router.py | 275 ++++++++ tests/server/test_integration.py | 63 +- 24 files changed, 401 insertions(+), 3508 deletions(-) delete mode 100644 src/a2a/server/apps/__init__.py delete mode 100644 src/a2a/server/apps/jsonrpc/__init__.py delete mode 100644 src/a2a/server/apps/jsonrpc/fastapi_app.py delete mode 100644 src/a2a/server/apps/jsonrpc/jsonrpc_app.py delete mode 100644 src/a2a/server/apps/jsonrpc/starlette_app.py delete mode 100644 src/a2a/server/apps/rest/__init__.py delete mode 100644 src/a2a/server/apps/rest/fastapi_app.py delete mode 100644 src/a2a/server/apps/rest/rest_adapter.py delete mode 100644 src/a2a/server/request_handlers/jsonrpc_handler.py delete mode 100644 src/a2a/server/request_handlers/rest_handler.py delete mode 100644 src/a2a/server/request_handlers/rest_handler_v2.py delete mode 100644 tests/server/apps/jsonrpc/test_fastapi_app.py delete mode 100644 tests/server/apps/jsonrpc/test_jsonrpc_app.py delete mode 100644 tests/server/apps/jsonrpc/test_serialization.py delete mode 100644 tests/server/apps/jsonrpc/test_starlette_app.py create mode 100644 tests/server/router/test_jsonrpc_router.py diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py deleted file mode 100644 index 579deaa5..00000000 --- a/src/a2a/server/apps/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""HTTP application components for the A2A server.""" - -from a2a.server.apps.jsonrpc import ( - A2AFastAPIApplication, - A2AStarletteApplication, - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.apps.rest import A2ARESTFastAPIApplication - - -__all__ = [ - 'A2AFastAPIApplication', - 'A2ARESTFastAPIApplication', - 'A2AStarletteApplication', - 'CallContextBuilder', - 'JSONRPCApplication', -] diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py deleted file mode 100644 index 1121fdbc..00000000 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""A2A JSON-RPC Applications.""" - -from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - DefaultCallContextBuilder, - JSONRPCApplication, - StarletteUserProxy, -) -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication - - -__all__ = [ - 'A2AFastAPIApplication', - 'A2AStarletteApplication', - 'CallContextBuilder', - 'DefaultCallContextBuilder', - 'JSONRPCApplication', - 'StarletteUserProxy', -] diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py deleted file mode 100644 index 0ec9d1ab..00000000 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ /dev/null @@ -1,148 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from fastapi import FastAPI - - _package_fastapi_installed = True -else: - try: - from fastapi import FastAPI - - _package_fastapi_installed = True - except ImportError: - FastAPI = Any - - _package_fastapi_installed = False - -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) - - -logger = logging.getLogger(__name__) - - -class A2AFastAPIApplication(JSONRPCApplication): - """A FastAPI application implementing the A2A protocol server endpoints. - - Handles incoming JSON-RPC requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB - enable_v0_3_compat: bool = False, - ) -> None: - """Initializes the A2AFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - """ - if not _package_fastapi_installed: - raise ImportError( - 'The `fastapi` package is required to use the `A2AFastAPIApplication`.' - ' It can be added as a part of `a2a-sdk` optional dependencies,' - ' `a2a-sdk[http-server]`.' - ) - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - max_content_length=max_content_length, - enable_v0_3_compat=enable_v0_3_compat, - ) - - def add_routes_to_app( - self, - app: FastAPI, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> None: - """Adds the routes to the FastAPI application. - - Args: - app: The FastAPI application to add the routes to. - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - """ - app.post( - rpc_url, - openapi_extra={ - 'requestBody': { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/A2ARequest' - } - } - }, - 'required': True, - 'description': 'A2ARequest', - } - }, - )(self._handle_requests) - app.get(agent_card_url)(self._handle_get_agent_card) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> FastAPI: - """Builds and returns the FastAPI application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured FastAPI application instance. - """ - app = FastAPI(**kwargs) - - self.add_routes_to_app(app, agent_card_url, rpc_url) - - return app diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py deleted file mode 100644 index 0d79b10e..00000000 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ /dev/null @@ -1,640 +0,0 @@ -"""JSON-RPC application for A2A server.""" - -import contextlib -import json -import logging -import traceback - -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import ParseDict -from jsonrpc.jsonrpc2 import JSONRPC20Request - -from a2a.auth.user import UnauthenticatedUser -from a2a.auth.user import User as A2AUser -from a2a.compat.v0_3.jsonrpc_adapter import JSONRPC03Adapter -from a2a.extensions.common import ( - HTTP_EXTENSION_HEADER, - get_requested_extensions, -) -from a2a.server.context import ServerCallContext -from a2a.server.jsonrpc_models import ( - InternalError, - InvalidParamsError, - InvalidRequestError, - JSONParseError, - JSONRPCError, - MethodNotFoundError, -) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, - build_error_response, -) -from a2a.types import A2ARequest -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - DeleteTaskPushNotificationConfigRequest, - GetExtendedAgentCardRequest, - GetTaskPushNotificationConfigRequest, - GetTaskRequest, - ListTaskPushNotificationConfigsRequest, - ListTasksRequest, - SendMessageRequest, - SubscribeToTaskRequest, - TaskPushNotificationConfig, -) -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) -from a2a.utils.errors import ( - A2AError, - UnsupportedOperationError, -) -from a2a.utils.helpers import maybe_await - - -INTERNAL_ERROR_CODE = -32603 - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from fastapi import FastAPI - from sse_starlette.sse import EventSourceResponse - from starlette.applications import Starlette - from starlette.authentication import BaseUser - from starlette.exceptions import HTTPException - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - try: - # Starlette v0.48.0 - from starlette.status import HTTP_413_CONTENT_TOO_LARGE - except ImportError: - from starlette.status import ( # type: ignore[no-redef] - HTTP_413_REQUEST_ENTITY_TOO_LARGE as HTTP_413_CONTENT_TOO_LARGE, - ) - - _package_starlette_installed = True -else: - FastAPI = Any - try: - from sse_starlette.sse import EventSourceResponse - from starlette.applications import Starlette - from starlette.authentication import BaseUser - from starlette.exceptions import HTTPException - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - try: - # Starlette v0.48.0 - from starlette.status import HTTP_413_CONTENT_TOO_LARGE - except ImportError: - from starlette.status import ( - HTTP_413_REQUEST_ENTITY_TOO_LARGE as HTTP_413_CONTENT_TOO_LARGE, - ) - - _package_starlette_installed = True - except ImportError: - _package_starlette_installed = False - # Provide placeholder types for runtime type hinting when dependencies are not installed. - # These will not be used if the code path that needs them is guarded by _http_server_installed. - EventSourceResponse = Any - Starlette = Any - BaseUser = Any - HTTPException = Any - Request = Any - JSONResponse = Any - Response = Any - HTTP_413_CONTENT_TOO_LARGE = Any - - -class StarletteUserProxy(A2AUser): - """Adapts the Starlette User class to the A2A user representation.""" - - def __init__(self, user: BaseUser): - self._user = user - - @property - def is_authenticated(self) -> bool: - """Returns whether the current user is authenticated.""" - return self._user.is_authenticated - - @property - def user_name(self) -> str: - """Returns the user name of the current user.""" - return self._user.display_name - - -class CallContextBuilder(ABC): - """A class for building ServerCallContexts using the Starlette Request.""" - - @abstractmethod - def build(self, request: Request) -> ServerCallContext: - """Builds a ServerCallContext from a Starlette Request.""" - - -class DefaultCallContextBuilder(CallContextBuilder): - """A default implementation of CallContextBuilder.""" - - def build(self, request: Request) -> ServerCallContext: - """Builds a ServerCallContext from a Starlette Request. - - Args: - request: The incoming Starlette Request object. - - Returns: - A ServerCallContext instance populated with user and state - information from the request. - """ - user: A2AUser = UnauthenticatedUser() - state = {} - with contextlib.suppress(Exception): - user = StarletteUserProxy(request.user) - state['auth'] = request.auth - state['headers'] = dict(request.headers) - return ServerCallContext( - user=user, - state=state, - requested_extensions=get_requested_extensions( - request.headers.getlist(HTTP_EXTENSION_HEADER) - ), - ) - - -class JSONRPCApplication(ABC): - """Base class for A2A JSONRPC applications. - - Handles incoming JSON-RPC requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - # Method-to-model mapping for centralized routing - # Proto types don't have model_fields, so we define the mapping explicitly - # Method names match gRPC service method names - METHOD_TO_MODEL: dict[str, type] = { - 'SendMessage': SendMessageRequest, - 'SendStreamingMessage': SendMessageRequest, # Same proto type as SendMessage - 'GetTask': GetTaskRequest, - 'ListTasks': ListTasksRequest, - 'CancelTask': CancelTaskRequest, - 'CreateTaskPushNotificationConfig': TaskPushNotificationConfig, - 'GetTaskPushNotificationConfig': GetTaskPushNotificationConfigRequest, - 'ListTaskPushNotificationConfigs': ListTaskPushNotificationConfigsRequest, - 'DeleteTaskPushNotificationConfig': DeleteTaskPushNotificationConfigRequest, - 'SubscribeToTask': SubscribeToTaskRequest, - 'GetExtendedAgentCard': GetExtendedAgentCardRequest, - } - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB - enable_v0_3_compat: bool = False, - ) -> None: - """Initializes the JSONRPCApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use the' - ' `JSONRPCApplication`. They can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier - self.handler = JSONRPCHandler( - agent_card=agent_card, - request_handler=http_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=extended_card_modifier, - ) - self._context_builder = context_builder or DefaultCallContextBuilder() - self._max_content_length = max_content_length - self.enable_v0_3_compat = enable_v0_3_compat - self._v03_adapter: JSONRPC03Adapter | None = None - - if self.enable_v0_3_compat: - self._v03_adapter = JSONRPC03Adapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - - def _generate_error_response( - self, - request_id: str | int | None, - error: Exception | JSONRPCError | A2AError, - ) -> JSONResponse: - """Creates a Starlette JSONResponse for a JSON-RPC error. - - Logs the error based on its type. - - Args: - request_id: The ID of the request that caused the error. - error: The error object (one of the JSONRPCError types). - - Returns: - A `JSONResponse` object formatted as a JSON-RPC error response. - """ - if not isinstance(error, A2AError | JSONRPCError): - error = InternalError(message=str(error)) - - response_data = build_error_response(request_id, error) - error_info = response_data.get('error', {}) - code = error_info.get('code') - message = error_info.get('message') - data = error_info.get('data') - - log_level = logging.WARNING - if code == INTERNAL_ERROR_CODE: - log_level = logging.ERROR - - logger.log( - log_level, - "Request Error (ID: %s): Code=%s, Message='%s'%s", - request_id, - code, - message, - f', Data={data}' if data else '', - ) - return JSONResponse( - response_data, - status_code=200, - ) - - def _allowed_content_length(self, request: Request) -> bool: - """Checks if the request content length is within the allowed maximum. - - Args: - request: The incoming Starlette Request object. - - Returns: - False if the content length is larger than the allowed maximum, True otherwise. - """ - if self._max_content_length is not None: - with contextlib.suppress(ValueError): - content_length = int(request.headers.get('content-length', '0')) - if content_length and content_length > self._max_content_length: - return False - return True - - async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 - """Handles incoming POST requests to the main A2A endpoint. - - Parses the request body as JSON, validates it against A2A request types, - dispatches it to the appropriate handler method, and returns the response. - Handles JSON parsing errors, validation errors, and other exceptions, - returning appropriate JSON-RPC error responses. - - Args: - request: The incoming Starlette Request object. - - Returns: - A Starlette Response object (JSONResponse or EventSourceResponse). - - Raises: - (Implicitly handled): Various exceptions are caught and converted - into JSON-RPC error responses by this method. - """ - request_id = None - body = None - - try: - body = await request.json() - if isinstance(body, dict): - request_id = body.get('id') - # Ensure request_id is valid for JSON-RPC response (str/int/None only) - if request_id is not None and not isinstance( - request_id, str | int - ): - request_id = None - # Treat payloads lager than allowed as invalid request (-32600) before routing - if not self._allowed_content_length(request): - return self._generate_error_response( - request_id, - InvalidRequestError(message='Payload too large'), - ) - logger.debug('Request body: %s', body) - # 1) Validate base JSON-RPC structure only (-32600 on failure) - try: - base_request = JSONRPC20Request.from_data(body) - if not isinstance(base_request, JSONRPC20Request): - # Batch requests are not supported - return self._generate_error_response( - request_id, - InvalidRequestError( - message='Batch requests are not supported' - ), - ) - if body.get('jsonrpc') != '2.0': - return self._generate_error_response( - request_id, - InvalidRequestError( - message="Invalid request: 'jsonrpc' must be exactly '2.0'" - ), - ) - except Exception as e: - logger.exception('Failed to validate base JSON-RPC request') - return self._generate_error_response( - request_id, - InvalidRequestError(data=str(e)), - ) - - # 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure) - method: str | None = base_request.method - request_id = base_request._id # noqa: SLF001 - - if not method: - return self._generate_error_response( - request_id, - InvalidRequestError(message='Method is required'), - ) - - if ( - self.enable_v0_3_compat - and self._v03_adapter - and self._v03_adapter.supports_method(method) - ): - return await self._v03_adapter.handle_request( - request_id=request_id, - method=method, - body=body, - request=request, - ) - - model_class = self.METHOD_TO_MODEL.get(method) - if not model_class: - return self._generate_error_response( - request_id, MethodNotFoundError() - ) - try: - # Parse the params field into the proto message type - params = body.get('params', {}) - specific_request = ParseDict(params, model_class()) - except Exception as e: - logger.exception('Failed to parse request params') - return self._generate_error_response( - request_id, - InvalidParamsError(data=str(e)), - ) - - # 3) Build call context and wrap the request for downstream handling - call_context = self._context_builder.build(request) - call_context.tenant = getattr(specific_request, 'tenant', '') - call_context.state['method'] = method - call_context.state['request_id'] = request_id - - # Route streaming requests by method name - if method in ('SendStreamingMessage', 'SubscribeToTask'): - return await self._process_streaming_request( - request_id, specific_request, call_context - ) - - return await self._process_non_streaming_request( - request_id, specific_request, call_context - ) - except json.decoder.JSONDecodeError as e: - traceback.print_exc() - return self._generate_error_response( - None, JSONParseError(message=str(e)) - ) - except HTTPException as e: - if e.status_code == HTTP_413_CONTENT_TOO_LARGE: - return self._generate_error_response( - request_id, - InvalidRequestError(message='Payload too large'), - ) - raise e - except Exception as e: - logger.exception('Unhandled exception') - return self._generate_error_response( - request_id, InternalError(message=str(e)) - ) - - async def _process_streaming_request( - self, - request_id: str | int | None, - request_obj: A2ARequest, - context: ServerCallContext, - ) -> Response: - """Processes streaming requests (SendStreamingMessage or SubscribeToTask). - - Args: - request_id: The ID of the request. - request_obj: The proto request message. - context: The ServerCallContext for the request. - - Returns: - An `EventSourceResponse` object to stream results to the client. - """ - handler_result: Any = None - # Check for streaming message request (same type as SendMessage, but handled differently) - if isinstance( - request_obj, - SendMessageRequest, - ): - handler_result = self.handler.on_message_send_stream( - request_obj, context - ) - elif isinstance(request_obj, SubscribeToTaskRequest): - handler_result = self.handler.on_subscribe_to_task( - request_obj, context - ) - - return self._create_response(context, handler_result) - - async def _process_non_streaming_request( - self, - request_id: str | int | None, - request_obj: A2ARequest, - context: ServerCallContext, - ) -> Response: - """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). - - Args: - request_id: The ID of the request. - request_obj: The proto request message. - context: The ServerCallContext for the request. - - Returns: - A `JSONResponse` object containing the result or error. - """ - handler_result: Any = None - match request_obj: - case SendMessageRequest(): - handler_result = await self.handler.on_message_send( - request_obj, context - ) - case CancelTaskRequest(): - handler_result = await self.handler.on_cancel_task( - request_obj, context - ) - case GetTaskRequest(): - handler_result = await self.handler.on_get_task( - request_obj, context - ) - case ListTasksRequest(): - handler_result = await self.handler.list_tasks( - request_obj, context - ) - case TaskPushNotificationConfig(): - handler_result = ( - await self.handler.set_push_notification_config( - request_obj, - context, - ) - ) - case GetTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.get_push_notification_config( - request_obj, - context, - ) - ) - case ListTaskPushNotificationConfigsRequest(): - handler_result = ( - await self.handler.list_push_notification_configs( - request_obj, - context, - ) - ) - case DeleteTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.delete_push_notification_config( - request_obj, - context, - ) - ) - case GetExtendedAgentCardRequest(): - handler_result = ( - await self.handler.get_authenticated_extended_card( - request_obj, - context, - ) - ) - case _: - logger.error( - 'Unhandled validated request type: %s', type(request_obj) - ) - error = UnsupportedOperationError( - message=f'Request type {type(request_obj).__name__} is unknown.' - ) - return self._generate_error_response(request_id, error) - - return self._create_response(context, handler_result) - - def _create_response( - self, - context: ServerCallContext, - handler_result: AsyncGenerator[dict[str, Any]] | dict[str, Any], - ) -> Response: - """Creates a Starlette Response based on the result from the request handler. - - Handles: - - AsyncGenerator for Server-Sent Events (SSE). - - Dict responses from handlers. - - Args: - context: The ServerCallContext provided to the request handler. - handler_result: The result from a request handler method. Can be an - async generator for streaming or a dict for non-streaming. - - Returns: - A Starlette JSONResponse or EventSourceResponse. - """ - headers = {} - if exts := context.activated_extensions: - headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts)) - if isinstance(handler_result, AsyncGenerator): - # Result is a stream of dict objects - async def event_generator( - stream: AsyncGenerator[dict[str, Any]], - ) -> AsyncGenerator[dict[str, str]]: - async for item in stream: - yield {'data': json.dumps(item)} - - return EventSourceResponse( - event_generator(handler_result), headers=headers - ) - - # handler_result is a dict (JSON-RPC response) - return JSONResponse(handler_result, headers=headers) - - async def _handle_get_agent_card(self, request: Request) -> JSONResponse: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return JSONResponse( - agent_card_to_dict( - card_to_serve, - ) - ) - - @abstractmethod - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> FastAPI | Starlette: - """Builds and returns the JSONRPC application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured JSONRPC application instance. - """ - raise NotImplementedError( - 'Subclasses must implement the build method to create the application instance.' - ) diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py deleted file mode 100644 index 553fa250..00000000 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ /dev/null @@ -1,169 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from starlette.applications import Starlette - from starlette.routing import Route - - _package_starlette_installed = True - -else: - try: - from starlette.applications import Starlette - from starlette.routing import Route - - _package_starlette_installed = True - except ImportError: - Starlette = Any - Route = Any - - _package_starlette_installed = False - -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) - - -logger = logging.getLogger(__name__) - - -class A2AStarletteApplication(JSONRPCApplication): - """A Starlette application implementing the A2A protocol server endpoints. - - Handles incoming JSON-RPC requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB - enable_v0_3_compat: bool = False, - ) -> None: - """Initializes the A2AStarletteApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use the' - ' `A2AStarletteApplication`. It can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - max_content_length=max_content_length, - enable_v0_3_compat=enable_v0_3_compat, - ) - - def routes( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> list[Route]: - """Returns the Starlette Routes for handling A2A requests. - - Args: - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - - Returns: - A list of Starlette Route objects. - """ - return [ - Route( - rpc_url, - self._handle_requests, - methods=['POST'], - name='a2a_handler', - ), - Route( - agent_card_url, - self._handle_get_agent_card, - methods=['GET'], - name='agent_card', - ), - ] - - def add_routes_to_app( - self, - app: Starlette, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> None: - """Adds the routes to the Starlette application. - - Args: - app: The Starlette application to add the routes to. - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - """ - routes = self.routes( - agent_card_url=agent_card_url, - rpc_url=rpc_url, - ) - app.routes.extend(routes) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> Starlette: - """Builds and returns the Starlette application instance. - - Args: - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - **kwargs: Additional keyword arguments to pass to the Starlette constructor. - - Returns: - A configured Starlette application instance. - """ - app = Starlette(**kwargs) - - self.add_routes_to_app(app, agent_card_url, rpc_url) - - return app diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py deleted file mode 100644 index bafe4cb6..00000000 --- a/src/a2a/server/apps/rest/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""A2A REST Applications.""" - -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication - - -__all__ = [ - 'A2ARESTFastAPIApplication', -] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py deleted file mode 100644 index ea9a501b..00000000 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ /dev/null @@ -1,194 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True -else: - try: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True - except ImportError: - APIRouter = Any - FastAPI = Any - Request = Any - Response = Any - StarletteHTTPException = Any - - _package_fastapi_installed = False - - -from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder -from a2a.server.apps.rest.rest_adapter import RESTAdapter -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - - -logger = logging.getLogger(__name__) - - -_HTTP_TO_GRPC_STATUS_MAP = { - 400: 'INVALID_ARGUMENT', - 401: 'UNAUTHENTICATED', - 403: 'PERMISSION_DENIED', - 404: 'NOT_FOUND', - 405: 'UNIMPLEMENTED', - 409: 'ALREADY_EXISTS', - 415: 'INVALID_ARGUMENT', - 422: 'INVALID_ARGUMENT', - 500: 'INTERNAL', - 501: 'UNIMPLEMENTED', - 502: 'INTERNAL', - 503: 'UNAVAILABLE', - 504: 'DEADLINE_EXCEEDED', -} - - -class A2ARESTFastAPIApplication: - """A FastAPI application implementing the A2A protocol server REST endpoints. - - Handles incoming REST requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - enable_v0_3_compat: bool = False, - ): - """Initializes the A2ARESTFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol - endpoints under the '/v0.3' path prefix using REST03Adapter. - """ - if not _package_fastapi_installed: - raise ImportError( - 'The `fastapi` package is required to use the' - ' `A2ARESTFastAPIApplication`. It can be added as a part of' - ' `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' - ) - self._adapter = RESTAdapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - self.enable_v0_3_compat = enable_v0_3_compat - self._v03_adapter = None - - if self.enable_v0_3_compat: - self._v03_adapter = REST03Adapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = '', - **kwargs: Any, - ) -> FastAPI: - """Builds and returns the FastAPI application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A REST endpoint base path. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured FastAPI application instance. - """ - app = FastAPI(**kwargs) - - @app.exception_handler(StarletteHTTPException) - async def http_exception_handler( - request: Request, exc: StarletteHTTPException - ) -> Response: - """Catches framework-level HTTP exceptions. - - For example, 404 Not Found for bad routes, 422 Unprocessable Entity - for schema validation, and formats them into the A2A standard - google.rpc.Status JSON format (AIP-193). - """ - grpc_status = _HTTP_TO_GRPC_STATUS_MAP.get( - exc.status_code, 'UNKNOWN' - ) - return JSONResponse( - status_code=exc.status_code, - content={ - 'error': { - 'code': exc.status_code, - 'status': grpc_status, - 'message': str(exc.detail) - if hasattr(exc, 'detail') - else 'HTTP Exception', - } - }, - media_type='application/json', - ) - - if self.enable_v0_3_compat and self._v03_adapter: - v03_adapter = self._v03_adapter - v03_router = APIRouter() - for route, callback in v03_adapter.routes().items(): - v03_router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - app.include_router(v03_router) - - router = APIRouter() - for route, callback in self._adapter.routes().items(): - router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - - @router.get(f'{rpc_url}{agent_card_url}') - async def get_agent_card(request: Request) -> Response: - card = await self._adapter.handle_get_agent_card(request) - return JSONResponse(card) - - app.include_router(router) - - return app diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py deleted file mode 100644 index 6b8abb99..00000000 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ /dev/null @@ -1,296 +0,0 @@ -import functools -import json -import logging - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import MessageToDict - -from a2a.utils.helpers import maybe_await - - -if TYPE_CHECKING: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - -else: - try: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - except ImportError: - EventSourceResponse = Any - Request = Any - JSONResponse = Any - Response = Any - - _package_starlette_installed = False - -from a2a.server.apps.jsonrpc import ( - CallContextBuilder, - DefaultCallContextBuilder, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, -) -from a2a.server.request_handlers.rest_handler import RESTHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.error_handlers import ( - rest_error_handler, - rest_stream_error_handler, -) -from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, - InvalidRequestError, -) - - -logger = logging.getLogger(__name__) - - -class RESTAdapterInterface(ABC): - """Interface for RESTAdapter.""" - - @abstractmethod - async def handle_get_agent_card( - self, request: 'Request', call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint.""" - - @abstractmethod - def routes(self) -> dict[tuple[str, str], Callable[['Request'], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers.""" - - -class RESTAdapter(RESTAdapterInterface): - """Adapter to make RequestHandler work with RESTful API. - - Defines REST requests processors and the routes to attach them too, as well as - manages response generation including Server-Sent Events (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - ): - """Initializes the RESTApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use' - ' the `RESTAdapter`. They can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier - self.handler = RESTHandler( - agent_card=agent_card, request_handler=http_handler - ) - self._context_builder = context_builder or DefaultCallContextBuilder() - - @rest_error_handler - async def _handle_request( - self, - method: Callable[[Request, ServerCallContext], Awaitable[Any]], - request: Request, - ) -> Response: - call_context = self._build_call_context(request) - - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_stream_error_handler - async def _handle_streaming_request( - self, - method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], - request: Request, - ) -> EventSourceResponse: - # Pre-consume and cache the request body to prevent deadlock in streaming context - # This is required because Starlette's request.body() can only be consumed once, - # and attempting to consume it after EventSourceResponse starts causes deadlock - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - call_context = self._build_call_context(request) - - async def event_generator( - stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: - async for item in stream: - yield json.dumps(item) - - return EventSourceResponse( - event_generator(method(request, call_context)) - ) - - async def handle_get_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return agent_card_to_dict(card_to_serve) - - async def _handle_authenticated_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Hook for per credential agent card response. - - If a dynamic card is needed based on the credentials provided in the request - override this method and return the customized content. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the authenticated card. - """ - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card - - if not card_to_serve: - card_to_serve = self.agent_card - - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=True) - - def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers. - - This method maps URL paths and HTTP methods to the appropriate handler - functions from the RESTHandler. It can be used by a web framework - (like Starlette or FastAPI) to set up the application's endpoints. - - Returns: - A dictionary where each key is a tuple of (path, http_method) and - the value is the callable handler for that route. - """ - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { - ('/message:send', 'POST'): functools.partial( - self._handle_request, self.handler.on_message_send - ), - ('/message:stream', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_message_send_stream, - ), - ('/tasks/{id}:cancel', 'POST'): functools.partial( - self._handle_request, self.handler.on_cancel_task - ), - ('/tasks/{id}:subscribe', 'GET'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}:subscribe', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}', 'GET'): functools.partial( - self._handle_request, self.handler.on_get_task - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'GET', - ): functools.partial( - self._handle_request, self.handler.get_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'DELETE', - ): functools.partial( - self._handle_request, self.handler.delete_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'POST', - ): functools.partial( - self._handle_request, self.handler.set_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'GET', - ): functools.partial( - self._handle_request, self.handler.list_push_notifications - ), - ('/tasks', 'GET'): functools.partial( - self._handle_request, self.handler.list_tasks - ), - } - - if self.agent_card.capabilities.extended_agent_card: - base_routes[('/extendedAgentCard', 'GET')] = functools.partial( - self._handle_request, self._handle_authenticated_agent_card - ) - - routes: dict[tuple[str, str], Callable[[Request], Any]] = { - (p, method): handler - for (path, method), handler in base_routes.items() - for p in (path, f'/{{tenant}}{path}') - } - - return routes - - def _build_call_context(self, request: Request) -> ServerCallContext: - call_context = self._context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] - return call_context diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 43ebc8e2..ef4f0a74 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -5,13 +5,11 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, ) -from a2a.server.request_handlers.rest_handler import RESTHandler logger = logging.getLogger(__name__) @@ -40,8 +38,6 @@ def __init__(self, *args, **kwargs): __all__ = [ 'DefaultRequestHandler', 'GrpcHandler', - 'JSONRPCHandler', - 'RESTHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py deleted file mode 100644 index e7d5b75a..00000000 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ /dev/null @@ -1,474 +0,0 @@ -"""JSON-RPC handler for A2A server requests.""" - -import logging - -from collections.abc import AsyncIterable, Awaitable, Callable -from typing import Any - -from google.protobuf.json_format import MessageToDict -from jsonrpc.jsonrpc2 import JSONRPC20Response - -from a2a.server.context import ServerCallContext -from a2a.server.jsonrpc_models import ( - InternalError as JSONRPCInternalError, -) -from a2a.server.jsonrpc_models import ( - JSONRPCError, -) -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - DeleteTaskPushNotificationConfigRequest, - GetExtendedAgentCardRequest, - GetTaskPushNotificationConfigRequest, - GetTaskRequest, - ListTaskPushNotificationConfigsRequest, - ListTasksRequest, - SendMessageRequest, - SendMessageResponse, - SubscribeToTaskRequest, - Task, - TaskPushNotificationConfig, -) -from a2a.utils import proto_utils -from a2a.utils.errors import ( - JSON_RPC_ERROR_CODE_MAP, - A2AError, - ContentTypeNotSupportedError, - ExtendedAgentCardNotConfiguredError, - ExtensionSupportRequiredError, - InternalError, - InvalidAgentResponseError, - InvalidParamsError, - InvalidRequestError, - MethodNotFoundError, - PushNotificationNotSupportedError, - TaskNotCancelableError, - TaskNotFoundError, - UnsupportedOperationError, - VersionNotSupportedError, -) -from a2a.utils.helpers import maybe_await, validate, validate_async_generator -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -EXCEPTION_MAP: dict[type[A2AError], type[JSONRPCError]] = { - TaskNotFoundError: JSONRPCError, - TaskNotCancelableError: JSONRPCError, - PushNotificationNotSupportedError: JSONRPCError, - UnsupportedOperationError: JSONRPCError, - ContentTypeNotSupportedError: JSONRPCError, - InvalidAgentResponseError: JSONRPCError, - ExtendedAgentCardNotConfiguredError: JSONRPCError, - InternalError: JSONRPCInternalError, - InvalidParamsError: JSONRPCError, - InvalidRequestError: JSONRPCError, - MethodNotFoundError: JSONRPCError, - ExtensionSupportRequiredError: JSONRPCError, - VersionNotSupportedError: JSONRPCError, -} - - -def _build_success_response( - request_id: str | int | None, result: Any -) -> dict[str, Any]: - """Build a JSON-RPC success response dict.""" - return JSONRPC20Response(result=result, _id=request_id).data - - -def _build_error_response( - request_id: str | int | None, error: Exception -) -> dict[str, Any]: - """Build a JSON-RPC error response dict.""" - jsonrpc_error: JSONRPCError - if isinstance(error, A2AError): - error_type = type(error) - model_class = EXCEPTION_MAP.get(error_type, JSONRPCInternalError) - code = JSON_RPC_ERROR_CODE_MAP.get(error_type, -32603) - jsonrpc_error = model_class( - code=code, - message=str(error), - ) - else: - jsonrpc_error = JSONRPCInternalError(message=str(error)) - - error_dict = jsonrpc_error.model_dump(exclude_none=True) - return JSONRPC20Response(error=error_dict, _id=request_id).data - - -@trace_class(kind=SpanKind.SERVER) -class JSONRPCHandler: - """Maps incoming JSON-RPC requests to the appropriate request handler method and formats responses.""" - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - ): - """Initializes the JSONRPCHandler. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegate requests to. - extended_agent_card: An optional, distinct Extended AgentCard to be served - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - """ - self.agent_card = agent_card - self.request_handler = request_handler - self.extended_agent_card = extended_agent_card - self.extended_card_modifier = extended_card_modifier - self.card_modifier = card_modifier - - def _get_request_id( - self, context: ServerCallContext | None - ) -> str | int | None: - """Get the JSON-RPC request ID from the context.""" - if context is None: - return None - return context.state.get('request_id') - - async def on_message_send( - self, - request: SendMessageRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'message/send' JSON-RPC method. - - Args: - request: The incoming `SendMessageRequest` proto message. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task_or_message = await self.request_handler.on_message_send( - request, context - ) - if isinstance(task_or_message, Task): - response = SendMessageResponse(task=task_or_message) - else: - response = SendMessageResponse(message=task_or_message) - - result = MessageToDict(response) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - request: SendMessageRequest, - context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: - """Handles the 'message/stream' JSON-RPC method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `SendMessageRequest` object (for streaming). - context: Context provided by the server. - - Yields: - Dict representations of JSON-RPC responses containing streaming events. - """ - try: - async for event in self.request_handler.on_message_send_stream( - request, context - ): - # Wrap the event in StreamResponse for consistent client parsing - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield _build_success_response( - self._get_request_id(context), result - ) - except A2AError as e: - yield _build_error_response( - self._get_request_id(context), - e, - ) - - async def on_cancel_task( - self, - request: CancelTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/cancel' JSON-RPC method. - - Args: - request: The incoming `CancelTaskRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task = await self.request_handler.on_cancel_task(request, context) - except A2AError as e: - return _build_error_response(request_id, e) - - if task: - result = MessageToDict(task, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - - return _build_error_response(request_id, TaskNotFoundError()) - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - request: SubscribeToTaskRequest, - context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: - """Handles the 'SubscribeToTask' JSON-RPC method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `SubscribeToTaskRequest` object. - context: Context provided by the server. - - Yields: - Dict representations of JSON-RPC responses containing streaming events. - """ - try: - async for event in self.request_handler.on_subscribe_to_task( - request, context - ): - # Wrap the event in StreamResponse for consistent client parsing - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield _build_success_response( - self._get_request_id(context), result - ) - except A2AError as e: - yield _build_error_response( - self._get_request_id(context), - e, - ) - - async def get_push_notification_config( - self, - request: GetTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/get' JSON-RPC method. - - Args: - request: The incoming `GetTaskPushNotificationConfigRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - config = ( - await self.request_handler.on_get_task_push_notification_config( - request, context - ) - ) - result = MessageToDict(config, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification_config( - self, - request: TaskPushNotificationConfig, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/set' JSON-RPC method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - - Raises: - UnsupportedOperationError: If push notifications are not supported by the agent - (due to the `@validate` decorator). - """ - request_id = self._get_request_id(context) - try: - # Pass the full request to the handler - result_config = await self.request_handler.on_create_task_push_notification_config( - request, context - ) - result = MessageToDict( - result_config, preserving_proto_field_name=False - ) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - async def on_get_task( - self, - request: GetTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/get' JSON-RPC method. - - Args: - request: The incoming `GetTaskRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task = await self.request_handler.on_get_task(request, context) - except A2AError as e: - return _build_error_response(request_id, e) - - if task: - result = MessageToDict(task, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - - return _build_error_response(request_id, TaskNotFoundError()) - - async def list_tasks( - self, - request: ListTasksRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/list' JSON-RPC method. - - Args: - request: The incoming `ListTasksRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - response = await self.request_handler.on_list_tasks( - request, context - ) - result = MessageToDict( - response, - preserving_proto_field_name=False, - always_print_fields_with_no_presence=True, - ) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - async def list_push_notification_configs( - self, - request: ListTaskPushNotificationConfigsRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'ListTaskPushNotificationConfigs' JSON-RPC method. - - Args: - request: The incoming `ListTaskPushNotificationConfigsRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - response = await self.request_handler.on_list_task_push_notification_configs( - request, context - ) - # response is a ListTaskPushNotificationConfigsResponse proto - result = MessageToDict(response, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - async def delete_push_notification_config( - self, - request: DeleteTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/delete' JSON-RPC method. - - Args: - request: The incoming `DeleteTaskPushNotificationConfigRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - await self.request_handler.on_delete_task_push_notification_config( - request, context - ) - return _build_success_response(request_id, None) - except A2AError as e: - return _build_error_response(request_id, e) - - async def get_authenticated_extended_card( - self, - request: GetExtendedAgentCardRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'agent/authenticatedExtendedCard' JSON-RPC method. - - Args: - request: The incoming `GetExtendedAgentCardRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='The agent does not have an extended agent card configured' - ) - - base_card = self.extended_agent_card - if base_card is None: - base_card = self.agent_card - - card_to_serve = base_card - if self.extended_card_modifier and context: - card_to_serve = await maybe_await( - self.extended_card_modifier(base_card, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - - result = MessageToDict(card_to_serve, preserving_proto_field_name=False) - return _build_success_response(request_id, result) diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py deleted file mode 100644 index f01c1371..00000000 --- a/src/a2a/server/request_handlers/rest_handler.py +++ /dev/null @@ -1,321 +0,0 @@ -import logging - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import ( - MessageToDict, - Parse, -) - - -if TYPE_CHECKING: - from starlette.requests import Request -else: - try: - from starlette.requests import Request - except ImportError: - Request = Any - - -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - GetTaskPushNotificationConfigRequest, - SubscribeToTaskRequest, -) -from a2a.utils import proto_utils -from a2a.utils.errors import TaskNotFoundError -from a2a.utils.helpers import validate, validate_async_generator -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.SERVER) -class RESTHandler: - """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. - - This uses the protobuf definitions of the gRPC service as the source of truth. By - doing this, it ensures that this implementation and the gRPC transcoding - (via Envoy) are equivalent. This handler should be used if using the gRPC handler - with Envoy is not feasible for a given deployment solution. Use this handler - and a related application if you desire to ONLY server the RESTful API. - """ - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - ): - """Initializes the RESTHandler. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegate requests to. - """ - self.agent_card = agent_card - self.request_handler = request_handler - - async def on_message_send( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'message/send' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the result (Task or Message) - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - task_or_message = await self.request_handler.on_message_send( - params, context - ) - if isinstance(task_or_message, a2a_pb2.Task): - response = a2a_pb2.SendMessageResponse(task=task_or_message) - else: - response = a2a_pb2.SendMessageResponse(message=task_or_message) - return MessageToDict(response) - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'message/stream' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) - - async def on_cancel_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/cancel' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the updated Task - """ - task_id = request.path_params['id'] - task = await self.request_handler.on_cancel_task( - CancelTaskRequest(id=task_id), context - ) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'SubscribeToTask' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - """ - task_id = request.path_params['id'] - async for event in self.request_handler.on_subscribe_to_task( - SubscribeToTaskRequest(id=task_id), context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) - - async def get_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/get' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = GetTaskPushNotificationConfigRequest( - task_id=task_id, - id=push_id, - ) - config = ( - await self.request_handler.on_get_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/set' REST method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config object. - - Raises: - UnsupportedOperationError: If push notifications are not supported by the agent - (due to the `@validate` decorator), A2AError if processing error is - found. - """ - body = await request.body() - params = a2a_pb2.TaskPushNotificationConfig() - Parse(body, params) - # Set the parent to the task resource name format - params.task_id = request.path_params['id'] - config = ( - await self.request_handler.on_create_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - async def on_get_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/{id}' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `Task` object containing the Task. - """ - params = a2a_pb2.GetTaskRequest() - proto_utils.parse_params(request.query_params, params) - params.id = request.path_params['id'] - task = await self.request_handler.on_get_task(params, context) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - async def delete_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/delete' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - An empty `dict` representing the empty response. - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( - task_id=task_id, id=push_id - ) - await self.request_handler.on_delete_task_push_notification_config( - params, context - ) - return {} - - async def list_tasks( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `Task` objects. - """ - params = a2a_pb2.ListTasksRequest() - proto_utils.parse_params(request.query_params, params) - - result = await self.request_handler.on_list_tasks(params, context) - return MessageToDict(result, always_print_fields_with_no_presence=True) - - async def list_push_notifications( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `TaskPushNotificationConfig` objects. - """ - params = a2a_pb2.ListTaskPushNotificationConfigsRequest() - proto_utils.parse_params(request.query_params, params) - params.task_id = request.path_params['id'] - - result = ( - await self.request_handler.on_list_task_push_notification_configs( - params, context - ) - ) - return MessageToDict(result) diff --git a/src/a2a/server/request_handlers/rest_handler_v2.py b/src/a2a/server/request_handlers/rest_handler_v2.py deleted file mode 100644 index 5df91eaa..00000000 --- a/src/a2a/server/request_handlers/rest_handler_v2.py +++ /dev/null @@ -1,156 +0,0 @@ -import logging - -from collections.abc import AsyncIterator -from typing import Any - -from google.protobuf.json_format import ( - MessageToDict, -) - -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import ( - AgentCard, -) -from a2a.utils import proto_utils -from a2a.utils.errors import TaskNotFoundError -from a2a.utils.helpers import validate, validate_async_generator -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.SERVER) -class RESTHandlerV2: - """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses.""" - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - ): - self.agent_card = agent_card - self.request_handler = request_handler - - async def on_message_send( - self, - params: a2a_pb2.SendMessageRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - task_or_message = await self.request_handler.on_message_send( - params, context - ) - if isinstance(task_or_message, a2a_pb2.Task): - response = a2a_pb2.SendMessageResponse(task=task_or_message) - else: - response = a2a_pb2.SendMessageResponse(message=task_or_message) - return MessageToDict(response) - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - params: a2a_pb2.SendMessageRequest, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) - - async def on_cancel_task( - self, - params: a2a_pb2.CancelTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - task = await self.request_handler.on_cancel_task(params, context) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - params: a2a_pb2.SubscribeToTaskRequest, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - async for event in self.request_handler.on_subscribe_to_task( - params, context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) - - async def get_push_notification( - self, - params: a2a_pb2.GetTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - config = ( - await self.request_handler.on_get_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification( - self, - params: a2a_pb2.TaskPushNotificationConfig, - context: ServerCallContext, - ) -> dict[str, Any]: - config = ( - await self.request_handler.on_create_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - async def on_get_task( - self, - params: a2a_pb2.GetTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - task = await self.request_handler.on_get_task(params, context) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - async def delete_push_notification( - self, - params: a2a_pb2.DeleteTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - await self.request_handler.on_delete_task_push_notification_config( - params, context - ) - return {} - - async def list_tasks( - self, - params: a2a_pb2.ListTasksRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - result = await self.request_handler.on_list_tasks(params, context) - return MessageToDict(result, always_print_fields_with_no_presence=True) - - async def list_push_notifications( - self, - params: a2a_pb2.ListTaskPushNotificationConfigsRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - result = ( - await self.request_handler.on_list_task_push_notification_configs( - params, context - ) - ) - return MessageToDict(result) diff --git a/src/a2a/server/router/__init__.py b/src/a2a/server/router/__init__.py index bbd27494..7a25a53e 100644 --- a/src/a2a/server/router/__init__.py +++ b/src/a2a/server/router/__init__.py @@ -1,21 +1,20 @@ """A2A JSON-RPC Applications.""" -from a2a.server.apps.jsonrpc.jsonrpc_app import ( +from a2a.server.router.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, StarletteUserProxy, ) -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication from a2a.server.router.agent_card_router import AgentCardRouter from a2a.server.router.jsonrpc_router import JsonRpcRouter from a2a.server.router.rest_router import RestRouter __all__ = [ - 'A2AFastAPIApplication', - 'A2AStarletteApplication', 'CallContextBuilder', 'DefaultCallContextBuilder', - 'JSONRPCApplication', 'StarletteUserProxy', + 'AgentCardRouter', + 'JsonRpcRouter', + 'RestRouter', ] diff --git a/src/a2a/server/router/agent_card_router.py b/src/a2a/server/router/agent_card_router.py index 04e280f7..aaa9799a 100644 --- a/src/a2a/server/router/agent_card_router.py +++ b/src/a2a/server/router/agent_card_router.py @@ -1,19 +1,22 @@ import logging -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Route + from starlette.routing import Router, Route else: try: + from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Route + from starlette.routing import Router, Route except ImportError: + Middleware = Any Route = Any Request = Any Response = Any @@ -36,6 +39,7 @@ def __init__( card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, card_url: str = '/', + middleware: Sequence['Middleware'] | None = None, ) -> None: """Initializes the AgentCardRouter. @@ -56,4 +60,10 @@ async def get_agent_card(request: Request) -> Response: ) return JSONResponse(agent_card_to_dict(card_to_serve)) - self.route = Route(path=card_url, endpoint=get_agent_card, methods=['GET']) + self.route = Route( + path=card_url, + endpoint=get_agent_card, + methods=['GET'], + middleware=middleware, + ) + diff --git a/src/a2a/server/router/jsonrpc_router.py b/src/a2a/server/router/jsonrpc_router.py index e16c6158..050a156a 100644 --- a/src/a2a/server/router/jsonrpc_router.py +++ b/src/a2a/server/router/jsonrpc_router.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any @@ -31,7 +31,8 @@ logger = logging.getLogger(__name__) -from starlette.routing import Route +from starlette.middleware import Middleware +from starlette.routing import Route, Router class JsonRpcRouter: @@ -56,6 +57,7 @@ def __init__( # noqa: PLR0913 | None = None, enable_v0_3_compat: bool = False, rpc_url: str = '/', + middleware: Sequence[Middleware] | None = None, ) -> None: """Initializes the JsonRpcRouter. @@ -81,5 +83,6 @@ def __init__( # noqa: PLR0913 self.route = Route( path=rpc_url, endpoint=self.dispatcher._handle_requests, - methods=['POST'] + methods=['POST'], + middleware=middleware, ) diff --git a/src/a2a/server/router/rest_router.py b/src/a2a/server/router/rest_router.py index cf31231a..2da54b27 100644 --- a/src/a2a/server/router/rest_router.py +++ b/src/a2a/server/router/rest_router.py @@ -1,28 +1,31 @@ import logging -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from sse_starlette.sse import EventSourceResponse from starlette.exceptions import HTTPException as StarletteHTTPException + from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Router + from starlette.routing import Route _package_starlette_installed = True else: try: from sse_starlette.sse import EventSourceResponse from starlette.exceptions import HTTPException as StarletteHTTPException + from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Router + from starlette.routing import Route _package_starlette_installed = True except ImportError: - Router = Any + Middleware = Any + Route = Any EventSourceResponse = Any Request = Any JSONResponse = Any @@ -35,7 +38,7 @@ from google.protobuf.json_format import MessageToDict, Parse -from a2a.server.apps.jsonrpc import ( +from a2a.server.router.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, ) @@ -82,6 +85,7 @@ def __init__( # noqa: PLR0913 | None = None, enable_v0_3_compat: bool = False, rpc_url: str = '', + middleware: Sequence['Middleware'] | None = None, ) -> None: """Initializes the A2AFastAPIApplication. @@ -129,8 +133,7 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, ) - self.router = Router() - self._setup_router(rpc_url) + self._setup_routes(rpc_url) def _build_call_context(self, request: Request) -> ServerCallContext: call_context = self._context_builder.build(request) @@ -138,17 +141,23 @@ def _build_call_context(self, request: Request) -> ServerCallContext: call_context.tenant = request.path_params['tenant'] return call_context - def _setup_router( + def _setup_routes( self, rpc_url: str, ) -> None: """Builds and returns the FastAPI application instance.""" + self.routes = [] if self.enable_v0_3_compat and self._v03_adapter: for route, callback in self._v03_adapter.routes().items(): - self.router.add_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] + self.routes.append( + Route( + path=f'{rpc_url}{route[0]}', + endpoint=callback, + methods=[route[1]], + ) ) + base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { ('/message:send', 'POST'): self._message_send, ('/message:stream', 'POST'): self._message_stream, @@ -164,16 +173,16 @@ def _setup_router( ('/extendedAgentCard', 'GET'): self._get_extended_agent_card, } - routes: dict[tuple[str, str], Callable[[Request], Any]] = { - (p, method): handler + self.routes.extend([ + Route( + path=f'{rpc_url}{p}', + endpoint=handler, + methods=[method], + ) for (path, method), handler in base_routes.items() for p in (path, f'/{{tenant}}{path}') - } + ]) - for (path, method), handler in routes.items(): - self.router.add_route( - f'{rpc_url}{path}', handler, methods=[method] - ) @rest_error_handler async def _message_send(self, request: Request) -> Response: diff --git a/tck/sut_agent.py b/tck/sut_agent.py index cedd67fe..2102fa75 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -215,8 +215,10 @@ def serve(task_store: TaskStore) -> None: rest_router = RestRouter( agent_card=agent_card, request_handler=request_handler, + rpc_url=REST_URL, ) - main_app.mount(REST_URL, rest_router.router) + main_app.routes.extend(rest_router.routes) + config = uvicorn.Config( main_app, host='127.0.0.1', port=http_port, log_level='info' diff --git a/tests/server/apps/jsonrpc/test_fastapi_app.py b/tests/server/apps/jsonrpc/test_fastapi_app.py deleted file mode 100644 index 11831df5..00000000 --- a/tests/server/apps/jsonrpc/test_fastapi_app.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from a2a.server.apps.jsonrpc import fastapi_app -from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication -from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec -) -from a2a.types.a2a_pb2 import AgentCard # For mock spec - - -# --- A2AFastAPIApplication Tests --- - - -class TestA2AFastAPIApplicationOptionalDeps: - # Running tests in this class requires the optional dependency fastapi to be - # present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_fastapi_is_present(self): - try: - import fastapi as _fastapi # noqa: F401 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' the optional dependency fastapi to be present in the test' - ' environment. Run `uv sync --dev ...` before running the test' - ' suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'capabilities.extended_agent_card' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_fastapi_not_installed(self): - pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed - fastapi_app._package_fastapi_installed = False - yield - fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag - - def test_create_a2a_fastapi_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - try: - _app = A2AFastAPIApplication(**mock_app_params) - except ImportError: - pytest.fail( - 'With the fastapi package present, creating a' - ' A2AFastAPIApplication instance should not raise ImportError' - ) - - def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror( - self, - mock_app_params: dict, - mark_pkg_fastapi_not_installed: Any, - ): - with pytest.raises( - ImportError, - match=( - 'The `fastapi` package is required to use the' - ' `A2AFastAPIApplication`' - ), - ): - _app = A2AFastAPIApplication(**mock_app_params) - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py deleted file mode 100644 index ab220e9c..00000000 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ /dev/null @@ -1,444 +0,0 @@ -# ruff: noqa: INP001 -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from starlette.responses import JSONResponse -from starlette.testclient import TestClient - - -# Attempt to import StarletteBaseUser, fallback to MagicMock if not available -try: - from starlette.authentication import BaseUser as StarletteBaseUser -except ImportError: - StarletteBaseUser = MagicMock() # type: ignore - -from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.apps.jsonrpc import ( - jsonrpc_app, # Keep this import for optional deps test -) -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - JSONRPCApplication, - StarletteUserProxy, -) -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import ( - RequestHandler, -) # For mock spec -from a2a.types.a2a_pb2 import ( - AgentCard, - Message, - Part, - Role, -) - - -# --- StarletteUserProxy Tests --- - - -class TestStarletteUserProxy: - def test_starlette_user_proxy_is_authenticated_true(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.is_authenticated = True - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.is_authenticated is True - - def test_starlette_user_proxy_is_authenticated_false(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.is_authenticated = False - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.is_authenticated is False - - def test_starlette_user_proxy_user_name(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.display_name = 'Test User DisplayName' - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.user_name == 'Test User DisplayName' - - def test_starlette_user_proxy_user_name_raises_attribute_error(self): - """ - Tests that if the underlying starlette user object is missing the - display_name attribute, the proxy currently raises an AttributeError. - """ - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - # Ensure display_name is not present on the mock to trigger AttributeError - del starlette_user_mock.display_name - - proxy = StarletteUserProxy(starlette_user_mock) - with pytest.raises(AttributeError, match='display_name'): - _ = proxy.user_name - - -# --- JSONRPCApplication Tests (Selected) --- - - -@pytest.fixture -def mock_handler(): - handler = AsyncMock(spec=RequestHandler) - # Return a proto Message object directly - the handler wraps it in SendMessageResponse - handler.on_message_send.return_value = Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response message')], - ) - return handler - - -@pytest.fixture -def test_app(mock_handler): - mock_agent_card = MagicMock(spec=AgentCard) - mock_agent_card.url = 'http://mockurl.com' - # Set up capabilities.streaming to avoid validation issues - mock_agent_card.capabilities = MagicMock() - mock_agent_card.capabilities.streaming = False - return A2AStarletteApplication( - agent_card=mock_agent_card, http_handler=mock_handler - ) - - -@pytest.fixture -def client(test_app): - return TestClient(test_app.build()) - - -def _make_send_message_request( - text: str = 'hi', tenant: str | None = None -) -> dict: - """Helper to create a JSON-RPC send message request.""" - params: dict[str, Any] = { - 'message': { - 'messageId': '1', - 'role': 'ROLE_USER', - 'parts': [{'text': text}], - } - } - if tenant is not None: - params['tenant'] = tenant - - return { - 'jsonrpc': '2.0', - 'id': '1', - 'method': 'SendMessage', - 'params': params, - } - - -class TestJSONRPCApplicationSetup: # Renamed to avoid conflict - def test_jsonrpc_app_build_method_abstract_raises_typeerror( - self, - ): # Renamed test - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in JSONRPCApplication.__init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed in __init__ - mock_agent_card.url = 'http://mockurl.com' - # Ensure 'supportsAuthenticatedExtendedCard' attribute exists - - # This will fail at definition time if an abstract method is not implemented - with pytest.raises( - TypeError, - match=r".*abstract class IncompleteJSONRPCApp .* abstract method '?build'?", - ): - - class IncompleteJSONRPCApp(JSONRPCApplication): - # Intentionally not implementing 'build' - def some_other_method(self): - pass - - IncompleteJSONRPCApp( - agent_card=mock_agent_card, http_handler=mock_handler - ) # type: ignore[abstract] - - -class TestJSONRPCApplicationOptionalDeps: - # Running tests in this class requires optional dependencies starlette and - # sse-starlette to be present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_starlette_is_present(self): - try: - import sse_starlette as _sse_starlette # noqa: F401, PLC0415 - import starlette as _starlette # noqa: F401, PLC0415 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' optional dependencies starlette and sse-starlette to be' - ' present in the test environment. Run `uv sync --dev ...`' - ' before running the test suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'supportsAuthenticatedExtendedCard' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_starlette_not_installed(self): - pkg_starlette_installed_flag = jsonrpc_app._package_starlette_installed - jsonrpc_app._package_starlette_installed = False - yield - jsonrpc_app._package_starlette_installed = pkg_starlette_installed_flag - - def test_create_jsonrpc_based_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - class MockJSONRPCApp(JSONRPCApplication): - def build( # type: ignore[override] - self, - agent_card_url='/.well-known/agent.json', - rpc_url='/', - **kwargs, - ): - return object() # type: ignore[return-value] - - try: - _app = MockJSONRPCApp(**mock_app_params) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating a' - ' JSONRPCApplication-based instance should not raise' - ' ImportError' - ) - - def test_create_jsonrpc_based_app_with_missing_deps_raises_importerror( - self, mock_app_params: dict, mark_pkg_starlette_not_installed: Any - ): - class MockJSONRPCApp(JSONRPCApplication): - def build( # type: ignore[override] - self, - agent_card_url='/.well-known/agent.json', - rpc_url='/', - **kwargs, - ): - return object() # type: ignore[return-value] - - with pytest.raises( - ImportError, - match=( - 'Packages `starlette` and `sse-starlette` are required to use' - ' the `JSONRPCApplication`' - ), - ): - _app = MockJSONRPCApp(**mock_app_params) - - -class TestJSONRPCApplicationExtensions: - def test_request_with_single_extension(self, client, mock_handler): - headers = {HTTP_EXTENSION_HEADER: 'foo'} - response = client.post( - '/', - headers=headers, - json=_make_send_message_request(), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert isinstance(call_context, ServerCallContext) - assert call_context.requested_extensions == {'foo'} - - def test_request_with_comma_separated_extensions( - self, client, mock_handler - ): - headers = {HTTP_EXTENSION_HEADER: 'foo, bar'} - response = client.post( - '/', - headers=headers, - json=_make_send_message_request(), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert call_context.requested_extensions == {'foo', 'bar'} - - def test_request_with_comma_separated_extensions_no_space( - self, client, mock_handler - ): - headers = [ - (HTTP_EXTENSION_HEADER, 'foo, bar'), - (HTTP_EXTENSION_HEADER, 'baz'), - ] - response = client.post( - '/', - headers=headers, - json=_make_send_message_request(), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert call_context.requested_extensions == {'foo', 'bar', 'baz'} - - def test_method_added_to_call_context_state(self, client, mock_handler): - response = client.post( - '/', - json=_make_send_message_request(), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert call_context.state['method'] == 'SendMessage' - - def test_request_with_multiple_extension_headers( - self, client, mock_handler - ): - headers = [ - (HTTP_EXTENSION_HEADER, 'foo'), - (HTTP_EXTENSION_HEADER, 'bar'), - ] - response = client.post( - '/', - headers=headers, - json=_make_send_message_request(), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert call_context.requested_extensions == {'foo', 'bar'} - - def test_response_with_activated_extensions(self, client, mock_handler): - def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') - # Return a proto Message object directly - return Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response message')], - ) - - mock_handler.on_message_send.side_effect = side_effect - - response = client.post( - '/', - json=_make_send_message_request(), - ) - response.raise_for_status() - - assert response.status_code == 200 - assert HTTP_EXTENSION_HEADER in response.headers - assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == { - 'foo', - 'baz', - } - - -class TestJSONRPCApplicationTenant: - def test_tenant_extraction_from_params(self, client, mock_handler): - tenant_id = 'my-tenant-123' - response = client.post( - '/', - json=_make_send_message_request(tenant=tenant_id), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert isinstance(call_context, ServerCallContext) - assert call_context.tenant == tenant_id - - def test_no_tenant_extraction(self, client, mock_handler): - response = client.post( - '/', - json=_make_send_message_request(tenant=None), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert isinstance(call_context, ServerCallContext) - assert call_context.tenant == '' - - -class TestJSONRPCApplicationV03Compat: - def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): - mock_agent_card = MagicMock(spec=AgentCard) - mock_agent_card.url = 'http://mockurl.com' - mock_agent_card.capabilities = MagicMock() - mock_agent_card.capabilities.streaming = False - - app = A2AStarletteApplication( - agent_card=mock_agent_card, - http_handler=mock_handler, - enable_v0_3_compat=True, - ) - - client = TestClient(app.build()) - - request_data = { - 'jsonrpc': '2.0', - 'id': '1', - 'method': 'message/send', - 'params': { - 'message': { - 'messageId': 'msg-1', - 'role': 'ROLE_USER', - 'parts': [{'text': 'Hello'}], - } - }, - } - - with patch.object( - app._v03_adapter, 'handle_request', new_callable=AsyncMock - ) as mock_handle: - mock_handle.return_value = JSONResponse( - {'jsonrpc': '2.0', 'id': '1', 'result': {}} - ) - - response = client.post('/', json=request_data) - - response.raise_for_status() - assert mock_handle.called - assert mock_handle.call_args[1]['method'] == 'message/send' - - def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): - mock_agent_card = MagicMock(spec=AgentCard) - mock_agent_card.url = 'http://mockurl.com' - mock_agent_card.capabilities = MagicMock() - mock_agent_card.capabilities.streaming = False - - app = A2AStarletteApplication( - agent_card=mock_agent_card, - http_handler=mock_handler, - enable_v0_3_compat=False, - ) - - client = TestClient(app.build()) - - request_data = { - 'jsonrpc': '2.0', - 'id': '1', - 'method': 'message/send', - 'params': { - 'message': { - 'messageId': 'msg-1', - 'role': 'ROLE_USER', - 'parts': [{'text': 'Hello'}], - } - }, - } - - response = client.post('/', json=request_data) - - assert response.status_code == 200 - # Should return MethodNotFoundError because the v0.3 method is not recognized - # without the adapter enabled. - resp_json = response.json() - assert 'error' in resp_json - assert resp_json['error']['code'] == -32601 - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/apps/jsonrpc/test_serialization.py b/tests/server/apps/jsonrpc/test_serialization.py deleted file mode 100644 index 825f8e2a..00000000 --- a/tests/server/apps/jsonrpc/test_serialization.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Tests for JSON-RPC serialization behavior.""" - -from unittest import mock - -import pytest -from starlette.testclient import TestClient - -from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication -from a2a.server.jsonrpc_models import JSONParseError -from a2a.types import ( - InvalidRequestError, -) -from a2a.types.a2a_pb2 import ( - AgentCapabilities, - AgentInterface, - AgentCard, - AgentSkill, - APIKeySecurityScheme, - Message, - Part, - Role, - SecurityRequirement, - SecurityScheme, -) - - -@pytest.fixture -def minimal_agent_card(): - """Provides a minimal AgentCard for testing.""" - return AgentCard( - name='TestAgent', - description='A test agent.', - supported_interfaces=[ - AgentInterface( - url='http://example.com/agent', protocol_binding='HTTP+JSON' - ) - ], - version='1.0.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['text/plain'], - skills=[ - AgentSkill( - id='skill-1', - name='Test Skill', - description='A test skill', - tags=['test'], - ) - ], - ) - - -@pytest.fixture -def agent_card_with_api_key(): - """Provides an AgentCard with an APIKeySecurityScheme for testing serialization.""" - api_key_scheme = APIKeySecurityScheme( - name='X-API-KEY', - location='header', - ) - - security_scheme = SecurityScheme(api_key_security_scheme=api_key_scheme) - - card = AgentCard( - name='APIKeyAgent', - description='An agent that uses API Key auth.', - supported_interfaces=[ - AgentInterface( - url='http://example.com/apikey-agent', - protocol_binding='HTTP+JSON', - ) - ], - version='1.0.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['text/plain'], - ) - # Add security scheme to the map - card.security_schemes['api_key_auth'].CopyFrom(security_scheme) - - return card - - -def test_starlette_agent_card_serialization(minimal_agent_card: AgentCard): - """Tests that the A2AStarletteApplication endpoint correctly serializes agent card.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - assert response_data['name'] == 'TestAgent' - assert response_data['description'] == 'A test agent.' - assert ( - response_data['supportedInterfaces'][0]['url'] - == 'http://example.com/agent' - ) - assert response_data['version'] == '1.0.0' - - -def test_starlette_agent_card_with_api_key_scheme( - agent_card_with_api_key: AgentCard, -): - """Tests that the A2AStarletteApplication endpoint correctly serializes API key schemes.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(agent_card_with_api_key, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - # Check security schemes are serialized - assert 'securitySchemes' in response_data - assert 'api_key_auth' in response_data['securitySchemes'] - - -def test_fastapi_agent_card_serialization(minimal_agent_card: AgentCard): - """Tests that the A2AFastAPIApplication endpoint correctly serializes agent card.""" - handler = mock.AsyncMock() - app_instance = A2AFastAPIApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - assert response_data['name'] == 'TestAgent' - assert response_data['description'] == 'A test agent.' - - -def test_handle_invalid_json(minimal_agent_card: AgentCard): - """Test handling of malformed JSON.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.post( - '/', - content='{ "jsonrpc": "2.0", "method": "test", "id": 1, "params": { "key": "value" }', - ) - assert response.status_code == 200 - data = response.json() - assert data['error']['code'] == JSONParseError().code - - -def test_handle_oversized_payload(minimal_agent_card: AgentCard): - """Test handling of oversized JSON payloads.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - large_string = 'a' * 11 * 1_000_000 # 11MB string - payload = { - 'jsonrpc': '2.0', - 'method': 'test', - 'id': 1, - 'params': {'data': large_string}, - } - - response = client.post('/', json=payload) - assert response.status_code == 200 - data = response.json() - assert data['error']['code'] == -32600 - - -@pytest.mark.parametrize( - 'max_content_length', - [ - None, - 11 * 1024 * 1024, - 30 * 1024 * 1024, - ], -) -def test_handle_oversized_payload_with_max_content_length( - minimal_agent_card: AgentCard, - max_content_length: int | None, -): - """Test handling of JSON payloads with sizes within custom max_content_length.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication( - minimal_agent_card, handler, max_content_length=max_content_length - ) - client = TestClient(app_instance.build()) - - large_string = 'a' * 11 * 1_000_000 # 11MB string - payload = { - 'jsonrpc': '2.0', - 'method': 'test', - 'id': 1, - 'params': {'data': large_string}, - } - - response = client.post('/', json=payload) - assert response.status_code == 200 - data = response.json() - # When max_content_length is set, requests up to that size should not be - # rejected due to payload size. The request might fail for other reasons, - # but it shouldn't be an InvalidRequestError related to the content length. - if max_content_length is not None: - assert data['error']['code'] != -32600 - - -def test_handle_unicode_characters(minimal_agent_card: AgentCard): - """Test handling of unicode characters in JSON payload.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - unicode_text = 'こんにちは世界' # "Hello world" in Japanese - - # Mock a handler response - handler.on_message_send.return_value = Message( - role=Role.ROLE_AGENT, - parts=[Part(text=f'Received: {unicode_text}')], - message_id='response-unicode', - ) - - unicode_payload = { - 'jsonrpc': '2.0', - 'method': 'SendMessage', - 'id': 'unicode_test', - 'params': { - 'message': { - 'role': 'ROLE_USER', - 'parts': [{'text': unicode_text}], - 'messageId': 'msg-unicode', - } - }, - } - - response = client.post('/', json=unicode_payload) - - # We are testing that the server can correctly deserialize the unicode payload - assert response.status_code == 200 - data = response.json() - # Check that we got a result (handler was called) - if 'result' in data: - # Response should contain the unicode text - result = data['result'] - if 'message' in result: - assert ( - result['message']['parts'][0]['text'] - == f'Received: {unicode_text}' - ) - elif 'parts' in result: - assert result['parts'][0]['text'] == f'Received: {unicode_text}' - - -def test_fastapi_sub_application(minimal_agent_card: AgentCard): - """ - Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application. - """ - from fastapi import FastAPI - - handler = mock.AsyncMock() - sub_app_instance = A2AFastAPIApplication(minimal_agent_card, handler) - app_instance = FastAPI() - app_instance.mount('/a2a', sub_app_instance.build()) - client = TestClient(app_instance) - - response = client.get('/a2a/openapi.json') - assert response.status_code == 200 - response_data = response.json() - - # The generated a2a.json (OpenAPI 2.0 / Swagger) does not typically include a 'servers' block - # unless specifically configured or converted to OpenAPI 3.0. - # FastAPI usually generates OpenAPI 3.0 schemas which have 'servers'. - # When we inject the raw Swagger 2.0 schema, it won't have 'servers'. - # We check if it is indeed the injected schema by checking for 'swagger': '2.0' - # or by checking for 'basePath' if we want to test path correctness. - - if response_data.get('swagger') == '2.0': - # It's the injected Swagger 2.0 schema - pass - else: - # It's an auto-generated OpenAPI 3.0+ schema (fallback or otherwise) - assert 'servers' in response_data - assert response_data['servers'] == [{'url': '/a2a'}] diff --git a/tests/server/apps/jsonrpc/test_starlette_app.py b/tests/server/apps/jsonrpc/test_starlette_app.py deleted file mode 100644 index fa686871..00000000 --- a/tests/server/apps/jsonrpc/test_starlette_app.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from a2a.server.apps.jsonrpc import starlette_app -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication -from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec -) -from a2a.types.a2a_pb2 import AgentCard # For mock spec - - -# --- A2AStarletteApplication Tests --- - - -class TestA2AStarletteApplicationOptionalDeps: - # Running tests in this class requires optional dependencies starlette and - # sse-starlette to be present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_starlette_is_present(self): - try: - import sse_starlette as _sse_starlette # noqa: F401 - import starlette as _starlette # noqa: F401 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' optional dependencies starlette and sse-starlette to be' - ' present in the test environment. Run `uv sync --dev ...`' - ' before running the test suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'capabilities.extended_agent_card' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_starlette_not_installed(self): - pkg_starlette_installed_flag = ( - starlette_app._package_starlette_installed - ) - starlette_app._package_starlette_installed = False - yield - starlette_app._package_starlette_installed = ( - pkg_starlette_installed_flag - ) - - def test_create_a2a_starlette_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - try: - _app = A2AStarletteApplication(**mock_app_params) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating an' - ' A2AStarletteApplication instance should not raise ImportError' - ) - - def test_create_a2a_starlette_app_with_missing_deps_raises_importerror( - self, - mock_app_params: dict, - mark_pkg_starlette_not_installed: Any, - ): - with pytest.raises( - ImportError, - match='Packages `starlette` and `sse-starlette` are required', - ): - _app = A2AStarletteApplication(**mock_app_params) - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index c8510023..131d9f11 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -10,9 +10,7 @@ from google.protobuf import json_format from httpx import ASGITransport, AsyncClient -from a2a.server.apps.rest import fastapi_app, rest_adapter -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication -from a2a.server.apps.rest.rest_adapter import RESTAdapter +from a2a.server.router.rest_router import RestRouter from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( @@ -74,12 +72,14 @@ async def extended_card_modifier() -> MagicMock | None: @pytest.fixture async def streaming_app( streaming_agent_card: AgentCard, request_handler: RequestHandler -) -> FastAPI: - """Builds the FastAPI application for testing streaming endpoints.""" - - return A2ARESTFastAPIApplication( - streaming_agent_card, request_handler - ).build(agent_card_url='/well-known/agent-card.json', rpc_url='') +) -> Any: + from starlette.applications import Starlette + router = RestRouter( + streaming_agent_card, request_handler, rpc_url='' + ) + app = Starlette() + app.mount('', router.router) + return app @pytest.fixture @@ -95,14 +95,25 @@ async def app( agent_card: AgentCard, request_handler: RequestHandler, extended_card_modifier: MagicMock | None, -) -> FastAPI: - """Builds the FastAPI application for testing.""" - - return A2ARESTFastAPIApplication( +) -> Any: + from starlette.applications import Starlette + from a2a.server.router.agent_card_router import AgentCardRouter + # Return Starlette app + app_instance = Starlette() + + rest_router = RestRouter( agent_card, request_handler, extended_card_modifier=extended_card_modifier, - ).build(agent_card_url='/well-known/agent.json', rpc_url='') + rpc_url='' + ) + app_instance.mount('', rest_router.router) + + # Also Agent card endpoint? if needed in tests + card_router = AgentCardRouter(agent_card, card_url='/well-known/agent.json') + app_instance.routes.append(card_router.route) + + return app_instance @pytest.fixture @@ -112,93 +123,18 @@ async def client(app: FastAPI) -> AsyncClient: ) -@pytest.fixture -def mark_pkg_starlette_not_installed(): - pkg_starlette_installed_flag = rest_adapter._package_starlette_installed - rest_adapter._package_starlette_installed = False - yield - rest_adapter._package_starlette_installed = pkg_starlette_installed_flag - - -@pytest.fixture -def mark_pkg_fastapi_not_installed(): - pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed - fastapi_app._package_fastapi_installed = False - yield - fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag - - -@pytest.mark.anyio -async def test_create_rest_adapter_with_present_deps_succeeds( - agent_card: AgentCard, request_handler: RequestHandler -): - try: - _app = RESTAdapter(agent_card, request_handler) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating an' - ' RESTAdapter instance should not raise ImportError' - ) - - -@pytest.mark.anyio -async def test_create_rest_adapter_with_missing_deps_raises_importerror( - agent_card: AgentCard, - request_handler: RequestHandler, - mark_pkg_starlette_not_installed: Any, -): - with pytest.raises( - ImportError, - match=( - r'Packages `starlette` and `sse-starlette` are required to use' - r' the `RESTAdapter`.' - ), - ): - _app = RESTAdapter(agent_card, request_handler) - - -@pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_present_deps_succeeds( - agent_card: AgentCard, request_handler: RequestHandler -): - try: - _app = A2ARESTFastAPIApplication(agent_card, request_handler).build( - agent_card_url='/well-known/agent.json', rpc_url='' - ) - except ImportError: - pytest.fail( - 'With the fastapi package present, creating a' - ' A2ARESTFastAPIApplication instance should not raise ImportError' - ) - - -@pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_missing_deps_raises_importerror( - agent_card: AgentCard, - request_handler: RequestHandler, - mark_pkg_fastapi_not_installed: Any, -): - with pytest.raises( - ImportError, - match=( - 'The `fastapi` package is required to use the' - ' `A2ARESTFastAPIApplication`' - ), - ): - _app = A2ARESTFastAPIApplication(agent_card, request_handler).build( - agent_card_url='/well-known/agent.json', rpc_url='' - ) +# --- RestRouter Tests --- @pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_v0_3_compat( +async def test_create_rest_router_with_v0_3_compat( agent_card: AgentCard, request_handler: RequestHandler ): - app = A2ARESTFastAPIApplication( - agent_card, request_handler, enable_v0_3_compat=True - ).build(agent_card_url='/well-known/agent.json', rpc_url='') - - routes = [getattr(route, 'path', '') for route in app.routes] + router = RestRouter( + agent_card, request_handler, enable_v0_3_compat=True, rpc_url='' + ) + # Check if a route from v0.3 compatibility is present + routes = [getattr(route, 'path', '') for route in router.router.routes] assert '/v1/message:send' in routes @@ -692,33 +628,5 @@ async def test_tenant_extraction_extended_agent_card( assert context.tenant == '' -@pytest.mark.anyio -async def test_global_http_exception_handler_returns_rpc_status( - client: AsyncClient, -) -> None: - """Test that a standard FastAPI 404 is transformed into the A2A google.rpc.Status format.""" - - # Send a request to an endpoint that does not exist - response = await client.get('/non-existent-route') - - # Verify it returns a 404 with standard application/json - assert response.status_code == 404 - assert response.headers.get('content-type') == 'application/json' - - data = response.json() - - # Assert the payload is wrapped in the "error" envelope - assert 'error' in data - error_payload = data['error'] - - # Assert it has the correct AIP-193 format - assert error_payload['code'] == 404 - assert error_payload['status'] == 'NOT_FOUND' - assert 'Not Found' in error_payload['message'] - - # Standard HTTP errors shouldn't leak details - assert 'details' not in error_payload - - if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/server/router/test_jsonrpc_router.py b/tests/server/router/test_jsonrpc_router.py new file mode 100644 index 00000000..850b5193 --- /dev/null +++ b/tests/server/router/test_jsonrpc_router.py @@ -0,0 +1,275 @@ +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from starlette.responses import JSONResponse +from starlette.testclient import TestClient + +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.server.router import JsonRpcRouter, StarletteUserProxy +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types.a2a_pb2 import ( + AgentCard, + Message, + Part, + Role, +) + + +@pytest.fixture +def mock_handler(): + handler = AsyncMock(spec=RequestHandler) + handler.on_message_send.return_value = Message( + message_id='test', + role=Role.ROLE_AGENT, + parts=[Part(text='response message')], + ) + return handler + + +@pytest.fixture +def test_app(mock_handler): + mock_agent_card = MagicMock(spec=AgentCard) + mock_agent_card.url = 'http://mockurl.com' + mock_agent_card.capabilities = MagicMock() + mock_agent_card.capabilities.streaming = False + + from starlette.applications import Starlette + app = Starlette() + router = JsonRpcRouter(mock_agent_card, mock_handler) + app.routes.append(router.route) + return app + + +@pytest.fixture +def client(test_app): + return TestClient(test_app) + + +def _make_send_message_request( + text: str = 'hi', tenant: str | None = None +) -> dict: + params: dict[str, Any] = { + 'message': { + 'messageId': '1', + 'role': 'ROLE_USER', + 'parts': [{'text': text}], + } + } + if tenant is not None: + params['tenant'] = tenant + + return { + 'jsonrpc': '2.0', + 'id': '1', + 'method': 'SendMessage', + 'params': params, + } + + +class TestJSONRPCApplicationExtensions: + def test_request_with_single_extension(self, client, mock_handler): + headers = {HTTP_EXTENSION_HEADER: 'foo'} + response = client.post( + '/', + headers=headers, + json=_make_send_message_request(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo'} + + def test_request_with_comma_separated_extensions( + self, client, mock_handler + ): + headers = {HTTP_EXTENSION_HEADER: 'foo, bar'} + response = client.post( + '/', + headers=headers, + json=_make_send_message_request(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar'} + + def test_request_with_comma_separated_extensions_no_space( + self, client, mock_handler + ): + headers = [ + (HTTP_EXTENSION_HEADER, 'foo, bar'), + (HTTP_EXTENSION_HEADER, 'baz'), + ] + response = client.post( + '/', + headers=headers, + json=_make_send_message_request(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar', 'baz'} + + def test_method_added_to_call_context_state(self, client, mock_handler): + response = client.post( + '/', + json=_make_send_message_request(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.state['method'] == 'SendMessage' + + def test_request_with_multiple_extension_headers( + self, client, mock_handler + ): + headers = [ + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'bar'), + ] + response = client.post( + '/', + headers=headers, + json=_make_send_message_request(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar'} + + def test_response_with_activated_extensions(self, client, mock_handler): + def side_effect(request, context: ServerCallContext): + context.activated_extensions.add('foo') + context.activated_extensions.add('baz') + return Message( + message_id='test', + role=Role.ROLE_AGENT, + parts=[Part(text='response message')], + ) + + mock_handler.on_message_send.side_effect = side_effect + + response = client.post( + '/', + json=_make_send_message_request(), + ) + response.raise_for_status() + + assert response.status_code == 200 + assert HTTP_EXTENSION_HEADER in response.headers + assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == { + 'foo', + 'baz', + } + + +class TestJSONRPCApplicationTenant: + def test_tenant_extraction_from_params(self, client, mock_handler): + tenant_id = 'my-tenant-123' + response = client.post( + '/', + json=_make_send_message_request(tenant=tenant_id), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.tenant == tenant_id + + def test_no_tenant_extraction(self, client, mock_handler): + response = client.post( + '/', + json=_make_send_message_request(tenant=None), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.tenant == '' + + +class TestJSONRPCApplicationV03Compat: + def test_v0_3_compat_flag_routes_to_dispatcher(self, mock_handler): + mock_agent_card = MagicMock(spec=AgentCard) + mock_agent_card.url = 'http://mockurl.com' + mock_agent_card.capabilities = MagicMock() + mock_agent_card.capabilities.streaming = False + + from starlette.applications import Starlette + app = Starlette() + router = JsonRpcRouter(mock_agent_card, mock_handler, enable_v0_3_compat=True) + app.routes.append(router.route) + + client = TestClient(app) + + request_data = { + 'jsonrpc': '2.0', + 'id': '1', + 'method': 'message/send', + 'params': { + 'message': { + 'messageId': 'msg-1', + 'role': 'ROLE_USER', + 'parts': [{'text': 'Hello'}], + } + }, + } + + # Instead of _v03_adapter, the handler handles it or it's dispatcher + with patch.object( + router.dispatcher, '_process_non_streaming_request', new_callable=AsyncMock + ) as mock_handle: + mock_handle.return_value = {'jsonrpc': '2.0', 'id': '1', 'result': {}} + + response = client.post('/', json=request_data) + + response.raise_for_status() + assert mock_handle.called + + def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): + mock_agent_card = MagicMock(spec=AgentCard) + mock_agent_card.url = 'http://mockurl.com' + mock_agent_card.capabilities = MagicMock() + mock_agent_card.capabilities.streaming = False + + from starlette.applications import Starlette + app = Starlette() + router = JsonRpcRouter(mock_agent_card, mock_handler, enable_v0_3_compat=False) + app.routes.append(router.route) + + client = TestClient(app) + + request_data = { + 'jsonrpc': '2.0', + 'id': '1', + 'method': 'message/send', + 'params': { + 'message': { + 'messageId': 'msg-1', + 'role': 'ROLE_USER', + 'parts': [{'text': 'Hello'}], + } + }, + } + + response = client.post('/', json=request_data) + + assert response.status_code == 200 + resp_json = response.json() + assert 'error' in resp_json + assert resp_json['error']['code'] == -32601 + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index e6bb5f88..0e30fe30 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -18,10 +18,8 @@ from starlette.routing import Route from starlette.testclient import TestClient -from a2a.server.apps import ( - A2AFastAPIApplication, - A2AStarletteApplication, -) +from a2a.server.router.agent_card_router import AgentCardRouter +from a2a.server.router.jsonrpc_router import JsonRpcRouter from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( InternalError, @@ -148,14 +146,35 @@ def handler(): return handler +class AppBuilder: + def __init__(self, agent_card, handler, card_modifier=None): + self.agent_card = agent_card + self.handler = handler + self.card_modifier = card_modifier + + def build(self, rpc_url='/', agent_card_url=AGENT_CARD_WELL_KNOWN_PATH, middleware=None, routes=None): + from starlette.applications import Starlette + app_instance = Starlette(middleware=middleware, routes=routes or []) + + # Agent card router + card_router = AgentCardRouter(self.agent_card, card_url=agent_card_url, card_modifier=self.card_modifier) + app_instance.routes.append(card_router.route) + + # JSON-RPC router + rpc_router = JsonRpcRouter(self.agent_card, self.handler, rpc_url=rpc_url) + app_instance.routes.append(rpc_router.route) + + return app_instance + + @pytest.fixture def app(agent_card: AgentCard, handler: mock.AsyncMock): - return A2AStarletteApplication(agent_card, handler) + return AppBuilder(agent_card, handler) @pytest.fixture -def client(app: A2AStarletteApplication, **kwargs): - """Create a test client with the Starlette app.""" +def client(app, **kwargs): + """Create a test client with the app builder.""" return TestClient(app.build(**kwargs)) @@ -173,7 +192,7 @@ def test_agent_card_endpoint(client: TestClient, agent_card: AgentCard): def test_agent_card_custom_url( - app: A2AStarletteApplication, agent_card: AgentCard + app, agent_card: AgentCard ): """Test the agent card endpoint with a custom URL.""" client = TestClient(app.build(agent_card_url='/my-agent')) @@ -184,7 +203,7 @@ def test_agent_card_custom_url( def test_starlette_rpc_endpoint_custom_url( - app: A2AStarletteApplication, handler: mock.AsyncMock + app, handler: mock.AsyncMock ): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value @@ -207,7 +226,7 @@ def test_starlette_rpc_endpoint_custom_url( def test_fastapi_rpc_endpoint_custom_url( - app: A2AFastAPIApplication, handler: mock.AsyncMock + app, handler: mock.AsyncMock ): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value @@ -230,7 +249,7 @@ def test_fastapi_rpc_endpoint_custom_url( def test_starlette_build_with_extra_routes( - app: A2AStarletteApplication, agent_card: AgentCard + app, agent_card: AgentCard ): """Test building the app with additional routes.""" @@ -254,7 +273,7 @@ def custom_handler(request): def test_fastapi_build_with_extra_routes( - app: A2AFastAPIApplication, agent_card: AgentCard + app, agent_card: AgentCard ): """Test building the app with additional routes.""" @@ -278,7 +297,7 @@ def custom_handler(request): def test_fastapi_build_custom_agent_card_path( - app: A2AFastAPIApplication, agent_card: AgentCard + app, agent_card: AgentCard ): """Test building the app with a custom agent card path.""" @@ -467,7 +486,7 @@ def test_get_push_notification_config( handler.on_get_task_push_notification_config.assert_awaited_once() -def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock): +def test_server_auth(app, handler: mock.AsyncMock): class TestAuthMiddleware(AuthenticationBackend): async def authenticate( self, conn: HTTPConnection @@ -530,7 +549,7 @@ async def authenticate( @pytest.mark.asyncio async def test_message_send_stream( - app: A2AStarletteApplication, handler: mock.AsyncMock + app, handler: mock.AsyncMock ) -> None: """Test streaming message sending.""" @@ -606,7 +625,7 @@ async def stream_generator(): @pytest.mark.asyncio async def test_task_resubscription( - app: A2AStarletteApplication, handler: mock.AsyncMock + app, handler: mock.AsyncMock ) -> None: """Test task resubscription streaming.""" @@ -738,7 +757,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AStarletteApplication( + app_instance = AppBuilder( agent_card, handler, card_modifier=modifier ) client = TestClient(app_instance.build()) @@ -763,7 +782,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AStarletteApplication( + app_instance = AppBuilder( agent_card, handler, card_modifier=modifier ) client = TestClient(app_instance.build()) @@ -788,7 +807,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AFastAPIApplication( + app_instance = AppBuilder( agent_card, handler, card_modifier=modifier ) client = TestClient(app_instance.build()) @@ -810,7 +829,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AFastAPIApplication( + app_instance = AppBuilder( agent_card, handler, card_modifier=modifier ) client = TestClient(app_instance.build()) @@ -924,7 +943,7 @@ def test_agent_card_backward_compatibility_supports_extended_card( ): """Test that supportsAuthenticatedExtendedCard is injected when extended_agent_card is True.""" agent_card.capabilities.extended_agent_card = True - app_instance = A2AStarletteApplication(agent_card, handler) + app_instance = AppBuilder(agent_card, handler) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) assert response.status_code == 200 @@ -937,7 +956,7 @@ def test_agent_card_backward_compatibility_no_extended_card( ): """Test that supportsAuthenticatedExtendedCard is absent when extended_agent_card is False.""" agent_card.capabilities.extended_agent_card = False - app_instance = A2AStarletteApplication(agent_card, handler) + app_instance = AppBuilder(agent_card, handler) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) assert response.status_code == 200 From 2d954eb0c0a08dde419e5389806f15a3a98cad7b Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 09:07:29 +0000 Subject: [PATCH 05/17] wip --- src/a2a/server/router/__init__.py | 20 --- src/a2a/server/routes/__init__.py | 20 +++ .../agent_card_route.py} | 12 +- .../{router => routes}/jsonrpc_dispatcher.py | 52 ++++++-- .../jsonrpc_route.py} | 17 ++- .../rest_router.py => routes/rest_routes.py} | 125 ++++++++++++------ tck/sut_agent.py | 9 +- tests/server/apps/rest/__init__.py | 0 tests/server/routes/test_agent_card_route.py | 55 ++++++++ .../test_jsonrpc_route.py} | 28 +++- .../test_rest_routes.py} | 30 ++--- tests/server/test_integration.py | 77 +++++------ 12 files changed, 284 insertions(+), 161 deletions(-) delete mode 100644 src/a2a/server/router/__init__.py create mode 100644 src/a2a/server/routes/__init__.py rename src/a2a/server/{router/agent_card_router.py => routes/agent_card_route.py} (88%) rename src/a2a/server/{router => routes}/jsonrpc_dispatcher.py (94%) rename src/a2a/server/{router/jsonrpc_router.py => routes/jsonrpc_route.py} (93%) rename src/a2a/server/{router/rest_router.py => routes/rest_routes.py} (81%) delete mode 100644 tests/server/apps/rest/__init__.py create mode 100644 tests/server/routes/test_agent_card_route.py rename tests/server/{router/test_jsonrpc_router.py => routes/test_jsonrpc_route.py} (93%) rename tests/server/{apps/rest/test_rest_fastapi_app.py => routes/test_rest_routes.py} (97%) diff --git a/src/a2a/server/router/__init__.py b/src/a2a/server/router/__init__.py deleted file mode 100644 index 7a25a53e..00000000 --- a/src/a2a/server/router/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""A2A JSON-RPC Applications.""" - -from a2a.server.router.jsonrpc_dispatcher import ( - CallContextBuilder, - DefaultCallContextBuilder, - StarletteUserProxy, -) -from a2a.server.router.agent_card_router import AgentCardRouter -from a2a.server.router.jsonrpc_router import JsonRpcRouter -from a2a.server.router.rest_router import RestRouter - - -__all__ = [ - 'CallContextBuilder', - 'DefaultCallContextBuilder', - 'StarletteUserProxy', - 'AgentCardRouter', - 'JsonRpcRouter', - 'RestRouter', -] diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py new file mode 100644 index 00000000..d9d5ee22 --- /dev/null +++ b/src/a2a/server/routes/__init__.py @@ -0,0 +1,20 @@ +"""A2A Server Routes.""" + +from a2a.server.routes.agent_card_route import AgentCardRoute +from a2a.server.routes.jsonrpc_dispatcher import ( + CallContextBuilder, + DefaultCallContextBuilder, + StarletteUserProxy, +) +from a2a.server.routes.jsonrpc_route import JsonRpcRoute +from a2a.server.routes.rest_routes import RestRoutes + + +__all__ = [ + 'AgentCardRoute', + 'CallContextBuilder', + 'DefaultCallContextBuilder', + 'JsonRpcRoute', + 'RestRoutes', + 'StarletteUserProxy', +] diff --git a/src/a2a/server/router/agent_card_router.py b/src/a2a/server/routes/agent_card_route.py similarity index 88% rename from src/a2a/server/router/agent_card_router.py rename to src/a2a/server/routes/agent_card_route.py index aaa9799a..feee2937 100644 --- a/src/a2a/server/router/agent_card_router.py +++ b/src/a2a/server/routes/agent_card_route.py @@ -8,13 +8,14 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Router, Route + from starlette.routing import Route else: try: from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Router, Route + from starlette.routing import Route + except ImportError: Middleware = Any Route = Any @@ -30,8 +31,8 @@ logger = logging.getLogger(__name__) -class AgentCardRouter: - """A FastAPI router implementing the A2A protocol agent card endpoints.""" +class AgentCardRoute: + """Provides the Starlette Route for the A2A protocol agent card endpoint.""" def __init__( self, @@ -41,7 +42,7 @@ def __init__( card_url: str = '/', middleware: Sequence['Middleware'] | None = None, ) -> None: - """Initializes the AgentCardRouter. + """Initializes the AgentCardRoute. Args: agent_card: The AgentCard describing the agent's capabilities. @@ -66,4 +67,3 @@ async def get_agent_card(request: Request) -> Response: methods=['GET'], middleware=middleware, ) - diff --git a/src/a2a/server/router/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py similarity index 94% rename from src/a2a/server/router/jsonrpc_dispatcher.py rename to src/a2a/server/routes/jsonrpc_dispatcher.py index fba7b381..ff2806c8 100644 --- a/src/a2a/server/router/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -428,7 +428,9 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 raw_result = await self._process_non_streaming_request( request_id, specific_request, call_context ) - handler_result = JSONRPC20Response(result=raw_result, _id=request_id).data + handler_result = JSONRPC20Response( + result=raw_result, _id=request_id + ).data except Exception as e: return self._generate_error_response(request_id, e) @@ -474,14 +476,20 @@ async def _process_streaming_request( stream: AsyncGenerator | None = None if isinstance(request_obj, SendMessageRequest): - stream = self.http_handler.on_message_send_stream(request_obj, context) + stream = self.http_handler.on_message_send_stream( + request_obj, context + ) elif isinstance(request_obj, SubscribeToTaskRequest): - stream = self.http_handler.on_subscribe_to_task(request_obj, context) + stream = self.http_handler.on_subscribe_to_task( + request_obj, context + ) if stream is None: - raise UnsupportedOperationError(message='Stream not supported') + raise UnsupportedOperationError(message='Stream not supported') - async def _wrap_stream(st: AsyncGenerator) -> AsyncGenerator[dict[str, Any], None]: + async def _wrap_stream( + st: AsyncGenerator, + ) -> AsyncGenerator[dict[str, Any], None]: try: async for event in st: stream_response = proto_utils.to_stream_response(event) @@ -490,7 +498,11 @@ async def _wrap_stream(st: AsyncGenerator) -> AsyncGenerator[dict[str, Any], Non ) yield JSONRPC20Response(result=result, _id=request_id).data except Exception as e: - error = e if isinstance(e, A2AError) else InternalError(message=str(e)) + error = ( + e + if isinstance(e, A2AError) + else InternalError(message=str(e)) + ) yield build_error_response(request_id, error) return _wrap_stream(stream) @@ -522,17 +534,21 @@ async def _process_non_streaming_request( msg_response = SendMessageResponse(message=task_or_message) return MessageToDict(msg_response) case CancelTaskRequest(): - task = await self.http_handler.on_cancel_task(request_obj, context) + task = await self.http_handler.on_cancel_task( + request_obj, context + ) if not task: - raise TaskNotFoundError() + raise TaskNotFoundError return MessageToDict(task, preserving_proto_field_name=False) case GetTaskRequest(): task = await self.http_handler.on_get_task(request_obj, context) if not task: - raise TaskNotFoundError() + raise TaskNotFoundError return MessageToDict(task, preserving_proto_field_name=False) case ListTasksRequest(): - tasks_response = await self.http_handler.on_list_tasks(request_obj, context) + tasks_response = await self.http_handler.on_list_tasks( + request_obj, context + ) return MessageToDict( tasks_response, preserving_proto_field_name=False, @@ -546,7 +562,9 @@ async def _process_non_streaming_request( result_config = await self.http_handler.on_create_task_push_notification_config( request_obj, context ) - return MessageToDict(result_config, preserving_proto_field_name=False) + return MessageToDict( + result_config, preserving_proto_field_name=False + ) case GetTaskPushNotificationConfigRequest(): config = await self.http_handler.on_get_task_push_notification_config( request_obj, context @@ -556,7 +574,9 @@ async def _process_non_streaming_request( list_push_response = await self.http_handler.on_list_task_push_notification_configs( request_obj, context ) - return MessageToDict(list_push_response, preserving_proto_field_name=False) + return MessageToDict( + list_push_response, preserving_proto_field_name=False + ) case DeleteTaskPushNotificationConfigRequest(): await self.http_handler.on_delete_task_push_notification_config( request_obj, context @@ -574,8 +594,12 @@ async def _process_non_streaming_request( self.extended_card_modifier(base_card, context) ) elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - return MessageToDict(card_to_serve, preserving_proto_field_name=False) + card_to_serve = await maybe_await( + self.card_modifier(base_card) + ) + return MessageToDict( + card_to_serve, preserving_proto_field_name=False + ) case _: logger.error( 'Unhandled validated request type: %s', type(request_obj) diff --git a/src/a2a/server/router/jsonrpc_router.py b/src/a2a/server/routes/jsonrpc_route.py similarity index 93% rename from src/a2a/server/router/jsonrpc_router.py rename to src/a2a/server/routes/jsonrpc_route.py index 050a156a..e7be63c1 100644 --- a/src/a2a/server/router/jsonrpc_router.py +++ b/src/a2a/server/routes/jsonrpc_route.py @@ -3,8 +3,12 @@ from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any +from starlette.middleware import Middleware +from starlette.routing import Route + if TYPE_CHECKING: + from starlette.routing import Router _package_starlette_installed = True @@ -21,7 +25,7 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.router.jsonrpc_dispatcher import ( +from a2a.server.routes.jsonrpc_dispatcher import ( CallContextBuilder, JsonRpcDispatcher, ) @@ -31,11 +35,10 @@ logger = logging.getLogger(__name__) -from starlette.middleware import Middleware -from starlette.routing import Route, Router -class JsonRpcRouter: + +class JsonRpcRoute: """Provides the Starlette Route for the A2A protocol JSON-RPC endpoint. Handles incoming JSON-RPC requests, routes them to the appropriate @@ -59,13 +62,13 @@ def __init__( # noqa: PLR0913 rpc_url: str = '/', middleware: Sequence[Middleware] | None = None, ) -> None: - """Initializes the JsonRpcRouter. - + """Initializes the JsonRpcRoute. + ... (docstrings remain the same) ... """ if not _package_starlette_installed: raise ImportError( - 'The `starlette` package is required to use the `JsonRpcRouter`.' + 'The `starlette` package is required to use the `JsonRpcRoute`.' ' It can be added as a part of `a2a-sdk` optional dependencies,' ' `a2a-sdk[http-server]`.' ) diff --git a/src/a2a/server/router/rest_router.py b/src/a2a/server/routes/rest_routes.py similarity index 81% rename from src/a2a/server/router/rest_router.py rename to src/a2a/server/routes/rest_routes.py index 2da54b27..5620bd5e 100644 --- a/src/a2a/server/router/rest_router.py +++ b/src/a2a/server/routes/rest_routes.py @@ -38,12 +38,12 @@ from google.protobuf.json_format import MessageToDict, Parse -from a2a.server.router.jsonrpc_dispatcher import ( +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, ) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils @@ -63,8 +63,8 @@ logger = logging.getLogger(__name__) -class RestRouter: - """A FastAPI application implementing the A2A protocol server endpoints. +class RestRoutes: + """Provides the Starlette Routes for the A2A protocol REST endpoints. Handles incoming JSON-REST requests, routes them to the appropriate handler methods, and manages response generation including Server-Sent Events @@ -87,7 +87,7 @@ def __init__( # noqa: PLR0913 rpc_url: str = '', middleware: Sequence['Middleware'] | None = None, ) -> None: - """Initializes the A2AFastAPIApplication. + """Initializes the RestRoutes. Args: agent_card: The AgentCard describing the agent's capabilities. @@ -107,7 +107,7 @@ def __init__( # noqa: PLR0913 """ if not _package_starlette_installed: raise ImportError( - 'The `starlette` package is required to use the `RestRouter`.' + 'The `starlette` package is required to use the `RestRoutes`.' ' It can be added as a part of `a2a-sdk` optional dependencies,' ' `a2a-sdk[http-server]`.' ) @@ -145,7 +145,7 @@ def _setup_routes( self, rpc_url: str, ) -> None: - """Builds and returns the FastAPI application instance.""" + """Sets up the Starlette routes.""" self.routes = [] if self.enable_v0_3_compat and self._v03_adapter: for route, callback in self._v03_adapter.routes().items(): @@ -157,7 +157,6 @@ def _setup_routes( ) ) - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { ('/message:send', 'POST'): self._message_send, ('/message:stream', 'POST'): self._message_stream, @@ -165,24 +164,37 @@ def _setup_routes( ('/tasks/{id}:subscribe', 'GET'): self._subscribe_task, ('/tasks/{id}:subscribe', 'POST'): self._subscribe_task, ('/tasks/{id}', 'GET'): self._get_task, - ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): self._get_push_notification, - ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE'): self._delete_push_notification, - ('/tasks/{id}/pushNotificationConfigs', 'POST'): self._set_push_notification, - ('/tasks/{id}/pushNotificationConfigs', 'GET'): self._list_push_notifications, + ( + '/tasks/{id}/pushNotificationConfigs/{push_id}', + 'GET', + ): self._get_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs/{push_id}', + 'DELETE', + ): self._delete_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs', + 'POST', + ): self._set_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs', + 'GET', + ): self._list_push_notifications, ('/tasks', 'GET'): self._list_tasks, ('/extendedAgentCard', 'GET'): self._get_extended_agent_card, } - self.routes.extend([ - Route( - path=f'{rpc_url}{p}', - endpoint=handler, - methods=[method], - ) - for (path, method), handler in base_routes.items() - for p in (path, f'/{{tenant}}{path}') - ]) - + self.routes.extend( + [ + Route( + path=f'{rpc_url}{p}', + endpoint=handler, + methods=[method], + ) + for (path, method), handler in base_routes.items() + for p in (path, f'/{{tenant}}{path}') + ] + ) @rest_error_handler async def _message_send(self, request: Request) -> Response: @@ -190,7 +202,9 @@ async def _message_send(self, request: Request) -> Response: params = a2a_pb2.SendMessageRequest() Parse(body, params) context = self._build_call_context(request) - task_or_message = await self.request_handler.on_message_send(params, context) + task_or_message = await self.request_handler.on_message_send( + params, context + ) if isinstance(task_or_message, a2a_pb2.Task): response = a2a_pb2.SendMessageResponse(task=task_or_message) else: @@ -207,7 +221,9 @@ async def _message_stream(self, request: Request) -> EventSourceResponse: ) from e if not self.agent_card.capabilities.streaming: - raise UnsupportedOperationError(message='Streaming is not supported by the agent') + raise UnsupportedOperationError( + message='Streaming is not supported by the agent' + ) body = await request.body() params = a2a_pb2.SendMessageRequest() @@ -215,8 +231,12 @@ async def _message_stream(self, request: Request) -> EventSourceResponse: context = self._build_call_context(request) async def event_generator() -> AsyncIterator[str]: - async for event in self.request_handler.on_message_send_stream(params, context): - yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + async for event in self.request_handler.on_message_send_stream( + params, context + ): + yield json.dumps( + MessageToDict(proto_utils.to_stream_response(event)) + ) return EventSourceResponse(event_generator()) @@ -240,13 +260,19 @@ async def _subscribe_task(self, request: Request) -> EventSourceResponse: ) from e task_id = request.path_params['id'] if not self.agent_card.capabilities.streaming: - raise UnsupportedOperationError(message='Streaming is not supported by the agent') + raise UnsupportedOperationError( + message='Streaming is not supported by the agent' + ) params = a2a_pb2.SubscribeToTaskRequest(id=task_id) context = self._build_call_context(request) async def event_generator() -> AsyncIterator[str]: - async for event in self.request_handler.on_subscribe_to_task(params, context): - yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + async for event in self.request_handler.on_subscribe_to_task( + params, context + ): + yield json.dumps( + MessageToDict(proto_utils.to_stream_response(event)) + ) return EventSourceResponse(event_generator()) @@ -269,7 +295,11 @@ async def _get_push_notification(self, request: Request) -> Response: task_id=task_id, id=push_id ) context = self._build_call_context(request) - config = await self.request_handler.on_get_task_push_notification_config(params, context) + config = ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) + ) return JSONResponse(MessageToDict(config)) @rest_error_handler @@ -280,19 +310,27 @@ async def _delete_push_notification(self, request: Request) -> Response: task_id=task_id, id=push_id ) context = self._build_call_context(request) - await self.request_handler.on_delete_task_push_notification_config(params, context) + await self.request_handler.on_delete_task_push_notification_config( + params, context + ) return JSONResponse({}) @rest_error_handler async def _set_push_notification(self, request: Request) -> Response: if not self.agent_card.capabilities.push_notifications: - raise UnsupportedOperationError(message='Push notifications are not supported by the agent') + raise UnsupportedOperationError( + message='Push notifications are not supported by the agent' + ) body = await request.body() params = a2a_pb2.TaskPushNotificationConfig() Parse(body, params) params.task_id = request.path_params['id'] context = self._build_call_context(request) - config = await self.request_handler.on_create_task_push_notification_config(params, context) + config = ( + await self.request_handler.on_create_task_push_notification_config( + params, context + ) + ) return JSONResponse(MessageToDict(config)) @rest_error_handler @@ -301,7 +339,11 @@ async def _list_push_notifications(self, request: Request) -> Response: proto_utils.parse_params(request.query_params, params) params.task_id = request.path_params['id'] context = self._build_call_context(request) - result = await self.request_handler.on_list_task_push_notification_configs(params, context) + result = ( + await self.request_handler.on_list_task_push_notification_configs( + params, context + ) + ) return JSONResponse(MessageToDict(result)) @rest_error_handler @@ -310,7 +352,9 @@ async def _list_tasks(self, request: Request) -> Response: proto_utils.parse_params(request.query_params, params) context = self._build_call_context(request) result = await self.request_handler.on_list_tasks(params, context) - return JSONResponse(MessageToDict(result, always_print_fields_with_no_presence=True)) + return JSONResponse( + MessageToDict(result, always_print_fields_with_no_presence=True) + ) @rest_error_handler async def _get_extended_agent_card(self, request: Request) -> Response: @@ -326,13 +370,8 @@ async def _get_extended_agent_card(self, request: Request) -> Response: self.extended_card_modifier(card_to_serve, context) ) elif self.card_modifier: - card_to_serve = await maybe_await( - self.card_modifier(card_to_serve) - ) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return JSONResponse( - MessageToDict( - card_to_serve, preserving_proto_field_name=True - ) + MessageToDict(card_to_serve, preserving_proto_field_name=True) ) - diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 2102fa75..8612e47f 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -21,7 +21,7 @@ DefaultRequestHandler, ) from a2a.server.request_handlers.grpc_handler import GrpcHandler -from a2a.server.router import AgentCardRouter, JsonRpcRouter, RestRouter +from a2a.server.routes import AgentCardRoute, JsonRpcRoute, RestRoutes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore from a2a.types import ( @@ -197,14 +197,14 @@ def serve(task_store: TaskStore) -> None: main_app = Starlette() # Agent Card - agent_card_router = AgentCardRouter( + agent_card_router = AgentCardRoute( agent_card=agent_card, card_url=AGENT_CARD_URL, ) main_app.routes.append(agent_card_router.route) # JSONRPC - jsonrpc_router = JsonRpcRouter( + jsonrpc_router = JsonRpcRoute( agent_card=agent_card, request_handler=request_handler, rpc_url=JSONRPC_URL, @@ -212,14 +212,13 @@ def serve(task_store: TaskStore) -> None: main_app.routes.append(jsonrpc_router.route) # REST - rest_router = RestRouter( + rest_router = RestRoutes( agent_card=agent_card, request_handler=request_handler, rpc_url=REST_URL, ) main_app.routes.extend(rest_router.routes) - config = uvicorn.Config( main_app, host='127.0.0.1', port=http_port, log_level='info' ) diff --git a/tests/server/apps/rest/__init__.py b/tests/server/apps/rest/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/server/routes/test_agent_card_route.py b/tests/server/routes/test_agent_card_route.py new file mode 100644 index 00000000..f117d37e --- /dev/null +++ b/tests/server/routes/test_agent_card_route.py @@ -0,0 +1,55 @@ +import pytest +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from a2a.server.routes import AgentCardRoute +from a2a.types import a2a_pb2 +from a2a.server.request_handlers.response_helpers import agent_card_to_dict + + +@pytest.fixture +def mock_agent_card(): + return a2a_pb2.AgentCard( + name='test-agent', + version='1.0.0', + documentation_url='http://localhost:8000', + ) + + +@pytest.fixture +def test_app(mock_agent_card): + app = Starlette() + card_route = AgentCardRoute(mock_agent_card) + app.routes.append(card_route.route) + return app + + +@pytest.fixture +def client(test_app): + return TestClient(test_app) + + +def test_agent_card_route_returns_json(client, mock_agent_card): + response = client.get('/') + assert response.status_code == 200 + + # The route returns JSON, not protobuf SerializeToString() + actual_json = response.json() + expected_json = agent_card_to_dict(mock_agent_card) + + assert actual_json == expected_json + + +def test_agent_card_route_with_modifier(mock_agent_card): + async def modifier(card): + card.name = 'modified-agent' + return card + + card_route = AgentCardRoute(mock_agent_card, card_modifier=modifier) + app = Starlette() + app.routes.append(card_route.route) + client = TestClient(app) + + response = client.get('/') + assert response.status_code == 200 + assert response.json()['name'] == 'modified-agent' diff --git a/tests/server/router/test_jsonrpc_router.py b/tests/server/routes/test_jsonrpc_route.py similarity index 93% rename from tests/server/router/test_jsonrpc_router.py rename to tests/server/routes/test_jsonrpc_route.py index 850b5193..226265ba 100644 --- a/tests/server/router/test_jsonrpc_router.py +++ b/tests/server/routes/test_jsonrpc_route.py @@ -7,7 +7,8 @@ from starlette.testclient import TestClient from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.router import JsonRpcRouter, StarletteUserProxy +from a2a.server.routes import JsonRpcRoute, StarletteUserProxy + from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( @@ -35,10 +36,11 @@ def test_app(mock_handler): mock_agent_card.url = 'http://mockurl.com' mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False - + from starlette.applications import Starlette + app = Starlette() - router = JsonRpcRouter(mock_agent_card, mock_handler) + router = JsonRpcRoute(mock_agent_card, mock_handler) app.routes.append(router.route) return app @@ -207,8 +209,11 @@ def test_v0_3_compat_flag_routes_to_dispatcher(self, mock_handler): mock_agent_card.capabilities.streaming = False from starlette.applications import Starlette + app = Starlette() - router = JsonRpcRouter(mock_agent_card, mock_handler, enable_v0_3_compat=True) + router = JsonRpcRoute( + mock_agent_card, mock_handler, enable_v0_3_compat=True + ) app.routes.append(router.route) client = TestClient(app) @@ -228,9 +233,15 @@ def test_v0_3_compat_flag_routes_to_dispatcher(self, mock_handler): # Instead of _v03_adapter, the handler handles it or it's dispatcher with patch.object( - router.dispatcher, '_process_non_streaming_request', new_callable=AsyncMock + router.dispatcher, + '_process_non_streaming_request', + new_callable=AsyncMock, ) as mock_handle: - mock_handle.return_value = {'jsonrpc': '2.0', 'id': '1', 'result': {}} + mock_handle.return_value = { + 'jsonrpc': '2.0', + 'id': '1', + 'result': {}, + } response = client.post('/', json=request_data) @@ -244,8 +255,11 @@ def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): mock_agent_card.capabilities.streaming = False from starlette.applications import Starlette + app = Starlette() - router = JsonRpcRouter(mock_agent_card, mock_handler, enable_v0_3_compat=False) + router = JsonRpcRoute( + mock_agent_card, mock_handler, enable_v0_3_compat=False + ) app.routes.append(router.route) client = TestClient(app) diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/routes/test_rest_routes.py similarity index 97% rename from tests/server/apps/rest/test_rest_fastapi_app.py rename to tests/server/routes/test_rest_routes.py index 131d9f11..1a5729bd 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/routes/test_rest_routes.py @@ -10,7 +10,7 @@ from google.protobuf import json_format from httpx import ASGITransport, AsyncClient -from a2a.server.router.rest_router import RestRouter +from a2a.server.routes import RestRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( @@ -74,11 +74,10 @@ async def streaming_app( streaming_agent_card: AgentCard, request_handler: RequestHandler ) -> Any: from starlette.applications import Starlette - router = RestRouter( - streaming_agent_card, request_handler, rpc_url='' - ) + + router = RestRoutes(streaming_agent_card, request_handler, rpc_url='') app = Starlette() - app.mount('', router.router) + app.routes.extend(router.routes) return app @@ -97,22 +96,23 @@ async def app( extended_card_modifier: MagicMock | None, ) -> Any: from starlette.applications import Starlette - from a2a.server.router.agent_card_router import AgentCardRouter + from a2a.server.routes import AgentCardRoute + # Return Starlette app app_instance = Starlette() - - rest_router = RestRouter( + + rest_router = RestRoutes( agent_card, request_handler, extended_card_modifier=extended_card_modifier, - rpc_url='' + rpc_url='', ) - app_instance.mount('', rest_router.router) - + app_instance.routes.extend(rest_router.routes) + # Also Agent card endpoint? if needed in tests - card_router = AgentCardRouter(agent_card, card_url='/well-known/agent.json') + card_router = AgentCardRoute(agent_card, card_url='/well-known/agent.json') app_instance.routes.append(card_router.route) - + return app_instance @@ -130,11 +130,11 @@ async def client(app: FastAPI) -> AsyncClient: async def test_create_rest_router_with_v0_3_compat( agent_card: AgentCard, request_handler: RequestHandler ): - router = RestRouter( + router = RestRoutes( agent_card, request_handler, enable_v0_3_compat=True, rpc_url='' ) # Check if a route from v0.3 compatibility is present - routes = [getattr(route, 'path', '') for route in router.router.routes] + routes = [getattr(route, 'path', '') for route in router.routes] assert '/v1/message:send' in routes diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 0e30fe30..02c5c8b2 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -18,8 +18,8 @@ from starlette.routing import Route from starlette.testclient import TestClient -from a2a.server.router.agent_card_router import AgentCardRouter -from a2a.server.router.jsonrpc_router import JsonRpcRouter +from a2a.server.routes import AgentCardRoute, JsonRpcRoute + from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( InternalError, @@ -152,18 +152,31 @@ def __init__(self, agent_card, handler, card_modifier=None): self.handler = handler self.card_modifier = card_modifier - def build(self, rpc_url='/', agent_card_url=AGENT_CARD_WELL_KNOWN_PATH, middleware=None, routes=None): + def build( + self, + rpc_url='/', + agent_card_url=AGENT_CARD_WELL_KNOWN_PATH, + middleware=None, + routes=None, + ): from starlette.applications import Starlette + app_instance = Starlette(middleware=middleware, routes=routes or []) - + # Agent card router - card_router = AgentCardRouter(self.agent_card, card_url=agent_card_url, card_modifier=self.card_modifier) + card_router = AgentCardRoute( + self.agent_card, + card_url=agent_card_url, + card_modifier=self.card_modifier, + ) app_instance.routes.append(card_router.route) - + # JSON-RPC router - rpc_router = JsonRpcRouter(self.agent_card, self.handler, rpc_url=rpc_url) + rpc_router = JsonRpcRoute( + self.agent_card, self.handler, rpc_url=rpc_url + ) app_instance.routes.append(rpc_router.route) - + return app_instance @@ -191,9 +204,7 @@ def test_agent_card_endpoint(client: TestClient, agent_card: AgentCard): assert 'streaming' in data['capabilities'] -def test_agent_card_custom_url( - app, agent_card: AgentCard -): +def test_agent_card_custom_url(app, agent_card: AgentCard): """Test the agent card endpoint with a custom URL.""" client = TestClient(app.build(agent_card_url='/my-agent')) response = client.get('/my-agent') @@ -202,9 +213,7 @@ def test_agent_card_custom_url( assert data['name'] == agent_card.name -def test_starlette_rpc_endpoint_custom_url( - app, handler: mock.AsyncMock -): +def test_starlette_rpc_endpoint_custom_url(app, handler: mock.AsyncMock): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = MINIMAL_TASK_STATUS @@ -225,9 +234,7 @@ def test_starlette_rpc_endpoint_custom_url( assert data['result']['id'] == 'task1' -def test_fastapi_rpc_endpoint_custom_url( - app, handler: mock.AsyncMock -): +def test_fastapi_rpc_endpoint_custom_url(app, handler: mock.AsyncMock): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = MINIMAL_TASK_STATUS @@ -248,9 +255,7 @@ def test_fastapi_rpc_endpoint_custom_url( assert data['result']['id'] == 'task1' -def test_starlette_build_with_extra_routes( - app, agent_card: AgentCard -): +def test_starlette_build_with_extra_routes(app, agent_card: AgentCard): """Test building the app with additional routes.""" def custom_handler(request): @@ -272,9 +277,7 @@ def custom_handler(request): assert data['name'] == agent_card.name -def test_fastapi_build_with_extra_routes( - app, agent_card: AgentCard -): +def test_fastapi_build_with_extra_routes(app, agent_card: AgentCard): """Test building the app with additional routes.""" def custom_handler(request): @@ -296,9 +299,7 @@ def custom_handler(request): assert data['name'] == agent_card.name -def test_fastapi_build_custom_agent_card_path( - app, agent_card: AgentCard -): +def test_fastapi_build_custom_agent_card_path(app, agent_card: AgentCard): """Test building the app with a custom agent card path.""" test_app = app.build(agent_card_url='/agent-card') @@ -548,9 +549,7 @@ async def authenticate( @pytest.mark.asyncio -async def test_message_send_stream( - app, handler: mock.AsyncMock -) -> None: +async def test_message_send_stream(app, handler: mock.AsyncMock) -> None: """Test streaming message sending.""" # Setup mock streaming response @@ -624,9 +623,7 @@ async def stream_generator(): @pytest.mark.asyncio -async def test_task_resubscription( - app, handler: mock.AsyncMock -) -> None: +async def test_task_resubscription(app, handler: mock.AsyncMock) -> None: """Test task resubscription streaming.""" # Setup mock streaming response @@ -757,9 +754,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = AppBuilder( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -782,9 +777,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = AppBuilder( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -807,9 +800,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = AppBuilder( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -829,9 +820,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = AppBuilder( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) From 6a23a42db9863eb0059431bb113cd3838af490a3 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 10:08:38 +0000 Subject: [PATCH 06/17] wip --- src/a2a/compat/v0_3/jsonrpc_adapter.py | 2 +- src/a2a/compat/v0_3/rest_adapter.py | 5 ++- src/a2a/server/routes/agent_card_route.py | 14 +++++++++ src/a2a/server/routes/jsonrpc_dispatcher.py | 19 +++++------ src/a2a/server/routes/jsonrpc_route.py | 35 ++++++++++++++------- src/a2a/server/routes/rest_routes.py | 17 +++++----- 6 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index cdb701b5..f9483191 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from starlette.requests import Request - from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder + from a2a.server.routes.jsonrpc_dispatcher import CallContextBuilder from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import AgentCard diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index b0296e40..d5ad567e 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -33,11 +33,10 @@ from a2a.compat.v0_3 import conversions from a2a.compat.v0_3.rest_handler import REST03Handler -from a2a.server.apps.jsonrpc.jsonrpc_app import ( +from a2a.server.routes.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, ) -from a2a.server.apps.rest.rest_adapter import RESTAdapterInterface from a2a.server.context import ServerCallContext from a2a.utils.error_handlers import ( rest_error_handler, @@ -53,7 +52,7 @@ logger = logging.getLogger(__name__) -class REST03Adapter(RESTAdapterInterface): +class REST03Adapter: """Adapter to make RequestHandler work with v0.3 RESTful API. Defines v0.3 REST request processors and their routes, as well as managing response generation including Server-Sent Events (SSE). diff --git a/src/a2a/server/routes/agent_card_route.py b/src/a2a/server/routes/agent_card_route.py index feee2937..b5481bff 100644 --- a/src/a2a/server/routes/agent_card_route.py +++ b/src/a2a/server/routes/agent_card_route.py @@ -9,6 +9,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route + + _package_starlette_installed = True else: try: from starlette.middleware import Middleware @@ -16,6 +18,7 @@ from starlette.responses import JSONResponse, Response from starlette.routing import Route + _package_starlette_installed = True except ImportError: Middleware = Any Route = Any @@ -23,6 +26,8 @@ Response = Any JSONResponse = Any + _package_starlette_installed = False + from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types.a2a_pb2 import AgentCard from a2a.utils.helpers import maybe_await @@ -49,7 +54,16 @@ def __init__( card_modifier: An optional callback to dynamically modify the public agent card before it is served. card_url: The URL for the agent card endpoint. + middleware: An optional list of Starlette middleware to apply to the + agent card endpoint. """ + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use the `JsonRpcRoute`.' + ' It can be added as a part of `a2a-sdk` optional dependencies,' + ' `a2a-sdk[http-server]`.' + ) + self.agent_card = agent_card self.card_modifier = card_modifier diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index ff2806c8..eda82fa5 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -311,7 +311,7 @@ def _allowed_content_length(self, request: Request) -> bool: return False return True - async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 + async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 """Handles incoming POST requests to the main A2A endpoint. Parses the request body as JSON, validates it against A2A request types, @@ -417,7 +417,7 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 call_context.state['method'] = method call_context.state['request_id'] = request_id - # Route streaming requests by method name + handler_result: AsyncGenerator[dict[str, Any], None] | dict[str, Any] # Route streaming requests by method name if method in ('SendStreamingMessage', 'SubscribeToTask'): handler_result = await self._process_streaming_request( @@ -431,8 +431,8 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 handler_result = JSONRPC20Response( result=raw_result, _id=request_id ).data - except Exception as e: - return self._generate_error_response(request_id, e) + except A2AError as e: + handler_result = build_error_response(request_id, e) return self._create_response(call_context, handler_result) except json.decoder.JSONDecodeError as e: @@ -497,17 +497,12 @@ async def _wrap_stream( stream_response, preserving_proto_field_name=False ) yield JSONRPC20Response(result=result, _id=request_id).data - except Exception as e: - error = ( - e - if isinstance(e, A2AError) - else InternalError(message=str(e)) - ) - yield build_error_response(request_id, error) + except A2AError as e: + yield build_error_response(request_id, e) return _wrap_stream(stream) - async def _process_non_streaming_request( + async def _process_non_streaming_request( # noqa: PLR0911, PLR0912 self, request_id: str | int | None, request_obj: A2ARequest, diff --git a/src/a2a/server/routes/jsonrpc_route.py b/src/a2a/server/routes/jsonrpc_route.py index e7be63c1..536d6315 100644 --- a/src/a2a/server/routes/jsonrpc_route.py +++ b/src/a2a/server/routes/jsonrpc_route.py @@ -3,21 +3,21 @@ from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any -from starlette.middleware import Middleware -from starlette.routing import Route - if TYPE_CHECKING: - - from starlette.routing import Router + from starlette.middleware import Middleware + from starlette.routing import Route, Router _package_starlette_installed = True else: try: - from starlette.routing import Router + from starlette.middleware import Middleware + from starlette.routing import Route, Router _package_starlette_installed = True except ImportError: + Middleware = Any + Route = Any Router = Any _package_starlette_installed = False @@ -35,9 +35,6 @@ logger = logging.getLogger(__name__) - - - class JsonRpcRoute: """Provides the Starlette Route for the A2A protocol JSON-RPC endpoint. @@ -64,7 +61,23 @@ def __init__( # noqa: PLR0913 ) -> None: """Initializes the JsonRpcRoute. - ... (docstrings remain the same) ... + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + rpc_url: The URL prefix for the RPC endpoints. + middleware: An optional list of Starlette middleware to apply to the routes. """ if not _package_starlette_installed: raise ImportError( @@ -85,7 +98,7 @@ def __init__( # noqa: PLR0913 self.route = Route( path=rpc_url, - endpoint=self.dispatcher._handle_requests, + endpoint=self.dispatcher.handle_requests, methods=['POST'], middleware=middleware, ) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 5620bd5e..01b571b8 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -3,6 +3,10 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any +from a2a.compat.v0_3.rest_adapter import ( + REST03Adapter as V03RESTAdapter, +) + if TYPE_CHECKING: from sse_starlette.sse import EventSourceResponse @@ -91,12 +95,12 @@ def __init__( # noqa: PLR0913 Args: agent_card: The AgentCard describing the agent's capabilities. - httpr: The handler instance responsible for processing A2A + request_handler: The handler instance responsible for processing A2A requests via http. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the httpr. If None, no + ServerCallContext passed to the request_handler. If None, no ServerCallContext is passed. card_modifier: An optional callback to dynamically modify the public agent card before it is served. @@ -104,6 +108,8 @@ def __init__( # noqa: PLR0913 the extended agent card before it is served. It receives the call context. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + rpc_url: The URL prefix for the RPC endpoints. + middleware: An optional list of Starlette middleware to apply to the routes. """ if not _package_starlette_installed: raise ImportError( @@ -122,17 +128,14 @@ def __init__( # noqa: PLR0913 self._v03_adapter = None if enable_v0_3_compat: - from a2a.compat.v0_3.rest_adapter import ( - REST03Adapter as V03RESTAdapter, - ) - self._v03_adapter = V03RESTAdapter( agent_card=agent_card, - httpr=request_handler, + http_handler=request_handler, extended_agent_card=extended_agent_card, context_builder=context_builder, ) + self._setup_routes(rpc_url) def _build_call_context(self, request: Request) -> ServerCallContext: From 5ba961e6a077cd55c52174a3ace443e2100b6de0 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 11:09:52 +0000 Subject: [PATCH 07/17] wip --- src/a2a/compat/v0_3/jsonrpc_adapter.py | 2 +- src/a2a/compat/v0_3/rest_adapter.py | 2 +- src/a2a/server/routes/__init__.py | 4 +- src/a2a/server/routes/agent_card_route.py | 16 +- src/a2a/server/routes/jsonrpc_dispatcher.py | 4 +- src/a2a/server/routes/jsonrpc_route.py | 16 +- src/a2a/server/routes/rest_routes.py | 1 - tck/sut_agent.py | 22 +- tests/compat/v0_3/test_jsonrpc_app_compat.py | 11 +- .../v0_3/test_rest_fastapi_app_compat.py | 19 +- tests/e2e/push_notifications/agent_app.py | 30 +- tests/integration/test_agent_card.py | 28 +- .../test_client_server_integration.py | 73 +- tests/integration/test_end_to_end.py | 22 +- tests/integration/test_tenant.py | 12 +- .../request_handlers/test_jsonrpc_handler.py | 1453 ----------------- tests/server/routes/test_agent_card_route.py | 6 +- tests/server/routes/test_jsonrpc_route.py | 14 +- tests/server/routes/test_rest_routes.py | 6 +- tests/server/test_integration.py | 10 +- 20 files changed, 177 insertions(+), 1574 deletions(-) delete mode 100644 tests/server/request_handlers/test_jsonrpc_handler.py diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index f9483191..b62c6f66 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -10,8 +10,8 @@ if TYPE_CHECKING: from starlette.requests import Request - from a2a.server.routes.jsonrpc_dispatcher import CallContextBuilder from a2a.server.request_handlers.request_handler import RequestHandler + from a2a.server.routes.jsonrpc_dispatcher import CallContextBuilder from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index d5ad567e..91186059 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -33,11 +33,11 @@ from a2a.compat.v0_3 import conversions from a2a.compat.v0_3.rest_handler import REST03Handler +from a2a.server.context import ServerCallContext from a2a.server.routes.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, ) -from a2a.server.context import ServerCallContext from a2a.utils.error_handlers import ( rest_error_handler, rest_stream_error_handler, diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index d9d5ee22..2d8fc39c 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -1,12 +1,12 @@ """A2A Server Routes.""" -from a2a.server.routes.agent_card_route import AgentCardRoute +from a2a.server.routes.agent_card_route import AgentCardRoutes from a2a.server.routes.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, StarletteUserProxy, ) -from a2a.server.routes.jsonrpc_route import JsonRpcRoute +from a2a.server.routes.jsonrpc_route import JsonRpcRoutes from a2a.server.routes.rest_routes import RestRoutes diff --git a/src/a2a/server/routes/agent_card_route.py b/src/a2a/server/routes/agent_card_route.py index b5481bff..782d0712 100644 --- a/src/a2a/server/routes/agent_card_route.py +++ b/src/a2a/server/routes/agent_card_route.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) -class AgentCardRoute: +class AgentCardRoutes: """Provides the Starlette Route for the A2A protocol agent card endpoint.""" def __init__( @@ -75,9 +75,11 @@ async def get_agent_card(request: Request) -> Response: ) return JSONResponse(agent_card_to_dict(card_to_serve)) - self.route = Route( - path=card_url, - endpoint=get_agent_card, - methods=['GET'], - middleware=middleware, - ) + self.routes = [ + Route( + path=card_url, + endpoint=get_agent_card, + methods=['GET'], + middleware=middleware, + ) + ] diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index eda82fa5..4bab6855 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -417,7 +417,9 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, call_context.state['method'] = method call_context.state['request_id'] = request_id - handler_result: AsyncGenerator[dict[str, Any], None] | dict[str, Any] + handler_result: ( + AsyncGenerator[dict[str, Any], None] | dict[str, Any] + ) # Route streaming requests by method name if method in ('SendStreamingMessage', 'SubscribeToTask'): handler_result = await self._process_streaming_request( diff --git a/src/a2a/server/routes/jsonrpc_route.py b/src/a2a/server/routes/jsonrpc_route.py index 536d6315..82a78774 100644 --- a/src/a2a/server/routes/jsonrpc_route.py +++ b/src/a2a/server/routes/jsonrpc_route.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) -class JsonRpcRoute: +class JsonRpcRoutes: """Provides the Starlette Route for the A2A protocol JSON-RPC endpoint. Handles incoming JSON-RPC requests, routes them to the appropriate @@ -96,9 +96,11 @@ def __init__( # noqa: PLR0913 enable_v0_3_compat=enable_v0_3_compat, ) - self.route = Route( - path=rpc_url, - endpoint=self.dispatcher.handle_requests, - methods=['POST'], - middleware=middleware, - ) + self.routes = [ + Route( + path=rpc_url, + endpoint=self.dispatcher.handle_requests, + methods=['POST'], + middleware=middleware, + ) + ] diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 01b571b8..117465a7 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -135,7 +135,6 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, ) - self._setup_routes(rpc_url) def _build_call_context(self, request: Request) -> ServerCallContext: diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 8612e47f..f2cc03b3 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -21,7 +21,7 @@ DefaultRequestHandler, ) from a2a.server.request_handlers.grpc_handler import GrpcHandler -from a2a.server.routes import AgentCardRoute, JsonRpcRoute, RestRoutes +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore from a2a.types import ( @@ -194,30 +194,30 @@ def serve(task_store: TaskStore) -> None: task_store=task_store, ) - main_app = Starlette() - # Agent Card - agent_card_router = AgentCardRoute( + agent_card_routes = AgentCardRoutes( agent_card=agent_card, card_url=AGENT_CARD_URL, ) - main_app.routes.append(agent_card_router.route) - # JSONRPC - jsonrpc_router = JsonRpcRoute( + jsonrpc_routes = JsonRpcRoutes( agent_card=agent_card, request_handler=request_handler, rpc_url=JSONRPC_URL, ) - main_app.routes.append(jsonrpc_router.route) - # REST - rest_router = RestRoutes( + rest_routes = RestRoutes( agent_card=agent_card, request_handler=request_handler, rpc_url=REST_URL, ) - main_app.routes.extend(rest_router.routes) + + routes = [ + *agent_card_routes.routes, + *jsonrpc_routes.routes, + *rest_routes.routes, + ] + main_app = Starlette(routes=routes) config = uvicorn.Config( main_app, host='127.0.0.1', port=http_port, log_level='info' diff --git a/tests/compat/v0_3/test_jsonrpc_app_compat.py b/tests/compat/v0_3/test_jsonrpc_app_compat.py index 4f09bb23..4b344c67 100644 --- a/tests/compat/v0_3/test_jsonrpc_app_compat.py +++ b/tests/compat/v0_3/test_jsonrpc_app_compat.py @@ -6,7 +6,8 @@ import pytest from starlette.testclient import TestClient -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from starlette.applications import Starlette +from a2a.server.routes import JsonRpcRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, @@ -50,16 +51,18 @@ def test_app(mock_handler): mock_agent_card.capabilities.streaming = False mock_agent_card.capabilities.push_notifications = True mock_agent_card.capabilities.extended_agent_card = True - return A2AStarletteApplication( + router = JsonRpcRoutes( agent_card=mock_agent_card, - http_handler=mock_handler, + request_handler=mock_handler, enable_v0_3_compat=True, + rpc_url='/', ) + return Starlette(routes=router.routes) @pytest.fixture def client(test_app): - return TestClient(test_app.build()) + return TestClient(test_app) def test_send_message_v03_compat( diff --git a/tests/compat/v0_3/test_rest_fastapi_app_compat.py b/tests/compat/v0_3/test_rest_fastapi_app_compat.py index 8625b7e0..2b2ca22d 100644 --- a/tests/compat/v0_3/test_rest_fastapi_app_compat.py +++ b/tests/compat/v0_3/test_rest_fastapi_app_compat.py @@ -9,7 +9,8 @@ from google.protobuf import json_format from httpx import ASGITransport, AsyncClient -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import RestRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, @@ -50,17 +51,19 @@ async def request_handler() -> RequestHandler: async def app( agent_card: AgentCard, request_handler: RequestHandler, -) -> FastAPI: - """Builds the FastAPI application for testing.""" - return A2ARESTFastAPIApplication( - agent_card, - request_handler, +) -> Starlette: + """Builds the Starlette application for testing.""" + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=request_handler, enable_v0_3_compat=True, - ).build(agent_card_url='/well-known/agent.json', rpc_url='') + rpc_url='', + ) + return Starlette(routes=rest_routes.routes) @pytest.fixture -async def client(app: FastAPI) -> AsyncClient: +async def client(app: Starlette) -> AsyncClient: return AsyncClient( transport=ASGITransport(app=app), base_url='http://testapp' ) diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index ca1a234b..ec2456d6 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -3,7 +3,8 @@ from fastapi import FastAPI from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import RestRoutes from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue from a2a.server.request_handlers import DefaultRequestHandler @@ -139,17 +140,20 @@ def create_agent_app( ) -> FastAPI: """Creates a new HTTP+REST FastAPI application for the test agent.""" push_config_store = InMemoryPushNotificationConfigStore() - app = A2ARESTFastAPIApplication( - agent_card=test_agent_card(url), - http_handler=DefaultRequestHandler( - agent_executor=TestAgentExecutor(), - task_store=InMemoryTaskStore(), - push_config_store=push_config_store, - push_sender=BasePushNotificationSender( - httpx_client=notification_client, - config_store=push_config_store, - context=ServerCallContext(), - ), + handler = DefaultRequestHandler( + agent_executor=TestAgentExecutor(), + task_store=InMemoryTaskStore(), + push_config_store=push_config_store, + push_sender=BasePushNotificationSender( + httpx_client=notification_client, + config_store=push_config_store, + context=ServerCallContext(), ), ) - return app.build() + rest_routes = RestRoutes( + agent_card=test_agent_card(url), + request_handler=handler, + extended_agent_card=test_agent_card(url), + rpc_url='', + ) + return FastAPI(routes=rest_routes.routes) diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index 0af06ad7..710399e4 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -4,7 +4,8 @@ from fastapi import FastAPI from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler @@ -69,16 +70,27 @@ async def test_agent_card_integration() -> None: app = FastAPI() # Mount JSONRPC application - # In JSONRPCApplication, the default agent_card_url is AGENT_CARD_WELL_KNOWN_PATH - jsonrpc_app = A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card - ).build() + jsonrpc_routes = [ + *AgentCardRoutes( + agent_card=agent_card, card_url='/.well-known/agent-card.json' + ).routes, + *JsonRpcRoutes( + agent_card=agent_card, request_handler=handler, rpc_url='/' + ).routes, + ] + jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) # Mount REST application - rest_app = A2ARESTFastAPIApplication( - http_handler=handler, agent_card=agent_card - ).build() + rest_routes = [ + *AgentCardRoutes( + agent_card=agent_card, card_url='/.well-known/agent-card.json' + ).routes, + *RestRoutes( + agent_card=agent_card, request_handler=handler, rpc_url='' + ).routes, + ] + rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) expected_content = { diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e239d780..2d2d3ac9 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -23,7 +23,8 @@ with_a2a_extensions, ) from a2a.client.transports import JsonRpcTransport, RestTransport -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -220,10 +221,14 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): def jsonrpc_setup(http_base_setup) -> TransportSetup: """Sets up the JsonRpcTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2AFastAPIApplication( - agent_card, mock_request_handler, extended_agent_card=agent_card + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) factory = ClientFactory( config=ClientConfig( @@ -239,10 +244,13 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: def rest_setup(http_base_setup) -> TransportSetup: """Sets up the RestTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2ARESTFastAPIApplication( - agent_card, mock_request_handler, extended_agent_card=agent_card + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='', ) - app = app_builder.build() + app = Starlette(routes=rest_routes.routes) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) factory = ClientFactory( config=ClientConfig( @@ -619,12 +627,16 @@ async def test_json_transport_get_signed_base_card( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, - card_modifier=signer, # Sign the base card + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, card_url='/', card_modifier=signer ) - app = app_builder.build() + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', + ) + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url @@ -639,7 +651,8 @@ async def test_json_transport_get_signed_base_card( # Verification happens here result = await resolver.get_agent_card( - signature_verifier=signature_verifier + relative_card_path='/', + signature_verifier=signature_verifier, ) # Create transport with the verified card @@ -684,15 +697,15 @@ async def test_client_get_signed_extended_card( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, extended_agent_card=extended_agent_card, - extended_card_modifier=lambda card, ctx: signer( - card - ), # Sign the extended card + extended_card_modifier=lambda card, ctx: signer(card), + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) transport = JsonRpcTransport( @@ -753,16 +766,17 @@ async def test_client_get_signed_base_and_extended_cards( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, card_url='/', card_modifier=signer + ) + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, extended_agent_card=extended_agent_card, - card_modifier=signer, # Sign the base card - extended_card_modifier=lambda card, ctx: signer( - card - ), # Sign the extended card + extended_card_modifier=lambda card, ctx: signer(card), + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url @@ -777,7 +791,8 @@ async def test_client_get_signed_base_and_extended_cards( # 1. Fetch base card base_card = await resolver.get_agent_card( - signature_verifier=signature_verifier + relative_card_path='/', + signature_verifier=signature_verifier, ) # 2. Create transport with base card diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index ddf9edbf..a2735e9e 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -10,7 +10,8 @@ from a2a.client.client import ClientConfig from a2a.client.client_factory import ClientFactory from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -171,8 +172,13 @@ def base_e2e_setup(): @pytest.fixture def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - app_builder = A2ARESTFastAPIApplication(agent_card, handler) - app = app_builder.build() + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='', + ) + app = Starlette(routes=rest_routes.routes) httpx_client = httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url='http://testserver' ) @@ -192,10 +198,14 @@ def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: @pytest.fixture def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - app_builder = A2AFastAPIApplication( - agent_card, handler, extended_agent_card=agent_card + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url='http://testserver' ) diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 903b90a2..10f92461 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -19,7 +19,8 @@ from a2a.client import ClientConfig, ClientFactory from a2a.utils.constants import TransportProtocol -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from starlette.applications import Starlette +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.context import ServerCallContext @@ -197,10 +198,13 @@ def jsonrpc_agent_card(self): @pytest.fixture def server_app(self, jsonrpc_agent_card, mock_handler): - app = A2AStarletteApplication( + jsonrpc_routes = JsonRpcRoutes( agent_card=jsonrpc_agent_card, - http_handler=mock_handler, - ).build(rpc_url='/jsonrpc') + request_handler=mock_handler, + extended_agent_card=jsonrpc_agent_card, + rpc_url='/jsonrpc', + ) + app = Starlette(routes=jsonrpc_routes.routes) return app @pytest.mark.asyncio diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py deleted file mode 100644 index cbdf6b5e..00000000 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ /dev/null @@ -1,1453 +0,0 @@ -import asyncio -import unittest -import unittest.async_case - -from collections.abc import AsyncGenerator -from typing import Any, NoReturn -from unittest.mock import ANY, AsyncMock, MagicMock, call, patch - -import httpx -import pytest - -from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.agent_execution.request_context_builder import ( - RequestContextBuilder, -) -from a2a.server.context import ServerCallContext -from a2a.server.events import QueueManager -from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers import DefaultRequestHandler, JSONRPCHandler -from a2a.server.tasks import ( - BasePushNotificationSender, - InMemoryPushNotificationConfigStore, - PushNotificationConfigStore, - PushNotificationSender, - TaskStore, -) -from a2a.types import ( - InternalError, - TaskNotFoundError, - UnsupportedOperationError, -) -from a2a.types.a2a_pb2 import ( - AgentCapabilities, - AgentCard, - AgentInterface, - Artifact, - CancelTaskRequest, - DeleteTaskPushNotificationConfigRequest, - GetExtendedAgentCardRequest, - GetTaskPushNotificationConfigRequest, - GetTaskRequest, - ListTaskPushNotificationConfigsRequest, - ListTaskPushNotificationConfigsResponse, - ListTasksResponse, - Message, - Part, - TaskPushNotificationConfig, - Role, - SendMessageConfiguration, - SendMessageRequest, - TaskPushNotificationConfig, - SubscribeToTaskRequest, - Task, - TaskArtifactUpdateEvent, - TaskPushNotificationConfig, - TaskState, - TaskStatus, - TaskStatusUpdateEvent, -) - - -# Helper function to create a minimal Task proto -def create_task( - task_id: str = 'task_123', context_id: str = 'session-xyz' -) -> Task: - return Task( - id=task_id, - context_id=context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - - -# Helper function to create a Message proto -def create_message( - message_id: str = '111', - role: Role = Role.ROLE_AGENT, - text: str = 'test message', - task_id: str | None = None, - context_id: str | None = None, -) -> Message: - msg = Message( - message_id=message_id, - role=role, - parts=[Part(text=text)], - ) - if task_id: - msg.task_id = task_id - if context_id: - msg.context_id = context_id - return msg - - -# Helper functions for checking JSON-RPC response structure -def is_success_response(response: dict[str, Any]) -> bool: - """Check if response is a successful JSON-RPC response.""" - return 'result' in response and 'error' not in response - - -def is_error_response(response: dict[str, Any]) -> bool: - """Check if response is an error JSON-RPC response.""" - return 'error' in response - - -def get_error_code(response: dict[str, Any]) -> int | None: - """Get error code from JSON-RPC error response.""" - if 'error' in response: - return response['error'].get('code') - return None - - -def get_error_message(response: dict[str, Any]) -> str | None: - """Get error message from JSON-RPC error response.""" - if 'error' in response: - return response['error'].get('message') - return None - - -class TestJSONRPCtHandler(unittest.async_case.IsolatedAsyncioTestCase): - @pytest.fixture(autouse=True) - def init_fixtures(self) -> None: - self.mock_agent_card = MagicMock( - spec=AgentCard, - ) - self.mock_agent_card.capabilities = MagicMock(spec=AgentCapabilities) - self.mock_agent_card.capabilities.extended_agent_card = True - self.mock_agent_card.capabilities.streaming = True - self.mock_agent_card.capabilities.push_notifications = True - - # Mock supported_interfaces list - interface = MagicMock(spec=AgentInterface) - interface.url = 'http://agent.example.com/api' - self.mock_agent_card.supported_interfaces = [interface] - - async def test_on_get_task_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - request = GetTaskRequest(id=f'{task_id}') - response = await handler.on_get_task(request, call_context) - # Response is now a dict with 'result' key for success - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - assert response['result']['id'] == task_id - mock_task_store.get.assert_called_once_with(f'{task_id}', ANY) - - async def test_on_get_task_not_found(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = GetTaskRequest(id='nonexistent_id') - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - response = await handler.on_get_task(request, call_context) - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32001 - - async def test_on_list_tasks_success(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task1 = create_task() - task2 = create_task() - task2.id = 'task_456' - mock_result = ListTasksResponse( - next_page_token='123', - tasks=[task1, task2], - ) - request_handler.on_list_tasks.return_value = mock_result - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest( - page_size=10, - page_token='token', - ) - call_context = ServerCallContext(state={'foo': 'bar'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertIn('tasks', response['result']) - self.assertEqual(len(response['result']['tasks']), 2) - self.assertEqual(response['result']['nextPageToken'], '123') - - async def test_on_list_tasks_error(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - request_handler.on_list_tasks.side_effect = InternalError( - message='DB down' - ) - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest(page_size=10) - call_context = ServerCallContext(state={'request_id': '2'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['message'], 'DB down') - - async def test_on_list_tasks_empty(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_result = ListTasksResponse(page_size=10) - request_handler.on_list_tasks.return_value = mock_result - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest(page_size=10) - call_context = ServerCallContext(state={'foo': 'bar'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertIn('tasks', response['result']) - self.assertEqual(len(response['result']['tasks']), 0) - self.assertIn('nextPageToken', response['result']) - self.assertEqual(response['result']['nextPageToken'], '') - self.assertIn('pageSize', response['result']) - self.assertEqual(response['result']['pageSize'], 10) - self.assertIn('totalSize', response['result']) - self.assertEqual(response['result']['totalSize'], 0) - - async def test_on_cancel_task_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - mock_agent_executor.cancel.return_value = None - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - - async def streaming_coro(): - mock_task.status.state = TaskState.TASK_STATE_CANCELED - yield mock_task - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = CancelTaskRequest(id=f'{task_id}') - response = await handler.on_cancel_task(request, call_context) - assert mock_agent_executor.cancel.call_count == 1 - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result is converted to dict for JSON serialization - assert response['result']['id'] == task_id # type: ignore - assert ( - response['result']['status']['state'] == 'TASK_STATE_CANCELED' - ) # type: ignore - mock_agent_executor.cancel.assert_called_once() - - async def test_on_cancel_task_not_supported(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - mock_agent_executor.cancel.return_value = None - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': '1'} - ) - - async def streaming_coro(): - raise UnsupportedOperationError() - yield - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = CancelTaskRequest(id=f'{task_id}') - response = await handler.on_cancel_task(request, call_context) - assert mock_agent_executor.cancel.call_count == 1 - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32004 - mock_agent_executor.cancel.assert_called_once() - - async def test_on_cancel_task_not_found(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = CancelTaskRequest(id='nonexistent_id') - call_context = ServerCallContext(state={'request_id': '1'}) - response = await handler.on_cancel_task(request, call_context) - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32001 - mock_task_store.get.assert_called_once_with('nonexistent_id', ANY) - mock_agent_executor.cancel.assert_not_called() - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_new_message_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message( - task_id='task_123', context_id='session-xyz' - ), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - # execute is called asynchronously in background task - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - - async def test_on_message_new_message_with_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - # execute is called asynchronously in background task - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - - async def test_on_message_error(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - async def streaming_coro(): - raise UnsupportedOperationError() - yield - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, context_id=mock_task.context_id - ), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - - # Allow the background event loop to start the execution_task - import asyncio - - await asyncio.sleep(0) - - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32004 - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_stream_new_message_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - # Latch to ensure background execute is scheduled before asserting - execute_called = asyncio.Event() - - async def exec_side_effect(*args, **kwargs): - execute_called.set() - - mock_agent_executor.execute.side_effect = exec_side_effect - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message( - task_id='task_123', context_id='session-xyz' - ), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == len(events) - await asyncio.wait_for(execute_called.wait(), timeout=0.1) - mock_agent_executor.execute.assert_called_once() - - async def test_on_message_stream_new_message_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - # Latch to ensure background execute is scheduled before asserting - execute_called = asyncio.Event() - - async def exec_side_effect(*args, **kwargs): - execute_called.set() - - mock_agent_executor.execute.side_effect = exec_side_effect - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events = [item async for item in response] - assert len(collected_events) == len(events) - await asyncio.wait_for(execute_called.wait(), timeout=0.1) - mock_agent_executor.execute.assert_called_once() - assert mock_task.history is not None and len(mock_task.history) == 1 - - async def test_set_push_notification_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_push_notification_store = AsyncMock( - spec=PushNotificationConfigStore - ) - - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=mock_push_notification_store, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - ) - context = ServerCallContext() - response = await handler.set_push_notification_config(request, context) - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, request, context - ) - - async def test_get_push_notification_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - push_notification_store = InMemoryPushNotificationConfigStore() - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=push_notification_store, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - push_config = TaskPushNotificationConfig( - id='default', url='http://example.com' - ) - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - id='default', - ) - await handler.set_push_notification_config(request, ServerCallContext()) - - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='default', - ) - get_response = await handler.get_push_notification_config( - get_request, ServerCallContext() - ) - self.assertIsInstance(get_response, dict) - self.assertTrue(is_success_response(get_response)) - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_stream_new_message_send_push_notification_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) - push_notification_store = InMemoryPushNotificationConfigStore() - push_notification_sender = BasePushNotificationSender( - mock_httpx_client, push_notification_store, ServerCallContext() - ) - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=push_notification_store, - push_sender=push_notification_sender, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = None - mock_agent_executor.execute.return_value = None - mock_httpx_client.post.return_value = httpx.Response(200) - request = SendMessageRequest( - message=create_message(), - configuration=SendMessageConfiguration( - accepted_output_modes=['text'], - task_push_notification_config=TaskPushNotificationConfig( - url='http://example.com' - ), - ), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - - collected_events = [item async for item in response] - assert len(collected_events) == len(events) - - async def test_on_resubscribe_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_queue_manager = AsyncMock(spec=QueueManager) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store, mock_queue_manager - ) - self.mock_agent_card = MagicMock(spec=AgentCard) - self.mock_agent_card.capabilities = MagicMock(spec=AgentCapabilities) - self.mock_agent_card.capabilities.streaming = True - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_queue_manager.tap.return_value = EventQueue() - request = SubscribeToTaskRequest(id=f'{mock_task.id}') - response = handler.on_subscribe_to_task( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert ( - len(collected_events) == len(events) + 1 - ) # First event is task itself - assert mock_task.history is not None and len(mock_task.history) == 0 - - async def test_on_subscribe_no_existing_task_error(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = SubscribeToTaskRequest(id='nonexistent_id') - response = handler.on_subscribe_to_task(request, ServerCallContext()) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == 1 - self.assertIsInstance(collected_events[0], dict) - self.assertTrue(is_error_response(collected_events[0])) - assert collected_events[0]['error']['code'] == -32001 - - async def test_streaming_not_supported_error( - self, - ) -> None: - """Test that on_message_send_stream raises an error when streaming not supported.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - # Create agent card with streaming capability disabled - self.mock_agent_card.capabilities = AgentCapabilities(streaming=False) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Act & Assert - request = SendMessageRequest( - message=create_message(), - ) - - # Should raise UnsupportedOperationError about streaming not supported - with self.assertRaises(UnsupportedOperationError) as context: - async for _ in handler.on_message_send_stream( - request, ServerCallContext() - ): - pass - - self.assertEqual( - str(context.exception.message), - 'Streaming is not supported by the agent', - ) - - async def test_push_notifications_not_supported_error(self) -> None: - """Test that set_push_notification raises an error when push notifications not supported.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - # Create agent card with push notifications capability disabled - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=False, streaming=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Act & Assert - request = TaskPushNotificationConfig( - task_id='task_123', - url='http://example.com', - ) - - # Should raise UnsupportedOperationError about push notifications not supported - with self.assertRaises(UnsupportedOperationError) as context: - await handler.set_push_notification_config( - request, ServerCallContext() - ) - - self.assertEqual( - str(context.exception.message), - 'Push notifications are not supported by the agent', - ) - - async def test_on_get_push_notification_no_push_config_store(self) -> None: - """Test get_push_notification with no push notifier configured.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - # Create request handler without a push notifier - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Act - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='default', - ) - response = await handler.get_push_notification_config( - get_request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_set_push_notification_no_push_config_store(self) -> None: - """Test set_push_notification with no push notifier configured.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - # Create request handler without a push notifier - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Act - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - ) - response = await handler.set_push_notification_config( - request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_message_send_internal_error(self) -> None: - """Test on_message_send with an internal error.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Make the request handler raise an Internal error without specifying an error type - async def raise_server_error(*args, **kwargs) -> NoReturn: - raise InternalError(message='Internal Error') - - # Patch the method to raise an error - with patch.object( - request_handler, 'on_message_send', side_effect=raise_server_error - ): - # Act - request = SendMessageRequest( - message=create_message(), - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_message_stream_internal_error(self) -> None: - """Test on_message_send_stream with an internal error.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Make the request handler raise an Internal error without specifying an error type - async def raise_server_error(*args, **kwargs): - raise InternalError(message='Internal Error') - yield # Need this to make it an async generator - - # Patch the method to raise an error - with patch.object( - request_handler, - 'on_message_send_stream', - return_value=raise_server_error(), - ): - # Act - request = SendMessageRequest( - message=create_message(), - ) - - # Get the single error response - responses = [] - async for response in handler.on_message_send_stream( - request, ServerCallContext() - ): - responses.append(response) - - # Assert - self.assertEqual(len(responses), 1) - self.assertIsInstance(responses[0], dict) - self.assertTrue(is_error_response(responses[0])) - self.assertEqual(responses[0]['error']['code'], -32603) - - async def test_default_request_handler_with_custom_components(self) -> None: - """Test DefaultRequestHandler initialization with custom components.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_queue_manager = AsyncMock(spec=QueueManager) - mock_push_config_store = AsyncMock(spec=PushNotificationConfigStore) - mock_push_sender = AsyncMock(spec=PushNotificationSender) - mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) - - # Act - handler = DefaultRequestHandler( - agent_executor=mock_agent_executor, - task_store=mock_task_store, - queue_manager=mock_queue_manager, - push_config_store=mock_push_config_store, - push_sender=mock_push_sender, - request_context_builder=mock_request_context_builder, - ) - - # Assert - self.assertEqual(handler.agent_executor, mock_agent_executor) - self.assertEqual(handler.task_store, mock_task_store) - self.assertEqual(handler._queue_manager, mock_queue_manager) - self.assertEqual(handler._push_config_store, mock_push_config_store) - self.assertEqual(handler._push_sender, mock_push_sender) - self.assertEqual( - handler._request_context_builder, mock_request_context_builder - ) - - async def test_on_message_send_error_handling(self) -> None: - """Test error handling in on_message_send when consuming raises A2AError.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Let task exist - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Set up consume_and_break_on_interrupt to raise UnsupportedOperationError - async def consume_raises_error(*args, **kwargs) -> NoReturn: - raise UnsupportedOperationError() - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - side_effect=consume_raises_error, - ): - # Act - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - - response = await handler.on_message_send( - request, ServerCallContext() - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_message_send_task_id_mismatch(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - # Mock returns task with different ID than what will be generated - mock_task_store.get.return_value = None # No existing task - mock_agent_executor.execute.return_value = None - - # Task returned has task_id='task_123' but request_context will have generated UUID - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message(), # No task_id, so UUID is generated - ) - response = await handler.on_message_send( - request, ServerCallContext() - ) - # The task ID mismatch should cause an error - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_message_stream_task_id_mismatch(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - events: list[Any] = [create_task()] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = None - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message(), - ) - response = handler.on_message_send_stream( - request, ServerCallContext() - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == 1 - self.assertIsInstance(collected_events[0], dict) - self.assertTrue(is_error_response(collected_events[0])) - self.assertEqual(collected_events[0]['error']['code'], -32603) - - async def test_on_get_push_notification(self) -> None: - """Test get_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - task_push_config = TaskPushNotificationConfig( - task_id=mock_task.id, id='config1', url='http://example.com' - ) - request_handler.on_get_task_push_notification_config.return_value = ( - task_push_config - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='config1', - ) - response = await handler.get_push_notification_config( - get_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result is converted to dict for JSON serialization - self.assertEqual( - response['result']['id'], - 'config1', - ) - self.assertEqual( - response['result']['taskId'], - mock_task.id, - ) - - async def test_on_list_push_notification(self) -> None: - """Test list_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - task_push_config = TaskPushNotificationConfig( - task_id=mock_task.id, id='default', url='http://example.com' - ) - request_handler.on_list_task_push_notification_configs.return_value = ( - ListTaskPushNotificationConfigsResponse(configs=[task_push_config]) - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - list_request = ListTaskPushNotificationConfigsRequest( - task_id=mock_task.id, - ) - response = await handler.list_push_notification_configs( - list_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result contains the response dict with configs field - self.assertIsInstance(response['result'], dict) - - async def test_on_list_push_notification_error(self) -> None: - """Test list_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - # throw server error - request_handler.on_list_task_push_notification_configs.side_effect = ( - InternalError() - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - list_request = ListTaskPushNotificationConfigsRequest( - task_id=mock_task.id, - ) - response = await handler.list_push_notification_configs( - list_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_delete_push_notification(self) -> None: - """Test delete_push_notification_config handling""" - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - request_handler.on_delete_task_push_notification_config.return_value = ( - None - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - delete_request = DeleteTaskPushNotificationConfigRequest( - task_id='task1', - id='config1', - ) - response = await handler.delete_push_notification_config( - delete_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['result'], None) - - async def test_on_delete_push_notification_error(self) -> None: - """Test delete_push_notification_config error handling""" - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - # throw server error - request_handler.on_delete_task_push_notification_config.side_effect = ( - UnsupportedOperationError() - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - delete_request = DeleteTaskPushNotificationConfigRequest( - task_id='task1', - id='config1', - ) - response = await handler.delete_push_notification_config( - delete_request, ServerCallContext() - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_get_authenticated_extended_card_success(self) -> None: - """Test successful retrieval of the authenticated extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_extended_card = AgentCard( - name='Extended Card', - description='More details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.1', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_extended_card, - extended_card_modifier=None, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-1'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-1') - # Result is the agent card proto - - async def test_get_authenticated_extended_card_not_configured(self) -> None: - """Test error when authenticated extended agent card is not configured.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - # We need a proper card here because agent_card_to_dict accesses multiple fields - card = AgentCard( - name='TestAgent', - version='1.0.0', - supported_interfaces=[ - AgentInterface( - url='http://localhost', - protocol_binding='JSONRPC', - protocol_version='1.0.0', - ) - ], - capabilities=AgentCapabilities(extended_agent_card=True), - ) - - handler = JSONRPCHandler( - card, - mock_request_handler, - extended_agent_card=None, - extended_card_modifier=None, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-2'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - # Authenticated Extended Card flag is set with no extended card, - # returns base card in this case. - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-2') - - async def test_get_authenticated_extended_card_with_modifier(self) -> None: - """Test successful retrieval of a dynamically modified extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_base_card = AgentCard( - name='Base Card', - description='Base details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - - async def modifier( - card: AgentCard, context: ServerCallContext - ) -> AgentCard: - modified_card = AgentCard() - modified_card.CopyFrom(card) - modified_card.name = 'Modified Card' - modified_card.description = ( - f'Modified for context: {context.state.get("foo")}' - ) - return modified_card - - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_base_card, - extended_card_modifier=modifier, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext(state={'foo': 'bar'}) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertFalse(is_error_response(response)) - from google.protobuf.json_format import ParseDict - - modified_card = ParseDict( - response['result'], AgentCard(), ignore_unknown_fields=True - ) - self.assertEqual(modified_card.name, 'Modified Card') - self.assertEqual(modified_card.description, 'Modified for context: bar') - self.assertEqual(modified_card.version, '1.0') - - async def test_get_authenticated_extended_card_with_modifier_sync( - self, - ) -> None: - """Test successful retrieval of a synchronously dynamically modified extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_base_card = AgentCard( - name='Base Card', - description='Base details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - - def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: - # Copy the card by creating a new one with the same fields - from copy import deepcopy - - modified_card = AgentCard() - modified_card.CopyFrom(card) - modified_card.name = 'Modified Card' - modified_card.description = ( - f'Modified for context: {context.state.get("foo")}' - ) - return modified_card - - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_base_card, - extended_card_modifier=modifier, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-mod'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-mod') - # Result is converted to dict for JSON serialization - modified_card_dict = response['result'] - self.assertEqual(modified_card_dict['name'], 'Modified Card') - self.assertEqual( - modified_card_dict['description'], 'Modified for context: bar' - ) - self.assertEqual(modified_card_dict['version'], '1.0') diff --git a/tests/server/routes/test_agent_card_route.py b/tests/server/routes/test_agent_card_route.py index f117d37e..7cea4e81 100644 --- a/tests/server/routes/test_agent_card_route.py +++ b/tests/server/routes/test_agent_card_route.py @@ -2,7 +2,7 @@ from starlette.applications import Starlette from starlette.testclient import TestClient -from a2a.server.routes import AgentCardRoute +from a2a.server.routes import AgentCardRoutes from a2a.types import a2a_pb2 from a2a.server.request_handlers.response_helpers import agent_card_to_dict @@ -19,7 +19,7 @@ def mock_agent_card(): @pytest.fixture def test_app(mock_agent_card): app = Starlette() - card_route = AgentCardRoute(mock_agent_card) + card_route = AgentCardRoutes(mock_agent_card) app.routes.append(card_route.route) return app @@ -45,7 +45,7 @@ async def modifier(card): card.name = 'modified-agent' return card - card_route = AgentCardRoute(mock_agent_card, card_modifier=modifier) + card_route = AgentCardRoutes(mock_agent_card, card_modifier=modifier) app = Starlette() app.routes.append(card_route.route) client = TestClient(app) diff --git a/tests/server/routes/test_jsonrpc_route.py b/tests/server/routes/test_jsonrpc_route.py index 226265ba..776b227f 100644 --- a/tests/server/routes/test_jsonrpc_route.py +++ b/tests/server/routes/test_jsonrpc_route.py @@ -7,7 +7,7 @@ from starlette.testclient import TestClient from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.routes import JsonRpcRoute, StarletteUserProxy +from a2a.server.routes import JsonRpcRoutes, StarletteUserProxy from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler @@ -40,8 +40,8 @@ def test_app(mock_handler): from starlette.applications import Starlette app = Starlette() - router = JsonRpcRoute(mock_agent_card, mock_handler) - app.routes.append(router.route) + jsonrpc_routes = JsonRpcRoutes(mock_agent_card, mock_handler) + app.routes.extend(jsonrpc_routes.routes) return app @@ -211,10 +211,10 @@ def test_v0_3_compat_flag_routes_to_dispatcher(self, mock_handler): from starlette.applications import Starlette app = Starlette() - router = JsonRpcRoute( + routes = JsonRpcRoutes( mock_agent_card, mock_handler, enable_v0_3_compat=True ) - app.routes.append(router.route) + app.routes.extend(routes.routes) client = TestClient(app) @@ -257,10 +257,10 @@ def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): from starlette.applications import Starlette app = Starlette() - router = JsonRpcRoute( + routes = JsonRpcRoutes( mock_agent_card, mock_handler, enable_v0_3_compat=False ) - app.routes.append(router.route) + app.routes.extend(routes.routes) client = TestClient(app) diff --git a/tests/server/routes/test_rest_routes.py b/tests/server/routes/test_rest_routes.py index 1a5729bd..50990f22 100644 --- a/tests/server/routes/test_rest_routes.py +++ b/tests/server/routes/test_rest_routes.py @@ -96,7 +96,7 @@ async def app( extended_card_modifier: MagicMock | None, ) -> Any: from starlette.applications import Starlette - from a2a.server.routes import AgentCardRoute + from a2a.server.routes import AgentCardRoutes # Return Starlette app app_instance = Starlette() @@ -110,8 +110,8 @@ async def app( app_instance.routes.extend(rest_router.routes) # Also Agent card endpoint? if needed in tests - card_router = AgentCardRoute(agent_card, card_url='/well-known/agent.json') - app_instance.routes.append(card_router.route) + card_routes = AgentCardRoutes(agent_card, card_url='/well-known/agent.json') + app_instance.routes.extend(card_routes.routes) return app_instance diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 02c5c8b2..ec974b8c 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -18,7 +18,7 @@ from starlette.routing import Route from starlette.testclient import TestClient -from a2a.server.routes import AgentCardRoute, JsonRpcRoute +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( @@ -164,18 +164,18 @@ def build( app_instance = Starlette(middleware=middleware, routes=routes or []) # Agent card router - card_router = AgentCardRoute( + card_routes = AgentCardRoutes( self.agent_card, card_url=agent_card_url, card_modifier=self.card_modifier, ) - app_instance.routes.append(card_router.route) + app_instance.routes.extend(card_routes.routes) # JSON-RPC router - rpc_router = JsonRpcRoute( + rpc_routes = JsonRpcRoutes( self.agent_card, self.handler, rpc_url=rpc_url ) - app_instance.routes.append(rpc_router.route) + app_instance.routes.extend(rpc_routes.routes) return app_instance From 32936bdd4440186079c0875cd251395725a76bca Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 19 Mar 2026 14:13:22 +0100 Subject: [PATCH 08/17] Update tests/server/routes/test_jsonrpc_route.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/server/routes/test_jsonrpc_route.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/routes/test_jsonrpc_route.py b/tests/server/routes/test_jsonrpc_route.py index 776b227f..4acb15ce 100644 --- a/tests/server/routes/test_jsonrpc_route.py +++ b/tests/server/routes/test_jsonrpc_route.py @@ -233,7 +233,7 @@ def test_v0_3_compat_flag_routes_to_dispatcher(self, mock_handler): # Instead of _v03_adapter, the handler handles it or it's dispatcher with patch.object( - router.dispatcher, + routes.dispatcher, '_process_non_streaming_request', new_callable=AsyncMock, ) as mock_handle: From 50c1984d75f8b49de65934f636974bc60f7284bf Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 19 Mar 2026 14:13:35 +0100 Subject: [PATCH 09/17] Update tests/server/routes/test_agent_card_route.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/server/routes/test_agent_card_route.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/routes/test_agent_card_route.py b/tests/server/routes/test_agent_card_route.py index 7cea4e81..9588910f 100644 --- a/tests/server/routes/test_agent_card_route.py +++ b/tests/server/routes/test_agent_card_route.py @@ -47,7 +47,7 @@ async def modifier(card): card_route = AgentCardRoutes(mock_agent_card, card_modifier=modifier) app = Starlette() - app.routes.append(card_route.route) + app.routes.append(card_route.routes[0]) client = TestClient(app) response = client.get('/') From 292159bcf20483467eb591d0c6d0adc82752403b Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 19 Mar 2026 14:13:45 +0100 Subject: [PATCH 10/17] Update tests/server/routes/test_agent_card_route.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/server/routes/test_agent_card_route.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/routes/test_agent_card_route.py b/tests/server/routes/test_agent_card_route.py index 9588910f..6ba4dbee 100644 --- a/tests/server/routes/test_agent_card_route.py +++ b/tests/server/routes/test_agent_card_route.py @@ -20,7 +20,7 @@ def mock_agent_card(): def test_app(mock_agent_card): app = Starlette() card_route = AgentCardRoutes(mock_agent_card) - app.routes.append(card_route.route) + app.routes.append(card_route.routes[0]) return app From c42fb8cf60cf8efa2225744b6e6229c105b9c0b8 Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 19 Mar 2026 14:14:02 +0100 Subject: [PATCH 11/17] Update src/a2a/server/routes/__init__.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/server/routes/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index 2d8fc39c..e03680f8 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -11,7 +11,7 @@ __all__ = [ - 'AgentCardRoute', + 'AgentCardRoutes', 'CallContextBuilder', 'DefaultCallContextBuilder', 'JsonRpcRoute', From 7893240fa16b64752f94b8d6d571c4c8334622b6 Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 19 Mar 2026 14:14:11 +0100 Subject: [PATCH 12/17] Update src/a2a/server/routes/__init__.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/server/routes/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index e03680f8..e94428a4 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -14,7 +14,7 @@ 'AgentCardRoutes', 'CallContextBuilder', 'DefaultCallContextBuilder', - 'JsonRpcRoute', + 'JsonRpcRoutes', 'RestRoutes', 'StarletteUserProxy', ] From ac1c8729a259467fc0442d964cbcd78d59889688 Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 19 Mar 2026 14:14:53 +0100 Subject: [PATCH 13/17] Update src/a2a/server/routes/agent_card_route.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/server/routes/agent_card_route.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/a2a/server/routes/agent_card_route.py b/src/a2a/server/routes/agent_card_route.py index 782d0712..7e9d2922 100644 --- a/src/a2a/server/routes/agent_card_route.py +++ b/src/a2a/server/routes/agent_card_route.py @@ -59,8 +59,7 @@ def __init__( """ if not _package_starlette_installed: raise ImportError( - 'The `starlette` package is required to use the `JsonRpcRoute`.' - ' It can be added as a part of `a2a-sdk` optional dependencies,' + 'The `starlette` package is required to use the `AgentCardRoutes`.' ' `a2a-sdk[http-server]`.' ) From 94ee11d0c9ccc85403bcfacafbd14dde86da1b1f Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 13:26:46 +0000 Subject: [PATCH 14/17] fix typo --- src/a2a/server/routes/jsonrpc_route.py | 2 +- src/a2a/server/routes/rest_routes.py | 2 ++ .../cross_version/client_server/server_1_0.py | 31 +++++++++++++------ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/a2a/server/routes/jsonrpc_route.py b/src/a2a/server/routes/jsonrpc_route.py index 82a78774..0674af7d 100644 --- a/src/a2a/server/routes/jsonrpc_route.py +++ b/src/a2a/server/routes/jsonrpc_route.py @@ -81,7 +81,7 @@ def __init__( # noqa: PLR0913 """ if not _package_starlette_installed: raise ImportError( - 'The `starlette` package is required to use the `JsonRpcRoute`.' + 'The `starlette` package is required to use the `JsonRpcRoutes`.' ' It can be added as a part of `a2a-sdk` optional dependencies,' ' `a2a-sdk[http-server]`.' ) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 117465a7..8c9c06fe 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -133,6 +133,8 @@ def __init__( # noqa: PLR0913 http_handler=request_handler, extended_agent_card=extended_agent_card, context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, ) self._setup_routes(rpc_url) diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index e079fdf2..56c21bf3 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -5,7 +5,8 @@ import grpc from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes import JsonRpcRoutes, RestRoutes, AgentCardRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -166,17 +167,29 @@ async def main_async(http_port: int, grpc_port: int): app = FastAPI() app.add_middleware(CustomLoggingMiddleware) - jsonrpc_app = A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build() + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=handler, + enable_v0_3_compat=True, + ) + jsonrpc_card = AgentCardRoutes( + agent_card=agent_card, + card_url='/.well-known/agent-card.json', + ) + jsonrpc_app = Starlette(routes=jsonrpc_routes.routes + jsonrpc_card.routes) app.mount('/jsonrpc', jsonrpc_app) - app.mount( - '/rest', - A2ARESTFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build(), + rest_routes = RestRoutes( + agent_card=agent_card, + request_handler=handler, + enable_v0_3_compat=True, + ) + rest_card = AgentCardRoutes( + agent_card=agent_card, + card_url='/.well-known/agent-card.json', ) + rest_app = Starlette(routes=rest_routes.routes + rest_card.routes) + app.mount('/rest', rest_app) # Start gRPC Server server = grpc.aio.server() From 90a030f890fdc7e3a9575bfd8ca09efd8dd415ca Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 13:29:42 +0000 Subject: [PATCH 15/17] fix sample --- samples/hello_world_agent.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 38dfdf56..2316aeb8 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -11,7 +11,7 @@ from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import GrpcHandler from a2a.server.request_handlers.default_request_handler import ( @@ -190,22 +190,29 @@ async def serve( agent_executor=SampleAgentExecutor(), task_store=task_store ) - rest_app_builder = A2ARESTFastAPIApplication( + agent_card_routes = AgentCardRoutes( agent_card=agent_card, - http_handler=request_handler, + card_url='/.well-known/agent-card.json', + ) + # JSON-RPC + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=request_handler, + rpc_url='/a2a/jsonrpc/', enable_v0_3_compat=True, ) - rest_app = rest_app_builder.build() - - jsonrpc_app_builder = A2AFastAPIApplication( + # REST + rest_routes = RestRoutes( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, + rpc_url='/a2a/rest', enable_v0_3_compat=True, ) app = FastAPI() - jsonrpc_app_builder.add_routes_to_app(app, rpc_url='/a2a/jsonrpc/') - app.mount('/a2a/rest', rest_app) + app.routes.extend(agent_card_routes.routes) + app.routes.extend(jsonrpc_routes.routes) + app.routes.extend(rest_routes.routes) grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'{host}:{grpc_port}') From c73aa4eefdcb8daa4566682033d024f577d67ab3 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 14:05:59 +0000 Subject: [PATCH 16/17] fixes --- samples/hello_world_agent.py | 2 +- tests/server/routes/test_jsonrpc_route.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 2316aeb8..bb86c32f 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -11,12 +11,12 @@ from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import GrpcHandler from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes, RestRoutes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import ( diff --git a/tests/server/routes/test_jsonrpc_route.py b/tests/server/routes/test_jsonrpc_route.py index 4acb15ce..dd53922d 100644 --- a/tests/server/routes/test_jsonrpc_route.py +++ b/tests/server/routes/test_jsonrpc_route.py @@ -233,15 +233,13 @@ def test_v0_3_compat_flag_routes_to_dispatcher(self, mock_handler): # Instead of _v03_adapter, the handler handles it or it's dispatcher with patch.object( - routes.dispatcher, - '_process_non_streaming_request', + routes.dispatcher._v03_adapter, + 'handle_request', new_callable=AsyncMock, ) as mock_handle: - mock_handle.return_value = { - 'jsonrpc': '2.0', - 'id': '1', - 'result': {}, - } + mock_handle.return_value = JSONResponse( + {'jsonrpc': '2.0', 'id': '1', 'result': {}} + ) response = client.post('/', json=request_data) From 573eec5c53b85ad96d348878fc6d73ab34e1ff8a Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 14:44:49 +0000 Subject: [PATCH 17/17] add tests --- src/a2a/server/routes/agent_card_route.py | 4 +- src/a2a/server/routes/jsonrpc_dispatcher.py | 28 ++-- src/a2a/server/routes/jsonrpc_route.py | 3 +- src/a2a/server/routes/rest_routes.py | 3 +- .../server/routes/test_jsonrpc_dispatcher.py | 152 ++++++++++++++++++ 5 files changed, 171 insertions(+), 19 deletions(-) create mode 100644 tests/server/routes/test_jsonrpc_dispatcher.py diff --git a/src/a2a/server/routes/agent_card_route.py b/src/a2a/server/routes/agent_card_route.py index 7e9d2922..d1e27d68 100644 --- a/src/a2a/server/routes/agent_card_route.py +++ b/src/a2a/server/routes/agent_card_route.py @@ -31,7 +31,7 @@ from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types.a2a_pb2 import AgentCard from a2a.utils.helpers import maybe_await - +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def __init__( agent_card: AgentCard, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, - card_url: str = '/', + card_url: str = AGENT_CARD_WELL_KNOWN_PATH, middleware: Sequence['Middleware'] | None = None, ) -> None: """Initializes the AgentCardRoute. diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 4bab6855..e50ca034 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -194,7 +194,7 @@ class JsonRpcDispatcher: def __init__( # noqa: PLR0913 self, agent_card: AgentCard, - http_handler: RequestHandler, + request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] @@ -203,7 +203,6 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB enable_v0_3_compat: bool = False, ) -> None: """Initializes the JSONRPCApplication. @@ -237,16 +236,15 @@ def __init__( # noqa: PLR0913 self.extended_agent_card = extended_agent_card self.card_modifier = card_modifier self.extended_card_modifier = extended_card_modifier - self.http_handler = http_handler + self.request_handler = request_handler self._context_builder = context_builder or DefaultCallContextBuilder() - self._max_content_length = max_content_length self.enable_v0_3_compat = enable_v0_3_compat self._v03_adapter: JSONRPC03Adapter | None = None if self.enable_v0_3_compat: self._v03_adapter = JSONRPC03Adapter( agent_card=agent_card, - http_handler=http_handler, + http_handler=request_handler, extended_agent_card=extended_agent_card, context_builder=context_builder, card_modifier=card_modifier, @@ -478,11 +476,11 @@ async def _process_streaming_request( stream: AsyncGenerator | None = None if isinstance(request_obj, SendMessageRequest): - stream = self.http_handler.on_message_send_stream( + stream = self.request_handler.on_message_send_stream( request_obj, context ) elif isinstance(request_obj, SubscribeToTaskRequest): - stream = self.http_handler.on_subscribe_to_task( + stream = self.request_handler.on_subscribe_to_task( request_obj, context ) @@ -522,7 +520,7 @@ async def _process_non_streaming_request( # noqa: PLR0911, PLR0912 """ match request_obj: case SendMessageRequest(): - task_or_message = await self.http_handler.on_message_send( + task_or_message = await self.request_handler.on_message_send( request_obj, context ) if isinstance(task_or_message, Task): @@ -531,19 +529,19 @@ async def _process_non_streaming_request( # noqa: PLR0911, PLR0912 msg_response = SendMessageResponse(message=task_or_message) return MessageToDict(msg_response) case CancelTaskRequest(): - task = await self.http_handler.on_cancel_task( + task = await self.request_handler.on_cancel_task( request_obj, context ) if not task: raise TaskNotFoundError return MessageToDict(task, preserving_proto_field_name=False) case GetTaskRequest(): - task = await self.http_handler.on_get_task(request_obj, context) + task = await self.request_handler.on_get_task(request_obj, context) if not task: raise TaskNotFoundError return MessageToDict(task, preserving_proto_field_name=False) case ListTasksRequest(): - tasks_response = await self.http_handler.on_list_tasks( + tasks_response = await self.request_handler.on_list_tasks( request_obj, context ) return MessageToDict( @@ -556,26 +554,26 @@ async def _process_non_streaming_request( # noqa: PLR0911, PLR0912 raise UnsupportedOperationError( message='Push notifications are not supported by the agent' ) - result_config = await self.http_handler.on_create_task_push_notification_config( + result_config = await self.request_handler.on_create_task_push_notification_config( request_obj, context ) return MessageToDict( result_config, preserving_proto_field_name=False ) case GetTaskPushNotificationConfigRequest(): - config = await self.http_handler.on_get_task_push_notification_config( + config = await self.request_handler.on_get_task_push_notification_config( request_obj, context ) return MessageToDict(config, preserving_proto_field_name=False) case ListTaskPushNotificationConfigsRequest(): - list_push_response = await self.http_handler.on_list_task_push_notification_configs( + list_push_response = await self.request_handler.on_list_task_push_notification_configs( request_obj, context ) return MessageToDict( list_push_response, preserving_proto_field_name=False ) case DeleteTaskPushNotificationConfigRequest(): - await self.http_handler.on_delete_task_push_notification_config( + await self.request_handler.on_delete_task_push_notification_config( request_obj, context ) return None diff --git a/src/a2a/server/routes/jsonrpc_route.py b/src/a2a/server/routes/jsonrpc_route.py index 0674af7d..73bca828 100644 --- a/src/a2a/server/routes/jsonrpc_route.py +++ b/src/a2a/server/routes/jsonrpc_route.py @@ -30,6 +30,7 @@ JsonRpcDispatcher, ) from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.constants import DEFAULT_RPC_URL logger = logging.getLogger(__name__) @@ -56,7 +57,7 @@ def __init__( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, - rpc_url: str = '/', + rpc_url: str = DEFAULT_RPC_URL, middleware: Sequence[Middleware] | None = None, ) -> None: """Initializes the JsonRpcRoute. diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 8c9c06fe..0ab24fcb 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -62,6 +62,7 @@ UnsupportedOperationError, ) from a2a.utils.helpers import maybe_await +from a2a.utils.constants import DEFAULT_RPC_URL logger = logging.getLogger(__name__) @@ -88,7 +89,7 @@ def __init__( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, - rpc_url: str = '', + rpc_url: str = DEFAULT_RPC_URL, middleware: Sequence['Middleware'] | None = None, ) -> None: """Initializes the RestRoutes. diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py new file mode 100644 index 00000000..5a216825 --- /dev/null +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -0,0 +1,152 @@ +import pytest +import json +from unittest.mock import AsyncMock, MagicMock, patch +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.datastructures import Headers + +from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types.a2a_pb2 import ( + AgentCard, Message, Role, Part, GetTaskRequest, Task, + ListTasksResponse, TaskPushNotificationConfig, + ListTaskPushNotificationConfigsResponse +) +from a2a.server.jsonrpc_models import ( + JSONParseError, + InvalidRequestError, + MethodNotFoundError, + InvalidParamsError, +) + +@pytest.fixture +def mock_handler(): + handler = AsyncMock(spec=RequestHandler) + return handler + +@pytest.fixture +def agent_card(): + card = MagicMock(spec=AgentCard) + card.capabilities = MagicMock() + card.capabilities.streaming = True + card.capabilities.push_notifications = True + return card + +class TestJsonRpcDispatcher: + def _create_request(self, body_dict=None, headers=None, body_bytes=None): + """Helper to create a starlette Request for testing""" + scope = { + 'type': 'http', + 'method': 'POST', + 'path': '/', + 'headers': Headers(headers or {}).raw + } + + async def receive(): + if body_bytes: + return {'type': 'http.request', 'body': body_bytes, 'more_body': False} + return {'type': 'http.request', 'body': json.dumps(body_dict or {}).encode('utf-8'), 'more_body': False} + + return Request(scope, receive) + + @pytest.mark.asyncio + async def test_generate_error_response(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + resp = dispatcher._generate_error_response(1, JSONParseError(message='test')) + assert resp.status_code == 200 + data = json.loads(resp.body.decode()) + assert data['id'] == 1 + assert 'error' in data + assert data['error']['code'] == -32700 + + @pytest.mark.asyncio + @pytest.mark.parametrize("method, params, handler_attr, mock_return, expected_key", [ + ("GetTask", {"id": "task-1"}, "on_get_task", Task(id="task-1"), "id"), + ("SendMessage", {"message": {"parts": [{"text": "hi"}]}}, "on_message_send", Task(id="task-1"), "task"), + ("CancelTask", {"id": "task-1"}, "on_cancel_task", Task(id="task-1"), "id"), + ("ListTasks", {}, "on_list_tasks", ListTasksResponse(tasks=[Task(id="task-1")]), "tasks"), + ("CreateTaskPushNotificationConfig", {"taskId": "task-1"}, "on_create_task_push_notification_config", TaskPushNotificationConfig(task_id="task-1"), "taskId"), + ("GetTaskPushNotificationConfig", {"taskId": "task-1"}, "on_get_task_push_notification_config", TaskPushNotificationConfig(task_id="task-1"), "taskId"), + ("ListTaskPushNotificationConfigs", {}, "on_list_task_push_notification_configs", ListTaskPushNotificationConfigsResponse(configs=[TaskPushNotificationConfig(task_id="task-1")]), "configs"), + ("DeleteTaskPushNotificationConfig", {"taskId": "task-1"}, "on_delete_task_push_notification_config", None, None), + ]) + async def test_handle_requests_success_non_streaming(self, agent_card, mock_handler, method, params, handler_attr, mock_return, expected_key): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req_body = { + 'jsonrpc': '2.0', + 'id': 'msg-1', + 'method': method, + 'params': params + } + req = self._create_request(body_dict=req_body) + + mock_func = getattr(mock_handler, handler_attr) + if hasattr(mock_func, 'return_value'): + mock_func.return_value = mock_return + + resp = await dispatcher.handle_requests(req) + assert resp.status_code == 200 + res = json.loads(resp.body.decode()) + assert res['id'] == 'msg-1' + if expected_key: + assert 'result' in res + assert expected_key in res['result'] + assert mock_func.called + + @pytest.mark.asyncio + async def test_handle_requests_payload_too_large(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler, max_content_length=10) + req = self._create_request( + body_dict={'jsonrpc': '2.0', 'id': '1', 'method': 'GetTask'}, + headers={'content-length': '100'} + ) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32600 + assert 'Payload too large' in res['error']['message'] + + @pytest.mark.asyncio + async def test_handle_requests_batch_not_supported(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_dict=[{'jsonrpc': '2.0'}]) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32600 + # The underlying jsonrpc library formats the exact text differently depending on parse path + assert 'Invalid Request' in res['error']['message'] + + @pytest.mark.asyncio + async def test_handle_requests_invalid_jsonrpc_version(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_dict={'jsonrpc': '1.0', 'id': '1', 'method': 'GetTask'}) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32600 + + @pytest.mark.asyncio + async def test_handle_requests_method_not_found(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_dict={'jsonrpc': '2.0', 'id': '1', 'method': 'UnknownMethod'}) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32601 + + @pytest.mark.asyncio + async def test_v03_compat_delegation(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler, enable_v0_3_compat=True) + dispatcher._v03_adapter.supports_method = MagicMock(return_value=True) + dispatcher._v03_adapter.handle_request = AsyncMock(return_value=JSONResponse({'v03': 'compat'})) + + req = self._create_request(body_dict={'jsonrpc': '2.0', 'id': '1', 'method': 'message/send'}) + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res == {'v03': 'compat'} + + @pytest.mark.asyncio + async def test_invalid_json_body_error(self, agent_card, mock_handler): + dispatcher = JsonRpcDispatcher(agent_card, mock_handler) + req = self._create_request(body_bytes=b'{"invalid": json}') + + resp = await dispatcher.handle_requests(req) + res = json.loads(resp.body.decode()) + assert res['error']['code'] == -32700