From f04563e529b49247076478072b90d7146b757dba Mon Sep 17 00:00:00 2001 From: c <37263590+Aphroq@users.noreply.github.com> Date: Sun, 10 May 2026 17:42:36 +0000 Subject: [PATCH] Await realtime background tasks during cleanup --- src/agents/realtime/session.py | 44 ++++++--- tests/realtime/test_session.py | 113 +++++++++++++++++++++- tests/realtime/test_session_exceptions.py | 33 ++++--- 3 files changed, 167 insertions(+), 23 deletions(-) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index fbb204502a..41ca1a950b 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -72,6 +72,7 @@ ) REJECTION_MESSAGE = DEFAULT_APPROVAL_REJECTION_MESSAGE +_BACKGROUND_TASK_CLEANUP_TIMEOUT_SECONDS = 1.0 class _RealtimeSessionClosedSentinel: @@ -1051,11 +1052,35 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: ) ) - def _cleanup_guardrail_tasks(self) -> None: - for task in self._guardrail_tasks: + async def _cleanup_background_tasks( + self, tasks_set: set[asyncio.Task[Any]], task_kind: str + ) -> None: + tasks = list(tasks_set) + if not tasks: + return + + for task in tasks: if not task.done(): task.cancel() - self._guardrail_tasks.clear() + + done, pending = await asyncio.wait( + tasks, + timeout=_BACKGROUND_TASK_CLEANUP_TIMEOUT_SECONDS, + ) + if done: + await asyncio.gather(*done, return_exceptions=True) + if pending: + logger.warning( + "Timed out waiting for %s Realtime %s background task(s) " + "to finish during cleanup; continuing shutdown.", + len(pending), + task_kind, + ) + + tasks_set.difference_update(tasks) + + async def _cleanup_guardrail_tasks(self) -> None: + await self._cleanup_background_tasks(self._guardrail_tasks, "guardrail") def _enqueue_tool_call_task( self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent @@ -1089,11 +1114,8 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: ) ) - def _cleanup_tool_call_tasks(self) -> None: - for task in self._tool_call_tasks: - if not task.done(): - task.cancel() - self._tool_call_tasks.clear() + async def _cleanup_tool_call_tasks(self) -> None: + await self._cleanup_background_tasks(self._tool_call_tasks, "tool-call") def _wake_event_iterators(self) -> None: for _ in range(self._event_iterator_waiters): @@ -1105,9 +1127,9 @@ async def _cleanup(self) -> None: self._wake_event_iterators() return - # Cancel and cleanup guardrail tasks - self._cleanup_guardrail_tasks() - self._cleanup_tool_call_tasks() + # Cancel and clean up background tasks. + await self._cleanup_guardrail_tasks() + await self._cleanup_tool_call_tasks() # Remove ourselves as a listener self._model.remove_listener(self) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 000ecf9930..2d630f2730 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -59,7 +59,11 @@ RealtimeModelSendSessionUpdate, RealtimeModelSendUserInput, ) -from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession, _serialize_tool_output +from agents.realtime.session import ( + REJECTION_MESSAGE, + RealtimeSession, + _serialize_tool_output, +) from agents.run_context import RunContextWrapper from agents.tool import FunctionTool, tool_namespace from agents.tool_context import ToolContext @@ -352,6 +356,113 @@ def mock_function_tool(): return tool +@pytest.mark.asyncio +async def test_cleanup_awaits_cancelled_background_tasks(mock_model, mock_agent): + session = RealtimeSession(mock_model, mock_agent, None) + guardrail_started = asyncio.Event() + guardrail_finished = asyncio.Event() + tool_started = asyncio.Event() + tool_finished = asyncio.Event() + guardrail_wait = asyncio.Event() + tool_wait = asyncio.Event() + + async def guardrail_task(): + guardrail_started.set() + try: + await guardrail_wait.wait() + finally: + await asyncio.sleep(0) + guardrail_finished.set() + + async def tool_call_task(): + tool_started.set() + try: + await tool_wait.wait() + finally: + await asyncio.sleep(0) + tool_finished.set() + + guardrail = asyncio.create_task(guardrail_task()) + tool_call = asyncio.create_task(tool_call_task()) + session._guardrail_tasks.add(guardrail) + session._tool_call_tasks.add(tool_call) + + await asyncio.wait_for(guardrail_started.wait(), timeout=1) + await asyncio.wait_for(tool_started.wait(), timeout=1) + + await session._cleanup() + + assert guardrail.cancelled() + assert tool_call.cancelled() + assert guardrail_finished.is_set() + assert tool_finished.is_set() + assert not session._guardrail_tasks + assert not session._tool_call_tasks + + +@pytest.mark.asyncio +async def test_cleanup_bounds_cancellation_resistant_background_tasks( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, + mock_model, + mock_agent, +): + monkeypatch.setattr( + "agents.realtime.session._BACKGROUND_TASK_CLEANUP_TIMEOUT_SECONDS", + 0.01, + ) + caplog.set_level("WARNING", logger="openai.agents") + + session = RealtimeSession(mock_model, mock_agent, None) + guardrail_started = asyncio.Event() + guardrail_cancelled = asyncio.Event() + guardrail_release = asyncio.Event() + tool_started = asyncio.Event() + tool_cancelled = asyncio.Event() + tool_release = asyncio.Event() + + async def cancellation_resistant_task( + started: asyncio.Event, + cancelled: asyncio.Event, + release: asyncio.Event, + ) -> None: + started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancelled.set() + await release.wait() + + guardrail = asyncio.create_task( + cancellation_resistant_task(guardrail_started, guardrail_cancelled, guardrail_release) + ) + tool_call = asyncio.create_task( + cancellation_resistant_task(tool_started, tool_cancelled, tool_release) + ) + session._guardrail_tasks.add(guardrail) + session._tool_call_tasks.add(tool_call) + + await asyncio.wait_for(guardrail_started.wait(), timeout=1) + await asyncio.wait_for(tool_started.wait(), timeout=1) + + try: + await asyncio.wait_for(session._cleanup(), timeout=1) + + assert guardrail_cancelled.is_set() + assert tool_cancelled.is_set() + assert not guardrail.done() + assert not tool_call.done() + assert not session._guardrail_tasks + assert not session._tool_call_tasks + assert session._closed is True + assert "Realtime guardrail background task" in caplog.text + assert "Realtime tool-call background task" in caplog.text + finally: + guardrail_release.set() + tool_release.set() + await asyncio.gather(guardrail, tool_call, return_exceptions=True) + + @pytest.fixture def mock_handoff(): handoff = Mock(spec=Handoff) diff --git a/tests/realtime/test_session_exceptions.py b/tests/realtime/test_session_exceptions.py index da93902368..e08d4b2f2f 100644 --- a/tests/realtime/test_session_exceptions.py +++ b/tests/realtime/test_session_exceptions.py @@ -249,16 +249,26 @@ async def test_exception_during_guardrail_processing( session = RealtimeSession(fake_model, fake_agent, None) - # Add some fake guardrail tasks - fake_task1 = Mock() - fake_task1.done.return_value = False - fake_task1.cancel = Mock() + task_started = asyncio.Event() + task_finished = asyncio.Event() + task_wait = asyncio.Event() - fake_task2 = Mock() - fake_task2.done.return_value = True - fake_task2.cancel = Mock() + async def long_running_task(): + task_started.set() + try: + await task_wait.wait() + finally: + task_finished.set() + + async def completed_task(): + return None + + pending_task = asyncio.create_task(long_running_task()) + done_task = asyncio.create_task(completed_task()) + await asyncio.wait_for(task_started.wait(), timeout=1) + await done_task - session._guardrail_tasks = {fake_task1, fake_task2} + session._guardrail_tasks = {pending_task, done_task} fake_model.set_next_events([exception_event]) @@ -267,9 +277,10 @@ async def test_exception_during_guardrail_processing( async for _event in session: pass - # Verify guardrail tasks were properly cleaned up - fake_task1.cancel.assert_called_once() - fake_task2.cancel.assert_not_called() # Already done + # Verify guardrail tasks were properly cleaned up. + assert pending_task.cancelled() + assert done_task.done() + assert task_finished.is_set() assert len(session._guardrail_tasks) == 0 @pytest.mark.asyncio