Skip to content

Commit fa881fc

Browse files
Parameter Injection
1 parent fb4186e commit fa881fc

4 files changed

Lines changed: 72 additions & 16 deletions

File tree

pufferlib/config/ocean/drive.ini

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,19 @@ reward_bound_goal_radius_min = 2.0
7070
reward_bound_goal_radius_max = 12.0
7171

7272
reward_bound_collision_min = -3.0
73-
reward_bound_collision_max = -2.9
73+
reward_bound_collision_max = -0.1
7474

7575
reward_bound_offroad_min = -3.0
76-
reward_bound_offroad_max = -2.9
76+
reward_bound_offroad_max = -0.1
7777

7878
reward_bound_comfort_min = -0.1
7979
reward_bound_comfort_max = 0.0
8080

81-
reward_bound_lane_align_min = 0.0020
81+
reward_bound_lane_align_min = 0.00020
8282
reward_bound_lane_align_max = 0.0025
8383

8484
reward_bound_lane_center_min = -0.00075
85-
reward_bound_lane_center_max = -0.00065
85+
reward_bound_lane_center_max = -0.000065
8686

8787
reward_bound_velocity_min = 0.0
8888
reward_bound_velocity_max = 0.005
@@ -116,7 +116,7 @@ reward_bound_acc_max = 1.5
116116

117117
[train]
118118
seed=42
119-
total_timesteps = 2_000_000_000
119+
total_timesteps = 1_000_000_000_0
120120
; learning_rate = 0.02
121121
; gamma = 0.985
122122
anneal_lr = True
@@ -141,11 +141,11 @@ vf_clip_coef = 0.1999999999999999
141141
vf_coef = 2
142142
vtrace_c_clip = 1
143143
vtrace_rho_clip = 1
144-
checkpoint_interval = 250
144+
checkpoint_interval = 1000
145145
; Rendering options
146146
render = True
147147
render_async = False # Render interval of below 50 might cause process starvation and slowness in training
148-
render_interval = 250
148+
render_interval = 1000
149149
; If True, show exactly what the agent sees in agent observation
150150
obs_only = True
151151
; Show grid lines

pufferlib/ocean/drive/drive.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,9 @@ def __init__(
372372

373373
self.c_envs = binding.vectorize(*env_ids)
374374

375-
def reset(self, seed=0):
376-
binding.vec_reset(self.c_envs, seed)
375+
def reset(self, seed=0, parameters=None):
376+
parameters = parameters or {}
377+
binding.vec_reset(self.c_envs, seed, parameters)
377378
self.tick = 0
378379
self.truncations[:] = 0
379380
return self.observations, []
@@ -518,7 +519,7 @@ def step(self, actions):
518519
env_ids.append(env_id)
519520
self.c_envs = binding.vectorize(*env_ids)
520521

521-
binding.vec_reset(self.c_envs, seed)
522+
binding.vec_reset(self.c_envs, seed, None)
522523
self.terminals[:] = 1
523524
return (self.observations, self.rewards, self.terminals, self.truncations, info)
524525

pufferlib/ocean/env_binding.h

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,53 @@ static PyObject *vectorize(PyObject *self, PyObject *args) {
473473
return PyLong_FromVoidPtr(vec);
474474
}
475475

476+
void apply_parameters(VecEnv *vec, PyObject *params_dict) {
477+
if (params_dict == NULL || !PyDict_Check(params_dict)) {
478+
return;
479+
}
480+
481+
// Helper to extract float parameter and apply to all envs
482+
#define APPLY_REWARD_BOUND(param_min_name, param_max_name, coef_index) \
483+
{ \
484+
PyObject *val_min = PyDict_GetItemString(params_dict, param_min_name); \
485+
PyObject *val_max = PyDict_GetItemString(params_dict, param_max_name); \
486+
if (val_min != NULL && PyFloat_Check(val_min)) { \
487+
for (int i = 0; i < vec->num_envs; i++) { \
488+
Drive *drive = (Drive *)vec->envs[i]; \
489+
drive->reward_bounds[coef_index].min_val = (float)PyFloat_AsDouble(val_min); \
490+
} \
491+
} \
492+
if (val_max != NULL && PyFloat_Check(val_max)) { \
493+
for (int i = 0; i < vec->num_envs; i++) { \
494+
Drive *drive = (Drive *)vec->envs[i]; \
495+
drive->reward_bounds[coef_index].max_val = (float)PyFloat_AsDouble(val_max); \
496+
} \
497+
} \
498+
}
499+
500+
APPLY_REWARD_BOUND("reward_bound_goal_radius_min", "reward_bound_goal_radius_max", REWARD_COEF_GOAL_RADIUS);
501+
APPLY_REWARD_BOUND("reward_bound_collision_min", "reward_bound_collision_max", REWARD_COEF_COLLISION);
502+
APPLY_REWARD_BOUND("reward_bound_offroad_min", "reward_bound_offroad_max", REWARD_COEF_OFFROAD);
503+
APPLY_REWARD_BOUND("reward_bound_comfort_min", "reward_bound_comfort_max", REWARD_COEF_COMFORT);
504+
APPLY_REWARD_BOUND("reward_bound_lane_align_min", "reward_bound_lane_align_max", REWARD_COEF_LANE_ALIGN);
505+
APPLY_REWARD_BOUND("reward_bound_lane_center_min", "reward_bound_lane_center_max", REWARD_COEF_LANE_CENTER);
506+
APPLY_REWARD_BOUND("reward_bound_velocity_min", "reward_bound_velocity_max", REWARD_COEF_VELOCITY);
507+
APPLY_REWARD_BOUND("reward_bound_traffic_light_min", "reward_bound_traffic_light_max", REWARD_COEF_TRAFFIC_LIGHT);
508+
APPLY_REWARD_BOUND("reward_bound_center_bias_min", "reward_bound_center_bias_max", REWARD_COEF_CENTER_BIAS);
509+
APPLY_REWARD_BOUND("reward_bound_vel_align_min", "reward_bound_vel_align_max", REWARD_COEF_VEL_ALIGN);
510+
APPLY_REWARD_BOUND("reward_bound_overspeed_min", "reward_bound_overspeed_max", REWARD_COEF_OVERSPEED);
511+
APPLY_REWARD_BOUND("reward_bound_timestep_min", "reward_bound_timestep_max", REWARD_COEF_TIMESTEP);
512+
APPLY_REWARD_BOUND("reward_bound_reverse_min", "reward_bound_reverse_max", REWARD_COEF_REVERSE);
513+
APPLY_REWARD_BOUND("reward_bound_throttle_min", "reward_bound_throttle_max", REWARD_COEF_THROTTLE);
514+
APPLY_REWARD_BOUND("reward_bound_steer_min", "reward_bound_steer_max", REWARD_COEF_STEER);
515+
APPLY_REWARD_BOUND("reward_bound_acc_min", "reward_bound_acc_max", REWARD_COEF_ACC);
516+
517+
#undef APPLY_REWARD_BOUND
518+
}
519+
476520
static PyObject *vec_reset(PyObject *self, PyObject *args) {
477-
if (PyTuple_Size(args) != 2) {
478-
PyErr_SetString(PyExc_TypeError, "vec_reset requires 2 arguments");
521+
if (PyTuple_Size(args) != 3) {
522+
PyErr_SetString(PyExc_TypeError, "vec_reset requires 3 arguments");
479523
return NULL;
480524
}
481525

@@ -484,6 +528,17 @@ static PyObject *vec_reset(PyObject *self, PyObject *args) {
484528
return NULL;
485529
}
486530

531+
PyObject *params = PyTuple_GetItem(args, 2);
532+
533+
if (params == Py_None) {
534+
// skip parameter logic
535+
} else if (!PyDict_Check(params)) {
536+
PyErr_SetString(PyExc_TypeError, "parameters must be dict or None");
537+
return NULL;
538+
} else {
539+
apply_parameters(vec, params);
540+
}
541+
487542
PyObject *seed_arg = PyTuple_GetItem(args, 1);
488543
if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) {
489544
PyErr_SetString(PyExc_TypeError, "seed must be an integer");

pufferlib/vector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def _worker_process(
231231

232232
start = time.time()
233233
if sem == RESET:
234-
seed = recv_pipe.recv()
235-
_, infos = envs.reset(seed=seed)
234+
seed, parameters = recv_pipe.recv()
235+
_, infos = envs.reset(seed=seed, parameters=parameters)
236236
elif sem == STEP:
237237
_, _, _, _, infos = envs.step(atn_arr)
238238
elif sem == CLOSE:
@@ -503,7 +503,7 @@ def send(self, actions):
503503
self.actions[idxs] = actions
504504
self.buf["semaphores"][idxs] = STEP
505505

506-
def async_reset(self, seed=0):
506+
def async_reset(self, seed=0, parameters=None):
507507
# Flush any waiting workers
508508
while self.waiting_workers:
509509
worker = self.waiting_workers.pop(0)
@@ -528,7 +528,7 @@ def async_reset(self, seed=0):
528528
for i in range(self.num_workers):
529529
start = i * self.envs_per_worker
530530
end = (i + 1) * self.envs_per_worker
531-
self.send_pipes[i].send(seed + i)
531+
self.send_pipes[i].send((seed + i, parameters))
532532

533533
def notify(self):
534534
self.buf["notify"][:] = True

0 commit comments

Comments
 (0)