diff --git a/pufferlib/ocean/drive/datatypes.h b/pufferlib/ocean/drive/datatypes.h index cb40b949c1..ef16d958f0 100644 --- a/pufferlib/ocean/drive/datatypes.h +++ b/pufferlib/ocean/drive/datatypes.h @@ -183,7 +183,7 @@ struct Agent { int closest_path_idx_wp; // Metrics and status tracking - float metrics_array[10]; // [collision, offroad, red_light, reached_goal, lane_dist, + float metrics_array[11]; // [collision, offroad, red_light, reached_goal, lane_dist, // lane_angle, comfort_violation, velocity_progress, speed_limit, avg_displacement_error] int collision_state; int aabb_collision_state; diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index e00342c1e5..95b0231311 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -60,6 +60,13 @@ #define CONTROL_WOSAC 2 #define CONTROL_SDC_ONLY 3 +// Lane selection scoring +#define LANE_SELECTION_DISTANCE_WEIGHT 0.6f +#define LANE_SELECTION_HEADING_WEIGHT 0.4f +#define LANE_DISTANCE_NORMALIZATION 4.0f +#define LANE_SWITCH_THRESHOLD 0.05f // Hysteresis: new lane must be 5% better to switch +#define LANE_ALIGN_COS_THRESHOLD 0.5f + // Minimum distance to goal position #define MIN_DISTANCE_TO_GOAL 2.0f @@ -78,8 +85,15 @@ // Metrics array indices #define COLLISION_IDX 0 #define OFFROAD_IDX 1 -#define REACHED_GOAL_IDX 2 -#define LANE_ALIGNED_IDX 3 +#define RED_LIGHT_IDX 2 +#define REACHED_GOAL_IDX 3 +#define LANE_DIST_IDX 4 +#define LANE_ANGLE_IDX 5 +#define COMFORT_VIOLATION_IDX 6 +#define VELOCITY_PROGRESS_IDX 7 +#define SPEED_LIMIT_IDX 8 +#define AVG_DISPLACEMENT_ERROR_IDX 9 +#define LANE_ALIGNED_IDX 10 // Grid cell size #define GRID_CELL_SIZE 5.0f @@ -96,9 +110,11 @@ #define ROAD_FEATURES_ONEHOT 14 #define PARTNER_FEATURES 8 +#define MAX_CHECKED_LANES 32 + // Ego features depend on dynamics model -#define EGO_FEATURES_CLASSIC 8 -#define EGO_FEATURES_JERK 11 +#define EGO_FEATURES_CLASSIC 11 +#define EGO_FEATURES_JERK 14 // Observation normalization constants #define MAX_SPEED 100.0f @@ -329,6 +345,11 @@ float normalize_heading(float heading) { return heading; } +static float compute_heading_diff(float heading1, float heading2) { + float heading_diff = heading1 - heading2; + return normalize_heading(heading_diff); +} + // Note: added for 2.5D typedef struct { float dis; @@ -850,6 +871,37 @@ void load_map_binary(const char *filename, Drive *env) { // ======================================== // void compute_multi_segment_alignment(void){} +static float compute_multi_segment_alignment(RoadMapElement *element, int center_seg_idx) { + // NOTE: This function returns the average heading in radians for a lane segment, + // with more weight given to the center segment. + + float avg_heading = 0.0f; + float total_weight = 0.0f; + + int start = (center_seg_idx > 0) ? (center_seg_idx - 1) : center_seg_idx; + int end = (center_seg_idx < element->segment_length - 2) ? (center_seg_idx + 1) : (element->segment_length - 2); + + for (int seg_idx = start; seg_idx <= end; seg_idx++) { + if (seg_idx < 0 || seg_idx >= element->segment_length - 1) + continue; + + float dx = element->x[seg_idx + 1] - element->x[seg_idx]; + float dy = element->y[seg_idx + 1] - element->y[seg_idx]; + float seg_heading = atan2f(dy, dx); + + float weight = (seg_idx == center_seg_idx) ? 2.0f : 1.0f; + + if (total_weight == 0.0f) { + avg_heading = seg_heading; + } else { + float angle_diff = compute_heading_diff(seg_heading, avg_heading); + avg_heading += weight * angle_diff / (total_weight + weight); + } + total_weight += weight; + } + + return avg_heading; +} // void get_drivable_lane_indices(void){} @@ -860,6 +912,60 @@ void load_map_binary(const char *filename, Drive *env) { // void compute_remaining_lane_distance(void){} // void find_closest_segment_on_lane(void){} +static float find_closest_segment_on_lane(RoadMapElement *lane, float agent_x, float agent_y, int *out_segment_idx) { + int num_segments = lane->segment_length - 1; + if (num_segments < 1) { + *out_segment_idx = 0; + return 1e9f; + } + + float min_dist_sq = 1e18f; + int closest_idx = 0; + float closest_cross = 0.0f; + + for (int seg_idx = 0; seg_idx < num_segments; seg_idx++) { + float seg_start_x = lane->x[seg_idx]; + float seg_start_y = lane->y[seg_idx]; + float seg_end_x = lane->x[seg_idx + 1]; + float seg_end_y = lane->y[seg_idx + 1]; + + float seg_dx = seg_end_x - seg_start_x; + float seg_dy = seg_end_y - seg_start_y; + float seg_length_sq = seg_dx * seg_dx + seg_dy * seg_dy; + + float to_agent_x = agent_x - seg_start_x; + float to_agent_y = agent_y - seg_start_y; + + // cross > 0 means agent is left of lane direction + float cross = seg_dx * to_agent_y - seg_dy * to_agent_x; + + float dist_sq; + if (seg_length_sq > 1e-6f) { + float t = (to_agent_x * seg_dx + to_agent_y * seg_dy) / seg_length_sq; + if (t <= 0.0f) { + dist_sq = to_agent_x * to_agent_x + to_agent_y * to_agent_y; + } else if (t >= 1.0f) { + float dx = agent_x - seg_end_x; + float dy = agent_y - seg_end_y; + dist_sq = dx * dx + dy * dy; + } else { + dist_sq = (cross * cross) / seg_length_sq; + } + } else { + dist_sq = to_agent_x * to_agent_x + to_agent_y * to_agent_y; + } + + if (dist_sq < min_dist_sq) { + min_dist_sq = dist_sq; + closest_idx = seg_idx; + closest_cross = cross; + } + } + + *out_segment_idx = closest_idx; + float abs_dist = sqrtf(min_dist_sq); + return (closest_cross >= 0.0f) ? -abs_dist : abs_dist; +} // void compute_log_trajectory_distance(void){} @@ -1100,7 +1206,10 @@ void reset_agent_metrics(Drive *env, int agent_idx) { Agent *agent = &env->agents[agent_idx]; agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad + agent->metrics_array[REACHED_GOAL_IDX] = 0.0f; // goal reached agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned + agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle + agent->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center agent->collision_state = 0; agent->aabb_collision_state = 0; } @@ -1155,6 +1264,8 @@ void set_start_position(Drive *env) { e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned + e->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle + e->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center e->respawn_timestep = -1; e->stopped = 0; e->removed = 0; @@ -1530,8 +1641,11 @@ void compute_agent_metrics(Drive *env, int agent_idx) { float sin_heading = sinf(agent->sim_heading); float min_distance = (float)INT16_MAX; - int closest_lane_entity_idx = -1; - int closest_lane_geometry_idx = -1; + float best_score = 1e9f; + int best_candidate_entity_idx = -1; + int best_candidate_geometry_idx = -1; + float best_candidate_signed_lane_distance = 0.0f; + float best_candidate_lane_heading = 0.0f; float corners[4][2]; for (int i = 0; i < 4; i++) { @@ -1541,6 +1655,12 @@ void compute_agent_metrics(Drive *env, int agent_idx) { agent->sim_y + (offsets[i][0] * half_length * sin_heading + offsets[i][1] * half_width * cos_heading); } int list_size = 0; + // Vehicle-width based distance threshold (3x width) + float max_distance_threshold = 3.0f * agent->sim_width; + + // Track already-checked drivable lanes to avoid redundant processing + int checked_lanes[MAX_CHECKED_LANES]; + int num_checked_lanes = 0; GridMapEntity *entity_list = checkNeighbors(env, agent->sim_x, agent->sim_y, collision_offsets, COLLISION_RANGE * COLLISION_RANGE, &list_size); for (int i = 0; i < list_size; i++) { @@ -1550,6 +1670,7 @@ void compute_agent_metrics(Drive *env, int agent_idx) { continue; RoadMapElement *entity; entity = &env->road_elements[entity_list[i].entity_idx]; + int entity_idx = entity_list[i].entity_idx; // Check for offroad collision with road edges if (entity->type == ROAD_EDGE) { @@ -1571,41 +1692,88 @@ void compute_agent_metrics(Drive *env, int agent_idx) { break; // Find closest point on the road centerline to the agent - if (entity->type == ROAD_LANE) { - int entity_idx = entity_list[i].entity_idx; - int geometry_idx = entity_list[i].geometry_idx; + if (is_drivable_road_lane(entity->type) || entity->type == ROAD_LANE) { + // Check if we've already processed this lane (skip duplicates) + int already_checked = 0; + for (int c = 0; c < num_checked_lanes; c++) { + if (checked_lanes[c] == entity_idx) { + already_checked = 1; + break; + } + } + if (already_checked) + continue; - float start[2] = {entity->x[geometry_idx], entity->y[geometry_idx]}; - float end[2] = {entity->x[geometry_idx + 1], entity->y[geometry_idx + 1]}; + // Mark this lane as checked + if (num_checked_lanes < MAX_CHECKED_LANES) { + checked_lanes[num_checked_lanes++] = entity_idx; + } - float dist = point_to_segment_distance_2d(agent->sim_x, agent->sim_y, start[0], start[1], end[0], end[1]); - float heading_diff = fabsf(atan2f(end[1] - start[1], end[0] - start[0]) - agent->sim_heading); + // Find closest segment on this lane (returns signed distance) + int closest_segment_idx; + float signed_dist = find_closest_segment_on_lane(entity, agent->sim_x, agent->sim_y, &closest_segment_idx); + float abs_dist = fabsf(signed_dist); - // Normalize heading difference to [0, pi] - if (heading_diff > M_PI) - heading_diff = 2.0f * M_PI - heading_diff; + if (abs_dist > max_distance_threshold) + continue; // Skip this lane, too far away - // Penalize if heading differs by more than 30 degrees - if (heading_diff > (M_PI / 6.0f)) - dist += 3.0f; + // Compute lane heading using multi-segment alignment + float avg_lane_heading = compute_multi_segment_alignment(entity, closest_segment_idx); - if (dist < min_distance) { - min_distance = dist; - closest_lane_entity_idx = entity_idx; - closest_lane_geometry_idx = geometry_idx; + // Compute heading alignment penalty (0.0 = perfect, 1.0 = opposite) + float heading_diff = compute_heading_diff(agent->sim_heading, avg_lane_heading); + float heading_penalty = fabsf(heading_diff) / M_PI; // Normalize to [0, 1] + + // Normalize distance for scoring + float distance_penalty = abs_dist / LANE_DISTANCE_NORMALIZATION; + + // Combined score using defined weights + float score = + LANE_SELECTION_DISTANCE_WEIGHT * distance_penalty + LANE_SELECTION_HEADING_WEIGHT * heading_penalty; + + // Hysteresis: penalize switching away from current lane + if (agent->current_lane_index != entity_idx && agent->current_lane_index != -1) { + score += LANE_SWITCH_THRESHOLD; + } + + // Track best candidate + if (score < best_score) { + min_distance = abs_dist; + best_score = score; + best_candidate_entity_idx = entity_idx; + best_candidate_geometry_idx = closest_segment_idx; + best_candidate_signed_lane_distance = signed_dist; + best_candidate_lane_heading = avg_lane_heading; } } } + // Update lane alignment metric (running average) + if (best_candidate_entity_idx != -1) { + agent->current_lane_index = best_candidate_entity_idx; + agent->current_lane_geometry_idx = best_candidate_geometry_idx; + + // Lane distance and angle metrics (GIGAFLOW Frenet coordinates) + // x_f = lateral offset from lane center (left = negative, right = positive) + agent->metrics_array[LANE_DIST_IDX] = best_candidate_signed_lane_distance; + // theta_f = angle relative to lane heading + float theta_f = compute_heading_diff(agent->sim_heading, best_candidate_lane_heading); + agent->metrics_array[LANE_ANGLE_IDX] = cosf(theta_f); // Store cos(θ_f) + } else { + // Agent not on any lane - use "bad" values to indicate offroad state + agent->current_lane_index = -1; + agent->current_lane_geometry_idx = -1; + agent->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; // Max distance (far from lane) + agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; + } + // check if aligned with closest lane and set current lane - // 4.0m threshold: agents more than 4 meters from any lane are considered off-road - if (min_distance > 4.0f || closest_lane_entity_idx == -1) { + if (min_distance > max_distance_threshold || best_candidate_entity_idx == -1) { agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; agent->current_lane_index = -1; } else { - agent->current_lane_index = closest_lane_entity_idx; - int lane_aligned = - check_lane_aligned(agent, &env->road_elements[closest_lane_entity_idx], closest_lane_geometry_idx); + agent->current_lane_index = best_candidate_entity_idx; + int lane_aligned = (fabs(agent->metrics_array[LANE_ANGLE_IDX]) > 0.965) ? 1 : 0; agent->metrics_array[LANE_ALIGNED_IDX] = lane_aligned; } @@ -1663,6 +1831,13 @@ void compute_observations(Drive *env) { float v_dot_heading = ego_entity->sim_vx * cos_heading + ego_entity->sim_vy * sin_heading; float signed_speed = copysignf(speed_magnitude, v_dot_heading); + // Adding speed limit calculation + float speed_limit = 20.0f; + // We need to add speed limit calculation + + // Adding lane angle and center information + float lane_center_dist = ego_entity->metrics_array[LANE_DIST_IDX] / LANE_DISTANCE_NORMALIZATION; + lane_center_dist = fmaxf(-1.0f, fminf(1.0f, lane_center_dist)); // Set goal distances float goal_x = ego_entity->goal_position_x - ego_entity->sim_x; float goal_y = ego_entity->goal_position_y - ego_entity->sim_y; @@ -1688,8 +1863,14 @@ void compute_observations(Drive *env) { (ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3]; obs[9] = ego_entity->a_lat / JERK_LAT[2]; obs[10] = (ego_entity->respawn_timestep != -1) ? 1 : 0; + obs[11] = fminf(speed_limit / MAX_SPEED, 1.0f); + obs[12] = lane_center_dist; + obs[13] = ego_entity->metrics_array[LANE_ANGLE_IDX]; } else { obs[7] = (ego_entity->respawn_timestep != -1) ? 1 : 0; + obs[8] = fminf(speed_limit / MAX_SPEED, 1.0f); + obs[9] = lane_center_dist; + obs[10] = ego_entity->metrics_array[LANE_ANGLE_IDX]; } // Relative Pos of other cars @@ -1836,6 +2017,8 @@ void respawn_agent(Drive *env, int agent_idx) { agent->metrics_array[OFFROAD_IDX] = 0.0f; agent->metrics_array[REACHED_GOAL_IDX] = 0.0f; agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; + agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle + agent->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center agent->respawn_timestep = env->timestep; agent->collided_before_goal = 0; @@ -2106,6 +2289,8 @@ void c_reset(Drive *env) { agent->metrics_array[OFFROAD_IDX] = 0.0f; agent->metrics_array[REACHED_GOAL_IDX] = 0.0f; agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; + agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle + agent->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center agent->stopped = 0; agent->removed = 0; diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 8119415eec..252097b4be 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -24,7 +24,7 @@ def __init__(self, env, input_size=128, hidden_size=128, **kwargs): self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories # Determine ego dimension from environment's dynamics model - self.ego_dim = 11 if env.dynamics_model == "jerk" else 8 + self.ego_dim = env.ego_features self.ego_encoder = nn.Sequential( pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)),