diff --git a/example_apps/echo.py b/example_apps/echo.py index 2a60bfa..64a84f2 100644 --- a/example_apps/echo.py +++ b/example_apps/echo.py @@ -5,7 +5,7 @@ from logging import getLogger from stackchan_server.app import StackChanApp -from stackchan_server.speech_recognition import WhisperCppSpeechToText +from stackchan_server.speech_recognition import WhisperCppSpeechToText, WhisperServerSpeechToText from stackchan_server.speech_synthesis import VoiceVoxSpeechSynthesizer from stackchan_server.ws_proxy import EmptyTranscriptError, WsProxy @@ -17,7 +17,14 @@ ) def _create_app() -> StackChanApp: + whisper_server_url = os.getenv("STACKCHAN_WHISPER_SERVER_URL") + whisper_server_port = os.getenv("STACKCHAN_WHISPER_SERVER_PORT") whisper_model = os.getenv("STACKCHAN_WHISPER_MODEL") + if whisper_server_url or whisper_server_port: + return StackChanApp( + speech_recognizer=WhisperServerSpeechToText(server_url=whisper_server_url), + speech_synthesizer=VoiceVoxSpeechSynthesizer(), + ) if whisper_model: return StackChanApp( speech_recognizer=WhisperCppSpeechToText( diff --git a/misc/whisper-server/run-whisper-server.sh b/misc/whisper-server/run-whisper-server.sh new file mode 100755 index 0000000..5935b12 --- /dev/null +++ b/misc/whisper-server/run-whisper-server.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -xe + +whisper-server \ + --host 0.0.0.0 \ + --port ${STACKCHAN_WHISPER_SERVER_PORT} \ + -m ${STACKCHAN_WHISPER_MODEL} \ + -l ja \ + -nt \ + --vad \ + -vm ${STACKCHAN_WHISPER_VAD_MODEL} \ + -vt 0.6 \ + -vspd 250 \ + -vsd 400 \ + -vp 30 diff --git a/stackchan_server/listen.py b/stackchan_server/listen.py index c7e5b84..a2c934d 100644 --- a/stackchan_server/listen.py +++ b/stackchan_server/listen.py @@ -9,6 +9,7 @@ from fastapi import WebSocket, WebSocketDisconnect +from .static import LISTEN_AUDIO_FORMAT from .types import SpeechRecognizer, StreamingSpeechRecognizer, StreamingSpeechSession logger = getLogger(__name__) @@ -29,20 +30,13 @@ def __init__( speech_recognizer: SpeechRecognizer, recordings_dir: Path, debug_recording: bool, - sample_rate_hz: int, - channels: int, - sample_width: int, listen_audio_timeout_seconds: float, - language_code: str = "ja-JP", ) -> None: self.speech_recognizer = speech_recognizer self.recordings_dir = recordings_dir self.debug_recording = debug_recording - self.sample_rate_hz = sample_rate_hz - self.channels = channels - self.sample_width = sample_width + self.audio_format = LISTEN_AUDIO_FORMAT self.listen_audio_timeout_seconds = listen_audio_timeout_seconds - self.language_code = language_code self._pcm_buffer = bytearray() self._streaming = False @@ -96,12 +90,7 @@ async def handle_start(self, websocket: WebSocket) -> bool: self._message_error = None if isinstance(self.speech_recognizer, StreamingSpeechRecognizer): try: - self._speech_stream = await self.speech_recognizer.start_stream( - sample_rate_hz=self.sample_rate_hz, - channels=self.channels, - sample_width=self.sample_width, - language_code=self.language_code, - ) + self._speech_stream = await self.speech_recognizer.start_stream() except Exception: asyncio.create_task(websocket.close(code=1011, reason="speech streaming failed")) return False @@ -113,7 +102,7 @@ async def handle_data(self, websocket: WebSocket, payload_bytes: int, payload: b await self._abort_speech_stream() asyncio.create_task(websocket.close(code=1003, reason="data received before start")) return False - if payload_bytes % (self.sample_width * self.channels) != 0: + if payload_bytes % (self.audio_format.sample_width * self.audio_format.channels) != 0: await self._abort_speech_stream() asyncio.create_task(websocket.close(code=1003, reason="invalid pcm chunk length")) return False @@ -142,7 +131,7 @@ async def handle_end( await self._abort_speech_stream() await websocket.close(code=1003, reason="end received before start") return - if payload_bytes % (self.sample_width * self.channels) != 0: + if payload_bytes % (self.audio_format.sample_width * self.audio_format.channels) != 0: await self._abort_speech_stream() await websocket.close(code=1003, reason="invalid pcm tail length") return @@ -155,19 +144,21 @@ async def handle_end( await websocket.close(code=1011, reason="speech streaming failed") return - if len(self._pcm_buffer) == 0 or len(self._pcm_buffer) % (self.sample_width * self.channels) != 0: + if len(self._pcm_buffer) == 0 or len(self._pcm_buffer) % ( + self.audio_format.sample_width * self.audio_format.channels + ) != 0: await self._abort_speech_stream() await websocket.close(code=1003, reason="invalid accumulated pcm length") return await send_state_command(thinking_state) - frames = len(self._pcm_buffer) // (self.sample_width * self.channels) - duration_seconds = frames / float(self.sample_rate_hz) + frames = len(self._pcm_buffer) // (self.audio_format.sample_width * self.audio_format.channels) + duration_seconds = frames / float(self.audio_format.sample_rate_hz) ws_meta = { - "sample_rate": self.sample_rate_hz, + "sample_rate": self.audio_format.sample_rate_hz, "frames": frames, - "channels": self.channels, + "channels": self.audio_format.channels, "duration_seconds": round(duration_seconds, 3), } if self.debug_recording: @@ -197,9 +188,9 @@ def _save_wav(self, pcm_bytes: bytes) -> tuple[Path, str]: filepath = self.recordings_dir / filename with wave.open(str(filepath), "wb") as wav_fp: - wav_fp.setnchannels(self.channels) - wav_fp.setsampwidth(self.sample_width) - wav_fp.setframerate(self.sample_rate_hz) + wav_fp.setnchannels(self.audio_format.channels) + wav_fp.setsampwidth(self.audio_format.sample_width) + wav_fp.setframerate(self.audio_format.sample_rate_hz) wav_fp.writeframes(pcm_bytes) logger.info("Saved WAV: %s", filename) @@ -211,13 +202,7 @@ async def _transcribe_async(self, pcm_bytes: bytes) -> str: return await self._transcribe(pcm_bytes) async def _transcribe(self, pcm_bytes: bytes) -> str: - transcript = await self.speech_recognizer.transcribe( - pcm_bytes, - sample_rate_hz=self.sample_rate_hz, - channels=self.channels, - sample_width=self.sample_width, - language_code=self.language_code, - ) + transcript = await self.speech_recognizer.transcribe(pcm_bytes) if transcript: logger.info("Transcript: %s", transcript) return transcript diff --git a/stackchan_server/speech_recognition/__init__.py b/stackchan_server/speech_recognition/__init__.py index f9a066a..aedc100 100644 --- a/stackchan_server/speech_recognition/__init__.py +++ b/stackchan_server/speech_recognition/__init__.py @@ -3,10 +3,16 @@ from ..types import SpeechRecognizer from .google_cloud import GoogleCloudSpeechToText from .whisper_cpp import WhisperCppSpeechToText +from .whisper_server import WhisperServerSpeechToText def create_speech_recognizer() -> SpeechRecognizer: return GoogleCloudSpeechToText() -__all__ = ["GoogleCloudSpeechToText", "WhisperCppSpeechToText", "create_speech_recognizer"] +__all__ = [ + "GoogleCloudSpeechToText", + "WhisperCppSpeechToText", + "WhisperServerSpeechToText", + "create_speech_recognizer", +] diff --git a/stackchan_server/speech_recognition/google_cloud.py b/stackchan_server/speech_recognition/google_cloud.py index 3780ea6..7a3ec73 100644 --- a/stackchan_server/speech_recognition/google_cloud.py +++ b/stackchan_server/speech_recognition/google_cloud.py @@ -5,6 +5,7 @@ from google.cloud import speech +from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE from ..types import StreamingSpeechRecognizer, StreamingSpeechSession logger = getLogger(__name__) @@ -15,25 +16,13 @@ class _GoogleCloudStreamingSession(StreamingSpeechSession): def __init__( self, client: speech.SpeechAsyncClient, - *, - sample_rate_hz: int, - channels: int, - sample_width: int, - language_code: str, ) -> None: - if channels != 1: - raise ValueError(f"Google Cloud Speech only supports mono input here: channels={channels}") - if sample_width != 2: - raise ValueError( - f"Google Cloud Speech LINEAR16 requires 16-bit samples here: sample_width={sample_width}" - ) - self._client = client self._config = speech.StreamingRecognitionConfig( config=speech.RecognitionConfig( encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, - sample_rate_hertz=sample_rate_hz, - language_code=language_code, + sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz, + language_code=LISTEN_LANGUAGE_CODE, ), interim_results=False, single_utterance=False, @@ -109,47 +98,19 @@ class GoogleCloudSpeechToText(StreamingSpeechRecognizer): def __init__(self, client: speech.SpeechAsyncClient | None = None) -> None: self._client = client or speech.SpeechAsyncClient() - async def transcribe( - self, - pcm_bytes: bytes, - *, - sample_rate_hz: int, - channels: int, - sample_width: int, - language_code: str = "ja-JP", - ) -> str: - if channels != 1: - raise ValueError(f"Google Cloud Speech only supports mono input here: channels={channels}") - if sample_width != 2: - raise ValueError( - f"Google Cloud Speech LINEAR16 requires 16-bit samples here: sample_width={sample_width}" - ) - + async def transcribe(self, pcm_bytes: bytes) -> str: audio = speech.RecognitionAudio(content=pcm_bytes) config = speech.RecognitionConfig( encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, - sample_rate_hertz=sample_rate_hz, - language_code=language_code, + sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz, + language_code=LISTEN_LANGUAGE_CODE, ) response = await self._client.recognize(config=config, audio=audio) return "".join(result.alternatives[0].transcript for result in response.results) - async def start_stream( - self, - *, - sample_rate_hz: int, - channels: int, - sample_width: int, - language_code: str = "ja-JP", - ) -> StreamingSpeechSession: - return _GoogleCloudStreamingSession( - self._client, - sample_rate_hz=sample_rate_hz, - channels=channels, - sample_width=sample_width, - language_code=language_code, - ) + async def start_stream(self) -> StreamingSpeechSession: + return _GoogleCloudStreamingSession(self._client) __all__ = ["GoogleCloudSpeechToText"] diff --git a/stackchan_server/speech_recognition/whisper_cpp.py b/stackchan_server/speech_recognition/whisper_cpp.py index 97e10da..f0925a2 100644 --- a/stackchan_server/speech_recognition/whisper_cpp.py +++ b/stackchan_server/speech_recognition/whisper_cpp.py @@ -12,6 +12,7 @@ from logging import getLogger from pathlib import Path +from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE from ..types import SpeechRecognizer logger = getLogger(__name__) @@ -26,7 +27,7 @@ class WhisperCppSpeechToText(SpeechRecognizer): def __init__( self, *, - model_path: str | Path, + model_path: str | Path | None = None, cli_path: str = "whisper-cli", threads: int | None = None, translate: bool = False, @@ -40,7 +41,10 @@ def __init__( vad_speech_pad_ms: int = _DEFAULT_VAD_SPEECH_PAD_MS, silence_rms_threshold: float = _DEFAULT_SILENCE_RMS_THRESHOLD, ) -> None: - self._model_path = Path(model_path) + resolved_model_path = model_path or os.getenv("STACKCHAN_WHISPER_MODEL") + if not resolved_model_path: + raise ValueError("whisper.cpp model_path is required or set STACKCHAN_WHISPER_MODEL") + self._model_path = Path(resolved_model_path) self._cli_path = cli_path self._threads = threads self._translate = translate @@ -54,19 +58,7 @@ def __init__( self._vad_speech_pad_ms = vad_speech_pad_ms self._silence_rms_threshold = silence_rms_threshold - async def transcribe( - self, - pcm_bytes: bytes, - *, - sample_rate_hz: int, - channels: int, - sample_width: int, - language_code: str = "ja-JP", - ) -> str: - if channels != 1: - raise ValueError(f"whisper.cpp only supports mono input here: channels={channels}") - if sample_width != 2: - raise ValueError(f"whisper.cpp expects 16-bit PCM here: sample_width={sample_width}") + async def transcribe(self, pcm_bytes: bytes) -> str: if not self._model_path.is_file(): raise FileNotFoundError(f"whisper.cpp model not found: {self._model_path}") if _pcm_rms_level(pcm_bytes) < self._silence_rms_threshold: @@ -81,7 +73,7 @@ async def transcribe( if cli_path is None: raise FileNotFoundError(f"whisper.cpp CLI not found in PATH: {self._cli_path}") - language = _normalize_language(language_code) + language = _normalize_language(LISTEN_LANGUAGE_CODE) with tempfile.TemporaryDirectory(prefix="stackchan_whisper_") as temp_dir_name: temp_dir = Path(temp_dir_name) wav_path = temp_dir / "input.wav" @@ -90,9 +82,9 @@ async def transcribe( _write_wav( wav_path, pcm_bytes, - sample_rate_hz=sample_rate_hz, - channels=channels, - sample_width=sample_width, + sample_rate_hz=LISTEN_AUDIO_FORMAT.sample_rate_hz, + channels=LISTEN_AUDIO_FORMAT.channels, + sample_width=LISTEN_AUDIO_FORMAT.sample_width, ) command = [ diff --git a/stackchan_server/speech_recognition/whisper_server.py b/stackchan_server/speech_recognition/whisper_server.py new file mode 100644 index 0000000..c3b5d69 --- /dev/null +++ b/stackchan_server/speech_recognition/whisper_server.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import asyncio +import json +import math +import mimetypes +import os +import uuid +from collections.abc import Mapping +from logging import getLogger +from pathlib import Path +from typing import cast +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE +from ..types import SpeechRecognizer + +logger = getLogger(__name__) + +_DEFAULT_SILENCE_RMS_THRESHOLD = 75.0 +_DEFAULT_SERVER_PORT = 8080 + + +class WhisperServerSpeechToText(SpeechRecognizer): + def __init__( + self, + *, + server_url: str | None = None, + language: str | None = None, + detect_language: bool = False, + response_format: str = "verbose_json", + silence_rms_threshold: float = _DEFAULT_SILENCE_RMS_THRESHOLD, + request_timeout_seconds: float = 60.0, + ) -> None: + self._server_url = server_url or _default_server_url() + self._language = language or _normalize_language(LISTEN_LANGUAGE_CODE) + self._detect_language = detect_language + self._response_format = response_format + self._silence_rms_threshold = silence_rms_threshold + self._request_timeout_seconds = request_timeout_seconds + + async def transcribe(self, pcm_bytes: bytes) -> str: + rms_level = _pcm_rms_level(pcm_bytes) + if rms_level < self._silence_rms_threshold: + logger.info( + "Skipping whisper-server transcription because pcm rms %.2f is below silence threshold %.2f", + rms_level, + self._silence_rms_threshold, + ) + return "" + + wav_bytes = _wrap_pcm_as_wav( + pcm_bytes, + sample_rate_hz=LISTEN_AUDIO_FORMAT.sample_rate_hz, + channels=LISTEN_AUDIO_FORMAT.channels, + sample_width=LISTEN_AUDIO_FORMAT.sample_width, + ) + transcript = await asyncio.to_thread( + self._request_transcript, + wav_bytes, + self._language, + ) + if transcript: + logger.info("whisper-server transcript: %s", transcript) + return transcript + + def _request_transcript(self, wav_bytes: bytes, language: str) -> str: + fields = { + "response_format": self._response_format, + "language": language, + } + if self._detect_language: + fields["detect_language"] = "true" + + body, content_type = _encode_multipart_formdata( + fields=fields, + files={"file": ("input.wav", wav_bytes, "audio/wav")}, + ) + request = Request( + self._server_url, + data=body, + headers={"Content-Type": content_type}, + method="POST", + ) + logger.info("Running whisper-server request: POST %s", self._server_url) + try: + with urlopen(request, timeout=self._request_timeout_seconds) as response: + response_body = response.read() + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace").strip() + raise RuntimeError( + f"whisper-server failed: status={exc.code} body={detail or ''}" + ) from exc + except URLError as exc: + raise RuntimeError(f"whisper-server request failed: {exc.reason}") from exc + + if self._response_format == "json": + payload = json.loads(response_body.decode("utf-8")) + text = payload.get("text") + return text.strip() if isinstance(text, str) else "" + + payload = json.loads(response_body.decode("utf-8")) + return _load_transcript_from_verbose_json(payload) + + +def _default_server_url() -> str: + configured = os.getenv("STACKCHAN_WHISPER_SERVER_URL") + if configured: + return configured.rstrip("/") + port = os.getenv("STACKCHAN_WHISPER_SERVER_PORT", str(_DEFAULT_SERVER_PORT)) + return f"http://127.0.0.1:{port}/inference" + + +def _normalize_language(language_code: str) -> str: + if not language_code: + return "" + return language_code.split("-", 1)[0].lower() + + +def _load_transcript_from_verbose_json(payload: object) -> str: + if not isinstance(payload, Mapping): + return "" + payload = cast(Mapping[str, object], payload) + transcription = payload.get("transcription") + if not isinstance(transcription, list): + text = payload.get("text") + return text.strip() if isinstance(text, str) else "" + parts: list[str] = [] + for item in transcription: + if not isinstance(item, Mapping): + continue + item = cast(Mapping[str, object], item) + text = item.get("text") + if isinstance(text, str): + normalized = text.strip() + if normalized: + parts.append(normalized) + return " ".join(parts).strip() + + +def _pcm_rms_level(pcm_bytes: bytes) -> float: + if len(pcm_bytes) < 2: + return 0.0 + sample_count = len(pcm_bytes) // 2 + total = 0.0 + for index in range(0, sample_count * 2, 2): + sample = int.from_bytes(pcm_bytes[index : index + 2], byteorder="little", signed=True) + total += float(sample * sample) + return math.sqrt(total / sample_count) + + +def _wrap_pcm_as_wav( + pcm_bytes: bytes, + *, + sample_rate_hz: int, + channels: int, + sample_width: int, +) -> bytes: + import io + import wave + + with io.BytesIO() as buffer: + with wave.open(buffer, "wb") as wav_fp: + wav_fp.setnchannels(channels) + wav_fp.setsampwidth(sample_width) + wav_fp.setframerate(sample_rate_hz) + wav_fp.writeframes(pcm_bytes) + return buffer.getvalue() + + +def _encode_multipart_formdata( + *, + fields: dict[str, str], + files: dict[str, tuple[str, bytes, str]], +) -> tuple[bytes, str]: + boundary = f"----stackchan-{uuid.uuid4().hex}" + boundary_bytes = boundary.encode("ascii") + lines: list[bytes] = [] + + for key, value in fields.items(): + lines.extend( + [ + b"--" + boundary_bytes, + f'Content-Disposition: form-data; name="{key}"'.encode("utf-8"), + b"", + value.encode("utf-8"), + ] + ) + + for field_name, (filename, content, content_type) in files.items(): + guessed_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" + lines.extend( + [ + b"--" + boundary_bytes, + ( + f'Content-Disposition: form-data; name="{field_name}"; filename="{Path(filename).name}"' + ).encode("utf-8"), + f"Content-Type: {guessed_type}".encode("utf-8"), + b"", + content, + ] + ) + + lines.append(b"--" + boundary_bytes + b"--") + lines.append(b"") + body = b"\r\n".join(lines) + return body, f"multipart/form-data; boundary={boundary}" + + +__all__ = ["WhisperServerSpeechToText"] diff --git a/stackchan_server/static.py b/stackchan_server/static.py new file mode 100644 index 0000000..872f1dc --- /dev/null +++ b/stackchan_server/static.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from .types import AudioFormat + +LISTEN_AUDIO_FORMAT = AudioFormat( + sample_rate_hz=16000, + channels=1, + sample_width=2, +) +LISTEN_LANGUAGE_CODE = "ja-JP" + +__all__ = ["LISTEN_AUDIO_FORMAT", "LISTEN_LANGUAGE_CODE"] diff --git a/stackchan_server/types.py b/stackchan_server/types.py index ab63bad..326b8b2 100644 --- a/stackchan_server/types.py +++ b/stackchan_server/types.py @@ -6,15 +6,7 @@ @runtime_checkable class SpeechRecognizer(Protocol): - async def transcribe( - self, - pcm_bytes: bytes, - *, - sample_rate_hz: int, - channels: int, - sample_width: int, - language_code: str = "ja-JP", - ) -> str: ... + async def transcribe(self, pcm_bytes: bytes) -> str: ... @runtime_checkable @@ -28,14 +20,7 @@ async def abort(self) -> None: ... @runtime_checkable class StreamingSpeechRecognizer(SpeechRecognizer, Protocol): - async def start_stream( - self, - *, - sample_rate_hz: int, - channels: int, - sample_width: int, - language_code: str = "ja-JP", - ) -> StreamingSpeechSession: ... + async def start_stream(self) -> StreamingSpeechSession: ... @runtime_checkable diff --git a/stackchan_server/ws_proxy.py b/stackchan_server/ws_proxy.py index 1b9bd3b..5ead596 100644 --- a/stackchan_server/ws_proxy.py +++ b/stackchan_server/ws_proxy.py @@ -13,6 +13,7 @@ from .listen import EmptyTranscriptError, ListenHandler, TimeoutError from .speak import SpeakHandler +from .static import LISTEN_AUDIO_FORMAT from .types import SpeechRecognizer, SpeechSynthesizer logger = getLogger(__name__) @@ -23,10 +24,6 @@ _WS_HEADER_FMT = "