Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 171 additions & 40 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,144 @@ void init_action_space() {
}
}

static inline int entity_has_valid_timestep(Entity *agent, int t) {
return t >= 0 && t < agent->array_size && (agent->traj_valid == NULL || agent->traj_valid[t] != 0);
}

static void set_entity_invalid_state(Entity *agent) {
agent->x = INVALID_POSITION;
agent->y = INVALID_POSITION;
agent->z = 0.0f;
agent->vx = 0.0f;
agent->vy = 0.0f;
agent->vz = 0.0f;
agent->heading = 0.0f;
agent->heading_x = 1.0f;
agent->heading_y = 0.0f;
agent->valid = 0;
}

static void update_replay_velocity(Drive *env, Entity *agent, int t) {
if (agent->traj_vx != NULL && agent->traj_vy != NULL && t >= 0 && t < agent->array_size) {
agent->vx = agent->traj_vx[t];
agent->vy = agent->traj_vy[t];
agent->vz = (agent->traj_vz != NULL) ? agent->traj_vz[t] : 0.0f;
return;
}

int prev_t = t - 1;
while (prev_t >= 0 && !entity_has_valid_timestep(agent, prev_t))
prev_t--;

int next_t = t + 1;
while (next_t < agent->array_size && !entity_has_valid_timestep(agent, next_t))
next_t++;

if (prev_t >= 0 && next_t < agent->array_size) {
float dt = (next_t - prev_t) * env->dt;
if (dt > 0.0f) {
agent->vx = (agent->traj_x[next_t] - agent->traj_x[prev_t]) / dt;
agent->vy = (agent->traj_y[next_t] - agent->traj_y[prev_t]) / dt;
agent->vz = (agent->traj_z[next_t] - agent->traj_z[prev_t]) / dt;
return;
}
} else if (next_t < agent->array_size) {
float dt = (next_t - t) * env->dt;
if (dt > 0.0f) {
agent->vx = (agent->traj_x[next_t] - agent->traj_x[t]) / dt;
agent->vy = (agent->traj_y[next_t] - agent->traj_y[t]) / dt;
agent->vz = (agent->traj_z[next_t] - agent->traj_z[t]) / dt;
return;
}
} else if (prev_t >= 0) {
float dt = (t - prev_t) * env->dt;
if (dt > 0.0f) {
agent->vx = (agent->traj_x[t] - agent->traj_x[prev_t]) / dt;
agent->vy = (agent->traj_y[t] - agent->traj_y[prev_t]) / dt;
agent->vz = (agent->traj_z[t] - agent->traj_z[prev_t]) / dt;
return;
}
}

agent->vx = 0.0f;
agent->vy = 0.0f;
agent->vz = 0.0f;
}

static void set_classic_noop_action(Drive *env, int action_idx) {
if (env->action_type == 1) {
float (*action_array_f)[2] = (float (*)[2])env->actions;
action_array_f[action_idx][0] = 0.0f;
action_array_f[action_idx][1] = 0.0f;
return;
}

int *action_array = (int *)env->actions;
action_array[action_idx] = (NUM_ACCEL_BINS / 2) * NUM_STEER_BINS + (NUM_STEER_BINS / 2);
}

static float signed_speed_from_heading(float vx, float vy, float heading) {
float speed_magnitude = sqrtf(vx * vx + vy * vy);
float heading_x = cosf(heading);
float heading_y = sinf(heading);
float dot = vx * heading_x + vy * heading_y;
return copysignf(speed_magnitude, dot);
}

static int infer_classic_expert_action_from_trajectory(Drive *env, Entity *agent, int t, float *accel_out,
float *steer_out) {
if (t < 0 || t + 1 >= agent->array_size)
return 0;
if (!entity_has_valid_timestep(agent, t) || !entity_has_valid_timestep(agent, t + 1))
return 0;

float heading_t = agent->traj_heading[t];
float heading_t1 = agent->traj_heading[t + 1];
float vx_t, vy_t, vx_t1, vy_t1;

if (agent->traj_vx != NULL && agent->traj_vy != NULL) {
vx_t = agent->traj_vx[t];
vy_t = agent->traj_vy[t];
vx_t1 = agent->traj_vx[t + 1];
vy_t1 = agent->traj_vy[t + 1];
} else {
vx_t = (agent->traj_x[t + 1] - agent->traj_x[t]) / env->dt;
vy_t = (agent->traj_y[t + 1] - agent->traj_y[t]) / env->dt;
if (t + 2 < agent->array_size && entity_has_valid_timestep(agent, t + 2)) {
vx_t1 = (agent->traj_x[t + 2] - agent->traj_x[t + 1]) / env->dt;
vy_t1 = (agent->traj_y[t + 2] - agent->traj_y[t + 1]) / env->dt;
} else {
vx_t1 = vx_t;
vy_t1 = vy_t;
}
}

float speed_t = signed_speed_from_heading(vx_t, vy_t, heading_t);
float speed_t1 = signed_speed_from_heading(vx_t1, vy_t1, heading_t1);
float accel = clip((speed_t1 - speed_t) / env->dt, ACCEL_MIN, ACCEL_MAX);

float heading_diff = heading_t1 - heading_t;
while (heading_diff > M_PI)
heading_diff -= 2.0f * M_PI;
while (heading_diff < -M_PI)
heading_diff += 2.0f * M_PI;
float yaw_rate = heading_diff / env->dt;
float steer = 0.0f;

if (fabsf(speed_t) > 0.1f) {
float wheelbase = fmaxf(agent->wheelbase, 1e-3f);
float k = (yaw_rate * wheelbase) / speed_t;
float clipped_k = clip(k, -1.999f, 1.999f);
float denom = sqrtf(fmaxf(1.0f - 0.25f * clipped_k * clipped_k, 1e-6f));
float tan_steer = clipped_k / denom;
steer = atanf(tan_steer);
}

*accel_out = accel;
*steer_out = clip(steer, STEER_MIN, STEER_MAX);
return 1;
}

Entity *load_map_binary(const char *filename, Drive *env) {
FILE *file = fopen(filename, "rb");
if (!file)
Expand Down Expand Up @@ -923,22 +1061,8 @@ void set_means(Drive *env) {
void move_expert(Drive *env, float *actions, int agent_idx) {
Entity *agent = &env->entities[agent_idx];
int t = env->timestep;
if (t < 0 || t >= agent->array_size) {
agent->x = INVALID_POSITION;
agent->y = INVALID_POSITION;
agent->z = 0.0f;
agent->heading = 0.0f;
agent->heading_x = 1.0f;
agent->heading_y = 0.0f;
return;
}
if (agent->traj_valid && agent->traj_valid[t] == 0) {
agent->x = INVALID_POSITION;
agent->y = INVALID_POSITION;
agent->z = 0.0f;
agent->heading = 0.0f;
agent->heading_x = 1.0f;
agent->heading_y = 0.0f;
if (!entity_has_valid_timestep(agent, t)) {
set_entity_invalid_state(agent);
return;
}
agent->x = agent->traj_x[t];
Expand All @@ -947,6 +1071,8 @@ void move_expert(Drive *env, float *actions, int agent_idx) {
agent->heading = agent->traj_heading[t];
agent->heading_x = cosf(agent->heading);
agent->heading_y = sinf(agent->heading);
agent->valid = 1;
update_replay_velocity(env, agent, t);
}

bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) {
Expand Down Expand Up @@ -1335,13 +1461,17 @@ bool should_control_agent(Drive *env, int agent_idx, int control_limit) {
}

Entity *entity = &env->entities[agent_idx];
bool is_vehicle = (entity->type == VEHICLE);
bool is_ped_or_bike = (entity->type == PEDESTRIAN || entity->type == CYCLIST);

if (env->control_mode == CONTROL_SDC_ONLY) {
return agent_idx == env->sdc_track_index;
return agent_idx == env->sdc_track_index && (env->dynamics_model != CLASSIC || is_vehicle);
}

if (env->dynamics_model == CLASSIC && !is_vehicle) {
return false;
}

bool is_vehicle = (entity->type == VEHICLE);
bool is_ped_or_bike = (entity->type == PEDESTRIAN || entity->type == CYCLIST);
bool type_is_valid = false;

switch (env->control_mode) {
Expand Down Expand Up @@ -1696,11 +1826,12 @@ void move_dynamics(Drive *env, int action_idx, int agent_idx) {
// Update speed with acceleration
signed_speed = signed_speed + acceleration * env->dt;
signed_speed = clipSpeed(signed_speed);
// Compute yaw rate
float beta = tanh(.5 * tanf(steering));
float wheelbase = fmaxf(agent->wheelbase, 1e-3f);
float tan_steering = tanf(steering);
float beta = atanf(0.5f * tan_steering);

// New heading
float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / agent->length;
float yaw_rate = (signed_speed * cosf(beta) * tan_steering) / wheelbase;

// New velocity
float new_vx = signed_speed * cosf(heading + beta);
Expand All @@ -1709,7 +1840,7 @@ void move_dynamics(Drive *env, int action_idx, int agent_idx) {
// Update position
x = x + (new_vx * env->dt);
y = y + (new_vy * env->dt);
heading = heading + yaw_rate * env->dt;
heading = normalize_heading(heading + yaw_rate * env->dt);

// Apply updates to the agent's state
agent->x = x;
Expand Down Expand Up @@ -2262,31 +2393,31 @@ static void override_action_with_expert(Drive *env, int action_idx, int agent_id
int t = env->timestep;

if (env->dynamics_model == CLASSIC) {
if (t < 0 || t >= agent->array_size)
return;
if (agent->expert_accel == NULL || agent->expert_steering == NULL)
return;
if (agent->expert_accel[t] == -1.0f || agent->expert_steering[t] == -1.0f)
float accel = 0.0f;
float steer = 0.0f;
if (!infer_classic_expert_action_from_trajectory(env, agent, t, &accel, &steer)) {
set_classic_noop_action(env, action_idx);
return;
}

if (env->action_type == 1) { // continuous
float (*action_array_f)[2] = (float (*)[2])env->actions;
action_array_f[action_idx][0] = agent->expert_accel[t] / ACCEL_MAX;
action_array_f[action_idx][1] = agent->expert_steering[t] / STEER_MAX;
action_array_f[action_idx][0] = accel / ACCEL_MAX;
action_array_f[action_idx][1] = steer / STEER_MAX;
} else { // discrete
int best_accel_idx = 0;
float min_accel_diff = fabsf(agent->expert_accel[t] - ACCELERATION_VALUES[0]);
float min_accel_diff = fabsf(accel - ACCELERATION_VALUES[0]);
for (int j = 1; j < NUM_ACCEL_BINS; j++) {
float diff = fabsf(agent->expert_accel[t] - ACCELERATION_VALUES[j]);
float diff = fabsf(accel - ACCELERATION_VALUES[j]);
if (diff < min_accel_diff) {
min_accel_diff = diff;
best_accel_idx = j;
}
}
int best_steer_idx = 0;
float min_steer_diff = fabsf(agent->expert_steering[t] - STEERING_VALUES[0]);
float min_steer_diff = fabsf(steer - STEERING_VALUES[0]);
for (int j = 1; j < NUM_STEER_BINS; j++) {
float diff = fabsf(agent->expert_steering[t] - STEERING_VALUES[j]);
float diff = fabsf(steer - STEERING_VALUES[j]);
if (diff < min_steer_diff) {
min_steer_diff = diff;
best_steer_idx = j;
Expand Down Expand Up @@ -2534,6 +2665,7 @@ void c_collect_expert_data(Drive *env, float *expert_actions_discrete_out, float
int ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES;
int max_obs = ego_dim + PARTNER_FEATURES * (MAX_AGENTS - 1) + ROAD_FEATURES * MAX_ROAD_SEGMENT_OBSERVATIONS;
int original_timestep = env->timestep;
int original_control_mode = env->control_mode;
int is_delta = (env->dynamics_model == DELTA_LOCAL);

// Action dimensions per agent
Expand All @@ -2542,6 +2674,7 @@ void c_collect_expert_data(Drive *env, float *expert_actions_discrete_out, float

// Reset agents to start of trajectory
env->timestep = env->init_steps;
env->control_mode = CONTROL_REPLAY_LOGS;
set_start_position(env);
compute_observations(env);

Expand Down Expand Up @@ -2621,14 +2754,11 @@ void c_collect_expert_data(Drive *env, float *expert_actions_discrete_out, float

} else {
// Classic dynamics: joint accel × steer action
bool is_valid =
(t < agent->array_size && agent->expert_accel != NULL && agent->expert_steering != NULL &&
agent->expert_accel[t] != -1.0f && agent->expert_steering[t] != -1.0f);
float accel = 0.0f;
float steer = 0.0f;
bool is_valid = infer_classic_expert_action_from_trajectory(env, agent, t, &accel, &steer);

if (is_valid) {
float accel = agent->expert_accel[t];
float steer = agent->expert_steering[t];

expert_actions_continuous_out[cont_off + 0] = accel;
expert_actions_continuous_out[cont_off + 1] = steer;

Expand Down Expand Up @@ -2676,6 +2806,7 @@ void c_collect_expert_data(Drive *env, float *expert_actions_discrete_out, float

// Restore original state
env->timestep = original_timestep;
env->control_mode = original_control_mode;
set_start_position(env);
compute_observations(env);
}
Expand Down
Loading