diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 02c418eb3..e8614bbb6 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -47,7 +47,7 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER -from a2a.utils.errors import A2A_REASON_TO_ERROR +from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError, InvalidParamsError from a2a.utils.telemetry import SpanKind, trace_class @@ -61,17 +61,30 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: # Use grpc_status to cleanly extract the rich Status from the call status = rpc_status.from_call(cast('grpc.Call', e)) + data = None if status is not None: + exception_cls: type[A2AError] | None = None for detail in status.details: + if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR): + bad_request = error_details_pb2.BadRequest() + detail.Unpack(bad_request) + errors = [ + {'field': v.field, 'message': v.description} + for v in bad_request.field_violations + ] + data = {'errors': errors} + exception_cls = InvalidParamsError + break if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) - if error_info.domain == 'a2a-protocol.org': exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) - if exception_cls: - raise exception_cls(status.message) from e + break + + if exception_cls: + raise exception_cls(status.message, data=data) from e raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 9854aabb0..eca6c4897 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -318,9 +318,10 @@ def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception: """Creates the appropriate A2AError from a JSON-RPC error dictionary.""" code = error_dict.get('code') message = error_dict.get('message', str(error_dict)) + data = error_dict.get('data') if isinstance(code, int) and code in _JSON_RPC_ERROR_CODE_TO_A2A_ERROR: - return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message) + return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message, data=data) # Fallback to general A2AClientError return A2AClientError(f'JSON-RPC Error {code}: {message}') diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 6b8abb99e..e44120f3d 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -159,7 +159,7 @@ async def event_generator( yield json.dumps(item) return EventSourceResponse( - event_generator(method(request, call_context)) + event_generator(await method(request, call_context)) ) async def handle_get_agent_card( diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index c641b0f12..99bb81fc2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -18,7 +18,10 @@ InMemoryQueueManager, QueueManager, ) -from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.request_handler import ( + RequestHandler, + validate_request_params, +) from a2a.server.tasks import ( PushNotificationConfigStore, PushNotificationEvent, @@ -118,6 +121,7 @@ def __init__( # noqa: PLR0913 # asyncio tasks and to surface unexpected exceptions. self._background_tasks = set() + @validate_request_params async def on_get_task( self, params: GetTaskRequest, @@ -133,6 +137,7 @@ async def on_get_task( return apply_history_length(task, params) + @validate_request_params async def on_list_tasks( self, params: ListTasksRequest, @@ -154,6 +159,7 @@ async def on_list_tasks( return page + @validate_request_params async def on_cancel_task( self, params: CancelTaskRequest, @@ -317,6 +323,7 @@ async def _send_push_notification_if_needed( ): await self._push_sender.send_notification(task_id, event) + @validate_request_params async def on_message_send( self, params: SendMessageRequest, @@ -386,6 +393,7 @@ async def push_notification_callback(event: Event) -> None: return result + @validate_request_params async def on_message_send_stream( self, params: SendMessageRequest, @@ -474,6 +482,7 @@ async def _cleanup_producer( async with self._running_agents_lock: self._running_agents.pop(task_id, None) + @validate_request_params async def on_create_task_push_notification_config( self, params: TaskPushNotificationConfig, @@ -499,6 +508,7 @@ async def on_create_task_push_notification_config( return params + @validate_request_params async def on_get_task_push_notification_config( self, params: GetTaskPushNotificationConfigRequest, @@ -530,6 +540,7 @@ async def on_get_task_push_notification_config( raise InternalError(message='Push notification config not found') + @validate_request_params async def on_subscribe_to_task( self, params: SubscribeToTaskRequest, @@ -572,6 +583,7 @@ async def on_subscribe_to_task( async for event in result_aggregator.consume_and_emit(consumer): yield event + @validate_request_params async def on_list_task_push_notification_configs( self, params: ListTaskPushNotificationConfigsRequest, @@ -597,6 +609,7 @@ async def on_list_task_push_notification_configs( configs=push_notification_config_list ) + @validate_request_params async def on_delete_task_push_notification_config( self, params: DeleteTaskPushNotificationConfigRequest, diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index b290fbf44..e9f3f2fe8 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -403,16 +403,32 @@ async def abort_context( error.message if hasattr(error, 'message') else str(error) ) - # Create standard Status and pack the ErrorInfo + # Create standard Status status = status_pb2.Status(code=status_code, message=error_msg) - detail = any_pb2.Any() - detail.Pack(error_info) - status.details.append(detail) + + if ( + isinstance(error, types.InvalidParamsError) + and error.data + and error.data.get('errors') + ): + bad_request = error_details_pb2.BadRequest() + for err_dict in error.data['errors']: + violation = bad_request.field_violations.add() + violation.field = err_dict.get('field', '') + violation.description = err_dict.get('message', '') + any_bad_request = any_pb2.Any() + any_bad_request.Pack(bad_request) + status.details.append(any_bad_request) + else: + detail = any_pb2.Any() + detail.Pack(error_info) + status.details.append(detail) # Use grpc_status to safely generate standard trailing metadata rich_status = rpc_status.to_status(status) new_metadata: list[tuple[str, str | bytes]] = [] + trailing = context.trailing_metadata() if trailing: for k, v in trailing: diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index 06188e412..b13b84007 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -96,6 +96,7 @@ def _build_error_response( jsonrpc_error = model_class( code=code, message=str(error), + data=error.data, ) else: jsonrpc_error = JSONRPCInternalError(message=str(error)) diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 120a71e37..ed1ba6e52 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -1,5 +1,11 @@ +import functools +import inspect + from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable +from typing import Any + +from google.protobuf.message import Message as ProtoMessage from a2a.server.context import ServerCallContext from a2a.server.events.event_queue import Event @@ -19,6 +25,7 @@ TaskPushNotificationConfig, ) from a2a.utils.errors import UnsupportedOperationError +from a2a.utils.proto_utils import validate_proto_required_fields class RequestHandler(ABC): @@ -218,3 +225,36 @@ async def on_delete_task_push_notification_config( Returns: None """ + + +def validate_request_params(method: Callable) -> Callable: + """Decorator for RequestHandler methods to validate required fields on incoming requests.""" + if inspect.iscoroutinefunction(method): + + @functools.wraps(method) + async def async_wrapper( + self: RequestHandler, + params: ProtoMessage, + context: ServerCallContext, + *args: Any, + **kwargs: Any, + ) -> Any: + if params is not None: + validate_proto_required_fields(params) + return await method(self, params, context, *args, **kwargs) + + return async_wrapper + + @functools.wraps(method) + def sync_wrapper( + self: RequestHandler, + params: ProtoMessage, + context: ServerCallContext, + *args: Any, + **kwargs: Any, + ) -> Any: + if params is not None: + validate_proto_required_fields(params) + return method(self, params, context, *args, **kwargs) + + return sync_wrapper diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index af889d9df..50a9f2ac6 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -116,11 +116,14 @@ async def on_message_send_stream( body = await request.body() params = a2a_pb2.SendMessageRequest() Parse(body, params) - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) + stream = self.request_handler.on_message_send_stream(params, context) + + async def _generator() -> AsyncIterator[dict[str, Any]]: + async for event in stream: + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) + + return _generator() @validate_version(constants.PROTOCOL_VERSION_1_0) async def on_cancel_task( @@ -167,10 +170,15 @@ async def on_subscribe_to_task( JSON serialized objects containing streaming events """ task_id = request.path_params['id'] - async for event in self.request_handler.on_subscribe_to_task( + stream = self.request_handler.on_subscribe_to_task( SubscribeToTaskRequest(id=task_id), context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) + ) + + async def _generator() -> AsyncIterator[dict[str, Any]]: + async for event in stream: + yield MessageToDict(proto_utils.to_stream_response(event)) + + return _generator() @validate_version(constants.PROTOCOL_VERSION_1_0) async def get_push_notification( diff --git a/src/a2a/utils/errors.py b/src/a2a/utils/errors.py index a16542d97..c87fa7372 100644 --- a/src/a2a/utils/errors.py +++ b/src/a2a/utils/errors.py @@ -21,9 +21,10 @@ class A2AError(Exception): message: str = 'A2A Error' data: dict | None = None - def __init__(self, message: str | None = None): + def __init__(self, message: str | None = None, data: dict | None = None): if message: self.message = message + self.data = data super().__init__(self.message) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index cdfc306f4..34de6e47a 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -17,11 +17,15 @@ This module provides helper functions for common proto type operations. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict +from google.api.field_behavior_pb2 import FieldBehavior, field_behavior +from google.protobuf.descriptor import FieldDescriptor from google.protobuf.json_format import ParseDict from google.protobuf.message import Message as ProtobufMessage +from a2a.utils.errors import InvalidParamsError + if TYPE_CHECKING: from starlette.datastructures import QueryParams @@ -189,3 +193,106 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None: processed[k] = parsed_val ParseDict(processed, message, ignore_unknown_fields=True) + + +class ValidationDetail(TypedDict): + """Structured validation error detail.""" + + field: str + message: str + + +def _check_required_field_violation( + msg: ProtobufMessage, field: FieldDescriptor +) -> ValidationDetail | None: + """Check if a required field is missing or invalid.""" + val = getattr(msg, field.name) + if field.is_repeated: + if not val: + return ValidationDetail( + field=field.name, + message='Field must contain at least one element.', + ) + elif field.has_presence: + if not msg.HasField(field.name): + return ValidationDetail( + field=field.name, message='Field is required.' + ) + elif val == field.default_value: + return ValidationDetail(field=field.name, message='Field is required.') + return None + + +def _append_nested_errors( + errors: list[ValidationDetail], + prefix: str, + sub_errs: list[ValidationDetail], +) -> None: + """Format nested validation errors and append to errors list.""" + for sub in sub_errs: + sub_field = sub['field'] + errors.append( + ValidationDetail( + field=f'{prefix}.{sub_field}' if sub_field else prefix, + message=sub['message'], + ) + ) + + +def _recurse_validation( + msg: ProtobufMessage, field: FieldDescriptor +) -> list[ValidationDetail]: + """Recurse validation for nested messages and map fields.""" + errors: list[ValidationDetail] = [] + if field.type != FieldDescriptor.TYPE_MESSAGE: + return errors + + val = getattr(msg, field.name) + if not field.is_repeated: + if msg.HasField(field.name): + sub_errs = _validate_proto_required_fields_internal(val) + _append_nested_errors(errors, field.name, sub_errs) + elif field.message_type.GetOptions().map_entry: + for k, v in val.items(): + if isinstance(v, ProtobufMessage): + sub_errs = _validate_proto_required_fields_internal(v) + _append_nested_errors(errors, f'{field.name}[{k}]', sub_errs) + else: + for i, item in enumerate(val): + sub_errs = _validate_proto_required_fields_internal(item) + _append_nested_errors(errors, f'{field.name}[{i}]', sub_errs) + return errors + + +def _validate_proto_required_fields_internal( + msg: ProtobufMessage, +) -> list[ValidationDetail]: + """Internal validation that returns a list of error dictionaries.""" + desc = msg.DESCRIPTOR + errors: list[ValidationDetail] = [] + + for field in desc.fields: + options = field.GetOptions() + if FieldBehavior.REQUIRED in options.Extensions[field_behavior]: + violation = _check_required_field_violation(msg, field) + if violation: + errors.append(violation) + errors.extend(_recurse_validation(msg, field)) + return errors + + +def validate_proto_required_fields(msg: ProtobufMessage) -> None: + """Validate that all fields marked as REQUIRED are present on the proto message. + + Args: + msg: The Protobuf message to validate. + + Raises: + InvalidParamsError: If a required field is missing or empty. + """ + errors = _validate_proto_required_fields_internal(msg) + + if errors: + raise InvalidParamsError( + message='Validation failed', data={'errors': errors} + ) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index ddf9edbf3..efeea9ad6 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator -from typing import NamedTuple +from typing import Any, NamedTuple import grpc import httpx @@ -31,6 +31,7 @@ a2a_pb2_grpc, ) from a2a.utils import TransportProtocol +from a2a.utils.errors import InvalidParamsError def assert_message_matches(message, expected_role, expected_text): @@ -546,3 +547,34 @@ async def test_end_to_end_input_required(transport_setups): ], ) assert_message_matches(task.status.message, Role.ROLE_AGENT, 'done') + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'empty_request, expected_fields', + [ + ( + SendMessageRequest(), + {'message'}, + ), + ( + SendMessageRequest(message=Message()), + {'message.message_id', 'message.role', 'message.parts'}, + ), + ], +) +async def test_end_to_end_validation_errors( + transport_setups, + empty_request: SendMessageRequest, + expected_fields: set[str], +) -> None: + client = transport_setups.client + + with pytest.raises(InvalidParamsError) as exc_info: + async for _ in client.send_message(request=empty_request): + pass + + errors = exc_info.value.data.get('errors', []) + assert {e['field'] for e in errors} == expected_fields + + await client.close() diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index ba2627e38..3d22813c6 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -451,7 +451,9 @@ async def test_on_cancel_task_invalid_result_type(): # Mock ResultAggregator to return a Message mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) mock_result_aggregator_instance.consume_all.return_value = Message( - message_id='unexpected_msg', role=Role.ROLE_AGENT, parts=[] + message_id='unexpected_msg', + role=Role.ROLE_AGENT, + parts=[Part(text='Test')], ) request_handler = DefaultRequestHandler( @@ -524,7 +526,7 @@ async def test_on_message_send_with_push_notification(): message=Message( role=Role.ROLE_USER, message_id='msg_push', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -630,7 +632,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request(): message=Message( role=Role.ROLE_USER, message_id='msg_non_blocking', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -750,7 +752,11 @@ async def test_on_message_send_with_push_notification_no_existing_Task(): accepted_output_modes=['text/plain'], # Added required field ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg_push', parts=[]), + message=Message( + role=Role.ROLE_USER, + message_id='msg_push', + parts=[Part(text='Test')], + ), configuration=message_config, ) @@ -815,7 +821,11 @@ async def test_on_message_send_no_result_from_aggregator(): request_context_builder=mock_request_context_builder, ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg_no_res', parts=[]) + message=Message( + role=Role.ROLE_USER, + message_id='msg_no_res', + parts=[Part(text='Test')], + ) ) mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) @@ -863,7 +873,9 @@ async def test_on_message_send_task_id_mismatch(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_id_mismatch', parts=[] + role=Role.ROLE_USER, + message_id='msg_id_mismatch', + parts=[Part(text='Test')], ) ) @@ -1067,7 +1079,9 @@ async def test_on_message_send_interrupted_flow(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_interrupt', parts=[] + role=Role.ROLE_USER, + message_id='msg_interrupt', + parts=[Part(text='Test')], ) ) @@ -1178,7 +1192,7 @@ async def test_on_message_send_stream_with_push_notification(): message=Message( role=Role.ROLE_USER, message_id='msg_stream_push', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -1460,7 +1474,7 @@ async def test_stream_disconnect_then_resubscribe_receives_future_events(): message=Message( role=Role.ROLE_USER, message_id='msg_reconn', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1558,7 +1572,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea message=Message( role=Role.ROLE_USER, message_id='mid', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1698,7 +1712,7 @@ async def cancel( message=Message( role=Role.ROLE_USER, message_id='msg_persist', - parts=[], + parts=[Part(text='Test')], ) ) @@ -1785,7 +1799,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(): message=Message( role=Role.ROLE_USER, message_id='mid_track', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1890,7 +1904,9 @@ async def test_on_message_send_stream_task_id_mismatch(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_stream_mismatch', parts=[] + role=Role.ROLE_USER, + message_id='msg_stream_mismatch', + parts=[Part(text='Test')], ) ) @@ -2586,7 +2602,7 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): message=Message( role=Role.ROLE_USER, message_id='msg_terminal', - parts=[], + parts=[Part(text='Test')], task_id=task_id, ) ) @@ -2627,7 +2643,7 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): message=Message( role=Role.ROLE_USER, message_id='msg_terminal_stream', - parts=[], + parts=[Part(text='Test')], task_id=task_id, ) ) @@ -2869,7 +2885,9 @@ async def test_on_message_send_negative_history_length_error(): accepted_output_modes=['text/plain'], ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg1', parts=[]), + message=Message( + role=Role.ROLE_USER, message_id='msg1', parts=[Part(text='Test')] + ), configuration=message_config, ) context = create_server_call_context() diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 6a53541f3..e2c760bae 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -239,3 +239,44 @@ def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams: return httpx.Request( 'GET', 'http://api.example.com', params=rest_dict ).url.params + + +class TestValidateProtoRequiredFields: + """Tests for validate_proto_required_fields function.""" + + def test_valid_required_fields(self): + """Test with all required fields present.""" + msg = Message( + message_id='msg-1', + role=Role.ROLE_USER, + parts=[Part(text='hello')], + ) + proto_utils.validate_proto_required_fields(msg) + + def test_missing_required_fields(self): + """Test with empty message raising InvalidParamsError containing all errors.""" + from a2a.utils.errors import InvalidParamsError + + msg = Message() + with pytest.raises(InvalidParamsError) as exc_info: + proto_utils.validate_proto_required_fields(msg) + + err = exc_info.value + errors = err.data.get('errors', []) if err.data else [] + + assert {e['field'] for e in errors} == {'message_id', 'role', 'parts'} + + def test_nested_required_fields(self): + """Test nested required fields inside TaskStatus.""" + from a2a.utils.errors import InvalidParamsError + + # Task Status requires 'state' + task = Task(id='task-1', status=TaskStatus()) + with pytest.raises(InvalidParamsError) as exc_info: + proto_utils.validate_proto_required_fields(task) + + err = exc_info.value + errors = err.data.get('errors', []) if err.data else [] + + fields = [e['field'] for e in errors] + assert 'status.state' in fields