diff --git a/.env.template b/.env.template index 43d3290..1d82075 100644 --- a/.env.template +++ b/.env.template @@ -21,6 +21,8 @@ STACKCHAN_GOOGLE_CLOUD_STT_LANGUAGE_CODE="ja-JP" # STACKCHAN_USE_WHISPER_SERVER=1 # STACKCHAN_WHISPER_SERVER_URL="http://127.0.0.1:8080/inference" # STACKCHAN_WHISPER_SERVER_MODEL= +# STACKCHAN_WHISPER_SERVER_PROMPT= +# STACKCHAN_WHISPER_SERVER_LANGUAGE="ja" # -- Speech Syntheis -- # Google Cloud TTS @@ -34,6 +36,14 @@ STACKCHAN_GOOGLE_CLOUD_TTS_VOICE_NAME="Despina" STACKCHAN_VOICEVOX_URL="http://localhost:50021" STACKCHAN_VOICEVOX_SPEAKER=1 +# -- Server-side Wakeup Word Detection -- +# Whisper Server +# STACKCHAN_USE_WWD_WHISPER_SERVER=1 +# STACKCHAN_WWD_WHISPER_SERVER_URL="http://127.0.0.1:8080/inference" +# STACKCHAN_WWD_WHISPER_SERVER_MODEL= +# STACKCHAN_WWD_WHISPER_SERVER_LANGUAGE="ja" +# STACKCHAN_WWD_WHISPER_SERVER_PROMPT="日本語で、スタックチャンという名前で、話しかけらるので、話しかけられたことを検出してください" + # -- Claude Agent SDK -- # using Google Cloud Vertex AI CLAUDE_CODE_USE_VERTEX=1 diff --git a/AGENTS.md b/AGENTS.md index 26b7ae4..54be171 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -12,8 +12,8 @@ ## 状態遷移の要点 -- ファームウェア状態: `Idle`, `Listening`, `Thinking`, `Speaking`, `Disconnected` -- サーバーから指示できるのは `StateCmd` の `Idle` / `Listening` / `Thinking` / `Speaking` +- ファームウェア状態: `Idle`, `Listening`, `Thinking`, `Speaking`, `ServerWwd`, `Disconnected` +- サーバーから指示できるのは `StateCmd` の `Idle` / `Listening` / `Thinking` / `Speaking` / `ServerWwd` - `Disconnected` はファームウェア内部状態で、WebSocket 切断時に入る - `WakeWordEvt` を受けるか、REST API の wakeword 擬似発火で talk session が始まる @@ -75,6 +75,7 @@ - `websocket.client.host` を StackChan の識別子として使う - 同一 IP の再接続時は既存接続を置き換える - `listen()` は `Listening` 指示後、音声 uplink 完了を待つ +- サーバーサイド wakeword 検出中は `ServerWwd` を指示する - `speak()` は TTS downlink 送信後、`SpeakDoneEvt` を待つ - `move_servo()` / `wait_servo_complete()` を公開 @@ -106,7 +107,7 @@ - `MoveX`, `MoveY`, `Sleep` を順次処理 - 完了時に `ServoDoneEvt` - `src/display.cpp` - - `Idle=濃いグレー`, `Listening=青`, `Thinking=オレンジ`, `Speaking=緑`, `Disconnected=赤` + - `Idle=濃いグレー`, `Listening=青`, `Thinking=オレンジ`, `Speaking=緑`, `ServerWwd=Idle(Server-WWD)`, `Disconnected=赤` ## サンプルアプリ diff --git a/docs/server_ja.md b/docs/server_ja.md index 33767ce..2bc40c9 100644 --- a/docs/server_ja.md +++ b/docs/server_ja.md @@ -65,6 +65,8 @@ STACKCHAN_WHISPER_CLI_VAD_MODEL_PATH="/path/to/whisper.cpp/ggml-silero-v5.1.2.bi `STACKCHAN_WHISPER_SERVER_URL` に Whisper Server の推論エンドポイント URL をそのまま指定します。 未設定時は `http://127.0.0.1:8080/inference` を利用します。 +`STACKCHAN_WHISPER_SERVER_LANGUAGE` を設定すると、その値を `language` パラメータとして各リクエストに含めます。未設定または空文字の場合は `language` を送信しません。 +また、`STACKCHAN_WHISPER_SERVER_PROMPT` を設定すると、whisper-server の各リクエストに `prompt` フィールドとして送信します。 #### 例: Whisper.cppのwhisper-serverの設定 @@ -74,6 +76,8 @@ whisper.cpp/examples/server: https://github.com/ggml-org/whisper.cpp/tree/master STACKCHAN_USE_WHISPER_SERVER=1 STACKCHAN_WHISPER_SERVER_URL="http://127.0.0.1:8080/inference" STACKCHAN_WHISPER_SERVER_MODEL= +STACKCHAN_WHISPER_SERVER_LANGUAGE="ja" +STACKCHAN_WHISPER_SERVER_PROMPT="" ``` #### 例: [Lemonade](https://lemonade-server.ai/) を使う場合 @@ -84,6 +88,28 @@ Lemonade: https://lemonade-server.ai/ STACKCHAN_USE_WHISPER_SERVER=1 STACKCHAN_WHISPER_SERVER_URL=http://localhost:13305/api/v1/audio/transcriptions STACKCHAN_WHISPER_SERVER_MODEL=Whisper-Large-v3-Turbo +STACKCHAN_WHISPER_SERVER_LANGUAGE="ja" +STACKCHAN_WHISPER_SERVER_PROMPT="" +``` + +### (オプション) サーバーサイド wakeword 用 Whisper Server の設定 + +サーバーサイド wakeword 検出を有効にするには、以下を設定します。 + +- `STACKCHAN_USE_WWD_WHISPER_SERVER`: `1` +- `STACKCHAN_WWD_WHISPER_SERVER_URL`: wakeword 検出専用 Whisper Server の推論エンドポイント URL +- `STACKCHAN_WWD_WHISPER_SERVER_MODEL`: wakeword 検出専用に利用するモデル名 +- `STACKCHAN_WWD_WHISPER_SERVER_LANGUAGE`: wakeword 検出専用 Whisper Server リクエストへ渡す language +- `STACKCHAN_WWD_WHISPER_SERVER_PROMPT`: wakeword 検出専用 Whisper Server リクエストへ渡す prompt + +通常の音声認識で使う `STACKCHAN_WHISPER_SERVER_URL` / `STACKCHAN_WHISPER_SERVER_MODEL` とは別設定です。 + +``` +STACKCHAN_USE_WWD_WHISPER_SERVER=1 +STACKCHAN_WWD_WHISPER_SERVER_URL="http://127.0.0.1:8080/inference" +STACKCHAN_WWD_WHISPER_SERVER_MODEL= +STACKCHAN_WWD_WHISPER_SERVER_LANGUAGE="ja" +STACKCHAN_WWD_WHISPER_SERVER_PROMPT="日本語で、スタックチャンという名前で、話しかけらるので、話しかけられたことを検出してください" ``` ## 音声合成の設定 diff --git a/docs/websocket_protocols_ja.md b/docs/websocket_protocols_ja.md index b816c13..fc2945c 100644 --- a/docs/websocket_protocols_ja.md +++ b/docs/websocket_protocols_ja.md @@ -28,6 +28,7 @@ | 名前 | 方向 | 用途 | | --- | --- | --- | | `AudioPcm` | CoreS3 → Server | マイク音声 PCM ストリーム | +| `ServerWwdPcm` | CoreS3 → Server | サーバーサイド wakeword 検出専用 PCM ストリーム | | `AudioWav` | Server → CoreS3 | TTS 音声 PCM ストリーム | | `StateCmd` | Server → CoreS3 | 状態遷移指示 | | `WakeWordEvt` | CoreS3 → Server | ウェイクワード検出通知 | @@ -35,6 +36,8 @@ | `SpeakDoneEvt` | CoreS3 → Server | 音声再生完了通知 | | `ServoCmd` | Server → CoreS3 | サーボ動作シーケンス指示 | | `ServoDoneEvt` | CoreS3 → Server | サーボ動作完了通知 | +| `FirmwareMetadata` | CoreS3 → Server | クライアント能力通知 | +| `ServerMetadata` | Server → CoreS3 | サーバー能力通知 | ### `MessageType` 一覧 @@ -62,6 +65,20 @@ - 無音判定は平均絶対振幅 `<= 200` が 3 秒継続したときに発火します。 - 停止時は未送信サンプルを `DATA` で flush してから `END` を送ります。 +## サーバーサイド wakeword 入力 `ServerWwdPcm` + +- 方向: CoreS3 → Server +- フォーマット: PCM16LE / 16kHz / 1ch +- シーケンス: `AudioPcmStart` → `AudioChunk` 複数回 → `AudioPcmEnd` +- `kind`: `MESSAGE_KIND_SERVER_WWD_PCM` +- body は `AudioPcm` と同じ `AudioPcmStart` / `AudioChunk` / `AudioPcmEnd` を使います。 + +### 現行実装メモ + +- `StateCmd(ServerWwd)` を受けた CoreS3 は、この kind で uplink を開始します。 +- 無音 3 秒によるクライアント側自動終了は行いません。 +- サーバーはこの kind だけを server-side wakeword detector にルーティングします。 + ## スピーカ再生 `AudioWav` - 方向: Server → CoreS3 @@ -97,13 +114,18 @@ - `Listening` - `Thinking` - `Speaking` +- `ServerWwd` ### 現行実装メモ -- `proxy.listen()` 開始時に Server が `Listening` を指示します。 +- `proxy.listen()` 開始時に Server が `StateCmd(Listening)` を指示します。 +- サーバーサイド wakeword 検出開始時は `StateCmd(ServerWwd)` を指示します。 - 音声 uplink の `END` を受けると、Server は `Thinking` を指示します。 - `proxy.speak()` 完了後、Server は `Idle` を指示します。 +> [!NOTE] +> `ServerWwd` の場合、CoreS3 は内部的にマイク uplink を開始しますが、表示は `Idle(Server-WWD)` にし、無音 3 秒による自動終了も行いません。 + ## ウェイクワード検出 `WakeWordEvt` - 方向: CoreS3 → Server @@ -112,6 +134,30 @@ - `Idle` 中のウェイクワード検出をサーバー側に通知します。 - REST API の `POST /v1/stackchan/{ip}/wakeword` は、このイベントをサーバー内部で擬似発火させます。 +## メタデータ交換 `FirmwareMetadata` / `ServerMetadata` + +WebSocket 接続後、能力情報を相互交換します。 + +- CoreS3 → Server: `FirmwareMetadata` + - `has_device_wake_word`: クライアント側 wakeword 対応有無 + - そのほか `device_type`, `display_width`, `display_height`, `has_led`, `servo_type`, `supports_audio_duplex`, `firmware_version` +- Server → CoreS3: `ServerMetadata` + - `has_server_wake_word`: サーバー側 wakeword 対応有無 + - `server_version` + +CoreS3 側は `has_server_wake_word=true` を受けると、デバイス側 wakeword を使わずにサーバー側検出モードで待機します(表示は `Idle(Server-WWD)`)。 + +## サーバーサイド wakeword 検出フロー + +- 環境変数 `STACKCHAN_USE_WWD_WHISPER_SERVER=1` の場合、サーバーは `@app.setup()` 完了後と `Idle` 復帰後に自動でサーバーサイド wakeword 検出を開始します。 +- サーバーは `StateCmd(ServerWwd)` を送信して `MESSAGE_KIND_SERVER_WWD_PCM` のマイク uplink を受信します。 +- 受信した音声の直近 3 秒窓を 0.5 秒ごとに音声認識へ渡し、 + 定義キーワード(例: `スタクチャン`)を含むか判定します。 +- 各判定タイミングの認識結果はすべてログ出力されます。 +- キーワード検出時は内部 wakeword イベントを発火し、通常の `talk_session` フローに進みます。 +- 検出完了時(検出/未検出を問わず)は `StateCmd(Idle)` で待機状態に戻します。 +- この間、CoreS3 の画面表示は `Listening` ではなく `Idle(Server-WWD)` を維持します。 + ## 状態通知 `StateEvt` - 方向: CoreS3 → Server @@ -124,6 +170,7 @@ - `Listening` - `Thinking` - `Speaking` +- `ServerWwd` - CoreS3 は状態遷移の entry hook で送信します。 - WebSocket 切断中は `Disconnected` 状態になりますが、切断時は uplink 送信できないため `StateEvt` では通知されません。 diff --git a/firmware/include/listening.hpp b/firmware/include/listening.hpp index 0e18ba8..0a89b25 100644 --- a/firmware/include/listening.hpp +++ b/firmware/include/listening.hpp @@ -10,6 +10,12 @@ class Listening { public: + enum class SessionMode + { + Speech, + WakeWord, + }; + Listening(WebSocketsClient &ws, StateMachine &sm, int sampleRate); // allocate buffers / reset counters; call once from setup @@ -19,6 +25,10 @@ class Listening void begin(); void end(); + // Idle(Server-WWD) のままマイク uplink を開始/終了する + bool beginWakeWordStreaming(); + void endWakeWordStreaming(); + // begin a new streaming session (sends START); returns false if WS not connected bool startStreaming(); @@ -34,7 +44,11 @@ class Listening // 無音が所定時間続いているか判定 bool shouldStopForSilence() const; + bool isWakeWordStreaming() const { return streaming_ && session_mode_ == SessionMode::WakeWord; } + private: + bool beginStreamingSession(SessionMode mode, bool auto_stop_for_silence); + void stopMicrophoneOnly(); void updateLevelStats(const int16_t *samples, size_t sampleCount); bool sendPacket(stackchan_websocket_v1_MessageType type, const int16_t *samples, size_t sampleCount); void ringPush(const int16_t *src, size_t samples); @@ -56,6 +70,8 @@ class Listening uint32_t seq_counter_ = 0; bool streaming_ = false; bool events_registered_ = false; + SessionMode session_mode_ = SessionMode::Speech; + bool auto_stop_for_silence_ = true; // 無音判定関連 int32_t last_level_ = 0; diff --git a/firmware/include/metadata.hpp b/firmware/include/metadata.hpp index f490abd..c97b4eb 100644 --- a/firmware/include/metadata.hpp +++ b/firmware/include/metadata.hpp @@ -29,6 +29,7 @@ extern ServerMetadataState g_server_metadata; void initializeFirmwareMetadata(); void resetServerMetadata(); bool shouldUseDeviceWakeWord(); +bool shouldUseServerWakeWord(); void setFirmwareMetadataMessage( stackchan_websocket_v1_WebSocketMessage &message, uint32_t seq); diff --git a/firmware/include/state_machine.hpp b/firmware/include/state_machine.hpp index d5bcd62..a3f4b97 100644 --- a/firmware/include/state_machine.hpp +++ b/firmware/include/state_machine.hpp @@ -14,7 +14,8 @@ class StateMachine Listening = 1, Thinking = 2, Speaking = 3, - Disconnected = 4, + ServerWwd = 4, + Disconnected = 5, }; StateMachine() = default; @@ -25,6 +26,7 @@ class StateMachine bool isListening() const; bool isThinking() const; bool isSpeaking() const; + bool isServerWwd() const; bool isDisconnected() const; using Callback = std::function; @@ -33,8 +35,8 @@ class StateMachine private: State state_ = Disconnected; - std::array, 5> entry_events_{}; - std::array, 5> exit_events_{}; + std::array, 6> entry_events_{}; + std::array, 6> exit_events_{}; }; const char *stateToString(StateMachine::State state); diff --git a/firmware/lib/generated_protobuf/websocket-message.pb.h b/firmware/lib/generated_protobuf/websocket-message.pb.h index 8e0c222..cc98ef3 100644 --- a/firmware/lib/generated_protobuf/websocket-message.pb.h +++ b/firmware/lib/generated_protobuf/websocket-message.pb.h @@ -21,7 +21,8 @@ typedef enum _stackchan_websocket_v1_MessageKind { stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVO_CMD = 7, stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVO_DONE_EVT = 8, stackchan_websocket_v1_MessageKind_MESSAGE_KIND_FIRMWARE_METADATA = 9, - stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_METADATA = 10 + stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_METADATA = 10, + stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_WWD_PCM = 11 } stackchan_websocket_v1_MessageKind; typedef enum _stackchan_websocket_v1_MessageType { @@ -35,7 +36,8 @@ typedef enum _stackchan_websocket_v1_StackchanState { stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_IDLE = 0, stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_LISTENING = 1, stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_THINKING = 2, - stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SPEAKING = 3 + stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SPEAKING = 3, + stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SERVER_WWD = 4 } stackchan_websocket_v1_StackchanState; typedef enum _stackchan_websocket_v1_ServoOperation { @@ -165,16 +167,16 @@ extern "C" { /* Helper constants for enums */ #define _stackchan_websocket_v1_MessageKind_MIN stackchan_websocket_v1_MessageKind_MESSAGE_KIND_UNSPECIFIED -#define _stackchan_websocket_v1_MessageKind_MAX stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_METADATA -#define _stackchan_websocket_v1_MessageKind_ARRAYSIZE ((stackchan_websocket_v1_MessageKind)(stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_METADATA+1)) +#define _stackchan_websocket_v1_MessageKind_MAX stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_WWD_PCM +#define _stackchan_websocket_v1_MessageKind_ARRAYSIZE ((stackchan_websocket_v1_MessageKind)(stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_WWD_PCM+1)) #define _stackchan_websocket_v1_MessageType_MIN stackchan_websocket_v1_MessageType_MESSAGE_TYPE_UNSPECIFIED #define _stackchan_websocket_v1_MessageType_MAX stackchan_websocket_v1_MessageType_MESSAGE_TYPE_END #define _stackchan_websocket_v1_MessageType_ARRAYSIZE ((stackchan_websocket_v1_MessageType)(stackchan_websocket_v1_MessageType_MESSAGE_TYPE_END+1)) #define _stackchan_websocket_v1_StackchanState_MIN stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_IDLE -#define _stackchan_websocket_v1_StackchanState_MAX stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SPEAKING -#define _stackchan_websocket_v1_StackchanState_ARRAYSIZE ((stackchan_websocket_v1_StackchanState)(stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SPEAKING+1)) +#define _stackchan_websocket_v1_StackchanState_MAX stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SERVER_WWD +#define _stackchan_websocket_v1_StackchanState_ARRAYSIZE ((stackchan_websocket_v1_StackchanState)(stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SERVER_WWD+1)) #define _stackchan_websocket_v1_ServoOperation_MIN stackchan_websocket_v1_ServoOperation_SERVO_OPERATION_SLEEP #define _stackchan_websocket_v1_ServoOperation_MAX stackchan_websocket_v1_ServoOperation_SERVO_OPERATION_MOVE_Y diff --git a/firmware/src/display.cpp b/firmware/src/display.cpp index a33cabf..07f0815 100644 --- a/firmware/src/display.cpp +++ b/firmware/src/display.cpp @@ -4,6 +4,7 @@ #include "config.h" #include "display.hpp" +#include "metadata.hpp" #if USE_STACKCHAN_BSP #define GFXModule M5StackChan.Display() @@ -119,6 +120,11 @@ void Display::drawForState(StateMachine::State state) font_color = TFT_BLACK; led_color = Adafruit_NeoPixel::ColorHSV(kLedHueGreen, 255, ledValueFromBrightness()); break; + case StateMachine::ServerWwd: + bg_color = TFT_DARKGRAY; + font_color = TFT_WHITE; + led_color = Adafruit_NeoPixel::ColorHSV(0, 0, 0); + break; case StateMachine::Disconnected: bg_color = TFT_RED; font_color = TFT_WHITE; diff --git a/firmware/src/listening.cpp b/firmware/src/listening.cpp index edb2e35..256d26a 100644 --- a/firmware/src/listening.cpp +++ b/firmware/src/listening.cpp @@ -42,21 +42,45 @@ void Listening::init() void Listening::begin() { M5.Mic.begin(); - startStreaming(); + beginStreamingSession(SessionMode::Speech, true); } void Listening::end() { stopStreaming(); - M5.Mic.end(); + stopMicrophoneOnly(); +} + +bool Listening::beginWakeWordStreaming() +{ + if (streaming_) + { + return session_mode_ == SessionMode::WakeWord; + } + + M5.Mic.begin(); + return beginStreamingSession(SessionMode::WakeWord, false); +} + +void Listening::endWakeWordStreaming() +{ + stopStreaming(); + stopMicrophoneOnly(); } bool Listening::startStreaming() +{ + return beginStreamingSession(SessionMode::Speech, true); +} + +bool Listening::beginStreamingSession(SessionMode mode, bool auto_stop_for_silence) { ring_write_ = ring_read_ = ring_available_ = 0; seq_counter_ = 0; last_level_ = 0; silence_since_ms_ = 0; + session_mode_ = mode; + auto_stop_for_silence_ = auto_stop_for_silence; streaming_ = true; return sendPacket(stackchan_websocket_v1_MessageType_MESSAGE_TYPE_START, nullptr, 0); } @@ -90,9 +114,18 @@ bool Listening::stopStreaming() streaming_ = false; ok = sendPacket(stackchan_websocket_v1_MessageType_MESSAGE_TYPE_END, nullptr, 0) && ok; + session_mode_ = SessionMode::Speech; + auto_stop_for_silence_ = true; return ok; } +void Listening::stopMicrophoneOnly() +{ + session_mode_ = SessionMode::Speech; + auto_stop_for_silence_ = true; + M5.Mic.end(); +} + void Listening::loop() { if (!streaming_) @@ -123,13 +156,20 @@ void Listening::loop() { streaming_ = false; log_i("WS send failed (data)"); - state_.setState(StateMachine::Idle); + if (session_mode_ == SessionMode::Speech) + { + state_.setState(StateMachine::Idle); + } + else + { + stopMicrophoneOnly(); + } return; } } // 無音が3秒続いたら終了 - if (shouldStopForSilence()) + if (auto_stop_for_silence_ && shouldStopForSilence()) { log_i("Auto stop: silence detected (avg=%ld)", static_cast(last_level_)); if (!stopStreaming()) @@ -196,7 +236,10 @@ bool Listening::sendPacket(stackchan_websocket_v1_MessageType type, const int16_ auto &message = g_listening_tx_message; message = stackchan_websocket_v1_WebSocketMessage_init_zero; - message.kind = stackchan_websocket_v1_MessageKind_MESSAGE_KIND_AUDIO_PCM; + message.kind = + (session_mode_ == SessionMode::WakeWord) + ? stackchan_websocket_v1_MessageKind_MESSAGE_KIND_SERVER_WWD_PCM + : stackchan_websocket_v1_MessageKind_MESSAGE_KIND_AUDIO_PCM; message.message_type = type; message.seq = seq_counter_++; diff --git a/firmware/src/main.cpp b/firmware/src/main.cpp index 6fb95c4..5263e24 100644 --- a/firmware/src/main.cpp +++ b/firmware/src/main.cpp @@ -249,6 +249,28 @@ bool applyRemoteStateCommand(const stackchan_websocket_v1_StateCommand &command) case stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SPEAKING: stateMachine.setState(StateMachine::Speaking); return true; + case stackchan_websocket_v1_StackchanState_STACKCHAN_STATE_SERVER_WWD: + if (!shouldUseServerWakeWord()) + { + log_w("Server-side wakeword is not available"); + return false; + } + if (stateMachine.getState() == StateMachine::ServerWwd) + { + return true; + } + if (stateMachine.getState() != StateMachine::Idle) + { + log_w("Cannot enter server-side wakeword from state=%u", static_cast(stateMachine.getState())); + return false; + } + if (!listening.beginWakeWordStreaming()) + { + log_w("Failed to start server-side wakeword streaming"); + return false; + } + stateMachine.setState(StateMachine::ServerWwd); + return true; default: log_w("Unknown remote state"); return false; @@ -323,6 +345,11 @@ void handleWsEvent(WStype_t type, uint8_t *payload, size_t length) case WStype_DISCONNECTED: // M5.Display.println("WS: disconnected"); log_i("WS disconnected"); + if (listening.isWakeWordStreaming()) + { + log_i("Stopping server-side wakeword uplink because WS disconnected"); + listening.endWakeWordStreaming(); + } resetServerMetadata(); stateMachine.setState(StateMachine::Disconnected); break; @@ -513,6 +540,13 @@ void setup() listening.end(); }); + stateMachine.addStateEntryEvent(StateMachine::ServerWwd, [](StateMachine::State, StateMachine::State) { + notifyCurrentState(StateMachine::ServerWwd); + }); + stateMachine.addStateExitEvent(StateMachine::ServerWwd, [](StateMachine::State, StateMachine::State) { + listening.endWakeWordStreaming(); + }); + stateMachine.addStateEntryEvent(StateMachine::Speaking, [](StateMachine::State, StateMachine::State) { notifyCurrentState(StateMachine::Speaking); speaking.begin(); @@ -542,7 +576,11 @@ void loop() { case StateMachine::Idle: handleTouchWakeWordInput(); - if (shouldUseDeviceWakeWord()) + if (listening.isWakeWordStreaming()) + { + listening.loop(); + } + else if (shouldUseDeviceWakeWord()) { wakeUpWord.loop(); } @@ -550,6 +588,9 @@ void loop() case StateMachine::Listening: listening.loop(); break; + case StateMachine::ServerWwd: + listening.loop(); + break; case StateMachine::Thinking: // Wait for server side command / audio stream. break; diff --git a/firmware/src/metadata.cpp b/firmware/src/metadata.cpp index 5e6cd9b..579e9a8 100644 --- a/firmware/src/metadata.cpp +++ b/firmware/src/metadata.cpp @@ -73,6 +73,11 @@ bool shouldUseDeviceWakeWord() return g_server_metadata.available && !g_server_metadata.has_server_wake_word; } +bool shouldUseServerWakeWord() +{ + return g_server_metadata.available && g_server_metadata.has_server_wake_word; +} + void setFirmwareMetadataMessage( stackchan_websocket_v1_WebSocketMessage &message, uint32_t seq) diff --git a/firmware/src/state_machine.cpp b/firmware/src/state_machine.cpp index 2432cd2..ea38cfd 100644 --- a/firmware/src/state_machine.cpp +++ b/firmware/src/state_machine.cpp @@ -13,6 +13,8 @@ const char *stateToString(StateMachine::State s) return "Thinking"; case StateMachine::Speaking: return "Speaking"; + case StateMachine::ServerWwd: + return "Idle(Server-WWD)"; case StateMachine::Disconnected: return "Disconnected"; default: @@ -66,6 +68,11 @@ bool StateMachine::isThinking() const return state_ == Thinking; } +bool StateMachine::isServerWwd() const +{ + return state_ == ServerWwd; +} + bool StateMachine::isDisconnected() const { return state_ == Disconnected; diff --git a/protobuf/websocket-message.proto b/protobuf/websocket-message.proto index c643673..4d288ef 100644 --- a/protobuf/websocket-message.proto +++ b/protobuf/websocket-message.proto @@ -46,6 +46,7 @@ enum MessageKind { MESSAGE_KIND_SERVO_DONE_EVT = 8; MESSAGE_KIND_FIRMWARE_METADATA = 9; MESSAGE_KIND_SERVER_METADATA = 10; + MESSAGE_KIND_SERVER_WWD_PCM = 11; } enum MessageType { @@ -60,6 +61,7 @@ enum StackchanState { STACKCHAN_STATE_LISTENING = 1; STACKCHAN_STATE_THINKING = 2; STACKCHAN_STATE_SPEAKING = 3; + STACKCHAN_STATE_SERVER_WWD = 4; } enum ServoOperation { diff --git a/stackchan_server/app.py b/stackchan_server/app.py index 14496d2..5921d9a 100644 --- a/stackchan_server/app.py +++ b/stackchan_server/app.py @@ -99,6 +99,8 @@ async def _handle_ws(self, websocket: WebSocket) -> None: if self._setup_fn: await self._setup_fn(proxy) + await proxy.enable_auto_server_wakeword_detection() + while not proxy.closed: if not self._talk_session_fn: await asyncio.sleep(0.05) diff --git a/stackchan_server/generated_protobuf/websocket_message_pb2.py b/stackchan_server/generated_protobuf/websocket_message_pb2.py index a7d7a4e..1237224 100644 --- a/stackchan_server/generated_protobuf/websocket_message_pb2.py +++ b/stackchan_server/generated_protobuf/websocket_message_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17websocket-message.proto\x12\x16stackchan.websocket.v1\"\x96\x08\n\x10WebSocketMessage\x12\x31\n\x04kind\x18\x01 \x01(\x0e\x32#.stackchan.websocket.v1.MessageKind\x12\x39\n\x0cmessage_type\x18\x02 \x01(\x0e\x32#.stackchan.websocket.v1.MessageType\x12\x0b\n\x03seq\x18\x03 \x01(\r\x12@\n\x0f\x61udio_pcm_start\x18\n \x01(\x0b\x32%.stackchan.websocket.v1.AudioPcmStartH\x00\x12<\n\x0e\x61udio_pcm_data\x18\x0b \x01(\x0b\x32\".stackchan.websocket.v1.AudioChunkH\x00\x12<\n\raudio_pcm_end\x18\x0c \x01(\x0b\x32#.stackchan.websocket.v1.AudioPcmEndH\x00\x12@\n\x0f\x61udio_wav_start\x18\x14 \x01(\x0b\x32%.stackchan.websocket.v1.AudioWavStartH\x00\x12<\n\x0e\x61udio_wav_data\x18\x15 \x01(\x0b\x32\".stackchan.websocket.v1.AudioChunkH\x00\x12<\n\raudio_wav_end\x18\x16 \x01(\x0b\x32#.stackchan.websocket.v1.AudioWavEndH\x00\x12\x39\n\tstate_cmd\x18\x1e \x01(\x0b\x32$.stackchan.websocket.v1.StateCommandH\x00\x12>\n\rwake_word_evt\x18\x1f \x01(\x0b\x32%.stackchan.websocket.v1.WakeWordEventH\x00\x12\x37\n\tstate_evt\x18 \x01(\x0b\x32\".stackchan.websocket.v1.StateEventH\x00\x12@\n\x0espeak_done_evt\x18! \x01(\x0b\x32&.stackchan.websocket.v1.SpeakDoneEventH\x00\x12\x41\n\tservo_cmd\x18\" \x01(\x0b\x32,.stackchan.websocket.v1.ServoCommandSequenceH\x00\x12@\n\x0eservo_done_evt\x18# \x01(\x0b\x32&.stackchan.websocket.v1.ServoDoneEventH\x00\x12\x45\n\x11\x66irmware_metadata\x18$ \x01(\x0b\x32(.stackchan.websocket.v1.FirmwareMetadataH\x00\x12\x41\n\x0fserver_metadata\x18% \x01(\x0b\x32&.stackchan.websocket.v1.ServerMetadataH\x00\x42\x06\n\x04\x62ody\"\x0f\n\rAudioPcmStart\"\r\n\x0b\x41udioPcmEnd\"6\n\rAudioWavStart\x12\x13\n\x0bsample_rate\x18\x01 \x01(\r\x12\x10\n\x08\x63hannels\x18\x02 \x01(\r\"\r\n\x0b\x41udioWavEnd\"\x1f\n\nAudioChunk\x12\x11\n\tpcm_bytes\x18\x01 \x01(\x0c\"E\n\x0cStateCommand\x12\x35\n\x05state\x18\x01 \x01(\x0e\x32&.stackchan.websocket.v1.StackchanState\"!\n\rWakeWordEvent\x12\x10\n\x08\x64\x65tected\x18\x01 \x01(\x08\"C\n\nStateEvent\x12\x35\n\x05state\x18\x01 \x01(\x0e\x32&.stackchan.websocket.v1.StackchanState\"\x1e\n\x0eSpeakDoneEvent\x12\x0c\n\x04\x64one\x18\x01 \x01(\x08\"N\n\x14ServoCommandSequence\x12\x36\n\x08\x63ommands\x18\x01 \x03(\x0b\x32$.stackchan.websocket.v1.ServoCommand\"f\n\x0cServoCommand\x12\x32\n\x02op\x18\x01 \x01(\x0e\x32&.stackchan.websocket.v1.ServoOperation\x12\r\n\x05\x61ngle\x18\x02 \x01(\x11\x12\x13\n\x0b\x64uration_ms\x18\x03 \x01(\x11\"\x1e\n\x0eServoDoneEvent\x12\x0c\n\x04\x64one\x18\x01 \x01(\x08\"\x99\x02\n\x10\x46irmwareMetadata\x12\x37\n\x0b\x64\x65vice_type\x18\x01 \x01(\x0e\x32\".stackchan.websocket.v1.DeviceType\x12\x15\n\rdisplay_width\x18\x02 \x01(\r\x12\x16\n\x0e\x64isplay_height\x18\x03 \x01(\r\x12\x1c\n\x14has_device_wake_word\x18\x04 \x01(\x08\x12\x0f\n\x07has_led\x18\x05 \x01(\x08\x12\x35\n\nservo_type\x18\x06 \x01(\x0e\x32!.stackchan.websocket.v1.ServoType\x12\x1d\n\x15supports_audio_duplex\x18\x07 \x01(\x08\x12\x18\n\x10\x66irmware_version\x18\x08 \x01(\t\"F\n\x0eServerMetadata\x12\x1c\n\x14has_server_wake_word\x18\x01 \x01(\x08\x12\x16\n\x0eserver_version\x18\x02 \x01(\t*\xdf\x02\n\x0bMessageKind\x12\x1c\n\x18MESSAGE_KIND_UNSPECIFIED\x10\x00\x12\x1a\n\x16MESSAGE_KIND_AUDIO_PCM\x10\x01\x12\x1a\n\x16MESSAGE_KIND_AUDIO_WAV\x10\x02\x12\x1a\n\x16MESSAGE_KIND_STATE_CMD\x10\x03\x12\x1e\n\x1aMESSAGE_KIND_WAKE_WORD_EVT\x10\x04\x12\x1a\n\x16MESSAGE_KIND_STATE_EVT\x10\x05\x12\x1f\n\x1bMESSAGE_KIND_SPEAK_DONE_EVT\x10\x06\x12\x1a\n\x16MESSAGE_KIND_SERVO_CMD\x10\x07\x12\x1f\n\x1bMESSAGE_KIND_SERVO_DONE_EVT\x10\x08\x12\"\n\x1eMESSAGE_KIND_FIRMWARE_METADATA\x10\t\x12 \n\x1cMESSAGE_KIND_SERVER_METADATA\x10\n*p\n\x0bMessageType\x12\x1c\n\x18MESSAGE_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12MESSAGE_TYPE_START\x10\x01\x12\x15\n\x11MESSAGE_TYPE_DATA\x10\x02\x12\x14\n\x10MESSAGE_TYPE_END\x10\x03*\x85\x01\n\x0eStackchanState\x12\x18\n\x14STACKCHAN_STATE_IDLE\x10\x00\x12\x1d\n\x19STACKCHAN_STATE_LISTENING\x10\x01\x12\x1c\n\x18STACKCHAN_STATE_THINKING\x10\x02\x12\x1c\n\x18STACKCHAN_STATE_SPEAKING\x10\x03*c\n\x0eServoOperation\x12\x19\n\x15SERVO_OPERATION_SLEEP\x10\x00\x12\x1a\n\x16SERVO_OPERATION_MOVE_X\x10\x01\x12\x1a\n\x16SERVO_OPERATION_MOVE_Y\x10\x02*\x85\x01\n\nDeviceType\x12\x1b\n\x17\x44\x45VICE_TYPE_UNSPECIFIED\x10\x00\x12\x1e\n\x1a\x44\x45VICE_TYPE_M5STACK_CORES3\x10\x01\x12\x1a\n\x16\x44\x45VICE_TYPE_M5ATOM_S3R\x10\x02\x12\x1e\n\x1a\x44\x45VICE_TYPE_M5ATOM_ECHOS3R\x10\x03*i\n\tServoType\x12\x1a\n\x16SERVO_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fSERVO_TYPE_NONE\x10\x01\x12\x13\n\x0fSERVO_TYPE_SG90\x10\x02\x12\x16\n\x12SERVO_TYPE_SCS0009\x10\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17websocket-message.proto\x12\x16stackchan.websocket.v1\"\x96\x08\n\x10WebSocketMessage\x12\x31\n\x04kind\x18\x01 \x01(\x0e\x32#.stackchan.websocket.v1.MessageKind\x12\x39\n\x0cmessage_type\x18\x02 \x01(\x0e\x32#.stackchan.websocket.v1.MessageType\x12\x0b\n\x03seq\x18\x03 \x01(\r\x12@\n\x0f\x61udio_pcm_start\x18\n \x01(\x0b\x32%.stackchan.websocket.v1.AudioPcmStartH\x00\x12<\n\x0e\x61udio_pcm_data\x18\x0b \x01(\x0b\x32\".stackchan.websocket.v1.AudioChunkH\x00\x12<\n\raudio_pcm_end\x18\x0c \x01(\x0b\x32#.stackchan.websocket.v1.AudioPcmEndH\x00\x12@\n\x0f\x61udio_wav_start\x18\x14 \x01(\x0b\x32%.stackchan.websocket.v1.AudioWavStartH\x00\x12<\n\x0e\x61udio_wav_data\x18\x15 \x01(\x0b\x32\".stackchan.websocket.v1.AudioChunkH\x00\x12<\n\raudio_wav_end\x18\x16 \x01(\x0b\x32#.stackchan.websocket.v1.AudioWavEndH\x00\x12\x39\n\tstate_cmd\x18\x1e \x01(\x0b\x32$.stackchan.websocket.v1.StateCommandH\x00\x12>\n\rwake_word_evt\x18\x1f \x01(\x0b\x32%.stackchan.websocket.v1.WakeWordEventH\x00\x12\x37\n\tstate_evt\x18 \x01(\x0b\x32\".stackchan.websocket.v1.StateEventH\x00\x12@\n\x0espeak_done_evt\x18! \x01(\x0b\x32&.stackchan.websocket.v1.SpeakDoneEventH\x00\x12\x41\n\tservo_cmd\x18\" \x01(\x0b\x32,.stackchan.websocket.v1.ServoCommandSequenceH\x00\x12@\n\x0eservo_done_evt\x18# \x01(\x0b\x32&.stackchan.websocket.v1.ServoDoneEventH\x00\x12\x45\n\x11\x66irmware_metadata\x18$ \x01(\x0b\x32(.stackchan.websocket.v1.FirmwareMetadataH\x00\x12\x41\n\x0fserver_metadata\x18% \x01(\x0b\x32&.stackchan.websocket.v1.ServerMetadataH\x00\x42\x06\n\x04\x62ody\"\x0f\n\rAudioPcmStart\"\r\n\x0b\x41udioPcmEnd\"6\n\rAudioWavStart\x12\x13\n\x0bsample_rate\x18\x01 \x01(\r\x12\x10\n\x08\x63hannels\x18\x02 \x01(\r\"\r\n\x0b\x41udioWavEnd\"\x1f\n\nAudioChunk\x12\x11\n\tpcm_bytes\x18\x01 \x01(\x0c\"E\n\x0cStateCommand\x12\x35\n\x05state\x18\x01 \x01(\x0e\x32&.stackchan.websocket.v1.StackchanState\"!\n\rWakeWordEvent\x12\x10\n\x08\x64\x65tected\x18\x01 \x01(\x08\"C\n\nStateEvent\x12\x35\n\x05state\x18\x01 \x01(\x0e\x32&.stackchan.websocket.v1.StackchanState\"\x1e\n\x0eSpeakDoneEvent\x12\x0c\n\x04\x64one\x18\x01 \x01(\x08\"N\n\x14ServoCommandSequence\x12\x36\n\x08\x63ommands\x18\x01 \x03(\x0b\x32$.stackchan.websocket.v1.ServoCommand\"f\n\x0cServoCommand\x12\x32\n\x02op\x18\x01 \x01(\x0e\x32&.stackchan.websocket.v1.ServoOperation\x12\r\n\x05\x61ngle\x18\x02 \x01(\x11\x12\x13\n\x0b\x64uration_ms\x18\x03 \x01(\x11\"\x1e\n\x0eServoDoneEvent\x12\x0c\n\x04\x64one\x18\x01 \x01(\x08\"\x99\x02\n\x10\x46irmwareMetadata\x12\x37\n\x0b\x64\x65vice_type\x18\x01 \x01(\x0e\x32\".stackchan.websocket.v1.DeviceType\x12\x15\n\rdisplay_width\x18\x02 \x01(\r\x12\x16\n\x0e\x64isplay_height\x18\x03 \x01(\r\x12\x1c\n\x14has_device_wake_word\x18\x04 \x01(\x08\x12\x0f\n\x07has_led\x18\x05 \x01(\x08\x12\x35\n\nservo_type\x18\x06 \x01(\x0e\x32!.stackchan.websocket.v1.ServoType\x12\x1d\n\x15supports_audio_duplex\x18\x07 \x01(\x08\x12\x18\n\x10\x66irmware_version\x18\x08 \x01(\t\"F\n\x0eServerMetadata\x12\x1c\n\x14has_server_wake_word\x18\x01 \x01(\x08\x12\x16\n\x0eserver_version\x18\x02 \x01(\t*\x80\x03\n\x0bMessageKind\x12\x1c\n\x18MESSAGE_KIND_UNSPECIFIED\x10\x00\x12\x1a\n\x16MESSAGE_KIND_AUDIO_PCM\x10\x01\x12\x1a\n\x16MESSAGE_KIND_AUDIO_WAV\x10\x02\x12\x1a\n\x16MESSAGE_KIND_STATE_CMD\x10\x03\x12\x1e\n\x1aMESSAGE_KIND_WAKE_WORD_EVT\x10\x04\x12\x1a\n\x16MESSAGE_KIND_STATE_EVT\x10\x05\x12\x1f\n\x1bMESSAGE_KIND_SPEAK_DONE_EVT\x10\x06\x12\x1a\n\x16MESSAGE_KIND_SERVO_CMD\x10\x07\x12\x1f\n\x1bMESSAGE_KIND_SERVO_DONE_EVT\x10\x08\x12\"\n\x1eMESSAGE_KIND_FIRMWARE_METADATA\x10\t\x12 \n\x1cMESSAGE_KIND_SERVER_METADATA\x10\n\x12\x1f\n\x1bMESSAGE_KIND_SERVER_WWD_PCM\x10\x0b*p\n\x0bMessageType\x12\x1c\n\x18MESSAGE_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12MESSAGE_TYPE_START\x10\x01\x12\x15\n\x11MESSAGE_TYPE_DATA\x10\x02\x12\x14\n\x10MESSAGE_TYPE_END\x10\x03*\xa5\x01\n\x0eStackchanState\x12\x18\n\x14STACKCHAN_STATE_IDLE\x10\x00\x12\x1d\n\x19STACKCHAN_STATE_LISTENING\x10\x01\x12\x1c\n\x18STACKCHAN_STATE_THINKING\x10\x02\x12\x1c\n\x18STACKCHAN_STATE_SPEAKING\x10\x03\x12\x1e\n\x1aSTACKCHAN_STATE_SERVER_WWD\x10\x04*c\n\x0eServoOperation\x12\x19\n\x15SERVO_OPERATION_SLEEP\x10\x00\x12\x1a\n\x16SERVO_OPERATION_MOVE_X\x10\x01\x12\x1a\n\x16SERVO_OPERATION_MOVE_Y\x10\x02*\x85\x01\n\nDeviceType\x12\x1b\n\x17\x44\x45VICE_TYPE_UNSPECIFIED\x10\x00\x12\x1e\n\x1a\x44\x45VICE_TYPE_M5STACK_CORES3\x10\x01\x12\x1a\n\x16\x44\x45VICE_TYPE_M5ATOM_S3R\x10\x02\x12\x1e\n\x1a\x44\x45VICE_TYPE_M5ATOM_ECHOS3R\x10\x03*i\n\tServoType\x12\x1a\n\x16SERVO_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fSERVO_TYPE_NONE\x10\x01\x12\x13\n\x0fSERVO_TYPE_SG90\x10\x02\x12\x16\n\x12SERVO_TYPE_SCS0009\x10\x03\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,17 +32,17 @@ if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None _globals['_MESSAGEKIND']._serialized_start=2016 - _globals['_MESSAGEKIND']._serialized_end=2367 - _globals['_MESSAGETYPE']._serialized_start=2369 - _globals['_MESSAGETYPE']._serialized_end=2481 - _globals['_STACKCHANSTATE']._serialized_start=2484 - _globals['_STACKCHANSTATE']._serialized_end=2617 - _globals['_SERVOOPERATION']._serialized_start=2619 - _globals['_SERVOOPERATION']._serialized_end=2718 - _globals['_DEVICETYPE']._serialized_start=2721 - _globals['_DEVICETYPE']._serialized_end=2854 - _globals['_SERVOTYPE']._serialized_start=2856 - _globals['_SERVOTYPE']._serialized_end=2961 + _globals['_MESSAGEKIND']._serialized_end=2400 + _globals['_MESSAGETYPE']._serialized_start=2402 + _globals['_MESSAGETYPE']._serialized_end=2514 + _globals['_STACKCHANSTATE']._serialized_start=2517 + _globals['_STACKCHANSTATE']._serialized_end=2682 + _globals['_SERVOOPERATION']._serialized_start=2684 + _globals['_SERVOOPERATION']._serialized_end=2783 + _globals['_DEVICETYPE']._serialized_start=2786 + _globals['_DEVICETYPE']._serialized_end=2919 + _globals['_SERVOTYPE']._serialized_start=2921 + _globals['_SERVOTYPE']._serialized_end=3026 _globals['_WEBSOCKETMESSAGE']._serialized_start=52 _globals['_WEBSOCKETMESSAGE']._serialized_end=1098 _globals['_AUDIOPCMSTART']._serialized_start=1100 diff --git a/stackchan_server/protobuf_ws.py b/stackchan_server/protobuf_ws.py index 8569004..b652a2c 100644 --- a/stackchan_server/protobuf_ws.py +++ b/stackchan_server/protobuf_ws.py @@ -92,7 +92,10 @@ def encode_audio_wav_end_message(seq: int) -> bytes: return message.SerializeToString() -def encode_state_command_message(seq: int, state_id: int) -> bytes: +def encode_state_command_message( + seq: int, + state_id: int, +) -> bytes: message = _new_message( ws_pb2.MESSAGE_KIND_STATE_CMD, ws_pb2.MESSAGE_TYPE_DATA, diff --git a/stackchan_server/server_wwd.py b/stackchan_server/server_wwd.py new file mode 100644 index 0000000..43bb73e --- /dev/null +++ b/stackchan_server/server_wwd.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +import asyncio +from logging import getLogger +from typing import Any, Awaitable, Callable, Optional + +from .wakeup_word_detection import ( + WakeWordDetectionError, + WakeWordDetectionTimeout, + create_server_side_wake_word_detector, +) + +logger = getLogger(__name__) + +_SERVER_WAKEWORD_RESTART_DELAY_SECONDS = 0.25 +_TRAILING_PCM_DRAIN_SECONDS = 1.0 + + +class ServerWwdController: + def __init__( + self, + *, + send_state_command: Callable[[int], Awaitable[None]], + set_current_state: Callable[[int], None], + close_websocket: Callable[[int, str], Awaitable[None]], + current_state: Callable[[], int], + is_closed: Callable[[], bool], + on_detected: Callable[[], None], + server_wwd_state: int, + idle_state: int, + ) -> None: + self._send_state_command = send_state_command + self._set_current_state = set_current_state + self._close_websocket = close_websocket + self._current_state = current_state + self._is_closed = is_closed + self._on_detected = on_detected + self._server_wwd_state = server_wwd_state + self._idle_state = idle_state + + self._detector = create_server_side_wake_word_detector() + self._task: Optional[asyncio.Task[bool]] = None + self._restart_task: Optional[asyncio.Task[None]] = None + self._auto_start = False + self._suppress_restart_once = False + self._drain_trailing_pcm_until_end = False + self._drain_trailing_pcm_deadline: float | None = None + + @property + def available(self) -> bool: + return self._detector is not None + + @property + def auto_start_enabled(self) -> bool: + return self._auto_start + + async def enable_auto_detection(self) -> None: + self._auto_start = True + + async def start_if_available(self) -> bool: + if ( + self._is_closed() + or self._detector is None + or self._current_state() != self._idle_state + ): + return False + + if self._task is not None and not self._task.done(): + return True + + self._cancel_restart_task() + self._task = asyncio.create_task( + self._run_detection(), + name="server-side-wakeword-detection", + ) + return True + + async def stop(self, *, suppress_restart: bool = True) -> None: + self._cancel_restart_task() + task = self._task + if task is None: + return + + if suppress_restart and not task.done(): + self._suppress_restart_once = True + + if task.done(): + self._task = None + try: + await task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Server-side wake-word detection task failed") + return + + task.cancel() + self._task = None + try: + await task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Server-side wake-word detection task failed") + + async def handle_pcm_message(self, message: Any, *, ws_pb2: Any) -> bool: + body_name = message.WhichOneof("body") + + if self._should_drain_trailing_pcm(): + if ( + message.message_type == ws_pb2.MESSAGE_TYPE_START + and body_name == "audio_pcm_start" + ): + logger.info( + "Received a new server-side wake-word PCM START while draining trailing audio; resuming normal routing" + ) + self._clear_trailing_pcm_drain() + elif ( + message.message_type == ws_pb2.MESSAGE_TYPE_DATA + and body_name == "audio_pcm_data" + ): + logger.info( + "Discarding trailing server-side wake-word PCM DATA payload_bytes=%d", + len(message.audio_pcm_data.pcm_bytes), + ) + return True + elif ( + message.message_type == ws_pb2.MESSAGE_TYPE_END + and body_name == "audio_pcm_end" + ): + logger.info("Finished draining trailing server-side wake-word PCM") + self._clear_trailing_pcm_drain() + return True + + detector = self._detector + if detector is None or not detector.running: + logger.info( + "Ignoring server-side wake-word PCM while detector is inactive type=%s body=%s", + message.message_type, + body_name, + ) + return True + + if ( + message.message_type == ws_pb2.MESSAGE_TYPE_START + and body_name == "audio_pcm_start" + ): + await detector.handle_start() + return True + + if ( + message.message_type == ws_pb2.MESSAGE_TYPE_DATA + and body_name == "audio_pcm_data" + ): + payload = bytes(message.audio_pcm_data.pcm_bytes) + await detector.handle_data(payload) + return True + + if ( + message.message_type == ws_pb2.MESSAGE_TYPE_END + and body_name == "audio_pcm_end" + ): + await detector.handle_end() + return True + + await self._close_websocket(1003, "unknown server wake-word PCM protobuf body") + return False + + def schedule_restart( + self, + delay_seconds: float = _SERVER_WAKEWORD_RESTART_DELAY_SECONDS, + ) -> None: + if not self._auto_start or self._is_closed(): + return + + self._cancel_restart_task() + self._restart_task = asyncio.create_task( + self._restart_after_delay(delay_seconds), + name="server-side-wakeword-restart", + ) + + async def _run_detection(self) -> bool: + detector = self._detector + if detector is None: + return False + + detected = False + should_restart = False + try: + await detector.start() + await self._send_state_command(self._server_wwd_state) + detected = await detector.wait_result() + if detected: + self._on_detected() + return detected + except asyncio.CancelledError: + raise + except WakeWordDetectionTimeout as exc: + logger.info("Server-side wake-word detection stopped: %s", exc) + return False + except WakeWordDetectionError as exc: + logger.warning("Server-side wake-word detection stopped: %s", exc) + return False + except Exception: + logger.exception("Server-side wake-word detection failed") + return False + finally: + await detector.stop() + self._arm_trailing_pcm_drain() + if not self._is_closed(): + self._set_current_state(self._idle_state) + try: + await self._send_state_command(self._idle_state) + except Exception: + logger.exception( + "Failed to return firmware to idle after wake-word detection" + ) + suppress_restart = self._suppress_restart_once + self._suppress_restart_once = False + should_restart = ( + self._auto_start + and not detected + and not suppress_restart + and not self._is_closed() + ) + if self._task is asyncio.current_task(): + self._task = None + if should_restart: + self.schedule_restart() + + def _cancel_restart_task(self) -> None: + task = self._restart_task + if task is None: + return + self._restart_task = None + task.cancel() + + async def _restart_after_delay(self, delay_seconds: float) -> None: + try: + await asyncio.sleep(delay_seconds) + if self._is_closed(): + return + await self.start_if_available() + except asyncio.CancelledError: + raise + finally: + if self._restart_task is asyncio.current_task(): + self._restart_task = None + + def _arm_trailing_pcm_drain( + self, + timeout_seconds: float = _TRAILING_PCM_DRAIN_SECONDS, + ) -> None: + loop = asyncio.get_running_loop() + self._drain_trailing_pcm_until_end = True + self._drain_trailing_pcm_deadline = loop.time() + timeout_seconds + + def _clear_trailing_pcm_drain(self) -> None: + self._drain_trailing_pcm_until_end = False + self._drain_trailing_pcm_deadline = None + + def _should_drain_trailing_pcm(self) -> bool: + if not self._drain_trailing_pcm_until_end: + return False + deadline = self._drain_trailing_pcm_deadline + if deadline is None: + return True + if asyncio.get_running_loop().time() <= deadline: + return True + + logger.info( + "Trailing PCM drain window expired before END arrived; resuming normal routing" + ) + self._clear_trailing_pcm_drain() + return False + + +__all__ = ["ServerWwdController"] diff --git a/stackchan_server/speech_recognition/whisper_server.py b/stackchan_server/speech_recognition/whisper_server.py index 99dd811..d508cb5 100644 --- a/stackchan_server/speech_recognition/whisper_server.py +++ b/stackchan_server/speech_recognition/whisper_server.py @@ -25,9 +25,10 @@ class WhisperServerSpeechToTextConfig(BaseSettings): url: str = _DEFAULT_SERVER_URL - language: str = "auto" + language: str = "" detect_language: bool = False response_format: str = "verbose_json" + prompt: str = "" silence_rms_threshold: float = _DEFAULT_SILENCE_RMS_THRESHOLD request_timeout_seconds: float = 60.0 model: str = "" @@ -73,9 +74,15 @@ async def transcribe(self, pcm_bytes: bytes) -> str: def _request_transcript(self, wav_bytes: bytes, language: str) -> str: fields = { "response_format": self._conf.response_format, - "language": language, } + normalized_language = language.strip() + if normalized_language: + fields["language"] = normalized_language + + if self._conf.prompt: + fields["prompt"] = self._conf.prompt + if self._conf.model: fields["model"] = self._conf.model diff --git a/stackchan_server/wakeup_word_detection/__init__.py b/stackchan_server/wakeup_word_detection/__init__.py new file mode 100644 index 0000000..198a4fb --- /dev/null +++ b/stackchan_server/wakeup_word_detection/__init__.py @@ -0,0 +1,17 @@ +from .create import create_server_side_wake_word_detector +from .whisper_server import ( + WakeWordDetectionError, + WakeWordDetectionTimeout, + WhisperServerWakeWordDetector, + WhisperServerWakeWordDetectorConfig, + WhisperServerWakeWordSpeechToTextConfig, +) + +__all__ = [ + "create_server_side_wake_word_detector", + "WhisperServerWakeWordDetector", + "WhisperServerWakeWordDetectorConfig", + "WhisperServerWakeWordSpeechToTextConfig", + "WakeWordDetectionError", + "WakeWordDetectionTimeout", +] diff --git a/stackchan_server/wakeup_word_detection/create.py b/stackchan_server/wakeup_word_detection/create.py new file mode 100644 index 0000000..6d7520d --- /dev/null +++ b/stackchan_server/wakeup_word_detection/create.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from pydantic_settings import BaseSettings + +from .whisper_server import WhisperServerWakeWordDetector + + +class _CreateWhisperServerWakeWordDetectorEnv(BaseSettings): + use_wwd_whisper_server: bool = False + + class Config: + env_prefix = "STACKCHAN_" + + +def create_server_side_wake_word_detector() -> WhisperServerWakeWordDetector | None: + env = _CreateWhisperServerWakeWordDetectorEnv() + if not env.use_wwd_whisper_server: + return None + + return WhisperServerWakeWordDetector() + + +__all__ = ["create_server_side_wake_word_detector"] diff --git a/stackchan_server/wakeup_word_detection/whisper_server.py b/stackchan_server/wakeup_word_detection/whisper_server.py new file mode 100644 index 0000000..6b25fe6 --- /dev/null +++ b/stackchan_server/wakeup_word_detection/whisper_server.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import asyncio +import unicodedata +from logging import getLogger + +from pydantic import Field +from pydantic_settings import BaseSettings + +from ..speech_recognition.whisper_server import ( + WhisperServerSpeechToText, + WhisperServerSpeechToTextConfig, +) +from ..static import LISTEN_AUDIO_FORMAT + +logger = getLogger(__name__) + + +class WakeWordDetectionError(Exception): + pass + + +class WakeWordDetectionTimeout(WakeWordDetectionError): + pass + + +class WhisperServerWakeWordDetectorConfig(BaseSettings): + keywords: list[str] = Field(default_factory=lambda: ["スタックチャン"]) + window_seconds: float = 3.0 + interval_seconds: float = 0.5 + timeout_seconds: float = 300.0 + + class Config: + env_prefix = "STACKCHAN_WWD_" + + +class WhisperServerWakeWordSpeechToTextConfig(WhisperServerSpeechToTextConfig): + class Config(WhisperServerSpeechToTextConfig.Config): + env_prefix = "STACKCHAN_WWD_WHISPER_SERVER_" + + +class WhisperServerWakeWordDetector: + def __init__( + self, + *, + recognizer: WhisperServerSpeechToText | None = None, + config: WhisperServerWakeWordDetectorConfig | None = None, + ) -> None: + self.config = config or WhisperServerWakeWordDetectorConfig() + self.recognizer = recognizer or WhisperServerSpeechToText( + config=WhisperServerWakeWordSpeechToTextConfig() + ) + self._pcm_buffer = bytearray() + self._running = False + self._detected = False + self._streaming_started = False + self._error: Exception | None = None + self._last_inference_at = 0.0 + self._inference_task: asyncio.Task[None] | None = None + self._event = asyncio.Event() + self._lock = asyncio.Lock() + self._streaming_ended = False + + @property + def running(self) -> bool: + return self._running + + async def start(self) -> None: + await self.stop() + self._pcm_buffer = bytearray() + self._running = True + self._detected = False + self._streaming_started = False + self._streaming_ended = False + self._error = None + self._last_inference_at = 0.0 + self._event.clear() + logger.info("Server-side wake-word detection started") + + async def stop(self) -> None: + self._running = False + if self._inference_task is not None: + self._inference_task.cancel() + try: + await self._inference_task + except asyncio.CancelledError: + pass + self._inference_task = None + self._event.set() + + async def handle_start(self) -> None: + if not self._running: + return + self._streaming_started = True + self._streaming_ended = False + self._pcm_buffer = bytearray() + self._last_inference_at = 0.0 + logger.info("Server-side wake-word stream START") + + async def handle_data(self, payload: bytes) -> None: + if not self._running: + return + if not self._streaming_started: + logger.warning( + "Ignoring stale server-side wake-word DATA before START payload_bytes=%d", + len(payload), + ) + return + if self._streaming_ended: + logger.warning( + "Ignoring stale server-side wake-word DATA after END payload_bytes=%d", + len(payload), + ) + return + + self._pcm_buffer.extend(payload) + self._truncate_buffer_to_window() + + loop = asyncio.get_running_loop() + now = loop.time() + if (now - self._last_inference_at) < self.config.interval_seconds: + return + if self._inference_task is not None and not self._inference_task.done(): + return + + self._last_inference_at = now + window_bytes = bytes(self._pcm_buffer) + self._inference_task = asyncio.create_task(self._run_inference(window_bytes)) + + async def handle_end(self) -> None: + if not self._running: + return + if not self._streaming_started: + logger.warning("Ignoring stale server-side wake-word END before START") + return + if self._streaming_ended: + logger.warning("Ignoring duplicate server-side wake-word END") + return + self._streaming_ended = True + logger.info("Server-side wake-word stream END") + if self._inference_task is not None and not self._inference_task.done(): + try: + await self._inference_task + except Exception as exc: # pragma: no cover + self._error = exc + if not self._detected: + self._event.set() + + async def wait_result(self, timeout_seconds: float | None = None) -> bool: + if not self._running: + raise WakeWordDetectionError("Server-side wake-word detection is not running") + + timeout = ( + timeout_seconds + if timeout_seconds is not None + else self.config.timeout_seconds + ) + try: + await asyncio.wait_for(self._event.wait(), timeout=timeout) + except asyncio.TimeoutError as exc: + raise WakeWordDetectionTimeout( + "Server-side wake-word detection timed out" + ) from exc + + if self._error is not None: + raise WakeWordDetectionError(str(self._error)) from self._error + + return self._detected + + async def _run_inference(self, pcm_bytes: bytes) -> None: + if not pcm_bytes: + return + + try: + async with self._lock: + transcript = await self.recognizer.transcribe(pcm_bytes) + except Exception as exc: # pragma: no cover + logger.exception("Server-side wake-word transcription failed") + self._error = exc + self._event.set() + return + + logger.info("Server-side wake-word transcript: %s", transcript) + + if self._contains_wake_word(transcript): + logger.info("Server-side wake-word detected") + self._detected = True + self._event.set() + + def _contains_wake_word(self, transcript: str) -> bool: + normalized_transcript = _normalize_text(transcript) + if not normalized_transcript: + return False + + for keyword in self.config.keywords: + normalized_keyword = _normalize_text(keyword) + if normalized_keyword and normalized_keyword in normalized_transcript: + return True + return False + + def _truncate_buffer_to_window(self) -> None: + sample_rate = LISTEN_AUDIO_FORMAT.sample_rate_hz + channels = LISTEN_AUDIO_FORMAT.channels + sample_width = LISTEN_AUDIO_FORMAT.sample_width + bytes_per_second = sample_rate * channels * sample_width + max_bytes = max(1, int(bytes_per_second * self.config.window_seconds)) + if len(self._pcm_buffer) <= max_bytes: + return + del self._pcm_buffer[: len(self._pcm_buffer) - max_bytes] + + +def _normalize_text(text: str) -> str: + normalized = unicodedata.normalize("NFKC", text or "") + return "".join(normalized.lower().split()) + + +__all__ = [ + "WhisperServerWakeWordDetector", + "WhisperServerWakeWordDetectorConfig", + "WhisperServerWakeWordSpeechToTextConfig", + "WakeWordDetectionError", + "WakeWordDetectionTimeout", +] diff --git a/stackchan_server/ws_proxy.py b/stackchan_server/ws_proxy.py index 1c45236..1c414cc 100644 --- a/stackchan_server/ws_proxy.py +++ b/stackchan_server/ws_proxy.py @@ -12,7 +12,6 @@ from fastapi import WebSocket, WebSocketDisconnect from google.protobuf.message import DecodeError -from pydantic_settings import BaseSettings from . import __version__ from .generated_protobuf import websocket_message_pb2 as _ws_pb2 @@ -23,6 +22,7 @@ encode_state_command_message, parse_websocket_message, ) +from .server_wwd import ServerWwdController from .speak import SpeakHandler from .static import LISTEN_AUDIO_FORMAT from .types import SpeechRecognizer, SpeechSynthesizer @@ -45,22 +45,12 @@ _DEBUG_RECORDING_ENABLED = os.getenv("DEBUG_RECODING") == "1" -class _WakeWordServerConfig(BaseSettings): - no_use_client_wakeup_word: bool = False - use_open_wake_word: bool = False - - class Config: - env_prefix = "STACKCHAN_" - - -_WAKEWORD_SERVER_CONFIG = _WakeWordServerConfig() - - class FirmwareState(IntEnum): IDLE = 0 LISTENING = 1 THINKING = 2 SPEAKING = 3 + SERVER_WWD = 4 class ServoMoveType(StrEnum): @@ -129,7 +119,6 @@ def __init__( recordings_dir=self.recordings_dir, debug_recording=self._debug_recording, ) - self._receiving_task: Optional[asyncio.Task] = None self._closed = False @@ -144,6 +133,18 @@ def __init__( self._servo_done_counter = 0 self._servo_sent_counter = 0 self._pending_servo_wait_targets: deque[int] = deque() + self._server_wwd = ServerWwdController( + send_state_command=self.send_state_command, + set_current_state=lambda state: setattr( + self, "_current_firmware_state", FirmwareState(state) + ), + close_websocket=self.ws.close, + current_state=lambda: int(self._current_firmware_state), + is_closed=lambda: self._closed, + on_detected=self._wakeword_event.set, + server_wwd_state=int(FirmwareState.SERVER_WWD), + idle_state=int(FirmwareState.IDLE), + ) @property def closed(self) -> bool: @@ -157,6 +158,10 @@ def current_state(self) -> FirmwareState: def receive_task(self) -> Optional[asyncio.Task]: return self._receiving_task + @property + def has_server_wakeword_detector(self) -> bool: + return self._server_wwd.available + def trigger_wakeword(self) -> None: """Web API から擬似的に WAKEWORD_EVT を発火させる。""" logger.info("Triggered wakeword via API") @@ -165,6 +170,7 @@ def trigger_wakeword(self) -> None: async def wait_for_talk_session(self) -> None: while True: if self._wakeword_event.is_set(): + await self._server_wwd.stop() self._wakeword_event.clear() return if self._closed: @@ -172,6 +178,7 @@ async def wait_for_talk_session(self) -> None: await asyncio.sleep(0.05) async def listen(self) -> str: + await self._server_wwd.stop() return await self._listener.listen( send_state_command=self.send_state_command, is_closed=lambda: self._closed, @@ -188,11 +195,16 @@ async def speak(self, text: str) -> None: is_closed=lambda: self._closed, ) - async def send_state_command(self, state_id: int | FirmwareState) -> None: + async def send_state_command( + self, + state_id: int | FirmwareState, + ) -> None: await self._send_state_command(state_id) async def reset_state(self) -> None: await self.send_state_command(FirmwareState.IDLE) + self._current_firmware_state = FirmwareState.IDLE + self._server_wwd.schedule_restart() async def move_servo(self, commands: Sequence[ServoCommand]) -> None: previous_counter = self._servo_sent_counter @@ -200,7 +212,7 @@ async def move_servo(self, commands: Sequence[ServoCommand]) -> None: self._servo_sent_counter = target_counter self._pending_servo_wait_targets.append(target_counter) try: - await self.ws.send_bytes( + await self._send_ws_bytes( encode_servo_command_message(self._next_down_seq(), commands) ) except Exception: @@ -232,62 +244,49 @@ async def start(self) -> None: async def close(self) -> None: self._closed = True + await self._server_wwd.stop() if self._receiving_task: self._receiving_task.cancel() with suppress(asyncio.CancelledError): - await self._receiving_task + try: + await self._receiving_task + except RuntimeError as exc: + if not self._is_closed_websocket_runtime_error(exc): + raise await self._listener.close() async def start_talking(self, text: str) -> None: await self.speak(text) + async def enable_auto_server_wakeword_detection(self) -> None: + await self._server_wwd.enable_auto_detection() + if self.firmware_metadata is not None: + await self._server_wwd.start_if_available() + async def _receive_loop(self) -> None: try: while True: - raw_message = await self.ws.receive_bytes() + try: + raw_message = await self.ws.receive_bytes() + except RuntimeError as exc: + if self._is_closed_websocket_runtime_error(exc): + break + raise try: message = parse_websocket_message(raw_message) except DecodeError: await self.ws.close(code=1003, reason="invalid protobuf message") break + if message.kind == ws_pb2.MESSAGE_KIND_SERVER_WWD_PCM: + if not await self._server_wwd.handle_pcm_message(message, ws_pb2=ws_pb2): + break + continue + if message.kind == ws_pb2.MESSAGE_KIND_AUDIO_PCM: - body_name = message.WhichOneof("body") - - if ( - message.message_type == ws_pb2.MESSAGE_TYPE_START - and body_name == "audio_pcm_start" - ): - if not await self._listener.handle_start(self.ws): - break - continue - - if ( - message.message_type == ws_pb2.MESSAGE_TYPE_DATA - and body_name == "audio_pcm_data" - ): - payload = bytes(message.audio_pcm_data.pcm_bytes) - if not await self._listener.handle_data( - self.ws, len(payload), payload - ): - break - continue - - if ( - message.message_type == ws_pb2.MESSAGE_TYPE_END - and body_name == "audio_pcm_end" - ): - await self._listener.handle_end( - self.ws, - payload_bytes=0, - payload=b"", - send_state_command=self.send_state_command, - thinking_state=FirmwareState.THINKING, - ) - continue - - await self.ws.close(code=1003, reason="unknown PCM protobuf body") - break + if not await self._handle_audio_pcm_message(message): + break + continue if message.kind == ws_pb2.MESSAGE_KIND_WAKE_WORD_EVT: self._handle_wakeword_event(message) @@ -316,6 +315,38 @@ async def _receive_loop(self) -> None: finally: self._closed = True + async def _handle_audio_pcm_message(self, message: Any) -> bool: + body_name = message.WhichOneof("body") + + if ( + message.message_type == ws_pb2.MESSAGE_TYPE_START + and body_name == "audio_pcm_start" + ): + return await self._listener.handle_start(self.ws) + + if ( + message.message_type == ws_pb2.MESSAGE_TYPE_DATA + and body_name == "audio_pcm_data" + ): + payload = bytes(message.audio_pcm_data.pcm_bytes) + return await self._listener.handle_data(self.ws, len(payload), payload) + + if ( + message.message_type == ws_pb2.MESSAGE_TYPE_END + and body_name == "audio_pcm_end" + ): + await self._listener.handle_end( + self.ws, + payload_bytes=0, + payload=b"", + send_state_command=self.send_state_command, + thinking_state=FirmwareState.THINKING, + ) + return True + + await self.ws.close(code=1003, reason="unknown PCM protobuf body") + return False + def _handle_wakeword_event(self, message: Any) -> None: if message.message_type != ws_pb2.MESSAGE_TYPE_DATA: return @@ -355,24 +386,20 @@ async def _handle_firmware_metadata(self, message: Any) -> None: self.firmware_metadata.firmware_version, ) self.server_metadata = self._build_server_metadata(self.firmware_metadata) - await self.ws.send_bytes( + await self._send_ws_bytes( encode_server_metadata_message( self._next_down_seq(), has_server_wake_word=self.server_metadata.has_server_wake_word, server_version=self.server_metadata.server_version, ) ) + if self._server_wwd.auto_start_enabled: + await self._server_wwd.start_if_available() def _build_server_metadata( self, firmware_metadata: FirmwareMetadata ) -> ServerMetadata: - should_use_server_wake_word = ( - _WAKEWORD_SERVER_CONFIG.use_open_wake_word - and ( - _WAKEWORD_SERVER_CONFIG.no_use_client_wakeup_word - or not firmware_metadata.has_device_wake_word - ) - ) + should_use_server_wake_word = self._server_wwd.available return ServerMetadata( has_server_wake_word=should_use_server_wake_word, server_version=__version__, @@ -410,11 +437,36 @@ def _handle_servo_done_event(self, message: Any) -> None: self._servo_done_counter += 1 logger.info("Received servo done event") - async def _send_state_command(self, state_id: int | FirmwareState) -> None: - await self.ws.send_bytes( - encode_state_command_message(self._next_down_seq(), int(state_id)) + async def _send_state_command( + self, + state_id: int | FirmwareState, + ) -> None: + await self._send_ws_bytes( + encode_state_command_message( + self._next_down_seq(), + int(state_id), + ) + ) + + async def _send_ws_bytes(self, data: bytes) -> None: + try: + await self.ws.send_bytes(data) + except RuntimeError as exc: + self._raise_websocket_disconnect_from_runtime_error(exc) + + def _is_closed_websocket_runtime_error(self, exc: RuntimeError) -> bool: + message = str(exc) + return ( + 'Cannot call "send" once a close message has been sent.' in message + or 'WebSocket is not connected. Need to call "accept" first.' in message ) + def _raise_websocket_disconnect_from_runtime_error(self, exc: RuntimeError) -> None: + if not self._is_closed_websocket_runtime_error(exc): + raise exc + self._closed = True + raise WebSocketDisconnect() from exc + async def _wait_for_counter( self, *,