Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions maia2/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_preds(model, dataloader, all_moves_dict_reversed):
legal_moves = legal_moves.to(device)

logits_maia, _, logits_value = model(boards, elos_self, elos_oppo)
logits_maia_legal = logits_maia * legal_moves
logits_maia_legal = logits_maia + legal_moves.log()
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

legal_moves.log() will produce -inf for masked entries; if a position ever has zero legal moves (e.g., checkmate/stalemate) or the mask is accidentally all zeros, softmax over all -inf yields NaN probabilities (0/0). Consider explicitly handling the all-masked case (early return / raise) or using a safe fallback before softmax.

Suggested change
logits_maia_legal = logits_maia + legal_moves.log()
logits_maia_legal = logits_maia + legal_moves.log()
# If a position has no legal moves (all zeros in the mask), logits_maia_legal will be all -inf;
# softmax over all -inf yields NaNs. For such rows, fall back to the unmasked logits.
has_legal = legal_moves.sum(dim=-1, keepdim=True) > 0
if not has_legal.all():
logits_maia_legal = torch.where(has_legal, logits_maia_legal, logits_maia)

Copilot uses AI. Check for mistakes.
probs = logits_maia_legal.softmax(dim=-1).cpu().tolist()

logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist()
Expand Down Expand Up @@ -154,7 +154,7 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo):
legal_moves = legal_moves.unsqueeze(dim=0).to(device)

logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo)
logits_maia_legal = logits_maia * legal_moves
logits_maia_legal = logits_maia + legal_moves.log()
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same masking approach here: if legal_moves is all zeros, legal_moves.log() makes the entire row -inf and softmax will return NaNs. Please add a guard for terminal positions / empty legal-move sets or otherwise ensure at least one legal move before applying softmax.

Suggested change
logits_maia_legal = logits_maia + legal_moves.log()
# Guard against the case where there are no legal moves (all zeros), which would make
# legal_moves.log() equal to -inf everywhere and softmax return NaNs.
if (legal_moves > 0).any():
legal_moves_for_mask = legal_moves
else:
# In terminal positions with no legal moves, use an all-ones mask so that .log()
# produces zeros and does not introduce -inf values.
legal_moves_for_mask = torch.ones_like(legal_moves)
logits_maia_legal = logits_maia + legal_moves_for_mask.log()

Copilot uses AI. Check for mistakes.
probs = logits_maia_legal.softmax(dim=-1).cpu().tolist()

logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item()
Expand Down