Skip to content

Commit c781f56

Browse files
committed
Fix inference fallback, stabilize match sync, and add regression coverage
1 parent 12fcfcf commit c781f56

9 files changed

Lines changed: 163 additions & 20 deletions

File tree

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ INFERENCE_MODE_DEFAULT="fast" # fast | strong
4545
INFERENCE_DEVICE="auto" # auto | cpu | cuda
4646
INFERENCE_MCTS_SIMS=160
4747
INFERENCE_C_PUCT=1.5
48+
INFERENCE_FALLBACK_HEURISTIC_LEVEL="easy" # easy | normal | hard
4849

4950
# Auth / JWT
5051
AUTH_JWT_SECRET="change_me_with_a_long_random_secret"

src/api/config/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from functools import lru_cache
4+
from typing import Literal
45
from urllib.parse import quote, urlsplit, urlunsplit
56

67
from pydantic import computed_field, model_validator
@@ -60,6 +61,7 @@ class Settings(BaseSettings):
6061
inference_mcts_sims: int = 160
6162
inference_c_puct: float = 1.5
6263
inference_prefer_onnx: bool = True
64+
inference_fallback_heuristic_level: Literal["easy", "normal", "hard"] = "easy"
6365

6466
# Auth/JWT
6567
auth_jwt_secret: str = ""

src/api/modules/gameplay/router.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Annotated
4+
from typing import Annotated, Literal
55
from uuid import UUID
66

77
import numpy as np
@@ -19,6 +19,7 @@
1919

2020
from agents.heuristic import heuristic_move
2121
from agents.random_agent import random_move
22+
from api.config import Settings, get_settings
2223
from api.db.enums import GameStatus
2324
from api.db.models import Game, User
2425
from api.deps.auth import get_auth_service_dep, get_current_user_dep
@@ -48,6 +49,14 @@
4849
CURRENT_USER_DEP = Depends(get_current_user_dep)
4950
AUTH_SERVICE_DEP = Depends(get_auth_service_dep)
5051
logger = logging.getLogger(__name__)
52+
FALLBACK_MODE_BY_LEVEL: dict[
53+
Literal["easy", "normal", "hard"],
54+
Literal["heuristic_easy", "heuristic_normal", "heuristic_hard"],
55+
] = {
56+
"easy": "heuristic_easy",
57+
"normal": "heuristic_normal",
58+
"hard": "heuristic_hard",
59+
}
5160

5261

5362
def _resolve_inference_service(request: Request) -> InferenceService:
@@ -59,6 +68,13 @@ def _resolve_inference_service(request: Request) -> InferenceService:
5968
return provider()
6069

6170

71+
def _resolve_settings(request: Request) -> Settings:
72+
state_settings = getattr(request.app.state, "settings", None)
73+
if isinstance(state_settings, Settings):
74+
return state_settings
75+
return get_settings()
76+
77+
6278
async def _to_game_response(
6379
gameplay_service: GameplayService,
6480
game: Game,
@@ -154,26 +170,29 @@ def post_move(
154170
except HTTPException as exc:
155171
if exc.status_code != status.HTTP_503_SERVICE_UNAVAILABLE:
156172
raise
173+
settings = _resolve_settings(http_request)
174+
fallback_level = settings.inference_fallback_heuristic_level
175+
fallback_mode = FALLBACK_MODE_BY_LEVEL[fallback_level]
157176
# Keep PvE matches playable when model artifacts are missing in runtime.
158177
logger.warning(
159-
"Inference unavailable on /gameplay/move; falling back to heuristic_hard",
160-
extra={"detail": exc.detail},
178+
"Inference unavailable on /gameplay/move; falling back to heuristic",
179+
extra={"detail": exc.detail, "fallback_level": fallback_level},
161180
)
162181
rng = np.random.default_rng()
163-
fallback_move = heuristic_move(board=board, rng=rng, level="hard")
182+
fallback_move = heuristic_move(board=board, rng=rng, level=fallback_level)
164183
if fallback_move is None:
165184
return MoveResponse(
166185
move=None,
167186
action_idx=ACTION_SPACE.pass_index,
168187
value=0.0,
169-
mode="heuristic_hard",
188+
mode=fallback_mode,
170189
)
171190
r1, c1, r2, c2 = fallback_move
172191
return MoveResponse(
173192
move=MovePayload(r1=r1, c1=c1, r2=r2, c2=c2),
174193
action_idx=ACTION_SPACE.encode(fallback_move),
175194
value=0.0,
176-
mode="heuristic_hard",
195+
mode=fallback_mode,
177196
)
178197

179198
rng = np.random.default_rng()

src/api/modules/matches/router.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MATCHES_SERVICE_DEP = Depends(get_matches_service_dep)
3636
CURRENT_USER_DEP = Depends(get_current_user_dep)
3737
AUTH_SERVICE_DEP = Depends(get_auth_service_dep)
38+
INVITATIONS_WS_REFRESH_S = 8.0
3839

3940

4041
@router.post(
@@ -233,8 +234,11 @@ async def invitations_ws(
233234
}
234235
)
235236
try:
236-
# Lower polling pressure on DB while still keeping invitation UI responsive.
237-
await asyncio.wait_for(websocket.receive_text(), timeout=2.5)
237+
# Lower DB pressure: invitation updates do not need sub-second cadence.
238+
await asyncio.wait_for(
239+
websocket.receive_text(),
240+
timeout=INVITATIONS_WS_REFRESH_S,
241+
)
238242
except (TimeoutError, asyncio.TimeoutError):
239243
continue
240244
except (WebSocketDisconnect, asyncio.TimeoutError):

src/inference/service.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from game.types import Move
1515

1616
if TYPE_CHECKING:
17+
import torch.nn as nn
18+
1719
from engine.mcts import MCTS
1820

1921
InferenceMode = Literal["fast", "strong"]
@@ -57,7 +59,9 @@ def run(self, output_names: list[str] | None, input_feed: dict[str, Any]) -> lis
5759

5860

5961
class _SystemLike(Protocol):
60-
model: Any
62+
@property
63+
def model(self) -> nn.Module:
64+
...
6165

6266
def eval(self) -> _SystemLike:
6367
...
@@ -69,6 +73,28 @@ def load_state_dict(self, state_dict: dict[str, object]) -> object:
6973
...
7074

7175

76+
class _CheckpointSystemAdapter:
77+
"""Minimal runtime wrapper to use plain torch modules as inference systems."""
78+
79+
def __init__(self, model: nn.Module) -> None:
80+
self._model = model
81+
82+
@property
83+
def model(self) -> nn.Module:
84+
return self._model
85+
86+
def eval(self) -> _CheckpointSystemAdapter:
87+
self._model.eval()
88+
return self
89+
90+
def to(self, device: str) -> _CheckpointSystemAdapter:
91+
self._model.to(device)
92+
return self
93+
94+
def load_state_dict(self, state_dict: dict[str, object]) -> object:
95+
return self._model.load_state_dict(state_dict)
96+
97+
7298
@lru_cache(maxsize=1)
7399
def _get_torch_module() -> ModuleType | None:
74100
"""Import torch lazily so API startup does not hard-fail in lightweight runtimes."""
@@ -165,17 +191,31 @@ def _extract_arch_kwargs(raw_kwargs: ModelInitKwargs) -> dict[str, Any]:
165191
allowed = ("d_model", "nhead", "num_layers", "dim_feedforward", "dropout")
166192
return {key: raw_kwargs[key] for key in allowed if key in raw_kwargs}
167193

194+
@staticmethod
195+
def _extract_model_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
196+
# Training checkpoints prefix model params with `model.` (Lightning module layout).
197+
# Runtime inference uses the raw network, so we strip this prefix when present.
198+
if all(key.startswith("model.") for key in state_dict):
199+
return {key.removeprefix("model."): value for key, value in state_dict.items()}
200+
return state_dict
201+
168202
def _build_legacy_system(self) -> _SystemLike:
169203
from inference.legacy_model import LegacyAtaxxSystem
170204

171205
return LegacyAtaxxSystem(**self._extract_arch_kwargs(self.model_kwargs))
172206

173-
def _load_system(self) -> _SystemLike:
174-
from model.system import AtaxxZero
207+
def _build_spatial_system(self) -> _SystemLike:
208+
from model.transformer import AtaxxTransformerNet
209+
210+
model = AtaxxTransformerNet(**self._extract_arch_kwargs(self.model_kwargs))
211+
return _CheckpointSystemAdapter(model)
175212

213+
def _load_system(self) -> _SystemLike:
176214
torch_module = self._require_torch()
177215
ckpt = self.checkpoint_path
178216
if ckpt.suffix == ".ckpt":
217+
from model.system import AtaxxZero
218+
179219
try:
180220
return AtaxxZero.load_from_checkpoint(str(ckpt), map_location=self.device)
181221
except RuntimeError as exc:
@@ -191,9 +231,9 @@ def _load_system(self) -> _SystemLike:
191231
if not isinstance(state_dict_obj, dict):
192232
raise ValueError("Checkpoint dictionary must contain key 'state_dict'.")
193233

194-
system = AtaxxZero(**self.model_kwargs)
234+
system = self._build_spatial_system()
195235
try:
196-
system.load_state_dict(state_dict_obj)
236+
system.load_state_dict(self._extract_model_state_dict(state_dict_obj))
197237
except RuntimeError as exc:
198238
if self._is_legacy_state_dict(state_dict_obj):
199239
legacy_system = self._build_legacy_system()

tests/test_api_move.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
1111

1212
from api.app import create_app
13+
from api.config import Settings
1314
from api.deps.inference import get_inference_service_dep
1415
from api.modules.gameplay.schemas import MoveRequest
1516
from game.actions import ACTION_SPACE
@@ -119,6 +120,27 @@ def _unavailable_inference() -> _StubInferenceService:
119120
response = client.post("/api/v1/gameplay/move", json=payload)
120121
self.assertEqual(response.status_code, 200)
121122

123+
body = response.json()
124+
self.assertEqual(body["mode"], "heuristic_easy")
125+
self.assertIsInstance(body["action_idx"], int)
126+
127+
def test_move_endpoint_fallback_level_honors_settings(self) -> None:
128+
app = create_app(settings=Settings(inference_fallback_heuristic_level="hard"))
129+
130+
def _unavailable_inference() -> _StubInferenceService:
131+
raise HTTPException(
132+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
133+
detail="Inference unavailable in test",
134+
)
135+
136+
app.dependency_overrides[get_inference_service_dep] = _unavailable_inference
137+
client = TestClient(app)
138+
139+
board = AtaxxBoard()
140+
payload = MoveRequest(board=board_to_state(board), mode="fast").model_dump()
141+
response = client.post("/api/v1/gameplay/move", json=payload)
142+
self.assertEqual(response.status_code, 200)
143+
122144
body = response.json()
123145
self.assertEqual(body["mode"], "heuristic_hard")
124146
self.assertIsInstance(body["action_idx"], int)

tests/test_inference_service.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from inference.legacy_model import LegacyAtaxxSystem
1818
from inference.service import InferenceService
1919
from model.system import AtaxxZero
20+
from model.transformer import AtaxxTransformerNet
2021

2122

2223
class TestInferenceService(unittest.TestCase):
@@ -58,6 +59,49 @@ def test_fast_mode_returns_legal_action(self) -> None:
5859
self.assertIn(result.move, legal_moves)
5960
self.assertTrue(-1.0 <= result.value <= 1.0)
6061

62+
def test_pt_checkpoint_load_does_not_require_lightning_runtime(self) -> None:
63+
with tempfile.TemporaryDirectory() as tmp_dir:
64+
model = AtaxxTransformerNet(
65+
d_model=64,
66+
nhead=8,
67+
num_layers=2,
68+
dim_feedforward=128,
69+
dropout=0.0,
70+
)
71+
ckpt_path = Path(tmp_dir) / "spatial.pt"
72+
state_dict = {f"model.{key}": value for key, value in model.state_dict().items()}
73+
torch.save({"state_dict": state_dict}, ckpt_path)
74+
75+
native_import = __import__
76+
77+
def guarded_import(
78+
name: str,
79+
globals_: dict[str, object] | None = None,
80+
locals_: dict[str, object] | None = None,
81+
fromlist: tuple[str, ...] = (),
82+
level: int = 0,
83+
) -> object:
84+
if name.startswith("pytorch_lightning"):
85+
raise ModuleNotFoundError("pytorch_lightning blocked by test")
86+
return native_import(name, globals_, locals_, fromlist, level)
87+
88+
with patch("builtins.__import__", side_effect=guarded_import):
89+
service = InferenceService(
90+
checkpoint_path=ckpt_path,
91+
device="cpu",
92+
model_kwargs={
93+
"d_model": 64,
94+
"nhead": 8,
95+
"num_layers": 2,
96+
"dim_feedforward": 128,
97+
"dropout": 0.0,
98+
},
99+
)
100+
result = service.predict(AtaxxBoard(), mode="fast")
101+
102+
self.assertEqual(result.mode, "fast")
103+
self.assertIsNotNone(result.move)
104+
61105
def test_strong_mode_returns_legal_action(self) -> None:
62106
with tempfile.TemporaryDirectory() as tmp_dir:
63107
system = self._tiny_system()

web/src/pages/match/MatchPage.tsx

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ const AI_THINK_DELAY_MS = 460;
6262
const AI_PREVIEW_MS = 420;
6363
const INFECTION_STEP_MS = 90;
6464
const INFECTION_BURST_MS = 420;
65-
const OUTGOING_INVITE_POLL_MS = 2500;
65+
const OUTGOING_INVITE_POLL_MS = 4000;
6666
const UI_TICK_MS = 120;
6767
const INTRO_COUNTDOWN_START = 3;
6868
const HOVER_SFX_MIN_GAP_MS = 120;
@@ -383,6 +383,7 @@ export function MatchPage(): JSX.Element {
383383
const gameplayWsRef = useRef<WebSocket | null>(null);
384384
const lastWsPlyRef = useRef(-1);
385385
const persistQueueRef = useRef<Promise<void>>(Promise.resolve());
386+
const latestBoardRef = useRef<BoardState>(board);
386387
const failedPersistOpsRef = useRef<PendingPersistOperation[]>([]);
387388
const unmountCleanupTriggeredRef = useRef(false);
388389
const unmountCleanupStateRef = useRef<{
@@ -416,6 +417,10 @@ export function MatchPage(): JSX.Element {
416417
}
417418
}, [accessToken]);
418419

420+
useEffect(() => {
421+
latestBoardRef.current = board;
422+
}, [board]);
423+
419424
useEffect(() => {
420425
unmountCleanupStateRef.current = {
421426
accessToken,
@@ -1372,7 +1377,10 @@ export function MatchPage(): JSX.Element {
13721377

13731378
if (event.move.board_after !== null) {
13741379
const boardAfter = event.move.board_after as BoardState;
1375-
setBoard(boardAfter);
1380+
// Ignore stale snapshots: delayed WS frames used to overwrite a newer local board.
1381+
if (boardAfter.half_moves >= latestBoardRef.current.half_moves) {
1382+
setBoard(boardAfter);
1383+
}
13761384
}
13771385
const remoteMove =
13781386
event.move.r1 === null || event.move.c1 === null || event.move.r2 === null || event.move.c2 === null
@@ -1798,14 +1806,15 @@ export function MatchPage(): JSX.Element {
17981806

17991807
const persistMoveWithRetry = useCallback(
18001808
async (operation: PendingPersistOperation) => {
1801-
if (!canPersist || accessToken === null) {
1809+
const token = lastAccessTokenRef.current;
1810+
if (!canPersist || token === null) {
18021811
throw new Error("Persistencia no disponible.");
18031812
}
18041813
let lastError: unknown = null;
18051814
for (let attempt = 1; attempt <= PERSIST_MAX_RETRIES; attempt += 1) {
18061815
try {
18071816
await storeManualMove(
1808-
accessToken,
1817+
token,
18091818
operation.gameId,
18101819
operation.beforeBoard,
18111820
operation.move,
@@ -1826,7 +1835,7 @@ export function MatchPage(): JSX.Element {
18261835
}
18271836
throw lastError;
18281837
},
1829-
[accessToken, canPersist],
1838+
[canPersist],
18301839
);
18311840

18321841
const disableRemotePersistence = useCallback(

0 commit comments

Comments
 (0)