Skip to content

Commit e819cc0

Browse files
authored
Merge pull request #17 from 74th/feat/whisper-cpp
Enhance server-based speech recognition and configuration options
2 parents 79d47e3 + e964b13 commit e819cc0

10 files changed

Lines changed: 292 additions & 124 deletions

File tree

example_apps/echo.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from logging import getLogger
66

77
from stackchan_server.app import StackChanApp
8-
from stackchan_server.speech_recognition import WhisperCppSpeechToText
8+
from stackchan_server.speech_recognition import WhisperCppSpeechToText, WhisperServerSpeechToText
99
from stackchan_server.speech_synthesis import VoiceVoxSpeechSynthesizer
1010
from stackchan_server.ws_proxy import EmptyTranscriptError, WsProxy
1111

@@ -17,7 +17,14 @@
1717
)
1818

1919
def _create_app() -> StackChanApp:
20+
whisper_server_url = os.getenv("STACKCHAN_WHISPER_SERVER_URL")
21+
whisper_server_port = os.getenv("STACKCHAN_WHISPER_SERVER_PORT")
2022
whisper_model = os.getenv("STACKCHAN_WHISPER_MODEL")
23+
if whisper_server_url or whisper_server_port:
24+
return StackChanApp(
25+
speech_recognizer=WhisperServerSpeechToText(server_url=whisper_server_url),
26+
speech_synthesizer=VoiceVoxSpeechSynthesizer(),
27+
)
2128
if whisper_model:
2229
return StackChanApp(
2330
speech_recognizer=WhisperCppSpeechToText(
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
set -xe
3+
4+
whisper-server \
5+
--host 0.0.0.0 \
6+
--port ${STACKCHAN_WHISPER_SERVER_PORT} \
7+
-m ${STACKCHAN_WHISPER_MODEL} \
8+
-l ja \
9+
-nt \
10+
--vad \
11+
-vm ${STACKCHAN_WHISPER_VAD_MODEL} \
12+
-vt 0.6 \
13+
-vspd 250 \
14+
-vsd 400 \
15+
-vp 30

stackchan_server/listen.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from fastapi import WebSocket, WebSocketDisconnect
1111

12+
from .static import LISTEN_AUDIO_FORMAT
1213
from .types import SpeechRecognizer, StreamingSpeechRecognizer, StreamingSpeechSession
1314

1415
logger = getLogger(__name__)
@@ -29,20 +30,13 @@ def __init__(
2930
speech_recognizer: SpeechRecognizer,
3031
recordings_dir: Path,
3132
debug_recording: bool,
32-
sample_rate_hz: int,
33-
channels: int,
34-
sample_width: int,
3533
listen_audio_timeout_seconds: float,
36-
language_code: str = "ja-JP",
3734
) -> None:
3835
self.speech_recognizer = speech_recognizer
3936
self.recordings_dir = recordings_dir
4037
self.debug_recording = debug_recording
41-
self.sample_rate_hz = sample_rate_hz
42-
self.channels = channels
43-
self.sample_width = sample_width
38+
self.audio_format = LISTEN_AUDIO_FORMAT
4439
self.listen_audio_timeout_seconds = listen_audio_timeout_seconds
45-
self.language_code = language_code
4640

4741
self._pcm_buffer = bytearray()
4842
self._streaming = False
@@ -96,12 +90,7 @@ async def handle_start(self, websocket: WebSocket) -> bool:
9690
self._message_error = None
9791
if isinstance(self.speech_recognizer, StreamingSpeechRecognizer):
9892
try:
99-
self._speech_stream = await self.speech_recognizer.start_stream(
100-
sample_rate_hz=self.sample_rate_hz,
101-
channels=self.channels,
102-
sample_width=self.sample_width,
103-
language_code=self.language_code,
104-
)
93+
self._speech_stream = await self.speech_recognizer.start_stream()
10594
except Exception:
10695
asyncio.create_task(websocket.close(code=1011, reason="speech streaming failed"))
10796
return False
@@ -113,7 +102,7 @@ async def handle_data(self, websocket: WebSocket, payload_bytes: int, payload: b
113102
await self._abort_speech_stream()
114103
asyncio.create_task(websocket.close(code=1003, reason="data received before start"))
115104
return False
116-
if payload_bytes % (self.sample_width * self.channels) != 0:
105+
if payload_bytes % (self.audio_format.sample_width * self.audio_format.channels) != 0:
117106
await self._abort_speech_stream()
118107
asyncio.create_task(websocket.close(code=1003, reason="invalid pcm chunk length"))
119108
return False
@@ -142,7 +131,7 @@ async def handle_end(
142131
await self._abort_speech_stream()
143132
await websocket.close(code=1003, reason="end received before start")
144133
return
145-
if payload_bytes % (self.sample_width * self.channels) != 0:
134+
if payload_bytes % (self.audio_format.sample_width * self.audio_format.channels) != 0:
146135
await self._abort_speech_stream()
147136
await websocket.close(code=1003, reason="invalid pcm tail length")
148137
return
@@ -155,19 +144,21 @@ async def handle_end(
155144
await websocket.close(code=1011, reason="speech streaming failed")
156145
return
157146

158-
if len(self._pcm_buffer) == 0 or len(self._pcm_buffer) % (self.sample_width * self.channels) != 0:
147+
if len(self._pcm_buffer) == 0 or len(self._pcm_buffer) % (
148+
self.audio_format.sample_width * self.audio_format.channels
149+
) != 0:
159150
await self._abort_speech_stream()
160151
await websocket.close(code=1003, reason="invalid accumulated pcm length")
161152
return
162153

163154
await send_state_command(thinking_state)
164155

165-
frames = len(self._pcm_buffer) // (self.sample_width * self.channels)
166-
duration_seconds = frames / float(self.sample_rate_hz)
156+
frames = len(self._pcm_buffer) // (self.audio_format.sample_width * self.audio_format.channels)
157+
duration_seconds = frames / float(self.audio_format.sample_rate_hz)
167158
ws_meta = {
168-
"sample_rate": self.sample_rate_hz,
159+
"sample_rate": self.audio_format.sample_rate_hz,
169160
"frames": frames,
170-
"channels": self.channels,
161+
"channels": self.audio_format.channels,
171162
"duration_seconds": round(duration_seconds, 3),
172163
}
173164
if self.debug_recording:
@@ -197,9 +188,9 @@ def _save_wav(self, pcm_bytes: bytes) -> tuple[Path, str]:
197188
filepath = self.recordings_dir / filename
198189

199190
with wave.open(str(filepath), "wb") as wav_fp:
200-
wav_fp.setnchannels(self.channels)
201-
wav_fp.setsampwidth(self.sample_width)
202-
wav_fp.setframerate(self.sample_rate_hz)
191+
wav_fp.setnchannels(self.audio_format.channels)
192+
wav_fp.setsampwidth(self.audio_format.sample_width)
193+
wav_fp.setframerate(self.audio_format.sample_rate_hz)
203194
wav_fp.writeframes(pcm_bytes)
204195

205196
logger.info("Saved WAV: %s", filename)
@@ -211,13 +202,7 @@ async def _transcribe_async(self, pcm_bytes: bytes) -> str:
211202
return await self._transcribe(pcm_bytes)
212203

213204
async def _transcribe(self, pcm_bytes: bytes) -> str:
214-
transcript = await self.speech_recognizer.transcribe(
215-
pcm_bytes,
216-
sample_rate_hz=self.sample_rate_hz,
217-
channels=self.channels,
218-
sample_width=self.sample_width,
219-
language_code=self.language_code,
220-
)
205+
transcript = await self.speech_recognizer.transcribe(pcm_bytes)
221206
if transcript:
222207
logger.info("Transcript: %s", transcript)
223208
return transcript

stackchan_server/speech_recognition/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
from ..types import SpeechRecognizer
44
from .google_cloud import GoogleCloudSpeechToText
55
from .whisper_cpp import WhisperCppSpeechToText
6+
from .whisper_server import WhisperServerSpeechToText
67

78

89
def create_speech_recognizer() -> SpeechRecognizer:
910
return GoogleCloudSpeechToText()
1011

1112

12-
__all__ = ["GoogleCloudSpeechToText", "WhisperCppSpeechToText", "create_speech_recognizer"]
13+
__all__ = [
14+
"GoogleCloudSpeechToText",
15+
"WhisperCppSpeechToText",
16+
"WhisperServerSpeechToText",
17+
"create_speech_recognizer",
18+
]

stackchan_server/speech_recognition/google_cloud.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from google.cloud import speech
77

8+
from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE
89
from ..types import StreamingSpeechRecognizer, StreamingSpeechSession
910

1011
logger = getLogger(__name__)
@@ -15,25 +16,13 @@ class _GoogleCloudStreamingSession(StreamingSpeechSession):
1516
def __init__(
1617
self,
1718
client: speech.SpeechAsyncClient,
18-
*,
19-
sample_rate_hz: int,
20-
channels: int,
21-
sample_width: int,
22-
language_code: str,
2319
) -> None:
24-
if channels != 1:
25-
raise ValueError(f"Google Cloud Speech only supports mono input here: channels={channels}")
26-
if sample_width != 2:
27-
raise ValueError(
28-
f"Google Cloud Speech LINEAR16 requires 16-bit samples here: sample_width={sample_width}"
29-
)
30-
3120
self._client = client
3221
self._config = speech.StreamingRecognitionConfig(
3322
config=speech.RecognitionConfig(
3423
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
35-
sample_rate_hertz=sample_rate_hz,
36-
language_code=language_code,
24+
sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
25+
language_code=LISTEN_LANGUAGE_CODE,
3726
),
3827
interim_results=False,
3928
single_utterance=False,
@@ -109,47 +98,19 @@ class GoogleCloudSpeechToText(StreamingSpeechRecognizer):
10998
def __init__(self, client: speech.SpeechAsyncClient | None = None) -> None:
11099
self._client = client or speech.SpeechAsyncClient()
111100

112-
async def transcribe(
113-
self,
114-
pcm_bytes: bytes,
115-
*,
116-
sample_rate_hz: int,
117-
channels: int,
118-
sample_width: int,
119-
language_code: str = "ja-JP",
120-
) -> str:
121-
if channels != 1:
122-
raise ValueError(f"Google Cloud Speech only supports mono input here: channels={channels}")
123-
if sample_width != 2:
124-
raise ValueError(
125-
f"Google Cloud Speech LINEAR16 requires 16-bit samples here: sample_width={sample_width}"
126-
)
127-
101+
async def transcribe(self, pcm_bytes: bytes) -> str:
128102
audio = speech.RecognitionAudio(content=pcm_bytes)
129103
config = speech.RecognitionConfig(
130104
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
131-
sample_rate_hertz=sample_rate_hz,
132-
language_code=language_code,
105+
sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
106+
language_code=LISTEN_LANGUAGE_CODE,
133107
)
134108
response = await self._client.recognize(config=config, audio=audio)
135109

136110
return "".join(result.alternatives[0].transcript for result in response.results)
137111

138-
async def start_stream(
139-
self,
140-
*,
141-
sample_rate_hz: int,
142-
channels: int,
143-
sample_width: int,
144-
language_code: str = "ja-JP",
145-
) -> StreamingSpeechSession:
146-
return _GoogleCloudStreamingSession(
147-
self._client,
148-
sample_rate_hz=sample_rate_hz,
149-
channels=channels,
150-
sample_width=sample_width,
151-
language_code=language_code,
152-
)
112+
async def start_stream(self) -> StreamingSpeechSession:
113+
return _GoogleCloudStreamingSession(self._client)
153114

154115

155116
__all__ = ["GoogleCloudSpeechToText"]

stackchan_server/speech_recognition/whisper_cpp.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from logging import getLogger
1313
from pathlib import Path
1414

15+
from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE
1516
from ..types import SpeechRecognizer
1617

1718
logger = getLogger(__name__)
@@ -26,7 +27,7 @@ class WhisperCppSpeechToText(SpeechRecognizer):
2627
def __init__(
2728
self,
2829
*,
29-
model_path: str | Path,
30+
model_path: str | Path | None = None,
3031
cli_path: str = "whisper-cli",
3132
threads: int | None = None,
3233
translate: bool = False,
@@ -40,7 +41,10 @@ def __init__(
4041
vad_speech_pad_ms: int = _DEFAULT_VAD_SPEECH_PAD_MS,
4142
silence_rms_threshold: float = _DEFAULT_SILENCE_RMS_THRESHOLD,
4243
) -> None:
43-
self._model_path = Path(model_path)
44+
resolved_model_path = model_path or os.getenv("STACKCHAN_WHISPER_MODEL")
45+
if not resolved_model_path:
46+
raise ValueError("whisper.cpp model_path is required or set STACKCHAN_WHISPER_MODEL")
47+
self._model_path = Path(resolved_model_path)
4448
self._cli_path = cli_path
4549
self._threads = threads
4650
self._translate = translate
@@ -54,19 +58,7 @@ def __init__(
5458
self._vad_speech_pad_ms = vad_speech_pad_ms
5559
self._silence_rms_threshold = silence_rms_threshold
5660

57-
async def transcribe(
58-
self,
59-
pcm_bytes: bytes,
60-
*,
61-
sample_rate_hz: int,
62-
channels: int,
63-
sample_width: int,
64-
language_code: str = "ja-JP",
65-
) -> str:
66-
if channels != 1:
67-
raise ValueError(f"whisper.cpp only supports mono input here: channels={channels}")
68-
if sample_width != 2:
69-
raise ValueError(f"whisper.cpp expects 16-bit PCM here: sample_width={sample_width}")
61+
async def transcribe(self, pcm_bytes: bytes) -> str:
7062
if not self._model_path.is_file():
7163
raise FileNotFoundError(f"whisper.cpp model not found: {self._model_path}")
7264
if _pcm_rms_level(pcm_bytes) < self._silence_rms_threshold:
@@ -81,7 +73,7 @@ async def transcribe(
8173
if cli_path is None:
8274
raise FileNotFoundError(f"whisper.cpp CLI not found in PATH: {self._cli_path}")
8375

84-
language = _normalize_language(language_code)
76+
language = _normalize_language(LISTEN_LANGUAGE_CODE)
8577
with tempfile.TemporaryDirectory(prefix="stackchan_whisper_") as temp_dir_name:
8678
temp_dir = Path(temp_dir_name)
8779
wav_path = temp_dir / "input.wav"
@@ -90,9 +82,9 @@ async def transcribe(
9082
_write_wav(
9183
wav_path,
9284
pcm_bytes,
93-
sample_rate_hz=sample_rate_hz,
94-
channels=channels,
95-
sample_width=sample_width,
85+
sample_rate_hz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
86+
channels=LISTEN_AUDIO_FORMAT.channels,
87+
sample_width=LISTEN_AUDIO_FORMAT.sample_width,
9688
)
9789

9890
command = [

0 commit comments

Comments
 (0)