From fdb156bb866ae3ea44c7bce2a0b865d0c33966a1 Mon Sep 17 00:00:00 2001 From: knap Date: Mon, 9 Mar 2026 11:52:40 +0000 Subject: [PATCH] feat: implement gRPC rich error details for A2A errors, including new error types and client-side parsing. --- src/a2a/client/transports/grpc.py | 46 +++++++++++++++-- .../server/request_handlers/grpc_handler.py | 39 ++++++++++++--- src/a2a/utils/errors.py | 35 +++++++++++++ tests/client/transports/test_grpc_client.py | 49 +++++++++++++++++-- .../request_handlers/test_grpc_handler.py | 46 +++++++++++++++-- 5 files changed, 197 insertions(+), 18 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 231c1ebb..a33ea06e 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -1,11 +1,12 @@ import logging -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Callable, Iterable from functools import wraps -from typing import Any, NoReturn +from typing import Any, NoReturn, cast +from a2a.client.errors import A2AClientError, A2AClientTimeoutError from a2a.client.middleware import ClientCallContext -from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP +from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, A2AError try: @@ -18,8 +19,12 @@ ) from e +from google.rpc import ( # type: ignore[reportMissingModuleSource] + error_details_pb2, + status_pb2, +) + from a2a.client.client import ClientConfig -from a2a.client.errors import A2AClientError, A2AClientTimeoutError from a2a.client.middleware import ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport @@ -44,6 +49,7 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER +from a2a.utils.errors import A2A_REASON_TO_ERROR from a2a.utils.telemetry import SpanKind, trace_class @@ -54,14 +60,44 @@ } +def _parse_rich_grpc_error( + value: bytes, original_error: grpc.aio.AioRpcError +) -> None: + try: + status = status_pb2.Status.FromString(value) + for detail in status.details: + if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): + error_info = error_details_pb2.ErrorInfo() + detail.Unpack(error_info) + + if error_info.domain == 'a2a-protocol.org': + exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) + if exception_cls: + raise exception_cls(status.message) from original_error # noqa: TRY301 + except Exception as parse_e: + # Don't swallow A2A errors generated above + if isinstance(parse_e, (A2AError, A2AClientError)): + raise parse_e + logger.warning( + 'Failed to parse grpc-status-details-bin', exc_info=parse_e + ) + + def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: raise A2AClientTimeoutError('Client Request timed out') from e + metadata = e.trailing_metadata() + if metadata: + iterable_metadata = cast('Iterable[tuple[str, str | bytes]]', metadata) + for key, value in iterable_metadata: + if key == 'grpc-status-details-bin' and isinstance(value, bytes): + _parse_rich_grpc_error(value, e) + details = e.details() if isinstance(details, str) and ': ' in details: error_type_name, error_message = details.split(': ', 1) - # TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723. + # Leaving as fallback for errors that don't use the rich error details. exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name) if exception_cls: raise exception_cls(error_message) from e diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index fd9d042f..57aea73e 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -3,7 +3,8 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Awaitable +from collections.abc import AsyncIterable, Awaitable, Callable +from typing import cast try: @@ -16,9 +17,8 @@ "'pip install a2a-sdk[grpc]'" ) from e -from collections.abc import Callable - -from google.protobuf import empty_pb2, message +from google.protobuf import any_pb2, empty_pb2, message +from google.rpc import error_details_pb2, status_pb2 import a2a.types.a2a_pb2_grpc as a2a_grpc @@ -33,7 +33,7 @@ from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils -from a2a.utils.errors import A2AError, TaskNotFoundError +from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError from a2a.utils.helpers import maybe_await, validate, validate_async_generator @@ -419,10 +419,37 @@ async def abort_context( ) -> None: """Sets the grpc errors appropriately in the context.""" code = _ERROR_CODE_MAP.get(type(error)) + + status_value = code.value if code else grpc.StatusCode.UNKNOWN.value + status_code = ( + status_value[0] if isinstance(status_value, tuple) else status_value + ) + error_msg = error.message if hasattr(error, 'message') else str(error) + status = status_pb2.Status(code=status_code, message=error_msg) + + if code: + reason = A2A_ERROR_REASONS.get(type(error), 'UNKNOWN_ERROR') + + error_info = error_details_pb2.ErrorInfo( + reason=reason, + domain='a2a-protocol.org', + ) + + detail = any_pb2.Any() + detail.Pack(error_info) + status.details.append(detail) + + context.set_trailing_metadata( + cast( + 'tuple[tuple[str, str | bytes], ...]', + (('grpc-status-details-bin', status.SerializeToString()),), + ) + ) + if code: await context.abort( code, - f'{type(error).__name__}: {error.message}', + status.message, ) else: await context.abort( diff --git a/src/a2a/utils/errors.py b/src/a2a/utils/errors.py index 845bbfca..b4e9fbca 100644 --- a/src/a2a/utils/errors.py +++ b/src/a2a/utils/errors.py @@ -82,11 +82,26 @@ class MethodNotFoundError(A2AError): message = 'Method not found' +class ExtensionSupportRequiredError(A2AError): + """Exception raised when extension support is required but not present.""" + + message = 'Extension support required' + + +class VersionNotSupportedError(A2AError): + """Exception raised when the requested version is not supported.""" + + message = 'Version not supported' + + # For backward compatibility if needed, or just aliases for clean refactor # We remove the Pydantic models here. __all__ = [ + 'A2A_ERROR_REASONS', + 'A2A_REASON_TO_ERROR', 'JSON_RPC_ERROR_CODE_MAP', + 'ExtensionSupportRequiredError', 'InternalError', 'InvalidAgentResponseError', 'InvalidParamsError', @@ -96,6 +111,7 @@ class MethodNotFoundError(A2AError): 'TaskNotCancelableError', 'TaskNotFoundError', 'UnsupportedOperationError', + 'VersionNotSupportedError', ] @@ -112,3 +128,22 @@ class MethodNotFoundError(A2AError): MethodNotFoundError: -32601, InternalError: -32603, } + + +A2A_ERROR_REASONS = { + TaskNotFoundError: 'TASK_NOT_FOUND', + TaskNotCancelableError: 'TASK_NOT_CANCELABLE', + PushNotificationNotSupportedError: 'PUSH_NOTIFICATION_NOT_SUPPORTED', + UnsupportedOperationError: 'UNSUPPORTED_OPERATION', + ContentTypeNotSupportedError: 'CONTENT_TYPE_NOT_SUPPORTED', + InvalidAgentResponseError: 'INVALID_AGENT_RESPONSE', + AuthenticatedExtendedCardNotConfiguredError: 'EXTENDED_AGENT_CARD_NOT_CONFIGURED', + ExtensionSupportRequiredError: 'EXTENSION_SUPPORT_REQUIRED', + VersionNotSupportedError: 'VERSION_NOT_SUPPORTED', + InvalidParamsError: 'INVALID_PARAMS', + InvalidRequestError: 'INVALID_REQUEST', + MethodNotFoundError: 'METHOD_NOT_FOUND', + InternalError: 'INTERNAL_ERROR', +} + +A2A_REASON_TO_ERROR = {reason: cls for cls, reason in A2A_ERROR_REASONS.items()} diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index a070b18f..4984a58b 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -3,10 +3,14 @@ import grpc import pytest +from google.protobuf import any_pb2 +from google.rpc import error_details_pb2, status_pb2 + from a2a.client.middleware import ClientCallContext from a2a.client.transports.grpc import GrpcTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT +from a2a.utils.errors import A2A_ERROR_REASONS from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -257,16 +261,15 @@ async def test_send_message_with_timeout_context( @pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys())) @pytest.mark.asyncio -async def test_grpc_mapped_errors( +async def test_grpc_mapped_errors_legacy( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_message_send_params: SendMessageRequest, error_cls, ) -> None: - """Test handling of mapped gRPC error responses.""" + """Test handling of legacy gRPC error responses.""" error_details = f'{error_cls.__name__}: Mapped Error' - # We must trigger it from a standard transport method call, for example `send_message`. mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError( code=grpc.StatusCode.INTERNAL, initial_metadata=grpc.aio.Metadata(), @@ -278,6 +281,46 @@ async def test_grpc_mapped_errors( await grpc_transport.send_message(sample_message_send_params) +@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys())) +@pytest.mark.asyncio +async def test_grpc_mapped_errors_rich( + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_message_send_params: SendMessageRequest, + error_cls, +) -> None: + """Test handling of rich gRPC error responses with Status metadata.""" + + reason = A2A_ERROR_REASONS.get(error_cls, 'UNKNOWN_ERROR') + + error_info = error_details_pb2.ErrorInfo( + reason=reason, + domain='a2a-protocol.org', + ) + + error_details = f'{error_cls.__name__}: Mapped Error' + status = status_pb2.Status( + code=grpc.StatusCode.INTERNAL.value[0], message=error_details + ) + detail = any_pb2.Any() + detail.Pack(error_info) + status.details.append(detail) + + mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError( + code=grpc.StatusCode.INTERNAL, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata( + ('grpc-status-details-bin', status.SerializeToString()), + ), + details='A generic error message', + ) + + with pytest.raises(error_cls) as excinfo: + await grpc_transport.send_message(sample_message_send_params) + + assert str(excinfo.value) == error_details + + @pytest.mark.asyncio async def test_send_message_message_response( grpc_transport: GrpcTransport, diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 88f050aa..954d7e01 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -5,6 +5,7 @@ import grpc.aio import pytest +from google.rpc import error_details_pb2, status_pb2 from a2a import types from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext @@ -99,7 +100,7 @@ async def test_send_message_server_error( await grpc_handler.SendMessage(request_proto, mock_grpc_context) mock_grpc_context.abort.assert_awaited_once_with( - grpc.StatusCode.INVALID_ARGUMENT, 'InvalidParamsError: Bad params' + grpc.StatusCode.INVALID_ARGUMENT, 'Bad params' ) @@ -138,7 +139,7 @@ async def test_get_task_not_found( await grpc_handler.GetTask(request_proto, mock_grpc_context) mock_grpc_context.abort.assert_awaited_once_with( - grpc.StatusCode.NOT_FOUND, 'TaskNotFoundError: Task not found' + grpc.StatusCode.NOT_FOUND, 'Task not found' ) @@ -157,7 +158,7 @@ async def test_cancel_task_server_error( mock_grpc_context.abort.assert_awaited_once_with( grpc.StatusCode.UNIMPLEMENTED, - 'TaskNotCancelableError: Task cannot be canceled', + 'Task cannot be canceled', ) @@ -379,7 +380,44 @@ async def test_abort_context_error_mapping( # noqa: PLR0913 mock_grpc_context.abort.assert_awaited_once() call_args, _ = mock_grpc_context.abort.call_args assert call_args[0] == grpc_status_code - assert error_message_part in call_args[1] + + # We shouldn't rely on the legacy ExceptionName: message string format + # But for backward compatability fallback it shouldn't fail + mock_grpc_context.set_trailing_metadata.assert_called_once() + metadata = mock_grpc_context.set_trailing_metadata.call_args[0][0] + + assert any(key == 'grpc-status-details-bin' for key, _ in metadata) + + +@pytest.mark.asyncio +async def test_abort_context_rich_error_format( + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, +) -> None: + + error = types.TaskNotFoundError('Could not find the task') + mock_request_handler.on_get_task.side_effect = error + request_proto = a2a_pb2.GetTaskRequest(id='any') + await grpc_handler.GetTask(request_proto, mock_grpc_context) + + mock_grpc_context.set_trailing_metadata.assert_called_once() + metadata = mock_grpc_context.set_trailing_metadata.call_args[0][0] + + bin_values = [v for k, v in metadata if k == 'grpc-status-details-bin'] + assert len(bin_values) == 1 + + status = status_pb2.Status.FromString(bin_values[0]) + assert status.code == grpc.StatusCode.NOT_FOUND.value[0] + assert status.message == 'Could not find the task' + + assert len(status.details) == 1 + + error_info = error_details_pb2.ErrorInfo() + status.details[0].Unpack(error_info) + + assert error_info.reason == 'TASK_NOT_FOUND' + assert error_info.domain == 'a2a-protocol.org' @pytest.mark.asyncio