diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c index 0fd36d6e7..f4bc8cad5 100644 --- a/pufferlib/ocean/drive/drive.c +++ b/pufferlib/ocean/drive/drive.c @@ -23,7 +23,7 @@ void test_drivenet() { Weights *weights = load_weights("puffer_drive_weights.bin"); DriveNet *net = init_drivenet(weights, num_agents, CLASSIC, 0); - forward(net, observations, actions); + forward(net, NULL, observations, actions); for (int i = 0; i < num_agents * num_actions; i++) { printf("idx: %d, action: %d, logits:", i, actions[i]); for (int j = 0; j < num_actions; j++) { @@ -126,7 +126,7 @@ void demo() { int *actions = (int *)env.actions; // Single integer per agent if (!IsKeyDown(KEY_LEFT_SHIFT)) { - forward(net, env.observations, actions); + forward(net, &env, env.observations, actions); } else { if (env.dynamics_model == CLASSIC) { // Classic dynamics: acceleration and steering diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index fc4aa2bfc..2a2c29e0e 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -350,6 +350,11 @@ struct Drive { float observation_window_size; float polyline_reduction_threshold; float polyline_max_segment_length; + // Predicted trajectory visualization (set from Python, drawn by c_render) + float *predicted_traj_x; // [num_agents * predicted_traj_len] + float *predicted_traj_y; // [num_agents * predicted_traj_len] + int predicted_traj_len; // steps per agent (0 = no trajectory to draw) + int sdc_track_index; int num_tracks_to_predict; int *tracks_to_predict_indices; @@ -374,6 +379,84 @@ void sample_new_goal(Drive *env, int agent_idx); int check_lane_aligned(Agent *car, RoadMapElement *lane, int geometry_idx); void reset_goal_positions(Drive *env); +// ======================================== +// Pure dynamics step functions +// ======================================== + +typedef struct { + float x, y, heading, vx, vy; + float a_long, a_lat, steering_angle; +} DynState; + +// Classic bicycle model: action is (acceleration, steering_angle) +static inline DynState classic_dynamics_step(DynState s, float acceleration, float steering, float length, float dt) { + DynState out = s; + float speed_mag = sqrtf(s.vx * s.vx + s.vy * s.vy); + float v_dot = s.vx * cosf(s.heading) + s.vy * sinf(s.heading); + float signed_speed = copysignf(speed_mag, v_dot); + signed_speed += acceleration * dt; + signed_speed = clipSpeed(signed_speed); + float beta = tanhf(0.5f * tanf(steering)); + float new_vx = signed_speed * cosf(s.heading + beta); + float new_vy = signed_speed * sinf(s.heading + beta); + float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / length; + out.vx = new_vx; + out.vy = new_vy; + out.x = s.x + new_vx * dt; + out.y = s.y + new_vy * dt; + out.heading = s.heading + yaw_rate * dt; + return out; +} + +// Jerk model: action is (jerk_long, jerk_lat). State includes a_long, a_lat, steering_angle. +static inline DynState jerk_dynamics_step(DynState s, float jerk_long, float jerk_lat, float wheelbase, float dt) { + DynState out = s; + float al_new = s.a_long + jerk_long * dt; + float at_new = s.a_lat + jerk_lat * dt; + + // Zero-crossing clamp + al_new = (s.a_long * al_new < 0) ? 0.0f : fminf(fmaxf(al_new, -5.0f), 2.5f); + at_new = (s.a_lat * at_new < 0) ? 0.0f : fminf(fmaxf(at_new, -4.0f), 4.0f); + + // Velocity + float v_dot = s.vx * cosf(s.heading) + s.vy * sinf(s.heading); + float signed_v = copysignf(sqrtf(s.vx * s.vx + s.vy * s.vy), v_dot); + float v_new = signed_v + 0.5f * (al_new + s.a_long) * dt; + v_new = (signed_v * v_new < 0) ? 0.0f : fminf(fmaxf(v_new, -2.0f), SPEED_LIMIT); + + // Steering from lateral acceleration + float curvature = at_new / fmaxf(v_new * v_new, 1e-5f); + float steer_target = atanf(curvature * wheelbase); + float delta_steer = fminf(fmaxf(steer_target - s.steering_angle, -0.6f * dt), 0.6f * dt); + float steer_new = fminf(fmaxf(s.steering_angle + delta_steer, -0.55f), 0.55f); + curvature = tanf(steer_new) / wheelbase; + at_new = v_new * v_new * curvature; + + // Bicycle displacement + float d = 0.5f * (v_new + signed_v) * dt; + float theta = d * curvature; + float dx_local, dy_local; + if (fabsf(curvature) < 1e-5f || fabsf(theta) < 1e-5f) { + dx_local = d; + dy_local = 0.0f; + } else { + dx_local = sinf(theta) / curvature; + dy_local = (1.0f - cosf(theta)) / curvature; + } + float cos_h = cosf(s.heading); + float sin_h = sinf(s.heading); + + out.x = s.x + dx_local * cos_h - dy_local * sin_h; + out.y = s.y + dx_local * sin_h + dy_local * cos_h; + out.heading = s.heading + theta; + out.vx = v_new * cosf(out.heading); + out.vy = v_new * sinf(out.heading); + out.a_long = al_new; + out.a_lat = at_new; + out.steering_angle = steer_new; + return out; +} + // ======================================== // Utility Functions // ======================================== @@ -2273,6 +2356,8 @@ void c_close(Drive *env) { free(env->static_agent_indices); free(env->expert_static_agent_indices); free(env->tracks_to_predict_indices); + free(env->predicted_traj_x); + free(env->predicted_traj_y); free(env->ini_file); } @@ -2918,42 +3003,18 @@ void move_dynamics(Drive *env, int action_idx, int agent_idx) { steering = STEERING_VALUES[steering_index]; } - // Current state - float heading = agent->sim_heading; - float vx = agent->sim_vx; - float vy = agent->sim_vy; - - // Calculate current speed (signed based on direction relative to heading) - float speed_magnitude = sqrtf(vx * vx + vy * vy); - float v_dot_heading = vx * cosf(heading) + vy * sinf(heading); - float signed_speed = copysignf(speed_magnitude, v_dot_heading); - - // Update speed with acceleration - signed_speed = signed_speed + acceleration * env->dt; - signed_speed = clipSpeed(signed_speed); - // Compute yaw rate - float beta = tanh(.5 * tanf(steering)); - - // New heading - float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / agent->sim_length; - - // New velocity - float new_vx = signed_speed * cosf(heading + beta); - float new_vy = signed_speed * sinf(heading + beta); - - // Update position - float dx = new_vx * env->dt; - float dy = new_vy * env->dt; - float dheading = yaw_rate * env->dt; - - // Apply updates to the agent's state - agent->sim_x += dx; - agent->sim_y += dy; - agent->sim_heading += dheading; - agent->heading_x = cosf(agent->sim_heading); - agent->heading_y = sinf(agent->sim_heading); - agent->sim_vx = new_vx; - agent->sim_vy = new_vy; + DynState s = {agent->sim_x, agent->sim_y, agent->sim_heading, + agent->sim_vx, agent->sim_vy, 0, 0, 0}; + DynState ns = classic_dynamics_step(s, acceleration, steering, agent->sim_length, env->dt); + float dx = ns.x - s.x; + float dy = ns.y - s.y; + agent->sim_x = ns.x; + agent->sim_y = ns.y; + agent->sim_heading = ns.heading; + agent->heading_x = cosf(ns.heading); + agent->heading_y = sinf(ns.heading); + agent->sim_vx = ns.vx; + agent->sim_vy = ns.vy; agent->cumulative_displacement += sqrtf(dx * dx + dy * dy); } else { // JERK dynamics model @@ -2984,76 +3045,24 @@ void move_dynamics(Drive *env, int action_idx, int agent_idx) { a_lat = JERK_LAT[a_lat_idx]; } - // Calculate new acceleration - float a_long_new = agent->a_long + a_long * env->dt; - float a_lat_new = agent->a_lat + a_lat * env->dt; - - // Make it easy to stop with 0 accel - if (agent->a_long * a_long_new < 0) { - a_long_new = 0.0f; - } else { - a_long_new = clip(a_long_new, -5.0f, 2.5f); - } - - if (agent->a_lat * a_lat_new < 0) { - a_lat_new = 0.0f; - } else { - a_lat_new = clip(a_lat_new, -4.0f, 4.0f); - } - - // Calculate new velocity - float v_dot_heading = agent->sim_vx * cosf(agent->sim_heading) + agent->sim_vy * sinf(agent->sim_heading); - float signed_v = copysignf(sqrtf(agent->sim_vx * agent->sim_vx + agent->sim_vy * agent->sim_vy), v_dot_heading); - float v_new = signed_v + 0.5f * (a_long_new + agent->a_long) * env->dt; - - // Make it easy to stop with 0 vel - if (signed_v * v_new < 0) { - v_new = 0.0f; - } else { - v_new = clip(v_new, -2.0f, 20.0f); - } - - // Calculate new steering angle - float signed_curvature = a_lat_new / fmaxf(v_new * v_new, 1e-5f); - float steering_angle = atanf(signed_curvature * agent->wheelbase); - float delta_steer = clip(steering_angle - agent->steering_angle, -0.6f * env->dt, 0.6f * env->dt); - float new_steering_angle = clip(agent->steering_angle + delta_steer, -0.55f, 0.55f); - - // Update curvature and accel to account for limited steering - signed_curvature = tanf(new_steering_angle) / agent->wheelbase; - a_lat_new = v_new * v_new * signed_curvature; - - // Calculate resulting movement using bicycle dynamics - float d = 0.5f * (v_new + signed_v) * env->dt; - float theta = d * signed_curvature; - float dx_local, dy_local; - - if (fabsf(signed_curvature) < 1e-5f || fabsf(theta) < 1e-5f) { - dx_local = d; - dy_local = 0.0f; - } else { - dx_local = sinf(theta) / signed_curvature; - dy_local = (1.0f - cosf(theta)) / signed_curvature; - } - - float cos_heading = cosf(agent->sim_heading); - float sin_heading = sinf(agent->sim_heading); - float dx = dx_local * cos_heading - dy_local * sin_heading; - float dy = dx_local * sin_heading + dy_local * cos_heading; - - // Update everything - agent->sim_x += dx; - agent->sim_y += dy; - agent->jerk_long = (a_long_new - agent->a_long) / env->dt; - agent->jerk_lat = (a_lat_new - agent->a_lat) / env->dt; - agent->a_long = a_long_new; - agent->a_lat = a_lat_new; - agent->sim_heading = normalize_heading(agent->sim_heading + theta); + DynState s = {agent->sim_x, agent->sim_y, agent->sim_heading, + agent->sim_vx, agent->sim_vy, + agent->a_long, agent->a_lat, agent->steering_angle}; + DynState ns = jerk_dynamics_step(s, a_long, a_lat, agent->wheelbase, env->dt); + float dx = ns.x - s.x; + float dy = ns.y - s.y; + agent->jerk_long = (ns.a_long - agent->a_long) / env->dt; + agent->jerk_lat = (ns.a_lat - agent->a_lat) / env->dt; + agent->sim_x = ns.x; + agent->sim_y = ns.y; + agent->sim_heading = normalize_heading(ns.heading); agent->heading_x = cosf(agent->sim_heading); agent->heading_y = sinf(agent->sim_heading); - agent->sim_vx = v_new * agent->heading_x; - agent->sim_vy = v_new * agent->heading_y; - agent->steering_angle = new_steering_angle; + agent->sim_vx = ns.vx; + agent->sim_vy = ns.vy; + agent->a_long = ns.a_long; + agent->a_lat = ns.a_lat; + agent->steering_angle = ns.steering_angle; agent->cumulative_displacement += sqrtf(dx * dx + dy * dy); agent->cumulative_displacement_since_last_goal += sqrtf(dx * dx + dy * dy); } @@ -4110,6 +4119,32 @@ void c_render(Drive *env) { handle_camera_controls(env->client); draw_scene(env, client, 0, 0, 0, 0); + // Re-enter 3D mode for trajectory drawing (draw_scene calls EndMode3D internally) + BeginMode3D(client->camera); + rlDisableDepthTest(); + if (env->predicted_traj_x != NULL && env->predicted_traj_len > 0) { + int i = env->human_agent_idx; + if (i < env->active_agent_count) { + int agent_idx = env->active_agent_indices[i]; + Agent *agent = &env->agents[agent_idx]; + int tlen = env->predicted_traj_len; + float z = agent->sim_z + 0.1f; + + Vector3 prev = {agent->sim_x, agent->sim_y, z}; + for (int t = 0; t < tlen; t++) { + float tx = env->predicted_traj_x[i * tlen + t]; + float ty = env->predicted_traj_y[i * tlen + t]; + Vector3 curr = {tx, ty, z}; + rlSetLineWidth(3.0f); + DrawLine3D(prev, curr, RED); + DrawCircle3D(curr, 1.5f, (Vector3){0, 0, 1}, 0.0f, RED); + prev = curr; + } + } + } + rlEnableDepthTest(); + EndMode3D(); + if (IsKeyPressed(KEY_TAB) && env->active_agent_count > 0) { env->human_agent_idx = (env->human_agent_idx + 1) % env->active_agent_count; } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index ec13480da..d69b4ed3c 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -721,6 +721,25 @@ def get_road_edge_polylines(self): return polylines + def set_predicted_trajectories(self, action_trajectories): + """Roll out action trajectories through dynamics and set for rendering. + + Args: + action_trajectories: numpy array of shape [num_agents, traj_len] + containing discrete action indices. + """ + _, traj_len = action_trajectories.shape + for i in range(self.num_envs): + start = self.agent_offsets[i] + end = self.agent_offsets[i + 1] + sub_actions = action_trajectories[start:end] + binding.vec_set_trajectory( + self.c_envs, + i, + sub_actions.flatten().astype(np.int32), + traj_len, + ) + def render(self): binding.vec_render(self.c_envs, 0) diff --git a/pufferlib/ocean/drive/drivenet.h b/pufferlib/ocean/drive/drivenet.h index 50b6a987a..81f5aa149 100644 --- a/pufferlib/ocean/drive/drivenet.h +++ b/pufferlib/ocean/drive/drivenet.h @@ -44,6 +44,11 @@ struct DriveNet { Linear *actor; Linear *value_fn; Multidiscrete *multidiscrete; + // Predicted trajectory (rolled out from actor output) + int action_size; + int traj_steps; + float *predicted_traj_x; + float *predicted_traj_y; }; DriveNet *init_drivenet(Weights *weights, int num_agents, int dynamics_model, int reward_conditioning) { @@ -108,6 +113,10 @@ DriveNet *init_drivenet(Weights *weights, int num_agents, int dynamics_model, in memset(net->lstm->state_h, 0, num_agents * NN_HIDDEN_SIZE * sizeof(float)); memset(net->lstm->state_c, 0, num_agents * NN_HIDDEN_SIZE * sizeof(float)); net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, action_dim); + net->action_size = action_size; + net->traj_steps = 80; + net->predicted_traj_x = calloc(num_agents * net->traj_steps, sizeof(float)); + net->predicted_traj_y = calloc(num_agents * net->traj_steps, sizeof(float)); return net; } @@ -138,13 +147,15 @@ void free_drivenet(DriveNet *net) { free(net->shared_embedding); free(net->relu); free(net->multidiscrete); + free(net->predicted_traj_x); + free(net->predicted_traj_y); free(net->actor); free(net->value_fn); free(net->lstm); free(net); } -void forward(DriveNet *net, float *observations, int *actions) { +void forward(DriveNet *net, Drive *env, float *observations, int *actions) { int ego_dim = net->ego_dim; int max_partners = MAX_AGENTS - 1; int max_road_obs = MAX_ROAD_SEGMENT_OBSERVATIONS; @@ -271,4 +282,46 @@ void forward(DriveNet *net, float *observations, int *actions) { // Get action by taking argmax of actor output softmax_multidiscrete(net->multidiscrete, net->actor->output, actions); + + // Roll out predicted trajectories using the reusable dynamics step functions + if (env != NULL) { + int traj_steps = net->traj_steps; + int num_steer = sizeof(STEERING_VALUES) / sizeof(STEERING_VALUES[0]); + int num_lat = sizeof(JERK_LAT) / sizeof(JERK_LAT[0]); + + for (int b = 0; b < net->num_agents; b++) { + int agent_idx = env->active_agent_indices[b]; + Agent *agent = &env->agents[agent_idx]; + int action_val = actions[b]; + + DynState s = {agent->sim_x, agent->sim_y, agent->sim_heading, + agent->sim_vx, agent->sim_vy, + agent->a_long, agent->a_lat, agent->steering_angle}; + + for (int t = 0; t < traj_steps; t++) { + // Apply action on first step, zero-jerk/zero-accel after + if (t == 0) { + if (env->dynamics_model == CLASSIC) { + int accel_idx = action_val / num_steer; + int steer_idx = action_val % num_steer; + s = classic_dynamics_step(s, ACCELERATION_VALUES[accel_idx], + STEERING_VALUES[steer_idx], agent->sim_length, env->dt); + } else { + int al_idx = action_val / num_lat; + int at_idx = action_val % num_lat; + s = jerk_dynamics_step(s, JERK_LONG[al_idx], JERK_LAT[at_idx], + agent->wheelbase, env->dt); + } + } else { + if (env->dynamics_model == CLASSIC) { + s = classic_dynamics_step(s, 0.0f, 0.0f, agent->sim_length, env->dt); + } else { + s = jerk_dynamics_step(s, 0.0f, 0.0f, agent->wheelbase, env->dt); + } + } + net->predicted_traj_x[b * traj_steps + t] = s.x; + net->predicted_traj_y[b * traj_steps + t] = s.y; + } + } + } } diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index 6e2bd719c..cb8d869b7 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -67,7 +67,7 @@ void CloseVideo(VideoRecorder *recorder) { void renderTopDownView(Drive *env, Client *client, int map_height, int obs, int lasers, int trajectories, int frame_count, float *path, int show_human_logs, int show_grid, int img_width, int img_height, - int zoom_in) { + int zoom_in, DriveNet *net) { BeginDrawing(); // Top-down orthographic camera @@ -133,6 +133,31 @@ void renderTopDownView(Drive *env, Client *client, int map_height, int obs, int // Draw scene draw_scene(env, client, 1, obs, lasers, show_grid); + + // Re-enter 3D mode for trajectory drawing (draw_scene calls EndMode3D internally) + BeginMode3D(camera); + rlDisableDepthTest(); + + // Draw predicted trajectories from policy network (ego agent only) + if (net != NULL && net->predicted_traj_x != NULL) { + int ego = env->human_agent_idx; + if (ego < env->active_agent_count && ego < net->num_agents) { + int agent_idx = env->active_agent_indices[ego]; + Agent *agent = &env->agents[agent_idx]; + int tlen = net->traj_steps; + Vector3 prev = {agent->sim_x, agent->sim_y, agent->sim_z + 0.6f}; + for (int t = 0; t < tlen; t++) { + float tx = net->predicted_traj_x[ego * tlen + t]; + float ty = net->predicted_traj_y[ego * tlen + t]; + Vector3 curr = {tx, ty, agent->sim_z + 0.6f}; + DrawLine3D(prev, curr, Fade(SKYBLUE, 0.8f)); + DrawSphere(curr, 0.3f, Fade(SKYBLUE, 0.6f)); + prev = curr; + } + } + } + + rlEnableDepthTest(); EndMode3D(); EndDrawing(); } @@ -193,11 +218,11 @@ static int make_gif_from_frames(const char *pattern, int fps, const char *palett int eval_gif(const char *map_name, const char *policy_name, int show_grid, int obs_only, int lasers, int show_human_logs, int frame_skip, const char *view_mode, const char *output_topdown, - const char *output_agent, int num_maps, int zoom_in) { + const char *output_agent, int num_maps, int zoom_in, const char *config_path) { // Parse configuration from INI file env_init_config conf = {0}; - const char *ini_file = "pufferlib/config/ocean/drive.ini"; + const char *ini_file = config_path ? config_path : "pufferlib/config/ocean/drive.ini"; if (ini_parse(ini_file, handler, &conf) < 0) { fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); return -1; @@ -344,6 +369,7 @@ int eval_gif(const char *map_name, const char *policy_name, int show_grid, int o DriveNet *net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, env.reward_conditioning); int frame_count = env.episode_length > 0 ? env.episode_length : TRAJECTORY_LENGTH_DEFAULT; + printf("episode_length=%d, frame_count=%d\n", env.episode_length, frame_count); char filename_topdown[256]; char filename_agent[256]; @@ -400,11 +426,11 @@ int eval_gif(const char *map_name, const char *policy_name, int show_grid, int o for (int i = 0; i < frame_count; i++) { if (i % frame_skip == 0) { renderTopDownView(&env, client, map_height, 0, 0, 0, frame_count, NULL, show_human_logs, show_grid, - img_width, img_height, zoom_in); + img_width, img_height, zoom_in, net); WriteFrame(&topdown_recorder, img_width, img_height); rendered_frames++; } - forward(net, env.observations, (int *)env.actions); + forward(net, &env, env.observations, (int *)env.actions); c_step(&env); } } @@ -422,7 +448,7 @@ int eval_gif(const char *map_name, const char *policy_name, int show_grid, int o WriteFrame(&agent_recorder, img_width, img_height); rendered_frames++; } - forward(net, env.observations, (int *)env.actions); + forward(net, &env, env.observations, (int *)env.actions); c_step(&env); } } @@ -552,6 +578,6 @@ int main(int argc, char *argv[]) { } eval_gif(map_name, policy_name, show_grid, obs_only, lasers, show_human_logs, frame_skip, view_mode, output_topdown, - output_agent, num_maps, zoom_in); + output_agent, num_maps, zoom_in, ini_file); return 0; } diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 1cb1c6f95..a644746a5 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -529,6 +529,65 @@ static PyObject *vec_step(PyObject *self, PyObject *arg) { Py_RETURN_NONE; } +// Set predicted trajectory positions for rendering (rolled out from Python actions). +static PyObject *vec_set_trajectory(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "vec_set_trajectory requires 4 args: vec_env, env_id, action_traj, traj_len"); + return NULL; + } + VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0)); + int env_id = PyLong_AsLong(PyTuple_GetItem(args, 1)); + PyArrayObject *action_arr = (PyArrayObject *)PyTuple_GetItem(args, 2); + int traj_len = PyLong_AsLong(PyTuple_GetItem(args, 3)); + + Env *env = vec->envs[env_id]; + int num_agents = env->active_agent_count; + int *actions = (int *)PyArray_DATA(action_arr); + int num_steer = sizeof(STEERING_VALUES) / sizeof(STEERING_VALUES[0]); + + free(env->predicted_traj_x); + free(env->predicted_traj_y); + int total = num_agents * traj_len; + env->predicted_traj_x = (float *)calloc(total, sizeof(float)); + env->predicted_traj_y = (float *)calloc(total, sizeof(float)); + env->predicted_traj_len = traj_len; + + int num_lat = sizeof(JERK_LAT) / sizeof(JERK_LAT[0]); + + for (int i = 0; i < num_agents; i++) { + int agent_idx = env->active_agent_indices[i]; + Agent *agent = &env->agents[agent_idx]; + + DynState s = { + .x = agent->sim_x, .y = agent->sim_y, .heading = agent->sim_heading, + .vx = agent->sim_vx, .vy = agent->sim_vy, + .a_long = agent->a_long, .a_lat = agent->a_lat, + .steering_angle = agent->steering_angle, + }; + + for (int t = 0; t < traj_len; t++) { + int action_val = actions[i * traj_len + t]; + + if (env->dynamics_model == CLASSIC) { + int accel_idx = action_val / num_steer; + int steer_idx = action_val % num_steer; + s = classic_dynamics_step(s, ACCELERATION_VALUES[accel_idx], + STEERING_VALUES[steer_idx], agent->sim_length, env->dt); + } else { + int a_long_idx = action_val / num_lat; + int a_lat_idx = action_val % num_lat; + s = jerk_dynamics_step(s, JERK_LONG[a_long_idx], + JERK_LAT[a_lat_idx], agent->wheelbase, env->dt); + } + + env->predicted_traj_x[i * traj_len + t] = s.x; + env->predicted_traj_y[i * traj_len + t] = s.y; + } + } + + Py_RETURN_NONE; +} + static PyObject *vec_render(PyObject *self, PyObject *args) { int num_args = PyTuple_Size(args); if (num_args != 2) { @@ -989,6 +1048,7 @@ static PyMethodDef methods[] = { {"vec_step", vec_step, METH_VARARGS, "Step the vector of environments"}, {"vec_log", vec_log, METH_VARARGS, "Log the vector of environments"}, {"vec_render", vec_render, METH_VARARGS, "Render the vector of environments"}, + {"vec_set_trajectory", vec_set_trajectory, METH_VARARGS, "Set predicted trajectory for rendering"}, {"vec_close", vec_close, METH_VARARGS, "Close the vector of environments"}, {"shared", (PyCFunction)my_shared, METH_VARARGS | METH_KEYWORDS, "Shared state"}, {"get_global_agent_state", get_global_agent_state, METH_VARARGS, "Get global agent state"}, diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py index aecda002a..e43c2f69a 100644 --- a/scripts/export_onnx.py +++ b/scripts/export_onnx.py @@ -23,23 +23,164 @@ from scripts.export_model_bin import load_config +JERK_LONG = torch.tensor([-15.0, -4.0, 0.0, 4.0]) +JERK_LAT = torch.tensor([-4.0, 0.0, 4.0]) +NUM_LAT = 3 +SPEED_LIMIT = 20.0 + + +class JerkDynamicsRollout(torch.nn.Module): + """Roll out a single discrete action through jerk dynamics for N steps. + + Takes the argmax action from logits, decodes into jerk_long/jerk_lat, + and integrates through the bicycle jerk model in ego frame (starting + at x=0, y=0, heading=0). + + Requires initial state from observation: speed, a_long, a_lat, steering_angle. + """ + + def __init__(self, num_steps=80, dt=0.1): + super().__init__() + self.num_steps = num_steps + self.dt = dt + self.register_buffer("jerk_long", JERK_LONG) + self.register_buffer("jerk_lat", JERK_LAT) + + def forward(self, logits, speed, a_long, a_lat, steering_angle, wheelbase): + """ + Args: + logits: [B, num_actions] — policy output + speed: [B] — signed speed from obs + a_long: [B] — longitudinal acceleration from obs + a_lat: [B] — lateral acceleration from obs + steering_angle: [B] — current steering angle from obs + wheelbase: [B] — vehicle wheelbase from obs + + Returns: + trajectory: [B, num_steps, 3] — (x, y, heading) in ego frame + """ + B = logits.shape[0] + action = logits.argmax(dim=-1) # [B] + a_long_idx = action // NUM_LAT + a_lat_idx = action % NUM_LAT + + jerk_long_val = self.jerk_long[a_long_idx] # [B] + jerk_lat_val = self.jerk_lat[a_lat_idx] # [B] + + # State in ego frame + x = torch.zeros(B, device=logits.device) + y = torch.zeros(B, device=logits.device) + heading = torch.zeros(B, device=logits.device) + v = speed.clone() + al = a_long.clone() + at = a_lat.clone() + steer = steering_angle.clone() + + dt = self.dt + trajectory = torch.zeros(B, self.num_steps, 3, device=logits.device) + + # Build jerk schedule: apply action on step 0, zero-jerk after + jl_schedule = torch.zeros(self.num_steps, B, device=logits.device) + jt_schedule = torch.zeros(self.num_steps, B, device=logits.device) + jl_schedule[0] = jerk_long_val + jt_schedule[0] = jerk_lat_val + + for t in range(self.num_steps): + # Update acceleration + al_new = al + jl_schedule[t] * dt + at_new = at + jt_schedule[t] * dt + + # Clamp acceleration + al_new = torch.clamp(al_new, -5.0, 2.5) + at_new = torch.clamp(at_new, -4.0, 4.0) + + # Zero-crossing: easy stop + al_new = torch.where(al * al_new < 0, torch.zeros_like(al_new), al_new) + at_new = torch.where(at * at_new < 0, torch.zeros_like(at_new), at_new) + + # Update velocity + v_new = v + 0.5 * (al_new + al) * dt + v_new = torch.where(v * v_new < 0, torch.zeros_like(v_new), v_new) + v_new = torch.clamp(v_new, -2.0, SPEED_LIMIT) + + # Steering from lateral acceleration + signed_curvature = at_new / torch.clamp(v_new * v_new, min=1e-5) + steer_target = torch.atan(signed_curvature * wheelbase) + delta_steer = torch.clamp(steer_target - steer, -0.6 * dt, 0.6 * dt) + steer_new = torch.clamp(steer + delta_steer, -0.55, 0.55) + + # Recompute curvature from clamped steering + signed_curvature = torch.tan(steer_new) / wheelbase + at_new = v_new * v_new * signed_curvature + + # Displacement (bicycle model) + d = 0.5 * (v_new + v) * dt + theta = d * signed_curvature + + # Local displacement + small = signed_curvature.abs() < 1e-5 + dx_local = torch.where(small, d, torch.sin(theta) / (signed_curvature + 1e-10)) + dy_local = torch.where(small, torch.zeros_like(d), (1 - torch.cos(theta)) / (signed_curvature + 1e-10)) + + # Rotate to world frame + cos_h = torch.cos(heading) + sin_h = torch.sin(heading) + dx = dx_local * cos_h - dy_local * sin_h + dy = dx_local * sin_h + dy_local * cos_h + + # Update state + x = x + dx + y = y + dy + heading = heading + theta + v = v_new + al = al_new + at = at_new + steer = steer_new + + trajectory[:, t, 0] = x + trajectory[:, t, 1] = y + trajectory[:, t, 2] = heading + + return trajectory + + class OnnxWrapper(torch.nn.Module): - def __init__(self, policy): + def __init__(self, policy, fake_trajectory=False, traj_steps=80, dt=0.1): super().__init__() self.policy = policy + self.fake_trajectory = fake_trajectory + if fake_trajectory: + self.rollout = JerkDynamicsRollout(num_steps=traj_steps, dt=dt) def forward(self, observation, h, c): - # Reconstruct the state dictionary expected by LSTMWrapper - # state must be mutable as forward_eval updates it state = {"lstm_h": h, "lstm_c": c} - - # Call forward_eval logits, value = self.policy.forward_eval(observation, state) - - # Extract updated states new_h = state["lstm_h"] new_c = state["lstm_c"] + if self.fake_trajectory: + # Extract ego state from observation (jerk model layout) + # obs[3] = signed_speed / MAX_SPEED + speed = observation[:, 3] * 100.0 # MAX_SPEED = 100 + # obs[7] = steering_angle / pi + steering_angle = observation[:, 7] * 3.14159265 + # obs[8] = a_long (normalized asymmetrically) + a_long_norm = observation[:, 8] + a_long = torch.where( + a_long_norm < 0, + a_long_norm * 15.0, # JERK_LONG[0] = -15 + a_long_norm * 4.0, # JERK_LONG[3] = 4 + ) + # obs[9] = a_lat / JERK_LAT[2] + a_lat = observation[:, 9] * 4.0 # JERK_LAT[2] = 4 + # obs[5] = sim_length / MAX_VEH_LEN + wheelbase = observation[:, 5] * 5.5 # MAX_VEH_LEN approximate + + # logits[0] for single-head multi-discrete (jerk: MultiDiscrete([12])) + logits_tensor = logits[0] if isinstance(logits, (tuple, list)) else logits + trajectory = self.rollout(logits_tensor, speed, a_long, a_lat, steering_angle, wheelbase) + return logits, value, new_h, new_c, trajectory + return logits, value, new_h, new_c @@ -54,6 +195,14 @@ def export_to_onnx(verify=True): ) parser.add_argument("--output", type=str, help="Output .onnx file path") parser.add_argument("--opset", type=int, default=18, help="ONNX opset version") + parser.add_argument( + "--fake-trajectory", + action="store_true", + help="Add trajectory output by repeating the argmax action through jerk dynamics", + ) + parser.add_argument("--traj-steps", type=int, default=80, help="Number of trajectory rollout steps") + parser.add_argument("--traj-dt", type=float, default=0.1, help="Timestep for trajectory rollout") + parser.add_argument("--render", action="store_true", help="Render eval video after export") args = parser.parse_args() @@ -89,13 +238,15 @@ def export_to_onnx(verify=True): else: state_dict = checkpoint - # Strip compile prefixes + # Strip DDP and compile prefixes new_state_dict = {} for k, v in state_dict.items(): - if k.startswith("_orig_mod."): - new_state_dict[k[10:]] = v - else: - new_state_dict[k] = v + key = k + if key.startswith("module."): + key = key[7:] + if key.startswith("_orig_mod."): + key = key[10:] + new_state_dict[key] = v policy.load_state_dict(new_state_dict) policy.eval() @@ -158,7 +309,12 @@ def export_to_onnx(verify=True): dummy_c = torch.zeros(batch_size, hidden_size) # Wrap policy for export - onnx_policy = OnnxWrapper(policy) + onnx_policy = OnnxWrapper( + policy, + fake_trajectory=args.fake_trajectory, + traj_steps=args.traj_steps, + dt=args.traj_dt, + ) onnx_policy.eval() # Determine output path @@ -172,6 +328,7 @@ def export_to_onnx(verify=True): print(f"Exporting to {args.output}...") # Dynamic axes for batch size flexibility + output_names = ["logits", "value", "lstm_h_out", "lstm_c_out"] dynamic_axes = { "observation": {0: "batch_size"}, "lstm_h_in": {0: "batch_size"}, @@ -181,6 +338,9 @@ def export_to_onnx(verify=True): "lstm_h_out": {0: "batch_size"}, "lstm_c_out": {0: "batch_size"}, } + if args.fake_trajectory: + output_names.append("trajectory") + dynamic_axes["trajectory"] = {0: "batch_size"} dummy_inputs = (dummy_obs, dummy_h, dummy_c) torch.onnx.export( @@ -191,7 +351,7 @@ def export_to_onnx(verify=True): opset_version=args.opset, do_constant_folding=True, input_names=["observation", "lstm_h_in", "lstm_c_in"], - output_names=["logits", "value", "lstm_h_out", "lstm_c_out"], + output_names=output_names, dynamic_axes=dynamic_axes, ) @@ -211,12 +371,16 @@ def export_to_onnx(verify=True): # PyTorch output with torch.no_grad(): - torch_logits, torch_value, torch_h, torch_c = onnx_policy(*dummy_inputs) + torch_outs = onnx_policy(*dummy_inputs) + if args.fake_trajectory: + torch_logits, torch_value, torch_h, torch_c, torch_traj = torch_outs + else: + torch_logits, torch_value, torch_h, torch_c = torch_outs # Output .pt files for testing print(f"Saving test inputs/outputs to {output_dir}") torch.save(dummy_inputs, os.path.join(output_dir, "test_inputs.pt")) - torch.save((torch_logits, torch_value, torch_h, torch_c), os.path.join(output_dir, "test_outputs.pt")) + torch.save(torch_outs, os.path.join(output_dir, "test_outputs.pt")) # ONNX Runtime output ort_inputs = {"observation": dummy_obs.numpy(), "lstm_h_in": dummy_h.numpy(), "lstm_c_in": dummy_c.numpy()} @@ -254,7 +418,88 @@ def compare(name, torch_out, ort_out, atol=1e-5): compare("LSTM h", torch_h, ort_h) compare("LSTM c", torch_c, ort_c) - # Export example input and output to .pt files + if args.fake_trajectory: + ort_traj = ort_outs[4] + compare("Trajectory", torch_traj, ort_traj) + + # Optionally render video using the visualize binary + if args.render: + print("\nExporting weights to .bin for rendering...") + from pufferlib.pufferl import export as export_bin + + bin_path = os.path.splitext(args.output)[0] + ".bin" + export_bin( + env_name=args.env, + path=bin_path, + args={"env_name": config["base"]["env_name"], "load_model_path": args.checkpoint, **config}, + vecenv=vecenv, + policy=policy, + silent=True, + ) + print(f"Saved .bin weights to {bin_path}") + + # Generate INI with gigaflow-matching coefficients pinned (min=max) + import configparser + import tempfile + ini_config = configparser.ConfigParser() + ini_config.read("pufferlib/config/ocean/drive.ini") + # Gigaflow-matching coefficients (sign matches INI convention) + gigaflow_bounds = { + "goal_radius": (10.0, 10.0), + "collision": (-3.0, -3.0), + "offroad": (-3.0, -3.0), + "comfort": (-0.05, -0.05), + "lane_align": (0.025, 0.025), + "vel_align": (1.0, 1.0), + "lane_center": (-0.0038, -0.0038), + "center_bias": (0.0, 0.0), + "velocity": (0.0025, 0.0025), + "reverse": (-0.005, -0.005), + "traffic_light": (-1.0, -1.0), + "timestep": (-0.000025, -0.000025), + "overspeed": (-1.0, -1.0), + "throttle": (1.0, 1.0), + "steer": (1.0, 1.0), + "acc": (1.0, 1.0), + } + for key, (lo, hi) in gigaflow_bounds.items(): + ini_config.set("env", f"reward_bound_{key}_min", str(lo)) + ini_config.set("env", f"reward_bound_{key}_max", str(hi)) + ini_config.set("env", "goal_radius", "10.0") + ini_config.set("env", "max_goal_speed", "3.0") + ini_config.set("env", "episode_length", "1000") + ini_config.set("env", "resample_frequency", "0") + ini_config.set("env", "termination_mode", "0") + ini_config.set("env", "goal_behavior", "1") + fd, ini_path = tempfile.mkstemp(suffix=".ini", prefix="gigaflow_") + with os.fdopen(fd, "w") as f: + ini_config.write(f) + + # Build visualize binary + import subprocess + subprocess.run(["bash", "scripts/build_ocean.sh", "visualize", "local"], check=True) + + # Pick a map + map_dir = config["env"].get("map_dir", "resources/drive/binaries/carla_2D") + map_files = sorted([f for f in os.listdir(map_dir) if f.endswith(".bin")]) + if not map_files: + print(f"No maps in {map_dir}") + else: + map_path = os.path.join(map_dir, map_files[0]) + output_video = os.path.splitext(args.output)[0] + "_video.mp4" + cmd = [ + "./visualize", + "--config", ini_path, + "--policy-name", bin_path, + "--map-name", map_path, + "--view", "both", + "--output-topdown", output_video, + ] + print(f"Running: {' '.join(cmd)}") + subprocess.run(cmd) + print(f"Video saved to {output_video}") + print(f"INI config used: {ini_path}") + # os.remove(ini_path) # keep for debugging if __name__ == "__main__": diff --git a/scripts/onnx_inference.py b/scripts/onnx_inference.py new file mode 100644 index 000000000..17063247b --- /dev/null +++ b/scripts/onnx_inference.py @@ -0,0 +1,521 @@ +"""Utility for running inference on an exported PufferDrive ONNX model. + +Handles observation construction, reward conditioning insertion, and action decoding. + +Usage: + from scripts.onnx_inference import DrivePolicyONNX + + policy = DrivePolicyONNX("path/to/model.onnx") + action, value, trajectory = policy.step(ego_state, partners, road_segments) +""" + +import numpy as np +import onnxruntime as ort + + +# Observation layout constants (must match drive.h) +MAX_AGENTS = 128 +MAX_ROAD_SEGMENT_OBSERVATIONS = 256 +PARTNER_FEATURES = 8 +ROAD_FEATURES = 8 +EGO_FEATURES_JERK = 16 +EGO_FEATURES_CLASSIC = 13 +NUM_REWARD_COEFS = 16 +MAX_SPEED = 100.0 +MAX_VEH_WIDTH = 3.0 +MAX_VEH_LEN = 6.0 +SPEED_LIMIT = 20.0 + +# Reward coefficient indices (must match datatypes.h) +COEF_NAMES = [ + "goal_radius", # 0 + "collision", # 1 + "offroad", # 2 + "comfort", # 3 + "lane_align", # 4 + "lane_center", # 5 + "velocity", # 6 + "traffic_light", # 7 + "center_bias", # 8 + "vel_align", # 9 + "overspeed", # 10 + "timestep", # 11 + "reverse", # 12 + "throttle", # 13 + "steer", # 14 + "acc", # 15 +] + +# Gigaflow-matching defaults +GIGAFLOW_COEFS = { + "goal_radius": 10.0, + "collision": -3.0, + "offroad": -3.0, + "comfort": -0.05, + "lane_align": 0.025, + "lane_center": -0.0038, + "velocity": 0.0025, + "traffic_light": -1.0, + "center_bias": 0.0, + "vel_align": 1.0, + "overspeed": -1.0, + "timestep": -0.000025, + "reverse": -0.005, + "throttle": 1.0, + "steer": 1.0, + "acc": 1.0, +} + +# Jerk action decoding +JERK_LONG = np.array([-15.0, -4.0, 0.0, 4.0]) +JERK_LAT = np.array([-4.0, 0.0, 4.0]) + + +class DrivePolicyONNX: + """Wrapper for running inference on an exported PufferDrive ONNX model. + + Args: + onnx_path: Path to the .onnx model file. + dynamics_model: "jerk" or "classic". + reward_coefs: Dict of reward coefficient name → value. + Defaults to gigaflow-matching values. + reward_bounds: Dict of reward coefficient name → (min, max) for normalization. + If None, uses gigaflow bounds (min=max, normalized to 0). + normalize_coefs: If True, normalize reward coefs to [0, 1] using bounds. + If False, pass raw values (matches turn_off_normalization=1). + """ + + def __init__( + self, + onnx_path, + dynamics_model="jerk", + reward_coefs=None, + reward_bounds=None, + normalize_coefs=False, + ): + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = 1 + sess_options.inter_op_num_threads = 1 + self.session = ort.InferenceSession(onnx_path, sess_options) + + self.dynamics_model = dynamics_model + self.ego_dim = EGO_FEATURES_JERK if dynamics_model == "jerk" else EGO_FEATURES_CLASSIC + self.obs_dim = ( + self.ego_dim + + NUM_REWARD_COEFS + + (MAX_AGENTS - 1) * PARTNER_FEATURES + + MAX_ROAD_SEGMENT_OBSERVATIONS * ROAD_FEATURES + ) + + # Reward conditioning + self.reward_coefs = reward_coefs or GIGAFLOW_COEFS + self.normalize_coefs = normalize_coefs + self.reward_bounds = reward_bounds + + # Build the conditioning vector + self._coef_vector = self._build_coef_vector() + + # LSTM state + output_names = [o.name for o in self.session.get_outputs()] + self.has_trajectory = "trajectory" in output_names + hidden_size = self.session.get_inputs()[1].shape[1] + self.hidden_size = hidden_size + self.lstm_h = np.zeros((1, hidden_size), dtype=np.float32) + self.lstm_c = np.zeros((1, hidden_size), dtype=np.float32) + + def _build_coef_vector(self): + """Build the 16-element reward conditioning vector.""" + raw = np.zeros(NUM_REWARD_COEFS, dtype=np.float32) + for i, name in enumerate(COEF_NAMES): + raw[i] = self.reward_coefs.get(name, 0.0) + + if self.normalize_coefs and self.reward_bounds: + normalized = np.zeros_like(raw) + for i, name in enumerate(COEF_NAMES): + lo, hi = self.reward_bounds.get(name, (raw[i], raw[i])) + rng = hi - lo + if rng < 1e-9: + normalized[i] = 0.0 + else: + normalized[i] = np.clip((raw[i] - lo) / rng, 0.0, 1.0) + return normalized + else: + return raw + + def build_observation(self, ego_state, partners=None, road_segments=None): + """Build a flattened observation vector. + + Args: + ego_state: Dict with keys matching the ego observation layout: + For jerk dynamics: + rel_goal_x, rel_goal_y, rel_goal_z: relative goal position + speed: signed speed (m/s) + width: vehicle width (m) + length: vehicle length (m) + collision: bool, currently colliding + steering_angle: current steering angle (rad) + a_long: longitudinal acceleration + a_lat: lateral acceleration + respawned: bool, has been respawned + goal_speed_min, goal_speed_max: normalized goal speed bounds + speed_limit: speed limit (m/s) + lane_center_dist: distance from lane center + lane_angle: cos(heading diff from lane) + + partners: np.array of shape [num_partners, 8] or None. + Each partner: [rel_x, rel_y, rel_z, speed, width, length, heading_x, heading_y] + Padded to MAX_AGENTS-1 with zeros. + + road_segments: np.array of shape [num_segments, 8] or None. + Each segment: [rel_x, rel_y, rel_z, length, width, cos_angle, sin_angle, type] + Padded to MAX_ROAD_SEGMENT_OBSERVATIONS with zeros. + + Returns: + obs: np.array of shape [1, obs_dim], ready for ONNX inference. + """ + obs = np.zeros((1, self.obs_dim), dtype=np.float32) + idx = 0 + + # Ego features + if self.dynamics_model == "jerk": + obs[0, 0] = ego_state.get("rel_goal_x", 0) * 0.005 + obs[0, 1] = ego_state.get("rel_goal_y", 0) * 0.005 + obs[0, 2] = ego_state.get("rel_goal_z", 0) * 0.005 + obs[0, 3] = ego_state.get("speed", 0) / MAX_SPEED + obs[0, 4] = ego_state.get("width", 2.0) / MAX_VEH_WIDTH + obs[0, 5] = ego_state.get("length", 4.5) / MAX_VEH_LEN + obs[0, 6] = 1.0 if ego_state.get("collision", False) else 0.0 + obs[0, 7] = ego_state.get("steering_angle", 0) / np.pi + a_long = ego_state.get("a_long", 0) + obs[0, 8] = a_long / 15.0 if a_long < 0 else a_long / 4.0 + obs[0, 9] = ego_state.get("a_lat", 0) / 4.0 + obs[0, 10] = 1.0 if ego_state.get("respawned", False) else 0.0 + obs[0, 11] = ego_state.get("goal_speed_min", 0) + obs[0, 12] = ego_state.get("goal_speed_max", 0) + obs[0, 13] = min(SPEED_LIMIT / MAX_SPEED, 1.0) + obs[0, 14] = ego_state.get("lane_center_dist", 0) + obs[0, 15] = ego_state.get("lane_angle", 1.0) + idx = EGO_FEATURES_JERK + else: + obs[0, 0] = ego_state.get("rel_goal_x", 0) * 0.005 + obs[0, 1] = ego_state.get("rel_goal_y", 0) * 0.005 + obs[0, 2] = ego_state.get("rel_goal_z", 0) * 0.005 + obs[0, 3] = ego_state.get("speed", 0) / MAX_SPEED + obs[0, 4] = ego_state.get("width", 2.0) / MAX_VEH_WIDTH + obs[0, 5] = ego_state.get("length", 4.5) / MAX_VEH_LEN + obs[0, 6] = 1.0 if ego_state.get("collision", False) else 0.0 + obs[0, 7] = 1.0 if ego_state.get("respawned", False) else 0.0 + obs[0, 8] = ego_state.get("goal_speed_min", 0) + obs[0, 9] = ego_state.get("goal_speed_max", 0) + obs[0, 10] = min(SPEED_LIMIT / MAX_SPEED, 1.0) + obs[0, 11] = ego_state.get("lane_center_dist", 0) + obs[0, 12] = ego_state.get("lane_angle", 1.0) + idx = EGO_FEATURES_CLASSIC + + # Reward conditioning (16 values) + obs[0, idx : idx + NUM_REWARD_COEFS] = self._coef_vector + idx += NUM_REWARD_COEFS + + # Partner features + max_partners = MAX_AGENTS - 1 + if partners is not None: + n = min(len(partners), max_partners) + obs[0, idx : idx + n * PARTNER_FEATURES] = partners[:n].flatten() + idx += max_partners * PARTNER_FEATURES + + # Road segment features + if road_segments is not None: + n = min(len(road_segments), MAX_ROAD_SEGMENT_OBSERVATIONS) + obs[0, idx : idx + n * ROAD_FEATURES] = road_segments[:n].flatten() + + return obs + + def step(self, ego_state, partners=None, road_segments=None): + """Run one inference step. + + Args: + ego_state: Dict of ego features (see build_observation). + partners: [N, 8] array of partner observations, or None. + road_segments: [N, 8] array of road segment observations, or None. + + Returns: + action: int, the discrete action index. + value: float, value estimate. + trajectory: np.array [traj_steps, 3] of (x, y, heading), or None. + """ + obs = self.build_observation(ego_state, partners, road_segments) + + inputs = { + "observation": obs, + "lstm_h_in": self.lstm_h, + "lstm_c_in": self.lstm_c, + } + + outputs = self.session.run(None, inputs) + + logits = outputs[0] # [1, num_actions] + value = outputs[1] # [1, 1] + self.lstm_h = outputs[2] + self.lstm_c = outputs[3] + + action = int(np.argmax(logits[0])) + trajectory = None + if self.has_trajectory and len(outputs) > 4: + trajectory = outputs[4][0] # [traj_steps, 3] + + return action, float(value[0, 0]), trajectory + + def decode_action(self, action): + """Decode a discrete action index into human-readable values. + + Returns: + Dict with jerk_long, jerk_lat (for jerk model) or + acceleration, steering (for classic model). + """ + if self.dynamics_model == "jerk": + num_lat = len(JERK_LAT) + a_long_idx = action // num_lat + a_lat_idx = action % num_lat + return { + "jerk_long": float(JERK_LONG[a_long_idx]), + "jerk_lat": float(JERK_LAT[a_lat_idx]), + "a_long_idx": a_long_idx, + "a_lat_idx": a_lat_idx, + } + else: + from pufferlib.ocean.drive.drive import ACCELERATION_VALUES, STEERING_VALUES + + num_steer = 13 + accel_idx = action // num_steer + steer_idx = action % num_steer + return { + "acceleration": float(ACCELERATION_VALUES[accel_idx]), + "steering": float(STEERING_VALUES[steer_idx]), + "accel_idx": accel_idx, + "steer_idx": steer_idx, + } + + def reset(self): + """Reset LSTM state.""" + self.lstm_h = np.zeros((1, self.hidden_size), dtype=np.float32) + self.lstm_c = np.zeros((1, self.hidden_size), dtype=np.float32) + + +def parse_env_obs_to_ego_state(obs_row, dynamics_model="jerk"): + """Reverse-parse a single agent's env observation into an ego_state dict. + + This undoes the normalization in compute_observations so we can feed + the raw values into build_observation and check the round-trip. + """ + ego = {} + ego["rel_goal_x"] = obs_row[0] / 0.005 + ego["rel_goal_y"] = obs_row[1] / 0.005 + ego["rel_goal_z"] = obs_row[2] / 0.005 + ego["speed"] = obs_row[3] * MAX_SPEED + ego["width"] = obs_row[4] * MAX_VEH_WIDTH + ego["length"] = obs_row[5] * MAX_VEH_LEN + ego["collision"] = obs_row[6] > 0.5 + + if dynamics_model == "jerk": + ego["steering_angle"] = obs_row[7] * np.pi + # a_long: asymmetric normalization + a_long_norm = obs_row[8] + ego["a_long"] = a_long_norm * 15.0 if a_long_norm < 0 else a_long_norm * 4.0 + ego["a_lat"] = obs_row[9] * 4.0 + ego["respawned"] = obs_row[10] > 0.5 + ego["goal_speed_min"] = obs_row[11] + ego["goal_speed_max"] = obs_row[12] + ego["speed_limit"] = obs_row[13] * MAX_SPEED + ego["lane_center_dist"] = obs_row[14] + ego["lane_angle"] = obs_row[15] + ego_dim = EGO_FEATURES_JERK + else: + ego["respawned"] = obs_row[7] > 0.5 + ego["goal_speed_min"] = obs_row[8] + ego["goal_speed_max"] = obs_row[9] + ego["speed_limit"] = obs_row[10] * MAX_SPEED + ego["lane_center_dist"] = obs_row[11] + ego["lane_angle"] = obs_row[12] + ego_dim = EGO_FEATURES_CLASSIC + + # Extract reward conditioning (16 values after ego features) + coef_start = ego_dim + reward_coefs = obs_row[coef_start : coef_start + NUM_REWARD_COEFS] + + # Extract partners + partner_start = ego_dim + NUM_REWARD_COEFS + max_partners = MAX_AGENTS - 1 + partners = obs_row[partner_start : partner_start + max_partners * PARTNER_FEATURES] + partners = partners.reshape(max_partners, PARTNER_FEATURES) + + # Extract road segments + road_start = partner_start + max_partners * PARTNER_FEATURES + roads = obs_row[road_start : road_start + MAX_ROAD_SEGMENT_OBSERVATIONS * ROAD_FEATURES] + roads = roads.reshape(MAX_ROAD_SEGMENT_OBSERVATIONS, ROAD_FEATURES) + + return ego, reward_coefs, partners, roads + + +def verify_onnx_vs_pytorch(onnx_path, checkpoint_path, env_name="puffer_drive", num_steps=5): + """Verify ONNX model matches PyTorch model on real env observations. + + Two checks: + 1. Export correctness: same env obs → same outputs from ONNX and PyTorch. + 2. Observation construction: env obs round-trips through parse → build_observation. + """ + import torch + import importlib + import pufferlib.vector + import pufferlib.models + from pufferlib.ocean.torch import Drive as DrivePolicy + from scripts.export_model_bin import load_config + + config = load_config(env_name) + package = config["base"]["package"] + module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" + env_module = importlib.import_module(module_name) + make_env = env_module.env_creator(env_name) + + vecenv = pufferlib.vector.make( + make_env, env_kwargs=config["env"], backend=pufferlib.vector.Serial, num_envs=1, + ) + + # Load PyTorch policy + policy = DrivePolicy(vecenv.driver_env, **config["policy"]) + if config["base"]["rnn_name"]: + policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"]) + + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if isinstance(checkpoint, dict) and "agent_state_dict" in checkpoint: + state_dict = checkpoint["agent_state_dict"] + else: + state_dict = checkpoint + state_dict = {k.replace("module.", "").replace("_orig_mod.", ""): v for k, v in state_dict.items()} + policy.load_state_dict(state_dict) + policy.eval() + + # Load ONNX model + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = 1 + sess_options.inter_op_num_threads = 1 + ort_session = ort.InferenceSession(onnx_path, sess_options) + + # Run both on real env observations + ob, _ = vecenv.reset() + num_agents = ob.shape[0] + hidden_size = config["rnn"]["hidden_size"] + pt_h = torch.zeros(num_agents, hidden_size) + pt_c = torch.zeros(num_agents, hidden_size) + ort_h = np.zeros((num_agents, hidden_size), dtype=np.float32) + ort_c = np.zeros((num_agents, hidden_size), dtype=np.float32) + + # Determine dynamics model + dynamics = config["env"].get("dynamics_model", "jerk") + + # Build an inference wrapper to get build_observation + onnx_policy = DrivePolicyONNX(onnx_path, dynamics_model=dynamics, normalize_coefs=False) + + all_pass = True + for step in range(num_steps): + ob_t = torch.as_tensor(ob).float() + + # --- Check 1: Observation round-trip --- + # Parse env obs → raw fields → rebuild via build_observation → compare + agent_0_obs = ob[0].astype(np.float32) + ego, coefs, partners, roads = parse_env_obs_to_ego_state(agent_0_obs, dynamics) + + # Override the inference wrapper's coef vector with what the env actually produced + onnx_policy._coef_vector = coefs + rebuilt_obs = onnx_policy.build_observation(ego, partners, roads) + rebuilt_obs_flat = rebuilt_obs[0] + + obs_match = np.allclose(agent_0_obs, rebuilt_obs_flat, atol=1e-4) + max_obs_diff = np.abs(agent_0_obs - rebuilt_obs_flat).max() + if not obs_match: + # Find which indices differ + diffs = np.where(np.abs(agent_0_obs - rebuilt_obs_flat) > 1e-4)[0] + print(f"Step {step}: OBS MISMATCH at indices {diffs[:10]} max_diff={max_obs_diff:.6f}") + all_pass = False + else: + print(f"Step {step}: OBS MATCH max_diff={max_obs_diff:.6f}") + + # --- Check 2: Model output match --- + # PyTorch gets env obs, ONNX gets rebuilt obs + with torch.no_grad(): + state = {"lstm_h": pt_h, "lstm_c": pt_c} + pt_logits, pt_value = policy.forward_eval(ob_t, state) + pt_h = state["lstm_h"] + pt_c = state["lstm_c"] + + if isinstance(pt_logits, (tuple, list)): + pt_logits_np = pt_logits[0].numpy() + else: + pt_logits_np = pt_logits.numpy() + + # ONNX gets the REBUILT observation (tests full pipeline) + ort_inputs = { + "observation": rebuilt_obs.astype(np.float32), + "lstm_h_in": ort_h[0:1], + "lstm_c_in": ort_c[0:1], + } + ort_outs = ort_session.run(None, ort_inputs) + ort_logits_np = ort_outs[0] + ort_value_np = ort_outs[1] + ort_h[0:1] = ort_outs[2] + ort_c[0:1] = ort_outs[3] + + logits_match = np.allclose(pt_logits_np[0:1], ort_logits_np, atol=1e-3) + value_match = np.allclose(pt_value[0:1].numpy(), ort_value_np, atol=1e-3) + pt_action = pt_logits_np[0].argmax() + ort_action = ort_logits_np[0].argmax() + action_match = pt_action == ort_action + + status = "PASS" if (logits_match and value_match and action_match) else "FAIL" + if status == "FAIL": + all_pass = False + print(f" MODEL {status}" + f" logits_close={logits_match}" + f" value_close={value_match}" + f" action_match={action_match} (pt={pt_action}, ort={ort_action})" + f" max_logit_diff={np.abs(pt_logits_np[0:1] - ort_logits_np).max():.6f}") + + # Step env + import pufferlib.pytorch + action, _, _ = pufferlib.pytorch.sample_logits(pt_logits) + action_np = action.cpu().numpy().reshape(vecenv.action_space.shape) + ob = vecenv.step(action_np)[0] + + vecenv.close() + if all_pass: + print(f"\nAll {num_steps} steps PASSED — observation normalization and ONNX output verified.") + else: + print(f"\nSome steps FAILED.") + return all_pass + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="ONNX inference utility") + parser.add_argument("--model", type=str, required=True, help="Path to .onnx model") + parser.add_argument("--dynamics", type=str, default="jerk", choices=["jerk", "classic"]) + parser.add_argument("--verify", type=str, default=None, help="Path to .pt checkpoint to verify against") + parser.add_argument("--steps", type=int, default=5, help="Number of verification steps") + args = parser.parse_args() + + if args.verify: + verify_onnx_vs_pytorch(args.model, args.verify, num_steps=args.steps) + else: + policy = DrivePolicyONNX(args.model, dynamics_model=args.dynamics) + ego = { + "rel_goal_x": 50.0, "rel_goal_y": 0.0, "rel_goal_z": 0.0, + "speed": 10.0, "width": 2.0, "length": 4.5, + "steering_angle": 0.0, "a_long": 0.0, "a_lat": 0.0, + "lane_center_dist": 0.0, "lane_angle": 1.0, + } + action, value, trajectory = policy.step(ego) + decoded = policy.decode_action(action) + print(f"Action: {action} → {decoded}") + print(f"Value: {value:.4f}") + if trajectory is not None: + print(f"Trajectory: {trajectory.shape}, endpoint: ({trajectory[-1, 0]:.1f}, {trajectory[-1, 1]:.1f})")