From 3b2cfee65a39421880b84fb933be1313c0e8e518 Mon Sep 17 00:00:00 2001 From: Tyro Sageframe Date: Sun, 31 May 2026 14:29:36 -0400 Subject: [PATCH 1/3] feat(api): add POST /dev/unload to release model from GPU VRAM Allows homelab deployments to free GPU memory when the TTS service is idle without stopping the container. The model reloads lazily on the next request. - ModelManager.unload(): acquires lock, calls backend.unload(), nulls _backend, then calls torch.cuda.empty_cache() if CUDA is available - ModelManager.generate(): lazy reinit when _backend is None (calls initialize() + load_model()) instead of raising RuntimeError - POST /dev/unload: 200 on success, 503 if manager not initialised, 500 on unexpected error - TTSService.model_manager annotated as Optional[ModelManager] for correct mypy narrowing at the endpoint - Full test coverage in api/tests/test_model_unload.py (11 tests) Closes #473 (partial) Co-Authored-By: Claude Sonnet 4.6 --- api/src/inference/model_manager.py | 16 ++- api/src/routers/development.py | 21 +++ api/src/services/tts_service.py | 3 +- api/tests/test_model_unload.py | 211 +++++++++++++++++++++++++++++ 4 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 api/tests/test_model_unload.py diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index eb817ecb..9fbb1db1 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -1,7 +1,9 @@ """Kokoro V1 model management.""" +import asyncio from typing import Optional +import torch from loguru import logger from ..core import paths @@ -26,6 +28,7 @@ def __init__(self, config: Optional[ModelConfig] = None): self._config = config or model_config self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1 self._device: Optional[str] = None + self._lock = asyncio.Lock() def _determine_device(self) -> str: """Determine device based on settings.""" @@ -137,7 +140,8 @@ async def generate(self, *args, **kwargs): RuntimeError: If generation fails """ if not self._backend: - raise RuntimeError("Backend not initialized") + await self.initialize() + await self.load_model(self._config.pytorch_kokoro_v1_file) try: async for chunk in self._backend.generate(*args, **kwargs): @@ -153,6 +157,16 @@ def unload_all(self) -> None: self._backend.unload() self._backend = None + async def unload(self) -> None: + """Release model from GPU memory. Reloads automatically on next request.""" + async with self._lock: + if self._backend is not None: + self._backend.unload() + self._backend = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("Model unloaded from GPU memory") + @property def current_backend(self) -> str: """Get current backend type.""" diff --git a/api/src/routers/development.py b/api/src/routers/development.py index 8c8ed7e1..b310c81a 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -411,3 +411,24 @@ async def single_output(): "type": "server_error", }, ) + + +@router.post("/dev/unload") +async def unload_model( + tts_service: TTSService = Depends(get_tts_service), +): + """Release the model from GPU VRAM without stopping the container. + + The model reloads automatically on the next inference request. + Useful for homelab deployments where GPU memory is shared across services. + """ + try: + if tts_service.model_manager is None: + raise HTTPException(status_code=503, detail={"error": "Model manager not initialized"}) + await tts_service.model_manager.unload() + return JSONResponse({"status": "unloaded"}) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error unloading model: {e}") + raise HTTPException(status_code=500, detail={"error": str(e)}) diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 1ac74557..352dbe78 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -15,6 +15,7 @@ from ..core.config import settings from ..inference.base import AudioChunk from ..inference.kokoro_v1 import KokoroV1 +from ..inference.model_manager import ModelManager from ..inference.model_manager import get_manager as get_model_manager from ..inference.voice_manager import get_manager as get_voice_manager from ..structures.schemas import NormalizationOptions @@ -33,7 +34,7 @@ class TTSService: def __init__(self, output_dir: str = None): """Initialize service.""" self.output_dir = output_dir - self.model_manager = None + self.model_manager: Optional[ModelManager] = None self._voice_manager = None @classmethod diff --git a/api/tests/test_model_unload.py b/api/tests/test_model_unload.py new file mode 100644 index 00000000..37adfc82 --- /dev/null +++ b/api/tests/test_model_unload.py @@ -0,0 +1,211 @@ +"""Tests for ModelManager.unload(), lazy reinit in generate(), and POST /dev/unload.""" + +import asyncio +from contextlib import contextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from fastapi.testclient import TestClient + +from api.src.inference.base import AudioChunk +from api.src.inference.model_manager import ModelManager +from api.src.main import app +from api.src.routers.development import get_tts_service +from api.src.services.tts_service import TTSService + +client = TestClient(app) + + +@contextmanager +def override_tts_service(service): + """Override the get_tts_service FastAPI dependency for the duration of the block.""" + async def _override(): + return service + + app.dependency_overrides[get_tts_service] = _override + try: + yield + finally: + app.dependency_overrides.pop(get_tts_service, None) + + +# --------------------------------------------------------------------------- +# ModelManager unit tests +# --------------------------------------------------------------------------- + + +def test_manager_init_creates_lock(): + manager = ModelManager() + assert isinstance(manager._lock, asyncio.Lock) + + +@pytest.mark.asyncio +async def test_unload_clears_backend(): + manager = ModelManager() + mock_backend = MagicMock() + manager._backend = mock_backend + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = False + await manager.unload() + + mock_backend.unload.assert_called_once() + assert manager._backend is None + + +@pytest.mark.asyncio +async def test_unload_when_already_none_is_noop(): + manager = ModelManager() + assert manager._backend is None + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = False + await manager.unload() # must not raise + + assert manager._backend is None + + +@pytest.mark.asyncio +async def test_unload_calls_cuda_empty_cache_when_available(): + manager = ModelManager() + manager._backend = MagicMock() + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = True + await manager.unload() + + mock_torch.cuda.empty_cache.assert_called_once() + + +@pytest.mark.asyncio +async def test_unload_skips_cuda_empty_cache_when_unavailable(): + manager = ModelManager() + manager._backend = MagicMock() + + with patch("api.src.inference.model_manager.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = False + await manager.unload() + + mock_torch.cuda.empty_cache.assert_not_called() + + +@pytest.mark.asyncio +async def test_generate_lazy_reinit_when_backend_none(): + """generate() initializes backend lazily when _backend is None.""" + manager = ModelManager() + assert manager._backend is None + + mock_backend = MagicMock() + audio_chunk = AudioChunk(np.zeros(10, dtype=np.float32)) + + async def fake_generate(*args, **kwargs): + yield audio_chunk + + mock_backend.generate = fake_generate + + async def fake_initialize(): + manager._backend = mock_backend + + with ( + patch.object(manager, "initialize", side_effect=fake_initialize) as mock_init, + patch.object(manager, "load_model", new_callable=AsyncMock) as mock_load, + ): + chunks = [] + async for chunk in manager.generate("hello", ("voice", "/path/voice.pt")): + chunks.append(chunk) + + mock_init.assert_called_once() + mock_load.assert_called_once_with(manager._config.pytorch_kokoro_v1_file) + assert len(chunks) == 1 + assert chunks[0] is audio_chunk + + +@pytest.mark.asyncio +async def test_generate_skips_reinit_when_backend_set(): + """generate() does not call initialize/load_model when backend already exists.""" + manager = ModelManager() + mock_backend = MagicMock() + audio_chunk = AudioChunk(np.zeros(10, dtype=np.float32)) + + async def fake_generate(*args, **kwargs): + yield audio_chunk + + mock_backend.generate = fake_generate + manager._backend = mock_backend + + with ( + patch.object(manager, "initialize", new_callable=AsyncMock) as mock_init, + patch.object(manager, "load_model", new_callable=AsyncMock) as mock_load, + ): + chunks = [] + async for chunk in manager.generate("hello", ("voice", "/path/voice.pt")): + chunks.append(chunk) + + mock_init.assert_not_called() + mock_load.assert_not_called() + assert len(chunks) == 1 + + +# --------------------------------------------------------------------------- +# POST /dev/unload endpoint tests +# --------------------------------------------------------------------------- + + +def _mock_service(manager=None): + """Build a TTSService-shaped mock with the given model_manager.""" + service = MagicMock(spec=TTSService) + service.model_manager = manager + return service + + +def test_unload_endpoint_returns_200(): + mock_manager = AsyncMock() + mock_manager.unload = AsyncMock() + service = _mock_service(manager=mock_manager) + + with override_tts_service(service): + response = client.post("/dev/unload") + + assert response.status_code == 200 + assert response.json() == {"status": "unloaded"} + mock_manager.unload.assert_called_once() + + +def test_unload_endpoint_idempotent(): + """Calling /dev/unload twice both succeed — unload is a no-op when already clear.""" + mock_manager = AsyncMock() + mock_manager.unload = AsyncMock() + service = _mock_service(manager=mock_manager) + + with override_tts_service(service): + r1 = client.post("/dev/unload") + r2 = client.post("/dev/unload") + + assert r1.status_code == 200 + assert r2.status_code == 200 + assert mock_manager.unload.call_count == 2 + + +def test_unload_endpoint_503_when_manager_none(): + """Returns 503 when model_manager has not been initialised on the service.""" + service = _mock_service(manager=None) + + with override_tts_service(service): + response = client.post("/dev/unload") + + assert response.status_code == 503 + assert response.json()["detail"]["error"] == "Model manager not initialized" + + +def test_unload_endpoint_500_on_exception(): + """Returns 500 when manager.unload() raises unexpectedly.""" + mock_manager = AsyncMock() + mock_manager.unload = AsyncMock(side_effect=RuntimeError("GPU exploded")) + service = _mock_service(manager=mock_manager) + + with override_tts_service(service): + response = client.post("/dev/unload") + + assert response.status_code == 500 + assert "GPU exploded" in response.json()["detail"]["error"] From 750fabcc93b28072d684eb4db306337b8b481c7a Mon Sep 17 00:00:00 2001 From: Tyro Sageframe Date: Sun, 31 May 2026 21:00:58 -0400 Subject: [PATCH 2/3] fix(model-unload): lazy reload on next request after /dev/unload Add ensure_backend() to ModelManager which reinitialises the backend and reloads the model if /dev/unload was called. All three get_backend() call sites in tts_service now await ensure_backend() first, so the first TTS request after an unload reloads the model automatically rather than raising RuntimeError: Backend not initialized. Co-Authored-By: Claude Sonnet 4.6 --- api/src/inference/model_manager.py | 6 ++++++ api/src/services/tts_service.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index 9fbb1db1..dd52d88a 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -101,6 +101,12 @@ async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]: except Exception as e: raise RuntimeError(f"Warmup failed: {e}") + async def ensure_backend(self) -> None: + """Reload the backend if it was unloaded via /dev/unload.""" + if not self._backend: + await self.initialize() + await self.load_model(self._config.pytorch_kokoro_v1_file) + def get_backend(self) -> BaseModelBackend: """Get initialized backend. diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 352dbe78..02de19cc 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -88,7 +88,7 @@ async def _process_chunk( if not tokens and not chunk_text: return - # Get backend + await self.model_manager.ensure_backend() backend = self.model_manager.get_backend() # Generate audio using pre-warmed model @@ -273,7 +273,7 @@ async def generate_audio_stream( chunk_index = 0 current_offset = 0.0 try: - # Get backend + await self.model_manager.ensure_backend() backend = self.model_manager.get_backend() # Get voice path, handling combined voices @@ -466,7 +466,7 @@ async def generate_from_phonemes( """ start_time = time.time() try: - # Get backend and voice path + await self.model_manager.ensure_backend() backend = self.model_manager.get_backend() voice_name, voice_path = await self._get_voices_path(voice) From 36d18f77a405453d0e879678de0a42c7ee8c9c8e Mon Sep 17 00:00:00 2001 From: Tyro Sageframe Date: Tue, 2 Jun 2026 18:00:41 -0400 Subject: [PATCH 3/3] fix(model-unload): add double-checked locking to ensure_backend() Addresses review feedback: the original lazy-reload in generate() ran without a lock, so a burst of requests landing while _backend is None could trigger multiple concurrent initialize()/load_model() calls. - ensure_backend(): fast-path check outside lock, then re-check inside _lock before initializing (double-checked locking pattern) - generate(): routes through ensure_backend() instead of inline check, eliminating the duplicate code path - test_ensure_backend_serializes_concurrent_reloads: 5-way concurrent gather confirms only one initialize/load_model cycle fires Co-Authored-By: Claude Sonnet 4.6 --- api/src/inference/model_manager.py | 14 ++++++++------ api/tests/test_model_unload.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index dd52d88a..231f1f97 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -103,9 +103,12 @@ async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]: async def ensure_backend(self) -> None: """Reload the backend if it was unloaded via /dev/unload.""" - if not self._backend: - await self.initialize() - await self.load_model(self._config.pytorch_kokoro_v1_file) + if self._backend: + return + async with self._lock: + if not self._backend: + await self.initialize() + await self.load_model(self._config.pytorch_kokoro_v1_file) def get_backend(self) -> BaseModelBackend: """Get initialized backend. @@ -145,9 +148,8 @@ async def generate(self, *args, **kwargs): Raises: RuntimeError: If generation fails """ - if not self._backend: - await self.initialize() - await self.load_model(self._config.pytorch_kokoro_v1_file) + await self.ensure_backend() + assert self._backend is not None # ensure_backend loaded it or raised try: async for chunk in self._backend.generate(*args, **kwargs): diff --git a/api/tests/test_model_unload.py b/api/tests/test_model_unload.py index 37adfc82..6b31489a 100644 --- a/api/tests/test_model_unload.py +++ b/api/tests/test_model_unload.py @@ -90,6 +90,36 @@ async def test_unload_skips_cuda_empty_cache_when_unavailable(): mock_torch.cuda.empty_cache.assert_not_called() +@pytest.mark.asyncio +async def test_ensure_backend_serializes_concurrent_reloads(): + """Concurrent callers when _backend is None should trigger only one load cycle.""" + manager = ModelManager() + assert manager._backend is None + + mock_backend = MagicMock() + init_count = 0 + load_count = 0 + + async def fake_initialize(): + nonlocal init_count + init_count += 1 + await asyncio.sleep(0) # yield so other tasks can attempt entry + manager._backend = mock_backend + + async def fake_load(path): + nonlocal load_count + load_count += 1 + + with ( + patch.object(manager, "initialize", side_effect=fake_initialize), + patch.object(manager, "load_model", side_effect=fake_load), + ): + await asyncio.gather(*[manager.ensure_backend() for _ in range(5)]) + + assert init_count == 1 + assert load_count == 1 + + @pytest.mark.asyncio async def test_generate_lazy_reinit_when_backend_none(): """generate() initializes backend lazily when _backend is None."""