diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index eb817ecb..231f1f97 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -1,7 +1,9 @@ """Kokoro V1 model management.""" +import asyncio from typing import Optional +import torch from loguru import logger from ..core import paths @@ -26,6 +28,7 @@ def __init__(self, config: Optional[ModelConfig] = None): self._config = config or model_config self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1 self._device: Optional[str] = None + self._lock = asyncio.Lock() def _determine_device(self) -> str: """Determine device based on settings.""" @@ -98,6 +101,15 @@ async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]: except Exception as e: raise RuntimeError(f"Warmup failed: {e}") + async def ensure_backend(self) -> None: + """Reload the backend if it was unloaded via /dev/unload.""" + if self._backend: + return + async with self._lock: + if not self._backend: + await self.initialize() + await self.load_model(self._config.pytorch_kokoro_v1_file) + def get_backend(self) -> BaseModelBackend: """Get initialized backend. @@ -136,8 +148,8 @@ async def generate(self, *args, **kwargs): Raises: RuntimeError: If generation fails """ - if not self._backend: - raise RuntimeError("Backend not initialized") + await self.ensure_backend() + assert self._backend is not None # ensure_backend loaded it or raised try: async for chunk in self._backend.generate(*args, **kwargs): @@ -153,6 +165,16 @@ def unload_all(self) -> None: self._backend.unload() self._backend = None + async def unload(self) -> None: + """Release model from GPU memory. Reloads automatically on next request.""" + async with self._lock: + if self._backend is not None: + self._backend.unload() + self._backend = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("Model unloaded from GPU memory") + @property def current_backend(self) -> str: """Get current backend type.""" diff --git a/api/src/routers/development.py b/api/src/routers/development.py index 8c8ed7e1..b310c81a 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -411,3 +411,24 @@ async def single_output(): "type": "server_error", }, ) + + +@router.post("/dev/unload") +async def unload_model( + tts_service: TTSService = Depends(get_tts_service), +): + """Release the model from GPU VRAM without stopping the container. + + The model reloads automatically on the next inference request. + Useful for homelab deployments where GPU memory is shared across services. + """ + try: + if tts_service.model_manager is None: + raise HTTPException(status_code=503, detail={"error": "Model manager not initialized"}) + await tts_service.model_manager.unload() + return JSONResponse({"status": "unloaded"}) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error unloading model: {e}") + raise HTTPException(status_code=500, detail={"error": str(e)}) diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 1ac74557..02de19cc 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -15,6 +15,7 @@ from ..core.config import settings from ..inference.base import AudioChunk from ..inference.kokoro_v1 import KokoroV1 +from ..inference.model_manager import ModelManager from ..inference.model_manager import get_manager as get_model_manager from ..inference.voice_manager import get_manager as get_voice_manager from ..structures.schemas import NormalizationOptions @@ -33,7 +34,7 @@ class TTSService: def __init__(self, output_dir: str = None): """Initialize service.""" self.output_dir = output_dir - self.model_manager = None + self.model_manager: Optional[ModelManager] = None self._voice_manager = None @classmethod @@ -87,7 +88,7 @@ async def _process_chunk( if not tokens and not chunk_text: return - # Get backend + await self.model_manager.ensure_backend() backend = self.model_manager.get_backend() # Generate audio using pre-warmed model @@ -272,7 +273,7 @@ async def generate_audio_stream( chunk_index = 0 current_offset = 0.0 try: - # Get backend + await self.model_manager.ensure_backend() backend = self.model_manager.get_backend() # Get voice path, handling combined voices @@ -465,7 +466,7 @@ async def generate_from_phonemes( """ start_time = time.time() try: - # Get backend and voice path + await self.model_manager.ensure_backend() backend = self.model_manager.get_backend() voice_name, voice_path = await self._get_voices_path(voice) diff --git a/api/tests/test_model_unload.py b/api/tests/test_model_unload.py new file mode 100644 index 00000000..6b31489a --- /dev/null +++ b/api/tests/test_model_unload.py @@ -0,0 +1,241 @@ +"""Tests for ModelManager.unload(), lazy reinit in generate(), and POST /dev/unload.""" + +import asyncio +from contextlib import contextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from fastapi.testclient import TestClient + +from api.src.inference.base import AudioChunk +from api.src.inference.model_manager import ModelManager +from api.src.main import app +from api.src.routers.development import get_tts_service +from api.src.services.tts_service import TTSService + +client = TestClient(app) + + +@contextmanager +def override_tts_service(service): + """Override the get_tts_service FastAPI dependency for the duration of the block.""" + async def _override(): + return service + + app.dependency_overrides[get_tts_service] = _override + try: + yield + finally: + app.dependency_overrides.pop(get_tts_service, None) + + +# --------------------------------------------------------------------------- +# ModelManager unit tests +# --------------------------------------------------------------------------- + + +def test_manager_init_creates_lock(): + manager = ModelManager() + assert isinstance(manager._lock, asyncio.Lock) + + +@pytest.mark.asyncio +async def test_unload_clears_backend(): + manager = ModelManager() + mock_backend = MagicMock() + manager._backend = mock_backend + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = False + await manager.unload() + + mock_backend.unload.assert_called_once() + assert manager._backend is None + + +@pytest.mark.asyncio +async def test_unload_when_already_none_is_noop(): + manager = ModelManager() + assert manager._backend is None + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = False + await manager.unload() # must not raise + + assert manager._backend is None + + +@pytest.mark.asyncio +async def test_unload_calls_cuda_empty_cache_when_available(): + manager = ModelManager() + manager._backend = MagicMock() + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = True + await manager.unload() + + mock_torch.cuda.empty_cache.assert_called_once() + + +@pytest.mark.asyncio +async def test_unload_skips_cuda_empty_cache_when_unavailable(): + manager = ModelManager() + manager._backend = MagicMock() + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = False + await manager.unload() + + mock_torch.cuda.empty_cache.assert_not_called() + + +@pytest.mark.asyncio +async def test_ensure_backend_serializes_concurrent_reloads(): + """Concurrent callers when _backend is None should trigger only one load cycle.""" + manager = ModelManager() + assert manager._backend is None + + mock_backend = MagicMock() + init_count = 0 + load_count = 0 + + async def fake_initialize(): + nonlocal init_count + init_count += 1 + await asyncio.sleep(0) # yield so other tasks can attempt entry + manager._backend = mock_backend + + async def fake_load(path): + nonlocal load_count + load_count += 1 + + with ( + patch.object(manager, "initialize", side_effect=fake_initialize), + patch.object(manager, "load_model", side_effect=fake_load), + ): + await asyncio.gather(*[manager.ensure_backend() for _ in range(5)]) + + assert init_count == 1 + assert load_count == 1 + + +@pytest.mark.asyncio +async def test_generate_lazy_reinit_when_backend_none(): + """generate() initializes backend lazily when _backend is None.""" + manager = ModelManager() + assert manager._backend is None + + mock_backend = MagicMock() + audio_chunk = AudioChunk(np.zeros(10, dtype=np.float32)) + + async def fake_generate(*args, **kwargs): + yield audio_chunk + + mock_backend.generate = fake_generate + + async def fake_initialize(): + manager._backend = mock_backend + + with ( + patch.object(manager, "initialize", side_effect=fake_initialize) as mock_init, + patch.object(manager, "load_model", new_callable=AsyncMock) as mock_load, + ): + chunks = [] + async for chunk in manager.generate("hello", ("voice", "/path/voice.pt")): + chunks.append(chunk) + + mock_init.assert_called_once() + mock_load.assert_called_once_with(manager._config.pytorch_kokoro_v1_file) + assert len(chunks) == 1 + assert chunks[0] is audio_chunk + + +@pytest.mark.asyncio +async def test_generate_skips_reinit_when_backend_set(): + """generate() does not call initialize/load_model when backend already exists.""" + manager = ModelManager() + mock_backend = MagicMock() + audio_chunk = AudioChunk(np.zeros(10, dtype=np.float32)) + + async def fake_generate(*args, **kwargs): + yield audio_chunk + + mock_backend.generate = fake_generate + manager._backend = mock_backend + + with ( + patch.object(manager, "initialize", new_callable=AsyncMock) as mock_init, + patch.object(manager, "load_model", new_callable=AsyncMock) as mock_load, + ): + chunks = [] + async for chunk in manager.generate("hello", ("voice", "/path/voice.pt")): + chunks.append(chunk) + + mock_init.assert_not_called() + mock_load.assert_not_called() + assert len(chunks) == 1 + + +# --------------------------------------------------------------------------- +# POST /dev/unload endpoint tests +# --------------------------------------------------------------------------- + + +def _mock_service(manager=None): + """Build a TTSService-shaped mock with the given model_manager.""" + service = MagicMock(spec=TTSService) + service.model_manager = manager + return service + + +def test_unload_endpoint_returns_200(): + mock_manager = AsyncMock() + mock_manager.unload = AsyncMock() + service = _mock_service(manager=mock_manager) + + with override_tts_service(service): + response = client.post("/dev/unload") + + assert response.status_code == 200 + assert response.json() == {"status": "unloaded"} + mock_manager.unload.assert_called_once() + + +def test_unload_endpoint_idempotent(): + """Calling /dev/unload twice both succeed — unload is a no-op when already clear.""" + mock_manager = AsyncMock() + mock_manager.unload = AsyncMock() + service = _mock_service(manager=mock_manager) + + with override_tts_service(service): + r1 = client.post("/dev/unload") + r2 = client.post("/dev/unload") + + assert r1.status_code == 200 + assert r2.status_code == 200 + assert mock_manager.unload.call_count == 2 + + +def test_unload_endpoint_503_when_manager_none(): + """Returns 503 when model_manager has not been initialised on the service.""" + service = _mock_service(manager=None) + + with override_tts_service(service): + response = client.post("/dev/unload") + + assert response.status_code == 503 + assert response.json()["detail"]["error"] == "Model manager not initialized" + + +def test_unload_endpoint_500_on_exception(): + """Returns 500 when manager.unload() raises unexpectedly.""" + mock_manager = AsyncMock() + mock_manager.unload = AsyncMock(side_effect=RuntimeError("GPU exploded")) + service = _mock_service(manager=mock_manager) + + with override_tts_service(service): + response = client.post("/dev/unload") + + assert response.status_code == 500 + assert "GPU exploded" in response.json()["detail"]["error"]