Skip to content

Commit d870f87

Browse files
authored
support aligned transcripts from tts (#2580)
1 parent 9c69195 commit d870f87

15 files changed

Lines changed: 339 additions & 65 deletions

File tree

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"livekit-agents": patch
3+
"livekit-plugins-cartesia": patch
4+
"livekit-plugins-elevenlabs": patch
5+
---
6+
7+
support aligned transcripts with timestamps from tts (#2580)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import asyncio
2+
import logging
3+
from collections.abc import AsyncGenerator, AsyncIterable
4+
5+
from dotenv import load_dotenv
6+
7+
from livekit.agents import Agent, AgentSession, JobContext, WorkerOptions, cli
8+
from livekit.agents.voice.agent import ModelSettings
9+
from livekit.agents.voice.io import TimedString
10+
from livekit.plugins import cartesia, deepgram, openai, silero
11+
12+
logger = logging.getLogger("my-worker")
13+
logger.setLevel(logging.INFO)
14+
15+
load_dotenv()
16+
17+
18+
# This example shows how to obtain the timed transcript from the TTS.
19+
# Right now, it's supported for Cartesia and ElevenLabs TTS (word level timestamps)
20+
# and non-streaming TTS with StreamAdapter (sentence level timestamps).
21+
22+
23+
class MyAgent(Agent):
24+
def __init__(self):
25+
super().__init__(instructions="You are a helpful assistant.")
26+
27+
self._closing_task: asyncio.Task[None] | None = None
28+
29+
async def transcription_node(
30+
self, text: AsyncIterable[str | TimedString], model_settings: ModelSettings
31+
) -> AsyncGenerator[str | TimedString, None]:
32+
async for chunk in text:
33+
if isinstance(chunk, TimedString):
34+
logger.info(f"TimedString: '{chunk}' ({chunk.start_time} - {chunk.end_time})")
35+
yield chunk
36+
37+
38+
async def entrypoint(ctx: JobContext):
39+
session = AgentSession(
40+
stt=deepgram.STT(),
41+
llm=openai.LLM(),
42+
tts=cartesia.TTS(),
43+
vad=silero.VAD.load(),
44+
# enable TTS-aligned transcript, can be configured at the Agent level as well
45+
use_tts_aligned_transcript=True,
46+
)
47+
48+
await session.start(agent=MyAgent(), room=ctx.room)
49+
50+
session.generate_reply(instructions="say hello to the user")
51+
52+
53+
if __name__ == "__main__":
54+
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))

livekit-agents/livekit/agents/tokenize/blingfire.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,32 @@
1212
]
1313

1414

15-
def _split_sentences(text: str, min_sentence_len: int) -> list[tuple[str, int, int]]:
16-
bf_sentences, offsets = blingfire.text_to_sentences_with_offsets(text)
17-
raw_sentences = bf_sentences.split("\n")
15+
def _split_sentences(
16+
text: str, min_sentence_len: int, *, retain_format: bool = False
17+
) -> list[tuple[str, int, int]]:
18+
_, offsets = blingfire.text_to_sentences_with_offsets(text)
1819

1920
merged_sentences = []
20-
buffer = ""
21-
buffer_start = None
21+
start = 0
2222

23-
for i, (sentence, (start, end)) in enumerate(zip(raw_sentences, offsets)):
24-
sentence = sentence.strip()
25-
if not sentence:
23+
for _, end in offsets:
24+
raw_sentence = text[start:end]
25+
sentence = raw_sentence.strip()
26+
if not sentence or len(sentence) < min_sentence_len:
2627
continue
2728

28-
if buffer:
29-
buffer += " " + sentence
30-
buffer_end = end
29+
if retain_format:
30+
merged_sentences.append((raw_sentence, start, end))
3131
else:
32-
buffer = sentence
33-
buffer_start = start
34-
buffer_end = end
32+
merged_sentences.append((sentence, start, end))
33+
start = end
3534

36-
if len(buffer) >= min_sentence_len or i == len(offsets) - 1:
37-
merged_sentences.append((buffer, buffer_start, buffer_end))
38-
buffer = ""
39-
buffer_start = None
35+
if start < len(text):
36+
raw_sentence = text[start:]
37+
if retain_format:
38+
merged_sentences.append((raw_sentence, start, len(text)))
39+
elif sentence := raw_sentence.strip():
40+
merged_sentences.append((sentence, start, len(text)))
4041

4142
return merged_sentences
4243

@@ -45,6 +46,7 @@ def _split_sentences(text: str, min_sentence_len: int) -> list[tuple[str, int, i
4546
class _TokenizerOptions:
4647
min_sentence_len: int
4748
stream_context_len: int
49+
retain_format: bool
4850

4951

5052
class SentenceTokenizer(tokenizer.SentenceTokenizer):
@@ -53,20 +55,30 @@ def __init__(
5355
*,
5456
min_sentence_len: int = 20,
5557
stream_context_len: int = 10,
58+
retain_format: bool = False,
5659
) -> None:
5760
self._config = _TokenizerOptions(
58-
min_sentence_len=min_sentence_len, stream_context_len=stream_context_len
61+
min_sentence_len=min_sentence_len,
62+
stream_context_len=stream_context_len,
63+
retain_format=retain_format,
5964
)
6065

6166
def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
6267
return [
63-
tok[0] for tok in _split_sentences(text, min_sentence_len=self._config.min_sentence_len)
68+
tok[0]
69+
for tok in _split_sentences(
70+
text,
71+
min_sentence_len=self._config.min_sentence_len,
72+
retain_format=self._config.retain_format,
73+
)
6474
]
6575

6676
def stream(self, *, language: str | None = None) -> tokenizer.SentenceStream:
6777
return token_stream.BufferedSentenceStream(
6878
tokenizer=functools.partial(
69-
_split_sentences, min_sentence_len=self._config.min_sentence_len
79+
_split_sentences,
80+
min_sentence_len=self._config.min_sentence_len,
81+
retain_format=self._config.retain_format,
7082
),
7183
min_token_len=self._config.min_sentence_len,
7284
min_ctx_len=self._config.stream_context_len,

livekit-agents/livekit/agents/tts/fallback_adapter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .. import utils
1313
from .._exceptions import APIConnectionError
1414
from ..log import logger
15-
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
15+
from ..types import DEFAULT_API_CONNECT_OPTIONS, USERDATA_TIMED_TRANSCRIPT, APIConnectOptions
1616
from ..utils import aio
1717
from .tts import (
1818
TTS,
@@ -83,6 +83,7 @@ def __init__(
8383
super().__init__(
8484
capabilities=TTSCapabilities(
8585
streaming=all(t.capabilities.streaming for t in tts),
86+
aligned_transcript=all(t.capabilities.aligned_transcript for t in tts),
8687
),
8788
sample_rate=sample_rate,
8889
num_channels=num_channels,
@@ -202,6 +203,9 @@ async def _run(self, output_emitter: AudioEmitter) -> None:
202203
try:
203204
resampler = tts_status.resampler
204205
async for synthesized_audio in self._try_synthesize(tts=tts, recovering=False):
206+
if texts := synthesized_audio.frame.userdata.get(USERDATA_TIMED_TRANSCRIPT):
207+
output_emitter.push_timed_transcript(texts)
208+
205209
if resampler is not None:
206210
for rf in resampler.push(synthesized_audio.frame):
207211
output_emitter.push(rf.data.tobytes())
@@ -341,6 +345,11 @@ async def _forward_input_task() -> None:
341345
),
342346
recovering=False,
343347
):
348+
if texts := synthesized_audio.frame.userdata.get(
349+
USERDATA_TIMED_TRANSCRIPT
350+
):
351+
output_emitter.push_timed_transcript(texts)
352+
344353
if resampler is not None:
345354
for resampled_frame in resampler.push(synthesized_audio.frame):
346355
output_emitter.push(resampled_frame.data.tobytes())

livekit-agents/livekit/agents/tts/stream_adapter.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def __init__(
2929
sentence_tokenizer: NotGivenOr[tokenize.SentenceTokenizer] = NOT_GIVEN,
3030
) -> None:
3131
super().__init__(
32-
capabilities=TTSCapabilities(
33-
streaming=True,
34-
),
32+
capabilities=TTSCapabilities(streaming=True, aligned_transcript=True),
3533
sample_rate=tts.sample_rate,
3634
num_channels=tts.num_channels,
3735
)
3836
self._wrapped_tts = tts
39-
self._sentence_tokenizer = sentence_tokenizer or tokenize.blingfire.SentenceTokenizer()
37+
self._sentence_tokenizer = sentence_tokenizer or tokenize.blingfire.SentenceTokenizer(
38+
retain_format=True
39+
)
4040

4141
@self._wrapped_tts.on("metrics_collected")
4242
def _forward_metrics(*args: Any, **kwargs: Any) -> None:
@@ -91,12 +91,19 @@ async def _forward_input() -> None:
9191
self._sent_stream.end_input()
9292

9393
async def _synthesize() -> None:
94+
from ..voice.io import TimedString
95+
96+
duration = 0.0
9497
async for ev in self._sent_stream:
98+
output_emitter.push_timed_transcript(
99+
TimedString(text=ev.token, start_time=duration)
100+
)
95101
async with self._tts._wrapped_tts.synthesize(
96-
ev.token, conn_options=self._wrapped_tts_conn_options
102+
ev.token.strip(), conn_options=self._wrapped_tts_conn_options
97103
) as tts_stream:
98104
async for audio in tts_stream:
99105
output_emitter.push(audio.frame.data.tobytes())
106+
duration += audio.frame.duration
100107
output_emitter.flush()
101108

102109
tasks = [

livekit-agents/livekit/agents/tts/tts.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import AsyncIterable, AsyncIterator
99
from dataclasses import dataclass
1010
from types import TracebackType
11-
from typing import Generic, Literal, TypeVar, Union
11+
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, Union
1212

1313
from pydantic import BaseModel, ConfigDict, Field
1414

@@ -17,9 +17,12 @@
1717
from .._exceptions import APIError
1818
from ..log import logger
1919
from ..metrics import TTSMetrics
20-
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
20+
from ..types import DEFAULT_API_CONNECT_OPTIONS, USERDATA_TIMED_TRANSCRIPT, APIConnectOptions
2121
from ..utils import aio, audio, codecs, log_exceptions
2222

23+
if TYPE_CHECKING:
24+
from ..voice.io import TimedString
25+
2326
lk_dump_tts = int(os.getenv("LK_DUMP_TTS", 0))
2427

2528

@@ -41,6 +44,8 @@ class SynthesizedAudio:
4144
class TTSCapabilities:
4245
streaming: bool
4346
"""Whether this TTS supports streaming (generally using websockets)"""
47+
aligned_transcript: bool = False
48+
"""Whether this TTS supports aligned transcripts with word timestamps"""
4449

4550

4651
class TTSError(BaseModel):
@@ -563,12 +568,15 @@ def initialize(
563568
self._num_channels = num_channels
564569
self._streaming = stream
565570

571+
from ..voice.io import TimedString
572+
566573
self._write_ch = aio.Chan[
567574
Union[
568575
bytes,
569576
AudioEmitter._FlushSegment,
570577
AudioEmitter._StartSegment,
571578
AudioEmitter._EndSegment,
579+
TimedString,
572580
]
573581
]()
574582
self._main_atask = asyncio.create_task(self._main_task(), name="AudioEmitter._main_task")
@@ -622,6 +630,19 @@ def push(self, data: bytes) -> None:
622630

623631
self._write_ch.send_nowait(data)
624632

633+
def push_timed_transcript(self, delta_text: TimedString | list[TimedString]) -> None:
634+
if not self._started:
635+
raise RuntimeError("AudioEmitter isn't started")
636+
637+
if self._write_ch.closed:
638+
return
639+
640+
if isinstance(delta_text, list):
641+
for text in delta_text:
642+
self._write_ch.send_nowait(text)
643+
else:
644+
self._write_ch.send_nowait(delta_text)
645+
625646
def flush(self) -> None:
626647
if not self._started:
627648
raise RuntimeError("AudioEmitter isn't started")
@@ -655,14 +676,17 @@ async def aclose(self) -> None:
655676

656677
@log_exceptions(logger=logger)
657678
async def _main_task(self) -> None:
679+
from ..voice.io import TimedString
680+
658681
audio_decoder: codecs.AudioStreamDecoder | None = None
659682
decode_atask: asyncio.Task | None = None
660683
segment_ctx: AudioEmitter._SegmentContext | None = None
661684
last_frame: rtc.AudioFrame | None = None
662685
debug_frames: list[rtc.AudioFrame] = []
686+
timed_transcripts: list[TimedString] = []
663687

664688
def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False) -> None:
665-
nonlocal last_frame, segment_ctx
689+
nonlocal last_frame, segment_ctx, timed_transcripts
666690
assert segment_ctx is not None
667691

668692
if last_frame is None:
@@ -686,6 +710,7 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
686710
if lk_dump_tts:
687711
debug_frames.append(frame)
688712

713+
frame.userdata[USERDATA_TIMED_TRANSCRIPT] = timed_transcripts
689714
self._dst_ch.send_nowait(
690715
SynthesizedAudio(
691716
frame=frame,
@@ -694,9 +719,11 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
694719
is_final=True,
695720
)
696721
)
722+
timed_transcripts = []
697723
return
698724

699725
if last_frame is not None:
726+
last_frame.userdata[USERDATA_TIMED_TRANSCRIPT] = timed_transcripts
700727
self._dst_ch.send_nowait(
701728
SynthesizedAudio(
702729
frame=last_frame,
@@ -705,6 +732,7 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
705732
is_final=is_final,
706733
)
707734
)
735+
timed_transcripts = []
708736
segment_ctx.audio_duration += last_frame.duration
709737
self._audio_durations[-1] += last_frame.duration
710738

@@ -714,12 +742,13 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
714742
last_frame = frame
715743

716744
def _flush_frame() -> None:
717-
nonlocal last_frame, segment_ctx
745+
nonlocal last_frame, segment_ctx, timed_transcripts
718746
assert segment_ctx is not None
719747

720748
if last_frame is None:
721749
return
722750

751+
last_frame.userdata[USERDATA_TIMED_TRANSCRIPT] = timed_transcripts
723752
self._dst_ch.send_nowait(
724753
SynthesizedAudio(
725754
frame=last_frame,
@@ -728,6 +757,7 @@ def _flush_frame() -> None:
728757
is_final=False, # flush isn't final
729758
)
730759
)
760+
timed_transcripts = []
731761
segment_ctx.audio_duration += last_frame.duration
732762
self._audio_durations[-1] += last_frame.duration
733763

@@ -780,6 +810,10 @@ async def _decode_task() -> None:
780810
audio_byte_stream: audio.AudioByteStream | None = None
781811
try:
782812
async for data in self._write_ch:
813+
if isinstance(data, TimedString):
814+
timed_transcripts.append(data)
815+
continue
816+
783817
if isinstance(data, AudioEmitter._StartSegment):
784818
if segment_ctx:
785819
raise RuntimeError(

livekit-agents/livekit/agents/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
TOPIC_CHAT = "lk.chat"
2626
TOPIC_TRANSCRIPTION = "lk.transcription"
2727

28+
USERDATA_TIMED_TRANSCRIPT = "lk.timed_transcripts"
29+
"""
30+
The key for the timed transcripts in the audio frame userdata.
31+
"""
32+
2833

2934
_T = TypeVar("_T")
3035

0 commit comments

Comments
 (0)