Skip to content

Commit 1de4aa7

Browse files
Merge pull request #20 from Association-INTech/split_train
Split train
2 parents 98c0e1b + 4a6137f commit 1de4aa7

11 files changed

Lines changed: 149 additions & 175 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Debug_Wayfinding
1010
.venv
1111
*.onnx
1212
checkpoints
13+
*.wbproj

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ uv sync --extra rpi
3333

3434
Navigate to the simulator directory.
3535
```bash
36-
cd src/Simulateur
36+
cd scripts
3737
```
3838

3939
Run the multi-process training script.
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import os
2+
import sys
3+
4+
from typing import *
5+
6+
import torch.nn as nn
7+
8+
from stable_baselines3 import PPO
9+
from stable_baselines3.common.vec_env import SubprocVecEnv
10+
11+
simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur'
12+
if simu_path not in sys.path:
13+
sys.path.insert(0, simu_path)
14+
15+
from config import *
16+
from TemporalResNetExtractor import TemporalResNetExtractor
17+
from onnx_utils import *
18+
19+
from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment
20+
if B_DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
21+
22+
def log(s: str):
23+
if B_DEBUG:
24+
print(s, file=open("/tmp/autotech/logs", "a"))
25+
26+
27+
28+
if __name__ == "__main__":
29+
if not os.path.exists("/tmp/autotech/"):
30+
os.mkdir("/tmp/autotech/")
31+
32+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
33+
if B_DEBUG:
34+
print("Webots started", file=open("/tmp/autotech/logs", "w"))
35+
36+
def make_env(rank: int):
37+
log(f"CAREFUL !!! created an SERVER env with {rank=}")
38+
return WebotsSimulationGymEnvironment(rank)
39+
40+
envs = SubprocVecEnv([lambda rank=rank : make_env(rank) for rank in range(n_simulations)])
41+
42+
ExtractorClass = TemporalResNetExtractor
43+
44+
policy_kwargs = dict(
45+
features_extractor_class=ExtractorClass,
46+
features_extractor_kwargs=dict(
47+
context_size=context_size,
48+
lidar_horizontal_resolution=lidar_horizontal_resolution,
49+
camera_horizontal_resolution=camera_horizontal_resolution,
50+
device=device
51+
),
52+
activation_fn=nn.ReLU,
53+
net_arch=[512, 512, 512],
54+
)
55+
56+
57+
ppo_args = dict(
58+
n_steps=4096,
59+
n_epochs=10,
60+
batch_size=256,
61+
learning_rate=3e-4,
62+
gamma=0.99,
63+
verbose=1,
64+
normalize_advantage=True,
65+
device=device
66+
)
67+
68+
69+
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
70+
os.makedirs(save_path, exist_ok=True)
71+
72+
73+
print(save_path)
74+
print(os.listdir(save_path))
75+
76+
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
77+
78+
if valid_files:
79+
model_name = max(
80+
valid_files,
81+
key=lambda x : int(x.rstrip(".zip"))
82+
)
83+
print(f"Loading model {save_path + model_name}")
84+
model = PPO.load(
85+
save_path + model_name,
86+
envs,
87+
**ppo_args,
88+
policy_kwargs=policy_kwargs
89+
)
90+
i = int(model_name.rstrip(".zip")) + 1
91+
print(f"----- Model found, loading {model_name} -----")
92+
93+
else:
94+
model = PPO(
95+
"MlpPolicy",
96+
envs,
97+
**ppo_args,
98+
policy_kwargs=policy_kwargs
99+
)
100+
101+
i = 0
102+
print("----- Model not found, creating a new one -----")
103+
104+
print("MODEL HAS HYPER PARAMETERS:")
105+
print(f"{model.learning_rate=}")
106+
print(f"{model.gamma=}")
107+
print(f"{model.verbose=}")
108+
print(f"{model.n_steps=}")
109+
print(f"{model.n_epochs=}")
110+
print(f"{model.batch_size=}")
111+
print(f"{model.device=}")
112+
113+
log(f"SERVER : finished executing")
114+
115+
# obs = envs.reset()
116+
# while True:
117+
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
118+
# obs, reward, done, info = envs.step(action)
119+
# envs.render() # Optional: visualize the environment
120+
121+
122+
while True:
123+
export_onnx(model)
124+
test_onnx(model)
125+
126+
if B_DEBUG:
127+
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
128+
else:
129+
model.learn(total_timesteps=500_000)
130+
131+
model.save(save_path + str(i))
132+
133+
i += 1
Lines changed: 1 addition & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
11
import os
2-
import time
32
from typing import *
4-
5-
import matplotlib.pyplot as plt
63
import numpy as np
7-
import torch
8-
import torch.nn as nn
9-
import torch.optim as optim
10-
import torch.multiprocessing as mp
11-
12-
from stable_baselines3 import PPO
13-
from stable_baselines3.common.env_checker import check_env
14-
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
15-
164
import gymnasium as gym
175

18-
from onnx_utils import export_onnx, test_onnx
196
from config import *
20-
from CNN1DExtractor import CNN1DExtractor
21-
from TemporalResNetExtractor import TemporalResNetExtractor
22-
from CNN1DResNetExtractor import CNN1DResNetExtractor
23-
24-
if B_DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
257

268

279
def log(s: str):
@@ -108,112 +90,4 @@ def step(self, action):
10890
# check if the context is correct
10991
# if self.simulation_rank == 0:
11092
# print(f"{(obs[0] == 0).mean():.3f} {(obs[1] == 0).mean():.3f}")
111-
return obs, reward, done, truncated, info
112-
113-
114-
if __name__ == "__main__":
115-
if not os.path.exists("/tmp/autotech/"):
116-
os.mkdir("/tmp/autotech/")
117-
118-
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
119-
if B_DEBUG:
120-
print("Webots started", file=open("/tmp/autotech/logs", "w"))
121-
122-
def make_env(rank: int):
123-
log(f"CAREFUL !!! created an SERVER env with {rank=}")
124-
return WebotsSimulationGymEnvironment(rank)
125-
126-
envs = SubprocVecEnv([lambda rank=rank : make_env(rank) for rank in range(n_simulations)])
127-
128-
ExtractorClass = TemporalResNetExtractor
129-
130-
policy_kwargs = dict(
131-
features_extractor_class=ExtractorClass,
132-
features_extractor_kwargs=dict(
133-
context_size=context_size,
134-
lidar_horizontal_resolution=lidar_horizontal_resolution,
135-
camera_horizontal_resolution=camera_horizontal_resolution,
136-
device=device
137-
),
138-
activation_fn=nn.ReLU,
139-
net_arch=[512, 512, 512],
140-
)
141-
142-
143-
ppo_args = dict(
144-
n_steps=4096,
145-
n_epochs=10,
146-
batch_size=256,
147-
learning_rate=3e-4,
148-
gamma=0.99,
149-
verbose=1,
150-
normalize_advantage=True,
151-
device=device
152-
)
153-
154-
155-
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
156-
if not os.path.exists(save_path):
157-
os.mkdir(save_path)
158-
159-
print(save_path)
160-
print(os.listdir(save_path))
161-
162-
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
163-
164-
if valid_files:
165-
model_name = max(
166-
valid_files,
167-
key=lambda x : int(x.rstrip(".zip"))
168-
)
169-
print(f"Loading model {save_path + model_name}")
170-
model = PPO.load(
171-
save_path + model_name,
172-
envs,
173-
**ppo_args,
174-
policy_kwargs=policy_kwargs
175-
)
176-
i = int(model_name.rstrip(".zip")) + 1
177-
print(f"----- Model found, loading {model_name} -----")
178-
179-
else:
180-
model = PPO(
181-
"MlpPolicy",
182-
envs,
183-
**ppo_args,
184-
policy_kwargs=policy_kwargs
185-
)
186-
187-
i = 0
188-
print("----- Model not found, creating a new one -----")
189-
190-
print("MODEL HAS HYPER PARAMETERS:")
191-
print(f"{model.learning_rate=}")
192-
print(f"{model.gamma=}")
193-
print(f"{model.verbose=}")
194-
print(f"{model.n_steps=}")
195-
print(f"{model.n_epochs=}")
196-
print(f"{model.batch_size=}")
197-
print(f"{model.device=}")
198-
199-
log(f"SERVER : finished executing")
200-
201-
# obs = envs.reset()
202-
# while True:
203-
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
204-
# obs, reward, done, info = envs.step(action)
205-
# envs.render() # Optional: visualize the environment
206-
207-
208-
while True:
209-
export_onnx(model)
210-
test_onnx(model)
211-
212-
if B_DEBUG:
213-
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
214-
else:
215-
model.learn(total_timesteps=500_000)
216-
217-
model.save(save_path + str(i))
218-
219-
i += 1
93+
return obs, reward, done, truncated, info

src/Simulateur/__init__.py

Whitespace-only changes.

src/Simulateur/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.cuda import is_available
33

44
n_map = 2
5-
n_simulations = 8
5+
n_simulations = 2
66
n_vehicles = 1
77
n_stupid_vehicles = 0
88
n_actions_steering = 16

src/Simulateur/controllers/controllerWorldSupervisor/controllerWorldSupervisor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import *
33
import numpy as np
44
import gymnasium as gym
5+
import time
56

67
from checkpointmanager import CheckpointManager, checkpoints
78

@@ -219,7 +220,17 @@ def main():
219220
#Prédiction pour séléctionner une action à partir de l"observation
220221
for e in envs:
221222
log(f"CLIENT{simulation_rank}/{e.vehicle_rank} : trying to read from fifo")
222-
action = np.frombuffer(e.fifo_r.read(np.dtype(np.int64).itemsize * 2), dtype=np.int64)
223+
224+
timeout = 10 # seconds
225+
start_time = time.time()
226+
227+
while time.time() - start_time < timeout:
228+
raw = e.fifo_r.read(np.dtype(np.int64).itemsize * 2)
229+
if len(raw) == np.dtype(np.int64).itemsize * 2:
230+
# We got the full action data
231+
action = np.frombuffer(raw, dtype=np.int64)
232+
break
233+
223234
log(f"CLIENT{simulation_rank}/{e.vehicle_rank} : received {action=}")
224235

225236
obs, reward, done, truncated, info = e.step(action)

src/Simulateur/worlds/.piste.wbproj

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/Simulateur/worlds/.piste0.wbproj

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/Simulateur/worlds/.piste1.wbproj

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)