Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion stackchan_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from logging import getLogger
from typing import Awaitable, Callable, Optional

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from pydantic import BaseModel

from .speech_recognition import create_speech_recognizer
from .speech_synthesis import create_speech_synthesizer
Expand All @@ -14,6 +15,15 @@
logger = getLogger(__name__)


class StackChanInfo(BaseModel):
ip: str
state: str


class SpeakRequest(BaseModel):
text: str


class StackChanApp:
def __init__(
self,
Expand All @@ -25,6 +35,7 @@ def __init__(
self.fastapi = FastAPI(title="StackChan WebSocket Server")
self._setup_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None
self._talk_session_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None
self._proxies: dict[str, WsProxy] = {}

@self.fastapi.get("/health")
async def _health() -> dict[str, str]:
Expand All @@ -34,6 +45,34 @@ async def _health() -> dict[str, str]:
async def _ws_audio(websocket: WebSocket):
await self._handle_ws(websocket)

@self.fastapi.get("/v1/stackchan", response_model=list[StackChanInfo])
async def _list_stackchans():
return [
StackChanInfo(ip=ip, state=proxy.current_state.name.lower())
for ip, proxy in self._proxies.items()
]

@self.fastapi.get("/v1/stackchan/{stackchan_ip}", response_model=StackChanInfo)
async def _get_stackchan(stackchan_ip: str):
proxy = self._proxies.get(stackchan_ip)
if proxy is None:
raise HTTPException(status_code=404, detail="stackchan not connected")
return StackChanInfo(ip=stackchan_ip, state=proxy.current_state.name.lower())

@self.fastapi.post("/v1/stackchan/{stackchan_ip}/wakeword", status_code=204)
async def _trigger_wakeword(stackchan_ip: str):
proxy = self._proxies.get(stackchan_ip)
if proxy is None:
raise HTTPException(status_code=404, detail="stackchan not connected")
proxy.trigger_wakeword()

@self.fastapi.post("/v1/stackchan/{stackchan_ip}/speak", status_code=204)
async def _speak(stackchan_ip: str, body: SpeakRequest):
proxy = self._proxies.get(stackchan_ip)
if proxy is None:
raise HTTPException(status_code=404, detail="stackchan not connected")
await proxy.speak(body.text)

def setup(self, fn: Callable[["WsProxy"], Awaitable[None]]):
self._setup_fn = fn
return fn
Expand All @@ -44,11 +83,21 @@ def talk_session(self, fn: Callable[["WsProxy"], Awaitable[None]]):

async def _handle_ws(self, websocket: WebSocket) -> None:
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown"

# 同一 IP からの既存接続があれば切断する
existing = self._proxies.get(client_ip)
if existing is not None:
logger.info("Duplicate connection from %s, closing old one", client_ip)
await existing.close()
self._proxies.pop(client_ip, None)

proxy = WsProxy(
websocket,
speech_recognizer=self.speech_recognizer,
speech_synthesizer=self.speech_synthesizer,
)
self._proxies[client_ip] = proxy
await proxy.start()
try:
if self._setup_fn:
Expand Down Expand Up @@ -82,6 +131,7 @@ async def _handle_ws(self, websocket: WebSocket) -> None:
pass
finally:
await proxy.close()
self._proxies.pop(client_ip, None)

def run(self, host: str = "0.0.0.0", port: int = 8000, reload: bool = True) -> None:
import uvicorn
Expand Down
11 changes: 11 additions & 0 deletions stackchan_server/ws_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,25 @@ def __init__(
self._closed = False

self._down_seq = 0
self._current_firmware_state: FirmwareState = FirmwareState.IDLE

@property
def closed(self) -> bool:
return self._closed

@property
def current_state(self) -> FirmwareState:
return self._current_firmware_state

@property
def receive_task(self) -> Optional[asyncio.Task]:
return self._receiving_task

def trigger_wakeword(self) -> None:
"""Web API から擬似的に WAKEWORD_EVT を発火させる。"""
logger.info("Triggered wakeword via API")
self._wakeword_event.set()

async def wait_for_talk_session(self) -> None:
while True:
if self._wakeword_event.is_set():
Expand Down Expand Up @@ -232,6 +242,7 @@ def _handle_state_event(self, msg_type: int, payload: bytes) -> None:
raw_state = int(payload[0])
try:
state = FirmwareState(raw_state)
self._current_firmware_state = state
logger.info("Received firmware state=%s(%d)", state.name, raw_state)
except ValueError:
logger.info("Received firmware state=%d", raw_state)
Expand Down
Loading