diff --git a/example_apps/echo.py b/example_apps/echo.py index 5185c92..8fb47f3 100644 --- a/example_apps/echo.py +++ b/example_apps/echo.py @@ -1,14 +1,18 @@ from __future__ import annotations -from logging import StreamHandler, getLogger +import logging +import os +from logging import getLogger from stackchan_server.app import StackChanApp from stackchan_server.ws_proxy import EmptyTranscriptError, WsProxy - logger = getLogger(__name__) -logger.addHandler(StreamHandler()) -logger.setLevel("DEBUG") +logging.basicConfig( + level=os.getenv("STACKCHAN_LOG_LEVEL", "INFO"), + format="%(asctime)s.%(msecs)03d %(levelname)s:%(name)s:%(message)s", + datefmt="%H:%M:%S", +) app = StackChanApp() diff --git a/misc/api_test/test_cloud_sst.py b/misc/api_test/test_cloud_sst.py index 643f216..f38cc43 100644 --- a/misc/api_test/test_cloud_sst.py +++ b/misc/api_test/test_cloud_sst.py @@ -1,4 +1,3 @@ -import os from google.cloud import speech sst_client = speech.SpeechClient() diff --git a/pyproject.toml b/pyproject.toml index 220a0cb..7458d32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ readme = "README.md" requires-python = ">=3.13" dependencies = [ "fastapi>=0.128.0", + "google-genai>=1.59.0", "google-cloud-speech>=2.35.0", "uvicorn[standard]>=0.40.0", "voicevox-client>=1.1.0", @@ -20,9 +21,7 @@ dev = [ "ruff>=0.15.2", "ty>=0.0.17", ] -example-gemini = [ - "google-genai>=1.59.0", -] +example-gemini = [] example-claude-agent-sdk = [ "claude-agent-sdk>=0.1.39", ] diff --git a/stackchan_server/app.py b/stackchan_server/app.py index c192556..46a3006 100644 --- a/stackchan_server/app.py +++ b/stackchan_server/app.py @@ -7,15 +7,21 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect from .speech_recognition import create_speech_recognizer -from .types import SpeechRecognizer +from .speech_synthesis import create_speech_synthesizer +from .types import SpeechRecognizer, SpeechSynthesizer from .ws_proxy import WsProxy logger = getLogger(__name__) class StackChanApp: - def __init__(self, speech_recognizer: SpeechRecognizer | None = None) -> None: + def __init__( + self, + speech_recognizer: SpeechRecognizer | None = None, + speech_synthesizer: SpeechSynthesizer | None = None, + ) -> None: self.speech_recognizer = speech_recognizer or create_speech_recognizer() + self.speech_synthesizer = speech_synthesizer or create_speech_synthesizer() self.fastapi = FastAPI(title="StackChan WebSocket Server") self._setup_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None self._talk_session_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None @@ -38,7 +44,11 @@ def talk_session(self, fn: Callable[["WsProxy"], Awaitable[None]]): async def _handle_ws(self, websocket: WebSocket) -> None: await websocket.accept() - proxy = WsProxy(websocket, speech_recognizer=self.speech_recognizer) + proxy = WsProxy( + websocket, + speech_recognizer=self.speech_recognizer, + speech_synthesizer=self.speech_synthesizer, + ) await proxy.start() try: if self._setup_fn: diff --git a/stackchan_server/speak.py b/stackchan_server/speak.py new file mode 100644 index 0000000..b748a37 --- /dev/null +++ b/stackchan_server/speak.py @@ -0,0 +1,348 @@ +from __future__ import annotations + +import asyncio +import io +import struct +import wave +from datetime import UTC, datetime +from logging import getLogger +from pathlib import Path +from typing import Awaitable, Callable + +from fastapi import WebSocket, WebSocketDisconnect + +from .listen import TimeoutError +from .types import AudioFormat, SpeechSynthesizer, StreamingSpeechSynthesizer + +logger = getLogger(__name__) + + +class SpeakHandler: + def __init__( + self, + *, + websocket: WebSocket, + ws_header_fmt: str, + wav_kind: int, + start_msg_type: int, + data_msg_type: int, + end_msg_type: int, + down_wav_chunk: int, + down_segment_millis: int, + down_segment_stagger_millis: int, + sample_width: int, + speech_synthesizer: SpeechSynthesizer, + recordings_dir: Path, + debug_recording: bool, + ) -> None: + self.ws = websocket + self.ws_header_fmt = ws_header_fmt + self.wav_kind = wav_kind + self.start_msg_type = start_msg_type + self.data_msg_type = data_msg_type + self.end_msg_type = end_msg_type + self.down_wav_chunk = down_wav_chunk + self.down_segment_millis = down_segment_millis + self.down_segment_stagger_millis = down_segment_stagger_millis + self.sample_width = sample_width + self.speech_synthesizer = speech_synthesizer + self.recordings_dir = recordings_dir + self.debug_recording = debug_recording + + self._speaking = False + self._speak_finished_counter = 0 + + @property + def speaking(self) -> bool: + return self._speaking + + def handle_speak_done_event(self) -> None: + self._speak_finished_counter += 1 + self._speaking = False + logger.info("Received speak done event") + + async def speak( + self, + text: str, + *, + next_seq: Callable[[], int], + send_state_command: Callable[[int], Awaitable[None]], + idle_state: int, + is_closed: Callable[[], bool], + ) -> None: + start_counter = self._speak_finished_counter + await self._start_talking_stream(text, next_seq=next_seq) + if not self._speaking: + return + await self._wait_for_speaking_finished( + min_counter=start_counter + 1, + timeout_seconds=120.0, + is_closed=is_closed, + ) + if not is_closed(): + await send_state_command(idle_state) + + async def _wait_for_speaking_finished( + self, + *, + min_counter: int, + timeout_seconds: float | None, + is_closed: Callable[[], bool], + ) -> None: + loop = asyncio.get_running_loop() + deadline = (loop.time() + timeout_seconds) if timeout_seconds else None + while True: + if self._speak_finished_counter >= min_counter: + return + if is_closed(): + raise WebSocketDisconnect() + if deadline and loop.time() >= deadline: + raise TimeoutError("Timed out waiting for speaking finished event") + await asyncio.sleep(0.05) + + async def _start_talking_stream(self, text: str, *, next_seq: Callable[[], int]) -> None: + self._speaking = True + try: + if isinstance(self.speech_synthesizer, StreamingSpeechSynthesizer): + await self._start_talking_streaming( + text, + self.speech_synthesizer, + next_seq=next_seq, + ) + return + wav_bytes = await self.speech_synthesizer.synthesize(text) + logger.info("Synthesized wav_bytes=%d text_chars=%d", len(wav_bytes), len(text)) + pcm_bytes, tts_sample_rate, tts_channels, tts_sample_width = self._extract_pcm(wav_bytes) + logger.info( + "Synthesized audio sample_rate=%d channels=%d sample_width=%d pcm_bytes=%d", + tts_sample_rate, + tts_channels, + tts_sample_width, + len(pcm_bytes), + ) + if len(pcm_bytes) == 0: + logger.warning("Synthesized audio is empty") + self._speaking = False + return + + if tts_sample_width != self.sample_width: + await self.ws.send_json({"error": f"unsupported sample width {tts_sample_width}"}) + self._speaking = False + return + + if self.debug_recording: + filepath, filename = self._save_wav(wav_bytes) + logger.info("Saved synthesized WAV: %s", filename) + await self.ws.send_json({"tts_debug_path": f"recordings/{filename}", "tts_debug_bytes": len(wav_bytes)}) + + bytes_per_second = tts_sample_rate * tts_channels * tts_sample_width + segment_bytes = int(bytes_per_second * (self.down_segment_millis / 1000)) + + if segment_bytes <= 0: + await self.ws.send_json({"error": "invalid segment size computed"}) + self._speaking = False + return + + await self._send_segments( + pcm_bytes, + tts_sample_rate, + tts_channels, + segment_bytes, + next_seq=next_seq, + ) + except Exception as exc: # pragma: no cover + self._speaking = False + logger.exception("Speech synthesis failed") + await self.ws.send_json({"error": f"speech synthesis failed: {exc}"}) + + async def _start_talking_streaming( + self, + text: str, + speech_synthesizer: StreamingSpeechSynthesizer, + *, + next_seq: Callable[[], int], + ) -> None: + output_format = speech_synthesizer.output_format + logger.info( + "Streaming synthesized audio sample_rate=%d channels=%d sample_width=%d", + output_format.sample_rate_hz, + output_format.channels, + output_format.sample_width, + ) + if output_format.sample_width != self.sample_width: + await self.ws.send_json({"error": f"unsupported sample width {output_format.sample_width}"}) + self._speaking = False + return + + bytes_per_second = ( + output_format.sample_rate_hz * output_format.channels * output_format.sample_width + ) + segment_bytes = int(bytes_per_second * (self.down_segment_millis / 1000)) + if segment_bytes <= 0: + await self.ws.send_json({"error": "invalid segment size computed"}) + self._speaking = False + return + + pending = bytearray() + saved_pcm = bytearray() + segment_count = 0 + base_time: float | None = None + async for chunk in speech_synthesizer.synthesize_stream(text): + pending.extend(chunk) + if self.debug_recording: + saved_pcm.extend(chunk) + while len(pending) >= segment_bytes: + segment = bytes(pending[:segment_bytes]) + del pending[:segment_bytes] + base_time = await self._wait_for_segment_slot(segment_count, base_time=base_time) + await self._send_segment( + segment, + output_format.sample_rate_hz, + output_format.channels, + next_seq=next_seq, + ) + segment_count += 1 + if pending: + base_time = await self._wait_for_segment_slot(segment_count, base_time=base_time) + await self._send_segment( + bytes(pending), + output_format.sample_rate_hz, + output_format.channels, + next_seq=next_seq, + ) + segment_count += 1 + logger.info("Prepared %d playback segments from streaming TTS", segment_count) + + if self.debug_recording and saved_pcm: + wav_bytes = self._wrap_pcm_as_wav(bytes(saved_pcm), output_format) + filepath, filename = self._save_wav(wav_bytes) + logger.info("Saved synthesized WAV: %s", filename) + await self.ws.send_json({"tts_debug_path": f"recordings/{filename}", "tts_debug_bytes": len(wav_bytes)}) + + if segment_count == 0: + logger.warning("Synthesized audio is empty") + self._speaking = False + + def _extract_pcm(self, wav_bytes: bytes) -> tuple[bytes, int, int, int]: + with wave.open(io.BytesIO(wav_bytes), "rb") as wf: + pcm_bytes = wf.readframes(wf.getnframes()) + tts_sample_rate = wf.getframerate() + tts_channels = wf.getnchannels() + tts_sample_width = wf.getsampwidth() + return pcm_bytes, tts_sample_rate, tts_channels, tts_sample_width + + def _save_wav(self, wav_bytes: bytes) -> tuple[Path, str]: + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") + filename = f"tts_ws_{timestamp}.wav" + filepath = self.recordings_dir / filename + filepath.write_bytes(wav_bytes) + return filepath, filename + + def _wrap_pcm_as_wav(self, pcm_bytes: bytes, audio_format: AudioFormat) -> bytes: + with io.BytesIO() as buffer: + with wave.open(buffer, "wb") as wav_fp: + wav_fp.setnchannels(audio_format.channels) + wav_fp.setsampwidth(audio_format.sample_width) + wav_fp.setframerate(audio_format.sample_rate_hz) + wav_fp.writeframes(pcm_bytes) + return buffer.getvalue() + + async def _wait_for_segment_slot(self, segment_index: int, *, base_time: float | None) -> float: + loop = asyncio.get_running_loop() + if base_time is None: + return loop.time() + + if segment_index == 0: + target_ms = 0 + elif segment_index == 1: + target_ms = self.down_segment_stagger_millis + else: + target_ms = self.down_segment_stagger_millis + (segment_index - 1) * self.down_segment_millis + + target_time = base_time + target_ms / 1000 + now = loop.time() + if target_time > now: + await asyncio.sleep(target_time - now) + return base_time + + async def _send_segments( + self, + pcm_bytes: bytes, + tts_sample_rate: int, + tts_channels: int, + segment_bytes: int, + *, + next_seq: Callable[[], int], + ) -> None: + segments: list[bytes] = [] + offset = 0 + total = len(pcm_bytes) + while offset < total: + segments.append(pcm_bytes[offset : offset + segment_bytes]) + offset += segment_bytes + logger.info("Prepared %d playback segments", len(segments)) + + loop = asyncio.get_running_loop() + base_time = loop.time() + + for idx, segment in enumerate(segments): + if idx == 0: + target_ms = 0 + elif idx == 1: + target_ms = self.down_segment_stagger_millis + else: + target_ms = self.down_segment_stagger_millis + (idx - 1) * self.down_segment_millis + + target_time = base_time + target_ms / 1000 + now = loop.time() + if target_time > now: + await asyncio.sleep(target_time - now) + + await self._send_segment(segment, tts_sample_rate, tts_channels, next_seq=next_seq) + + async def _send_segment( + self, + segment_pcm: bytes, + tts_sample_rate: int, + tts_channels: int, + *, + next_seq: Callable[[], int], + ) -> None: + logger.info("Sending segment bytes=%d", len(segment_pcm)) + start_payload = struct.pack(" SpeechSynthesizer: + return VoiceVoxSpeechSynthesizer() + + +__all__ = ["GoogleCloudTextToSpeech", "VoiceVoxSpeechSynthesizer", "create_speech_synthesizer"] diff --git a/stackchan_server/speech_synthesis/google_cloud.py b/stackchan_server/speech_synthesis/google_cloud.py new file mode 100644 index 0000000..8f3feab --- /dev/null +++ b/stackchan_server/speech_synthesis/google_cloud.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import io +import os +import wave +from collections.abc import AsyncIterator +from logging import getLogger +from typing import Any + +from google import genai +from google.genai import types + +from ..types import AudioFormat, StreamingSpeechSynthesizer + +logger = getLogger(__name__) + +_DEFAULT_MODEL = "gemini-2.5-flash-tts" +_DEFAULT_LOCATION = "global" +_PCM_SAMPLE_RATE_HZ = 24000 +_PCM_CHANNELS = 1 +_PCM_SAMPLE_WIDTH = 2 +_OUTPUT_FORMAT = AudioFormat( + sample_rate_hz=_PCM_SAMPLE_RATE_HZ, + channels=_PCM_CHANNELS, + sample_width=_PCM_SAMPLE_WIDTH, +) + + +def create_vertexai_client() -> Any: + return genai.Client( + vertexai=True, + project=os.getenv("GOOGLE_CLOUD_PROJECT"), + location=os.getenv("GOOGLE_CLOUD_LOCATION") or os.getenv("GOOGLE_CLOUD_REGION") or _DEFAULT_LOCATION, + ).aio + + +class GoogleCloudTextToSpeech(StreamingSpeechSynthesizer): + def __init__( + self, + *, + model: str = _DEFAULT_MODEL, + language_code: str = "ja-JP", + voice_name: str = "Despina", + style_instructions: str | None = None, + client: Any | None = None, + ) -> None: + self._model = model + self._language_code = language_code + self._voice_name = voice_name + self._style_instructions = style_instructions + self._client = client or create_vertexai_client() + + @property + def output_format(self) -> AudioFormat: + return _OUTPUT_FORMAT + + async def synthesize(self, text: str) -> bytes: + pcm_bytes = bytearray() + async for chunk in self.synthesize_stream(text): + pcm_bytes.extend(chunk) + logger.info( + "Gemini TTS response pcm_bytes=%d model=%s language_code=%s voice_name=%s", + len(pcm_bytes), + self._model, + self._language_code, + self._voice_name, + ) + return self._wrap_pcm_as_wav(bytes(pcm_bytes)) + + async def synthesize_stream(self, text: str) -> AsyncIterator[bytes]: + logger.info( + "Requesting Gemini TTS model=%s language_code=%s voice_name=%s text_chars=%d", + self._model, + self._language_code, + self._voice_name, + len(text), + ) + async for response in await self._client.models.generate_content_stream( + model=self._model, + contents=self._build_contents(text), + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + language_code=self._language_code, + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=self._voice_name, + ) + ), + ), + ), + ): + chunk = self._extract_audio_bytes(response) + if chunk: + yield chunk + + def _build_contents(self, text: str) -> str: + if not self._style_instructions: + return text + return f"{self._style_instructions}\n\n{text}" + + def _extract_audio_bytes(self, response: types.GenerateContentResponse) -> bytes: + pcm_bytes = bytearray() + if not response.candidates: + return b"" + for candidate in response.candidates: + if not candidate.content or not candidate.content.parts: + continue + for part in candidate.content.parts: + if part.inline_data and isinstance(part.inline_data.data, bytes): + pcm_bytes.extend(part.inline_data.data) + return bytes(pcm_bytes) + + def _wrap_pcm_as_wav(self, pcm_bytes: bytes) -> bytes: + with io.BytesIO() as buffer: + with wave.open(buffer, "wb") as wav_fp: + wav_fp.setnchannels(_PCM_CHANNELS) + wav_fp.setsampwidth(_PCM_SAMPLE_WIDTH) + wav_fp.setframerate(_PCM_SAMPLE_RATE_HZ) + wav_fp.writeframes(pcm_bytes) + return buffer.getvalue() + + +__all__ = ["GoogleCloudTextToSpeech", "create_vertexai_client"] diff --git a/stackchan_server/speech_synthesis/voicevox.py b/stackchan_server/speech_synthesis/voicevox.py new file mode 100644 index 0000000..8f1f1c9 --- /dev/null +++ b/stackchan_server/speech_synthesis/voicevox.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import os + +from vvclient import Client as VVClient + +from ..types import SpeechSynthesizer + + +def create_voicevox_client() -> VVClient: + voicevox_url = os.getenv("STACKCHAN_VOICEVOX_URL", "http://localhost:50021") + return VVClient(base_uri=voicevox_url) + + +class VoiceVoxSpeechSynthesizer(SpeechSynthesizer): + def __init__(self, speaker: int = 29) -> None: + self._speaker = speaker + + async def synthesize(self, text: str) -> bytes: + async with create_voicevox_client() as client: + audio_query = await client.create_audio_query(text, speaker=self._speaker) + return await audio_query.synthesis(speaker=self._speaker) + + +__all__ = ["VoiceVoxSpeechSynthesizer", "create_voicevox_client"] diff --git a/stackchan_server/types.py b/stackchan_server/types.py index 041f7d1..ab63bad 100644 --- a/stackchan_server/types.py +++ b/stackchan_server/types.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Protocol, runtime_checkable +from dataclasses import dataclass +from typing import AsyncIterator, Protocol, runtime_checkable @runtime_checkable @@ -37,4 +38,31 @@ async def start_stream( ) -> StreamingSpeechSession: ... -__all__ = ["SpeechRecognizer", "StreamingSpeechRecognizer", "StreamingSpeechSession"] +@runtime_checkable +class SpeechSynthesizer(Protocol): + async def synthesize(self, text: str) -> bytes: ... + + +@dataclass(frozen=True) +class AudioFormat: + sample_rate_hz: int + channels: int + sample_width: int + + +@runtime_checkable +class StreamingSpeechSynthesizer(SpeechSynthesizer, Protocol): + @property + def output_format(self) -> AudioFormat: ... + + def synthesize_stream(self, text: str) -> AsyncIterator[bytes]: ... + + +__all__ = [ + "AudioFormat", + "SpeechRecognizer", + "StreamingSpeechRecognizer", + "StreamingSpeechSession", + "SpeechSynthesizer", + "StreamingSpeechSynthesizer", +] diff --git a/stackchan_server/ws_proxy.py b/stackchan_server/ws_proxy.py index 596284b..1b9bd3b 100644 --- a/stackchan_server/ws_proxy.py +++ b/stackchan_server/ws_proxy.py @@ -1,10 +1,8 @@ from __future__ import annotations import asyncio -import io import os import struct -import wave from contextlib import suppress from enum import IntEnum from logging import getLogger @@ -12,10 +10,10 @@ from typing import Optional from fastapi import WebSocket, WebSocketDisconnect -from vvclient import Client as VVClient from .listen import EmptyTranscriptError, ListenHandler, TimeoutError -from .types import SpeechRecognizer +from .speak import SpeakHandler +from .types import SpeechRecognizer, SpeechSynthesizer logger = getLogger(__name__) @@ -57,16 +55,16 @@ class _WsMsgType(IntEnum): DATA = 2 END = 3 - -def create_voicevox_client() -> VVClient: - voicevox_url = os.getenv("STACKCHAN_VOICEVOX_URL", "http://localhost:50021") - return VVClient(base_uri=voicevox_url) - - class WsProxy: - def __init__(self, websocket: WebSocket, speech_recognizer: SpeechRecognizer): + def __init__( + self, + websocket: WebSocket, + speech_recognizer: SpeechRecognizer, + speech_synthesizer: SpeechSynthesizer, + ): self.ws = websocket self.speech_recognizer = speech_recognizer + self.speech_synthesizer = speech_synthesizer self.recordings_dir = _RECORDINGS_DIR self._debug_recording = _DEBUG_RECORDING_ENABLED if self._debug_recording: @@ -82,13 +80,25 @@ def __init__(self, websocket: WebSocket, speech_recognizer: SpeechRecognizer): sample_width=_SAMPLE_WIDTH, listen_audio_timeout_seconds=_LISTEN_AUDIO_TIMEOUT_SECONDS, ) + self._speaker = SpeakHandler( + websocket=self.ws, + ws_header_fmt=_WS_HEADER_FMT, + wav_kind=_WsKind.WAV.value, + start_msg_type=_WsMsgType.START.value, + data_msg_type=_WsMsgType.DATA.value, + end_msg_type=_WsMsgType.END.value, + down_wav_chunk=_DOWN_WAV_CHUNK, + down_segment_millis=_DOWN_SEGMENT_MILLIS, + down_segment_stagger_millis=_DOWN_SEGMENT_STAGGER_MILLIS, + sample_width=_SAMPLE_WIDTH, + speech_synthesizer=self.speech_synthesizer, + recordings_dir=self.recordings_dir, + debug_recording=self._debug_recording, + ) self._receiving_task: Optional[asyncio.Task] = None self._closed = False - self._speaking = False - self._speak_finished_counter = 0 - self._down_seq = 0 @property @@ -117,33 +127,13 @@ async def listen(self) -> str: ) async def speak(self, text: str) -> None: - start_counter = self._speak_finished_counter - await self._start_talking_stream(text) - if not self._speaking: - return - await self._wait_for_speaking_finished( - min_counter=start_counter + 1, - timeout_seconds=120.0, + await self._speaker.speak( + text, + next_seq=self._next_down_seq, + send_state_command=self.send_state_command, + idle_state=FirmwareState.IDLE, + is_closed=lambda: self._closed, ) - if not self._closed: - await self.send_state_command(FirmwareState.IDLE) - - async def _wait_for_speaking_finished( - self, - *, - min_counter: int = 0, - timeout_seconds: Optional[float] = None, - ) -> None: - loop = asyncio.get_running_loop() - deadline = (loop.time() + timeout_seconds) if timeout_seconds else None - while True: - if self._speak_finished_counter >= min_counter: - return - if self._closed: - raise WebSocketDisconnect() - if deadline and loop.time() >= deadline: - raise TimeoutError("Timed out waiting for speaking finished event") - await asyncio.sleep(0.05) async def send_state_command(self, state_id: int | FirmwareState) -> None: await self._send_state_command(state_id) @@ -166,36 +156,6 @@ async def close(self) -> None: async def start_talking(self, text: str) -> None: await self.speak(text) - async def _start_talking_stream(self, text: str) -> None: - self._speaking = True - try: - async with create_voicevox_client() as client: - audio_query = await client.create_audio_query(text, speaker=29) - wav_bytes = await audio_query.synthesis(speaker=29) - - pcm_bytes, tts_sample_rate, tts_channels, tts_sample_width = self._extract_pcm(wav_bytes) - if len(pcm_bytes) == 0: - self._speaking = False - return - - if tts_sample_width != _SAMPLE_WIDTH: - await self.ws.send_json({"error": f"unsupported sample width {tts_sample_width}"}) - self._speaking = False - return - - bytes_per_second = tts_sample_rate * tts_channels * tts_sample_width - segment_bytes = int(bytes_per_second * (_DOWN_SEGMENT_MILLIS / 1000)) - - if segment_bytes <= 0: - await self.ws.send_json({"error": "invalid segment size computed"}) - self._speaking = False - return - - await self._send_segments(pcm_bytes, tts_sample_rate, tts_channels, segment_bytes) - except Exception as exc: # pragma: no cover - self._speaking = False - await self.ws.send_json({"error": f"voicevox synthesis failed: {exc}"}) - async def _receive_loop(self) -> None: try: while True: @@ -255,7 +215,6 @@ async def _receive_loop(self) -> None: pass finally: self._closed = True - self._speaking = False def _handle_wakeword_event(self, msg_type: int, payload: bytes) -> None: if msg_type != _WsMsgType.DATA: @@ -282,9 +241,7 @@ def _handle_speak_done_event(self, msg_type: int, payload: bytes) -> None: return if len(payload) < 1: return - self._speak_finished_counter += 1 - self._speaking = False - logger.info("Received speak done event") + self._speaker.handle_speak_done_event() async def _send_state_command(self, state_id: int | FirmwareState) -> None: payload = struct.pack(" None: await self.ws.send_bytes(hdr + payload) self._down_seq += 1 - def _extract_pcm(self, wav_bytes: bytes) -> tuple[bytes, int, int, int]: - with wave.open(io.BytesIO(wav_bytes), "rb") as wf: - pcm_bytes = wf.readframes(wf.getnframes()) - tts_sample_rate = wf.getframerate() - tts_channels = wf.getnchannels() - tts_sample_width = wf.getsampwidth() - return pcm_bytes, tts_sample_rate, tts_channels, tts_sample_width - - async def _send_segments(self, pcm_bytes: bytes, tts_sample_rate: int, tts_channels: int, segment_bytes: int) -> None: - segments: list[bytes] = [] - offset = 0 - total = len(pcm_bytes) - while offset < total: - segments.append(pcm_bytes[offset : offset + segment_bytes]) - offset += segment_bytes - - loop = asyncio.get_running_loop() - base_time = loop.time() - - for idx, segment in enumerate(segments): - if idx == 0: - target_ms = 0 - elif idx == 1: - target_ms = _DOWN_SEGMENT_STAGGER_MILLIS - else: - target_ms = _DOWN_SEGMENT_STAGGER_MILLIS + (idx - 1) * _DOWN_SEGMENT_MILLIS - - target_time = base_time + target_ms / 1000 - now = loop.time() - if target_time > now: - await asyncio.sleep(target_time - now) - - await self._send_segment(segment, tts_sample_rate, tts_channels) - - async def _send_segment(self, segment_pcm: bytes, tts_sample_rate: int, tts_channels: int) -> None: - logger.info("Sending segment bytes=%d", len(segment_pcm)) - start_payload = struct.pack(" int: + seq = self._down_seq self._down_seq += 1 + return seq -__all__ = ["WsProxy", "FirmwareState", "TimeoutError", "EmptyTranscriptError", "create_voicevox_client"] +__all__ = ["WsProxy", "FirmwareState", "TimeoutError", "EmptyTranscriptError"] diff --git a/uv.lock b/uv.lock index 98700d1..5e62090 100644 --- a/uv.lock +++ b/uv.lock @@ -1411,6 +1411,7 @@ source = { editable = "." } dependencies = [ { name = "fastapi" }, { name = "google-cloud-speech" }, + { name = "google-genai" }, { name = "uvicorn", extra = ["standard"] }, { name = "voicevox-client" }, ] @@ -1423,14 +1424,12 @@ dev = [ example-claude-agent-sdk = [ { name = "claude-agent-sdk" }, ] -example-gemini = [ - { name = "google-genai" }, -] [package.metadata] requires-dist = [ { name = "fastapi", specifier = ">=0.128.0" }, { name = "google-cloud-speech", specifier = ">=2.35.0" }, + { name = "google-genai", specifier = ">=1.59.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.40.0" }, { name = "voicevox-client", specifier = ">=1.1.0" }, ] @@ -1441,7 +1440,7 @@ dev = [ { name = "ty", specifier = ">=0.0.17" }, ] example-claude-agent-sdk = [{ name = "claude-agent-sdk", specifier = ">=0.1.39" }] -example-gemini = [{ name = "google-genai", specifier = ">=1.59.0" }] +example-gemini = [] [[package]] name = "websockets"