Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion example_apps/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from logging import getLogger

from stackchan_server.app import StackChanApp
from stackchan_server.speech_recognition import WhisperCppSpeechToText
from stackchan_server.speech_recognition import WhisperCppSpeechToText, WhisperServerSpeechToText
from stackchan_server.speech_synthesis import VoiceVoxSpeechSynthesizer
from stackchan_server.ws_proxy import EmptyTranscriptError, WsProxy

Expand All @@ -17,7 +17,14 @@
)

def _create_app() -> StackChanApp:
whisper_server_url = os.getenv("STACKCHAN_WHISPER_SERVER_URL")
whisper_server_port = os.getenv("STACKCHAN_WHISPER_SERVER_PORT")
whisper_model = os.getenv("STACKCHAN_WHISPER_MODEL")
if whisper_server_url or whisper_server_port:
return StackChanApp(
speech_recognizer=WhisperServerSpeechToText(server_url=whisper_server_url),
speech_synthesizer=VoiceVoxSpeechSynthesizer(),
)
if whisper_model:
return StackChanApp(
speech_recognizer=WhisperCppSpeechToText(
Expand Down
15 changes: 15 additions & 0 deletions misc/whisper-server/run-whisper-server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash
set -xe

whisper-server \
--host 0.0.0.0 \
--port ${STACKCHAN_WHISPER_SERVER_PORT} \
-m ${STACKCHAN_WHISPER_MODEL} \
-l ja \
-nt \
--vad \
-vm ${STACKCHAN_WHISPER_VAD_MODEL} \
-vt 0.6 \
-vspd 250 \
-vsd 400 \
-vp 30
47 changes: 16 additions & 31 deletions stackchan_server/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fastapi import WebSocket, WebSocketDisconnect

from .static import LISTEN_AUDIO_FORMAT
from .types import SpeechRecognizer, StreamingSpeechRecognizer, StreamingSpeechSession

logger = getLogger(__name__)
Expand All @@ -29,20 +30,13 @@ def __init__(
speech_recognizer: SpeechRecognizer,
recordings_dir: Path,
debug_recording: bool,
sample_rate_hz: int,
channels: int,
sample_width: int,
listen_audio_timeout_seconds: float,
language_code: str = "ja-JP",
) -> None:
self.speech_recognizer = speech_recognizer
self.recordings_dir = recordings_dir
self.debug_recording = debug_recording
self.sample_rate_hz = sample_rate_hz
self.channels = channels
self.sample_width = sample_width
self.audio_format = LISTEN_AUDIO_FORMAT
self.listen_audio_timeout_seconds = listen_audio_timeout_seconds
self.language_code = language_code

self._pcm_buffer = bytearray()
self._streaming = False
Expand Down Expand Up @@ -96,12 +90,7 @@ async def handle_start(self, websocket: WebSocket) -> bool:
self._message_error = None
if isinstance(self.speech_recognizer, StreamingSpeechRecognizer):
try:
self._speech_stream = await self.speech_recognizer.start_stream(
sample_rate_hz=self.sample_rate_hz,
channels=self.channels,
sample_width=self.sample_width,
language_code=self.language_code,
)
self._speech_stream = await self.speech_recognizer.start_stream()
except Exception:
asyncio.create_task(websocket.close(code=1011, reason="speech streaming failed"))
return False
Expand All @@ -113,7 +102,7 @@ async def handle_data(self, websocket: WebSocket, payload_bytes: int, payload: b
await self._abort_speech_stream()
asyncio.create_task(websocket.close(code=1003, reason="data received before start"))
return False
if payload_bytes % (self.sample_width * self.channels) != 0:
if payload_bytes % (self.audio_format.sample_width * self.audio_format.channels) != 0:
await self._abort_speech_stream()
asyncio.create_task(websocket.close(code=1003, reason="invalid pcm chunk length"))
return False
Expand Down Expand Up @@ -142,7 +131,7 @@ async def handle_end(
await self._abort_speech_stream()
await websocket.close(code=1003, reason="end received before start")
return
if payload_bytes % (self.sample_width * self.channels) != 0:
if payload_bytes % (self.audio_format.sample_width * self.audio_format.channels) != 0:
await self._abort_speech_stream()
await websocket.close(code=1003, reason="invalid pcm tail length")
return
Expand All @@ -155,19 +144,21 @@ async def handle_end(
await websocket.close(code=1011, reason="speech streaming failed")
return

if len(self._pcm_buffer) == 0 or len(self._pcm_buffer) % (self.sample_width * self.channels) != 0:
if len(self._pcm_buffer) == 0 or len(self._pcm_buffer) % (
self.audio_format.sample_width * self.audio_format.channels
) != 0:
await self._abort_speech_stream()
await websocket.close(code=1003, reason="invalid accumulated pcm length")
return

await send_state_command(thinking_state)

frames = len(self._pcm_buffer) // (self.sample_width * self.channels)
duration_seconds = frames / float(self.sample_rate_hz)
frames = len(self._pcm_buffer) // (self.audio_format.sample_width * self.audio_format.channels)
duration_seconds = frames / float(self.audio_format.sample_rate_hz)
ws_meta = {
"sample_rate": self.sample_rate_hz,
"sample_rate": self.audio_format.sample_rate_hz,
"frames": frames,
"channels": self.channels,
"channels": self.audio_format.channels,
"duration_seconds": round(duration_seconds, 3),
}
if self.debug_recording:
Expand Down Expand Up @@ -197,9 +188,9 @@ def _save_wav(self, pcm_bytes: bytes) -> tuple[Path, str]:
filepath = self.recordings_dir / filename

with wave.open(str(filepath), "wb") as wav_fp:
wav_fp.setnchannels(self.channels)
wav_fp.setsampwidth(self.sample_width)
wav_fp.setframerate(self.sample_rate_hz)
wav_fp.setnchannels(self.audio_format.channels)
wav_fp.setsampwidth(self.audio_format.sample_width)
wav_fp.setframerate(self.audio_format.sample_rate_hz)
wav_fp.writeframes(pcm_bytes)

logger.info("Saved WAV: %s", filename)
Expand All @@ -211,13 +202,7 @@ async def _transcribe_async(self, pcm_bytes: bytes) -> str:
return await self._transcribe(pcm_bytes)

async def _transcribe(self, pcm_bytes: bytes) -> str:
transcript = await self.speech_recognizer.transcribe(
pcm_bytes,
sample_rate_hz=self.sample_rate_hz,
channels=self.channels,
sample_width=self.sample_width,
language_code=self.language_code,
)
transcript = await self.speech_recognizer.transcribe(pcm_bytes)
if transcript:
logger.info("Transcript: %s", transcript)
return transcript
Expand Down
8 changes: 7 additions & 1 deletion stackchan_server/speech_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
from ..types import SpeechRecognizer
from .google_cloud import GoogleCloudSpeechToText
from .whisper_cpp import WhisperCppSpeechToText
from .whisper_server import WhisperServerSpeechToText


def create_speech_recognizer() -> SpeechRecognizer:
return GoogleCloudSpeechToText()


__all__ = ["GoogleCloudSpeechToText", "WhisperCppSpeechToText", "create_speech_recognizer"]
__all__ = [
"GoogleCloudSpeechToText",
"WhisperCppSpeechToText",
"WhisperServerSpeechToText",
"create_speech_recognizer",
]
55 changes: 8 additions & 47 deletions stackchan_server/speech_recognition/google_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from google.cloud import speech

from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE
from ..types import StreamingSpeechRecognizer, StreamingSpeechSession

logger = getLogger(__name__)
Expand All @@ -15,25 +16,13 @@ class _GoogleCloudStreamingSession(StreamingSpeechSession):
def __init__(
self,
client: speech.SpeechAsyncClient,
*,
sample_rate_hz: int,
channels: int,
sample_width: int,
language_code: str,
) -> None:
if channels != 1:
raise ValueError(f"Google Cloud Speech only supports mono input here: channels={channels}")
if sample_width != 2:
raise ValueError(
f"Google Cloud Speech LINEAR16 requires 16-bit samples here: sample_width={sample_width}"
)

self._client = client
self._config = speech.StreamingRecognitionConfig(
config=speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=sample_rate_hz,
language_code=language_code,
sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
language_code=LISTEN_LANGUAGE_CODE,
),
interim_results=False,
single_utterance=False,
Expand Down Expand Up @@ -109,47 +98,19 @@ class GoogleCloudSpeechToText(StreamingSpeechRecognizer):
def __init__(self, client: speech.SpeechAsyncClient | None = None) -> None:
self._client = client or speech.SpeechAsyncClient()

async def transcribe(
self,
pcm_bytes: bytes,
*,
sample_rate_hz: int,
channels: int,
sample_width: int,
language_code: str = "ja-JP",
) -> str:
if channels != 1:
raise ValueError(f"Google Cloud Speech only supports mono input here: channels={channels}")
if sample_width != 2:
raise ValueError(
f"Google Cloud Speech LINEAR16 requires 16-bit samples here: sample_width={sample_width}"
)

async def transcribe(self, pcm_bytes: bytes) -> str:
audio = speech.RecognitionAudio(content=pcm_bytes)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=sample_rate_hz,
language_code=language_code,
sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
language_code=LISTEN_LANGUAGE_CODE,
)
response = await self._client.recognize(config=config, audio=audio)

return "".join(result.alternatives[0].transcript for result in response.results)

async def start_stream(
self,
*,
sample_rate_hz: int,
channels: int,
sample_width: int,
language_code: str = "ja-JP",
) -> StreamingSpeechSession:
return _GoogleCloudStreamingSession(
self._client,
sample_rate_hz=sample_rate_hz,
channels=channels,
sample_width=sample_width,
language_code=language_code,
)
async def start_stream(self) -> StreamingSpeechSession:
return _GoogleCloudStreamingSession(self._client)


__all__ = ["GoogleCloudSpeechToText"]
30 changes: 11 additions & 19 deletions stackchan_server/speech_recognition/whisper_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from logging import getLogger
from pathlib import Path

from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE
from ..types import SpeechRecognizer

logger = getLogger(__name__)
Expand All @@ -26,7 +27,7 @@ class WhisperCppSpeechToText(SpeechRecognizer):
def __init__(
self,
*,
model_path: str | Path,
model_path: str | Path | None = None,
cli_path: str = "whisper-cli",
threads: int | None = None,
translate: bool = False,
Expand All @@ -40,7 +41,10 @@ def __init__(
vad_speech_pad_ms: int = _DEFAULT_VAD_SPEECH_PAD_MS,
silence_rms_threshold: float = _DEFAULT_SILENCE_RMS_THRESHOLD,
) -> None:
self._model_path = Path(model_path)
resolved_model_path = model_path or os.getenv("STACKCHAN_WHISPER_MODEL")
if not resolved_model_path:
raise ValueError("whisper.cpp model_path is required or set STACKCHAN_WHISPER_MODEL")
self._model_path = Path(resolved_model_path)
self._cli_path = cli_path
self._threads = threads
self._translate = translate
Expand All @@ -54,19 +58,7 @@ def __init__(
self._vad_speech_pad_ms = vad_speech_pad_ms
self._silence_rms_threshold = silence_rms_threshold

async def transcribe(
self,
pcm_bytes: bytes,
*,
sample_rate_hz: int,
channels: int,
sample_width: int,
language_code: str = "ja-JP",
) -> str:
if channels != 1:
raise ValueError(f"whisper.cpp only supports mono input here: channels={channels}")
if sample_width != 2:
raise ValueError(f"whisper.cpp expects 16-bit PCM here: sample_width={sample_width}")
async def transcribe(self, pcm_bytes: bytes) -> str:
if not self._model_path.is_file():
raise FileNotFoundError(f"whisper.cpp model not found: {self._model_path}")
if _pcm_rms_level(pcm_bytes) < self._silence_rms_threshold:
Expand All @@ -81,7 +73,7 @@ async def transcribe(
if cli_path is None:
raise FileNotFoundError(f"whisper.cpp CLI not found in PATH: {self._cli_path}")

language = _normalize_language(language_code)
language = _normalize_language(LISTEN_LANGUAGE_CODE)
with tempfile.TemporaryDirectory(prefix="stackchan_whisper_") as temp_dir_name:
temp_dir = Path(temp_dir_name)
wav_path = temp_dir / "input.wav"
Expand All @@ -90,9 +82,9 @@ async def transcribe(
_write_wav(
wav_path,
pcm_bytes,
sample_rate_hz=sample_rate_hz,
channels=channels,
sample_width=sample_width,
sample_rate_hz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
channels=LISTEN_AUDIO_FORMAT.channels,
sample_width=LISTEN_AUDIO_FORMAT.sample_width,
)

command = [
Expand Down
Loading