|
16 | 16 | WinnerSide, |
17 | 17 | ) |
18 | 18 | from api.db.models import BotProfile, Game, GameMove, User |
19 | | -from api.modules.matches.model_runtime import resolve_model_inference_service |
| 19 | +from api.modules.matches.model_runtime import ( |
| 20 | + prewarm_model_inference_service, |
| 21 | + resolve_model_inference_service, |
| 22 | +) |
20 | 23 | from api.modules.matches.repository import MatchesRepository |
21 | 24 | from api.modules.matches.schemas import MatchCreateRequest, MatchMoveRequest |
22 | 25 | from api.modules.ranking.service import RankingService |
@@ -52,6 +55,11 @@ async def create_match(self, payload: MatchCreateRequest, actor_user_id: UUID) - |
52 | 55 | player2_agent = profile.agent_type |
53 | 56 | if profile.agent_type == AgentType.MODEL and model_version_id is None: |
54 | 57 | model_version_id = player2.model_version_id |
| 58 | + if profile.agent_type == AgentType.MODEL and model_version_id is not None: |
| 59 | + await self._prewarm_model_runtime( |
| 60 | + version_id=model_version_id, |
| 61 | + fallback_service=None, |
| 62 | + ) |
55 | 63 |
|
56 | 64 | game = Game( |
57 | 65 | season_id=payload.season_id, |
@@ -96,6 +104,11 @@ async def create_invitation( |
96 | 104 |
|
97 | 105 | if opponent.is_bot: |
98 | 106 | profile = await self._get_enabled_bot_profile(opponent.id) |
| 107 | + if profile.agent_type == AgentType.MODEL and opponent.model_version_id is not None: |
| 108 | + await self._prewarm_model_runtime( |
| 109 | + version_id=opponent.model_version_id, |
| 110 | + fallback_service=None, |
| 111 | + ) |
99 | 112 | now = datetime.now(timezone.utc).replace(tzinfo=None) |
100 | 113 | return await self.repository.create_game( |
101 | 114 | Game( |
@@ -468,5 +481,22 @@ async def _resolve_model_bot_inference_service( |
468 | 481 | f"Model version '{version.name}' has no usable local artifact for inference." |
469 | 482 | ) from None |
470 | 483 |
|
| 484 | + async def _prewarm_model_runtime( |
| 485 | + self, |
| 486 | + *, |
| 487 | + version_id: UUID, |
| 488 | + fallback_service: InferenceService | None, |
| 489 | + ) -> None: |
| 490 | + version = await self.repository.get_model_version(version_id) |
| 491 | + if version is None: |
| 492 | + return |
| 493 | + try: |
| 494 | + prewarm_model_inference_service( |
| 495 | + version=version, |
| 496 | + base_service=fallback_service, |
| 497 | + ) |
| 498 | + except (FileNotFoundError, RuntimeError, ValueError): |
| 499 | + return |
| 500 | + |
471 | 501 |
|
472 | 502 |
|
0 commit comments