88from collections .abc import AsyncIterable , AsyncIterator
99from dataclasses import dataclass
1010from types import TracebackType
11- from typing import Generic , Literal , TypeVar , Union
11+ from typing import TYPE_CHECKING , Generic , Literal , TypeVar , Union
1212
1313from pydantic import BaseModel , ConfigDict , Field
1414
1717from .._exceptions import APIError
1818from ..log import logger
1919from ..metrics import TTSMetrics
20- from ..types import DEFAULT_API_CONNECT_OPTIONS , APIConnectOptions
20+ from ..types import DEFAULT_API_CONNECT_OPTIONS , USERDATA_TIMED_TRANSCRIPT , APIConnectOptions
2121from ..utils import aio , audio , codecs , log_exceptions
2222
23+ if TYPE_CHECKING :
24+ from ..voice .io import TimedString
25+
2326lk_dump_tts = int (os .getenv ("LK_DUMP_TTS" , 0 ))
2427
2528
@@ -41,6 +44,8 @@ class SynthesizedAudio:
4144class TTSCapabilities :
4245 streaming : bool
4346 """Whether this TTS supports streaming (generally using websockets)"""
47+ aligned_transcript : bool = False
48+ """Whether this TTS supports aligned transcripts with word timestamps"""
4449
4550
4651class TTSError (BaseModel ):
@@ -563,12 +568,15 @@ def initialize(
563568 self ._num_channels = num_channels
564569 self ._streaming = stream
565570
571+ from ..voice .io import TimedString
572+
566573 self ._write_ch = aio .Chan [
567574 Union [
568575 bytes ,
569576 AudioEmitter ._FlushSegment ,
570577 AudioEmitter ._StartSegment ,
571578 AudioEmitter ._EndSegment ,
579+ TimedString ,
572580 ]
573581 ]()
574582 self ._main_atask = asyncio .create_task (self ._main_task (), name = "AudioEmitter._main_task" )
@@ -622,6 +630,19 @@ def push(self, data: bytes) -> None:
622630
623631 self ._write_ch .send_nowait (data )
624632
633+ def push_timed_transcript (self , delta_text : TimedString | list [TimedString ]) -> None :
634+ if not self ._started :
635+ raise RuntimeError ("AudioEmitter isn't started" )
636+
637+ if self ._write_ch .closed :
638+ return
639+
640+ if isinstance (delta_text , list ):
641+ for text in delta_text :
642+ self ._write_ch .send_nowait (text )
643+ else :
644+ self ._write_ch .send_nowait (delta_text )
645+
625646 def flush (self ) -> None :
626647 if not self ._started :
627648 raise RuntimeError ("AudioEmitter isn't started" )
@@ -655,14 +676,17 @@ async def aclose(self) -> None:
655676
656677 @log_exceptions (logger = logger )
657678 async def _main_task (self ) -> None :
679+ from ..voice .io import TimedString
680+
658681 audio_decoder : codecs .AudioStreamDecoder | None = None
659682 decode_atask : asyncio .Task | None = None
660683 segment_ctx : AudioEmitter ._SegmentContext | None = None
661684 last_frame : rtc .AudioFrame | None = None
662685 debug_frames : list [rtc .AudioFrame ] = []
686+ timed_transcripts : list [TimedString ] = []
663687
664688 def _emit_frame (frame : rtc .AudioFrame | None = None , * , is_final : bool = False ) -> None :
665- nonlocal last_frame , segment_ctx
689+ nonlocal last_frame , segment_ctx , timed_transcripts
666690 assert segment_ctx is not None
667691
668692 if last_frame is None :
@@ -686,6 +710,7 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
686710 if lk_dump_tts :
687711 debug_frames .append (frame )
688712
713+ frame .userdata [USERDATA_TIMED_TRANSCRIPT ] = timed_transcripts
689714 self ._dst_ch .send_nowait (
690715 SynthesizedAudio (
691716 frame = frame ,
@@ -694,9 +719,11 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
694719 is_final = True ,
695720 )
696721 )
722+ timed_transcripts = []
697723 return
698724
699725 if last_frame is not None :
726+ last_frame .userdata [USERDATA_TIMED_TRANSCRIPT ] = timed_transcripts
700727 self ._dst_ch .send_nowait (
701728 SynthesizedAudio (
702729 frame = last_frame ,
@@ -705,6 +732,7 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
705732 is_final = is_final ,
706733 )
707734 )
735+ timed_transcripts = []
708736 segment_ctx .audio_duration += last_frame .duration
709737 self ._audio_durations [- 1 ] += last_frame .duration
710738
@@ -714,12 +742,13 @@ def _emit_frame(frame: rtc.AudioFrame | None = None, *, is_final: bool = False)
714742 last_frame = frame
715743
716744 def _flush_frame () -> None :
717- nonlocal last_frame , segment_ctx
745+ nonlocal last_frame , segment_ctx , timed_transcripts
718746 assert segment_ctx is not None
719747
720748 if last_frame is None :
721749 return
722750
751+ last_frame .userdata [USERDATA_TIMED_TRANSCRIPT ] = timed_transcripts
723752 self ._dst_ch .send_nowait (
724753 SynthesizedAudio (
725754 frame = last_frame ,
@@ -728,6 +757,7 @@ def _flush_frame() -> None:
728757 is_final = False , # flush isn't final
729758 )
730759 )
760+ timed_transcripts = []
731761 segment_ctx .audio_duration += last_frame .duration
732762 self ._audio_durations [- 1 ] += last_frame .duration
733763
@@ -780,6 +810,10 @@ async def _decode_task() -> None:
780810 audio_byte_stream : audio .AudioByteStream | None = None
781811 try :
782812 async for data in self ._write_ch :
813+ if isinstance (data , TimedString ):
814+ timed_transcripts .append (data )
815+ continue
816+
783817 if isinstance (data , AudioEmitter ._StartSegment ):
784818 if segment_ctx :
785819 raise RuntimeError (
0 commit comments