diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 17989374d6..84287795cc 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -26,6 +26,7 @@ from ..experimental import a2a_experimental from .part_converter import A2APartToGenAIPartConverter from .part_converter import convert_a2a_part_to_genai_part +from .utils import _from_a2a_context_id @a2a_experimental @@ -70,6 +71,10 @@ def _get_user_id(request: RequestContext) -> str: ): return request.call_context.user.user_name + _, user_id, _ = _from_a2a_context_id(request.context_id) + if user_id: + return user_id + # Get user from context id return f'A2A_USER_{request.context_id}' @@ -106,9 +111,11 @@ def convert_a2a_request_to_agent_run_request( genai_parts = [genai_parts] if genai_parts else [] output_parts.extend(genai_parts) + _, _, session_id = _from_a2a_context_id(request.context_id) + return AgentRunRequest( user_id=_get_user_id(request), - session_id=request.context_id, + session_id=session_id, new_message=genai_types.Content( role='user', parts=output_parts, diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index f2f9852ab6..831b4aa00e 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -41,6 +41,7 @@ from ...utils.context_utils import Aclosing from ..converters.request_converter import AgentRunRequest from ..converters.utils import _get_adk_metadata_key +from ..converters.utils import _to_a2a_context_id from ..experimental import a2a_experimental from .a2a_agent_executor_impl import _A2aAgentExecutor as ExecutorImpl from .config import A2aAgentExecutorConfig @@ -53,6 +54,15 @@ logger = logging.getLogger('google_adk.' + __name__) +class _HandleRequestErrorWithContextId(Exception): + """Carries response context ID for error events emitted from execute().""" + + def __init__(self, *, context_id: str, cause: Exception): + super().__init__(str(cause)) + self.context_id = context_id + self.cause = cause + + @a2a_experimental class A2aAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK Agent against an A2A request and @@ -139,28 +149,15 @@ async def execute( context = await execute_before_agent_interceptors( context, self._config.execute_interceptors ) - - # for new task, create a task submitted event - if not context.current_task: - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.submitted, - message=context.message, - timestamp=datetime.fromtimestamp( - platform_time.get_time(), tz=timezone.utc - ).isoformat(), - ), - context_id=context.context_id, - final=False, - ) - ) - # Handle the request and publish updates to the event queue try: await self._handle_request(context, event_queue) except Exception as e: + failure_context_id = context.context_id + if isinstance(e, _HandleRequestErrorWithContextId): + failure_context_id = e.context_id + e = e.cause + logger.error('Error handling A2A request: %s', e, exc_info=True) # Publish failure event try: @@ -178,7 +175,7 @@ async def execute( parts=[TextPart(text=str(e))], ), ), - context_id=context.context_id, + context_id=failure_context_id, final=True, ) ) @@ -203,114 +200,157 @@ async def _handle_request( # ensure the session exists session = await self._prepare_session(context, run_request, runner) - - # create invocation context - invocation_context = runner._new_invocation_context( - session=session, - new_message=run_request.new_message, - run_config=run_request.run_config, - ) - - executor_context = ExecutorContext( - app_name=runner.app_name, - user_id=run_request.user_id, - session_id=run_request.session_id, + response_context_id = self._get_response_context_id( + context=context, runner=runner, + run_request=run_request, + session_id=session.id, ) + try: + + # for new task, create a task submitted event + if not context.current_task: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.submitted, + message=context.message, + timestamp=datetime.fromtimestamp( + platform_time.get_time(), tz=timezone.utc + ).isoformat(), + ), + context_id=response_context_id, + final=False, + ) + ) + + # create invocation context + invocation_context = runner._new_invocation_context( + session=session, + new_message=run_request.new_message, + run_config=run_request.run_config, + ) + + executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + + # publish the task working event + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.fromtimestamp( + platform_time.get_time(), tz=timezone.utc + ).isoformat(), + ), + context_id=response_context_id, + final=False, + metadata={ + _get_adk_metadata_key('app_name'): runner.app_name, + _get_adk_metadata_key('user_id'): run_request.user_id, + _get_adk_metadata_key('session_id'): run_request.session_id, + }, + ) + ) - # publish the task working event - await event_queue.enqueue_event( - TaskStatusUpdateEvent( + task_result_aggregator = TaskResultAggregator() + async with Aclosing(runner.run_async(**vars(run_request))) as agen: + async for adk_event in agen: + for a2a_event in self._config.event_converter( + adk_event, + invocation_context, + context.task_id, + response_context_id, + self._config.gen_ai_part_converter, + ): + a2a_event = await execute_after_event_interceptors( + a2a_event, + executor_context, + adk_event, + self._config.execute_interceptors, + ) + if a2a_event is None: + continue + + task_result_aggregator.process_event(a2a_event) + await event_queue.enqueue_event(a2a_event) + + # publish the task result event - this is final + if ( + task_result_aggregator.task_state == TaskState.working + and task_result_aggregator.task_status_message is not None + and task_result_aggregator.task_status_message.parts + ): + # if task is still working properly, publish the artifact update event as + # the final result according to a2a protocol. + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + last_chunk=True, + context_id=response_context_id, + artifact=Artifact( + artifact_id=platform_uuid.new_uuid(), + parts=task_result_aggregator.task_status_message.parts, + ), + ) + ) + # public the final status update event + final_event = TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.working, + state=TaskState.completed, timestamp=datetime.fromtimestamp( platform_time.get_time(), tz=timezone.utc ).isoformat(), ), - context_id=context.context_id, - final=False, - metadata={ - _get_adk_metadata_key('app_name'): runner.app_name, - _get_adk_metadata_key('user_id'): run_request.user_id, - _get_adk_metadata_key('session_id'): run_request.session_id, - }, + context_id=response_context_id, + final=True, + ) + else: + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=task_result_aggregator.task_state, + timestamp=datetime.fromtimestamp( + platform_time.get_time(), tz=timezone.utc + ).isoformat(), + message=task_result_aggregator.task_status_message, + ), + context_id=response_context_id, + final=True, ) - ) - task_result_aggregator = TaskResultAggregator() - async with Aclosing(runner.run_async(**vars(run_request))) as agen: - async for adk_event in agen: - for a2a_event in self._config.event_converter( - adk_event, - invocation_context, - context.task_id, - context.context_id, - self._config.gen_ai_part_converter, - ): - a2a_event = await execute_after_event_interceptors( - a2a_event, - executor_context, - adk_event, - self._config.execute_interceptors, - ) - if a2a_event is None: - continue - - task_result_aggregator.process_event(a2a_event) - await event_queue.enqueue_event(a2a_event) - - # publish the task result event - this is final - if ( - task_result_aggregator.task_state == TaskState.working - and task_result_aggregator.task_status_message is not None - and task_result_aggregator.task_status_message.parts - ): - # if task is still working properly, publish the artifact update event as - # the final result according to a2a protocol. - await event_queue.enqueue_event( - TaskArtifactUpdateEvent( - task_id=context.task_id, - last_chunk=True, - context_id=context.context_id, - artifact=Artifact( - artifact_id=platform_uuid.new_uuid(), - parts=task_result_aggregator.task_status_message.parts, - ), - ) - ) - # public the final status update event - final_event = TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.completed, - timestamp=datetime.fromtimestamp( - platform_time.get_time(), tz=timezone.utc - ).isoformat(), - ), - context_id=context.context_id, - final=True, - ) - else: - final_event = TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=task_result_aggregator.task_state, - timestamp=datetime.fromtimestamp( - platform_time.get_time(), tz=timezone.utc - ).isoformat(), - message=task_result_aggregator.task_status_message, - ), - context_id=context.context_id, - final=True, + final_event = await execute_after_agent_interceptors( + executor_context, + final_event, + self._config.execute_interceptors, ) + await event_queue.enqueue_event(final_event) + except Exception as e: + raise _HandleRequestErrorWithContextId( + context_id=response_context_id, cause=e + ) from e - final_event = await execute_after_agent_interceptors( - executor_context, - final_event, - self._config.execute_interceptors, - ) - await event_queue.enqueue_event(final_event) + def _get_response_context_id( + self, + *, + context: RequestContext, + runner: Runner, + run_request: AgentRunRequest, + session_id: str, + ) -> str: + try: + return _to_a2a_context_id( + runner.app_name, run_request.user_id, session_id + ) + except ValueError: + return context.context_id async def _prepare_session( self, @@ -318,10 +358,18 @@ async def _prepare_session( run_request: AgentRunRequest, runner: Runner, ): - session_id = run_request.session_id - # create a new session if not exists user_id = run_request.user_id + if not session_id: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=user_id, + state={}, + ) + run_request.session_id = session.id + return session + + # create a new session if not exists session = await runner.session_service.get_session( app_name=runner.app_name, user_id=user_id, diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index cd284ea313..4fbb29d377 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -18,6 +18,7 @@ from a2a.server.agent_execution import RequestContext from google.adk.a2a.converters.request_converter import _get_user_id from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request +from google.adk.a2a.converters.utils import _to_a2a_context_id from google.adk.runners import RunConfig from google.genai import types as genai_types import pytest @@ -58,6 +59,16 @@ def test_get_user_id_from_context_when_no_call_context(self): # Assert assert result == "A2A_USER_test_context" + def test_get_user_id_from_adk_context_id(self): + """Test getting user ID from ADK-formatted context id.""" + request = Mock(spec=RequestContext) + request.call_context = None + request.context_id = _to_a2a_context_id("app", "user-123", "session-456") + + result = _get_user_id(request) + + assert result == "user-123" + def test_get_user_id_from_context_when_call_context_has_no_user(self): """Test getting user ID from context when call context has no user.""" # Arrange @@ -129,6 +140,27 @@ def test_get_user_id_with_none_context_id(self): class TestConvertA2aRequestToAgentRunRequest: """Test cases for convert_a2a_request_to_agent_run_request function.""" + def test_convert_a2a_request_with_adk_context_id(self): + """Test conversion uses ADK context id for user/session.""" + mock_message = Mock() + mock_message.parts = [Mock()] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = _to_a2a_context_id("app", "user-1", "session-1") + request.call_context = None + request.metadata = {} + + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part = Mock(return_value=mock_genai_part) + + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) + + assert result.user_id == "user-1" + assert result.session_id == "session-1" + def test_convert_a2a_request_basic(self): """Test basic conversion of A2A request to ADK AgentRunRequest.""" # Arrange @@ -164,7 +196,7 @@ def test_convert_a2a_request_basic(self): # Assert assert result is not None assert result.user_id == "test_user" - assert result.session_id == "test_context_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part1, mock_genai_part2] @@ -213,7 +245,7 @@ def test_convert_a2a_request_multiple_parts(self): # Assert assert result is not None assert result.user_id == "test_user" - assert result.session_id == "test_context_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [ @@ -261,7 +293,7 @@ def test_convert_a2a_request_empty_parts(self): # Assert assert result is not None assert result.user_id == "A2A_USER_test_context_123" - assert result.session_id == "test_context_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [] @@ -328,7 +360,7 @@ def test_convert_a2a_request_no_auth(self): # Assert assert result is not None assert result.user_id == "A2A_USER_session_123" - assert result.session_id == "session_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part] @@ -370,7 +402,7 @@ def test_end_to_end_conversion_with_auth_user(self): # Assert assert result is not None assert result.user_id == "auth_user" # Should use authenticated user - assert result.session_id == "mysession" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part] @@ -404,7 +436,7 @@ def test_end_to_end_conversion_with_fallback_user(self): assert ( result.user_id == "A2A_USER_test_session_456" ) # Should fall back to context ID - assert result.session_id == "test_session_456" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part] diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 787b260fe6..10879f1be3 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -25,6 +25,7 @@ from a2a.types import TaskState from a2a.types import TextPart from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.converters.utils import _to_a2a_context_id from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig from google.adk.a2a.executor.config import ExecuteInterceptor @@ -111,6 +112,10 @@ async def mock_run_async(**kwargs): # Execute await self.executor.execute(self.mock_context, self.mock_event_queue) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify request converter was called with proper arguments self.mock_request_converter.assert_called_once_with( self.mock_context, self.mock_a2a_part_converter @@ -121,7 +126,7 @@ async def mock_run_async(**kwargs): mock_event, mock_invocation_context, self.mock_context.task_id, - self.mock_context.context_id, + expected_context_id, self.mock_gen_ai_part_converter, ) @@ -132,11 +137,13 @@ async def mock_run_async(**kwargs): ] assert submitted_event.status.state == TaskState.submitted assert submitted_event.final == False + assert submitted_event.context_id == expected_context_id # Verify working event was enqueued working_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][0] assert working_event.status.state == TaskState.working assert working_event.final == False + assert working_event.context_id == expected_context_id # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] @@ -145,6 +152,7 @@ async def mock_run_async(**kwargs): # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") assert final_event.status.state == TaskState.working + assert final_event.context_id == expected_context_id @pytest.mark.asyncio async def test_execute_no_message_error(self): @@ -194,6 +202,10 @@ async def mock_run_async(**kwargs): # Execute await self.executor.execute(self.mock_context, self.mock_event_queue) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify request converter was called with proper arguments self.mock_request_converter.assert_called_once_with( self.mock_context, self.mock_a2a_part_converter @@ -204,7 +216,7 @@ async def mock_run_async(**kwargs): mock_event, mock_invocation_context, self.mock_context.task_id, - self.mock_context.context_id, + expected_context_id, self.mock_gen_ai_part_converter, ) @@ -212,6 +224,7 @@ async def mock_run_async(**kwargs): working_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] assert working_event.status.state == TaskState.working assert working_event.final == False + assert working_event.context_id == expected_context_id # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] @@ -220,6 +233,7 @@ async def mock_run_async(**kwargs): # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") assert final_event.status.state == TaskState.working + assert final_event.context_id == expected_context_id @pytest.mark.asyncio async def test_prepare_session_new_session(self): @@ -617,22 +631,48 @@ async def test_execute_with_exception_handling(self): # Execute (should not raise since we catch the exception) await self.executor.execute(self.mock_context, self.mock_event_queue) - # Verify both submitted and failure events were enqueued - # First call should be submitted event, last should be failure event - assert self.mock_event_queue.enqueue_event.call_count >= 2 - - # Check submitted event (first) - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ - 0 - ] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Request converter error happens before submitted event is enqueued. + assert self.mock_event_queue.enqueue_event.call_count >= 1 # Check failure event (last) failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert failure_event.status.state == TaskState.failed assert failure_event.final == True + @pytest.mark.asyncio + async def test_execute_with_exception_after_mapped_events_uses_mapped_context_id( + self, + ): + """Test failure event context ID stays mapped when handle_request fails.""" + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + self.mock_runner._new_invocation_context.return_value = Mock() + + async def mock_run_async(**kwargs): + raise RuntimeError("stream failure") + yield # pragma: no cover + + self.mock_runner.run_async = mock_run_async + + await self.executor.execute(self.mock_context, self.mock_event_queue) + + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert failure_event.status.state == TaskState.failed + assert failure_event.context_id == expected_context_id + @pytest.mark.asyncio async def test_handle_request_with_aggregator_message(self): """Test that the final task status event includes message from aggregator.""" @@ -850,6 +890,10 @@ async def mock_run_async(**kwargs): self.mock_context, self.mock_event_queue ) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify artifact update event was published artifact_events = [ call[0][0] @@ -859,7 +903,7 @@ async def mock_run_async(**kwargs): assert len(artifact_events) == 1 artifact_event = artifact_events[0] assert artifact_event.task_id == "test-task-id" - assert artifact_event.context_id == "test-context-id" + assert artifact_event.context_id == expected_context_id # Check that artifact parts correspond to message parts assert len(artifact_event.artifact.parts) == len(test_message.parts) assert artifact_event.artifact.parts == test_message.parts @@ -874,7 +918,7 @@ async def mock_run_async(**kwargs): final_event = final_events[-1] # Get the last final event assert final_event.status.state == TaskState.completed assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + assert final_event.context_id == expected_context_id @pytest.mark.asyncio async def test_handle_request_with_non_working_state_publishes_status_only( @@ -943,6 +987,10 @@ async def mock_run_async(**kwargs): self.mock_context, self.mock_event_queue ) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify no artifact update event was published artifact_events = [ call[0][0] @@ -962,7 +1010,7 @@ async def mock_run_async(**kwargs): assert final_event.status.state == TaskState.auth_required assert final_event.status.message == test_message assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + assert final_event.context_id == expected_context_id @pytest.mark.asyncio async def test_after_event_interceptors_receive_correct_arguments_and_can_modify_event(