diff --git a/maia2/__init__.py b/maia2/__init__.py index 784611b..43c24d8 100644 --- a/maia2/__init__.py +++ b/maia2/__init__.py @@ -1,3 +1,3 @@ """An amazing sample package for maia2.""" -__version__ = "0.9" +__version__ = "0.9.1" diff --git a/maia2/inference.py b/maia2/inference.py index f4f2cfa..6fe4e55 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -60,7 +60,7 @@ def get_preds(model, dataloader, all_moves_dict_reversed): legal_moves = legal_moves.to(device) logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) - logits_maia_legal = logits_maia * legal_moves + logits_maia_legal = logits_maia.masked_fill(legal_moves == 0, float('-inf')) probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist() @@ -154,7 +154,7 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo): legal_moves = legal_moves.unsqueeze(dim=0).to(device) logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) - logits_maia_legal = logits_maia * legal_moves + logits_maia_legal = logits_maia.masked_fill(legal_moves == 0, float('-inf')) probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item() diff --git a/maia2/main.py b/maia2/main.py index 32b061d..b84c8f0 100644 --- a/maia2/main.py +++ b/maia2/main.py @@ -84,16 +84,19 @@ def process_per_game(game, white_elo, black_elo, white_win, cfg): def game_filter(game): - + white_elo = game.headers.get("WhiteElo", "?") black_elo = game.headers.get("BlackElo", "?") time_control = game.headers.get("TimeControl", "?") result = game.headers.get("Result", "?") event = game.headers.get("Event", "?") - + if white_elo == "?" or black_elo == "?" or time_control == "?" or result == "?" or event == "?": return + if game.headers.get("WhiteTitle") == "BOT" or game.headers.get("BlackTitle") == "BOT": + return + if 'Rated' not in event: return diff --git a/maia2/requirements.txt b/maia2/requirements.txt index a94f2cc..0fd8809 100644 --- a/maia2/requirements.txt +++ b/maia2/requirements.txt @@ -5,6 +5,6 @@ numpy==2.1.3 pandas==2.2.3 pyyaml==6.0.2 pyzstd==0.15.9 -Requests==2.32.3 +requests==2.32.3 torch==2.4.0 tqdm==4.65.0 \ No newline at end of file