From 0c1a6cecfacb50a0e0b4f0ec4a8b93bc6930d7fe Mon Sep 17 00:00:00 2001 From: aniebietafia Date: Sat, 30 May 2026 23:22:46 +0100 Subject: [PATCH] feat(elevenlabs): integrate ElevenLabs STT & TTS with fallback support - Implement ElevenLabs Scribe STT batch service and real-time streaming WebSocket client. - Implement ElevenLabs TTS batch and stream services using eleven_flash_v2_5. - Update STTWorker and TTSWorker to dispatch ElevenLabs and support primary/fallback providers. - Add active and fallback controls (ACTIVE_STT_PROVIDER, STT_FALLBACK_PROVIDER, etc.) in settings. - Add unit and pipeline integration tests for the ElevenLabs pipeline path. - Resolve mypy, ruff lint, and pytest import paths across modules and tests. Signed-off-by: aniebietafia --- app/core/config.py | 22 +- .../elevenlabs_stt/__init__.py | 13 + .../elevenlabs_stt/config.py | 30 ++ .../elevenlabs_stt/service.py | 166 ++++++++++ .../elevenlabs_stt/streaming.py | 308 ++++++++++++++++++ .../elevenlabs_tts/__init__.py | 8 + .../elevenlabs_tts/config.py | 70 ++++ .../elevenlabs_tts/service.py | 197 +++++++++++ app/services/stt_worker.py | 211 +++++++++++- app/services/tts_worker.py | 75 +++++ pyproject.toml | 1 + tests/external_services/__init__.py | 1 + .../external_services/test_elevenlabs_stt.py | 63 ++++ .../external_services/test_elevenlabs_tts.py | 118 +++++++ tests/test_kafka/test_pipeline.py | 145 +++++++++ 15 files changed, 1423 insertions(+), 5 deletions(-) create mode 100644 app/external_services/elevenlabs_stt/__init__.py create mode 100644 app/external_services/elevenlabs_stt/config.py create mode 100644 app/external_services/elevenlabs_stt/service.py create mode 100644 app/external_services/elevenlabs_stt/streaming.py create mode 100644 app/external_services/elevenlabs_tts/__init__.py create mode 100644 app/external_services/elevenlabs_tts/config.py create mode 100644 app/external_services/elevenlabs_tts/service.py create mode 100644 tests/external_services/__init__.py create mode 100644 tests/external_services/test_elevenlabs_stt.py create mode 100644 tests/external_services/test_elevenlabs_tts.py diff --git a/app/core/config.py b/app/core/config.py index 55d92a5..de32b63 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -66,6 +66,7 @@ class Settings(BaseSettings): DEEPL_API_KEY: str | None = None VOICE_AI_API_KEY: str | None = None OPENAI_API_KEY: str | None = None + ELEVEN_LABS_API_KEY: str | None = None # Google OAuth GOOGLE_CLIENT_ID: str | None = None @@ -101,12 +102,31 @@ class Settings(BaseSettings): DEEPGRAM_TTS_API_URL: str = "https://api.deepgram.com/v1/speak" DEEPGRAM_TTS_MODEL: str = "aura-2-thalia" + # AI Pipeline — STT (ElevenLabs) + ELEVENLABS_STT_API_URL: str = "https://api.elevenlabs.io/v1/speech-to-text" + ELEVENLABS_STT_WS_URL: str = "wss://api.elevenlabs.io/v1/speech-to-text/realtime" + ELEVENLABS_STT_MODEL: str = "scribe_v2" + ELEVENLABS_STT_REALTIME_MODEL: str = "scribe_v2_realtime" + ELEVENLABS_STT_USE_STREAMING: bool = True + + # AI Pipeline — TTS (ElevenLabs) + ELEVENLABS_TTS_API_URL: str = "https://api.elevenlabs.io/v1/text-to-speech" + ELEVENLABS_TTS_MODEL: str = "eleven_flash_v2_5" + ELEVENLABS_TTS_VOICE_ID: str = "JBFqnCBsd6RMkjVDRZzb" + ELEVENLABS_TTS_OUTPUT_FORMAT: str = "pcm_24000" + ELEVENLABS_TTS_USE_STREAMING: bool = True + # AI Pipeline — Audio Settings PIPELINE_AUDIO_SAMPLE_RATE: int = 24000 PIPELINE_AUDIO_ENCODING: str = "linear16" # "linear16" or "opus" - ACTIVE_TTS_PROVIDER: str = "deepgram" # "deepgram", "openai", or "voiceai" + ACTIVE_TTS_PROVIDER: str = ( + "elevenlabs" # "elevenlabs", "deepgram", "openai", or "voiceai" + ) TTS_FALLBACK_PROVIDER: str = "voiceai" # fallback when primary fails TTS_FALLBACK_ENABLED: bool = True # auto-fallback on provider failure + ACTIVE_STT_PROVIDER: str = "deepgram" # "deepgram" or "elevenlabs" + STT_FALLBACK_PROVIDER: str = "elevenlabs" # fallback when primary fails + STT_FALLBACK_ENABLED: bool = True # auto-fallback on provider failure # Mailgun Email Service MAILGUN_API_KEY: str | None = None diff --git a/app/external_services/elevenlabs_stt/__init__.py b/app/external_services/elevenlabs_stt/__init__.py new file mode 100644 index 0000000..700d22b --- /dev/null +++ b/app/external_services/elevenlabs_stt/__init__.py @@ -0,0 +1,13 @@ +"""ElevenLabs Speech-to-Text (scribe_v2 / scribe_v2_realtime) service package.""" + +from app.external_services.elevenlabs_stt.service import ( + ElevenLabsSTTService, + get_elevenlabs_stt_service, +) +from app.external_services.elevenlabs_stt.streaming import ElevenLabsStreamingSTT + +__all__ = [ + "ElevenLabsSTTService", + "ElevenLabsStreamingSTT", + "get_elevenlabs_stt_service", +] diff --git a/app/external_services/elevenlabs_stt/config.py b/app/external_services/elevenlabs_stt/config.py new file mode 100644 index 0000000..481e7b1 --- /dev/null +++ b/app/external_services/elevenlabs_stt/config.py @@ -0,0 +1,30 @@ +"""ElevenLabs STT configuration helpers. + +Provides header generation and utility settings for ElevenLabs Speech-to-Text. +""" + +from app.core.config import settings + + +def get_elevenlabs_stt_headers() -> dict[str, str]: + """Build HTTP headers for ElevenLabs STT API requests. + + Returns: + dict[str, str]: Authentication headers. + """ + if not settings.ELEVEN_LABS_API_KEY: + raise RuntimeError("ELEVEN_LABS_API_KEY is not configured.") + return { + "xi-api-key": settings.ELEVEN_LABS_API_KEY, + } + + +def get_stt_language_code(language: str) -> str | None: + """Standardize language code for ElevenLabs STT. + + ElevenLabs STT supports ISO 639-1 language codes (e.g. "en", "de", "es"). + If language is provided, we extract the base language part (e.g. 'en-US' -> 'en'). + """ + if not language: + return None + return language.split("-")[0].lower() diff --git a/app/external_services/elevenlabs_stt/service.py b/app/external_services/elevenlabs_stt/service.py new file mode 100644 index 0000000..1b16f91 --- /dev/null +++ b/app/external_services/elevenlabs_stt/service.py @@ -0,0 +1,166 @@ +"""ElevenLabs Speech-to-Text service module. + +Wraps the ElevenLabs Scribe STT API for transcription of pre-recorded +audio chunks. +""" + +import logging +import time + +import httpx + +from app.core.circuit_breaker import AsyncCircuitBreaker +from app.core.config import settings +from app.external_services.elevenlabs_stt.config import ( + get_elevenlabs_stt_headers, + get_stt_language_code, +) + +logger = logging.getLogger(__name__) + + +class ElevenLabsSTTService: + """Stateless service for converting audio bytes to text via ElevenLabs Scribe API. + + Provides a centralized client to execute audio transcription calls against + the ElevenLabs speech-to-text endpoint. + """ + + def __init__(self, timeout: float = 30.0) -> None: + self._timeout = timeout + self._client: httpx.AsyncClient | None = None + self._breaker = AsyncCircuitBreaker() + + @property + def client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=self._timeout) + return self._client + + async def transcribe( + self, + audio_bytes: bytes, + *, + language: str = "en", + sample_rate: int = 24000, + encoding: str = "linear16", + ) -> dict: + """Send raw audio to ElevenLabs Speech-to-Text and return transcription results. + + Args: + audio_bytes: Raw audio data. + language: ISO 639-1 language hint for the STT model. + sample_rate: Audio sample rate in Hz. + encoding: Audio encoding format. + + Returns: + A dict with keys ``text``, ``confidence``, + ``detected_language``, and ``latency_ms``. + + Raises: + httpx.HTTPStatusError: On non-2xx responses from ElevenLabs. + """ + headers = get_elevenlabs_stt_headers() + model_id = settings.ELEVENLABS_STT_MODEL or "scribe_v2" + lang_code = get_stt_language_code(language) + + # We must package raw PCM bytes in a WAV container or send with correct mimetype + # ElevenLabs speech-to-text accepts multiple audio formats. For raw PCM, + # uploading with an arbitrary filename like 'audio.raw' with content_type + # 'audio/wav' + # or 'audio/raw' or wrapping it in a simple WAV header is best. + # But wait! If it is raw pcm 24kHz linear16, we can write a simple WAV header + # or we can send it as audio/wav. Let's wrap it in a basic WAV header so the API + # knows sample rate and channel count, preventing transcription errors. + + wav_data = audio_bytes + if encoding.lower() in ("linear16", "raw", "pcm"): + # Construct a basic WAV header + # 44 bytes header for PCM + num_channels = 1 + bytes_per_sample = 2 # 16-bit + byte_rate = sample_rate * num_channels * bytes_per_sample + block_align = num_channels * bytes_per_sample + data_size = len(audio_bytes) + file_size = 36 + data_size + + header = bytearray(44) + header[0:4] = b"RIFF" + header[4:8] = file_size.to_bytes(4, "little") + header[8:12] = b"WAVE" + header[12:16] = b"fmt " + header[16:20] = (16).to_bytes(4, "little") # Subchunk1Size (16 for PCM) + header[20:22] = (1).to_bytes(2, "little") # AudioFormat (1 for PCM) + header[22:24] = num_channels.to_bytes(2, "little") + header[24:28] = sample_rate.to_bytes(4, "little") + header[28:32] = byte_rate.to_bytes(4, "little") + header[32:34] = block_align.to_bytes(2, "little") + header[34:36] = (bytes_per_sample * 8).to_bytes( + 2, "little" + ) # BitsPerSample + header[36:40] = b"data" + header[40:44] = data_size.to_bytes(4, "little") + + wav_data = bytes(header) + audio_bytes + + # Form data for ElevenLabs + files = {"file": ("audio.wav", wav_data, "audio/wav")} + data = { + "model_id": model_id, + } + if lang_code: + data["language_code"] = lang_code + + async def _call() -> httpx.Response: + resp = await self.client.post( + settings.ELEVENLABS_STT_API_URL, + headers=headers, + files=files, + data=data, + ) + resp.raise_for_status() + return resp + + start = time.monotonic() + response = await self._breaker.call(_call) + elapsed_ms = (time.monotonic() - start) * 1000 + logger.debug("ElevenLabs STT API completed in %.1fms", elapsed_ms) + + resp_json = response.json() + + # ElevenLabs response structure is expected to have 'text', and optionally + # 'language_code' / 'detected_language' + text = resp_json.get("text", "") + detected_language = resp_json.get("language_code", language) + + # Scribe API might return a list of words, let's look for average confidence if + # available, else default to 1.0 + words = resp_json.get("words", []) + confidence = 1.0 + if words: + confidences = [w.get("confidence", 1.0) for w in words if "confidence" in w] + if confidences: + confidence = sum(confidences) / len(confidences) + + return { + "text": text, + "confidence": confidence, + "detected_language": detected_language, + "latency_ms": round(elapsed_ms, 1), + } + + +# ── Module-level singleton ──────────────────────────────────────────── +_stt_service: ElevenLabsSTTService | None = None + + +def get_elevenlabs_stt_service() -> ElevenLabsSTTService: + """Retrieve the singleton instance of the ElevenLabsSTTService. + + Returns: + ElevenLabsSTTService: The service instance. + """ + global _stt_service # noqa: PLW0603 + if _stt_service is None: + _stt_service = ElevenLabsSTTService() + return _stt_service diff --git a/app/external_services/elevenlabs_stt/streaming.py b/app/external_services/elevenlabs_stt/streaming.py new file mode 100644 index 0000000..517302c --- /dev/null +++ b/app/external_services/elevenlabs_stt/streaming.py @@ -0,0 +1,308 @@ +"""ElevenLabs Speech-to-Text WebSocket streaming client. + +Connects to ElevenLabs scribe_v2_realtime WebSocket endpoint using raw websockets +to support real-time audio transcription. +""" + +import asyncio +import base64 +import contextlib +import json +import logging +from collections.abc import Callable, Coroutine +from typing import Any + +import websockets +from websockets.client import WebSocketClientProtocol # type: ignore[attr-defined] + +from app.core.config import settings +from app.external_services.elevenlabs_stt.config import get_stt_language_code + +logger = logging.getLogger(__name__) + +# Maximum number of automatic reconnection attempts before giving up. +_MAX_RECONNECT_ATTEMPTS = 3 + + +class ElevenLabsStreamingSTT: + """Wrapper around ElevenLabs Scribe Real-time WebSocket connection for live STT.""" + + def __init__( + self, + api_key: str, + room_id: str, + user_id: str, + on_transcript: Callable[[str, bool, float], Coroutine[Any, Any, None]], + language: str = "en", + model: str = "scribe_v2_realtime", + sample_rate: int = 24000, + ) -> None: + """Initialize the ElevenLabs streaming client. + + Args: + api_key: The ElevenLabs API key. + room_id: The meeting room identifier. + user_id: The participant user identifier. + on_transcript: Async callback function for transcript results. + Called with parameters (text, is_final, confidence). + language: ISO 639-1 language code. + model: ElevenLabs model name. Defaults to ``"scribe_v2_realtime"``. + sample_rate: Sample rate in Hz. Defaults to ``24000``. + """ + self._api_key = api_key + self.room_id = room_id + self.user_id = user_id + self.language = language + self.model = model + self.sample_rate = sample_rate + self._on_transcript = on_transcript + + self._websocket: WebSocketClientProtocol | None = None + self._listen_task: asyncio.Task | None = None + self._reconnect_task: asyncio.Task | None = None + self._background_tasks: set[asyncio.Task] = set() + + self._connected = False + self._intentional_close = False + self._reconnect_attempts = 0 + self.last_activity = asyncio.get_event_loop().time() + self._session_start_future: asyncio.Future[None] = asyncio.Future() + + async def connect(self) -> None: + """Establish the WebSocket connection to ElevenLabs.""" + logger.info( + "Connecting to ElevenLabs streaming STT for room=%s user=%s lang=%s", + self.room_id, + self.user_id, + self.language, + ) + + lang_code = get_stt_language_code(self.language) or "en" + + # Build query parameters + ws_url = ( + settings.ELEVENLABS_STT_WS_URL + or "wss://api.elevenlabs.io/v1/speech-to-text/realtime" + ) + url = ( + f"{ws_url}?model_id={self.model}" + f"&xi-api-key={self._api_key}" + f"&language_code={lang_code}" + f"&sample_rate={self.sample_rate}" + ) + + self._websocket = await websockets.connect( + url, + ping_interval=20, + ping_timeout=20, + ) + + self._connected = True + self._intentional_close = False + self._reconnect_attempts = 0 + self.last_activity = asyncio.get_event_loop().time() + self._session_start_future = asyncio.Future() + + # Start listening loop as background task + self._listen_task = asyncio.create_task(self._listen_loop()) + + # Wait up to 5s for session_start message from ElevenLabs + try: + await asyncio.wait_for(self._session_start_future, timeout=5.0) + logger.info( + "ElevenLabs STT session started successfully for room=%s user=%s", + self.room_id, + self.user_id, + ) + except TimeoutError: + logger.warning( + "Timeout waiting for ElevenLabs session_start " + "message, proceeding anyway." + ) + + async def send_audio(self, audio_bytes: bytes) -> None: + """Send raw audio bytes to the ElevenLabs WebSocket stream.""" + if not self._connected or not self._websocket: + raise RuntimeError("ElevenLabs STT connection not established") + + self.last_activity = asyncio.get_event_loop().time() + + # ElevenLabs accepts base64-encoded audio chunk JSON messages + base64_audio = base64.b64encode(audio_bytes).decode("utf-8") + message = { + "message_type": "input_audio_chunk", + "audio_base_64": base64_audio, + "commit": False, + } + await self._websocket.send(json.dumps(message)) + + async def close(self) -> None: + """Gracefully close the ElevenLabs WebSocket stream connection.""" + self._intentional_close = True + self._connected = False + + if self._reconnect_task: + self._reconnect_task.cancel() + self._reconnect_task = None + + if self._websocket: + try: + # Send end of stream message + end_message = {"message_type": "end_of_stream"} + await self._websocket.send(json.dumps(end_message)) + # Wait briefly for server to finalize + await asyncio.sleep(0.2) + except Exception as e: + logger.warning( + "Error sending end_of_stream to ElevenLabs for room=%s user=%s: %s", + self.room_id, + self.user_id, + e, + ) + + try: + await self._websocket.close() + except Exception as e: + logger.warning( + "Error closing ElevenLabs WebSocket for room=%s user=%s: %s", + self.room_id, + self.user_id, + e, + ) + self._websocket = None + + if self._listen_task: + self._listen_task.cancel() + self._listen_task = None + + # Await any remaining callback tasks + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + + logger.info( + "ElevenLabs streaming STT connection closed for room=%s user=%s", + self.room_id, + self.user_id, + ) + + async def _reconnect(self) -> None: + """Attempt to reconnect with exponential backoff.""" + while self._reconnect_attempts < _MAX_RECONNECT_ATTEMPTS: + self._reconnect_attempts += 1 + backoff = min(2**self._reconnect_attempts, 10) + logger.warning( + "ElevenLabs reconnect attempt %d/%d for room=%s " + "user=%s (backoff=%.1fs)", + self._reconnect_attempts, + _MAX_RECONNECT_ATTEMPTS, + self.room_id, + self.user_id, + backoff, + ) + await asyncio.sleep(backoff) + + try: + # Clean up old connection resources + if self._websocket: + with contextlib.suppress(Exception): + await self._websocket.close() + self._websocket = None + + if self._listen_task: + self._listen_task.cancel() + self._listen_task = None + + # Reconnect + await self.connect() + logger.info( + "ElevenLabs reconnected successfully for room=%s user=%s", + self.room_id, + self.user_id, + ) + return + except Exception as e: + logger.error( + "ElevenLabs reconnect attempt %d failed for room=%s user=%s: %s", + self._reconnect_attempts, + self.room_id, + self.user_id, + e, + ) + + logger.error( + "ElevenLabs reconnection exhausted (%d attempts) for room=%s user=%s", + _MAX_RECONNECT_ATTEMPTS, + self.room_id, + self.user_id, + ) + + async def _listen_loop(self) -> None: + """Receive and process messages from ElevenLabs WebSocket.""" + try: + while self._connected and self._websocket: + message_str = await self._websocket.recv() + self.last_activity = asyncio.get_event_loop().time() + + try: + message = json.loads(message_str) + except json.JSONDecodeError: + logger.warning( + "Failed to decode JSON from ElevenLabs WebSocket: %s", + message_str, + ) + continue + + self._process_message(message) + + except websockets.exceptions.ConnectionClosed as e: + logger.info( + "ElevenLabs STT WebSocket closed: code=%s, reason=%s", e.code, e.reason + ) + self._handle_unexpected_disconnect() + except asyncio.CancelledError: + pass + except Exception as e: + logger.error( + "Error in ElevenLabs STT listen loop for room=%s user=%s: %s", + self.room_id, + self.user_id, + e, + ) + self._handle_unexpected_disconnect() + + def _process_message(self, message: dict[str, Any]) -> None: + """Process decoded message from ElevenLabs real-time STT WebSocket.""" + msg_type = message.get("message_type") + + if msg_type == "session_start": + if not self._session_start_future.done(): + self._session_start_future.set_result(None) + return + + if msg_type in ("partial_transcript", "committed_transcript"): + is_final = msg_type == "committed_transcript" + transcript = message.get("text", "").strip() + + # Estimate confidence or use 1.0/0.5 + confidence = 1.0 if is_final else 0.5 + + if transcript: + # Launch callback in background task + task = asyncio.create_task( + self._on_transcript(transcript, is_final, confidence) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + def _handle_unexpected_disconnect(self) -> None: + """Handle unexpected disconnect by scheduling a reconnection.""" + self._connected = False + if not self._intentional_close: + logger.warning( + "Unexpected ElevenLabs disconnect for room=%s " + "user=%s, scheduling reconnect", + self.room_id, + self.user_id, + ) + self._reconnect_task = asyncio.create_task(self._reconnect()) diff --git a/app/external_services/elevenlabs_tts/__init__.py b/app/external_services/elevenlabs_tts/__init__.py new file mode 100644 index 0000000..839f4ed --- /dev/null +++ b/app/external_services/elevenlabs_tts/__init__.py @@ -0,0 +1,8 @@ +"""ElevenLabs TTS service package.""" + +from app.external_services.elevenlabs_tts.service import ( + ElevenLabsTTSService, + get_elevenlabs_tts_service, +) + +__all__ = ["ElevenLabsTTSService", "get_elevenlabs_tts_service"] diff --git a/app/external_services/elevenlabs_tts/config.py b/app/external_services/elevenlabs_tts/config.py new file mode 100644 index 0000000..4ce97ca --- /dev/null +++ b/app/external_services/elevenlabs_tts/config.py @@ -0,0 +1,70 @@ +"""ElevenLabs TTS configuration helpers. + +Provides header generation and language mapping for the ElevenLabs TTS API. +""" + +from app.core.config import settings + +# ElevenLabs language mappings for multilingual models +# Maps ISO 639-1 code to ElevenLabs language codes. +_LANGUAGE_MAP: dict[str, str] = { + "ar": "ar", + "bg": "bg", + "zh": "cmn", + "hr": "hr", + "cs": "cs", + "da": "da", + "nl": "nl", + "en": "en", + "fil": "fil", + "fi": "fi", + "fr": "fr", + "de": "de", + "el": "el", + "hi": "hi", + "id": "id", + "it": "it", + "ja": "ja", + "ko": "ko", + "ms": "ms", + "pl": "pl", + "pt": "pt", + "ro": "ro", + "ru": "ru", + "sk": "sk", + "es": "es", + "sv": "sv", + "ta": "ta", + "tr": "tr", + "uk": "uk", +} + +_DEFAULT_LANGUAGE = "en" + + +def get_elevenlabs_tts_headers() -> dict[str, str]: + """Build HTTP headers for ElevenLabs TTS API requests. + + Returns: + dict[str, str]: Authorization and content-type headers. + """ + return { + "xi-api-key": settings.ELEVEN_LABS_API_KEY or "", + "Content-Type": "application/json", + } + + +def get_language_code(language: str) -> str: + """Resolve an ISO 639-1 language code to ElevenLabs language code. + + Args: + language: ISO 639-1 language code (e.g. 'en', 'zh'). + + Returns: + str: The ElevenLabs language code. + """ + if not language: + return _DEFAULT_LANGUAGE + # Extract prefix if language has a locale (e.g., 'en-US' -> 'en') + base_lang = language.split("-")[0].lower() + return _LANGUAGE_MAP.get(base_lang, _DEFAULT_LANGUAGE) diff --git a/app/external_services/elevenlabs_tts/service.py b/app/external_services/elevenlabs_tts/service.py new file mode 100644 index 0000000..e0b27ec --- /dev/null +++ b/app/external_services/elevenlabs_tts/service.py @@ -0,0 +1,197 @@ +"""ElevenLabs TTS (eleven_flash_v2_5) Text-to-Speech service module. + +Wraps the ElevenLabs TTS API to convert translated text into synthesized audio. +Supports multilingual voices and raw PCM output. +""" + +import logging +import time +from collections.abc import AsyncGenerator + +import httpx + +from app.core.circuit_breaker import AsyncCircuitBreaker +from app.core.config import settings +from app.external_services.elevenlabs_tts.config import ( + get_elevenlabs_tts_headers, + get_language_code, +) + +logger = logging.getLogger(__name__) + + +class ElevenLabsTTSService: + """Stateless service for converting text to speech via ElevenLabs. + + Provides both batch and streaming synthesis methods. Uses the configured + ``ELEVEN_LABS_API_KEY`` from Settings. + """ + + def __init__(self, timeout: float = 60.0) -> None: + self._timeout = timeout + self._client: httpx.AsyncClient | None = None + self._breaker = AsyncCircuitBreaker() + + @property + def client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=self._timeout) + return self._client + + async def synthesize( + self, + text: str, + *, + language: str = "en", + encoding: str = "linear16", # default is linear16/pcm + ) -> dict: + """Convert text to audio bytes via ElevenLabs TTS. + + Args: + text: The text to synthesize. + language: ISO 639-1 language code for voice configuration. + Defaults to ``"en"``. + encoding: Output encoding. Defaults to ``"linear16"``. + + Returns: + dict: A dictionary containing ``audio_bytes``, ``sample_rate``, + and ``latency_ms``. + """ + headers = get_elevenlabs_tts_headers() + voice_id = settings.ELEVENLABS_TTS_VOICE_ID or "JBFqnCBsd6RMkjVDRZzb" + model_id = settings.ELEVENLABS_TTS_MODEL or "eleven_flash_v2_5" + lang_code = get_language_code(language) + + logger.debug("Synthesizing text with encoding %s", encoding) + + # ElevenLabs accepts pcm_24000 for 24kHz raw PCM + # Note: If encoding is not linear16, we could support other formats, + # but the pipeline expects 24kHz raw PCM for linear16. + output_format = settings.ELEVENLABS_TTS_OUTPUT_FORMAT or "pcm_24000" + + url = f"{settings.ELEVENLABS_TTS_API_URL.rstrip('/')}/{voice_id}" + params = {"output_format": output_format} + + payload = { + "text": text, + "model_id": model_id, + "language_code": lang_code, + "voice_settings": { + "stability": 0.5, + "similarity_boost": 0.75, + "style": 0.0, + "use_speaker_boost": True, + }, + } + + async def _call() -> httpx.Response: + resp = await self.client.post( + url, + headers=headers, + json=payload, + params=params, + ) + resp.raise_for_status() + return resp + + start = time.monotonic() + response = await self._breaker.call(_call) + elapsed_ms = (time.monotonic() - start) * 1000 + logger.debug("ElevenLabs TTS API completed in %.1fms", elapsed_ms) + + sample_rate = 24000 + if "24000" in output_format: + sample_rate = 24000 + elif "16000" in output_format: + sample_rate = 16000 + elif "44100" in output_format: + sample_rate = 44100 + + return { + "audio_bytes": response.content, + "sample_rate": sample_rate, + "latency_ms": round(elapsed_ms, 1), + } + + async def synthesize_stream( + self, + text: str, + *, + language: str = "en", + encoding: str = "linear16", + ) -> AsyncGenerator[dict, None]: + """Stream TTS audio chunks via ElevenLabs streaming endpoint. + + Args: + text: The text to synthesize. + language: ISO 639-1 language code. Defaults to ``"en"``. + encoding: Output encoding. Defaults to ``"linear16"``. + + Yields: + dict: A dictionary containing ``audio_bytes`` and ``sample_rate``. + """ + headers = get_elevenlabs_tts_headers() + voice_id = settings.ELEVENLABS_TTS_VOICE_ID or "JBFqnCBsd6RMkjVDRZzb" + model_id = settings.ELEVENLABS_TTS_MODEL or "eleven_flash_v2_5" + lang_code = get_language_code(language) + + logger.debug("Synthesizing stream with encoding %s", encoding) + output_format = settings.ELEVENLABS_TTS_OUTPUT_FORMAT or "pcm_24000" + + url = f"{settings.ELEVENLABS_TTS_API_URL.rstrip('/')}/{voice_id}/stream" + params = {"output_format": output_format} + + payload = { + "text": text, + "model_id": model_id, + "language_code": lang_code, + "voice_settings": { + "stability": 0.5, + "similarity_boost": 0.75, + "style": 0.0, + "use_speaker_boost": True, + }, + } + + sample_rate = 24000 + if "24000" in output_format: + sample_rate = 24000 + elif "16000" in output_format: + sample_rate = 16000 + elif "44100" in output_format: + sample_rate = 44100 + + start = time.monotonic() + async with self.client.stream( + "POST", + url, + headers=headers, + json=payload, + params=params, + ) as response: + response.raise_for_status() + elapsed_ms = (time.monotonic() - start) * 1000 + logger.debug("ElevenLabs TTS Stream initiated in %.1fms", elapsed_ms) + + async for chunk in response.aiter_bytes(chunk_size=4096): + if chunk: + yield { + "audio_bytes": chunk, + "sample_rate": sample_rate, + } + + +# ── Module-level singleton ──────────────────────────────────────────── +_tts_service: ElevenLabsTTSService | None = None + + +def get_elevenlabs_tts_service() -> ElevenLabsTTSService: + """Retrieve the singleton instance of the ElevenLabsTTSService. + + Returns: + ElevenLabsTTSService: The service instance. + """ + global _tts_service # noqa: PLW0603 + if _tts_service is None: + _tts_service = ElevenLabsTTSService() + return _tts_service diff --git a/app/services/stt_worker.py b/app/services/stt_worker.py index 2be4c9c..f7579fc 100644 --- a/app/services/stt_worker.py +++ b/app/services/stt_worker.py @@ -11,6 +11,7 @@ from typing import Any from app.external_services.deepgram.service import get_deepgram_stt_service +from app.external_services.elevenlabs_stt.service import get_elevenlabs_stt_service from app.kafka.consumer import BaseConsumer from app.kafka.schemas import BaseEvent from app.kafka.topics import AUDIO_RAW, TEXT_ORIGINAL @@ -85,12 +86,214 @@ async def handle(self, event: BaseEvent[Any]) -> None: from app.core.config import settings - use_streaming = settings.DEEPGRAM_USE_STREAMING and settings.DEEPGRAM_API_KEY + provider = settings.ACTIVE_STT_PROVIDER.lower() + try: + await self._dispatch_stt_provider( + provider, payload, buffer_key, audio_bytes + ) + except Exception as primary_err: + if not settings.STT_FALLBACK_ENABLED: + raise + fallback = settings.STT_FALLBACK_PROVIDER.lower() + if fallback == provider: + raise + logger.warning( + "Primary STT provider '%s' failed: %s. Falling back to '%s'.", + provider, + primary_err, + fallback, + ) + await self._dispatch_stt_provider( + fallback, payload, buffer_key, audio_bytes + ) + + async def _dispatch_stt_provider( + self, provider: str, payload: Any, buffer_key: str, audio_bytes: bytes + ) -> None: + from app.core.config import settings + + if provider == "elevenlabs": + use_streaming = settings.ELEVENLABS_STT_USE_STREAMING and bool( + settings.ELEVEN_LABS_API_KEY + ) + if use_streaming: + await self._handle_elevenlabs_streaming( + payload, buffer_key, audio_bytes + ) + else: + await self._handle_elevenlabs_batch(payload, buffer_key, audio_bytes) + else: # deepgram (default) + use_streaming = bool( + settings.DEEPGRAM_USE_STREAMING and settings.DEEPGRAM_API_KEY + ) + if use_streaming: + await self._handle_streaming(payload, buffer_key, audio_bytes) + else: + await self._handle_batch(payload, buffer_key, audio_bytes) + + async def _handle_elevenlabs_streaming( + self, payload: Any, buffer_key: str, audio_bytes: bytes + ) -> None: + """Stream raw audio chunks to ElevenLabs WebSocket.""" + from app.core.config import settings + + conn = self._streaming_connections.get(buffer_key) + try: + if not conn: + + async def on_transcript( + transcript_text: str, is_final: bool, confidence: float + ) -> None: + await self._on_streaming_transcript( + payload, buffer_key, transcript_text, is_final, confidence + ) + + from app.external_services.elevenlabs_stt.streaming import ( + ElevenLabsStreamingSTT, + ) + + if not settings.ELEVEN_LABS_API_KEY: + raise ValueError( + "ELEVEN_LABS_API_KEY must be set for ElevenLabs streaming STT" + ) + + conn = ElevenLabsStreamingSTT( + api_key=settings.ELEVEN_LABS_API_KEY, + room_id=payload.room_id, + user_id=payload.user_id, + on_transcript=on_transcript, + language=payload.source_language, + model=settings.ELEVENLABS_STT_REALTIME_MODEL, + sample_rate=payload.sample_rate, + ) + self._streaming_connections[buffer_key] = conn + await conn.connect() - if use_streaming: - await self._handle_streaming(payload, buffer_key, audio_bytes) + await conn.send_audio(audio_bytes) + except Exception as e: + logger.error( + "Error in ElevenLabs streaming connection for %s: %s", + buffer_key, + e, + ) + if conn: + task = asyncio.create_task(conn.close()) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + self._streaming_connections.pop(buffer_key, None) + + async def _handle_elevenlabs_batch( + self, payload: Any, buffer_key: str, audio_bytes: bytes + ) -> None: + """Buffer raw audio chunks and call ElevenLabs batch transcription.""" + from app.core.config import settings + + pipeline_start = time.monotonic() + + if buffer_key not in self._audio_buffers: + self._audio_buffers[buffer_key] = [] + + self._audio_buffers[buffer_key].append(audio_bytes) + + if len(self._audio_buffers[buffer_key]) < self.BUFFER_SIZE: + return + + full_audio = b"".join(self._audio_buffers[buffer_key]) + self._audio_buffers[buffer_key] = [] + + if not settings.ELEVEN_LABS_API_KEY: + logger.info("ELEVEN_LABS_API_KEY not set. Mocking ElevenLabs STT response.") + result: dict[str, Any] = { + "text": "Hello, this is a simulated ElevenLabs transcription.", + "detected_language": payload.source_language, + "confidence": 1.0, + } else: - await self._handle_batch(payload, buffer_key, audio_bytes) + stt_service = get_elevenlabs_stt_service() + result = await stt_service.transcribe( + full_audio, + language=payload.source_language, + sample_rate=payload.sample_rate, + encoding=payload.encoding.value, + ) + + text = result.get("text", "").strip() + if not text: + return + + transcription_payload = TranscriptionPayload( + room_id=payload.room_id, + user_id=payload.user_id, + sequence_number=payload.sequence_number, + text=text, + source_language=result.get("detected_language", payload.source_language), + is_final=True, + confidence=result.get("confidence", 0.0), + ) + transcription_event = TranscriptionEvent(payload=transcription_payload) + + await self._producer.send( + TEXT_ORIGINAL, transcription_event, key=payload.room_id + ) + + try: + from app.services.connection_manager import get_connection_manager + + manager = get_connection_manager() + task = asyncio.create_task( + manager.broadcast_to_room( + payload.room_id, + { + "type": "active_speaker_changed", + "user_id": payload.user_id, + }, + ) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + except Exception as e: + logger.error("Failed to broadcast active speaker: %s", e) + + try: + import json as _json + + from app.modules.auth.token_store import _get_redis_client + + participants = await self._state.get_participants(payload.room_id) + speaker_name = participants.get(payload.user_id, {}).get( + "display_name", "Speaker" + ) + + redis = _get_redis_client() + caption_msg = { + "event": "caption", + "type": "original", + "speaker_id": payload.user_id, + "speaker_name": speaker_name, + "text": text, + "source_language": transcription_payload.source_language, + "is_final": True, + "sequence_number": payload.sequence_number, + "timestamp_ms": int(time.time() * 1000), + } + await redis.publish( + f"pipeline:captions:{payload.room_id}", + _json.dumps(caption_msg), + ) + except Exception as redis_err: + logger.warning("Redis caption publish failed: %s", redis_err) + + elapsed_ms = (time.monotonic() - pipeline_start) * 1000 + logger.info( + "STT (ElevenLabs Batch): seq=%d room=%s user=%s " + "text='%s' confidence=%.2f latency=%.1fms", + payload.sequence_number, + payload.room_id, + payload.user_id, + text, + result.get("confidence", 0.0), + elapsed_ms, + ) async def _handle_streaming( self, payload: Any, buffer_key: str, audio_bytes: bytes diff --git a/app/services/tts_worker.py b/app/services/tts_worker.py index a835e72..1a64daf 100644 --- a/app/services/tts_worker.py +++ b/app/services/tts_worker.py @@ -16,6 +16,7 @@ from app.core.config import settings from app.external_services.deepgram_tts.service import get_deepgram_tts_service +from app.external_services.elevenlabs_tts.service import get_elevenlabs_tts_service from app.external_services.openai_tts.service import get_openai_tts_service from app.external_services.voiceai.service import get_voiceai_tts_service from app.external_services.voiceai.websocket_streaming import get_voiceai_ws_tts_service @@ -115,6 +116,12 @@ async def _dispatch_provider( pipeline_start: float, ) -> None: """Route synthesis to the specified provider.""" + if provider == "elevenlabs" and settings.ELEVENLABS_TTS_USE_STREAMING: + await self._handle_elevenlabs_streaming( + payload, text, encoding, pipeline_start + ) + return + use_ws = provider == "voiceai" and settings.VOICEAI_USE_WEBSOCKET use_streaming = ( provider == "voiceai" and settings.VOICEAI_USE_STREAMING and not use_ws @@ -264,6 +271,69 @@ async def _handle_http_streaming( elapsed_ms, ) + async def _handle_elevenlabs_streaming( + self, + payload: Any, + text: str, + encoding: str, + pipeline_start: float, + ) -> None: + """Handle ElevenLabs HTTP streaming path.""" + accumulated_bytes = bytearray() + sample_rate = 24000 + + async for chunk_data in get_elevenlabs_tts_service().synthesize_stream( + text=text, + language=payload.target_language, + encoding=encoding, + ): + chunk_bytes = chunk_data["audio_bytes"] + sample_rate = chunk_data["sample_rate"] + accumulated_bytes.extend(chunk_bytes) + + chunk_b64 = base64.b64encode(chunk_bytes).decode("ascii") + synth_payload = SynthesizedAudioPayload( + room_id=payload.room_id, + user_id=payload.user_id, + sequence_number=payload.sequence_number, + audio_data=chunk_b64, + target_language=payload.target_language, + sample_rate=sample_rate, + encoding=AudioEncoding(encoding), + ) + synth_event = SynthesizedAudioEvent(payload=synth_payload) + try: + await self._publish_audio_to_redis(synth_event) + except Exception as redis_err: + logger.warning("Redis audio egress publish failed: %s", redis_err) + + if accumulated_bytes: + full_audio_b64 = base64.b64encode(accumulated_bytes).decode("ascii") + final_payload = SynthesizedAudioPayload( + room_id=payload.room_id, + user_id=payload.user_id, + sequence_number=payload.sequence_number, + audio_data=full_audio_b64, + target_language=payload.target_language, + sample_rate=sample_rate, + encoding=AudioEncoding(encoding), + ) + final_event = SynthesizedAudioEvent(payload=final_payload) + await self._producer.send( + AUDIO_SYNTHESIZED, final_event, key=payload.room_id + ) + + elapsed_ms = (time.monotonic() - pipeline_start) * 1000 + logger.info( + "TTS (ElevenLabs Stream Final): seq=%d room=%s lang=%s " + "provider=elevenlabs audio_size=%d latency=%.1fms", + payload.sequence_number, + payload.room_id, + payload.target_language, + len(accumulated_bytes), + elapsed_ms, + ) + async def _handle_batch_synthesis( self, payload: Any, @@ -345,6 +415,11 @@ async def _synthesize( """ provider = (provider or settings.ACTIVE_TTS_PROVIDER).lower() + if provider == "elevenlabs": + return await get_elevenlabs_tts_service().synthesize( + text, language=language, encoding=encoding + ) + if provider == "deepgram": return await get_deepgram_tts_service().synthesize( text, language=language, encoding=encoding diff --git a/pyproject.toml b/pyproject.toml index d127e0c..dc3de21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -227,3 +227,4 @@ ignore_missing_imports = true [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" +pythonpath = ["."] diff --git a/tests/external_services/__init__.py b/tests/external_services/__init__.py new file mode 100644 index 0000000..5cc763a --- /dev/null +++ b/tests/external_services/__init__.py @@ -0,0 +1 @@ +"""Unit tests for external services.""" diff --git a/tests/external_services/test_elevenlabs_stt.py b/tests/external_services/test_elevenlabs_stt.py new file mode 100644 index 0000000..38d29c8 --- /dev/null +++ b/tests/external_services/test_elevenlabs_stt.py @@ -0,0 +1,63 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.core.config import settings +from app.external_services.elevenlabs_stt.config import get_stt_language_code +from app.external_services.elevenlabs_stt.service import get_elevenlabs_stt_service + + +def test_elevenlabs_stt_language_mapping(): + assert get_stt_language_code("en-US") == "en" + assert get_stt_language_code("de-DE") == "de" + assert get_stt_language_code("fr") == "fr" + assert get_stt_language_code(None) is None + + +@pytest.mark.asyncio +async def test_elevenlabs_stt_transcribe(): + # Setup + settings.ELEVEN_LABS_API_KEY = "test_stt_key" + settings.ELEVENLABS_STT_MODEL = "scribe_v2" + + response_data = { + "text": "Hello this is Scribe transcribing.", + "language_code": "en", + "words": [ + {"text": "Hello", "start": 0.0, "end": 0.5, "confidence": 0.98}, + {"text": "this", "start": 0.5, "end": 0.8, "confidence": 0.99}, + ], + } + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=response_data) + mock_response.raise_for_status = MagicMock() + + service = get_elevenlabs_stt_service() + + with patch.object(service.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + + # 24000Hz PCM audio chunk: + # (24000 samples/sec * 2 bytes/sample * 0.1s = 4800 bytes) + fake_pcm = b"\x00" * 4800 + + result = await service.transcribe( + fake_pcm, + language="en", + sample_rate=24000, + encoding="linear16", + ) + + # Assertions + mock_post.assert_called_once() + _args, kwargs = mock_post.call_args + assert kwargs["headers"]["xi-api-key"] == "test_stt_key" + + assert result["text"] == "Hello this is Scribe transcribing." + assert result["detected_language"] == "en" + # confidence is average of words: (0.98 + 0.99) / 2 = 0.985 + assert result["confidence"] == 0.985 + assert "latency_ms" in result diff --git a/tests/external_services/test_elevenlabs_tts.py b/tests/external_services/test_elevenlabs_tts.py new file mode 100644 index 0000000..116f689 --- /dev/null +++ b/tests/external_services/test_elevenlabs_tts.py @@ -0,0 +1,118 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.core.config import settings +from app.external_services.elevenlabs_tts.config import get_language_code +from app.external_services.elevenlabs_tts.service import get_elevenlabs_tts_service + + +def test_elevenlabs_tts_language_mapping(): + assert get_language_code("en") == "en" + assert get_language_code("en-US") == "en" + assert get_language_code("zh") == "cmn" + assert get_language_code("zh-CN") == "cmn" + assert get_language_code("de") == "de" + assert get_language_code("unknown") == "en" + assert get_language_code(None) == "en" + + +@pytest.mark.asyncio +async def test_elevenlabs_tts_synthesize(): + # Setup + settings.ELEVEN_LABS_API_KEY = "test_key" + settings.ELEVENLABS_TTS_VOICE_ID = "voice_abc" + settings.ELEVENLABS_TTS_MODEL = "eleven_flash_v2_5" + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = b"audio_bytes_123" + mock_response.raise_for_status = MagicMock() + + service = get_elevenlabs_tts_service() + + # We patch the httpx client's post method + with patch.object(service.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + + result = await service.synthesize("Hello world", language="en") + + # Assertions + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + assert "voice_abc" in args[0] + assert kwargs["headers"]["xi-api-key"] == "test_key" + assert kwargs["params"]["output_format"] == "pcm_24000" + + assert result["audio_bytes"] == b"audio_bytes_123" + assert result["sample_rate"] == 24000 + assert "latency_ms" in result + + +@pytest.mark.asyncio +async def test_elevenlabs_tts_synthesize_stream(): + # Setup + settings.ELEVEN_LABS_API_KEY = "test_key" + settings.ELEVENLABS_TTS_VOICE_ID = "voice_abc" + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + # Mock aiter_bytes async generator + async def mock_aiter_bytes(chunk_size=None): + _ = chunk_size + yield b"stream_chunk_123" + + mock_response.aiter_bytes = mock_aiter_bytes + + service = get_elevenlabs_tts_service() + + # Mock the client stream context manager + class MockStreamContext: + async def __aenter__(self): + return mock_response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + with patch.object( + service.client, "stream", return_value=MockStreamContext() + ) as mock_stream: + chunks = [] + async for chunk_data in service.synthesize_stream( + "Hello stream", language="en" + ): + chunks.append(chunk_data["audio_bytes"]) + assert chunk_data["sample_rate"] == 24000 + + mock_stream.assert_called_once() + assert b"".join(chunks) == b"stream_chunk_123" + + +@pytest.mark.asyncio +async def test_elevenlabs_tts_circuit_breaker(): + # Setup + settings.ELEVEN_LABS_API_KEY = "test_key" + settings.ELEVENLABS_TTS_VOICE_ID = "voice_abc" + + service = get_elevenlabs_tts_service() + # Reset breaker state + service._breaker.failure_count = 0 + service._breaker.state = "closed" + + # Mock post to raise HTTPStatusError + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 500 + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + "500 Internal Server Error", request=None, response=mock_response + ) + ) + + with patch.object(service.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + await service.synthesize("Hello fails", language="en") diff --git a/tests/test_kafka/test_pipeline.py b/tests/test_kafka/test_pipeline.py index eb4254d..a728ec7 100644 --- a/tests/test_kafka/test_pipeline.py +++ b/tests/test_kafka/test_pipeline.py @@ -320,3 +320,148 @@ async def mock_generator(*_args, **_kwargs): decoded = base64.b64decode(synth_event.payload.audio_data) assert decoded == b"chunk1chunk2" + + +@pytest.mark.asyncio +async def test_tts_worker_handle_elevenlabs(mock_producer, base_translation_event): + worker = TTSWorker(producer=mock_producer) + + with ( + patch( + "app.services.tts_worker.get_elevenlabs_tts_service" + ) as mock_get_elevenlabs, + patch("app.services.tts_worker.settings") as mock_settings, + patch("app.modules.auth.token_store._get_redis_client") as mock_get_redis, + ): + mock_settings.ACTIVE_TTS_PROVIDER = "elevenlabs" + mock_settings.ELEVENLABS_TTS_USE_STREAMING = False + mock_settings.PIPELINE_AUDIO_ENCODING = "linear16" + + mock_elevenlabs = AsyncMock() + mock_elevenlabs.synthesize.return_value = { + "audio_bytes": b"elevenlabs_batch_bytes", + "sample_rate": 24000, + } + mock_get_elevenlabs.return_value = mock_elevenlabs + + redis_mock = MagicMock() + redis_mock.publish = AsyncMock() + mock_get_redis.return_value = redis_mock + + await worker.handle(base_translation_event) + + mock_elevenlabs.synthesize.assert_called_once_with( + "Bonjour le monde", + language="fr", + encoding="linear16", + ) + + mock_producer.send.assert_called_once() + args, _kwargs = mock_producer.send.call_args + assert args[0] == "audio.synthesized" + + synth_event = args[1] + assert synth_event.payload.sample_rate == 24000 + decoded = base64.b64decode(synth_event.payload.audio_data) + assert decoded == b"elevenlabs_batch_bytes" + + +@pytest.mark.asyncio +async def test_tts_worker_handle_elevenlabs_streaming( + mock_producer, base_translation_event +): + worker = TTSWorker(producer=mock_producer) + + with ( + patch( + "app.services.tts_worker.get_elevenlabs_tts_service" + ) as mock_get_elevenlabs, + patch("app.services.tts_worker.settings") as mock_settings, + patch("app.modules.auth.token_store._get_redis_client") as mock_get_redis, + ): + mock_settings.ACTIVE_TTS_PROVIDER = "elevenlabs" + mock_settings.ELEVENLABS_TTS_USE_STREAMING = True + mock_settings.PIPELINE_AUDIO_ENCODING = "linear16" + + mock_elevenlabs = AsyncMock() + + async def mock_generator(*_args, **_kwargs): + yield {"audio_bytes": b"el_chunk1", "sample_rate": 24000} + yield {"audio_bytes": b"el_chunk2", "sample_rate": 24000} + + mock_elevenlabs.synthesize_stream = mock_generator + mock_get_elevenlabs.return_value = mock_elevenlabs + + redis_mock = MagicMock() + redis_mock.publish = AsyncMock() + mock_get_redis.return_value = redis_mock + + await worker.handle(base_translation_event) + + assert redis_mock.publish.call_count == 2 + + mock_producer.send.assert_called_once() + args, _kwargs = mock_producer.send.call_args + assert args[0] == "audio.synthesized" + + synth_event = args[1] + assert synth_event.payload.sample_rate == 24000 + decoded = base64.b64decode(synth_event.payload.audio_data) + assert decoded == b"el_chunk1el_chunk2" + + +@pytest.mark.asyncio +async def test_stt_worker_handle_elevenlabs_batch( + mock_producer, base_audio_chunk_event +): + worker = STTWorker(producer=mock_producer) + + with ( + patch( + "app.services.stt_worker.get_elevenlabs_stt_service" + ) as mock_get_elevenlabs_stt, + patch("app.core.config.settings") as mock_settings, + patch("app.services.connection_manager.get_connection_manager") as mock_get_cm, + patch("app.modules.auth.token_store._get_redis_client") as mock_get_redis, + ): + mock_settings.ACTIVE_STT_PROVIDER = "elevenlabs" + mock_settings.ELEVEN_LABS_API_KEY = "fake-key" + mock_settings.ELEVENLABS_STT_USE_STREAMING = False + mock_settings.STT_FALLBACK_ENABLED = True + mock_settings.STT_FALLBACK_PROVIDER = "deepgram" + + mock_stt_svc = AsyncMock() + mock_stt_svc.transcribe.return_value = { + "text": "Hello ElevenLabs Scribe", + "confidence": 0.97, + "detected_language": "en", + } + mock_get_elevenlabs_stt.return_value = mock_stt_svc + + redis_mock = MagicMock() + redis_mock.publish = AsyncMock() + mock_get_redis.return_value = redis_mock + + cm_mock = MagicMock() + cm_mock.broadcast_to_room = AsyncMock() + mock_get_cm.return_value = cm_mock + + mock_state = AsyncMock() + mock_state.get_participants.return_value = { + "user456": {"display_name": "Speaker Name"} + } + worker._state = mock_state + + for _ in range(STTWorker.BUFFER_SIZE): + await worker.handle(base_audio_chunk_event) + + mock_stt_svc.transcribe.assert_called_once_with( + b"fake_audio" * STTWorker.BUFFER_SIZE, + language="en", + sample_rate=16000, + encoding="linear16", + ) + mock_producer.send.assert_called_once() + args, _ = mock_producer.send.call_args + assert args[0] == "text.original" + assert args[1].payload.text == "Hello ElevenLabs Scribe"