1414from game .types import Move
1515
1616if TYPE_CHECKING :
17+ import torch .nn as nn
18+
1719 from engine .mcts import MCTS
1820
1921InferenceMode = Literal ["fast" , "strong" ]
@@ -57,7 +59,9 @@ def run(self, output_names: list[str] | None, input_feed: dict[str, Any]) -> lis
5759
5860
5961class _SystemLike (Protocol ):
60- model : Any
62+ @property
63+ def model (self ) -> nn .Module :
64+ ...
6165
6266 def eval (self ) -> _SystemLike :
6367 ...
@@ -69,6 +73,28 @@ def load_state_dict(self, state_dict: dict[str, object]) -> object:
6973 ...
7074
7175
76+ class _CheckpointSystemAdapter :
77+ """Minimal runtime wrapper to use plain torch modules as inference systems."""
78+
79+ def __init__ (self , model : nn .Module ) -> None :
80+ self ._model = model
81+
82+ @property
83+ def model (self ) -> nn .Module :
84+ return self ._model
85+
86+ def eval (self ) -> _CheckpointSystemAdapter :
87+ self ._model .eval ()
88+ return self
89+
90+ def to (self , device : str ) -> _CheckpointSystemAdapter :
91+ self ._model .to (device )
92+ return self
93+
94+ def load_state_dict (self , state_dict : dict [str , object ]) -> object :
95+ return self ._model .load_state_dict (state_dict )
96+
97+
7298@lru_cache (maxsize = 1 )
7399def _get_torch_module () -> ModuleType | None :
74100 """Import torch lazily so API startup does not hard-fail in lightweight runtimes."""
@@ -165,17 +191,31 @@ def _extract_arch_kwargs(raw_kwargs: ModelInitKwargs) -> dict[str, Any]:
165191 allowed = ("d_model" , "nhead" , "num_layers" , "dim_feedforward" , "dropout" )
166192 return {key : raw_kwargs [key ] for key in allowed if key in raw_kwargs }
167193
194+ @staticmethod
195+ def _extract_model_state_dict (state_dict : dict [str , Any ]) -> dict [str , Any ]:
196+ # Training checkpoints prefix model params with `model.` (Lightning module layout).
197+ # Runtime inference uses the raw network, so we strip this prefix when present.
198+ if all (key .startswith ("model." ) for key in state_dict ):
199+ return {key .removeprefix ("model." ): value for key , value in state_dict .items ()}
200+ return state_dict
201+
168202 def _build_legacy_system (self ) -> _SystemLike :
169203 from inference .legacy_model import LegacyAtaxxSystem
170204
171205 return LegacyAtaxxSystem (** self ._extract_arch_kwargs (self .model_kwargs ))
172206
173- def _load_system (self ) -> _SystemLike :
174- from model .system import AtaxxZero
207+ def _build_spatial_system (self ) -> _SystemLike :
208+ from model .transformer import AtaxxTransformerNet
209+
210+ model = AtaxxTransformerNet (** self ._extract_arch_kwargs (self .model_kwargs ))
211+ return _CheckpointSystemAdapter (model )
175212
213+ def _load_system (self ) -> _SystemLike :
176214 torch_module = self ._require_torch ()
177215 ckpt = self .checkpoint_path
178216 if ckpt .suffix == ".ckpt" :
217+ from model .system import AtaxxZero
218+
179219 try :
180220 return AtaxxZero .load_from_checkpoint (str (ckpt ), map_location = self .device )
181221 except RuntimeError as exc :
@@ -191,9 +231,9 @@ def _load_system(self) -> _SystemLike:
191231 if not isinstance (state_dict_obj , dict ):
192232 raise ValueError ("Checkpoint dictionary must contain key 'state_dict'." )
193233
194- system = AtaxxZero ( ** self .model_kwargs )
234+ system = self ._build_spatial_system ( )
195235 try :
196- system .load_state_dict (state_dict_obj )
236+ system .load_state_dict (self . _extract_model_state_dict ( state_dict_obj ) )
197237 except RuntimeError as exc :
198238 if self ._is_legacy_state_dict (state_dict_obj ):
199239 legacy_system = self ._build_legacy_system ()
0 commit comments