Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
)

REJECTION_MESSAGE = DEFAULT_APPROVAL_REJECTION_MESSAGE
_BACKGROUND_TASK_CLEANUP_TIMEOUT_SECONDS = 1.0


class _RealtimeSessionClosedSentinel:
Expand Down Expand Up @@ -1051,11 +1052,35 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_guardrail_tasks(self) -> None:
for task in self._guardrail_tasks:
async def _cleanup_background_tasks(
self, tasks_set: set[asyncio.Task[Any]], task_kind: str
) -> None:
tasks = list(tasks_set)
if not tasks:
return

for task in tasks:
if not task.done():
task.cancel()
self._guardrail_tasks.clear()

done, pending = await asyncio.wait(
tasks,
timeout=_BACKGROUND_TASK_CLEANUP_TIMEOUT_SECONDS,
)
if done:
await asyncio.gather(*done, return_exceptions=True)
if pending:
logger.warning(
"Timed out waiting for %s Realtime %s background task(s) "
"to finish during cleanup; continuing shutdown.",
len(pending),
task_kind,
)

tasks_set.difference_update(tasks)

async def _cleanup_guardrail_tasks(self) -> None:
await self._cleanup_background_tasks(self._guardrail_tasks, "guardrail")

def _enqueue_tool_call_task(
self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent
Expand Down Expand Up @@ -1089,11 +1114,8 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_tool_call_tasks(self) -> None:
for task in self._tool_call_tasks:
if not task.done():
task.cancel()
self._tool_call_tasks.clear()
async def _cleanup_tool_call_tasks(self) -> None:
await self._cleanup_background_tasks(self._tool_call_tasks, "tool-call")

def _wake_event_iterators(self) -> None:
for _ in range(self._event_iterator_waiters):
Expand All @@ -1105,9 +1127,9 @@ async def _cleanup(self) -> None:
self._wake_event_iterators()
return

# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()
# Cancel and clean up background tasks.
await self._cleanup_guardrail_tasks()
await self._cleanup_tool_call_tasks()

# Remove ourselves as a listener
self._model.remove_listener(self)
Expand Down
113 changes: 112 additions & 1 deletion tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@
RealtimeModelSendSessionUpdate,
RealtimeModelSendUserInput,
)
from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession, _serialize_tool_output
from agents.realtime.session import (
REJECTION_MESSAGE,
RealtimeSession,
_serialize_tool_output,
)
from agents.run_context import RunContextWrapper
from agents.tool import FunctionTool, tool_namespace
from agents.tool_context import ToolContext
Expand Down Expand Up @@ -352,6 +356,113 @@ def mock_function_tool():
return tool


@pytest.mark.asyncio
async def test_cleanup_awaits_cancelled_background_tasks(mock_model, mock_agent):
session = RealtimeSession(mock_model, mock_agent, None)
guardrail_started = asyncio.Event()
guardrail_finished = asyncio.Event()
tool_started = asyncio.Event()
tool_finished = asyncio.Event()
guardrail_wait = asyncio.Event()
tool_wait = asyncio.Event()

async def guardrail_task():
guardrail_started.set()
try:
await guardrail_wait.wait()
finally:
await asyncio.sleep(0)
guardrail_finished.set()

async def tool_call_task():
tool_started.set()
try:
await tool_wait.wait()
finally:
await asyncio.sleep(0)
tool_finished.set()

guardrail = asyncio.create_task(guardrail_task())
tool_call = asyncio.create_task(tool_call_task())
session._guardrail_tasks.add(guardrail)
session._tool_call_tasks.add(tool_call)

await asyncio.wait_for(guardrail_started.wait(), timeout=1)
await asyncio.wait_for(tool_started.wait(), timeout=1)

await session._cleanup()

assert guardrail.cancelled()
assert tool_call.cancelled()
assert guardrail_finished.is_set()
assert tool_finished.is_set()
assert not session._guardrail_tasks
assert not session._tool_call_tasks


@pytest.mark.asyncio
async def test_cleanup_bounds_cancellation_resistant_background_tasks(
monkeypatch: pytest.MonkeyPatch,
caplog: pytest.LogCaptureFixture,
mock_model,
mock_agent,
):
monkeypatch.setattr(
"agents.realtime.session._BACKGROUND_TASK_CLEANUP_TIMEOUT_SECONDS",
0.01,
)
caplog.set_level("WARNING", logger="openai.agents")

session = RealtimeSession(mock_model, mock_agent, None)
guardrail_started = asyncio.Event()
guardrail_cancelled = asyncio.Event()
guardrail_release = asyncio.Event()
tool_started = asyncio.Event()
tool_cancelled = asyncio.Event()
tool_release = asyncio.Event()

async def cancellation_resistant_task(
started: asyncio.Event,
cancelled: asyncio.Event,
release: asyncio.Event,
) -> None:
started.set()
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
cancelled.set()
await release.wait()

guardrail = asyncio.create_task(
cancellation_resistant_task(guardrail_started, guardrail_cancelled, guardrail_release)
)
tool_call = asyncio.create_task(
cancellation_resistant_task(tool_started, tool_cancelled, tool_release)
)
session._guardrail_tasks.add(guardrail)
session._tool_call_tasks.add(tool_call)

await asyncio.wait_for(guardrail_started.wait(), timeout=1)
await asyncio.wait_for(tool_started.wait(), timeout=1)

try:
await asyncio.wait_for(session._cleanup(), timeout=1)

assert guardrail_cancelled.is_set()
assert tool_cancelled.is_set()
assert not guardrail.done()
assert not tool_call.done()
assert not session._guardrail_tasks
assert not session._tool_call_tasks
assert session._closed is True
assert "Realtime guardrail background task" in caplog.text
assert "Realtime tool-call background task" in caplog.text
finally:
guardrail_release.set()
tool_release.set()
await asyncio.gather(guardrail, tool_call, return_exceptions=True)


@pytest.fixture
def mock_handoff():
handoff = Mock(spec=Handoff)
Expand Down
33 changes: 22 additions & 11 deletions tests/realtime/test_session_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,26 @@ async def test_exception_during_guardrail_processing(

session = RealtimeSession(fake_model, fake_agent, None)

# Add some fake guardrail tasks
fake_task1 = Mock()
fake_task1.done.return_value = False
fake_task1.cancel = Mock()
task_started = asyncio.Event()
task_finished = asyncio.Event()
task_wait = asyncio.Event()

fake_task2 = Mock()
fake_task2.done.return_value = True
fake_task2.cancel = Mock()
async def long_running_task():
task_started.set()
try:
await task_wait.wait()
finally:
task_finished.set()

async def completed_task():
return None

pending_task = asyncio.create_task(long_running_task())
done_task = asyncio.create_task(completed_task())
await asyncio.wait_for(task_started.wait(), timeout=1)
await done_task

session._guardrail_tasks = {fake_task1, fake_task2}
session._guardrail_tasks = {pending_task, done_task}

fake_model.set_next_events([exception_event])

Expand All @@ -267,9 +277,10 @@ async def test_exception_during_guardrail_processing(
async for _event in session:
pass

# Verify guardrail tasks were properly cleaned up
fake_task1.cancel.assert_called_once()
fake_task2.cancel.assert_not_called() # Already done
# Verify guardrail tasks were properly cleaned up.
assert pending_task.cancelled()
assert done_task.done()
assert task_finished.is_set()
assert len(session._guardrail_tasks) == 0

@pytest.mark.asyncio
Expand Down