Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pufferlib/ocean/drive/drive.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void test_drivenet() {
Weights *weights = load_weights("puffer_drive_weights.bin");
DriveNet *net = init_drivenet(weights, num_agents, CLASSIC, 0);

forward(net, observations, actions);
forward(net, NULL, observations, actions);
for (int i = 0; i < num_agents * num_actions; i++) {
printf("idx: %d, action: %d, logits:", i, actions[i]);
for (int j = 0; j < num_actions; j++) {
Expand Down Expand Up @@ -126,7 +126,7 @@ void demo() {
int *actions = (int *)env.actions; // Single integer per agent

if (!IsKeyDown(KEY_LEFT_SHIFT)) {
forward(net, env.observations, actions);
forward(net, &env, env.observations, actions);
} else {
if (env.dynamics_model == CLASSIC) {
// Classic dynamics: acceleration and steering
Expand Down
243 changes: 139 additions & 104 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ struct Drive {
float observation_window_size;
float polyline_reduction_threshold;
float polyline_max_segment_length;
// Predicted trajectory visualization (set from Python, drawn by c_render)
float *predicted_traj_x; // [num_agents * predicted_traj_len]
float *predicted_traj_y; // [num_agents * predicted_traj_len]
int predicted_traj_len; // steps per agent (0 = no trajectory to draw)

int sdc_track_index;
int num_tracks_to_predict;
int *tracks_to_predict_indices;
Expand All @@ -374,6 +379,84 @@ void sample_new_goal(Drive *env, int agent_idx);
int check_lane_aligned(Agent *car, RoadMapElement *lane, int geometry_idx);
void reset_goal_positions(Drive *env);

// ========================================
// Pure dynamics step functions
// ========================================

typedef struct {
float x, y, heading, vx, vy;
float a_long, a_lat, steering_angle;
} DynState;

// Classic bicycle model: action is (acceleration, steering_angle)
static inline DynState classic_dynamics_step(DynState s, float acceleration, float steering, float length, float dt) {
DynState out = s;
float speed_mag = sqrtf(s.vx * s.vx + s.vy * s.vy);
float v_dot = s.vx * cosf(s.heading) + s.vy * sinf(s.heading);
float signed_speed = copysignf(speed_mag, v_dot);
signed_speed += acceleration * dt;
signed_speed = clipSpeed(signed_speed);
float beta = tanhf(0.5f * tanf(steering));
float new_vx = signed_speed * cosf(s.heading + beta);
float new_vy = signed_speed * sinf(s.heading + beta);
float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / length;
out.vx = new_vx;
out.vy = new_vy;
out.x = s.x + new_vx * dt;
out.y = s.y + new_vy * dt;
out.heading = s.heading + yaw_rate * dt;
return out;
}

// Jerk model: action is (jerk_long, jerk_lat). State includes a_long, a_lat, steering_angle.
static inline DynState jerk_dynamics_step(DynState s, float jerk_long, float jerk_lat, float wheelbase, float dt) {
DynState out = s;
float al_new = s.a_long + jerk_long * dt;
float at_new = s.a_lat + jerk_lat * dt;

// Zero-crossing clamp
al_new = (s.a_long * al_new < 0) ? 0.0f : fminf(fmaxf(al_new, -5.0f), 2.5f);
at_new = (s.a_lat * at_new < 0) ? 0.0f : fminf(fmaxf(at_new, -4.0f), 4.0f);

// Velocity
float v_dot = s.vx * cosf(s.heading) + s.vy * sinf(s.heading);
float signed_v = copysignf(sqrtf(s.vx * s.vx + s.vy * s.vy), v_dot);
float v_new = signed_v + 0.5f * (al_new + s.a_long) * dt;
v_new = (signed_v * v_new < 0) ? 0.0f : fminf(fmaxf(v_new, -2.0f), SPEED_LIMIT);

// Steering from lateral acceleration
float curvature = at_new / fmaxf(v_new * v_new, 1e-5f);
float steer_target = atanf(curvature * wheelbase);
float delta_steer = fminf(fmaxf(steer_target - s.steering_angle, -0.6f * dt), 0.6f * dt);
float steer_new = fminf(fmaxf(s.steering_angle + delta_steer, -0.55f), 0.55f);
curvature = tanf(steer_new) / wheelbase;
at_new = v_new * v_new * curvature;

// Bicycle displacement
float d = 0.5f * (v_new + signed_v) * dt;
float theta = d * curvature;
float dx_local, dy_local;
if (fabsf(curvature) < 1e-5f || fabsf(theta) < 1e-5f) {
dx_local = d;
dy_local = 0.0f;
} else {
dx_local = sinf(theta) / curvature;
dy_local = (1.0f - cosf(theta)) / curvature;
}
float cos_h = cosf(s.heading);
float sin_h = sinf(s.heading);

out.x = s.x + dx_local * cos_h - dy_local * sin_h;
out.y = s.y + dx_local * sin_h + dy_local * cos_h;
out.heading = s.heading + theta;
out.vx = v_new * cosf(out.heading);
out.vy = v_new * sinf(out.heading);
out.a_long = al_new;
out.a_lat = at_new;
out.steering_angle = steer_new;
return out;
}

// ========================================
// Utility Functions
// ========================================
Expand Down Expand Up @@ -2273,6 +2356,8 @@ void c_close(Drive *env) {
free(env->static_agent_indices);
free(env->expert_static_agent_indices);
free(env->tracks_to_predict_indices);
free(env->predicted_traj_x);
free(env->predicted_traj_y);
free(env->ini_file);
}

Expand Down Expand Up @@ -2918,42 +3003,18 @@ void move_dynamics(Drive *env, int action_idx, int agent_idx) {
steering = STEERING_VALUES[steering_index];
}

// Current state
float heading = agent->sim_heading;
float vx = agent->sim_vx;
float vy = agent->sim_vy;

// Calculate current speed (signed based on direction relative to heading)
float speed_magnitude = sqrtf(vx * vx + vy * vy);
float v_dot_heading = vx * cosf(heading) + vy * sinf(heading);
float signed_speed = copysignf(speed_magnitude, v_dot_heading);

// Update speed with acceleration
signed_speed = signed_speed + acceleration * env->dt;
signed_speed = clipSpeed(signed_speed);
// Compute yaw rate
float beta = tanh(.5 * tanf(steering));

// New heading
float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / agent->sim_length;

// New velocity
float new_vx = signed_speed * cosf(heading + beta);
float new_vy = signed_speed * sinf(heading + beta);

// Update position
float dx = new_vx * env->dt;
float dy = new_vy * env->dt;
float dheading = yaw_rate * env->dt;

// Apply updates to the agent's state
agent->sim_x += dx;
agent->sim_y += dy;
agent->sim_heading += dheading;
agent->heading_x = cosf(agent->sim_heading);
agent->heading_y = sinf(agent->sim_heading);
agent->sim_vx = new_vx;
agent->sim_vy = new_vy;
DynState s = {agent->sim_x, agent->sim_y, agent->sim_heading,
agent->sim_vx, agent->sim_vy, 0, 0, 0};
DynState ns = classic_dynamics_step(s, acceleration, steering, agent->sim_length, env->dt);
float dx = ns.x - s.x;
float dy = ns.y - s.y;
agent->sim_x = ns.x;
agent->sim_y = ns.y;
agent->sim_heading = ns.heading;
agent->heading_x = cosf(ns.heading);
agent->heading_y = sinf(ns.heading);
agent->sim_vx = ns.vx;
agent->sim_vy = ns.vy;
agent->cumulative_displacement += sqrtf(dx * dx + dy * dy);
} else {
// JERK dynamics model
Expand Down Expand Up @@ -2984,76 +3045,24 @@ void move_dynamics(Drive *env, int action_idx, int agent_idx) {
a_lat = JERK_LAT[a_lat_idx];
}

// Calculate new acceleration
float a_long_new = agent->a_long + a_long * env->dt;
float a_lat_new = agent->a_lat + a_lat * env->dt;

// Make it easy to stop with 0 accel
if (agent->a_long * a_long_new < 0) {
a_long_new = 0.0f;
} else {
a_long_new = clip(a_long_new, -5.0f, 2.5f);
}

if (agent->a_lat * a_lat_new < 0) {
a_lat_new = 0.0f;
} else {
a_lat_new = clip(a_lat_new, -4.0f, 4.0f);
}

// Calculate new velocity
float v_dot_heading = agent->sim_vx * cosf(agent->sim_heading) + agent->sim_vy * sinf(agent->sim_heading);
float signed_v = copysignf(sqrtf(agent->sim_vx * agent->sim_vx + agent->sim_vy * agent->sim_vy), v_dot_heading);
float v_new = signed_v + 0.5f * (a_long_new + agent->a_long) * env->dt;

// Make it easy to stop with 0 vel
if (signed_v * v_new < 0) {
v_new = 0.0f;
} else {
v_new = clip(v_new, -2.0f, 20.0f);
}

// Calculate new steering angle
float signed_curvature = a_lat_new / fmaxf(v_new * v_new, 1e-5f);
float steering_angle = atanf(signed_curvature * agent->wheelbase);
float delta_steer = clip(steering_angle - agent->steering_angle, -0.6f * env->dt, 0.6f * env->dt);
float new_steering_angle = clip(agent->steering_angle + delta_steer, -0.55f, 0.55f);

// Update curvature and accel to account for limited steering
signed_curvature = tanf(new_steering_angle) / agent->wheelbase;
a_lat_new = v_new * v_new * signed_curvature;

// Calculate resulting movement using bicycle dynamics
float d = 0.5f * (v_new + signed_v) * env->dt;
float theta = d * signed_curvature;
float dx_local, dy_local;

if (fabsf(signed_curvature) < 1e-5f || fabsf(theta) < 1e-5f) {
dx_local = d;
dy_local = 0.0f;
} else {
dx_local = sinf(theta) / signed_curvature;
dy_local = (1.0f - cosf(theta)) / signed_curvature;
}

float cos_heading = cosf(agent->sim_heading);
float sin_heading = sinf(agent->sim_heading);
float dx = dx_local * cos_heading - dy_local * sin_heading;
float dy = dx_local * sin_heading + dy_local * cos_heading;

// Update everything
agent->sim_x += dx;
agent->sim_y += dy;
agent->jerk_long = (a_long_new - agent->a_long) / env->dt;
agent->jerk_lat = (a_lat_new - agent->a_lat) / env->dt;
agent->a_long = a_long_new;
agent->a_lat = a_lat_new;
agent->sim_heading = normalize_heading(agent->sim_heading + theta);
DynState s = {agent->sim_x, agent->sim_y, agent->sim_heading,
agent->sim_vx, agent->sim_vy,
agent->a_long, agent->a_lat, agent->steering_angle};
DynState ns = jerk_dynamics_step(s, a_long, a_lat, agent->wheelbase, env->dt);
float dx = ns.x - s.x;
float dy = ns.y - s.y;
agent->jerk_long = (ns.a_long - agent->a_long) / env->dt;
agent->jerk_lat = (ns.a_lat - agent->a_lat) / env->dt;
agent->sim_x = ns.x;
agent->sim_y = ns.y;
agent->sim_heading = normalize_heading(ns.heading);
agent->heading_x = cosf(agent->sim_heading);
agent->heading_y = sinf(agent->sim_heading);
agent->sim_vx = v_new * agent->heading_x;
agent->sim_vy = v_new * agent->heading_y;
agent->steering_angle = new_steering_angle;
agent->sim_vx = ns.vx;
agent->sim_vy = ns.vy;
agent->a_long = ns.a_long;
agent->a_lat = ns.a_lat;
agent->steering_angle = ns.steering_angle;
agent->cumulative_displacement += sqrtf(dx * dx + dy * dy);
agent->cumulative_displacement_since_last_goal += sqrtf(dx * dx + dy * dy);
}
Expand Down Expand Up @@ -4110,6 +4119,32 @@ void c_render(Drive *env) {
handle_camera_controls(env->client);
draw_scene(env, client, 0, 0, 0, 0);

// Re-enter 3D mode for trajectory drawing (draw_scene calls EndMode3D internally)
BeginMode3D(client->camera);
rlDisableDepthTest();
if (env->predicted_traj_x != NULL && env->predicted_traj_len > 0) {
int i = env->human_agent_idx;
if (i < env->active_agent_count) {
int agent_idx = env->active_agent_indices[i];
Agent *agent = &env->agents[agent_idx];
int tlen = env->predicted_traj_len;
float z = agent->sim_z + 0.1f;

Vector3 prev = {agent->sim_x, agent->sim_y, z};
for (int t = 0; t < tlen; t++) {
float tx = env->predicted_traj_x[i * tlen + t];
float ty = env->predicted_traj_y[i * tlen + t];
Vector3 curr = {tx, ty, z};
rlSetLineWidth(3.0f);
DrawLine3D(prev, curr, RED);
DrawCircle3D(curr, 1.5f, (Vector3){0, 0, 1}, 0.0f, RED);
prev = curr;
}
}
}
rlEnableDepthTest();
EndMode3D();

if (IsKeyPressed(KEY_TAB) && env->active_agent_count > 0) {
env->human_agent_idx = (env->human_agent_idx + 1) % env->active_agent_count;
}
Expand Down
19 changes: 19 additions & 0 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,25 @@ def get_road_edge_polylines(self):

return polylines

def set_predicted_trajectories(self, action_trajectories):
"""Roll out action trajectories through dynamics and set for rendering.

Args:
action_trajectories: numpy array of shape [num_agents, traj_len]
containing discrete action indices.
"""
_, traj_len = action_trajectories.shape
for i in range(self.num_envs):
start = self.agent_offsets[i]
end = self.agent_offsets[i + 1]
sub_actions = action_trajectories[start:end]
binding.vec_set_trajectory(
self.c_envs,
i,
sub_actions.flatten().astype(np.int32),
traj_len,
)

def render(self):
binding.vec_render(self.c_envs, 0)

Expand Down
Loading
Loading