-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathConfig.py
More file actions
74 lines (60 loc) · 1.61 KB
/
Config.py
File metadata and controls
74 lines (60 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import atexit
import os
import random
from pathlib import Path
import numpy as np
import torch
from dotenv import load_dotenv
from lightning import seed_everything
load_dotenv()
class _DBConfig:
db_name = os.getenv("POSTGRES_DB")
db_password = os.getenv("POSTGRES_PASSWORD")
db_username = os.getenv("POSTGRES_USER")
db_url = f"postgresql://{db_username}:{db_password}@localhost:5432/{db_name}"
class _ConfigPaths:
root = Path(__file__).parent
model_path = root / "models"
model_path.mkdir(exist_ok=True)
gui = root / "gui"
templates = gui / "templates"
ai_weights = root / 'speed_game.pth'
train_values_file = (root / "train_log.log").open('w')
class _ConfigAgent:
hidden_size = (
256,
128,
64,
32,
32,
32,
32,
32,
32,
)
initial_train_learning_rate = 1e-3
debug = True
# debug = False
if not debug:
random_state = random.randint(0, 2 ** 32)
print(f"{random_state=}")
else:
random_state = 42
class Config(_ConfigPaths, _ConfigAgent, _DBConfig):
lr_decline_rate = 1000
agent_print_interval = 10
results_over_time_counter = 100
agent_save_interval = 500
play_beta = 100
beta = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
training_buffer_len = 50_000
min_n_points_to_finish = 15
n_players = 2
n_actions = 45
seed_everything(Config.random_state, workers=True)
random.seed(Config.random_state)
np.random.seed(Config.random_state)
@atexit.register
def close():
Config.train_values_file.close()