From 9beb5d97a3c738404a559868103dcba9c828ffff Mon Sep 17 00:00:00 2001 From: mpragnay Date: Wed, 4 Mar 2026 18:22:26 -0500 Subject: [PATCH 1/3] Changed some constants to be read from C bindings --- pufferlib/ocean/torch.py | 390 +++++++++++++++++++++++++++++++++++++++ scripts/export_onnx.py | 56 +++--- 2 files changed, 414 insertions(+), 32 deletions(-) diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 925da37984..310a65372e 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -107,3 +107,393 @@ def decode_actions(self, flat_hidden): value = self.value_fn(flat_hidden) return action, value + + def observation_spec(self): + """Return structured observation specification. + + Documents every feature by name, index, normalization, and expected range. + This is the single source of truth for building input adapters. + + The flat observation vector is laid out as: + [ego_features | reward_conditioning (optional) | partner_features | road_features] + """ + import math + + # --- Constants matching C defines in drive.h / datatypes.h --- + MAX_SPEED = 100.0 + MAX_VEH_WIDTH = 15.0 + MAX_VEH_LEN = 30.0 + LANE_DIST_NORM = 4.0 + SPEED_LIMIT = 20.0 + JERK_LONG_MIN = -15.0 # braking + JERK_LONG_MAX = 4.0 # acceleration + JERK_LAT_MAX = 4.0 + MAX_ROAD_SEGMENT_LENGTH = 100.0 + MAX_ROAD_SCALE = 100.0 + NUM_REWARD_COEFS = 16 + GOAL_POSITION_SCALE = 0.005 # 1/200m — goal relative position normalization + RELATIVE_POSITION_SCALE = 0.02 # 1/50m — partner & road relative position normalization + + # Determine dynamics model from ego_dim + # Jerk: 16 base, Classic: 13 base + base_ego = self.ego_dim + has_conditioning = base_ego > 16 # if > 16, conditioning is appended + if has_conditioning: + base_ego_no_cond = base_ego - NUM_REWARD_COEFS + else: + base_ego_no_cond = base_ego + is_jerk = base_ego_no_cond == 16 + + # --- Ego features --- + ego_spec = [ + { + "name": "goal_rel_x", + "norm": f"* {GOAL_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Goal X in ego frame (1/200m scale)", + }, + { + "name": "goal_rel_y", + "norm": f"* {GOAL_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Goal Y in ego frame (1/200m scale)", + }, + { + "name": "goal_rel_z", + "norm": f"* {GOAL_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Goal Z in ego frame (1/200m scale)", + }, + {"name": "signed_speed", "norm": f"/ {MAX_SPEED}", "range": [-1, 1], "desc": "Signed speed along heading"}, + {"name": "vehicle_width", "norm": f"/ {MAX_VEH_WIDTH}", "range": [0, 1], "desc": "Ego vehicle width"}, + {"name": "vehicle_length", "norm": f"/ {MAX_VEH_LEN}", "range": [0, 1], "desc": "Ego vehicle length"}, + {"name": "collision_flag", "norm": "binary", "range": [0, 1], "desc": "1 if currently collided"}, + ] + if is_jerk: + # Jerk dynamics state variable limits (from c_step clipping in drive.h): + # steering_angle: clipped to [-0.55, 0.55] rad (~[-31.5°, 31.5°]) + # a_long: clipped to [-5.0, 2.5] m/s² + # a_lat: clipped to [-4.0, 4.0] m/s² + # Jerk action space: + # JERK_LONG = [-15.0, -4.0, 0.0, 4.0] m/s³ + # JERK_LAT = [-4.0, 0.0, 4.0] m/s³ + ego_spec += [ + { + "name": "steering_angle", + "norm": f"/ π ({math.pi:.4f})", + "range": [-0.175, 0.175], + "physical_range": [-0.55, 0.55], + "unit": "rad", + "desc": "Current steering angle (clipped ±0.55 rad ≈ ±31.5°)", + }, + { + "name": "a_long", + "norm": f"asymmetric: /{-JERK_LONG_MIN} if neg, /{JERK_LONG_MAX} if pos", + "range": [-0.333, 0.625], + "physical_range": [-5.0, 2.5], + "unit": "m/s²", + "desc": "Longitudinal acceleration (clipped [-5.0, 2.5] m/s², asymmetric norm)", + }, + { + "name": "a_lat", + "norm": f"/ {JERK_LAT_MAX}", + "range": [-1, 1], + "physical_range": [-4.0, 4.0], + "unit": "m/s²", + "desc": "Lateral acceleration (clipped ±4.0 m/s²)", + }, + { + "name": "respawned_flag", + "norm": "binary", + "range": [0, 1], + "desc": "1 while agent is in respawn transit (respawn_timestep != -1). " + "GOAL_RESPAWN(0): set to 1 on respawn, cleared to 0 once agent resumes. " + "GOAL_GENERATE_NEW(1): always 0 (agent keeps driving, never respawns). " + "GOAL_STOP(2): always 0 (agent stops at goal, never respawns).", + }, + { + "name": "goal_speed_min", + "norm": f"/ {MAX_SPEED}", + "range": [0, 1], + "desc": "Min goal speed (0 if disabled)", + }, + { + "name": "goal_speed_max", + "norm": f"/ {MAX_SPEED}", + "range": [0, 1], + "desc": "Max goal speed (0 if disabled)", + }, + {"name": "speed_limit", "norm": f"/ {MAX_SPEED}, clamped", "range": [0, 1], "desc": "Road speed limit"}, + { + "name": "lane_center_dist", + "norm": f"/ {LANE_DIST_NORM}, clamped", + "range": [-1, 1], + "desc": "Signed distance from lane center", + }, + {"name": "lane_angle_cos", "norm": "raw", "range": [-1, 1], "desc": "cos(heading diff from lane)"}, + ] + else: # Classic + ego_spec += [ + { + "name": "respawned_flag", + "norm": "binary", + "range": [0, 1], + "desc": "1 while agent is in respawn transit (respawn_timestep != -1). " + "GOAL_RESPAWN(0): set to 1 on respawn, cleared to 0 once agent resumes. " + "GOAL_GENERATE_NEW(1): always 0 (agent keeps driving, never respawns). " + "GOAL_STOP(2): always 0 (agent stops at goal, never respawns).", + }, + { + "name": "goal_speed_min", + "norm": f"/ {MAX_SPEED}", + "range": [0, 1], + "desc": "Min goal speed (0 if disabled)", + }, + { + "name": "goal_speed_max", + "norm": f"/ {MAX_SPEED}", + "range": [0, 1], + "desc": "Max goal speed (0 if disabled)", + }, + {"name": "speed_limit", "norm": f"/ {MAX_SPEED}, clamped", "range": [0, 1], "desc": "Road speed limit"}, + { + "name": "lane_center_dist", + "norm": f"/ {LANE_DIST_NORM}, clamped", + "range": [-1, 1], + "desc": "Signed distance from lane center", + }, + {"name": "lane_angle_cos", "norm": "raw", "range": [-1, 1], "desc": "cos(heading diff from lane)"}, + ] + + # --- Reward conditioning (optional) --- + reward_coef_names = [ + "goal_radius", + "collision", + "offroad", + "comfort", + "lane_align", + "lane_center", + "velocity", + "traffic_light", + "center_bias", + "vel_align", + "overspeed", + "timestep", + "reverse", + "throttle", + "steer", + "acc", + ] + conditioning_spec = None + if has_conditioning: + conditioning_spec = [ + { + "name": f"reward_coef_{name}", + "norm": "tanh-normalized", + "range": [-1, 1], + "desc": f"Reward conditioning coef for {name}", + } + for name in reward_coef_names + ] + + # --- Partner features (per object, 8 features) --- + partner_spec = [ + { + "name": "rel_x", + "norm": f"* {RELATIVE_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Partner X in ego frame (1/50m scale)", + }, + { + "name": "rel_y", + "norm": f"* {RELATIVE_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Partner Y in ego frame (1/50m scale)", + }, + { + "name": "rel_z", + "norm": f"* {RELATIVE_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Partner Z in ego frame (1/50m scale)", + }, + {"name": "width", "norm": f"/ {MAX_VEH_WIDTH}", "range": [0, 1], "desc": "Partner vehicle width"}, + {"name": "length", "norm": f"/ {MAX_VEH_LEN}", "range": [0, 1], "desc": "Partner vehicle length"}, + {"name": "rel_heading_cos", "norm": "raw", "range": [-1, 1], "desc": "cos(partner_heading - ego_heading)"}, + {"name": "rel_heading_sin", "norm": "raw", "range": [-1, 1], "desc": "sin(partner_heading - ego_heading)"}, + {"name": "signed_speed", "norm": f"/ {MAX_SPEED}", "range": [-1, 1], "desc": "Partner signed speed"}, + ] + + # --- Road features (per segment, 8 features) --- + road_spec = [ + { + "name": "rel_x", + "norm": f"* {RELATIVE_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Segment midpoint X in ego frame (1/50m scale)", + }, + { + "name": "rel_y", + "norm": f"* {RELATIVE_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Segment midpoint Y in ego frame (1/50m scale)", + }, + { + "name": "rel_z", + "norm": f"* {RELATIVE_POSITION_SCALE}", + "range": [-1, 1], + "desc": "Segment midpoint Z in ego frame (1/50m scale)", + }, + {"name": "length", "norm": f"/ {MAX_ROAD_SEGMENT_LENGTH}", "range": [0, 1], "desc": "Segment half-length"}, + {"name": "width", "norm": f"/ {MAX_ROAD_SCALE}", "range": [0, 1], "desc": "Segment width (hardcoded 0.1)"}, + {"name": "cos_angle", "norm": "raw", "range": [-1, 1], "desc": "cos(segment direction - ego heading)"}, + {"name": "sin_angle", "norm": "raw", "range": [-1, 1], "desc": "sin(segment direction - ego heading)"}, + { + "name": "road_type", + "norm": "categorical: type - 4", + "range": [0, 2], + "desc": "Road element type: 0=LANE (drivable surface center-line), " + "1=LINE (painted lane marking/divider between lanes), " + "2=EDGE (road boundary/curb)", + }, + ] + + return { + "layout": "[ego | reward_conditioning? | partners | road_segments]", + "total_dim": self.observation_size, + "ego": { + "offset": 0, + "count": 1, + "features_per_object": base_ego_no_cond, + "total_dim": base_ego_no_cond, + "features": ego_spec, + }, + "reward_conditioning": { + "offset": base_ego_no_cond, + "count": NUM_REWARD_COEFS if has_conditioning else 0, + "total_dim": NUM_REWARD_COEFS if has_conditioning else 0, + "features": conditioning_spec, + } + if has_conditioning + else None, + "partners": { + "offset": self.ego_dim, + "count": self.max_partner_objects, + "features_per_object": self.partner_features, + "total_dim": self.max_partner_objects * self.partner_features, + "features": partner_spec, + }, + "road_segments": { + "offset": self.ego_dim + self.max_partner_objects * self.partner_features, + "count": self.max_road_objects, + "features_per_object": self.road_features, + "total_dim": self.max_road_objects * self.road_features, + "features": road_spec, + }, + } + + @staticmethod + def build_structured_observation(dynamics_model="classic", reward_conditioning=False, batch_size=1): + """Build a physically valid dummy observation tensor for export/testing. + + Reads observation dimensions directly from the C binding constants. + All values are within the ranges that compute_observations() in C would produce. + + Args: + dynamics_model: "classic" or "jerk" + reward_conditioning: whether reward conditioning coefficients are appended to ego + batch_size: batch dimension + """ + import math + from pufferlib.ocean.drive import binding + + # --- Dimensions from C binding --- + max_road_objects = binding.MAX_ROAD_SEGMENT_OBSERVATIONS + max_partner_objects = binding.MAX_AGENTS - 1 + partner_features = binding.PARTNER_FEATURES + road_features = binding.ROAD_FEATURES + + if dynamics_model == "jerk": + ego_dim = binding.EGO_FEATURES_JERK_CONDITIONING if reward_conditioning else binding.EGO_FEATURES_JERK + else: + ego_dim = binding.EGO_FEATURES_CLASSIC_CONDITIONING if reward_conditioning else binding.EGO_FEATURES_CLASSIC + + is_jerk = dynamics_model == "jerk" + has_conditioning = reward_conditioning + base_ego = binding.EGO_FEATURES_JERK if is_jerk else binding.EGO_FEATURES_CLASSIC + + # --- Constants matching C normalization --- + MAX_SPEED = 100.0 + MAX_VEH_WIDTH = 15.0 + MAX_VEH_LEN = 30.0 + SPEED_LIMIT = 20.0 + NUM_REWARD_COEFS = 16 + GOAL_POSITION_SCALE = 0.005 # 1/200m — goal relative position normalization + RELATIVE_POSITION_SCALE = 0.02 # 1/50m — partner & road relative position normalization + + # --- Ego features --- + ego = torch.zeros(batch_size, ego_dim) + # Goal relative position (normalized by *GOAL_POSITION_SCALE, so raw ~[-200, 200] → [-1, 1]) + ego[:, 0] = 30.0 * GOAL_POSITION_SCALE # goal_rel_x: ~30m ahead + ego[:, 1] = 2.0 * GOAL_POSITION_SCALE # goal_rel_y: ~2m lateral + ego[:, 2] = 0.0 # goal_rel_z + ego[:, 3] = 5.0 / MAX_SPEED # signed_speed: 5 m/s + ego[:, 4] = 2.0 / MAX_VEH_WIDTH # vehicle_width: 2m + ego[:, 5] = 4.5 / MAX_VEH_LEN # vehicle_length: 4.5m + ego[:, 6] = 0.0 # collision_flag: no collision + + if is_jerk: + ego[:, 7] = 0.0 # steering_angle: straight + ego[:, 8] = 0.0 # a_long: no acceleration + ego[:, 9] = 0.0 # a_lat: no lateral accel + ego[:, 10] = 0.0 # respawned_flag + ego[:, 11] = 0.0 # goal_speed_min (disabled) + ego[:, 12] = 10.0 / MAX_SPEED # goal_speed_max + ego[:, 13] = min(SPEED_LIMIT / MAX_SPEED, 1.0) # speed_limit + ego[:, 14] = 0.05 # lane_center_dist: slightly off-center + ego[:, 15] = 0.98 # lane_angle_cos: well-aligned + else: + ego[:, 7] = 0.0 # respawned_flag + ego[:, 8] = 0.0 # goal_speed_min + ego[:, 9] = 10.0 / MAX_SPEED # goal_speed_max + ego[:, 10] = min(SPEED_LIMIT / MAX_SPEED, 1.0) # speed_limit + ego[:, 11] = 0.05 # lane_center_dist + ego[:, 12] = 0.98 # lane_angle_cos + + # Reward conditioning: tanh-normalized values in [-1, 1] + if has_conditioning: + cond_offset = base_ego + for c in range(NUM_REWARD_COEFS): + ego[:, cond_offset + c] = 0.0 # neutral conditioning + + # --- Partner features (mostly empty = no visible partners) --- + partner_dim = max_partner_objects * partner_features + partners = torch.zeros(batch_size, partner_dim) + # Place one visible partner ~10m ahead, 3m to the right + partners[:, 0] = 10.0 * RELATIVE_POSITION_SCALE # rel_x + partners[:, 1] = 3.0 * RELATIVE_POSITION_SCALE # rel_y + partners[:, 2] = 0.0 # rel_z + partners[:, 3] = 2.0 / MAX_VEH_WIDTH # width + partners[:, 4] = 4.5 / MAX_VEH_LEN # length + partners[:, 5] = 1.0 # rel_heading_cos (same direction) + partners[:, 6] = 0.0 # rel_heading_sin + partners[:, 7] = 8.0 / MAX_SPEED # signed_speed + + # --- Road features --- + road_dim = max_road_objects * road_features + roads = torch.zeros(batch_size, road_dim) + # Place a few road segments nearby + for seg in range(min(5, max_road_objects)): + base = seg * road_features + dist = 3.0 + seg * 4.0 # stagger segments ahead + roads[:, base + 0] = dist * RELATIVE_POSITION_SCALE # rel_x + roads[:, base + 1] = 0.5 * RELATIVE_POSITION_SCALE # rel_y: slightly off-center + roads[:, base + 2] = 0.0 # rel_z + roads[:, base + 3] = 5.0 / 100.0 # length + roads[:, base + 4] = 0.1 / 100.0 # width (hardcoded 0.1 in C) + angle = 0.05 * seg + roads[:, base + 5] = math.cos(angle) # cos_angle + roads[:, base + 6] = math.sin(angle) # sin_angle + roads[:, base + 7] = 0.0 # road_type: lane (type 4 - 4 = 0) + + obs = torch.cat([ego, partners, roads], dim=1) + return obs diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py index aecda002a7..43b8772cf5 100644 --- a/scripts/export_onnx.py +++ b/scripts/export_onnx.py @@ -1,7 +1,8 @@ """Export a trained PufferDrive policy checkpoint (.pt) to ONNX format. -The exported ONNX model accepts an observation vector (see in torch.py for the exact layout), -plus LSTM hidden states, and produces action logits, +The exported ONNX model accepts an observation vector (see +observation_spec() in torch.py for the exact layout and feature +descriptions), plus LSTM hidden states, and produces action logits, a value estimate, and updated LSTM states. Usage: @@ -117,36 +118,27 @@ def export_to_onnx(verify=True): drive_policy = policy if hasattr(drive_policy, "ego_dim"): - # Construct valid dummy observation for Drive policy - # Retrieve needed dimensions - ego_dim = drive_policy.ego_dim - max_partner_objects = drive_policy.max_partner_objects - partner_features = drive_policy.partner_features - max_road_objects = drive_policy.max_road_objects - road_features = drive_policy.road_features - - partner_dim = max_partner_objects * partner_features - road_dim = max_road_objects * road_features - - # Random parts - dummy_ego = torch.randn(batch_size, ego_dim) - dummy_partner = torch.randn(batch_size, partner_dim) - - # Road part: continuous features + categorical feature - road_cont_dim = road_features - 1 - - # (Batch, MaxObjects, Feats-1) - dummy_road_cont = torch.randn(batch_size, max_road_objects, road_cont_dim) - - # (Batch, MaxObjects, 1) - valid categorical values [0, 6] - # Ensure it's 0-6 range. 7 is num_classes. - dummy_road_cat = torch.randint(0, 7, (batch_size, max_road_objects, 1)).float() - - # Concatenate and flatten - dummy_road_objs = torch.cat([dummy_road_cont, dummy_road_cat], dim=2) - dummy_road = dummy_road_objs.view(batch_size, -1) - - dummy_obs = torch.cat([dummy_ego, dummy_partner, dummy_road], dim=1) + # Build a physically valid structured observation using binding constants + dummy_obs = drive_policy.build_structured_observation( + dynamics_model=config["env"].get("dynamics_model", "classic"), + reward_conditioning=bool(config["env"].get("reward_conditioning", 0)), + batch_size=batch_size, + ) + + # Print observation spec for reference + spec = drive_policy.observation_spec() + print(f"\nObservation layout: {spec['layout']}") + print(f" Ego: offset={spec['ego']['offset']}, dim={spec['ego']['total_dim']}") + if spec.get("reward_conditioning"): + rc = spec["reward_conditioning"] + print(f" Conditioning: offset={rc['offset']}, dim={rc['total_dim']}") + print( + f" Partners: offset={spec['partners']['offset']}, dim={spec['partners']['total_dim']} ({spec['partners']['count']}x{spec['partners']['features_per_object']})" + ) + print( + f" Road: offset={spec['road_segments']['offset']}, dim={spec['road_segments']['total_dim']} ({spec['road_segments']['count']}x{spec['road_segments']['features_per_object']})" + ) + print(f" Total: {spec['total_dim']}") else: print("Warning: Could not determine Drive policy structure. Using random observation.") dummy_obs = torch.randn(batch_size, obs_dim) From 391809e98a780dfca532a8d51af97f1a62689cbb Mon Sep 17 00:00:00 2001 From: mpragnay Date: Thu, 5 Mar 2026 11:15:02 -0500 Subject: [PATCH 2/3] SOme code cleanup --- scripts/export_onnx.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py index 43b8772cf5..057baae0f4 100644 --- a/scripts/export_onnx.py +++ b/scripts/export_onnx.py @@ -58,24 +58,18 @@ def export_to_onnx(verify=True): args = parser.parse_args() - # Load configuration + # Load environment config = load_config(args.env) - - # Load environment to get observation/action spaces package = config["base"]["package"] module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" env_module = importlib.import_module(module_name) make_env = env_module.env_creator(args.env) - - # Ensure env args/kwargs are correctly passed env_kwargs = config["env"] - vecenv = pufferlib.vector.make(make_env, env_kwargs=env_kwargs, backend=pufferlib.vector.Serial, num_envs=1) # Initialize Policy print("Initializing Policy...") policy = Drive(vecenv.driver_env, **config["policy"]) - if config["base"]["rnn_name"]: print("Wrapping with LSTM...") policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"]) @@ -106,13 +100,11 @@ def export_to_onnx(verify=True): batch_size = 1 obs_space = vecenv.single_observation_space - # Flatten observation if needed, Drive policy handles flattening internally usually but check vecenv # The LSTMWrapper expects (B, ObsDim) obs_dim = np.prod(obs_space.shape) # Create Dummy Observation if config["base"]["rnn_name"]: - # If wrapped, access the internal Drive policy drive_policy = policy.policy else: drive_policy = policy @@ -201,7 +193,6 @@ def export_to_onnx(verify=True): sess_options.inter_op_num_threads = 1 ort_session = ort.InferenceSession(args.output, sess_options) - # PyTorch output with torch.no_grad(): torch_logits, torch_value, torch_h, torch_c = onnx_policy(*dummy_inputs) From 18b7d82161ec61b62b601654963eb15725dfdcdf Mon Sep 17 00:00:00 2001 From: mpragnay Date: Thu, 5 Mar 2026 12:17:19 -0500 Subject: [PATCH 3/3] Output construction script and specs --- pufferlib/ocean/drive/drive.c | 4 +- pufferlib/ocean/torch.py | 189 ++++++++++++++++++++++++++++++++++ scripts/export_onnx.py | 53 +++++++++- 3 files changed, 242 insertions(+), 4 deletions(-) diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c index b2252153bd..9563f2a4ae 100644 --- a/pufferlib/ocean/drive/drive.c +++ b/pufferlib/ocean/drive/drive.c @@ -90,13 +90,13 @@ void demo() { .init_steps = conf.init_steps, .init_mode = conf.init_mode, .control_mode = conf.control_mode, - .map_name = "resources/drive/binaries/carla/carla_3D/map_001.bin", + .map_name = "resources/drive/binaries/carla_2D/map_001.bin", .reward_conditioning = 1, }; allocate(&env); c_reset(&env); c_render(&env); - Weights *weights = load_weights("resources/drive/puffer_drive_resampling_speed_lane.bin"); + Weights *weights = load_weights("model_puffer_drive_003000.bin"); DriveNet *net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, env.reward_conditioning); int accel_delta = 1; diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 310a65372e..e6fdcea1c4 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -391,6 +391,195 @@ def observation_spec(self): }, } + def action_spec(self): + """Return structured action/output specification for the discrete joint action space. + + Documents the network output layout, how to decompose the joint action integer + into sub-action indices, and the physical values (with units) each index maps to. + This is the single source of truth for interpreting model outputs on the deployment side. + + The network produces: + logits: (batch, joint_action_size) — unnormalized log-probabilities + value: (batch, 1) — critic estimate (unbounded) + lstm_h: (batch, hidden_size) — LSTM hidden state + lstm_c: (batch, hidden_size) — LSTM cell state + + Deployment post-processing (discrete, deterministic): + joint_action = argmax(logits) + primary_idx = joint_action // num_secondary + secondary_idx = joint_action % num_secondary + physical_primary = PRIMARY_VALUES[primary_idx] + physical_secondary = SECONDARY_VALUES[secondary_idx] + """ + from pufferlib.ocean.drive import binding + + dynamics_model = ( + "jerk" + if (self.ego_dim == binding.EGO_FEATURES_JERK or self.ego_dim == binding.EGO_FEATURES_JERK_CONDITIONING) + else "classic" + ) + is_jerk = dynamics_model == "jerk" + + # --- Action value arrays matching C defines in drive.h --- + if is_jerk: + # JERK_LONG[4] and JERK_LAT[3] + primary_values = [-15.0, -4.0, 0.0, 4.0] + primary_unit = "m/s³" + primary_name = "longitudinal_jerk" + secondary_values = [-4.0, 0.0, 4.0] + secondary_unit = "m/s³" + secondary_name = "lateral_jerk" + else: + # ACCELERATION_VALUES[7] and STEERING_VALUES[13] + primary_values = [-4.0, -2.667, -1.333, 0.0, 1.333, 2.667, 4.0] + primary_unit = "m/s²" + primary_name = "acceleration" + secondary_values = [-1.0, -0.833, -0.667, -0.5, -0.333, -0.167, 0.0, 0.167, 0.333, 0.5, 0.667, 0.833, 1.0] + secondary_unit = "rad" + secondary_name = "steering" + + num_primary = len(primary_values) + num_secondary = len(secondary_values) + joint_size = num_primary * num_secondary + + # Verify consistency with the model's atn_dim + assert sum(self.atn_dim) == joint_size, ( + f"action_spec joint_size={joint_size} != model atn_dim sum={sum(self.atn_dim)}" + ) + + return { + "mode": "discrete", + "dynamics_model": dynamics_model, + "joint_action_size": joint_size, + "decomposition": f"joint_action = {primary_name}_idx * {num_secondary} + {secondary_name}_idx", + "primary": { + "name": primary_name, + "num_actions": num_primary, + "values": primary_values, + "unit": primary_unit, + "index_formula": f"joint_action // {num_secondary}", + "desc": f"{primary_name} sub-action index (row of the joint grid)", + }, + "secondary": { + "name": secondary_name, + "num_actions": num_secondary, + "values": secondary_values, + "unit": secondary_unit, + "index_formula": f"joint_action % {num_secondary}", + "desc": f"{secondary_name} sub-action index (column of the joint grid)", + }, + "outputs": { + "logits": { + "shape": f"(batch, {joint_size})", + "desc": "Unnormalized log-probabilities over the joint action space. " + "Apply softmax for probabilities or argmax for greedy action.", + }, + "value": { + "shape": "(batch, 1)", + "desc": "State-value estimate V(s). Unbounded scalar, no activation.", + }, + "lstm_h": { + "shape": f"(batch, {self.hidden_size})", + "desc": "LSTM hidden state to feed back at next timestep.", + }, + "lstm_c": { + "shape": f"(batch, {self.hidden_size})", + "desc": "LSTM cell state to feed back at next timestep.", + }, + }, + "post_processing": { + "deterministic": "joint_action = argmax(logits, dim=-1)", + "stochastic": "joint_action = Categorical(logits=logits).sample()", + "decompose": ( + f"{primary_name}_idx = joint_action // {num_secondary}; " + f"{secondary_name}_idx = joint_action % {num_secondary}" + ), + "lookup": ( + f"physical_{primary_name} = {primary_name}_values[{primary_name}_idx]; " + f"physical_{secondary_name} = {secondary_name}_values[{secondary_name}_idx]" + ), + }, + } + + @staticmethod + def construct_action_output(logits, dynamics_model="classic"): + """Decode raw network logits into physical action values for the discrete joint action space. + + Uses pufferlib.pytorch.sample_logits — the same function used during training — + to ensure identical sampling behaviour. + Mirrors the C-side action decoding in move_dynamics() (drive.h). + This is the output-side counterpart of build_structured_observation(). + + Args: + logits: Raw logits as returned by decode_actions() — a tuple of tensors + for multi-discrete, or a single tensor for discrete. + dynamics_model: "classic" or "jerk" + + Returns: + dict with keys: + joint_action: (batch,) int tensor — flat joint action index + primary_idx: (batch,) int tensor — primary sub-action index + secondary_idx: (batch,) int tensor — secondary sub-action index + primary_physical: (batch,) float tensor — physical primary value + secondary_physical: (batch,) float tensor — physical secondary value + log_prob: (batch,) float tensor — log-probability of the sampled action + entropy: (batch,) float tensor — categorical entropy over joint actions + metadata: dict — action names, units, value tables + """ + import pufferlib.pytorch + + # --- Sample using the exact same function as training --- + action, log_prob, entropy = pufferlib.pytorch.sample_logits(logits) + joint_action = action.squeeze(-1).int() + + # --- Action value tables matching C defines --- + if dynamics_model == "jerk": + primary_values = torch.tensor([-15.0, -4.0, 0.0, 4.0]) + primary_name = "longitudinal_jerk" + primary_unit = "m/s³" + secondary_values = torch.tensor([-4.0, 0.0, 4.0]) + secondary_name = "lateral_jerk" + secondary_unit = "m/s³" + else: + primary_values = torch.tensor([-4.0, -2.667, -1.333, 0.0, 1.333, 2.667, 4.0]) + primary_name = "acceleration" + primary_unit = "m/s²" + secondary_values = torch.tensor( + [-1.0, -0.833, -0.667, -0.5, -0.333, -0.167, 0.0, 0.167, 0.333, 0.5, 0.667, 0.833, 1.0] + ) + secondary_name = "steering" + secondary_unit = "rad" + + num_secondary = len(secondary_values) + + # Decompose joint action: a = primary_idx * num_secondary + secondary_idx + primary_idx = joint_action // num_secondary + secondary_idx = joint_action % num_secondary + + # Look up physical values + primary_physical = primary_values[primary_idx.long()] + secondary_physical = secondary_values[secondary_idx.long()] + + return { + "joint_action": joint_action, + "primary_idx": primary_idx, + "secondary_idx": secondary_idx, + f"{primary_name}": primary_physical, + f"{secondary_name}": secondary_physical, + "log_prob": log_prob, + "entropy": entropy, + "metadata": { + "dynamics_model": dynamics_model, + "primary_name": primary_name, + "primary_unit": primary_unit, + "primary_values": primary_values.tolist(), + "secondary_name": secondary_name, + "secondary_unit": secondary_unit, + "secondary_values": secondary_values.tolist(), + "decomposition": f"joint = {primary_name}_idx * {num_secondary} + {secondary_name}_idx", + }, + } + @staticmethod def build_structured_observation(dynamics_model="classic", reward_conditioning=False, batch_size=1): """Build a physically valid dummy observation tensor for export/testing. diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py index 057baae0f4..105bac39d8 100644 --- a/scripts/export_onnx.py +++ b/scripts/export_onnx.py @@ -50,7 +50,7 @@ def export_to_onnx(verify=True): parser.add_argument( "--checkpoint", type=str, - default="experiments/puffer_drive_73kbtsi5/model_puffer_drive_000200.pt", + default="model_puffer_drive_003000.pt", help="Path to .pt checkpoint", ) parser.add_argument("--output", type=str, help="Output .onnx file path") @@ -237,7 +237,56 @@ def compare(name, torch_out, ort_out, atol=1e-5): compare("LSTM h", torch_h, ort_h) compare("LSTM c", torch_c, ort_c) - # Export example input and output to .pt files + # --- Construct and save decoded action outputs --- + dynamics_model = config["env"].get("dynamics_model", "classic") + + # Decode from PyTorch logits + action_output = Drive.construct_action_output(torch_logits, dynamics_model=dynamics_model) + + # Print action spec for reference + atn_spec = drive_policy.action_spec() + print(f"\nAction spec ({atn_spec['dynamics_model']}, {atn_spec['mode']}):") + print(f" Joint action size: {atn_spec['joint_action_size']}") + print(f" Decomposition: {atn_spec['decomposition']}") + p = atn_spec["primary"] + s = atn_spec["secondary"] + print(f" Primary: {p['name']} ({p['unit']}), {p['num_actions']} values: {p['values']}") + print(f" Secondary: {s['name']} ({s['unit']}), {s['num_actions']} values: {s['values']}") + + # Print decoded action for the dummy input + meta = action_output["metadata"] + print(f"\nDecoded action (categorical sample) for test observation:") + print(f" Joint action index: {action_output['joint_action'].item()}") + print( + f" {meta['primary_name']}_idx: {action_output['primary_idx'].item()}" + f" → {action_output[meta['primary_name']].item():.3f} {meta['primary_unit']}" + ) + print( + f" {meta['secondary_name']}_idx: {action_output['secondary_idx'].item()}" + f" → {action_output[meta['secondary_name']].item():.3f} {meta['secondary_unit']}" + ) + + # Save complete output checkpoint: raw network outputs + decoded actions + output_checkpoint = { + # Raw network outputs + "logits": torch_logits if isinstance(torch_logits, torch.Tensor) else torch.cat(torch_logits, dim=-1), + "value": torch_value, + "lstm_h": torch_h, + "lstm_c": torch_c, + # Decoded discrete actions (categorical sampling, matches training) + "joint_action": action_output["joint_action"], + "primary_idx": action_output["primary_idx"], + "secondary_idx": action_output["secondary_idx"], + f"{meta['primary_name']}": action_output[meta["primary_name"]], + f"{meta['secondary_name']}": action_output[meta["secondary_name"]], + "log_prob": action_output["log_prob"], + "entropy": action_output["entropy"], + # Metadata for the deployment side to reconstruct decoding + "action_metadata": action_output["metadata"], + } + output_path = os.path.join(output_dir, "test_outputs.pt") + torch.save(output_checkpoint, output_path) + print(f"\n✔ Saved output checkpoint (raw + decoded) to {output_path}") if __name__ == "__main__":