From c43b70f57b5a4db7c2901bb85b87b8c95a2f6eac Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Tue, 5 May 2026 14:03:26 +0000 Subject: [PATCH] Enqueue exceptions in ActiveTask. --- src/a2a/server/agent_execution/active_task.py | 125 +++++++++--------- .../default_request_handler_v2.py | 8 -- tests/integration/test_scenarios.py | 9 ++ .../agent_execution/test_active_task.py | 44 +++--- 4 files changed, 100 insertions(+), 86 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index 6498bc772..268dd1999 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -104,36 +104,60 @@ async def run(self) -> None: ) except Exception as e: logger.exception('Consumer[%s]: Failed', self.active_task._task_id) - async with self.active_task._lock: - await self.active_task._mark_task_as_failed(e) + + updated_task = None + task = await self.active_task._task_manager.get_task() + if task: + handled_event = TaskStatusUpdateEvent( + task_id=task.id, + context_id=task.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_FAILED, + ), + ) + updated_task = await self._handle_task_event(handled_event) + + await self._enqueue_to_subscribers(cast('Event', e), updated_task) async def _process_event(self, event: Event) -> None: updated_task = None + handled_event: ( + Task + | TaskStatusUpdateEvent + | TaskArtifactUpdateEvent + | PushNotificationEvent + | None + ) = None + + if isinstance(event, _RequestCompleted): + logger.debug( + 'Consumer[%s]: Request completed', self.active_task._task_id + ) + self.active_task._request_lock.release() + elif isinstance(event, _RequestStarted): + logger.debug( + 'Consumer[%s]: Request started', self.active_task._task_id + ) + self.message_to_save = event.request_context.message + elif isinstance(event, BaseException): + raise event + elif isinstance(event, Message): + self._handle_message_event(event) + elif isinstance( + event, + TaskStatusUpdateEvent + | TaskArtifactUpdateEvent + | PushNotificationEvent + | Task, + ): + updated_task = await self._handle_task_event(event) + handled_event = updated_task if isinstance(event, Task) else event - try: - if isinstance(event, _RequestCompleted): - logger.debug( - 'Consumer[%s]: Request completed', self.active_task._task_id - ) - self.active_task._request_lock.release() - elif isinstance(event, _RequestStarted): - logger.debug( - 'Consumer[%s]: Request started', self.active_task._task_id - ) - self.message_to_save = event.request_context.message - elif isinstance(event, Message): - self._handle_message_event(event) - else: - updated_task = await self._handle_task_event(event) - if isinstance(event, Task): - event = updated_task - - if updated_task is not None: - await self._update_task_state(updated_task, event) - self.active_task._task_created.set() + if updated_task is not None and handled_event is not None: + await self._update_task_state(updated_task, handled_event) + self.active_task._task_created.set() - finally: - await self._enqueue_to_subscribers(event, updated_task) + await self._enqueue_to_subscribers(event, updated_task) def _handle_message_event(self, event: Message) -> None: if self.task_mode is True: @@ -286,9 +310,6 @@ class ActiveTask: - `self._lock` (asyncio.Lock) ensures mutually exclusive access for critical lifecycle state changes, such as starting the task, subscribing, and determining if cleanup is safe to trigger. - - mutation to the observable result state (like `_exception`, - or `_is_finished`) notifies waiting coroutines (like `wait()`). - `self._is_finished` (asyncio.Event) provides a thread-safe, non-blocking way for external observers and internal loops to check if the ActiveTask has permanently ceased execution and closed its queues. @@ -349,10 +370,6 @@ def __init__( # Protected by `_lock`. self._reference_count = 0 - # Holds any fatal exception that crashed the producer or consumer. - # TODO: Synchronize exception handling (ideally mix it in the queue). - self._exception: Exception | None = None - # Queue for incoming requests self._request_queue: AsyncQueue[tuple[RequestContext, uuid.UUID]] = ( _create_async_queue() @@ -481,7 +498,6 @@ async def _run_producer(self) -> None: _RequestStarted(request_id, request_context), ) ) - await self._agent_executor.execute( request_context, self._event_queue_agent ) @@ -489,14 +505,10 @@ async def _run_producer(self) -> None: 'Producer[%s]: Execution finished successfully', self._task_id, ) - finally: - logger.debug( - 'Producer[%s]: Enqueuing request completed event', - self._task_id, - ) await self._event_queue_agent.enqueue_event( cast('Event', _RequestCompleted(request_id)) ) + finally: self._request_queue.task_done() except asyncio.CancelledError: logger.debug('Producer[%s]: Cancelled', self._task_id) @@ -516,8 +528,7 @@ async def _run_producer(self) -> None: request_context.context_id or '', ) self._task_created.set() - async with self._lock: - await self._mark_task_as_failed(e) + await self._event_queue_agent.enqueue_event(cast('Event', e)) finally: self._request_queue.shutdown(immediate=True) @@ -537,7 +548,7 @@ async def _run_consumer(self) -> None: logger.debug('Consumer[%s]: Finishing', self._task_id) await self._maybe_cleanup() - async def subscribe( # noqa: PLR0912, PLR0915 + async def subscribe( self, *, request: RequestContext | None = None, @@ -554,12 +565,6 @@ async def subscribe( # noqa: PLR0912, PLR0915 logger.debug('Subscribe[%s]: New subscriber', self._task_id) async with self._lock: - if self._exception: - logger.debug( - 'Subscribe[%s]: Failed, exception already set', - self._task_id, - ) - raise self._exception if self._is_finished.is_set(): raise InvalidParamsError( f'Task {self._task_id} is already completed.' @@ -585,17 +590,23 @@ async def subscribe( # noqa: PLR0912, PLR0915 while True: try: - if self._exception: - raise self._exception - dequeued = await tapped_queue.dequeue_event() event, updated_task = cast('Any', dequeued) logger.debug( - 'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n', + 'Subscriber[%s] Dequeued event [%s]:\n %s\nUpdated task:\n%s\n', self._task_id, + type(event).__name__, event, updated_task, ) + if isinstance(event, BaseException): + logger.debug( + 'Subscriber[%s]: Raising exception: %s', + self._task_id, + event, + ) + raise event + if replace_status_update_with_task and isinstance( event, TaskStatusUpdateEvent ): @@ -605,8 +616,6 @@ async def subscribe( # noqa: PLR0912, PLR0915 updated_task, ) event = updated_task - if self._exception: - raise self._exception from None if isinstance(event, _RequestCompleted): if ( request_id is not None @@ -629,8 +638,6 @@ async def subscribe( # noqa: PLR0912, PLR0915 finally: tapped_queue.task_done() except (QueueShutDown, asyncio.CancelledError): - if self._exception: - raise self._exception from None break finally: logger.debug('Subscribe[%s]: Unsubscribing', self._task_id) @@ -714,9 +721,9 @@ async def _maybe_cleanup(self) -> None: logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id) self._on_cleanup(self) - async def _mark_task_as_failed(self, exception: Exception) -> None: - if self._exception is None: - self._exception = exception + async def _mark_task_as_failed(self, exception: Exception) -> Task | None: + logger.debug('Marking task %s as failed: %s', self._task_id, exception) + task = None if self._task_created.is_set(): try: task = await self._task_manager.get_task() @@ -732,10 +739,10 @@ async def _mark_task_as_failed(self, exception: Exception) -> None: ) except QueueShutDown: pass + return task async def get_task(self) -> Task: """Get task from db.""" - # TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation). await self._task_created.wait() task = await self._task_manager.get_task() if not task: diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index b5a70fa8b..30304609a 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -37,7 +37,6 @@ SubscribeToTaskRequest, Task, TaskPushNotificationConfig, - TaskStatusUpdateEvent, ) from a2a.utils.errors import ( ExtendedAgentCardNotConfiguredError, @@ -252,13 +251,6 @@ async def on_message_send( # noqa: D102 type(event).__name__, event, ) - if isinstance(event, TaskStatusUpdateEvent): - self._validate_task_id_match(task_id, event.task_id) - event = await active_task.get_task() - logger.debug( - 'Replaced TaskStatusUpdateEvent with Task: %s', event - ) - if isinstance(event, Task) and ( params.configuration.return_immediately or event.status.state diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index bf1c6d3af..3f2383fae 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -461,6 +461,15 @@ async def cancel( (task,) = (await client.list_tasks(ListTasksRequest())).tasks assert task.status.state == TaskState.TASK_STATE_FAILED + if streaming: + with pytest.raises( + InvalidParamsError, + match='Task .* is already completed', + ): + await client.subscribe( + SubscribeToTaskRequest(id=task.id) + ).__anext__() + # Scenario 12/13: Exception after initial event @pytest.mark.timeout(2.0) diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py index 0ed960641..ce9e2c068 100644 --- a/tests/server/agent_execution/test_active_task.py +++ b/tests/server/agent_execution/test_active_task.py @@ -316,26 +316,42 @@ async def test_active_task_subscribe_exception_handling( active_task: ActiveTask, agent_executor: Mock, request_context: Mock, + task_manager: Mock, ) -> None: """Test exception handling in subscribe.""" - agent_executor.execute = AsyncMock( - side_effect=ValueError('Producer failure') + event = asyncio.Event() + + task_manager.get_task.return_value = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) + async def execute_mock(req, q): + await q.enqueue_event( + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + ) + ) + await event.wait() + raise ValueError('Producer failure') + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + await active_task.enqueue_request(request_context) await active_task.start( call_context=ServerCallContext(), create_task_if_missing=True ) - # Give it a moment to fail - for _ in range(10): - if active_task._exception: - break - await asyncio.sleep(0.05) + subscriber = active_task.subscribe() + task = await anext(subscriber) + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + + # Now trigger the exception + event.set() with pytest.raises(ValueError, match='Producer failure'): - async for _ in active_task.subscribe(): - pass + await anext(subscriber) @pytest.mark.asyncio async def test_active_task_cancel_not_started( @@ -766,16 +782,6 @@ async def test_active_task_maybe_cleanup_not_finished( await active_task._maybe_cleanup() on_cleanup.assert_not_called() - @pytest.mark.asyncio - async def test_active_task_subscribe_exception_already_set( - self, active_task: ActiveTask - ) -> None: - """Test subscribe when exception is already set.""" - active_task._exception = ValueError('Pre-existing error') - with pytest.raises(ValueError, match='Pre-existing error'): - async for _ in active_task.subscribe(): - pass - @pytest.mark.asyncio async def test_active_task_subscribe_inner_exception( self,