From 9562522c40a4bae6a326fdf2f9ba52722493ad07 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 11:38:19 +0000 Subject: [PATCH 1/5] feat(server): validate fields presence according to `google.api.field_behavior` annotations --- .../default_request_handler.py | 15 ++- .../request_handlers/request_handler.py | 43 ++++++- src/a2a/utils/errors.py | 3 +- src/a2a/utils/proto_utils.py | 109 +++++++++++++++++- .../test_default_request_handler.py | 50 +++++--- 5 files changed, 200 insertions(+), 20 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index c641b0f12..99bb81fc2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -18,7 +18,10 @@ InMemoryQueueManager, QueueManager, ) -from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.request_handler import ( + RequestHandler, + validate_request_params, +) from a2a.server.tasks import ( PushNotificationConfigStore, PushNotificationEvent, @@ -118,6 +121,7 @@ def __init__( # noqa: PLR0913 # asyncio tasks and to surface unexpected exceptions. self._background_tasks = set() + @validate_request_params async def on_get_task( self, params: GetTaskRequest, @@ -133,6 +137,7 @@ async def on_get_task( return apply_history_length(task, params) + @validate_request_params async def on_list_tasks( self, params: ListTasksRequest, @@ -154,6 +159,7 @@ async def on_list_tasks( return page + @validate_request_params async def on_cancel_task( self, params: CancelTaskRequest, @@ -317,6 +323,7 @@ async def _send_push_notification_if_needed( ): await self._push_sender.send_notification(task_id, event) + @validate_request_params async def on_message_send( self, params: SendMessageRequest, @@ -386,6 +393,7 @@ async def push_notification_callback(event: Event) -> None: return result + @validate_request_params async def on_message_send_stream( self, params: SendMessageRequest, @@ -474,6 +482,7 @@ async def _cleanup_producer( async with self._running_agents_lock: self._running_agents.pop(task_id, None) + @validate_request_params async def on_create_task_push_notification_config( self, params: TaskPushNotificationConfig, @@ -499,6 +508,7 @@ async def on_create_task_push_notification_config( return params + @validate_request_params async def on_get_task_push_notification_config( self, params: GetTaskPushNotificationConfigRequest, @@ -530,6 +540,7 @@ async def on_get_task_push_notification_config( raise InternalError(message='Push notification config not found') + @validate_request_params async def on_subscribe_to_task( self, params: SubscribeToTaskRequest, @@ -572,6 +583,7 @@ async def on_subscribe_to_task( async for event in result_aggregator.consume_and_emit(consumer): yield event + @validate_request_params async def on_list_task_push_notification_configs( self, params: ListTaskPushNotificationConfigsRequest, @@ -597,6 +609,7 @@ async def on_list_task_push_notification_configs( configs=push_notification_config_list ) + @validate_request_params async def on_delete_task_push_notification_config( self, params: DeleteTaskPushNotificationConfigRequest, diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 120a71e37..6fa68b084 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -1,5 +1,11 @@ +import functools +import inspect + from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable +from typing import Any + +from google.protobuf.message import Message as ProtoMessage from a2a.server.context import ServerCallContext from a2a.server.events.event_queue import Event @@ -19,6 +25,7 @@ TaskPushNotificationConfig, ) from a2a.utils.errors import UnsupportedOperationError +from a2a.utils.proto_utils import validate_proto_required_fields class RequestHandler(ABC): @@ -218,3 +225,37 @@ async def on_delete_task_push_notification_config( Returns: None """ + + +def validate_request_params(method: Callable) -> Callable: + """Decorator for RequestHandler methods to validate required fields on incoming requests.""" + if inspect.isasyncgenfunction(method): + + @functools.wraps(method) + async def async_generator_wrapper( + self: RequestHandler, + params: ProtoMessage, + context: ServerCallContext, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator: + if params is not None: + validate_proto_required_fields(params) + async for item in method(self, params, context, *args, **kwargs): + yield item + + return async_generator_wrapper + + @functools.wraps(method) + async def async_wrapper( + self: RequestHandler, + params: ProtoMessage, + context: ServerCallContext, + *args: Any, + **kwargs: Any, + ) -> Any: + if params is not None: + validate_proto_required_fields(params) + return await method(self, params, context, *args, **kwargs) + + return async_wrapper diff --git a/src/a2a/utils/errors.py b/src/a2a/utils/errors.py index a16542d97..c87fa7372 100644 --- a/src/a2a/utils/errors.py +++ b/src/a2a/utils/errors.py @@ -21,9 +21,10 @@ class A2AError(Exception): message: str = 'A2A Error' data: dict | None = None - def __init__(self, message: str | None = None): + def __init__(self, message: str | None = None, data: dict | None = None): if message: self.message = message + self.data = data super().__init__(self.message) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index cdfc306f4..34de6e47a 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -17,11 +17,15 @@ This module provides helper functions for common proto type operations. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict +from google.api.field_behavior_pb2 import FieldBehavior, field_behavior +from google.protobuf.descriptor import FieldDescriptor from google.protobuf.json_format import ParseDict from google.protobuf.message import Message as ProtobufMessage +from a2a.utils.errors import InvalidParamsError + if TYPE_CHECKING: from starlette.datastructures import QueryParams @@ -189,3 +193,106 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None: processed[k] = parsed_val ParseDict(processed, message, ignore_unknown_fields=True) + + +class ValidationDetail(TypedDict): + """Structured validation error detail.""" + + field: str + message: str + + +def _check_required_field_violation( + msg: ProtobufMessage, field: FieldDescriptor +) -> ValidationDetail | None: + """Check if a required field is missing or invalid.""" + val = getattr(msg, field.name) + if field.is_repeated: + if not val: + return ValidationDetail( + field=field.name, + message='Field must contain at least one element.', + ) + elif field.has_presence: + if not msg.HasField(field.name): + return ValidationDetail( + field=field.name, message='Field is required.' + ) + elif val == field.default_value: + return ValidationDetail(field=field.name, message='Field is required.') + return None + + +def _append_nested_errors( + errors: list[ValidationDetail], + prefix: str, + sub_errs: list[ValidationDetail], +) -> None: + """Format nested validation errors and append to errors list.""" + for sub in sub_errs: + sub_field = sub['field'] + errors.append( + ValidationDetail( + field=f'{prefix}.{sub_field}' if sub_field else prefix, + message=sub['message'], + ) + ) + + +def _recurse_validation( + msg: ProtobufMessage, field: FieldDescriptor +) -> list[ValidationDetail]: + """Recurse validation for nested messages and map fields.""" + errors: list[ValidationDetail] = [] + if field.type != FieldDescriptor.TYPE_MESSAGE: + return errors + + val = getattr(msg, field.name) + if not field.is_repeated: + if msg.HasField(field.name): + sub_errs = _validate_proto_required_fields_internal(val) + _append_nested_errors(errors, field.name, sub_errs) + elif field.message_type.GetOptions().map_entry: + for k, v in val.items(): + if isinstance(v, ProtobufMessage): + sub_errs = _validate_proto_required_fields_internal(v) + _append_nested_errors(errors, f'{field.name}[{k}]', sub_errs) + else: + for i, item in enumerate(val): + sub_errs = _validate_proto_required_fields_internal(item) + _append_nested_errors(errors, f'{field.name}[{i}]', sub_errs) + return errors + + +def _validate_proto_required_fields_internal( + msg: ProtobufMessage, +) -> list[ValidationDetail]: + """Internal validation that returns a list of error dictionaries.""" + desc = msg.DESCRIPTOR + errors: list[ValidationDetail] = [] + + for field in desc.fields: + options = field.GetOptions() + if FieldBehavior.REQUIRED in options.Extensions[field_behavior]: + violation = _check_required_field_violation(msg, field) + if violation: + errors.append(violation) + errors.extend(_recurse_validation(msg, field)) + return errors + + +def validate_proto_required_fields(msg: ProtobufMessage) -> None: + """Validate that all fields marked as REQUIRED are present on the proto message. + + Args: + msg: The Protobuf message to validate. + + Raises: + InvalidParamsError: If a required field is missing or empty. + """ + errors = _validate_proto_required_fields_internal(msg) + + if errors: + raise InvalidParamsError( + message='Validation failed', data={'errors': errors} + ) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index ba2627e38..3d22813c6 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -451,7 +451,9 @@ async def test_on_cancel_task_invalid_result_type(): # Mock ResultAggregator to return a Message mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) mock_result_aggregator_instance.consume_all.return_value = Message( - message_id='unexpected_msg', role=Role.ROLE_AGENT, parts=[] + message_id='unexpected_msg', + role=Role.ROLE_AGENT, + parts=[Part(text='Test')], ) request_handler = DefaultRequestHandler( @@ -524,7 +526,7 @@ async def test_on_message_send_with_push_notification(): message=Message( role=Role.ROLE_USER, message_id='msg_push', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -630,7 +632,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request(): message=Message( role=Role.ROLE_USER, message_id='msg_non_blocking', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -750,7 +752,11 @@ async def test_on_message_send_with_push_notification_no_existing_Task(): accepted_output_modes=['text/plain'], # Added required field ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg_push', parts=[]), + message=Message( + role=Role.ROLE_USER, + message_id='msg_push', + parts=[Part(text='Test')], + ), configuration=message_config, ) @@ -815,7 +821,11 @@ async def test_on_message_send_no_result_from_aggregator(): request_context_builder=mock_request_context_builder, ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg_no_res', parts=[]) + message=Message( + role=Role.ROLE_USER, + message_id='msg_no_res', + parts=[Part(text='Test')], + ) ) mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) @@ -863,7 +873,9 @@ async def test_on_message_send_task_id_mismatch(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_id_mismatch', parts=[] + role=Role.ROLE_USER, + message_id='msg_id_mismatch', + parts=[Part(text='Test')], ) ) @@ -1067,7 +1079,9 @@ async def test_on_message_send_interrupted_flow(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_interrupt', parts=[] + role=Role.ROLE_USER, + message_id='msg_interrupt', + parts=[Part(text='Test')], ) ) @@ -1178,7 +1192,7 @@ async def test_on_message_send_stream_with_push_notification(): message=Message( role=Role.ROLE_USER, message_id='msg_stream_push', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -1460,7 +1474,7 @@ async def test_stream_disconnect_then_resubscribe_receives_future_events(): message=Message( role=Role.ROLE_USER, message_id='msg_reconn', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1558,7 +1572,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea message=Message( role=Role.ROLE_USER, message_id='mid', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1698,7 +1712,7 @@ async def cancel( message=Message( role=Role.ROLE_USER, message_id='msg_persist', - parts=[], + parts=[Part(text='Test')], ) ) @@ -1785,7 +1799,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(): message=Message( role=Role.ROLE_USER, message_id='mid_track', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1890,7 +1904,9 @@ async def test_on_message_send_stream_task_id_mismatch(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_stream_mismatch', parts=[] + role=Role.ROLE_USER, + message_id='msg_stream_mismatch', + parts=[Part(text='Test')], ) ) @@ -2586,7 +2602,7 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): message=Message( role=Role.ROLE_USER, message_id='msg_terminal', - parts=[], + parts=[Part(text='Test')], task_id=task_id, ) ) @@ -2627,7 +2643,7 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): message=Message( role=Role.ROLE_USER, message_id='msg_terminal_stream', - parts=[], + parts=[Part(text='Test')], task_id=task_id, ) ) @@ -2869,7 +2885,9 @@ async def test_on_message_send_negative_history_length_error(): accepted_output_modes=['text/plain'], ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg1', parts=[]), + message=Message( + role=Role.ROLE_USER, message_id='msg1', parts=[Part(text='Test')] + ), configuration=message_config, ) context = create_server_call_context() From 1c648a21dcd7088f0e08e87c3f68016838226fb7 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 14:14:43 +0000 Subject: [PATCH 2/5] WIP --- src/a2a/client/transports/grpc.py | 20 +++++++-- src/a2a/client/transports/jsonrpc.py | 3 +- .../server/request_handlers/grpc_handler.py | 21 ++++++++-- .../request_handlers/jsonrpc_handler.py | 1 + tests/integration/test_end_to_end.py | 34 ++++++++++++++- tests/utils/test_proto_utils.py | 41 +++++++++++++++++++ 6 files changed, 110 insertions(+), 10 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 02c418eb3..0945f3bca 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -61,17 +61,29 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: # Use grpc_status to cleanly extract the rich Status from the call status = rpc_status.from_call(cast('grpc.Call', e)) + data = None if status is not None: + exception_cls = None for detail in status.details: - if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): + if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR): + bad_request = error_details_pb2.BadRequest() + detail.Unpack(bad_request) + errors = [ + {'field': v.field, 'message': v.description} + for v in bad_request.field_violations + ] + data = {'errors': errors} + # Infer InvalidParamsError from BadRequest details + exception_cls = A2A_REASON_TO_ERROR.get('INVALID_PARAMS') + elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) - if error_info.domain == 'a2a-protocol.org': exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) - if exception_cls: - raise exception_cls(status.message) from e + + if exception_cls: + raise exception_cls(status.message, data=data) from e raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 9854aabb0..eca6c4897 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -318,9 +318,10 @@ def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception: """Creates the appropriate A2AError from a JSON-RPC error dictionary.""" code = error_dict.get('code') message = error_dict.get('message', str(error_dict)) + data = error_dict.get('data') if isinstance(code, int) and code in _JSON_RPC_ERROR_CODE_TO_A2A_ERROR: - return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message) + return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message, data=data) # Fallback to general A2AClientError return A2AClientError(f'JSON-RPC Error {code}: {message}') diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 326dea236..05277426d 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -438,16 +438,29 @@ async def abort_context( error.message if hasattr(error, 'message') else str(error) ) - # Create standard Status and pack the ErrorInfo + # Create standard Status status = status_pb2.Status(code=status_code, message=error_msg) - detail = any_pb2.Any() - detail.Pack(error_info) - status.details.append(detail) + + # Exclusive details based on error type: + if error.data and error.data.get('errors'): + bad_request = error_details_pb2.BadRequest() + for err_dict in error.data['errors']: + violation = bad_request.field_violations.add() + violation.field = err_dict.get('field', '') + violation.description = err_dict.get('message', '') + any_bad_request = any_pb2.Any() + any_bad_request.Pack(bad_request) + status.details.append(any_bad_request) + else: + detail = any_pb2.Any() + detail.Pack(error_info) + status.details.append(detail) # Use grpc_status to safely generate standard trailing metadata rich_status = rpc_status.to_status(status) new_metadata: list[tuple[str, str | bytes]] = [] + trailing = context.trailing_metadata() if trailing: for k, v in trailing: diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index e7d5b75ad..0fe6c56bd 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -92,6 +92,7 @@ def _build_error_response( jsonrpc_error = model_class( code=code, message=str(error), + data=error.data, ) else: jsonrpc_error = JSONRPCInternalError(message=str(error)) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index ddf9edbf3..11d1b4562 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator -from typing import NamedTuple +from typing import Any, NamedTuple import grpc import httpx @@ -31,6 +31,7 @@ a2a_pb2_grpc, ) from a2a.utils import TransportProtocol +from a2a.utils.errors import InvalidParamsError def assert_message_matches(message, expected_role, expected_text): @@ -546,3 +547,34 @@ async def test_end_to_end_input_required(transport_setups): ], ) assert_message_matches(task.status.message, Role.ROLE_AGENT, 'done') + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'empty_request, expected_fields', + [ + ( + SendMessageRequest(), + ['message'], + ), + ( + SendMessageRequest(message=Message()), + ['message.message_id', 'message.role', 'message.parts'], + ), + ], +) +async def test_end_to_end_validation_errors( + transport_setups, + empty_request: SendMessageRequest, + expected_fields: list[str], +) -> None: + client = transport_setups.client + + with pytest.raises(InvalidParamsError) as exc_info: + async for _ in client.send_message(request=empty_request): + pass + + errors = exc_info.value.data.get('errors', []) + assert {e['field'] for e in errors} == set(expected_fields) + + await client.close() diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 6a53541f3..e2c760bae 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -239,3 +239,44 @@ def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams: return httpx.Request( 'GET', 'http://api.example.com', params=rest_dict ).url.params + + +class TestValidateProtoRequiredFields: + """Tests for validate_proto_required_fields function.""" + + def test_valid_required_fields(self): + """Test with all required fields present.""" + msg = Message( + message_id='msg-1', + role=Role.ROLE_USER, + parts=[Part(text='hello')], + ) + proto_utils.validate_proto_required_fields(msg) + + def test_missing_required_fields(self): + """Test with empty message raising InvalidParamsError containing all errors.""" + from a2a.utils.errors import InvalidParamsError + + msg = Message() + with pytest.raises(InvalidParamsError) as exc_info: + proto_utils.validate_proto_required_fields(msg) + + err = exc_info.value + errors = err.data.get('errors', []) if err.data else [] + + assert {e['field'] for e in errors} == {'message_id', 'role', 'parts'} + + def test_nested_required_fields(self): + """Test nested required fields inside TaskStatus.""" + from a2a.utils.errors import InvalidParamsError + + # Task Status requires 'state' + task = Task(id='task-1', status=TaskStatus()) + with pytest.raises(InvalidParamsError) as exc_info: + proto_utils.validate_proto_required_fields(task) + + err = exc_info.value + errors = err.data.get('errors', []) if err.data else [] + + fields = [e['field'] for e in errors] + assert 'status.state' in fields From 73e54452d8d573da87eeeff467fc405e2d0cfb31 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 14:46:11 +0000 Subject: [PATCH 3/5] WIP --- src/a2a/client/transports/grpc.py | 5 ++--- .../server/request_handlers/grpc_handler.py | 3 +-- tests/integration/test_end_to_end.py | 19 +++++++++++++++---- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 0945f3bca..c91b78a0c 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -47,7 +47,7 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER -from a2a.utils.errors import A2A_REASON_TO_ERROR +from a2a.utils.errors import InvalidParamsError from a2a.utils.telemetry import SpanKind, trace_class @@ -74,8 +74,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: for v in bad_request.field_violations ] data = {'errors': errors} - # Infer InvalidParamsError from BadRequest details - exception_cls = A2A_REASON_TO_ERROR.get('INVALID_PARAMS') + exception_cls = InvalidParamsError elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 05277426d..503faadbb 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -441,8 +441,7 @@ async def abort_context( # Create standard Status status = status_pb2.Status(code=status_code, message=error_msg) - # Exclusive details based on error type: - if error.data and error.data.get('errors'): + if isinstance(error, types.InvalidParamsError) and error.data and error.data.get('errors'): bad_request = error_details_pb2.BadRequest() for err_dict in error.data['errors']: violation = bad_request.field_violations.add() diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 11d1b4562..af9be2e83 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -570,11 +570,22 @@ async def test_end_to_end_validation_errors( ) -> None: client = transport_setups.client - with pytest.raises(InvalidParamsError) as exc_info: + try: async for _ in client.send_message(request=empty_request): pass - - errors = exc_info.value.data.get('errors', []) - assert {e['field'] for e in errors} == set(expected_fields) + except Exception as e: + # ASGITransport propagates server-side generator crashes as ExceptionGroups + exc = e + if hasattr(e, 'exceptions') and len(e.exceptions) == 1: + exc = e.exceptions[0] + + if not isinstance(exc, InvalidParamsError): + raise e + + errors = exc.data.get('errors', []) if exc.data else [] + assert {e['field'] for e in errors} == set(expected_fields) + return + + pytest.fail('InvalidParamsError was not raised') await client.close() From b891f344f4f098d8d2b771fae67a1be076422829 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 14:50:08 +0000 Subject: [PATCH 4/5] Cosmetics --- src/a2a/client/transports/grpc.py | 6 ++++-- src/a2a/client/transports/http_helpers.py | 5 +++++ tests/integration/test_end_to_end.py | 25 +++++++---------------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index c91b78a0c..d4cf35e31 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -47,7 +47,7 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER -from a2a.utils.errors import InvalidParamsError +from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError, InvalidParamsError from a2a.utils.telemetry import SpanKind, trace_class @@ -64,7 +64,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: data = None if status is not None: - exception_cls = None + exception_cls: type[A2AError] | None = None for detail in status.details: if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR): bad_request = error_details_pb2.BadRequest() @@ -75,11 +75,13 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: ] data = {'errors': errors} exception_cls = InvalidParamsError + break elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) if error_info.domain == 'a2a-protocol.org': exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) + break if exception_cls: raise exception_cls(status.message, data=data) from e diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index 0a5721b50..43accadd2 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -40,6 +40,11 @@ def handle_http_exceptions( raise A2AClientError(f'Network communication error: {e}') from e except json.JSONDecodeError as e: raise A2AClientError(f'JSON Decode Error: {e}') from e + except Exception as e: + # ASGITransport propagates local server-side generator crashes as ExceptionGroups + if hasattr(e, 'exceptions') and len(e.exceptions) == 1: + raise e.exceptions[0] from e + raise e def get_http_args(context: ClientCallContext | None) -> dict[str, Any]: diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index af9be2e83..efeea9ad6 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -555,37 +555,26 @@ async def test_end_to_end_input_required(transport_setups): [ ( SendMessageRequest(), - ['message'], + {'message'}, ), ( SendMessageRequest(message=Message()), - ['message.message_id', 'message.role', 'message.parts'], + {'message.message_id', 'message.role', 'message.parts'}, ), ], ) async def test_end_to_end_validation_errors( transport_setups, empty_request: SendMessageRequest, - expected_fields: list[str], + expected_fields: set[str], ) -> None: client = transport_setups.client - try: + with pytest.raises(InvalidParamsError) as exc_info: async for _ in client.send_message(request=empty_request): pass - except Exception as e: - # ASGITransport propagates server-side generator crashes as ExceptionGroups - exc = e - if hasattr(e, 'exceptions') and len(e.exceptions) == 1: - exc = e.exceptions[0] - - if not isinstance(exc, InvalidParamsError): - raise e - - errors = exc.data.get('errors', []) if exc.data else [] - assert {e['field'] for e in errors} == set(expected_fields) - return - - pytest.fail('InvalidParamsError was not raised') + + errors = exc_info.value.data.get('errors', []) + assert {e['field'] for e in errors} == expected_fields await client.close() From 02593433dd1ba09625a0e57c9e409792708e5011 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 20 Mar 2026 15:13:55 +0000 Subject: [PATCH 5/5] wip --- src/a2a/client/transports/grpc.py | 2 +- src/a2a/server/apps/rest/rest_adapter.py | 2 +- .../server/request_handlers/grpc_handler.py | 6 ++++- .../request_handlers/request_handler.py | 17 +++++++------ .../server/request_handlers/rest_handler.py | 24 ++++++++++++------- 5 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index d4cf35e31..e8614bbb6 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -76,7 +76,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: data = {'errors': errors} exception_cls = InvalidParamsError break - elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): + if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) if error_info.domain == 'a2a-protocol.org': diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 6b8abb99e..e44120f3d 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -159,7 +159,7 @@ async def event_generator( yield json.dumps(item) return EventSourceResponse( - event_generator(method(request, call_context)) + event_generator(await method(request, call_context)) ) async def handle_get_agent_card( diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 4a66e95b7..e9f3f2fe8 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -406,7 +406,11 @@ async def abort_context( # Create standard Status status = status_pb2.Status(code=status_code, message=error_msg) - if isinstance(error, types.InvalidParamsError) and error.data and error.data.get('errors'): + if ( + isinstance(error, types.InvalidParamsError) + and error.data + and error.data.get('errors') + ): bad_request = error_details_pb2.BadRequest() for err_dict in error.data['errors']: violation = bad_request.field_violations.add() diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 6fa68b084..ed1ba6e52 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -229,25 +229,24 @@ async def on_delete_task_push_notification_config( def validate_request_params(method: Callable) -> Callable: """Decorator for RequestHandler methods to validate required fields on incoming requests.""" - if inspect.isasyncgenfunction(method): + if inspect.iscoroutinefunction(method): @functools.wraps(method) - async def async_generator_wrapper( + async def async_wrapper( self: RequestHandler, params: ProtoMessage, context: ServerCallContext, *args: Any, **kwargs: Any, - ) -> AsyncGenerator: + ) -> Any: if params is not None: validate_proto_required_fields(params) - async for item in method(self, params, context, *args, **kwargs): - yield item + return await method(self, params, context, *args, **kwargs) - return async_generator_wrapper + return async_wrapper @functools.wraps(method) - async def async_wrapper( + def sync_wrapper( self: RequestHandler, params: ProtoMessage, context: ServerCallContext, @@ -256,6 +255,6 @@ async def async_wrapper( ) -> Any: if params is not None: validate_proto_required_fields(params) - return await method(self, params, context, *args, **kwargs) + return method(self, params, context, *args, **kwargs) - return async_wrapper + return sync_wrapper diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index af889d9df..50a9f2ac6 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -116,11 +116,14 @@ async def on_message_send_stream( body = await request.body() params = a2a_pb2.SendMessageRequest() Parse(body, params) - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) + stream = self.request_handler.on_message_send_stream(params, context) + + async def _generator() -> AsyncIterator[dict[str, Any]]: + async for event in stream: + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) + + return _generator() @validate_version(constants.PROTOCOL_VERSION_1_0) async def on_cancel_task( @@ -167,10 +170,15 @@ async def on_subscribe_to_task( JSON serialized objects containing streaming events """ task_id = request.path_params['id'] - async for event in self.request_handler.on_subscribe_to_task( + stream = self.request_handler.on_subscribe_to_task( SubscribeToTaskRequest(id=task_id), context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) + ) + + async def _generator() -> AsyncIterator[dict[str, Any]]: + async for event in stream: + yield MessageToDict(proto_utils.to_stream_response(event)) + + return _generator() @validate_version(constants.PROTOCOL_VERSION_1_0) async def get_push_notification(