Skip to content

Commit e15391b

Browse files
committed
Improve bot first-move latency via inference prewarm and UI flow/perf polish
1 parent c4d7662 commit e15391b

18 files changed

Lines changed: 722 additions & 123 deletions

src/api/app.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncIterator
4+
from contextlib import asynccontextmanager
5+
36
from fastapi import FastAPI
47
from fastapi.middleware.cors import CORSMiddleware
58
from fastapi.staticfiles import StaticFiles
69

710
from api.config import Settings, get_settings
11+
from api.deps.inference import preload_inference_service
812
from api.error_handling import register_error_handlers
913
from api.modules.auth.rate_limit import AuthRateLimiter
1014
from api.modules.auth.router import router as auth_router
@@ -24,11 +28,20 @@
2428
def create_app(settings: Settings | None = None) -> FastAPI:
2529
cfg = settings or get_settings()
2630
configure_logging(cfg)
31+
32+
@asynccontextmanager
33+
async def _lifespan(_app: FastAPI) -> AsyncIterator[None]:
34+
# Preload once at process startup to remove cold-start lag from the
35+
# first model move users see in live matches.
36+
preload_inference_service()
37+
yield
38+
2739
app = FastAPI(
2840
title=cfg.app_name,
2941
debug=cfg.app_debug,
3042
docs_url=cfg.docs_url,
3143
redoc_url=cfg.redoc_url,
44+
lifespan=_lifespan,
3245
)
3346
register_error_handlers(app)
3447
if cfg.app_log_requests:

src/api/deps/inference.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from functools import lru_cache
45

56
from fastapi import HTTPException, status
@@ -8,6 +9,8 @@
89
from api.inference_artifacts import resolve_artifact_uri
910
from inference.service import InferenceService
1011

12+
logger = logging.getLogger(__name__)
13+
1114

1215
@lru_cache(maxsize=1)
1316
def _build_inference_service(
@@ -54,3 +57,18 @@ def get_inference_service_dep() -> InferenceService:
5457
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
5558
detail=f"Inference service unavailable: {exc}",
5659
) from exc
60+
61+
62+
def preload_inference_service() -> InferenceService | None:
63+
"""
64+
Best-effort preload to avoid first-turn latency spikes in bot matches.
65+
"""
66+
try:
67+
service = get_inference_service_dep()
68+
service.warmup(mode="fast")
69+
return service
70+
except HTTPException as exc:
71+
logger.warning("Inference preload skipped.", extra={"detail": str(exc.detail)})
72+
except Exception: # pragma: no cover - defensive path for runtime-specific failures.
73+
logger.exception("Inference preload crashed unexpectedly.")
74+
return None

src/api/modules/gameplay/router.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
from agents.heuristic import heuristic_move
2121
from agents.random_agent import random_move
2222
from api.config import Settings, get_settings
23-
from api.db.enums import GameStatus
23+
from api.db.enums import AgentType, GameStatus
2424
from api.db.models import Game, User
2525
from api.deps.auth import get_auth_service_dep, get_current_user_dep
2626
from api.deps.gameplay import get_gameplay_service_dep
27-
from api.deps.inference import get_inference_service_dep
27+
from api.deps.inference import get_inference_service_dep, preload_inference_service
2828
from api.modules.auth.service import AuthService
2929
from api.modules.gameplay.schemas import (
3030
GameCreateRequest,
@@ -285,6 +285,8 @@ async def post_game(
285285
status_code=status.HTTP_400_BAD_REQUEST,
286286
detail=str(exc),
287287
) from exc
288+
if game.player1_agent == AgentType.MODEL or game.player2_agent == AgentType.MODEL:
289+
preload_inference_service()
288290
return await _to_game_response(gameplay_service, game)
289291

290292

src/api/modules/matches/model_runtime.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,17 @@ def resolve_model_inference_service(
7070
)
7171

7272

73-
__all__ = ["resolve_model_inference_service"]
73+
def prewarm_model_inference_service(
74+
*,
75+
version: ModelVersion,
76+
base_service: InferenceService | None,
77+
) -> InferenceService:
78+
service = resolve_model_inference_service(
79+
version=version,
80+
base_service=base_service,
81+
)
82+
service.warmup(mode="fast")
83+
return service
84+
85+
86+
__all__ = ["prewarm_model_inference_service", "resolve_model_inference_service"]

src/api/modules/matches/service.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
WinnerSide,
1717
)
1818
from api.db.models import BotProfile, Game, GameMove, User
19-
from api.modules.matches.model_runtime import resolve_model_inference_service
19+
from api.modules.matches.model_runtime import (
20+
prewarm_model_inference_service,
21+
resolve_model_inference_service,
22+
)
2023
from api.modules.matches.repository import MatchesRepository
2124
from api.modules.matches.schemas import MatchCreateRequest, MatchMoveRequest
2225
from api.modules.ranking.service import RankingService
@@ -52,6 +55,11 @@ async def create_match(self, payload: MatchCreateRequest, actor_user_id: UUID) -
5255
player2_agent = profile.agent_type
5356
if profile.agent_type == AgentType.MODEL and model_version_id is None:
5457
model_version_id = player2.model_version_id
58+
if profile.agent_type == AgentType.MODEL and model_version_id is not None:
59+
await self._prewarm_model_runtime(
60+
version_id=model_version_id,
61+
fallback_service=None,
62+
)
5563

5664
game = Game(
5765
season_id=payload.season_id,
@@ -96,6 +104,11 @@ async def create_invitation(
96104

97105
if opponent.is_bot:
98106
profile = await self._get_enabled_bot_profile(opponent.id)
107+
if profile.agent_type == AgentType.MODEL and opponent.model_version_id is not None:
108+
await self._prewarm_model_runtime(
109+
version_id=opponent.model_version_id,
110+
fallback_service=None,
111+
)
99112
now = datetime.now(timezone.utc).replace(tzinfo=None)
100113
return await self.repository.create_game(
101114
Game(
@@ -468,5 +481,22 @@ async def _resolve_model_bot_inference_service(
468481
f"Model version '{version.name}' has no usable local artifact for inference."
469482
) from None
470483

484+
async def _prewarm_model_runtime(
485+
self,
486+
*,
487+
version_id: UUID,
488+
fallback_service: InferenceService | None,
489+
) -> None:
490+
version = await self.repository.get_model_version(version_id)
491+
if version is None:
492+
return
493+
try:
494+
prewarm_model_inference_service(
495+
version=version,
496+
base_service=fallback_service,
497+
)
498+
except (FileNotFoundError, RuntimeError, ValueError):
499+
return
500+
471501

472502

src/inference/service.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
"Inference initialization failed: neither torch checkpoint nor ONNX session is available."
155155
)
156156
self._mcts: MCTS | None = None
157+
self._is_warmed_up = False
157158

158159
@staticmethod
159160
def _resolve_device(device: str) -> str:
@@ -414,6 +415,15 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
414415
value = float(value_tensor.item())
415416
return InferenceResult(move=move, action_idx=action_idx, value=value, mode="strong")
416417

418+
def warmup(self, *, mode: InferenceMode = "fast") -> None:
419+
"""
420+
Prime inference runtime once so the first real bot turn avoids cold-start latency.
421+
"""
422+
if self._is_warmed_up:
423+
return
424+
self.predict(board=AtaxxBoard(), mode=mode)
425+
self._is_warmed_up = True
426+
417427
def predict(self, board: AtaxxBoard, *, mode: InferenceMode = "fast") -> InferenceResult:
418428
if board.is_game_over():
419429
return InferenceResult(

tests/test_api_games.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import unittest
55
from pathlib import Path
6+
from unittest.mock import patch
67
from uuid import UUID, uuid4
78

89
from fastapi.testclient import TestClient
@@ -102,6 +103,34 @@ def test_create_game(self) -> None:
102103
self.assertEqual(payload["player2_agent"], "heuristic")
103104
self.assertIn("id", payload)
104105

106+
def test_create_game_prewarms_inference_when_model_agent_is_present(self) -> None:
107+
client, _ = self._client_with_stub()
108+
with patch("api.modules.gameplay.router.preload_inference_service") as preload:
109+
response = client.post(
110+
"/api/v1/gameplay/games",
111+
json={
112+
"queue_type": "vs_ai",
113+
"player1_agent": "human",
114+
"player2_agent": "model",
115+
},
116+
)
117+
self.assertEqual(response.status_code, 201)
118+
preload.assert_called_once()
119+
120+
def test_create_game_skips_prewarm_for_non_model_agents(self) -> None:
121+
client, _ = self._client_with_stub()
122+
with patch("api.modules.gameplay.router.preload_inference_service") as preload:
123+
response = client.post(
124+
"/api/v1/gameplay/games",
125+
json={
126+
"queue_type": "vs_ai",
127+
"player1_agent": "human",
128+
"player2_agent": "heuristic",
129+
},
130+
)
131+
self.assertEqual(response.status_code, 201)
132+
preload.assert_not_called()
133+
105134
def test_get_game_by_id(self) -> None:
106135
client, _ = self._client_with_stub()
107136
created = client.post(

tests/test_api_inference_dep.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import sys
44
import unittest
55
from pathlib import Path
6-
from unittest.mock import patch
6+
from unittest.mock import Mock, patch
77

88
from fastapi import HTTPException
99

1010
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
1111

12-
from api.deps.inference import get_inference_service_dep
12+
from api.deps.inference import get_inference_service_dep, preload_inference_service
1313

1414

1515
class TestApiInferenceDep(unittest.TestCase):
@@ -22,7 +22,23 @@ def test_maps_module_not_found_to_http_503(self, *_: object) -> None:
2222
self.assertEqual(ctx.exception.status_code, 503)
2323
self.assertIn("Inference service unavailable", str(ctx.exception.detail))
2424

25+
@patch("api.deps.inference.get_inference_service_dep")
26+
def test_preload_inference_service_warms_up_once(self, get_dep: Mock) -> None:
27+
service = Mock()
28+
get_dep.return_value = service
29+
30+
resolved = preload_inference_service()
31+
32+
self.assertIs(resolved, service)
33+
service.warmup.assert_called_once_with(mode="fast")
34+
35+
@patch(
36+
"api.deps.inference.get_inference_service_dep",
37+
side_effect=HTTPException(status_code=503, detail="inference unavailable"),
38+
)
39+
def test_preload_inference_service_returns_none_when_unavailable(self, *_: object) -> None:
40+
self.assertIsNone(preload_inference_service())
41+
2542

2643
if __name__ == "__main__":
2744
unittest.main()
28-

tests/test_matches_service_model_inference.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,50 @@ async def _run() -> None:
172172

173173
asyncio.run(_run())
174174

175+
def test_create_invitation_prewarms_runtime_for_model_bot(self) -> None:
176+
async def _run() -> None:
177+
async with self.sessionmaker() as session:
178+
service = MatchesService(repository=MatchesRepository(session=session))
179+
human = User(username="human-c", email="human-c@example.com", is_active=True)
180+
version = ModelVersion(
181+
name="ub_policy_spatial_v2",
182+
checkpoint_uri="checkpoints/policy_spatial_v2.ckpt",
183+
is_active=False,
184+
)
185+
bot = User(
186+
username="ub_bogonet_warmup",
187+
email="bogonet@example.com",
188+
is_active=True,
189+
is_bot=True,
190+
bot_kind=BotKind.MODEL,
191+
model_version_id=version.id,
192+
)
193+
session.add(human)
194+
session.add(version)
195+
session.add(bot)
196+
session.add(
197+
BotProfile(
198+
user_id=bot.id,
199+
agent_type=AgentType.MODEL,
200+
model_mode="fast",
201+
enabled=True,
202+
)
203+
)
204+
await session.commit()
205+
206+
with patch("api.modules.matches.service.prewarm_model_inference_service") as prewarm:
207+
game = await service.create_invitation(
208+
actor_user_id=human.id,
209+
opponent_user_id=bot.id,
210+
rated=False,
211+
)
212+
213+
self.assertEqual(game.status, GameStatus.IN_PROGRESS)
214+
self.assertEqual(game.player2_agent, AgentType.MODEL)
215+
prewarm.assert_called_once()
216+
217+
asyncio.run(_run())
218+
175219

176220
if __name__ == "__main__":
177221
unittest.main()

0 commit comments

Comments
 (0)