diff --git a/stackchan_server/app.py b/stackchan_server/app.py index 46a3006..03d6933 100644 --- a/stackchan_server/app.py +++ b/stackchan_server/app.py @@ -4,7 +4,8 @@ from logging import getLogger from typing import Awaitable, Callable, Optional -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect +from pydantic import BaseModel from .speech_recognition import create_speech_recognizer from .speech_synthesis import create_speech_synthesizer @@ -14,6 +15,15 @@ logger = getLogger(__name__) +class StackChanInfo(BaseModel): + ip: str + state: str + + +class SpeakRequest(BaseModel): + text: str + + class StackChanApp: def __init__( self, @@ -25,6 +35,7 @@ def __init__( self.fastapi = FastAPI(title="StackChan WebSocket Server") self._setup_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None self._talk_session_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None + self._proxies: dict[str, WsProxy] = {} @self.fastapi.get("/health") async def _health() -> dict[str, str]: @@ -34,6 +45,34 @@ async def _health() -> dict[str, str]: async def _ws_audio(websocket: WebSocket): await self._handle_ws(websocket) + @self.fastapi.get("/v1/stackchan", response_model=list[StackChanInfo]) + async def _list_stackchans(): + return [ + StackChanInfo(ip=ip, state=proxy.current_state.name.lower()) + for ip, proxy in self._proxies.items() + ] + + @self.fastapi.get("/v1/stackchan/{stackchan_ip}", response_model=StackChanInfo) + async def _get_stackchan(stackchan_ip: str): + proxy = self._proxies.get(stackchan_ip) + if proxy is None: + raise HTTPException(status_code=404, detail="stackchan not connected") + return StackChanInfo(ip=stackchan_ip, state=proxy.current_state.name.lower()) + + @self.fastapi.post("/v1/stackchan/{stackchan_ip}/wakeword", status_code=204) + async def _trigger_wakeword(stackchan_ip: str): + proxy = self._proxies.get(stackchan_ip) + if proxy is None: + raise HTTPException(status_code=404, detail="stackchan not connected") + proxy.trigger_wakeword() + + @self.fastapi.post("/v1/stackchan/{stackchan_ip}/speak", status_code=204) + async def _speak(stackchan_ip: str, body: SpeakRequest): + proxy = self._proxies.get(stackchan_ip) + if proxy is None: + raise HTTPException(status_code=404, detail="stackchan not connected") + await proxy.speak(body.text) + def setup(self, fn: Callable[["WsProxy"], Awaitable[None]]): self._setup_fn = fn return fn @@ -44,11 +83,21 @@ def talk_session(self, fn: Callable[["WsProxy"], Awaitable[None]]): async def _handle_ws(self, websocket: WebSocket) -> None: await websocket.accept() + client_ip = websocket.client.host if websocket.client else "unknown" + + # 同一 IP からの既存接続があれば切断する + existing = self._proxies.get(client_ip) + if existing is not None: + logger.info("Duplicate connection from %s, closing old one", client_ip) + await existing.close() + self._proxies.pop(client_ip, None) + proxy = WsProxy( websocket, speech_recognizer=self.speech_recognizer, speech_synthesizer=self.speech_synthesizer, ) + self._proxies[client_ip] = proxy await proxy.start() try: if self._setup_fn: @@ -82,6 +131,7 @@ async def _handle_ws(self, websocket: WebSocket) -> None: pass finally: await proxy.close() + self._proxies.pop(client_ip, None) def run(self, host: str = "0.0.0.0", port: int = 8000, reload: bool = True) -> None: import uvicorn diff --git a/stackchan_server/ws_proxy.py b/stackchan_server/ws_proxy.py index 1b9bd3b..5541b54 100644 --- a/stackchan_server/ws_proxy.py +++ b/stackchan_server/ws_proxy.py @@ -100,15 +100,25 @@ def __init__( self._closed = False self._down_seq = 0 + self._current_firmware_state: FirmwareState = FirmwareState.IDLE @property def closed(self) -> bool: return self._closed + @property + def current_state(self) -> FirmwareState: + return self._current_firmware_state + @property def receive_task(self) -> Optional[asyncio.Task]: return self._receiving_task + def trigger_wakeword(self) -> None: + """Web API から擬似的に WAKEWORD_EVT を発火させる。""" + logger.info("Triggered wakeword via API") + self._wakeword_event.set() + async def wait_for_talk_session(self) -> None: while True: if self._wakeword_event.is_set(): @@ -232,6 +242,7 @@ def _handle_state_event(self, msg_type: int, payload: bytes) -> None: raw_state = int(payload[0]) try: state = FirmwareState(raw_state) + self._current_firmware_state = state logger.info("Received firmware state=%s(%d)", state.name, raw_state) except ValueError: logger.info("Received firmware state=%d", raw_state)