diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e2ec69a15..20774001c 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Sequence +from collections.abc import AsyncIterable try: @@ -53,13 +53,11 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: def _get_metadata_value( context: grpc.aio.ServicerContext, key: str ) -> list[str]: - md = context.invocation_metadata - raw_values: list[str | bytes] = [] - if isinstance(md, Metadata): - raw_values = md.get_all(key) - elif isinstance(md, Sequence): - lower_key = key.lower() - raw_values = [e for (k, e) in md if k.lower() == lower_key] + md = context.invocation_metadata() + raw_values: list[str | bytes] = ( + md.get_all(key) if isinstance(md, Metadata) else [] + ) + return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values] diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 26f923c14..5c7e26ef1 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -321,7 +321,7 @@ async def test_send_message_with_extensions( mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, ) -> None: - mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata( (HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'bar'), ) @@ -361,7 +361,7 @@ async def test_send_message_with_comma_separated_extensions( mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, ) -> None: - mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata( (HTTP_EXTENSION_HEADER, 'foo ,, bar,'), (HTTP_EXTENSION_HEADER, 'baz , bar'), ) @@ -386,7 +386,7 @@ async def test_send_streaming_message_with_extensions( mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, ) -> None: - mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata( (HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'bar'), )