1- from collections import defaultdict
21from math import sqrt
2+ from collections import defaultdict
33
4+ import torch
45from torch import nn , Tensor
56
67from src .Game import Game
@@ -20,19 +21,35 @@ def search(
2021 state = game .get_state ()
2122 if state not in visited :
2223 visited .add (state )
23- move_scores , v = agent (Tensor ([state ]))
24+ with torch .no_grad ():
25+ move_scores , v = agent (Tensor ([state ]))
2426 tuple (
25- P [state ].__setitem__ (move , move_scores [0 , index ])
27+ P [state ].__setitem__ (move , move_scores [0 , index ]. item () )
2628 for index , move in enumerate (game .all_moves )
2729 )
28- return - v
30+ return - v .item ()
31+ q_state = Q [state ]
32+ p_state = P [state ]
33+ n_state = N [state ]
34+ sqrt_value = sqrt (sum (n_state .values ()))
2935
36+ # def _get_action(game: Game):
37+ # return max(
38+ # game.get_possible_actions(),
39+ # key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
40+ # )
41+ # def _get_action(game: Game):
42+ # best_action = None
43+ # best_value = -float('inf')
44+ # for action in game.all_moves:
45+ # value = q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action])
46+ # if value > best_value and action.is_valid(game):
47+ # best_value, best_action = value, action
48+ # return best_action
3049 action = max (
3150 game .get_possible_actions (),
32- key = lambda action : Q [state ].get (action , 1 )
33- + c * P [state ][action ] * sqrt (sum (N [state ].values ())) / (1 + N [state ][action ]),
51+ key = lambda action : q_state .get (action , 1 ) + c * p_state [action ] * sqrt_value / (1 + n_state [action ]),
3452 )
35-
3653 next_game_state = game .perform (action )
3754 v = search (next_game_state , agent , c , N , visited , P , Q )
3855
0 commit comments