-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model.py
More file actions
82 lines (63 loc) · 2.73 KB
/
test_model.py
File metadata and controls
82 lines (63 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from switchfl.switch_env import ASyncSwitchEnv
import matplotlib.pyplot as plt
import numpy as np
import os
from switchfl.distr_q import DistrQLearning
from flatland.envs.malfunction_generators import MalfunctionParameters, ParamMalfunctionGen
import time
stochastic_data = MalfunctionParameters(
malfunction_rate=0.01, # Rate of malfunction occurence
min_duration=5, # Minimal duration of malfunction
max_duration=15 # Max duration of malfunction
)
mf = ParamMalfunctionGen(stochastic_data)
if __name__=='__main__':
# Output directory for experiment results
out_dir = '/home/gianvito/Desktop/debug'
os.makedirs(out_dir, exist_ok=True)
# Environment setup
random_seed = 450565
rail_env = RailEnv(
width=18,
height=18,
rail_generator=sparse_rail_generator(
max_num_cities=5,
grid_mode=True,
max_rails_between_cities=1,
max_rail_pairs_in_city=1,
seed=random_seed,
),
line_generator=sparse_line_generator(seed=random_seed),
number_of_agents=2,
malfunction_generator=mf
)
num_episodes = 5
# -------------------------------------------------------------------------------------
# DO NOT MODIFY BELOW THIS LINE
# -------------------------------------------------------------------------------------
env = ASyncSwitchEnv(rail_env, render_mode="human", max_steps=100_000)
model = DistrQLearning(env=env,
gamma = 1.,
epsilon = 0.5,
epsilon_decay_rate = 0.9997,
lr = 0.1,
lr_decay_rate = 1.0,
default_q = 0.,
seed = random_seed)
start_time = time.time()
model.learn(num_episodes=num_episodes, out_dir=out_dir, checkpoint_freq=10000)
model.save(os.path.join(out_dir, "distr_q_model.pkl"))
elapsed_time = time.time() - start_time
print("DONE!")
print(f"TOTAL TIME: {elapsed_time:.1f} seconds")
print(f"Seconds per episode: {elapsed_time / num_episodes:.1f}")
print(f"Flatland step time: {env.flatland_step_time:.1f} seconds")
print(f"Total step time: {env.step_time:.1f} seconds")
print(f"Total last time: {env.last_time:.1f} seconds")
print(f"Action selection time: {env.action_selection_time:.1f} seconds")
print(f"Update time: {env.update_time:.1f} seconds")
print(f"Flatland reset time: {env.reset_time:.1f} seconds")
print(f"Total reset time: {env.reset_total_time:.1f} seconds")