Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging

from collections.abc import AsyncGenerator, Callable
from collections.abc import AsyncGenerator, Callable, Iterable
from functools import wraps
from typing import Any, NoReturn
from typing import Any, NoReturn, cast

from a2a.client.errors import A2AClientError, A2AClientTimeoutError
from a2a.client.middleware import ClientCallContext
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, A2AError


try:
Expand All @@ -18,8 +19,12 @@
) from e


from google.rpc import ( # type: ignore[reportMissingModuleSource]
error_details_pb2,
status_pb2,
)

from a2a.client.client import ClientConfig
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.optionals import Channel
from a2a.client.transports.base import ClientTransport
Expand All @@ -44,6 +49,7 @@
TaskPushNotificationConfig,
)
from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER
from a2a.utils.errors import A2A_REASON_TO_ERROR
from a2a.utils.telemetry import SpanKind, trace_class


Expand All @@ -54,47 +60,77 @@
}


def _parse_rich_grpc_error(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value: bytes, original_error: grpc.aio.AioRpcError
) -> None:
try:
status = status_pb2.Status.FromString(value)
for detail in status.details:
if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
error_info = error_details_pb2.ErrorInfo()
detail.Unpack(error_info)

if error_info.domain == 'a2a-protocol.org':
exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason)
if exception_cls:
raise exception_cls(status.message) from original_error # noqa: TRY301
except Exception as parse_e:
# Don't swallow A2A errors generated above
if isinstance(parse_e, (A2AError, A2AClientError)):
raise parse_e
logger.warning(
'Failed to parse grpc-status-details-bin', exc_info=parse_e
)


def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
raise A2AClientTimeoutError('Client Request timed out') from e

metadata = e.trailing_metadata()
if metadata:
iterable_metadata = cast('Iterable[tuple[str, str | bytes]]', metadata)
for key, value in iterable_metadata:
if key == 'grpc-status-details-bin' and isinstance(value, bytes):
_parse_rich_grpc_error(value, e)

details = e.details()
if isinstance(details, str) and ': ' in details:
error_type_name, error_message = details.split(': ', 1)
# TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723.
# Leaving as fallback for errors that don't use the rich error details.
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name)
if exception_cls:
raise exception_cls(error_message) from e
Comment on lines 98 to 103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was added in 1.0-dev branch and was never released (#761). It is safe to remove it.

raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e


def _handle_grpc_exception(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return await func(*args, **kwargs)
except grpc.aio.AioRpcError as e:
_map_grpc_error(e)

return wrapper


def _handle_grpc_stream_exception(
func: Callable[..., Any],
) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
async for item in func(*args, **kwargs):
yield item
except grpc.aio.AioRpcError as e:
_map_grpc_error(e)

return wrapper


@trace_class(kind=SpanKind.CLIENT)
class GrpcTransport(ClientTransport):

Check notice on line 133 in src/a2a/client/transports/grpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/compat/v0_3/grpc_transport.py (53-85)
"""A gRPC transport for the A2A client."""

def __init__(
Expand Down
39 changes: 33 additions & 6 deletions src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import logging

from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import cast


try:
Expand All @@ -16,9 +17,8 @@
"'pip install a2a-sdk[grpc]'"
) from e

from collections.abc import Callable

from google.protobuf import empty_pb2, message
from google.protobuf import any_pb2, empty_pb2, message
from google.rpc import error_details_pb2, status_pb2

import a2a.types.a2a_pb2_grpc as a2a_grpc

Expand All @@ -33,7 +33,7 @@
from a2a.types import a2a_pb2
from a2a.types.a2a_pb2 import AgentCard
from a2a.utils import proto_utils
from a2a.utils.errors import A2AError, TaskNotFoundError
from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError
from a2a.utils.helpers import maybe_await, validate, validate_async_generator


Expand Down Expand Up @@ -419,31 +419,58 @@
) -> None:
"""Sets the grpc errors appropriately in the context."""
code = _ERROR_CODE_MAP.get(type(error))

status_value = code.value if code else grpc.StatusCode.UNKNOWN.value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

status_code = (
status_value[0] if isinstance(status_value, tuple) else status_value
)
error_msg = error.message if hasattr(error, 'message') else str(error)
status = status_pb2.Status(code=status_code, message=error_msg)

if code:
reason = A2A_ERROR_REASONS.get(type(error), 'UNKNOWN_ERROR')

error_info = error_details_pb2.ErrorInfo(
reason=reason,
domain='a2a-protocol.org',
)

detail = any_pb2.Any()
detail.Pack(error_info)
status.details.append(detail)

context.set_trailing_metadata(
cast(
'tuple[tuple[str, str | bytes], ...]',
(('grpc-status-details-bin', status.SerializeToString()),),
)
)

if code:
await context.abort(
code,
f'{type(error).__name__}: {error.message}',
status.message,
)
else:
await context.abort(
grpc.StatusCode.UNKNOWN,
f'Unknown error type: {error}',
)

def _set_extension_metadata(
self,
context: grpc.aio.ServicerContext,
server_context: ServerCallContext,
) -> None:
if server_context.activated_extensions:
context.set_trailing_metadata(
[
(HTTP_EXTENSION_HEADER.lower(), e)
for e in sorted(server_context.activated_extensions)
]
)

def _build_call_context(

Check notice on line 473 in src/a2a/server/request_handlers/grpc_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/compat/v0_3/grpc_handler.py (156-177)
self,
context: grpc.aio.ServicerContext,
request: message.Message,
Expand Down
35 changes: 35 additions & 0 deletions src/a2a/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,26 @@
message = 'Method not found'


class ExtensionSupportRequiredError(A2AError):
"""Exception raised when extension support is required but not present."""

message = 'Extension support required'


class VersionNotSupportedError(A2AError):
"""Exception raised when the requested version is not supported."""

message = 'Version not supported'


# For backward compatibility if needed, or just aliases for clean refactor
# We remove the Pydantic models here.

__all__ = [
'A2A_ERROR_REASONS',
'A2A_REASON_TO_ERROR',
'JSON_RPC_ERROR_CODE_MAP',
'ExtensionSupportRequiredError',
'InternalError',
'InvalidAgentResponseError',
'InvalidParamsError',
Expand All @@ -96,19 +111,39 @@
'TaskNotCancelableError',
'TaskNotFoundError',
'UnsupportedOperationError',
'VersionNotSupportedError',
]


JSON_RPC_ERROR_CODE_MAP: dict[type[A2AError], int] = {
TaskNotFoundError: -32001,
TaskNotCancelableError: -32002,
PushNotificationNotSupportedError: -32003,
UnsupportedOperationError: -32004,
ContentTypeNotSupportedError: -32005,
InvalidAgentResponseError: -32006,
AuthenticatedExtendedCardNotConfiguredError: -32007,
InvalidParamsError: -32602,
InvalidRequestError: -32600,
MethodNotFoundError: -32601,
InternalError: -32603,

Check notice on line 129 in src/a2a/utils/errors.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/response_helpers.py (59-70)
}


A2A_ERROR_REASONS = {
TaskNotFoundError: 'TASK_NOT_FOUND',
TaskNotCancelableError: 'TASK_NOT_CANCELABLE',
PushNotificationNotSupportedError: 'PUSH_NOTIFICATION_NOT_SUPPORTED',
UnsupportedOperationError: 'UNSUPPORTED_OPERATION',
ContentTypeNotSupportedError: 'CONTENT_TYPE_NOT_SUPPORTED',
InvalidAgentResponseError: 'INVALID_AGENT_RESPONSE',
AuthenticatedExtendedCardNotConfiguredError: 'EXTENDED_AGENT_CARD_NOT_CONFIGURED',
ExtensionSupportRequiredError: 'EXTENSION_SUPPORT_REQUIRED',
VersionNotSupportedError: 'VERSION_NOT_SUPPORTED',
InvalidParamsError: 'INVALID_PARAMS',
InvalidRequestError: 'INVALID_REQUEST',
MethodNotFoundError: 'METHOD_NOT_FOUND',
InternalError: 'INTERNAL_ERROR',
Comment on lines +143 to +146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe those are JSON-RPC leftovers, we can keep A2A errors only here: https://a2a-protocol.org/latest/specification/#54-error-code-mappings.

}

A2A_REASON_TO_ERROR = {reason: cls for cls, reason in A2A_ERROR_REASONS.items()}
49 changes: 46 additions & 3 deletions tests/client/transports/test_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import grpc
import pytest

from google.protobuf import any_pb2
from google.rpc import error_details_pb2, status_pb2

from a2a.client.middleware import ClientCallContext
from a2a.client.transports.grpc import GrpcTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT
from a2a.utils.errors import A2A_ERROR_REASONS
from a2a.types import a2a_pb2
from a2a.types.a2a_pb2 import (
AgentCapabilities,
Expand Down Expand Up @@ -257,16 +261,15 @@ async def test_send_message_with_timeout_context(

@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
@pytest.mark.asyncio
async def test_grpc_mapped_errors(
async def test_grpc_mapped_errors_legacy(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_message_send_params: SendMessageRequest,
error_cls,
) -> None:
"""Test handling of mapped gRPC error responses."""
"""Test handling of legacy gRPC error responses."""
error_details = f'{error_cls.__name__}: Mapped Error'

# We must trigger it from a standard transport method call, for example `send_message`.
mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError(
code=grpc.StatusCode.INTERNAL,
initial_metadata=grpc.aio.Metadata(),
Expand All @@ -278,6 +281,46 @@ async def test_grpc_mapped_errors(
await grpc_transport.send_message(sample_message_send_params)


@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
@pytest.mark.asyncio
async def test_grpc_mapped_errors_rich(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_message_send_params: SendMessageRequest,
error_cls,
) -> None:
"""Test handling of rich gRPC error responses with Status metadata."""

reason = A2A_ERROR_REASONS.get(error_cls, 'UNKNOWN_ERROR')

error_info = error_details_pb2.ErrorInfo(
reason=reason,
domain='a2a-protocol.org',
)

error_details = f'{error_cls.__name__}: Mapped Error'
status = status_pb2.Status(
code=grpc.StatusCode.INTERNAL.value[0], message=error_details
)
detail = any_pb2.Any()
detail.Pack(error_info)
status.details.append(detail)

mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError(
code=grpc.StatusCode.INTERNAL,
initial_metadata=grpc.aio.Metadata(),
trailing_metadata=grpc.aio.Metadata(
('grpc-status-details-bin', status.SerializeToString()),
),
details='A generic error message',
)

with pytest.raises(error_cls) as excinfo:
await grpc_transport.send_message(sample_message_send_params)

assert str(excinfo.value) == error_details


@pytest.mark.asyncio
async def test_send_message_message_response(
grpc_transport: GrpcTransport,
Expand Down
46 changes: 42 additions & 4 deletions tests/server/request_handlers/test_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import grpc.aio
import pytest

from google.rpc import error_details_pb2, status_pb2
from a2a import types
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.server.context import ServerCallContext
Expand Down Expand Up @@ -99,7 +100,7 @@ async def test_send_message_server_error(
await grpc_handler.SendMessage(request_proto, mock_grpc_context)

mock_grpc_context.abort.assert_awaited_once_with(
grpc.StatusCode.INVALID_ARGUMENT, 'InvalidParamsError: Bad params'
grpc.StatusCode.INVALID_ARGUMENT, 'Bad params'
)


Expand Down Expand Up @@ -138,7 +139,7 @@ async def test_get_task_not_found(
await grpc_handler.GetTask(request_proto, mock_grpc_context)

mock_grpc_context.abort.assert_awaited_once_with(
grpc.StatusCode.NOT_FOUND, 'TaskNotFoundError: Task not found'
grpc.StatusCode.NOT_FOUND, 'Task not found'
)


Expand All @@ -157,7 +158,7 @@ async def test_cancel_task_server_error(

mock_grpc_context.abort.assert_awaited_once_with(
grpc.StatusCode.UNIMPLEMENTED,
'TaskNotCancelableError: Task cannot be canceled',
'Task cannot be canceled',
)


Expand Down Expand Up @@ -379,7 +380,44 @@ async def test_abort_context_error_mapping( # noqa: PLR0913
mock_grpc_context.abort.assert_awaited_once()
call_args, _ = mock_grpc_context.abort.call_args
assert call_args[0] == grpc_status_code
assert error_message_part in call_args[1]

# We shouldn't rely on the legacy ExceptionName: message string format
# But for backward compatability fallback it shouldn't fail
mock_grpc_context.set_trailing_metadata.assert_called_once()
metadata = mock_grpc_context.set_trailing_metadata.call_args[0][0]

assert any(key == 'grpc-status-details-bin' for key, _ in metadata)


@pytest.mark.asyncio
async def test_abort_context_rich_error_format(
grpc_handler: GrpcHandler,
mock_request_handler: AsyncMock,
mock_grpc_context: AsyncMock,
) -> None:

error = types.TaskNotFoundError('Could not find the task')
mock_request_handler.on_get_task.side_effect = error
request_proto = a2a_pb2.GetTaskRequest(id='any')
await grpc_handler.GetTask(request_proto, mock_grpc_context)

mock_grpc_context.set_trailing_metadata.assert_called_once()
metadata = mock_grpc_context.set_trailing_metadata.call_args[0][0]

bin_values = [v for k, v in metadata if k == 'grpc-status-details-bin']
assert len(bin_values) == 1

status = status_pb2.Status.FromString(bin_values[0])
assert status.code == grpc.StatusCode.NOT_FOUND.value[0]
assert status.message == 'Could not find the task'

assert len(status.details) == 1

error_info = error_details_pb2.ErrorInfo()
status.details[0].Unpack(error_info)

assert error_info.reason == 'TASK_NOT_FOUND'
assert error_info.domain == 'a2a-protocol.org'


@pytest.mark.asyncio
Expand Down
Loading