-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_multi_model.py
More file actions
138 lines (108 loc) · 4.27 KB
/
test_multi_model.py
File metadata and controls
138 lines (108 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecVideoRecorder, VecMonitor, VecFrameStack
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
from gym.wrappers.rescale_action import RescaleAction
from gym.spaces import Box
from custom_envs.MultiMerge import MultiMergeAllRewards as MultiMerge
import os
import wandb, glob
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.monitor import Monitor
import argparse
parser = argparse.ArgumentParser(description='test PPO multi model')
parser.add_argument("dir", help="model path")
parser.add_argument("--render", default =0, help = "should render default 0")
parser.add_argument("stats_load", help="vec norm stats file")
parser.add_argument("config", help="Config file")
args = parser.parse_args()
module = __import__("config_file",fromlist= [args.config])
exp_config = getattr(module, args.config)
config = {
"policy_type": "MultiInputPolicy",
"env_name": "SumoRamp()",
}
pdir = os.path.abspath('../')
dir = os.path.join(pdir, 'SBRampSavedFiles/wandbsavedfiles')
policy_kwargs = exp_config.policy_kwargs
action_space = exp_config.action_space
image_shape = exp_config.image_shape
obsspaces = exp_config.obsspaces
weights = exp_config.weights
sumoParameters = exp_config.sumoParameters
min_action = -1
max_action = +1
video_folder = dir + '/logs/videos/'
video_length = 180
def make_env(env_id, rank, seed=0, monitor_dir = None):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = MultiMerge(action_space=action_space, obsspaces=obsspaces, sumoParameters=sumoParameters, weights=weights,
isBaseline=False,render=0)
env.seed(seed + rank)
env = RescaleAction(env, min_action, max_action)
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
if monitor_path is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path)
return env
set_random_seed(seed)
return _init
if __name__ == '__main__':
run = wandb.init(
project="Robust-OnRampMerging-Testing",
dir=dir,
name=f"multimodal_{args.config}",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
magic=True
)
env_id = "MultiMerge"
num_cpu = 1 # Number of processes to use
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
env = VecFrameStack(env, n_stack=4) # stack 4 frames
env = VecNormalize.load(args.stats_load,env)
env.training = False
env.norm_reward =True
env = VecMonitor(venv=env)
model = PPO.load(args.dir, env)
obs = env.reset()
n_games = 10
for i_games in range(n_games):
done = False
obs = env.reset()
score = 0
num_collisions = 0
mergeTime = 0
velocity_reward= []
acc_reward = []
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
if int(args.render)==1:
env.render()
score += rewards
if int(info[0]['terminal']) == -1:
num_collisions += 1
if int(info[0]['terminal']) != 0:
mergeTime = info[0]['mergeTime']
velocity_reward.append(info[0]['velocity_reward'])
acc_reward.append((info[0]['acc_reward']))
print(f"score {score} num_collisions : {num_collisions} , mergetime : {mergeTime}")
wandb.log({
"episodic score": score,
"num_collisions": num_collisions,
"mergeTime": mergeTime,
"acc_reward": np.mean(acc_reward),
"velocity_reward": np.mean(velocity_reward),
}, step=i_games)