Skip to content

Commit 8c4e41e

Browse files
committed
fix(google): finalize realtime generation when turn_complete is missing
1 parent d75d61b commit 8c4e41e

2 files changed

Lines changed: 120 additions & 0 deletions

File tree

livekit-plugins/livekit-plugins-google/livekit/plugins/google/realtime/realtime_api.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
INPUT_AUDIO_CHANNELS = 1
3535
OUTPUT_AUDIO_SAMPLE_RATE = 24000
3636
OUTPUT_AUDIO_CHANNELS = 1
37+
TURN_COMPLETE_FALLBACK_SECONDS = 1.0
3738

3839
DEFAULT_IMAGE_ENCODE_OPTIONS = images.EncodeOptions(
3940
format="JPEG",
@@ -463,6 +464,7 @@ def __init__(self, realtime_model: RealtimeModel) -> None:
463464
self._session_should_close = asyncio.Event()
464465
self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {}
465466
self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
467+
self._turn_complete_fallback_task: asyncio.Task[None] | None = None
466468

467469
self._session_resumption_handle: str | None = (
468470
self._opts.session_resumption.handle
@@ -728,6 +730,7 @@ def truncate(
728730
pass
729731

730732
async def aclose(self) -> None:
733+
self._cancel_turn_complete_fallback()
731734
self._msg_ch.close()
732735
self._session_should_close.set()
733736

@@ -1037,6 +1040,8 @@ def _build_connect_config(self) -> types.LiveConnectConfig:
10371040
return conf
10381041

10391042
def _start_new_generation(self) -> None:
1043+
self._cancel_turn_complete_fallback()
1044+
10401045
if self._current_generation and not self._current_generation._done:
10411046
logger.warning("starting new generation while another is active. Finalizing previous.")
10421047
self._mark_current_generation_done()
@@ -1140,14 +1145,50 @@ def _handle_server_content(self, server_content: types.LiveServerContent) -> Non
11401145
if server_content.generation_complete or server_content.turn_complete:
11411146
current_gen._completed_timestamp = time.time()
11421147

1148+
if server_content.generation_complete and not server_content.turn_complete:
1149+
self._schedule_turn_complete_fallback(current_gen.response_id)
1150+
11431151
if server_content.interrupted and not self._pending_generation_fut:
11441152
# interrupt agent if there is no pending user initiated generation
11451153
self._handle_input_speech_started()
11461154

11471155
if server_content.turn_complete:
11481156
self._mark_current_generation_done()
11491157

1158+
def _cancel_turn_complete_fallback(self) -> None:
1159+
if self._turn_complete_fallback_task and not self._turn_complete_fallback_task.done():
1160+
self._turn_complete_fallback_task.cancel()
1161+
self._turn_complete_fallback_task = None
1162+
1163+
def _schedule_turn_complete_fallback(self, response_id: str) -> None:
1164+
self._cancel_turn_complete_fallback()
1165+
self._turn_complete_fallback_task = asyncio.create_task(
1166+
self._wait_for_turn_complete_fallback(
1167+
response_id=response_id,
1168+
timeout=TURN_COMPLETE_FALLBACK_SECONDS,
1169+
),
1170+
name=f"gemini_turn_complete_fallback_{response_id}",
1171+
)
1172+
1173+
async def _wait_for_turn_complete_fallback(self, *, response_id: str, timeout: float) -> None:
1174+
try:
1175+
await asyncio.sleep(timeout)
1176+
except asyncio.CancelledError:
1177+
return
1178+
1179+
current_gen = self._current_generation
1180+
if not current_gen or current_gen._done or current_gen.response_id != response_id:
1181+
return
1182+
1183+
logger.warning(
1184+
"Gemini Realtime did not emit turn_complete after generation_complete; "
1185+
f"finalizing generation (response_id={response_id}, timeout={timeout:.2f}s)"
1186+
)
1187+
self._mark_current_generation_done()
1188+
11501189
def _mark_current_generation_done(self) -> None:
1190+
self._cancel_turn_complete_fallback()
1191+
11511192
if not self._current_generation or self._current_generation._done:
11521193
return
11531194

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from types import SimpleNamespace
2+
3+
from livekit import rtc
4+
from livekit.agents import llm, utils
5+
from livekit.plugins.google.realtime.realtime_api import (
6+
RealtimeSession,
7+
_ResponseGeneration, # pyright: ignore[reportPrivateUsage]
8+
)
9+
10+
11+
def _make_session() -> RealtimeSession:
12+
session = RealtimeSession.__new__(RealtimeSession)
13+
session._current_generation = None
14+
session._turn_complete_fallback_task = None
15+
session._opts = SimpleNamespace(output_audio_transcription=object())
16+
session._chat_ctx = llm.ChatContext.empty()
17+
session._pending_generation_fut = None
18+
session.emit = lambda *args, **kwargs: None
19+
return session
20+
21+
22+
def _make_generation(response_id: str = "GR_test") -> _ResponseGeneration:
23+
return _ResponseGeneration(
24+
message_ch=utils.aio.Chan[llm.MessageGeneration](),
25+
function_ch=utils.aio.Chan[llm.FunctionCall](),
26+
input_id="GI_test",
27+
response_id=response_id,
28+
text_ch=utils.aio.Chan[str](),
29+
audio_ch=utils.aio.Chan[rtc.AudioFrame](),
30+
)
31+
32+
33+
def _server_content(
34+
*, generation_complete: bool = False, turn_complete: bool = False
35+
) -> SimpleNamespace:
36+
return SimpleNamespace(
37+
model_turn=None,
38+
input_transcription=None,
39+
output_transcription=None,
40+
generation_complete=generation_complete,
41+
turn_complete=turn_complete,
42+
interrupted=False,
43+
)
44+
45+
46+
async def test_generation_complete_schedules_turn_complete_fallback() -> None:
47+
sess = _make_session()
48+
gen = _make_generation("GR_schedule")
49+
sess._current_generation = gen
50+
51+
scheduled: list[str] = []
52+
sess._schedule_turn_complete_fallback = scheduled.append
53+
54+
sess._handle_server_content(_server_content(generation_complete=True))
55+
56+
assert scheduled == [gen.response_id]
57+
assert not gen._done
58+
59+
60+
async def test_turn_complete_still_finalizes_generation() -> None:
61+
sess = _make_session()
62+
gen = _make_generation("GR_turn")
63+
sess._current_generation = gen
64+
65+
sess._handle_server_content(_server_content(turn_complete=True))
66+
67+
assert gen._done
68+
69+
70+
async def test_turn_complete_fallback_finalizes_generation() -> None:
71+
sess = _make_session()
72+
gen = _make_generation("GR_fallback")
73+
sess._current_generation = gen
74+
75+
await sess._wait_for_turn_complete_fallback(response_id=gen.response_id, timeout=0.0)
76+
77+
assert gen._done
78+
assert gen.audio_ch.closed
79+
assert gen.message_ch.closed

0 commit comments

Comments
 (0)