-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathplay.py
More file actions
26 lines (19 loc) · 1.13 KB
/
play.py
File metadata and controls
26 lines (19 loc) · 1.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from pathlib import Path
from stable_eureka import make_env, get_logger, RLEvaluator
import yaml
if __name__ == '__main__':
exp_path = Path('/home/rsanchezmo/Projects/stable-eureka/experiments/lunar_lander_gpt4o/2024-06-09/')
# model_path = exp_path / 'code' / f'iteration_4' / f'sample_3' / 'model.zip'
model_path = exp_path / 'code' / 'benchmark' / 'model.zip'
config = yaml.safe_load(open(exp_path / 'config.yaml', 'r'))
env_name = f'LunarLander-v2'
env = make_env(env_class=env_name,
env_kwargs=config['environment'].get('kwargs', None),
n_envs=1,
is_atari=config['rl']['training'].get('is_atari', False),
state_stack=config['rl']['training'].get('state_stack', 1),
multithreaded=config['rl']['training'].get('multithreaded', False))
evaluator = RLEvaluator(model_path, algo=config['rl']['algo'])
evaluator.run(env, seed=config['rl']['evaluation']['seed'],
n_episodes=config['rl']['evaluation']['num_episodes'],
logger=get_logger(), save_gif=config['rl']['evaluation']['save_gif'])