Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions api/src/inference/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Kokoro V1 model management."""

import asyncio
from typing import Optional

import torch
from loguru import logger

from ..core import paths
Expand All @@ -26,6 +28,7 @@ def __init__(self, config: Optional[ModelConfig] = None):
self._config = config or model_config
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
self._device: Optional[str] = None
self._lock = asyncio.Lock()

def _determine_device(self) -> str:
"""Determine device based on settings."""
Expand Down Expand Up @@ -98,6 +101,15 @@ async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
except Exception as e:
raise RuntimeError(f"Warmup failed: {e}")

async def ensure_backend(self) -> None:
"""Reload the backend if it was unloaded via /dev/unload."""
if self._backend:
return
async with self._lock:
if not self._backend:
await self.initialize()
await self.load_model(self._config.pytorch_kokoro_v1_file)

def get_backend(self) -> BaseModelBackend:
"""Get initialized backend.

Expand Down Expand Up @@ -136,8 +148,8 @@ async def generate(self, *args, **kwargs):
Raises:
RuntimeError: If generation fails
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
await self.ensure_backend()
assert self._backend is not None # ensure_backend loaded it or raised

try:
async for chunk in self._backend.generate(*args, **kwargs):
Expand All @@ -153,6 +165,16 @@ def unload_all(self) -> None:
self._backend.unload()
self._backend = None

async def unload(self) -> None:
"""Release model from GPU memory. Reloads automatically on next request."""
async with self._lock:
if self._backend is not None:
self._backend.unload()
self._backend = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Model unloaded from GPU memory")

@property
def current_backend(self) -> str:
"""Get current backend type."""
Expand Down
21 changes: 21 additions & 0 deletions api/src/routers/development.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,24 @@ async def single_output():
"type": "server_error",
},
)


@router.post("/dev/unload")
async def unload_model(
tts_service: TTSService = Depends(get_tts_service),
):
"""Release the model from GPU VRAM without stopping the container.

The model reloads automatically on the next inference request.
Useful for homelab deployments where GPU memory is shared across services.
"""
try:
if tts_service.model_manager is None:
raise HTTPException(status_code=503, detail={"error": "Model manager not initialized"})
await tts_service.model_manager.unload()
return JSONResponse({"status": "unloaded"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Error unloading model: {e}")
raise HTTPException(status_code=500, detail={"error": str(e)})
9 changes: 5 additions & 4 deletions api/src/services/tts_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..core.config import settings
from ..inference.base import AudioChunk
from ..inference.kokoro_v1 import KokoroV1
from ..inference.model_manager import ModelManager
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
from ..structures.schemas import NormalizationOptions
Expand All @@ -33,7 +34,7 @@ class TTSService:
def __init__(self, output_dir: str = None):
"""Initialize service."""
self.output_dir = output_dir
self.model_manager = None
self.model_manager: Optional[ModelManager] = None
self._voice_manager = None

@classmethod
Expand Down Expand Up @@ -87,7 +88,7 @@ async def _process_chunk(
if not tokens and not chunk_text:
return

# Get backend
await self.model_manager.ensure_backend()
backend = self.model_manager.get_backend()

# Generate audio using pre-warmed model
Expand Down Expand Up @@ -272,7 +273,7 @@ async def generate_audio_stream(
chunk_index = 0
current_offset = 0.0
try:
# Get backend
await self.model_manager.ensure_backend()
backend = self.model_manager.get_backend()

# Get voice path, handling combined voices
Expand Down Expand Up @@ -465,7 +466,7 @@ async def generate_from_phonemes(
"""
start_time = time.time()
try:
# Get backend and voice path
await self.model_manager.ensure_backend()
backend = self.model_manager.get_backend()
voice_name, voice_path = await self._get_voices_path(voice)

Expand Down
241 changes: 241 additions & 0 deletions api/tests/test_model_unload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""Tests for ModelManager.unload(), lazy reinit in generate(), and POST /dev/unload."""

import asyncio
from contextlib import contextmanager
from unittest.mock import AsyncMock, MagicMock, patch

import numpy as np
import pytest
from fastapi.testclient import TestClient

from api.src.inference.base import AudioChunk
from api.src.inference.model_manager import ModelManager
from api.src.main import app
from api.src.routers.development import get_tts_service
from api.src.services.tts_service import TTSService

client = TestClient(app)


@contextmanager
def override_tts_service(service):
"""Override the get_tts_service FastAPI dependency for the duration of the block."""
async def _override():
return service

app.dependency_overrides[get_tts_service] = _override
try:
yield
finally:
app.dependency_overrides.pop(get_tts_service, None)


# ---------------------------------------------------------------------------
# ModelManager unit tests
# ---------------------------------------------------------------------------


def test_manager_init_creates_lock():
manager = ModelManager()
assert isinstance(manager._lock, asyncio.Lock)


@pytest.mark.asyncio
async def test_unload_clears_backend():
manager = ModelManager()
mock_backend = MagicMock()
manager._backend = mock_backend

with patch("api.src.inference.model_manager.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
await manager.unload()

mock_backend.unload.assert_called_once()
assert manager._backend is None


@pytest.mark.asyncio
async def test_unload_when_already_none_is_noop():
manager = ModelManager()
assert manager._backend is None

with patch("api.src.inference.model_manager.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
await manager.unload() # must not raise

assert manager._backend is None


@pytest.mark.asyncio
async def test_unload_calls_cuda_empty_cache_when_available():
manager = ModelManager()
manager._backend = MagicMock()

with patch("api.src.inference.model_manager.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = True
await manager.unload()

mock_torch.cuda.empty_cache.assert_called_once()


@pytest.mark.asyncio
async def test_unload_skips_cuda_empty_cache_when_unavailable():
manager = ModelManager()
manager._backend = MagicMock()

with patch("api.src.inference.model_manager.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
await manager.unload()

mock_torch.cuda.empty_cache.assert_not_called()


@pytest.mark.asyncio
async def test_ensure_backend_serializes_concurrent_reloads():
"""Concurrent callers when _backend is None should trigger only one load cycle."""
manager = ModelManager()
assert manager._backend is None

mock_backend = MagicMock()
init_count = 0
load_count = 0

async def fake_initialize():
nonlocal init_count
init_count += 1
await asyncio.sleep(0) # yield so other tasks can attempt entry
manager._backend = mock_backend

async def fake_load(path):
nonlocal load_count
load_count += 1

with (
patch.object(manager, "initialize", side_effect=fake_initialize),
patch.object(manager, "load_model", side_effect=fake_load),
):
await asyncio.gather(*[manager.ensure_backend() for _ in range(5)])

assert init_count == 1
assert load_count == 1


@pytest.mark.asyncio
async def test_generate_lazy_reinit_when_backend_none():
"""generate() initializes backend lazily when _backend is None."""
manager = ModelManager()
assert manager._backend is None

mock_backend = MagicMock()
audio_chunk = AudioChunk(np.zeros(10, dtype=np.float32))

async def fake_generate(*args, **kwargs):
yield audio_chunk

mock_backend.generate = fake_generate

async def fake_initialize():
manager._backend = mock_backend

with (
patch.object(manager, "initialize", side_effect=fake_initialize) as mock_init,
patch.object(manager, "load_model", new_callable=AsyncMock) as mock_load,
):
chunks = []
async for chunk in manager.generate("hello", ("voice", "/path/voice.pt")):
chunks.append(chunk)

mock_init.assert_called_once()
mock_load.assert_called_once_with(manager._config.pytorch_kokoro_v1_file)
assert len(chunks) == 1
assert chunks[0] is audio_chunk


@pytest.mark.asyncio
async def test_generate_skips_reinit_when_backend_set():
"""generate() does not call initialize/load_model when backend already exists."""
manager = ModelManager()
mock_backend = MagicMock()
audio_chunk = AudioChunk(np.zeros(10, dtype=np.float32))

async def fake_generate(*args, **kwargs):
yield audio_chunk

mock_backend.generate = fake_generate
manager._backend = mock_backend

with (
patch.object(manager, "initialize", new_callable=AsyncMock) as mock_init,
patch.object(manager, "load_model", new_callable=AsyncMock) as mock_load,
):
chunks = []
async for chunk in manager.generate("hello", ("voice", "/path/voice.pt")):
chunks.append(chunk)

mock_init.assert_not_called()
mock_load.assert_not_called()
assert len(chunks) == 1


# ---------------------------------------------------------------------------
# POST /dev/unload endpoint tests
# ---------------------------------------------------------------------------


def _mock_service(manager=None):
"""Build a TTSService-shaped mock with the given model_manager."""
service = MagicMock(spec=TTSService)
service.model_manager = manager
return service


def test_unload_endpoint_returns_200():
mock_manager = AsyncMock()
mock_manager.unload = AsyncMock()
service = _mock_service(manager=mock_manager)

with override_tts_service(service):
response = client.post("/dev/unload")

assert response.status_code == 200
assert response.json() == {"status": "unloaded"}
mock_manager.unload.assert_called_once()


def test_unload_endpoint_idempotent():
"""Calling /dev/unload twice both succeed — unload is a no-op when already clear."""
mock_manager = AsyncMock()
mock_manager.unload = AsyncMock()
service = _mock_service(manager=mock_manager)

with override_tts_service(service):
r1 = client.post("/dev/unload")
r2 = client.post("/dev/unload")

assert r1.status_code == 200
assert r2.status_code == 200
assert mock_manager.unload.call_count == 2


def test_unload_endpoint_503_when_manager_none():
"""Returns 503 when model_manager has not been initialised on the service."""
service = _mock_service(manager=None)

with override_tts_service(service):
response = client.post("/dev/unload")

assert response.status_code == 503
assert response.json()["detail"]["error"] == "Model manager not initialized"


def test_unload_endpoint_500_on_exception():
"""Returns 500 when manager.unload() raises unexpectedly."""
mock_manager = AsyncMock()
mock_manager.unload = AsyncMock(side_effect=RuntimeError("GPU exploded"))
service = _mock_service(manager=mock_manager)

with override_tts_service(service):
response = client.post("/dev/unload")

assert response.status_code == 500
assert "GPU exploded" in response.json()["detail"]["error"]
Loading