99
1010from fastapi import WebSocket , WebSocketDisconnect
1111
12+ from .static import LISTEN_AUDIO_FORMAT
1213from .types import SpeechRecognizer , StreamingSpeechRecognizer , StreamingSpeechSession
1314
1415logger = getLogger (__name__ )
@@ -29,20 +30,13 @@ def __init__(
2930 speech_recognizer : SpeechRecognizer ,
3031 recordings_dir : Path ,
3132 debug_recording : bool ,
32- sample_rate_hz : int ,
33- channels : int ,
34- sample_width : int ,
3533 listen_audio_timeout_seconds : float ,
36- language_code : str = "ja-JP" ,
3734 ) -> None :
3835 self .speech_recognizer = speech_recognizer
3936 self .recordings_dir = recordings_dir
4037 self .debug_recording = debug_recording
41- self .sample_rate_hz = sample_rate_hz
42- self .channels = channels
43- self .sample_width = sample_width
38+ self .audio_format = LISTEN_AUDIO_FORMAT
4439 self .listen_audio_timeout_seconds = listen_audio_timeout_seconds
45- self .language_code = language_code
4640
4741 self ._pcm_buffer = bytearray ()
4842 self ._streaming = False
@@ -96,12 +90,7 @@ async def handle_start(self, websocket: WebSocket) -> bool:
9690 self ._message_error = None
9791 if isinstance (self .speech_recognizer , StreamingSpeechRecognizer ):
9892 try :
99- self ._speech_stream = await self .speech_recognizer .start_stream (
100- sample_rate_hz = self .sample_rate_hz ,
101- channels = self .channels ,
102- sample_width = self .sample_width ,
103- language_code = self .language_code ,
104- )
93+ self ._speech_stream = await self .speech_recognizer .start_stream ()
10594 except Exception :
10695 asyncio .create_task (websocket .close (code = 1011 , reason = "speech streaming failed" ))
10796 return False
@@ -113,7 +102,7 @@ async def handle_data(self, websocket: WebSocket, payload_bytes: int, payload: b
113102 await self ._abort_speech_stream ()
114103 asyncio .create_task (websocket .close (code = 1003 , reason = "data received before start" ))
115104 return False
116- if payload_bytes % (self .sample_width * self .channels ) != 0 :
105+ if payload_bytes % (self .audio_format . sample_width * self . audio_format .channels ) != 0 :
117106 await self ._abort_speech_stream ()
118107 asyncio .create_task (websocket .close (code = 1003 , reason = "invalid pcm chunk length" ))
119108 return False
@@ -142,7 +131,7 @@ async def handle_end(
142131 await self ._abort_speech_stream ()
143132 await websocket .close (code = 1003 , reason = "end received before start" )
144133 return
145- if payload_bytes % (self .sample_width * self .channels ) != 0 :
134+ if payload_bytes % (self .audio_format . sample_width * self . audio_format .channels ) != 0 :
146135 await self ._abort_speech_stream ()
147136 await websocket .close (code = 1003 , reason = "invalid pcm tail length" )
148137 return
@@ -155,19 +144,21 @@ async def handle_end(
155144 await websocket .close (code = 1011 , reason = "speech streaming failed" )
156145 return
157146
158- if len (self ._pcm_buffer ) == 0 or len (self ._pcm_buffer ) % (self .sample_width * self .channels ) != 0 :
147+ if len (self ._pcm_buffer ) == 0 or len (self ._pcm_buffer ) % (
148+ self .audio_format .sample_width * self .audio_format .channels
149+ ) != 0 :
159150 await self ._abort_speech_stream ()
160151 await websocket .close (code = 1003 , reason = "invalid accumulated pcm length" )
161152 return
162153
163154 await send_state_command (thinking_state )
164155
165- frames = len (self ._pcm_buffer ) // (self .sample_width * self .channels )
166- duration_seconds = frames / float (self .sample_rate_hz )
156+ frames = len (self ._pcm_buffer ) // (self .audio_format . sample_width * self . audio_format .channels )
157+ duration_seconds = frames / float (self .audio_format . sample_rate_hz )
167158 ws_meta = {
168- "sample_rate" : self .sample_rate_hz ,
159+ "sample_rate" : self .audio_format . sample_rate_hz ,
169160 "frames" : frames ,
170- "channels" : self .channels ,
161+ "channels" : self .audio_format . channels ,
171162 "duration_seconds" : round (duration_seconds , 3 ),
172163 }
173164 if self .debug_recording :
@@ -197,9 +188,9 @@ def _save_wav(self, pcm_bytes: bytes) -> tuple[Path, str]:
197188 filepath = self .recordings_dir / filename
198189
199190 with wave .open (str (filepath ), "wb" ) as wav_fp :
200- wav_fp .setnchannels (self .channels )
201- wav_fp .setsampwidth (self .sample_width )
202- wav_fp .setframerate (self .sample_rate_hz )
191+ wav_fp .setnchannels (self .audio_format . channels )
192+ wav_fp .setsampwidth (self .audio_format . sample_width )
193+ wav_fp .setframerate (self .audio_format . sample_rate_hz )
203194 wav_fp .writeframes (pcm_bytes )
204195
205196 logger .info ("Saved WAV: %s" , filename )
@@ -211,13 +202,7 @@ async def _transcribe_async(self, pcm_bytes: bytes) -> str:
211202 return await self ._transcribe (pcm_bytes )
212203
213204 async def _transcribe (self , pcm_bytes : bytes ) -> str :
214- transcript = await self .speech_recognizer .transcribe (
215- pcm_bytes ,
216- sample_rate_hz = self .sample_rate_hz ,
217- channels = self .channels ,
218- sample_width = self .sample_width ,
219- language_code = self .language_code ,
220- )
205+ transcript = await self .speech_recognizer .transcribe (pcm_bytes )
221206 if transcript :
222207 logger .info ("Transcript: %s" , transcript )
223208 return transcript
0 commit comments