Skip to content

Commit 2d8e113

Browse files
authored
Merge pull request #4 from Tesla2000/feature/python-speedup
Feature/python speedup
2 parents 6340e62 + f2408eb commit 2d8e113

7 files changed

Lines changed: 36 additions & 13 deletions

File tree

Config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class _ConfigAgent:
2727
c = 0.2
2828
learning_rate = 1e-5
2929
debug = False
30-
pretrain = True
30+
# pretrain = True
31+
pretrain = False
3132

3233

3334
class Config(_ConfigPaths, _ConfigAgent):
@@ -38,7 +39,7 @@ class Config(_ConfigPaths, _ConfigAgent):
3839
train_batch_size = 128
3940
training_buffer_len = 100_000
4041
min_n_points_to_finish = 15
41-
n_simulations = 100
42+
n_simulations = 1000
4243
n_games = None
4344
n_players = 2
4445
n_actions = 45

agent/Agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class Agent(nn.Module):
1010
_input_size_dictionary = {
11-
2: 205,
11+
2: 211,
1212
}
1313

1414
def __init__(

agent/policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def policy(
1818
P = defaultdict(dict)
1919
Q = defaultdict(dict)
2020
initial_state = game.get_state()
21-
all_moves = game.get_possible_actions()
21+
all_moves = game.all_moves
2222
for _ in range(n_simulations):
2323
search(game.copy(), agent, c, N, visited, P, Q)
2424
pi = np.array([N[initial_state][a] for a in all_moves])

agent/search.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections import defaultdict
21
from math import sqrt
2+
from collections import defaultdict
33

4+
import torch
45
from torch import nn, Tensor
56

67
from src.Game import Game
@@ -20,19 +21,35 @@ def search(
2021
state = game.get_state()
2122
if state not in visited:
2223
visited.add(state)
23-
move_scores, v = agent(Tensor([state]))
24+
with torch.no_grad():
25+
move_scores, v = agent(Tensor([state]))
2426
tuple(
25-
P[state].__setitem__(move, move_scores[0, index])
27+
P[state].__setitem__(move, move_scores[0, index].item())
2628
for index, move in enumerate(game.all_moves)
2729
)
28-
return -v
30+
return -v.item()
31+
q_state = Q[state]
32+
p_state = P[state]
33+
n_state = N[state]
34+
sqrt_value = sqrt(sum(n_state.values()))
2935

36+
# def _get_action(game: Game):
37+
# return max(
38+
# game.get_possible_actions(),
39+
# key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
40+
# )
41+
# def _get_action(game: Game):
42+
# best_action = None
43+
# best_value = -float('inf')
44+
# for action in game.all_moves:
45+
# value = q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action])
46+
# if value > best_value and action.is_valid(game):
47+
# best_value, best_action = value, action
48+
# return best_action
3049
action = max(
3150
game.get_possible_actions(),
32-
key=lambda action: Q[state].get(action, 1)
33-
+ c * P[state][action] * sqrt(sum(N[state].values())) / (1 + N[state][action]),
51+
key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
3452
)
35-
3653
next_game_state = game.perform(action)
3754
v = search(next_game_state, agent, c, N, visited, P, Q)
3855

agent/self_play.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def self_play(
3232
def _perform_game(
3333
game: Game, states: list, id_to_agent: dict[int, Agent]
3434
) -> tuple[list[tuple[np.array, np.array, int]], Agent]:
35-
for turn in tqdm(count()):
35+
for _ in tqdm(count()):
3636
agent = id_to_agent[game.current_player.id]
3737
pi, action = policy(game, agent, Config.c, Config.n_simulations)
3838
states.append((game, pi / pi.sum(), 0))
@@ -47,7 +47,6 @@ def _perform_game(
4747
int(result[state[0].current_player.id] == 1),
4848
)
4949
for state in states
50-
if state[1] != game.null_move
5150
),
5251
id_to_agent[
5352
next(player.id for player in game.players if result[player.id])

src/Game.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def get_possible_actions(self) -> tuple[Move, ...]:
155155
(self.null_move,) if self.null_move.is_valid(self) else tuple()
156156
)
157157

158+
def get_possible_action_indexes(self) -> tuple[int, ...]:
159+
return tuple(index for index, move in enumerate(self.all_moves) if move.is_valid(self)) or (
160+
(self.null_move,) if self.null_move.is_valid(self) else tuple()
161+
)
162+
158163
combos = combinations([{field.name: 1} for field in fields(BasicResources)], 3)
159164
all_moves = list(
160165
GrabThreeResource(BasicResources(**res_1, **res_2, **res_3))

src/StateExtractor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def get_state(cls, game: "Game") -> tuple:
4848
tuple(iter(aristocrat.cost))
4949
for aristocrat in game.board.aristocrats
5050
),
51+
iter(game.board.resources),
5152
chain.from_iterable(
5253
(
5354
*tuple(iter(player.resources)),

0 commit comments

Comments
 (0)