diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index f26d54ceba..c1fed926c0 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -20,6 +20,10 @@ - [Evaluation overview](evaluation.md) - [WOSAC](wosac.md) +# Design + +- [Trial mode (`goal_behavior=3`)](trial_mode.md) + # Blog - [PufferDrive 2.0 release](pufferdrive-2.0.md) diff --git a/docs/src/trial_mode.md b/docs/src/trial_mode.md new file mode 100644 index 0000000000..c242fec37f --- /dev/null +++ b/docs/src/trial_mode.md @@ -0,0 +1,352 @@ +# Trial Mode (`goal_behavior=3`) + +Design and contract for the in-context adaptation training mode. + +## Why + +The adaptive ego is a Transformer with a KV cache. We want it to **adapt +across attempts within a single fixed-budget episode** — i.e., use what it +saw in trial 1 to do better in trial 2, etc. That requires: + +1. Multiple goal-reach attempts ("trials") inside one episode. +2. KV cache that **persists across trial boundaries** (so context is + preserved) but **resets at episode boundaries** (so episodes are i.i.d.). +3. PPO/GAE that **stops bootstrap at trial boundaries** (because the + agent's value at $t+1$ is computed post-respawn, from a different + state, and bootstrapping it into the last step of the old trial + contaminates the target). + +These three things — cache reset, GAE bootstrap-stop, episode-vs-trial +distinction — have different gates. The next sections specify each. + +## Terms + +| Term | Meaning | +|---|---| +| **Trial** | One goal-reach attempt. Ends on goal-reach OR `per_trial_timeout` ticks. | +| **Episode** | A sequence of at most `max_trials_per_episode` trials, sharing a single KV cache. | +| **Scenario** | A map. Under `goal_behavior=3`, each episode runs on **one** map (no per-trial map swap). | +| **`terminals[t]`** | 1 ⇔ the *episode* ended at step $t$. Used for both **cache reset** and **GAE bootstrap-stop**. | +| **`truncations[t]`** | 1 ⇔ a *trial* ended at step $t$ but the episode continues. Used **only for GAE bootstrap-stop**; cache persists. | +| **`trial_ended_this_step[i]`** | Per-agent C-side flag, set every trial boundary (goal-reach or timeout). Mirrored to `truncations` by Python. | +| **Cache reset** | Zero out the Transformer's K/V tensors. Done at episode boundary only. | +| **GAE bootstrap-stop** | Setting $(1-\text{stop}_t) = 0$ in the GAE recursion to prevent $V_{t+1}$ contamination across the boundary. | + +## The two-boundary problem + +Standard PPO has one boundary signal (`dones`). We need two, because the +two distinct things that happen at trial-vs-episode boundaries don't +align: + +| Event | `terminals` | `truncations` | KV cache | GAE bootstrap | +|---|:-:|:-:|:-:|:-:| +| Within-trial step | 0 | 0 | continues | continues | +| **Trial end** (goal or timeout), more trials to go | 0 | **1** | **continues** | **stops** | +| **Episode end** (last trial done) | **1** | 0 | **resets** | stops | +| Scenario boundary (gb≠3 only) | 0 | 0 | n/a here | n/a | + +Mnemonic: **`terminals` ⇒ cache reset; (`terminals` OR `truncations`) ⇒ bootstrap stop.** + +## State machine — per agent, per `c_step` + +``` + ┌──────────────────────┐ + │ agent.removed == 1 │ ───── skip (Option D) + └──────────┬───────────┘ + │ no + ▼ + ┌─────── trial_ended? ─────┴───── neither ───→ continue trial + │ (reached || timed_out) + │ + ▼ + trial_ended_this_step[i] = 1 + trial_count++ + │ + ├── trial_count >= max_trials_per_episode ───→ EPISODE END + │ │ + │ ▼ + │ terminals[i] = 1 // Python: cache reset HERE + │ add_log_one_agent(env, i) // flush this agent's metrics + │ agent.removed = 1 // Option D: idle + │ agent.x, agent.y = INVALID // off-grid + │ + └── otherwise ─────────────────────────→ TRIAL END + │ + ▼ + respawn_agent(env, i) // back to start + agent.respawn_timestep = -1 // clear ghost flag (see "Render gates") + agent.trial_start_timestep = env->timestep +``` + +The Python side mirrors `trial_ended_this_step → truncations` after every +`vec_step`. So: + +* Trial-end branch → C sets `trial_ended_this_step[i] = 1`. Python sets + `truncations[i] = 1`. `terminals[i] = 0` (it was zeroed at the top of + `step`). +* Episode-end branch → C sets BOTH `trial_ended_this_step[i] = 1` AND + `terminals[i] = 1`. Python sets `truncations[i] = 1`. + +Both signals fire at the last trial end. That's intentional — the cache +reset gate (terminals) and the bootstrap-stop gate (terminals OR +truncations) both want to fire there. + +## PPO / GAE formulation + +Standard GAE (Schulman et al. 2016) with a single `done` signal: + +$$ +\delta_t = r_t + \gamma\,(1-d_t)\,V_{t+1} - V_t +$$ + +$$ +\hat A_t = \delta_t + \gamma\lambda\,(1-d_t)\,\hat A_{t+1} +$$ + +In vanilla PPO, $d_t = \text{terminals}_t$. The $(1-d_t)$ factors zero out +the $V_{t+1}$ bootstrap and the recursive advantage at episode boundaries +(where state $t+1$ is a fresh env reset — no semantic relation to state +$t$). + +**Trial-mode modification.** At every trial boundary (not just episode +boundary), state $t+1$ is the post-respawn state — back at the trajectory +start position with reset velocity. $V_{t+1}$ from that state is **not** a +valid bootstrap for state $t$ (the last step of the old trial, somewhere +else in the map). We define: + +$$ +\text{bootstrap\_stop}_t \;=\; \min\!\bigl(\text{terminals}_t + \text{truncations}_t,\; 1\bigr) +$$ + +and replace $d_t$ in BOTH GAE equations: + +$$ +\delta_t = r_t + \gamma\,(1-\text{bootstrap\_stop}_t)\,V_{t+1} - V_t +$$ + +$$ +\hat A_t = \delta_t + \gamma\lambda\,(1-\text{bootstrap\_stop}_t)\,\hat A_{t+1} +$$ + +This is `pufferlib/pufferl.py`'s `bootstrap_stop = (self.terminals + self.truncations).clamp(max=1.0)`. + +The KV cache reset is **independent**: it gates on `terminals` alone, NOT +on `bootstrap_stop`. Otherwise we'd lose the cross-trial context that is +the entire point of trial mode. + +``` + PPO/GAE bootstrap-stop KV cache reset + ───────────────────── ────────────── +trial boundary YES (truncations[t]=1) no ← preserves context +episode boundary YES (terminals[t]=1) YES (fresh i.i.d. start) +``` + +## Cache reset gate (`pufferl.py`) + +```python +done_mask = d # was: d + t (gated on terminals only) +self.transformer_context[done_mask.bool()] = 0 +``` + +If we used `d + t`, every trial boundary would wipe the cache — exactly +the opposite of what we want. Trial mode breaks without this fix. + +## Option D — idle-after-max_trials + +Naïve trial mode would, after the agent completes `max_trials_per_episode` +trials, immediately reset the env and start a new episode. With Python's +typical rollout of one map per `resample_frequency` ticks, this leads to +**many short episodes on the same map** — the agent overfits to a tiny +subset of maps within a single Python cycle, and gradient updates see the +same map's gradients repeatedly. + +Option D fixes this by **idling the agent after its episode ends**: + +```c +if (e->trial_count >= env->max_trials_per_episode) { + env->terminals[i] = 1; + add_log_one_agent(env, i); + e->removed = 1; + e->x = e->y = INVALID_POSITION; // off-grid + e->vx = e->vy = 0.0f; + // do NOT call c_reset +} +``` + +The agent is invisible to subsequent `c_step`s (the top-of-loop +`if (e->removed) continue;` gates it out). It stays idle until Python's +`_reinit_envs_with_new_maps` fires at the next `resample_frequency` +boundary — that's when the env loads a fresh map and `c_reset` resets +`removed = 0`. + +**Net effect**: 1 episode per resample window, exactly one fresh map per +episode. Map diversity restored. + +## Trial parameter naming + +Under `goal_behavior=3`: + +* `k_scenarios` IS the number of trials per episode. +* `scenario_length` IS the per-trial timeout. + +These are the canonical names — the only two knobs you set. The C side +exposes internal fields named `max_trials_per_episode` and +`per_trial_timeout` (legacy: shared with non-trial code paths), and +`AdaptiveDrivingAgent.__init__` unconditionally sets them from +`k_scenarios` / `scenario_length` under gb=3. **There is no override.** +If you want a different trial count, change `k_scenarios`. + +`resample_frequency` is also derived: $k \times L$, the worst-case +episode budget. + +So `--env.k-scenarios 4 --env.goal-behavior 3 --env.scenario-length 201` +gives 4 trials of 201 ticks each, episode budget = 804 ticks, resample +at tick 804. (Pre-Option-A there were two more CLI flags +`--env.max-trials-per-episode` and `--env.per-trial-timeout`; both gone.) + +## End-to-end signal flow + +``` + ┌────────────────────┐ + │ C (drive.h) │ trial_count++, trial_ended_this_step[i] = 1 + │ c_step trial loop│ ── if last trial: terminals[i] = 1, removed = 1 + └─────────┬──────────┘ + │ zero-copy NumPy view of trial_ended_this_step (1D u8) + ▼ + ┌────────────────────┐ + │ Python (drive.py) │ truncations[:] = 0 # top of step + │ step() │ vec_step(c_envs) + │ │ truncations[trial_ended_this_step] = 1 # mirror, gb=3 only + │ │ terminals already set by C if episode end + └─────────┬──────────┘ + │ PufferLib SHM (np.bool views) + ▼ + ┌────────────────────┐ + │ pufferl.py │ rollout buffers store BOTH d_t and t_t + │ rollout + GAE │ done_mask = d # cache-reset gate + │ │ bootstrap = (d + t).clamp(1) + │ │ δ_t = r_t + γ(1-bootstrap_t) V_{t+1} - V_t + │ │ Â_t = δ_t + γλ(1-bootstrap_t) Â_{t+1} + └─────────┬──────────┘ + ▼ + PPO update +``` + +## Score semantic + +Standard non-trial modes set `score = 1` if the agent reaches goal "well +enough" in a single scenario (frac of goals reached above a threshold, +and no collisions). Under trial mode each episode has `max_trials` +attempts, so: + +* `goals_reached_this_episode` ∈ {0, 1, …, max_trials} (one increment per + successful trial; gated by `current_goal_reached` to prevent + over-counting within a trial) +* `frac = goals_reached_this_episode / max_trials_per_episode` +* threshold $\tau$ ladder by $k$: + * $k = 2$: $\tau = 0.5$ (both trials must succeed for $\text{frac} > 0.5$) + * $k \in \{3, 4\}$: $\tau = 0.8$ + * $k \geq 5$: $\tau = 0.9$ +* `score = 1` iff `frac > τ AND !collided_in_episode` + +## Render gates + +`respawn_agent()` is shared between `goal_behavior=0` (RESPAWN, with +intentional ghost-fade post-respawn) and the trial-mode mid-episode +respawn (no ghost — agent should be fully visible immediately for +trial 2..K). The function sets `respawn_timestep = env->timestep`, and +**seven** downstream gates use `respawn_timestep != -1` as a "ghosted" +marker: + +| Location | Effect when active | +|---|---| +| `drive.h:1327` | Skip self-side collision check | +| `drive.h:1342` | Skip other-as-target collision check | +| `drive.h:2409` | Force `obs[6] = 1` (post-respawn flag) | +| `drive.h:2455` | Other-car obs (self ghosted) zeroed | +| `drive.h:2457` | Other-car obs (other ghosted) zeroed | +| `drive.h:3482` | Skip 3D mesh draw (visible symptom) | +| `drive.h:3688` | Skip WOSAC track-index overlay | + +In trial mode, after `respawn_agent` in the mid-episode branch we +**must** clear the flag immediately: + +```c +respawn_agent(env, agent_idx); +e->respawn_timestep = -1; // GOAL_TRIAL is NOT a ghost-fade mode +e->trial_start_timestep = env->timestep; +``` + +Pre-fix symptom: trial 1 rendered correctly, trials 2..K appeared empty. + +## Per-trial metrics + +`add_log_one_agent` (in C) is called when an agent's episode ends. It +aggregates that single agent's metrics into `env->log` (the vec_log +sink), then resets all per-entity state the agent's next episode would +otherwise inherit (respawn_timestep, current_goal_reached, the +`metrics_array` slots, etc.). + +New per-trial log fields, in addition to the standard episode metrics: + +| Field | Meaning | +|---|---| +| `n_trials_completed` | Trials finished this episode (always equals `max_trials` for ego under Option D) | +| `n_trials_goal_reached` | Of those, how many reached goal | +| `n_trials_timed_out` | Of those, how many timed out | +| `trial_total_length` | Sum of trial lengths (ticks) | +| `trial_mean_length` | `trial_total_length / n_trials_completed` (computed in `add_log`) | +| `trial_goal_reach_rate` | `n_trials_goal_reached / n_trials_completed` | + +These are populated **only** under `goal_behavior=3`. The standard +metrics (score, collision_rate, episode_length, …) still populate via +the same `add_log_one_agent` path. + +The evaluator (`HumanReplayEvaluator`) computes additional per-trial +breakdowns from its own success array: + +| Field | Definition | +|---|---| +| `trial_K_score` | $\Pr$(reached in trial $K$) over the eval rollouts | +| `ada_delta_trial_K_minus_0` | `trial_K_score - trial_0_score` (the in-context adaptation signal) | + +For $K$ = `max_trials_per_episode` = 4 (auto-link from `k_scenarios=4`), +that's `trial_0_score`, …, `trial_3_score` and `ada_delta_trial_{1,2,3}_minus_0`. + +## Test coverage + +| File | What it covers | +|---|---| +| `test_goal_trial.py` | Trial timer fires; episode boundary fires at `trial_count == max_trials`; non-regression for gb∈{0,1,2} | +| `test_trial_ended_buffer.py` | `trial_ended_this_step` Python ↔ C buffer plumbing | +| `test_trial_log_fields.py` | Per-trial Log fields populate | +| `test_trial_standard_metrics.py` | Standard episode metrics still populate via `add_log_one_agent` | +| `test_trial_per_scenario_gate.py` | Per-scenario logic gated off under gb=3 | +| `test_trial_score_semantics.py` | Score uses `max_trials_per_episode` denominator | +| `test_trial_overcounting_fix.py` | `current_goal_reached` gates `goals_reached_this_episode` increments | +| `test_gae_trial_boundary.py` | GAE bootstrap-stop fires on truncations | +| `test_gae_decoupling_integration.py` | End-to-end `trial_ended_this_step → truncations` mirror | +| `test_adaptive_trial_link.py` | Auto-link of `max_trials_per_episode` and `per_trial_timeout` | +| `test_rollout_trial_mode.py` | Rollout `max_steps` / break / info under trial mode | +| `test_evaluator_trial_mode.py` | `HumanReplayEvaluator` emits `trial_K_score` + auto-link case | +| `test_pe_train_eval_consistency.py` | Transformer PE indexing matches between train and eval | +| `test_pos_within_episode.py` | `compute_pos_within_episode` correctness | + +All 53 tests pass on `mohit/trial-episode-redesign` HEAD. + +## Quick reference + +``` +goal_behavior = 3 # the toggle +k_scenarios = 4 # the number of trials, by auto-link +scenario_length = 201 # nuplan trajectories are 201 ticks + # → per_trial_timeout = 201, by auto-link + # → episode budget = 804 ticks + # → resample_frequency = 804, by auto-link +``` + +| Knob | Type | Default | Notes | +|---|---|---|---| +| `--env.goal-behavior` | int | 0 | 3 = trial mode | +| `--env.k-scenarios` | int | 1 | Under gb=3: number of trials per episode | +| `--env.scenario-length` | int | 91 | Under gb=3: per-trial timeout (ticks) | diff --git a/experiments/puffer_drive_2e029h15.pt b/experiments/puffer_drive_2e029h15.pt new file mode 100644 index 0000000000..99be09d278 Binary files /dev/null and b/experiments/puffer_drive_2e029h15.pt differ diff --git a/experiments/puffer_drive_6rauydj2.pt b/experiments/puffer_drive_6rauydj2.pt new file mode 100644 index 0000000000..3054a845e6 Binary files /dev/null and b/experiments/puffer_drive_6rauydj2.pt differ diff --git a/experiments/puffer_drive_m2ygolog.pt b/experiments/puffer_drive_m2ygolog.pt new file mode 100644 index 0000000000..d5b11e87e2 Binary files /dev/null and b/experiments/puffer_drive_m2ygolog.pt differ diff --git a/experiments/puffer_drive_miku2puk.pt b/experiments/puffer_drive_miku2puk.pt new file mode 100644 index 0000000000..97ab2dde05 Binary files /dev/null and b/experiments/puffer_drive_miku2puk.pt differ diff --git a/pufferlib/config/ocean/adaptive.ini b/pufferlib/config/ocean/adaptive.ini index fc347130b8..68d2d0246d 100644 --- a/pufferlib/config/ocean/adaptive.ini +++ b/pufferlib/config/ocean/adaptive.ini @@ -52,7 +52,11 @@ reward_vel_align = 1.0 goal_radius = 2.0 ; Max target speed in m/s for the agent to maintain towards the goal goal_speed = 100.0 -; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" +; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop", 3:"trial" +; Under 3 (trial): k_scenarios = number of trials, scenario_length = per-trial timeout. +; The C side still exposes `max_trials_per_episode` and `per_trial_timeout` for +; tests that want fine-grained control; runtime path overrides them from +; k_scenarios / scenario_length in AdaptiveDrivingAgent.__init__. goal_behavior = 0 ; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals. ; Large numbers will select a goal point further away from the agent's current position. diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 366ee25ac3..0b100e73ef 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -49,7 +49,8 @@ reward_vel_align = 1.0 goal_radius = 2.0 ; Max target speed in m/s for the agent to maintain towards the goal goal_speed = 100.0 -; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" +; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop", 3:"trial" +; Under 3 (trial), k_scenarios = number of trials, scenario_length = per-trial timeout. goal_behavior = 0 ; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals. ; Large numbers will select a goal point further away from the agent's current position. diff --git a/pufferlib/models.py b/pufferlib/models.py index 72121e458e..543da5be06 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -255,19 +255,15 @@ def __init__( else: self.input_projection = nn.Identity() - # Sinusoidal positional embedding (Vaswani et al.) — non-trainable. - # Switched from learnable PE so the transformer has temporal - # structure from initialization rather than having to learn it - # from gradients. Slot-tied: PE[i] is added when writing to - # cache slot i, identical for both forward (training) and - # forward_eval (rollout) paths via get_positional_embedding(). + # Sinusoidal positional embeddings (Vaswani et al.) — non-trainable. + # Per-episode reset is applied in forward() (training) so the PE + # indexing matches forward_eval's cache-pos indexing under + # multi-episode-per-row rollouts. pe = torch.zeros(horizon, hidden_size) position = torch.arange(0, horizon, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, hidden_size, 2, dtype=torch.float) * (-math.log(10000.0) / hidden_size)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - # register_buffer keeps it on the module's device but excludes it - # from .parameters() (no gradient updates). self.register_buffer("positional_embedding", pe.unsqueeze(0)) # Transformer encoder @@ -327,13 +323,33 @@ def get_causal_mask(self, T, device): return mask def get_positional_embedding(self, T, device): - """Get cached positional embedding for length T""" + """Get cached positional embedding for length T.""" cache_key = f"_pos_embed_{T}" if not hasattr(self, cache_key) or getattr(self, cache_key).device != device: pos_embed = self.positional_embedding[:, :T].to(device) setattr(self, cache_key, pos_embed) return getattr(self, cache_key) + @staticmethod + def compute_pos_within_episode(terminals): + """For terminals (B, T) bool/float, return per-slot position within + its episode (resets at slot AFTER each terminal). The convention + matches create_episode_mask: the terminal slot itself belongs to + the OLD episode, and the new episode starts at slot terminal+1. + + Vectorized: shift terminals right by one (so a terminal at slot s + becomes a start-flag at slot s+1), multiply by arange to mark the + position of each episode-start, cummax to propagate the last + seen start position forward, then subtract from arange. + """ + B, T = terminals.shape + device = terminals.device + arange_T = torch.arange(T, device=device, dtype=torch.long).unsqueeze(0).expand(B, T) + shifted = F.pad(terminals[:, :-1], (1, 0)).long() # (B, T) + starts = arange_T * shifted # (B, T) — slot index where new episode begins (else 0) + ep_start = starts.cummax(dim=1).values # (B, T) — most recent episode-start at or before t + return arange_T - ep_start # (B, T) + def create_episode_mask(self, terminals, seq_len): """Episode mask which ensures that you arent attending over episode boundaries. Optimized with cached mask buffers to reduce memory allocation.""" @@ -395,6 +411,10 @@ def init_eval_state(self, batch_size, device, dtype=torch.float32): k_cache=self._make_kv_cache(batch_size, device, dtype), v_cache=self._make_kv_cache(batch_size, device, dtype), transformer_position=torch.zeros(1, dtype=torch.long, device=device), + # B'' garbage_mask: per-agent, per-cache-slot bool. True = the slot + # was written while the agent was off-map (removed=1), so it should + # be excluded from attention to avoid limbo-token pollution. + garbage_mask=torch.zeros(batch_size, self.horizon, dtype=torch.bool, device=device), ) def _prime_kv_cache(self, indices, state): @@ -473,11 +493,17 @@ def reset_eval_state(self, state, done_indices=None): pos = state.get("transformer_position") if pos is not None: pos.zero_() + gm = state.get("garbage_mask") + if gm is not None: + gm.zero_() else: idx = done_indices if not torch.is_tensor(idx): idx = torch.as_tensor(idx, device=k_cache[0].device, dtype=torch.long) self._prime_kv_cache(idx, state) + gm = state.get("garbage_mask") + if gm is not None: + gm[idx] = False def forward_eval(self, observations, state): if _USE_LEGACY_EVAL: @@ -512,15 +538,26 @@ def forward_eval(self, observations, state): slot_t = (pos % self.horizon).long() # (1,) long tensor - # Add the slot's positional embedding (slot-tied, matching the - # legacy rolling-buffer scheme). + # PE indexed by slot_t (pos resets to 0 at episode boundary via + # pufferl.py's done handling, so PE[slot_t] = PE[pos_within_episode]). pos_embed = self.get_positional_embedding(self.horizon, device) # (1, horizon, hidden) pos_embed_slot = pos_embed.index_select(1, slot_t).squeeze(1) # (1, hidden) x = (hidden + pos_embed_slot).unsqueeze(1) # (B, 1, hidden) - # Build (1, 1, 1, horizon) bool mask: True at slots [0, slot_t]. + # B'' garbage_mask (per-agent, per-slot bool). Slots that were written + # while the agent was off-map (removed=1) are excluded from attention + # so the limbo period doesn't pollute the cache. Allocated lazily if + # missing or if batch size changed. + garbage_mask = state.get("garbage_mask") + if garbage_mask is None or garbage_mask.shape != (B, self.horizon): + garbage_mask = torch.zeros(B, self.horizon, dtype=torch.bool, device=device) + + # Build per-agent (B, 1, 1, horizon) bool mask: True at slots in + # [0, slot_t] AND not garbage. Pre-fix this was (1, 1, 1, horizon) + # shared across batch with no garbage exclusion. slots_arange = self._slot_arange(device) - attn_mask = (slots_arange <= slot_t).view(1, 1, 1, self.horizon) + base_mask = (slots_arange <= slot_t).view(1, self.horizon) # (1, horizon) + attn_mask = (base_mask & ~garbage_mask).view(B, 1, 1, self.horizon) H = self.num_heads D = self.head_dim @@ -577,6 +614,21 @@ def forward_eval(self, observations, state): state["transformer_position"] = pos + 1 state["hidden"] = hidden_out + # Mark just-written cache slot as garbage for agents that are off-map + # this step. Next step's attention will exclude these slots. `removed` + # comes from the env's SHM buffer (drive.py self.removed), routed via + # pufferl.py before this forward_eval call. + removed = state.get("removed") + if removed is not None: + r = removed.to(device=device, dtype=torch.bool).view(-1) + if r.shape[0] == B: + # Functional update (no scalar index) so torch.compile traces + # this without a data-dependent graph break. + slots = torch.arange(self.horizon, device=device) + position_mask = slot_t.view(-1, 1) == slots.view(1, -1) # (1, horizon) + garbage_mask = garbage_mask | (r.unsqueeze(1) & position_mask) + state["garbage_mask"] = garbage_mask + logits, values = self.policy.decode_actions(hidden_out) return logits, values @@ -607,7 +659,6 @@ def _forward_eval_legacy(self, observations, state): pos_embed = self.get_positional_embedding(self.horizon, device) context_with_pos = context + pos_embed - causal_mask = self.get_causal_mask(self.horizon, device) output = self.transformer(context_with_pos, mask=causal_mask, is_causal=True) @@ -646,7 +697,23 @@ def forward(self, observations, state): hidden = hidden[:, -T_actual:] T = T_actual - hidden = hidden + self.get_positional_embedding(T, device) + # Per-episode-reset PE: under multi-episode rollouts, training must + # match rollout's PE indexing. Rollout (forward_eval) resets pos to 0 + # at every episode boundary via pufferl.py's done handling, so for + # the same logical step within an episode, PE[pos_within_episode] + # is added. We mirror that here: compute pos_within_episode from + # terminals (cumsum-shifted-by-1 / cummax trick) and gather PE + # per-slot rather than indexing 0..T-1 across the segment. + terminals_for_pe = state.get("terminals") + if terminals_for_pe is not None: + pos_within_ep = self.compute_pos_within_episode(terminals_for_pe) # (B, T) long + pos_within_ep = pos_within_ep.clamp(max=self.horizon - 1) # safety: long-episode guard + # gather PE per (b, t): pe shape (1, horizon, hidden) → (B, T, hidden) + pe_full = self.get_positional_embedding(self.horizon, device) # (1, horizon, hidden) + pe_per_slot = pe_full[0, pos_within_ep] # (B, T, hidden) + hidden = hidden + pe_per_slot.to(hidden.dtype) + else: + hidden = hidden + self.get_positional_embedding(T, device) use_episode_mask = "terminals" in state and state["terminals"] is not None @@ -665,6 +732,27 @@ def forward(self, observations, state): causal_mask = self.get_causal_mask(T, device) episode_mask = self.create_episode_mask(terminals, T) attn_mask = causal_mask.unsqueeze(0) + episode_mask + # Train/eval parity: under gb=3, mask attention to limbo (off-map) + # SOURCE slots. Mirror the eval-time `garbage_mask` behavior so + # the training forward never attends to garbage slots. + # attn_mask[b, t, s] += -inf if removed[b, s] = 1, EXCEPT we + # always leave the diagonal (t == s) open so a limbo query + # has at least one valid source (itself) and the row never + # softmaxes to NaN. This matches the eval path, where + # garbage_mask[a, slot_t] is set AFTER the forward at step t, + # so the current slot is unmasked during its own attention. + removed = state.get("removed") + if removed is not None: + if removed.shape[1] > T: + removed = removed[:, -T:] + neg_inf = torch.tensor(float("-inf"), device=device, dtype=attn_mask.dtype) + zero = torch.tensor(0.0, device=device, dtype=attn_mask.dtype) + limbo_bias = torch.where(removed.unsqueeze(1), neg_inf, zero) # (B, 1, T) + # Materialize over the query axis so we can unmask the diagonal. + limbo_bias = limbo_bias.expand(-1, T, -1).contiguous() # (B, T, T) + diag_t = torch.arange(T, device=device) + limbo_bias[:, diag_t, diag_t] = 0 # leave self-attention open + attn_mask = attn_mask + limbo_bias attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) if self.training and self.use_checkpointing: hidden = checkpoint( diff --git a/pufferlib/ocean/benchmark/evaluator.py b/pufferlib/ocean/benchmark/evaluator.py index 818a266277..7924675e75 100644 --- a/pufferlib/ocean/benchmark/evaluator.py +++ b/pufferlib/ocean/benchmark/evaluator.py @@ -650,6 +650,18 @@ def rollout(self, args, puffer_env, policy): num_agents = puffer_env.observation_space.shape[0] device = args["train"]["device"] k_scenarios = args["env"].get("k_scenarios", 1) + goal_behavior = int(args["env"].get("goal_behavior", 0)) + # GOAL_TRIAL mode swaps the outer loop from `for scenario` to `for trial` + # (variable-length, ends on goal-reach OR per-trial timeout). + is_trial_mode = goal_behavior == 3 + if is_trial_mode: + # Under gb=3 the env's max_trials/per_trial_timeout are always + # k_scenarios / scenario_length (AdaptiveDrivingAgent enforces). + # Read from driver_env when available, else fall back to the + # adaptive-link formula directly. + driver = getattr(puffer_env, "driver_env", None) + max_trials = int(getattr(driver, "max_trials_per_episode", k_scenarios) or k_scenarios) + per_trial_timeout = int(getattr(driver, "per_trial_timeout", self.sim_steps) or self.sim_steps) is_transformer = hasattr(policy, "horizon") and hasattr(policy, "transformer") is_recurrent = hasattr(policy, "lstm") @@ -689,7 +701,9 @@ def _fresh_state(): cache_reset_per_scenario = os.environ.get("RECOVERY_CACHE_RESET_PER_SCENARIO", "0") == "1" if cache_reset_per_scenario: print("[recovery] CONTROL mode: resetting K/V cache at every scenario boundary", flush=True) - success_arr = np.zeros((num_rollouts, k_scenarios, num_agents), dtype=bool) + # success_arr indexed by (rollout, scenario_or_trial, agent) + n_outer = max_trials if is_trial_mode else k_scenarios + success_arr = np.zeros((num_rollouts, n_outer, num_agents), dtype=bool) for rollout_idx in range(num_rollouts): obs, _ = puffer_env.reset() @@ -698,41 +712,69 @@ def _fresh_state(): scenario_metrics = {} delta_metrics = {} - for scenario in range(k_scenarios): - if scenario > 0 and cache_reset_per_scenario: - state = _fresh_state() - for time_idx in range(self.sim_steps): + if is_trial_mode: + # Trial mode: run up to max_trials * per_trial_timeout ticks. + # Per-agent trial counter advances on trial_ended_this_step. + # Capture trial outcome (reach=reward>thresh) at trial-end. + trial_idx = np.zeros(num_agents, dtype=np.int32) + rollout_complete = np.zeros(num_agents, dtype=bool) + max_steps = max_trials * per_trial_timeout + for time_idx in range(max_steps): with torch.no_grad(): ob_tensor = torch.as_tensor(obs).to(device) logits, value = policy.forward_eval(ob_tensor, state) action, logprob, _ = pufferlib.pytorch.sample_logits(logits) action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape) - if isinstance(logits, torch.distributions.Normal): action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high) - obs, rewards, dones, truncs, info_list = puffer_env.step(action_np) - - # Mark per-agent success this scenario: a +reward_goal spike - # at any tick == goal reached. In stop-on-goal mode the env - # does NOT set `dones` per agent (the agent just stops moving), - # so we can't gate on dones. The only step-level reward that - # crosses `goal_reward_threshold` is the goal reward itself - # (lane_align is ~0.01/step, so even integrated it can't - # reach 0.5 in one tick). We OR across the scenario so the - # success flag sticks even if subsequent ticks are 0. rewards_arr = np.asarray(rewards).reshape(-1) - success_arr[rollout_idx, scenario] |= rewards_arr > goal_reward_threshold - + reached = rewards_arr > goal_reward_threshold + + te = np.asarray(puffer_env.trial_ended_this_step).reshape(-1).astype(bool) + end_idxs = np.where(te & ~rollout_complete)[0] + for a in end_idxs: + ti = int(trial_idx[a]) + if ti < max_trials: + success_arr[rollout_idx, ti, a] = bool(reached[a]) + trial_idx[a] = ti + 1 + if trial_idx[a] >= max_trials: + rollout_complete[a] = True for info_dict in info_list: if not isinstance(info_dict, dict): continue - if "ada_delta_score" in info_dict: - delta_metrics = info_dict - elif any(k.startswith("scenario_") for k in info_dict.keys()): - scenario_metrics.update(info_dict) - elif "score" in info_dict: + if "score" in info_dict: collected_infos.append(info_dict) + if rollout_complete.all(): + break + else: + for scenario in range(k_scenarios): + if scenario > 0 and cache_reset_per_scenario: + state = _fresh_state() + for time_idx in range(self.sim_steps): + with torch.no_grad(): + ob_tensor = torch.as_tensor(obs).to(device) + logits, value = policy.forward_eval(ob_tensor, state) + action, logprob, _ = pufferlib.pytorch.sample_logits(logits) + action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape) + + if isinstance(logits, torch.distributions.Normal): + action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high) + + obs, rewards, dones, truncs, info_list = puffer_env.step(action_np) + + rewards_arr = np.asarray(rewards).reshape(-1) + success_arr[rollout_idx, scenario] |= rewards_arr > goal_reward_threshold + + for info_dict in info_list: + if not isinstance(info_dict, dict): + continue + if "ada_delta_score" in info_dict: + delta_metrics = info_dict + elif any(k.startswith("scenario_") for k in info_dict.keys()): + scenario_metrics.update(info_dict) + elif "score" in info_dict: + collected_infos.append(info_dict) if collected_infos: rollout_agg = { @@ -769,12 +811,24 @@ def _fresh_state(): # Schema: list of {"rollout": int, "agent": int, "s0": int, ..., # "s_{k-1}": int} — one record per (rollout, agent) pair. records = [] + prefix = "t" if is_trial_mode else "s" for r in range(num_rollouts): for a in range(num_agents): rec = {"rollout": int(r), "agent": int(a)} - for s_idx in range(k_scenarios): - rec[f"s{s_idx}"] = int(success_arr[r, s_idx, a]) + for s_idx in range(n_outer): + rec[f"{prefix}{s_idx}"] = int(success_arr[r, s_idx, a]) records.append(rec) final["per_agent_success_log"] = records + # Per-trial aggregate metrics + ada_delta deltas (trial mode only). + # Computed from success_arr to give clean per-trial signal even when + # the env's vec_log path doesn't aggregate per-trial rates. + if is_trial_mode: + for k in range(n_outer): + trial_k_score = float(success_arr[:, k, :].mean()) + final[f"trial_{k}_score"] = trial_k_score + t0 = float(success_arr[:, 0, :].mean()) + for k in range(1, n_outer): + final[f"ada_delta_trial_{k}_minus_0"] = float(success_arr[:, k, :].mean()) - t0 + return final diff --git a/pufferlib/ocean/drive/adaptive.py b/pufferlib/ocean/drive/adaptive.py index 7af77c62db..46d690566e 100644 --- a/pufferlib/ocean/drive/adaptive.py +++ b/pufferlib/ocean/drive/adaptive.py @@ -20,4 +20,17 @@ def __init__(self, **kwargs): kwargs["resample_frequency"] = self.k_scenarios * self.scenario_length self.episode_length = kwargs["resample_frequency"] + # Under GOAL_TRIAL: k_scenarios IS the trial count, scenario_length IS + # per-trial-timeout. No fallback to INI defaults. Tests that need a + # custom trial budget should override k_scenarios + scenario_length + # directly. + if int(kwargs.get("goal_behavior", 0)) == 3: + assert self.k_scenarios <= 8, ( + f"k_scenarios={self.k_scenarios} > 8 not supported under goal_behavior=3 " + f"(trial_k_goal_reached[] is fixed at N_TRIAL_K_SLOTS=8 in drive.h). " + f"Bump that array + N_TRIAL_K_SLOTS or use k_scenarios <= 8." + ) + kwargs["max_trials_per_episode"] = self.k_scenarios + kwargs["per_trial_timeout"] = self.scenario_length + super().__init__(**kwargs) diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index 5eaed37b0a..1615f235a9 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -65,6 +65,39 @@ static int my_put(Env *env, PyObject *args, PyObject *kwargs) { return 1; } env->terminals = PyArray_DATA(terminals); + // env->truncations is wired from positional args by env_binding.h's + // env_init handler (zero-copy view of the PufferLib SHM buffer). + + // trial_ended_this_step is OPTIONAL — older callers may not pass it. + // Defaults to NULL; c_step's memset is guarded. + PyObject *trial = PyDict_GetItemString(kwargs, "trial_ended_this_step"); + if (trial != NULL) { + if (!PyObject_TypeCheck(trial, &PyArray_Type)) { + PyErr_SetString(PyExc_TypeError, "trial_ended_this_step must be a NumPy array"); + return 1; + } + PyArrayObject *trial_arr = (PyArrayObject *)trial; + if (!PyArray_ISCONTIGUOUS(trial_arr) || PyArray_NDIM(trial_arr) != 1) { + PyErr_SetString(PyExc_ValueError, "trial_ended_this_step must be 1D contiguous"); + return 1; + } + env->trial_ended_this_step = PyArray_DATA(trial_arr); + } + // removed (per-agent off-map flag, B''). Same pattern as + // trial_ended_this_step: C is the only writer; Python reads. + PyObject *removed_obj = PyDict_GetItemString(kwargs, "removed"); + if (removed_obj != NULL) { + if (!PyObject_TypeCheck(removed_obj, &PyArray_Type)) { + PyErr_SetString(PyExc_TypeError, "removed must be a NumPy array"); + return 1; + } + PyArrayObject *removed_arr = (PyArrayObject *)removed_obj; + if (!PyArray_ISCONTIGUOUS(removed_arr) || PyArray_NDIM(removed_arr) != 1) { + PyErr_SetString(PyExc_ValueError, "removed must be 1D contiguous"); + return 1; + } + env->removed = PyArray_DATA(removed_arr); + } return 0; } @@ -106,6 +139,17 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->reward_vel_align = (float)unpack(kwargs, "reward_vel_align"); env->scenario_length = conf.scenario_length; + // GOAL_TRIAL config (only used when goal_behavior == GOAL_TRIAL). + env->max_trials_per_episode = 2; + env->per_trial_timeout = conf.scenario_length; + if (kwargs && PyDict_GetItemString(kwargs, "max_trials_per_episode")) { + env->max_trials_per_episode = (int)unpack(kwargs, "max_trials_per_episode"); + } + if (kwargs && PyDict_GetItemString(kwargs, "per_trial_timeout")) { + int v = (int)unpack(kwargs, "per_trial_timeout"); + if (v > 0) env->per_trial_timeout = v; // 0 means "use default" (scenario_length) + } + env->termination_mode = conf.termination_mode; env->collision_behavior = conf.collision_behavior; env->offroad_behavior = conf.offroad_behavior; @@ -197,6 +241,37 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->map_name = strdup(map_file); env->init_steps = init_steps; env->timestep = init_steps; + + // trial_ended_this_step is OPTIONAL. NULL is safe (c_step's memset is guarded). + env->trial_ended_this_step = NULL; + PyObject *trial = PyDict_GetItemString(kwargs, "trial_ended_this_step"); + if (trial != NULL) { + if (!PyObject_TypeCheck(trial, &PyArray_Type)) { + PyErr_SetString(PyExc_TypeError, "trial_ended_this_step must be a NumPy array"); + return -1; + } + PyArrayObject *trial_arr = (PyArrayObject *)trial; + if (!PyArray_ISCONTIGUOUS(trial_arr) || PyArray_NDIM(trial_arr) != 1) { + PyErr_SetString(PyExc_ValueError, "trial_ended_this_step must be 1D contiguous"); + return -1; + } + env->trial_ended_this_step = PyArray_DATA(trial_arr); + } + env->removed = NULL; + PyObject *removed_obj = PyDict_GetItemString(kwargs, "removed"); + if (removed_obj != NULL) { + if (!PyObject_TypeCheck(removed_obj, &PyArray_Type)) { + PyErr_SetString(PyExc_TypeError, "removed must be a NumPy array"); + return -1; + } + PyArrayObject *removed_arr = (PyArrayObject *)removed_obj; + if (!PyArray_ISCONTIGUOUS(removed_arr) || PyArray_NDIM(removed_arr) != 1) { + PyErr_SetString(PyExc_ValueError, "removed must be 1D contiguous"); + return -1; + } + env->removed = PyArray_DATA(removed_arr); + } + init(env); return 0; } @@ -217,5 +292,27 @@ static int my_log(PyObject *dict, Log *log) { assign_to_dict(dict, "goals_reached_this_episode", log->goals_reached_this_episode); assign_to_dict(dict, "speed_at_goal", log->speed_at_goal); // assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error); + + // GOAL_TRIAL metrics (zero under other goal_behavior). + assign_to_dict(dict, "n_trials_completed", log->n_trials_completed); + assign_to_dict(dict, "n_trials_goal_reached", log->n_trials_goal_reached); + assign_to_dict(dict, "n_trials_timed_out", log->n_trials_timed_out); + if (log->n_trials_completed > 0.0f) { + assign_to_dict(dict, "trial_mean_length", log->trial_total_length / log->n_trials_completed); + assign_to_dict(dict, "trial_goal_reach_rate", log->n_trials_goal_reached / log->n_trials_completed); + } else { + assign_to_dict(dict, "trial_mean_length", 0.0f); + assign_to_dict(dict, "trial_goal_reach_rate", 0.0f); + } + // Per-trial-index success rate (GOAL_TRIAL only). n_trials_completed is + // the gate: it's only non-zero under GOAL_TRIAL, so gb=0/1/2 won't leak + // these keys into wandb / eval output. + if (log->n_trials_completed > 0.0f) { + char key[32]; + for (int k = 0; k < N_TRIAL_K_SLOTS; k++) { + snprintf(key, sizeof(key), "trial_%d_score", k); + assign_to_dict(dict, key, log->trial_k_goal_reached[k]); + } + } return 0; } diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 895604077f..d1f3ee7e54 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -116,6 +116,7 @@ #define GOAL_RESPAWN 0 #define GOAL_GENERATE_NEW 1 #define GOAL_STOP 2 +#define GOAL_TRIAL 3 // up to max_trials_per_episode trials; ends on goal or per-trial timeout #define PARTNER_FEATURES 7 @@ -197,7 +198,16 @@ struct Log { float avg_goal_weight; float avg_entropy_weight; float avg_discount_weight; + // Per-trial metrics (GOAL_TRIAL only). All zero under other goal_behavior. + float n_trials_completed; + float n_trials_goal_reached; + float n_trials_timed_out; + float trial_total_length; // running sum, divided by n_trials_completed in add_log + // Per-trial-index goal-reach counters. After vec_log normalization, each + // slot IS trial_K_score. k_scenarios > 8 isn't supported for this metric. + float trial_k_goal_reached[8]; }; +#define N_TRIAL_K_SLOTS 8 typedef struct Entity Entity; struct Entity { @@ -243,6 +253,9 @@ struct Entity { float goals_reached_this_episode; float goals_sampled_this_episode; int current_goal_reached; + int collided_this_trial; // GOAL_TRIAL only: 1 if any collision/offroad this trial + int trial_count; // GOAL_TRIAL only: trials completed this episode + int trial_start_timestep; // GOAL_TRIAL only: tick when current trial began int active_agent; float cumulative_displacement; int displacement_sample_count; @@ -333,6 +346,14 @@ struct Drive { float *actions; float *rewards; unsigned char *terminals; + unsigned char *trial_ended_this_step; // GOAL_TRIAL: per-agent trial-boundary flag + unsigned char *truncations; // GOAL_TRIAL: trial-end bootstrap-stop signal + unsigned char *removed; // GOAL_TRIAL B'': per-agent off-map flag + // Env-level trial state (GOAL_TRIAL B''). All egos in this env share one + // trial clock; trial-end fires when all egos have removed=1 or timeout. + int env_trial_count; + int env_trial_start_timestep; + int env_episode_ended; // 1 after episode end (Option D); cleared by c_reset Log log; Log *logs; int num_agents; @@ -382,6 +403,9 @@ struct Drive { int init_mode; int control_mode; + int max_trials_per_episode; // GOAL_TRIAL: max trials per episode (default 2) + int per_trial_timeout; // GOAL_TRIAL: ticks per trial (default scenario_length) + // Reward conditioning bool use_rc; float collision_weight_lb; @@ -417,6 +441,91 @@ struct Drive { // "render". }; +// Per-agent variant of add_log used at GOAL_TRIAL episode end. Can't reuse +// add_log because it assumes a synchronized scenario boundary; under gb=3 +// each agent's episode ends at its own trial_count == max_trials. +void add_log_one_agent(Drive *env, int i) { + Entity *e = &env->entities[env->active_agent_indices[i]]; + + if (e->is_ego) { + // BUG-FIX: these increments were OUTSIDE this guard before, so + // co-player goal counts contaminated the ego aggregate. Now gated + // on is_ego so env->log.goals_* reflect ego progress only. + env->log.goals_reached_this_episode += e->goals_reached_this_episode; + env->log.goals_sampled_this_episode += e->goals_sampled_this_episode; + + int offroad = env->logs[i].offroad_rate; + int collided = env->logs[i].collision_rate; + env->log.offroad_rate += offroad; + env->log.collision_rate += collided; + env->log.offroad_per_agent += env->logs[i].offroad_per_agent; + env->log.collisions_per_agent += env->logs[i].collisions_per_agent; + env->log.lane_alignment_rate += env->logs[i].lane_alignment_rate; + env->log.speed_at_goal += env->logs[i].speed_at_goal; + env->log.episode_length += env->logs[i].episode_length; + env->log.episode_return += env->logs[i].episode_return; + env->log.active_agent_count += env->active_agent_count; + env->log.expert_static_agent_count += env->expert_static_agent_count; + env->log.static_agent_count += env->static_agent_count; + + // Score is accumulated per-trial in c_step's trial-end loop: + // each clean trial (goal reached + no collision/offroad this trial) + // contributes 1/max_trials_per_episode. So score ∈ [0, 1] per ego. + // dnf_rate keeps episode-level "did not finish all trials cleanly". + float denom = (float)env->max_trials_per_episode; + float frac = (denom > 0.0f) ? e->goals_reached_this_episode / denom : 0.0f; + if (!offroad && !collided && frac < 1.0f) env->log.dnf_rate += 1.0f; + env->log.n += 1.0f; + } + + if (e->is_co_player && env->co_player_logs != NULL) { + int co_offroad = env->co_player_logs[i].offroad_rate; + int co_collided = env->co_player_logs[i].collision_rate; + env->co_player_log.offroad_rate += co_offroad; + env->co_player_log.collision_rate += co_collided; + env->co_player_log.offroad_per_agent += env->co_player_logs[i].offroad_per_agent; + env->co_player_log.collisions_per_agent += env->co_player_logs[i].collisions_per_agent; + env->co_player_log.lane_alignment_rate += env->co_player_logs[i].lane_alignment_rate; + env->co_player_log.speed_at_goal += env->co_player_logs[i].speed_at_goal; + env->co_player_log.episode_length += env->co_player_logs[i].episode_length; + env->co_player_log.episode_return += env->co_player_logs[i].episode_return; + + // Same per-trial denominator fix as the ego branch above. + float co_denom = (float)env->max_trials_per_episode; + float co_frac = (co_denom > 0.0f) ? e->goals_reached_this_episode / co_denom : 0.0f; + float co_threshold = 0.99f; + if (env->max_trials_per_episode == 2) co_threshold = 0.5f; + else if (env->max_trials_per_episode < 5) co_threshold = 0.8f; + else co_threshold = 0.9f; + if (co_frac > co_threshold && !co_collided) env->co_player_log.score += 1.0f; + if (!co_offroad && !co_collided && co_frac < 1.0f) env->co_player_log.dnf_rate += 1.0f; + env->co_player_log.n += 1.0f; + } + + // Mirror EVERY per-entity field c_reset clears. c_reset is bypassed under + // gb=3 (no scenario-length early-return); stale state would carry to the + // next episode (e.g. respawn_timestep stuck != -1 hides ego in renders). + env->logs[i] = (Log){0}; + if (env->population_play && env->co_player_logs != NULL) env->co_player_logs[i] = (Log){0}; + e->goals_reached_this_episode = 0.0f; + e->goals_sampled_this_episode = 1.0f; + e->collided_before_goal = 0; + e->current_goal_reached = 0; + e->respawn_timestep = -1; + e->respawn_count = 0; + e->stopped = 0; + // Don't reset `removed`: Option D sets it AFTER this call so the agent + // idles until resample_frequency. c_reset is what clears it. + e->metrics_array[COLLISION_IDX] = 0.0f; + e->metrics_array[OFFROAD_IDX] = 0.0f; + e->metrics_array[REACHED_GOAL_IDX] = 0.0f; + e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; + e->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; + e->metrics_array[LANE_ANGLE_IDX] = 0.0f; + e->current_lane_idx = -1; + e->current_lane_geometry_idx = -1; +} + void add_log(Drive *env) { for (int i = 0; i < env->active_agent_count; i++) { Entity *e = &env->entities[env->active_agent_indices[i]]; @@ -730,6 +839,9 @@ void set_start_position(Drive *env) { e->stopped = 0; e->removed = 0; e->respawn_count = 0; + e->trial_count = 0; + e->trial_start_timestep = 0; + e->collided_this_trial = 0; // Dynamics e->a_long = 0.0f; @@ -1032,6 +1144,20 @@ void set_means(Drive *env) { void move_expert(Drive *env, float *actions, int agent_idx) { Entity *agent = &env->entities[agent_idx]; int t = env->timestep; + // GOAL_TRIAL B'': humans replay on the env's trial clock so they reset to + // frame 0 at every env trial-end. Visual consequence: if humans have + // late-valid windows (enter scene at frame 30+) and trials are short + // (ego reaches goal fast), humans may not appear in those trials — + // that's the data, not a bug. We don't care about what humans do during + // a trial, only about ego-side strict trial-equivalence. + if (env->goal_behavior == GOAL_TRIAL && agent->array_size > 0) { + // Recording frame = init_steps + ticks-since-trial-start. Matches what + // set_start_position uses at c_reset (init_steps) and advances from + // there. Every trial begins at the same recording frame as trial 1. + t = env->init_steps + (env->timestep - env->env_trial_start_timestep); + t = t % agent->array_size; + if (t < 0) t += agent->array_size; + } if (t < 0 || t >= agent->array_size) { agent->x = INVALID_POSITION; agent->y = INVALID_POSITION; @@ -1562,6 +1688,10 @@ void compute_agent_metrics(Drive *env, int agent_idx) { collided = VEHICLE_COLLISION; agent->collision_state = collided; + if (collided != 0) { + // GOAL_TRIAL per-trial clean-success tracker. Cleared at trial-end. + agent->collided_this_trial = 1; + } if (collided == VEHICLE_COLLISION) { if (env->collision_behavior == STOP_AGENT && !agent->stopped) { @@ -1955,6 +2085,7 @@ void allocate(Drive *env) { env->actions = (float *)calloc(env->active_agent_count * 2, sizeof(float)); env->rewards = (float *)calloc(env->active_agent_count, sizeof(float)); env->terminals = (unsigned char *)calloc(env->active_agent_count, sizeof(unsigned char)); + env->trial_ended_this_step = (unsigned char *)calloc(env->active_agent_count, sizeof(unsigned char)); } void free_allocated(Drive *env) { @@ -1962,6 +2093,7 @@ void free_allocated(Drive *env) { free(env->actions); free(env->rewards); free(env->terminals); + free(env->trial_ended_this_step); // Always free weight arrays free(env->collision_weights); @@ -2480,6 +2612,9 @@ void sample_new_goal(Drive *env, int agent_idx) { void c_reset(Drive *env) { env->timestep = env->init_steps; + env->env_trial_count = 0; + env->env_trial_start_timestep = env->timestep; + env->env_episode_ended = 0; set_start_position(env); for (int i = 0; i < env->active_agent_count; i++) { @@ -2515,6 +2650,9 @@ void c_reset(Drive *env) { env->entities[agent_idx].current_lane_geometry_idx = -1; env->entities[agent_idx].stopped = 0; env->entities[agent_idx].removed = 0; + if (env->removed != NULL) env->removed[x] = 0; + env->entities[agent_idx].trial_count = 0; + env->entities[agent_idx].trial_start_timestep = env->init_steps; if (env->goal_behavior == GOAL_GENERATE_NEW) { env->entities[agent_idx].goal_position_x = env->entities[agent_idx].init_goal_x; @@ -2530,13 +2668,20 @@ void c_reset(Drive *env) { } void respawn_agent(Drive *env, int agent_idx) { - env->entities[agent_idx].x = env->entities[agent_idx].traj_x[0]; - env->entities[agent_idx].y = env->entities[agent_idx].traj_y[0]; - env->entities[agent_idx].heading = env->entities[agent_idx].traj_heading[0]; + // Use the same starting frame as c_reset's set_start_position (init_steps). + // Pre-fix this used traj[0], which broke trial-mode strict equivalence: + // trial 1 starts at traj[init_steps] (via set_start_position), trial 2..K + // would start at traj[0] via respawn_agent. Now both use init_steps. + int step = env->init_steps; + if (step >= env->entities[agent_idx].array_size) step = env->entities[agent_idx].array_size - 1; + if (step < 0) step = 0; + env->entities[agent_idx].x = env->entities[agent_idx].traj_x[step]; + env->entities[agent_idx].y = env->entities[agent_idx].traj_y[step]; + env->entities[agent_idx].heading = env->entities[agent_idx].traj_heading[step]; env->entities[agent_idx].heading_x = cosf(env->entities[agent_idx].heading); env->entities[agent_idx].heading_y = sinf(env->entities[agent_idx].heading); - env->entities[agent_idx].vx = env->entities[agent_idx].traj_vx[0]; - env->entities[agent_idx].vy = env->entities[agent_idx].traj_vy[0]; + env->entities[agent_idx].vx = env->entities[agent_idx].traj_vx[step]; + env->entities[agent_idx].vy = env->entities[agent_idx].traj_vy[step]; env->entities[agent_idx].metrics_array[COLLISION_IDX] = 0.0f; env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; @@ -2555,11 +2700,26 @@ void respawn_agent(Drive *env, int agent_idx) { env->entities[agent_idx].jerk_long = 0.0f; env->entities[agent_idx].jerk_lat = 0.0f; env->entities[agent_idx].steering_angle = 0.0f; + // Allow the next trial (GOAL_TRIAL) to register a fresh goal-reach event. + // Without this, the trial-end gate at the start of c_step's goal-reach + // block (`!current_goal_reached`) stays false forever after the first + // success, suppressing all subsequent trial-end goal_weight rewards. + env->entities[agent_idx].current_goal_reached = 0; } void c_step(Drive *env) { memset(env->rewards, 0, env->active_agent_count * sizeof(float)); memset(env->terminals, 0, env->active_agent_count * sizeof(unsigned char)); + if (env->trial_ended_this_step != NULL) { + memset(env->trial_ended_this_step, 0, env->active_agent_count * sizeof(unsigned char)); + } + // C owns truncations under GOAL_TRIAL. Zero at top of step; write 1 at + // each trial boundary inside the GOAL_TRIAL branch. Under non-trial + // modes Python may still write truncations directly (e.g. k_eff + // curriculum), so leave the buffer alone there. + if (env->goal_behavior == GOAL_TRIAL && env->truncations != NULL) { + memset(env->truncations, 0, env->active_agent_count * sizeof(unsigned char)); + } env->timestep++; @@ -2573,7 +2733,8 @@ void c_step(Drive *env) { } } - if (env->timestep == env->scenario_length || (!originals_remaining && env->termination_mode == 1)) { + if (env->goal_behavior != GOAL_TRIAL && + (env->timestep == env->scenario_length || (!originals_remaining && env->termination_mode == 1))) { add_log(env); c_reset(env); return; @@ -2664,6 +2825,16 @@ void c_step(Drive *env) { bool within_distance = distance_to_goal < env->goal_radius; bool within_speed = current_speed <= env->goal_speed; + // Goal-reach block. Invariant: `goals_reached_this_episode` is + // incremented at most ONCE per (agent, trial-or-scenario), gated by + // `current_goal_reached`. The flag is cleared by: + // - respawn_agent (GOAL_TRIAL mid-episode respawn, GOAL_RESPAWN's + // ghost-respawn) + // - c_reset (GOAL_STOP / scenario boundary) + // - add_log_one_agent (GOAL_TRIAL episode boundary) + // GOAL_RESPAWN's ghost-reward path (respawn_timestep != -1) does NOT + // increment: that reward fires every step the ghost is in radius, by + // design. Only the FIRST goal-reach pre-ghost counts as a "trial succeeded." if (within_distance && within_speed && !env->entities[agent_idx].current_goal_reached) { if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1) { float scaled_post_respawn_reward = env->reward_goal_post_respawn * env->goal_weights[i]; @@ -2687,7 +2858,7 @@ void c_step(Drive *env) { sample_new_goal(env, agent_idx); env->entities[agent_idx].current_goal_reached = 0; env->entities[agent_idx].goals_reached_this_episode += 1.0f; - } else { // Zero out the velocity so that the agent stops at the goal + } else { // GOAL_STOP or GOAL_TRIAL env->rewards[i] = env->goal_weights[i]; if (is_ego) { @@ -2696,9 +2867,23 @@ void c_step(Drive *env) { env->co_player_logs[i].episode_return = env->goal_weights[i]; } - env->entities[agent_idx].stopped = 1; - env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f; env->entities[agent_idx].goals_reached_this_episode += 1.0f; + env->entities[agent_idx].current_goal_reached = 1; + + if (env->goal_behavior == GOAL_TRIAL) { + // B'': go off-map, wait for env trial-end (sync reset). + env->entities[agent_idx].removed = 1; + if (env->removed != NULL) env->removed[i] = 1; + env->entities[agent_idx].x = INVALID_POSITION; + env->entities[agent_idx].y = INVALID_POSITION; + env->entities[agent_idx].vx = 0.0f; + env->entities[agent_idx].vy = 0.0f; + } else { + // GOAL_STOP: freeze in place, collidable. + env->entities[agent_idx].stopped = 1; + env->entities[agent_idx].vx = 0.0f; + env->entities[agent_idx].vy = 0.0f; + } } env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 1.0f; @@ -2759,6 +2944,81 @@ void c_step(Drive *env) { env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f; } } + } else if (env->goal_behavior == GOAL_TRIAL && !env->env_episode_ended) { + // B'': env-level trial. All egos share one clock. Trial-end fires when + // ALL active egos are off-map (removed=1, set by goal-reach branch) OR + // env's per_trial_timeout has elapsed. + int total_egos = 0; + int reached_egos = 0; + for (int i = 0; i < env->active_agent_count; i++) { + Entity *e = &env->entities[env->active_agent_indices[i]]; + if (!e->is_ego) continue; + total_egos++; + if (e->removed) reached_egos++; + } + bool all_egos_done = (total_egos > 0) && (reached_egos == total_egos); + bool env_timeout = (env->timestep - env->env_trial_start_timestep) >= env->per_trial_timeout; + if (all_egos_done || env_timeout) { + int k = env->env_trial_count; // index of the trial that just ended + int trial_len = env->timestep - env->env_trial_start_timestep; + env->env_trial_count++; + bool is_episode_end = (env->env_trial_count >= env->max_trials_per_episode); + + for (int i = 0; i < env->active_agent_count; i++) { + int agent_idx = env->active_agent_indices[i]; + Entity *e = &env->entities[agent_idx]; + + if (env->trial_ended_this_step != NULL) env->trial_ended_this_step[i] = 1; + if (env->truncations != NULL) env->truncations[i] = 1; + + if (e->is_ego) { + env->log.n_trials_completed += 1.0f; + env->log.trial_total_length += (float)trial_len; + if (e->current_goal_reached) { + env->log.n_trials_goal_reached += 1.0f; + if (!e->collided_this_trial) { + // Clean trial success: per-trial score = 1. + // Contributes 1/max_trials_per_episode to episode score. + env->log.score += 1.0f / (float)env->max_trials_per_episode; + if (k >= 0 && k < N_TRIAL_K_SLOTS) + env->log.trial_k_goal_reached[k] += 1.0f; + } + } else { + env->log.n_trials_timed_out += 1.0f; + } + e->collided_this_trial = 0; // reset for the next trial + } + + if (is_episode_end) { + // Option D: idle off-grid until c_reset. + env->terminals[i] = 1; + add_log_one_agent(env, i); + e->removed = 1; + if (env->removed != NULL) env->removed[i] = 1; + e->x = INVALID_POSITION; + e->y = INVALID_POSITION; + e->vx = 0.0f; + e->vy = 0.0f; + } else { + // Trial-end (not episode): per-entity trial-mode flags. + // Full position / velocity / metric reset happens via + // set_start_position below. + e->current_goal_reached = 0; + e->removed = 0; + if (env->removed != NULL) env->removed[i] = 0; + } + } + if (is_episode_end) { + env->env_trial_count = 0; + env->env_episode_ended = 1; + } else { + // Reset ALL entities (active + static) to the same initial + // state as c_reset's set_start_position so the next trial + // is bit-for-bit identical to trial 1. + set_start_position(env); + } + env->env_trial_start_timestep = env->timestep; + } } compute_observations(env); @@ -2840,7 +3100,7 @@ static void start_video_recorder(Client *client, const char *basename) { for (int fd = 3; fd < 256; fd++) { close(fd); } - execlp("ffmpeg", "ffmpeg", "-y", "-f", "rawvideo", "-pix_fmt", "rgba", "-s", size_str, "-r", "30", "-i", "-", + execlp("ffmpeg", "ffmpeg", "-y", "-f", "rawvideo", "-pix_fmt", "rgba", "-s", size_str, "-r", "10", "-i", "-", "-c:v", "libx264", "-threads", "4", "-pix_fmt", "yuv420p", "-preset", "ultrafast", "-crf", "23", "-loglevel", "error", filename, NULL); fprintf(stderr, "Failed to exec ffmpeg\n"); @@ -3614,8 +3874,18 @@ void c_render_with_mode(Drive *env, int view_mode, int draw_traces, int current_ EndMode3D(); } - // Draw scenario counter overlay (2D text on top of 3D scene) - if (k_scenarios > 1) { + // Draw scenario/trial counter overlay. Under gb=3 B'' we read the + // env-level trial counter (per-entity trial_count is no longer + // updated). Clamp to max so the last-tick "just incremented" value + // doesn't show as K+1. + if (env->goal_behavior == GOAL_TRIAL && env->max_trials_per_episode > 1) { + int trial_n = env->env_trial_count + 1; + if (trial_n > env->max_trials_per_episode) trial_n = env->max_trials_per_episode; + char trial_text[64]; + snprintf(trial_text, sizeof(trial_text), "Trial %d / %d", + trial_n, env->max_trials_per_episode); + DrawText(trial_text, 40, 40, 120, WHITE); + } else if (k_scenarios > 1) { char scenario_text[64]; snprintf(scenario_text, sizeof(scenario_text), "Scenario %d / %d", current_scenario + 1, k_scenarios); DrawText(scenario_text, 40, 40, 120, WHITE); diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 357d37649c..6d14c6ee57 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -34,6 +34,8 @@ def __init__( reward_lane_align=0.0, # GIGAFLOW lane alignment reward (0 = disabled) reward_vel_align=1.0, # Velocity alignment coefficient for lane reward goal_behavior=0, + max_trials_per_episode=2, # GOAL_TRIAL only + per_trial_timeout=None, # GOAL_TRIAL only; None → C defaults to scenario_length goal_target_distance=10.0, goal_radius=2.0, goal_speed=20.0, @@ -101,6 +103,8 @@ def __init__( self.goal_radius = goal_radius self.goal_speed = goal_speed self.goal_behavior = goal_behavior + self.max_trials_per_episode = max_trials_per_episode + self.per_trial_timeout = per_trial_timeout self.goal_target_distance = goal_target_distance self.collision_behavior = collision_behavior self.offroad_behavior = offroad_behavior @@ -420,7 +424,27 @@ def __init__( self.co_player_device = torch.device("cpu") self._set_co_player_state() + # B'' off-map flag. C writes 1 when an ego reaches goal mid-trial + # (entity goes off-map); 0 when env trial-end resets the world. + # pufferl reads this via vecenv.removed to mask off-map slots in + # the KV cache attention. Sourced from buf["removed"] when the vec + # backend (Multiprocessing or Serial) allocates SHM for it; falls + # back to a private numpy array for standalone use. + _removed_external = buf["removed"] if (buf is not None and "removed" in buf) else None + super().__init__(buf=buf) + + # Per-trial-boundary flag. C writes 1 at env trial-end under gb=3; + # Python reads. See docs/src/trial_mode.md. + self.trial_ended_this_step = np.zeros(self.num_agents, dtype=bool) + if _removed_external is not None: + assert _removed_external.shape == (self.num_agents,), ( + f"buf['removed'] shape {_removed_external.shape} != ({self.num_agents},)" + ) + self.removed = _removed_external + else: + self.removed = np.zeros(self.num_agents, dtype=bool) + if self.population_play: self.action_space = pufferlib.spaces.joint_space(self.single_action_space, self.num_ego_agents) co_player_atn_space = pufferlib.spaces.joint_space(self.single_action_space, self.num_co_players) @@ -467,6 +491,10 @@ def __init__( goal_radius=goal_radius, goal_speed=goal_speed, goal_behavior=self.goal_behavior, + max_trials_per_episode=self.max_trials_per_episode, + per_trial_timeout=( + int(self.per_trial_timeout) if self.per_trial_timeout is not None else 0 + ), goal_target_distance=self.goal_target_distance, collision_behavior=self.collision_behavior, offroad_behavior=self.offroad_behavior, @@ -500,6 +528,8 @@ def __init__( control_mode=self.control_mode, map_dir=map_dir, render_mode=self._render_mode_int, + trial_ended_this_step=self.trial_ended_this_step[cur:nxt], + removed=self.removed[cur:nxt], ) env_ids.append(env_id) @@ -906,6 +936,10 @@ def _reinit_envs_with_new_maps(self): reward_offroad_collision=self.reward_offroad_collision, goal_radius=self.goal_radius, goal_behavior=self.goal_behavior, + max_trials_per_episode=self.max_trials_per_episode, + per_trial_timeout=( + int(self.per_trial_timeout) if self.per_trial_timeout is not None else 0 + ), collision_behavior=self.collision_behavior, offroad_behavior=self.offroad_behavior, reward_goal=self.reward_goal, @@ -943,6 +977,8 @@ def _reinit_envs_with_new_maps(self): control_mode=self.control_mode, map_dir=self.map_dir, render_mode=self._render_mode_int, + trial_ended_this_step=self.trial_ended_this_step[cur:nxt], + removed=self.removed[cur:nxt], ) env_ids.append(env_id) self.c_envs = binding.vectorize(*env_ids) @@ -976,6 +1012,20 @@ def _aggregate_scenario_metrics(self, scenario_infos): return aggregated + def _inject_trial_deltas(self, log): + """Under goal_behavior=3, fill in ada_delta_trial_K_minus_0 keys from + the per-trial-index trial_K_score values the C side just emitted. + Mutates `log` in place. Stops at the first slot whose score == 0 AND + whose k > 0 (likely an unused slot for current max_trials). + """ + k_max = self.max_trials_per_episode + trial_0 = log.get("trial_0_score", 0.0) + for k in range(1, k_max): + key = f"trial_{k}_score" + if key not in log: + break + log[f"ada_delta_trial_{k}_minus_0"] = log[key] - trial_0 + def _compute_delta_metrics(self): """Compute delta metrics between first and last scenario.""" if len(self.scenario_metrics) < 2: @@ -1009,15 +1059,18 @@ def _compute_delta_metrics(self): def step(self, actions): self.terminals[:] = 0 + # Under gb=3, C owns both `truncations` and `trial_ended_this_step`: + # zeroes them at top of c_step and writes 1 at each trial boundary. + # Under non-trial modes, Python still owns `truncations` (k_eff + # curriculum below writes it directly), so zero here only if non-3. + if self.goal_behavior != 3: + self.truncations[:] = 0 self.actions[self.ego_ids] = actions if self.population_play and not self.external_co_player_actions: co_player_actions = self.get_co_player_actions() self.actions[self.co_player_ids] = co_player_actions - # When external_co_player_actions=True, the main process has already - # written co-player actions into self.actions[co_player_ids] via the - # shared-memory action buffer; nothing to do here. binding.vec_step(self.c_envs) if self.reward_only_last_scenario and self.current_scenario != self.k_scenarios - 1: @@ -1031,6 +1084,12 @@ def step(self, actions): if self.tick % self.report_interval == 0: log = binding.vec_log(self.c_envs, self.num_agents) if log: + # Under GOAL_TRIAL: derive ada_delta_trial_K_minus_0 from the + # per-trial-index success rates the C side now emits as + # trial_K_score. Surfaces in wandb every report_interval; no + # need to wait for eval-time HumanReplayEvaluator. + if self.goal_behavior == 3: + self._inject_trial_deltas(log) if self.adaptive_driving_agent: self.current_scenario_infos.append(log) # For training: only report 0-shot (scenario 0) metrics @@ -1051,7 +1110,12 @@ def step(self, actions): info.append(self._pending_k_eff_log) self._pending_k_eff_log = None - if self.tick % self.scenario_length == 0: + # Per-scenario block (gb != 3 only): every scenario_length ticks, + # aggregate per-scenario metrics, rotate partner / maps. Under gb=3 + # trial boundaries are variable-length so this fixed-time block would + # land mid-trial; metrics flow through add_log_one_agent instead. + run_per_scenario_block = self.tick % self.scenario_length == 0 and self.goal_behavior != 3 + if run_per_scenario_block: if self.adaptive_driving_agent and self.current_scenario_infos: scenario_log = self._aggregate_scenario_metrics(self.current_scenario_infos) scenario_log["scenario_id"] = self.current_scenario @@ -1128,11 +1192,20 @@ def step(self, actions): self.truncations[self.ego_ids] = 1 self.terminals[self.ego_ids] = 1 + # Map-rotation boundary. Option D's idle-after-max_trials prevents the + # 1-map-many-short-episodes pathology that motivated rotating on every + # terminals.any() (250ms/call × ~10 calls/sec was infeasible). if self.tick > 0 and self.resample_frequency > 0 and self.tick % self.resample_frequency == 0: + # Force-flush env->log under gb=3 before reinit zeros it. Slow + # agents that didn't finish max_trials this cycle don't bump + # log.n, so the standard vec_log gate may not fire on its own. + if self.goal_behavior == 3: + log = binding.vec_log(self.c_envs, 1) + if log and log.get("n", 0) > 0: + info.append(log) self.tick = 0 will_resample = 1 if will_resample: - # Log deltas before resampling if we're at the end of a cycle if self.adaptive_driving_agent and self.scenario_metrics: delta_metrics = self._compute_delta_metrics() if delta_metrics: diff --git a/pufferlib/ocean/drive/rollout.py b/pufferlib/ocean/drive/rollout.py index aa8d4babf4..91c561c7f7 100644 --- a/pufferlib/ocean/drive/rollout.py +++ b/pufferlib/ocean/drive/rollout.py @@ -112,13 +112,42 @@ def rollout_loop( lstm_c=torch.zeros(num_ego_agents, hidden_size, device=device), ) + # Default max_steps: + # - non-trial: one scenario_length (matches single-episode video budget). + # - GOAL_TRIAL: a full episode budget = max_trials * per_trial_timeout + # (which auto-link sets to k_scenarios * scenario_length). Without + # this, the render would cut off after one scenario_length, showing + # only the first trial of an adaptive episode — the whole point of + # trial mode is to see adaptation ACROSS trials in one video. if max_steps is None: - max_steps = getattr(driver, "scenario_length", 91) + goal_behavior = int(getattr(driver, "goal_behavior", 0)) + if goal_behavior == 3: + max_trials = int(getattr(driver, "max_trials_per_episode", 2)) + per_trial = int(getattr(driver, "per_trial_timeout", 0) or 0) + if per_trial <= 0: + per_trial = int(getattr(driver, "scenario_length", 91)) + max_steps = max_trials * per_trial + else: + max_steps = getattr(driver, "scenario_length", 91) + + # Per-trial annotation state. Under GOAL_TRIAL we read driver.trial_ended_this_step + # after each step and bump a per-agent trial counter so the visualizer (or + # downstream caller) knows which trial each frame belongs to. + is_trial_mode = int(getattr(driver, "goal_behavior", 0)) == 3 + n_agents_for_trial = ( + getattr(driver, "num_ego_agents", None) + or env.observation_space.shape[0] + ) + trial_idx = np.zeros(n_agents_for_trial, dtype=np.int32) if is_trial_mode else None + trial_starts = [] # list of (step, agent_idx, new_trial_idx) — useful for video chapter markers + last_print_step = -1 info = [] for step in range(max_steps): - if step % 30 == 0: - print(f"[Python Render] Step {step}/{max_steps}", flush=True) + if step % 30 == 0 and step != last_print_step: + trial_suffix = f" trial=mean_{float(trial_idx.mean()):.1f}" if is_trial_mode else "" + print(f"[Python Render] Step {step}/{max_steps}{trial_suffix}", flush=True) + last_print_step = step # Render BEFORE the step so each frame shows the state the policy was # conditioned on. if render_ctx is not None: @@ -152,10 +181,39 @@ def rollout_loop( if isinstance(logits, torch.distributions.Normal): action_np = np.clip(action_np, env.action_space.low, env.action_space.high) - obs, _, _, truncs, info = env.step(action_np) + obs, _, terms, truncs, info = env.step(action_np) + + # Per-trial bookkeeping. trial_ended_this_step is per-agent — when it + # fires we know that agent just started a new trial on the next step, + # so bump its trial_idx. trial_starts collects (step, agent, new_idx) + # tuples that callers can use to overlay trial boundaries on the video. + if is_trial_mode and trial_idx is not None: + te = np.asarray(driver.trial_ended_this_step, dtype=bool) + if population_play and ego_ids is not None: + te = te[ego_ids] if te.shape[0] == env.observation_space.shape[0] else te + te = te[: len(trial_idx)] + if te.any(): + trial_idx[te] += 1 + for a in np.where(te)[0]: + trial_starts.append((step + 1, int(a), int(trial_idx[a]))) + + # Break conditions: + # - non-trial: truncs.all() fires at scenario boundary (env auto-reset path). + # - GOAL_TRIAL: truncs fires on every trial boundary (it now mirrors + # trial_ended_this_step — see drive.py.step). truncs.all() would + # fire whenever all agents end a trial at the same tick, which is + # NOT an episode boundary. Use terminals.all() instead so we + # render the full multi-trial episode. + if is_trial_mode: + if np.asarray(terms).all(): + break + else: + if truncs.all(): + break - # Break when episode ends (truncs.all() is set when the env auto-resets) - if truncs.all(): - break + # Stash trial_starts on the info dict for downstream consumers (renderer + # overlays, video chapter markers). Doesn't change existing info contract. + if is_trial_mode and isinstance(info, list): + info.append({"_trial_starts": trial_starts, "_final_trial_idx": trial_idx.tolist()}) return info diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 343cc96584..a766079d8c 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -128,7 +128,7 @@ static PyObject *env_init(PyObject *self, PyObject *args, PyObject *kwargs) { PyErr_SetString(PyExc_ValueError, "Truncations must be 1D"); return NULL; } - // env->truncations = PyArray_DATA(truncations); + env->truncations = PyArray_DATA(truncations); PyObject *seed_arg = PyTuple_GetItem(args, 5); if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) { diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 08e680cd13..68efd77886 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -56,6 +56,38 @@ signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0)) +# ---------------------------------------------------------------------------- +# Trial-mode debug logger. Set PUFFER_TRIAL_DEBUG_FILE=/path/to/log.jsonl to +# capture per-epoch GAE/cache/trial diagnostics as a stream of JSON records. +# No-op otherwise. Schema: +# {"event": "...", "epoch": int, "step": int, ...event-specific fields} +# Events: +# "rollout_end_of_epoch": rollout buffer summary at end of each eval phase +# "gae_outer": pre-GAE stats + post-GAE advantage stats per training update +# "gae_inner": per-minibatch stats inside the PPO update loop +# "cache_reset": each cache-reset event during rollout (under episode-end) +# ---------------------------------------------------------------------------- +import json as _json + +_TRIAL_DEBUG_PATH = os.environ.get("PUFFER_TRIAL_DEBUG_FILE", "") +_TRIAL_DEBUG_ENABLED = bool(_TRIAL_DEBUG_PATH) +_TRIAL_DEBUG_FH = None + + +def _trial_debug_log(event, **data): + """Append a JSON line to the trial-debug file. Cheap when disabled.""" + if not _TRIAL_DEBUG_ENABLED: + return + global _TRIAL_DEBUG_FH + try: + if _TRIAL_DEBUG_FH is None: + _TRIAL_DEBUG_FH = open(_TRIAL_DEBUG_PATH, "a", buffering=1) + rec = {"event": event, "ts": time.time(), **data} + _TRIAL_DEBUG_FH.write(_json.dumps(rec, default=str) + "\n") + except Exception: + pass + + # Assume advantage kernel has been built if CUDA compiler is available ADVANTAGE_CUDA = shutil.which("nvcc") is not None @@ -170,6 +202,11 @@ def __init__(self, config, vecenv, policy, logger=None): self.rewards = torch.zeros(segments, horizon, device=device) self.terminals = torch.zeros(segments, horizon, device=device) self.truncations = torch.zeros(segments, horizon, device=device) + # Per-step per-agent off-map flag (gb=3 B''). Same shape as terminals. + # Used in training to (a) add a garbage-attention mask matching the + # eval-time `garbage_mask`, and (b) gate PPO loss/entropy/value-loss + # so limbo tuples don't contribute gradient. + self.removed_history = torch.zeros(segments, horizon, device=device, dtype=torch.bool) self.ratio = torch.ones(segments, horizon, device=device) self.importance = torch.ones(segments, horizon, device=device) self.ep_lengths = torch.zeros(total_agents, device=device, dtype=torch.int32) @@ -225,6 +262,12 @@ def __init__(self, config, vecenv, policy, logger=None): # prior to that). None initially → first call allocates. self.transformer_k_cache = {i * n: None for i in range(num_chunks)} self.transformer_v_cache = {i * n: None for i in range(num_chunks)} + # B'' garbage_mask: per-agent per-cache-slot bool. The model marks + # current slot True when env.removed[i]=1 (ego off-map). Attention + # then excludes those slots. Lazy-allocated by the model on first + # forward_eval — None here mirrors the k_cache pattern. + self.transformer_garbage_mask = {i * n: None for i in range(num_chunks)} + self.horizon = int(getattr(policy, "horizon", config.get("horizon", 0)) or 0) # Regression detector for the rnn_name plumbing bug — fires once. print( @@ -448,68 +491,48 @@ def _fill_external_co_player_actions(self, full_obs, info, env_id, dones, truncs device = self.config["device"] agents_per_worker = self.vecenv.agents_per_worker - # batch_size > 1 packs multiple workers into one recv; we handle the - # batch_size=1 case (the production setup) here. Generalizing to - # batch_size>1 is straightforward (loop over workers in the batch). - batch_size = self.vecenv.batch_size - if batch_size != 1: - raise NotImplementedError( - f"external_co_player_actions currently only supports batch_size=1; got batch_size={batch_size}." - ) - - # Map env_id back to a worker index so we know which co_player_state - # to use and which row of vecenv.actions to write to. - # For population_play, recv() returns ego-only agent ids - # (vecenv.ego_agent_ids), so divide by the per-worker ego count, not - # the full agent count. ego_agents_per_worker = getattr(self.vecenv, "ego_agents_per_worker", agents_per_worker) - worker_id = int(env_id[0]) // ego_agents_per_worker - # Pull the actual co_player_ids from info (the env knows them). - # Also check for the scenario-boundary cache reset signal. - co_ids = None - reset_cache = False + # Parse per-worker co_ids + reset flags from info (preserves order). + # When vec.batch_size > 1, recv() returns N workers' obs+info stacked, + # so info contains N dicts each with their own _external_co_player_ids. + co_ids_per_worker = [] + reset_flags = [] for item in info: - if isinstance(item, dict): - if "_external_co_player_ids" in item: - co_ids = list(item["_external_co_player_ids"]) - if item.get("_external_reset_co_cache"): - reset_cache = True - if co_ids is None: - raise RuntimeError( - "external_co_player_actions=True but the env did not " - "publish '_external_co_player_ids' in info. Is drive.py up to date?" - ) - if not co_ids: - return # no co-players in this env this step - - # Drop the cache at scenario boundaries — mirrors the per-worker - # OFF path's _reset_co_player_state() which fully reinits state. - # Replacing the dict makes forward_eval lazy-allocate fresh K/V on - # the next call, matching legacy behavior bit-for-bit at scenario - # boundaries. - if reset_cache: - self.co_player_state[worker_id] = {} - - # Slice the co-player observations. When the ego is in oracle mode - # (drive.py `ego_is_oracle=True`) the env's obs is wider than the - # partner policy expects — partner conditioning is appended to ego - # rows only. Strip trailing oracle dims by slicing columns to the - # env's `_c_obs_dim` (the C-side obs width). Defaults to None when - # oracle is off → take the full width as before. + if isinstance(item, dict) and "_external_co_player_ids" in item: + co_ids_per_worker.append(list(item["_external_co_player_ids"])) + reset_flags.append(bool(item.get("_external_reset_co_cache", False))) + n_in_batch = len(co_ids_per_worker) + if n_in_batch == 0: + return + + base_worker_id = int(env_id[0]) // ego_agents_per_worker + worker_ids = [base_worker_id + i for i in range(n_in_batch)] + + for i, w_id in enumerate(worker_ids): + if reset_flags[i]: + self.co_player_state[w_id] = {} + + # Build batched co_obs across workers. cum[] holds per-worker slice + # offsets in the batched tensor for distributing actions afterwards. co_obs_width = getattr(self.vecenv.driver_env, "_c_obs_dim", None) - if co_obs_width is None: - co_obs_np = full_obs[co_ids] - else: - co_obs_np = full_obs[co_ids, :co_obs_width] + parts = [] + cum = [0] + for i in range(n_in_batch): + co_ids = co_ids_per_worker[i] + worker_obs = full_obs[i * agents_per_worker : (i + 1) * agents_per_worker] + wco = worker_obs[co_ids] if co_obs_width is None else worker_obs[co_ids, :co_obs_width] + parts.append(wco) + cum.append(cum[-1] + len(co_ids)) + if cum[-1] == 0: + return + + co_obs_np = np.concatenate(parts, axis=0) co_obs = torch.as_tensor(co_obs_np, device=device) if self.co_player_conditioning_dims > 0: - # Pull this worker's conditioning slice from the SHM buffer the - # env wrote at scenario boundaries (or at env init). Insert right - # after the base ego_features — matches drive.py's - # `_add_co_player_conditioning` exactly. - cond_shm = self.vecenv.co_player_conditioning # (num_workers, max_co, cdim) - cond_np = cond_shm[worker_id, : len(co_ids), :] # only the rows we'll use + cond_shm = self.vecenv.co_player_conditioning + cond_parts = [cond_shm[worker_ids[i], : len(co_ids_per_worker[i]), :] for i in range(n_in_batch)] + cond_np = np.concatenate(cond_parts, axis=0) cond = torch.as_tensor(cond_np, device=device, dtype=co_obs.dtype) from pufferlib.ocean.drive import binding as _b @@ -518,25 +541,51 @@ def _fill_external_co_player_actions(self, full_obs, info, env_id, dones, truncs ) co_obs = torch.cat([co_obs[:, :base_ego_dim], cond, co_obs[:, base_ego_dim:]], dim=1) - # NOTE: the OFF (per-worker) path only resets cache at scenario - # boundary or reset(), never for individual done agents. So we - # don't reset per-done here either — it would diverge from OFF. + # Merge per-worker KV caches into one batched cache so the policy + # runs a single forward over all workers' co-players. Caches stay + # per-worker in storage — they're only briefly stacked for the call. + states = [self.co_player_state[w_id] for w_id in worker_ids] + batched_state = {} + if all("k_cache" in s and s["k_cache"] is not None for s in states): + n_layers = len(states[0]["k_cache"]) + batched_state["k_cache"] = [ + torch.cat([s["k_cache"][li] for s in states], dim=0) for li in range(n_layers) + ] + batched_state["v_cache"] = [ + torch.cat([s["v_cache"][li] for s in states], dim=0) for li in range(n_layers) + ] + if all("garbage_mask" in s and s["garbage_mask"] is not None for s in states): + batched_state["garbage_mask"] = torch.cat([s["garbage_mask"] for s in states], dim=0) + if "transformer_position" in states[0]: + batched_state["transformer_position"] = states[0]["transformer_position"] with torch.no_grad(): - logits, _ = self.co_player_policy.forward_eval(co_obs, self.co_player_state[worker_id]) + logits, _ = self.co_player_policy.forward_eval(co_obs, batched_state) + + # Split updated state back to per-worker stores along the batch dim. + for i, w_id in enumerate(worker_ids): + start, end = cum[i], cum[i + 1] + ns = {} + if "k_cache" in batched_state: + ns["k_cache"] = [k[start:end] for k in batched_state["k_cache"]] + ns["v_cache"] = [v[start:end] for v in batched_state["v_cache"]] + if "garbage_mask" in batched_state: + ns["garbage_mask"] = batched_state["garbage_mask"][start:end] + if "transformer_position" in batched_state: + ns["transformer_position"] = batched_state["transformer_position"] + self.co_player_state[w_id] = ns - # Match the per-worker code path: argmax for discrete actions. if isinstance(logits, tuple): co_action = torch.cat([l.argmax(dim=-1, keepdim=True) for l in logits], dim=-1) else: co_action = logits.argmax(dim=-1) - co_action_np = co_action.cpu().numpy().reshape(len(co_ids), -1) + co_action_np = co_action.cpu().numpy().reshape(cum[-1], -1) - # Write directly to the worker's slot in the shared-memory action - # buffer. The worker's env.step() will call vec_step using these - # actions because it has external_co_player_actions=True. - co_action_view = self.vecenv.actions[worker_id] # shape (agents_per_worker, *atn_shape) - co_action_view[co_ids] = co_action_np.reshape((len(co_ids),) + co_action_view.shape[1:]) + for i, w_id in enumerate(worker_ids): + co_ids = co_ids_per_worker[i] + start, end = cum[i], cum[i + 1] + co_action_view = self.vecenv.actions[w_id] + co_action_view[co_ids] = co_action_np[start:end].reshape((len(co_ids),) + co_action_view.shape[1:]) def evaluate(self): profile = self.profile @@ -566,6 +615,10 @@ def evaluate(self): self.transformer_v_cache[k] = None self.full_rows = 0 + # Hold autocast active across the whole rollout so torch.compile + # doesn't see autocast GLOBAL_STATE flip between forward_eval calls + # (would recompile every step under reduce-overhead). + self.amp_context.__enter__() while self.full_rows < self.segments: profile("env", epoch) # print(".", end="", flush=True) # Workaround: visible I/O prevents multiprocessing deadlock @@ -607,7 +660,11 @@ def evaluate(self): profile("eval_misc", epoch) env_id = slice(env_id[0], env_id[-1] + 1) - done_mask = d + t # TODO: Handle truncations separately + # KV cache + PE reset gate on `d` (terminals) only. Trial + # boundaries (`t`, truncations) keep the cache so the policy + # adapts across trials within an episode. See + # docs/src/trial_mode.md. + done_mask = d self.global_step += int(mask.sum()) profile("eval_copy", epoch) @@ -617,7 +674,7 @@ def evaluate(self): d = torch.as_tensor(d, device=device) profile("eval_forward", epoch) - with torch.no_grad(), self.amp_context: + with torch.no_grad(): state = dict( reward=r, done=d, @@ -645,6 +702,21 @@ def evaluate(self): # and the policy attends over the full accumulated past. state["k_cache"] = self.transformer_k_cache[state_key] state["v_cache"] = self.transformer_v_cache[state_key] + state["garbage_mask"] = self.transformer_garbage_mask[state_key] + # B'' off-map flag. The model uses this to (a) mark the + # current cache slot as garbage in garbage_mask, and + # (b) exclude existing garbage slots from this step's + # attention. Unified flat (num_agents,) view exposed by + # the vec backend: Multiprocessing returns a SHM view + # so worker writes are visible; Serial/native return + # the in-process numpy array. None or all-False if the + # env doesn't expose `removed` (e.g. non-gb=3 modes). + rem_buf = getattr(self.vecenv, "removed", None) + if rem_buf is None: + rem_buf = getattr(self.vecenv.driver_env, "removed", None) + if rem_buf is not None: + rem_np = np.asarray(rem_buf)[env_id] + state["removed"] = torch.as_tensor(rem_np, device=device, dtype=torch.bool) # Note: terminals not needed for eval since we're doing single-step inference # print(".", end="", flush=True) # Prevents multiprocessing deadlock @@ -680,6 +752,7 @@ def evaluate(self): # these if it took the legacy path. self.transformer_k_cache[transformer_key] = state.get("k_cache") self.transformer_v_cache[transformer_key] = state.get("v_cache") + self.transformer_garbage_mask[transformer_key] = state.get("garbage_mask") # Episode-boundary reset. pos is a shared (1,) scalar # across the chunk; cache rows are per-agent. Filter @@ -702,6 +775,22 @@ def evaluate(self): c[valid_indices] = 0 for c in vc: c[valid_indices] = 0 + gm = self.transformer_garbage_mask[transformer_key] + if gm is not None: + gm[valid_indices] = False + if _TRIAL_DEBUG_ENABLED: + # At this point d/t may be torch CUDA tensors + # (converted earlier in this block). Use done_mask + # (still numpy) for the boundary count. + _trial_debug_log( + "cache_reset", + epoch=int(self.epoch), + step=int(self.global_step), + env_id_start=int(env_id.start), + env_id_stop=int(env_id.stop), + n_done=int(len(valid_indices)), + done_mask_sum=int(np.asarray(done_mask).sum()), + ) # Fast path for fully vectorized envs l = self.ep_lengths[env_id.start].item() batch_rows = slice(self.ep_indices[env_id.start].item(), 1 + self.ep_indices[env_id.stop - 1].item()) @@ -715,9 +804,23 @@ def evaluate(self): self.logprobs[batch_rows, l] = logprob self.rewards[batch_rows, l] = r self.terminals[batch_rows, l] = d.float() + # Persist truncations for GAE bootstrap-stop. Stays out of + # state["terminals"] so attention/PE span trial boundaries. + t_tensor = torch.as_tensor(t, device=device).float() + self.truncations[batch_rows, l] = t_tensor self.values[batch_rows, l] = value.flatten() - - # Note: We are not yet handling masks in this version + # Persist per-step `removed` flag for train/eval mask parity. + # During training we (a) add a garbage-attention mask matching + # eval's `garbage_mask`, and (b) gate PPO losses so limbo + # tuples don't contribute gradient. + rem_buf = getattr(self.vecenv, "removed", None) + if rem_buf is None: + rem_buf = getattr(self.vecenv.driver_env, "removed", None) + if rem_buf is not None: + rem_step = torch.as_tensor( + np.asarray(rem_buf)[env_id], device=device, dtype=torch.bool + ) + self.removed_history[batch_rows, l] = rem_step self.ep_lengths[env_id] += 1 # Use appropriate horizon based on model type horizon = ( @@ -749,6 +852,9 @@ def evaluate(self): profile("env", epoch) self.vecenv.send(action) + # Exit the autocast context that wraps the rollout loop. + self.amp_context.__exit__(None, None, None) + profile("eval_misc", epoch) self.free_idx = self.total_agents @@ -802,10 +908,37 @@ def train(self): else: gammas = torch.full((self.segments,), config["gamma"], device=device, dtype=torch.float32) + # GAE bootstrap-stop = terminals ∨ truncations ∨ removed. + # - terminals: episode boundary (full reset) + # - truncations: trial boundary under gb=3 (world resets, KV cache persists) + # - removed: ego is off-map (limbo). V at limbo is computed from + # garbage (INVALID_POSITION) obs; bootstrapping from it would + # poison the prior step's advantage. Treat each limbo slot as + # a value-chain cut. + bootstrap_stop = ( + self.terminals + self.truncations + self.removed_history.float() + ).clamp(max=1.0) + if _TRIAL_DEBUG_ENABLED: + _trial_debug_log( + "gae_outer_pre", + epoch=int(self.epoch), + minibatch=int(mb), + step=int(self.global_step), + terminals_sum=float(self.terminals.sum().item()), + truncations_sum=float(self.truncations.sum().item()), + bootstrap_stop_sum=float(bootstrap_stop.sum().item()), + bootstrap_overlap=float( + torch.minimum(self.terminals, self.truncations).sum().item() + ), + values_mean=float(self.values.mean().item()), + values_std=float(self.values.std().item()), + rewards_mean=float(self.rewards.mean().item()), + rewards_sum=float(self.rewards.sum().item()), + ) advantages = compute_puff_advantage( self.values, self.rewards, - self.terminals, + bootstrap_stop, self.ratio, advantages, gammas, @@ -813,6 +946,20 @@ def train(self): config["vtrace_rho_clip"], config["vtrace_c_clip"], ) + if _TRIAL_DEBUG_ENABLED: + adv_flat = advantages.flatten() + _trial_debug_log( + "gae_outer_post", + epoch=int(self.epoch), + minibatch=int(mb), + step=int(self.global_step), + adv_mean=float(adv_flat.mean().item()), + adv_std=float(adv_flat.std().item()), + adv_min=float(adv_flat.min().item()), + adv_max=float(adv_flat.max().item()), + adv_nan_count=int(torch.isnan(adv_flat).sum().item()), + adv_inf_count=int(torch.isinf(adv_flat).sum().item()), + ) profile("train_copy", epoch) adv = advantages.abs().sum(axis=1) @@ -834,6 +981,7 @@ def train(self): mb_rewards = self.rewards[idx] mb_terminals = self.terminals[idx] mb_truncations = self.truncations[idx] + mb_removed = self.removed_history[idx] # (B, T) bool — 1 = limbo step mb_ratio = self.ratio[idx] mb_values = self.values[idx] mb_returns = advantages[idx] + mb_values @@ -861,6 +1009,7 @@ def train(self): state["transformer_context"] = None state["transformer_position"] = None state["terminals"] = mb_terminals # For episode boundary masking + state["removed"] = mb_removed # Train/eval mask parity (gb=3) logits, newvalue = self.policy(mb_obs, state) @@ -884,12 +1033,23 @@ def train(self): newlogprob = newlogprob.reshape(mb_logprobs.shape) logratio = newlogprob - mb_logprobs ratio = logratio.exp() - self.ratio[idx] = ratio.detach() + # Limbo importance ratios are computed from garbage obs / actions + # and would poison the outer GAE's v-trace coefficients on the + # next minibatch. Preserve the existing ratio at limbo positions. + ratio_to_store = ratio.detach() + if mb_removed is not None: + ratio_to_store = torch.where(mb_removed, self.ratio[idx], ratio_to_store) + self.ratio[idx] = ratio_to_store with torch.no_grad(): - old_approx_kl = (-logratio).mean() - approx_kl = ((ratio - 1) - logratio).mean() - clipfrac = ((ratio - 1.0).abs() > config["clip_coef"]).float().mean() + # Mask limbo steps from diagnostics too so values aren't + # inflated by garbage tuples (mb_removed will be available + # in scope by the time these are reported; safe to reference). + _diag_mask = (~mb_removed).to(logratio.dtype) + _diag_n = _diag_mask.sum().clamp(min=1.0) + old_approx_kl = ((-logratio) * _diag_mask).sum() / _diag_n + approx_kl = (((ratio - 1) - logratio) * _diag_mask).sum() / _diag_n + clipfrac = (((ratio - 1.0).abs() > config["clip_coef"]).float() * _diag_mask).sum() / _diag_n adv = advantages[idx] if hasattr(self.vecenv.driver_env, "discount_conditioned") and self.vecenv.driver_env.discount_conditioned: @@ -897,11 +1057,32 @@ def train(self): else: mb_gammas = torch.full((len(idx),), config["gamma"], device=device, dtype=torch.float32) - # Recompute advantages with new ratios + # Recompute advantages with new ratios — bootstrap-stop is + # terminals OR truncations OR removed (see outer GAE call comment). + mb_bootstrap_stop = ( + mb_terminals + mb_truncations + mb_removed.float() + ).clamp(max=1.0) + if _TRIAL_DEBUG_ENABLED: + _trial_debug_log( + "gae_inner", + epoch=int(self.epoch), + minibatch=int(mb), + step=int(self.global_step), + mb_terminals_sum=float(mb_terminals.sum().item()), + mb_truncations_sum=float(mb_truncations.sum().item()), + mb_bootstrap_sum=float(mb_bootstrap_stop.sum().item()), + ratio_mean=float(ratio.mean().item()), + ratio_min=float(ratio.min().item()), + ratio_max=float(ratio.max().item()), + approx_kl=float(approx_kl.item()), + clipfrac=float(clipfrac.item()), + adv_mean_pre=float(adv.mean().item()), + adv_std_pre=float(adv.std().item()), + ) adv = compute_puff_advantage( mb_values, mb_rewards, - mb_terminals, + mb_bootstrap_stop, ratio, adv, mb_gammas, @@ -913,17 +1094,30 @@ def train(self): adv = mb_prio * (adv - adv.mean()) / (adv.std() + 1e-8) # Losses + # Per-step validity mask: 1 where the agent was ACTIVE (not limbo), + # 0 where removed=1 (off-map). All per-sample losses are weighted + # by this and normalized by the count of valid samples, so limbo + # tuples contribute zero gradient. mb_removed has shape (B, T) + # matching the per-step losses below. + valid_mask = (~mb_removed).to(adv.dtype) # (B, T) + n_valid = valid_mask.sum().clamp(min=1.0) + pg_loss1 = -adv * ratio pg_loss2 = -adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() + pg_loss = (torch.max(pg_loss1, pg_loss2) * valid_mask).sum() / n_valid newvalue = newvalue.view(mb_returns.shape) v_clipped = mb_values + torch.clamp(newvalue - mb_values, -vf_clip, vf_clip) v_loss_unclipped = (newvalue - mb_returns) ** 2 v_loss_clipped = (v_clipped - mb_returns) ** 2 - v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean() - - # Entropy-weighted loss if entropy conditioning is enabled + v_loss = 0.5 * (torch.max(v_loss_unclipped, v_loss_clipped) * valid_mask).sum() / n_valid + + # Entropy-weighted loss if entropy conditioning is enabled. + # NOTE: entropy comes back from sample_logits FLAT — shape (B*T,) + # — while valid_mask is (B, T). Flatten valid_mask once for these + # mults so we don't crash on broadcast. + valid_mask_flat = valid_mask.reshape(-1) + n_valid_flat = valid_mask_flat.sum().clamp(min=1.0) if hasattr(self.vecenv.driver_env, "entropy_conditioned") and self.vecenv.driver_env.entropy_conditioned: mb_obs_flat = mb_obs.reshape(-1, mb_obs.shape[-1]) @@ -940,15 +1134,25 @@ def train(self): ent_weights = mb_obs_flat[:, ent_idx] # after ego(7/10) + RC(3) ent_weights = ent_weights.reshape(entropy.shape) - entropy_loss = -(entropy * ent_weights).mean() + entropy_loss = -((entropy * ent_weights) * valid_mask_flat).sum() / n_valid_flat loss = pg_loss + config["vf_coef"] * v_loss + entropy_loss else: - entropy_loss = entropy.mean() + entropy_loss = (entropy * valid_mask_flat).sum() / n_valid_flat loss = pg_loss + config["vf_coef"] * v_loss - config["ent_coef"] * entropy_loss self.amp_context.__enter__() # TODO: AMP needs some debugging - # This breaks vloss clipping? - self.values[idx] = newvalue.detach().float() + # Write back the new value-head output for the next outer GAE. + # CRITICAL: preserve limbo positions — at those slots `newvalue` + # was computed from garbage obs (INVALID_POSITION) and writing + # it back would poison subsequent GAE calls. The old `mb_values` + # at limbo positions is also garbage (also computed from limbo + # obs at rollout time), so neither choice is "right" — but + # keeping the prior value at limbo positions prevents + # mb-by-mb drift across PPO epochs. + new_v = newvalue.detach().float() + if mb_removed is not None: + new_v = torch.where(mb_removed, mb_values.float(), new_v) + self.values[idx] = new_v # Logging profile("train_misc", epoch) @@ -981,6 +1185,44 @@ def train(self): profile.end() logs = None + if _TRIAL_DEBUG_ENABLED: + # Per-epoch summary: cache health, transformer position, loss snapshot. + k_cache = getattr(self, "transformer_k_cache", None) + v_cache = getattr(self, "transformer_v_cache", None) + pos_buf = getattr(self, "transformer_position", None) + cache_stats = {} + if k_cache is not None and isinstance(k_cache, dict) and len(k_cache) > 0: + # k_cache is dict keyed by transformer_key; each value is a list of layer K tensors + some_key = next(iter(k_cache.keys())) + cache_list = k_cache[some_key] + if cache_list is not None and len(cache_list) > 0: + sample = cache_list[0] + cache_stats = dict( + shape=list(sample.shape), + dtype=str(sample.dtype), + norm_mean=float(sample.norm(dim=-1).mean().item()), + nan_count=int(torch.isnan(sample).sum().item()), + inf_count=int(torch.isinf(sample).sum().item()), + ) + pos_stats = {} + if pos_buf is not None and isinstance(pos_buf, dict) and len(pos_buf) > 0: + some_key = next(iter(pos_buf.keys())) + p = pos_buf[some_key] + if p is not None: + pos_stats = dict(min=int(p.min().item()), max=int(p.max().item())) + _trial_debug_log( + "epoch_end", + epoch=int(self.epoch), + step=int(self.global_step), + policy_loss=float(losses.get("policy_loss", 0)), + value_loss=float(losses.get("value_loss", 0)), + entropy=float(losses.get("entropy", 0)), + approx_kl=float(losses.get("approx_kl", 0)), + clipfrac=float(losses.get("clipfrac", 0)), + explained_var=float(explained_var.item() if not torch.isnan(torch.tensor(float(explained_var))) else 0.0), + cache=cache_stats, + position=pos_stats, + ) self.epoch += 1 done_training = self.global_step >= config["total_timesteps"] if done_training or self.global_step == 0 or time.time() > self.last_log_time + 0.25: diff --git a/pufferlib/utils.py b/pufferlib/utils.py index 7210f54d06..e732ab41f0 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -36,6 +36,10 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): k_scenarios = env_config.get("k_scenarios", 1) scenario_length = env_config.get("scenario_length", 91) train_horizon = config.get("horizon", scenario_length * k_scenarios) + # Inherit goal_behavior from training. Under gb=3 the eval subprocess + # re-derives max_trials_per_episode and per_trial_timeout from + # k_scenarios + scenario_length, so we don't pass them. + goal_behavior = int(env_config.get("goal_behavior", 0)) cmd = [ sys.executable, @@ -66,11 +70,10 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): str(scenario_length), "--train.horizon", str(train_horizon), - # Match training goal_behavior. The "stop is cleaner" claim was wrong - # in practice — sparse reward under gb=2 produces worse drivers; gb=0 - # respawn captures real efficiency adaptation in scen_1 vs scen_0. + # Inherit training's goal_behavior. Under gb=3 the env will + # re-derive trial config from k_scenarios + scenario_length. "--env.goal-behavior", - "0", + str(goal_behavior), "--env.conditioning.type", conditioning_type, "--env.conditioning.collision-weight-lb", @@ -247,6 +250,11 @@ def render_videos(config, policy, logger, epoch, global_step, device="cuda", hum env_kwargs = copy.deepcopy(config.get("env_config", {})) env_kwargs["render_mode"] = 1 # RENDER_HEADLESS + # Route renders to eval.map_dir if set, so test-set videos match + # the eval map distribution (e.g. nuplan_hard) rather than train. + eval_map_dir = config.get("eval", {}).get("map_dir") + if eval_map_dir: + env_kwargs["map_dir"] = eval_map_dir # Render env runs alongside training and has to fit in the same VRAM / # RAM budget — override the training num_agents (often 1024+) down to a # render-sized footprint so we don't OOM on first render call. @@ -280,10 +288,13 @@ def render_videos(config, policy, logger, epoch, global_step, device="cuda", hum } use_rnn = config.get("use_rnn", False) - episode_length = env_kwargs.get("scenario_length", 91) + scenario_length = env_kwargs.get("scenario_length", 91) k_scenarios = env_kwargs.get("k_scenarios", 1) - if k_scenarios > 1: - episode_length = k_scenarios * episode_length + goal_behavior = int(env_kwargs.get("goal_behavior", 0)) + # episode_length = k_scenarios * scenario_length under all goal_behaviors. + # Under gb=3 the auto-link makes this equal to max_trials * per_trial_timeout. + episode_length = scenario_length * k_scenarios if k_scenarios > 1 else scenario_length + episode_label = f"trials{k_scenarios}" if goal_behavior == 3 else f"k{k_scenarios}" mode = "human_replay" if human_replay else ("coplayer" if env_kwargs.get("co_player_enabled") else "baseline") videos_to_log_world = [] @@ -296,7 +307,7 @@ def render_videos(config, policy, logger, epoch, global_step, device="cuda", hum map_ids = getattr(driver, "map_ids", None) map_id = int(map_ids[0]) if map_ids is not None and len(map_ids) > 0 else 0 view = _VIEW_NAMES.get(int(view_mode), "view") - basename = f"epoch_{epoch:06d}_{mode}_k{k_scenarios}_map{map_id:03d}_{view}" + basename = f"epoch_{epoch:06d}_{mode}_{episode_label}_map{map_id:03d}_{view}" # Tell the env to keep raylib + ffmpeg alive across map swaps so # the in-step _reinit_envs_with_new_maps() at scenario boundaries diff --git a/pufferlib/vector.py b/pufferlib/vector.py index deee21d9f9..939e8ba465 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -18,6 +18,12 @@ MAIN = 5 INFO = 6 +# Module-level cache for co-player state_dicts. Render epochs call +# pufferlib.vector.make repeatedly, and reloading the (~50MB) frozen +# co-player checkpoint each time was a major source of memory pressure +# and disk I/O. Keyed by (path, mtime) so a replaced file invalidates. +_CO_PLAYER_STATE_DICT_CACHE = {} + def recv_precheck(vecenv): if vecenv.flag != RECV: @@ -78,6 +84,15 @@ def __init__(self, env_creators, env_args, env_kwargs, num_envs, buf=None, seed= set_buffers(self, buf) + # `removed` is a Drive-specific SHM channel for the B'' off-map flag. + # Pass it through buf so each env's slice is a view into the same + # parent array — needed for pufferl to read a unified view via + # vecenv.removed regardless of backend. + if buf is not None and "removed" in buf: + self.removed = buf["removed"] + else: + self.removed = np.zeros(self.agents_per_batch, dtype=bool) + self.envs = [] ptr = 0 for i in range(num_envs): @@ -89,6 +104,7 @@ def __init__(self, env_creators, env_args, env_kwargs, num_envs, buf=None, seed= truncations=self.truncations[ptr:end], masks=self.masks[ptr:end], actions=self.actions[ptr:end], + removed=self.removed[ptr:end], ) ptr = end seed_i = seed + i if seed is not None else None @@ -223,6 +239,8 @@ def _worker_process( masks=np.ndarray(shape, dtype=bool, buffer=shm["masks"])[worker_idx], actions=atn_arr, ) + if "removed" in shm: + buf["removed"] = np.ndarray(shape, dtype=bool, buffer=shm["removed"])[worker_idx] buf["masks"][:] = True if population_play: @@ -375,6 +393,9 @@ def __init__( terminals=RawArray("b", num_agents), truncateds=RawArray("b", num_agents), masks=RawArray("b", num_agents), + # Drive B'' off-map flag. Workers write per-agent removed bits; + # main process reads via self.removed to drive KV-cache masking. + removed=RawArray("b", num_agents), semaphores=RawArray("c", num_workers), notify=RawArray("b", num_workers), ) @@ -457,10 +478,14 @@ def __init__( terminals=np.ndarray(shape, dtype=bool, buffer=self.shm["terminals"]), truncations=np.ndarray(shape, dtype=bool, buffer=self.shm["truncateds"]), masks=np.ndarray(shape, dtype=bool, buffer=self.shm["masks"]), + removed=np.ndarray(shape, dtype=bool, buffer=self.shm["removed"]), semaphores=np.ndarray(num_workers, dtype=np.uint8, buffer=self.shm["semaphores"]), notify=np.ndarray(num_workers, dtype=bool, buffer=self.shm["notify"]), ) self.buf["semaphores"][:] = MAIN + # Flat (num_agents,) view of the SHM removed buffer. Reading this + # from the main process sees writes from any worker. + self.removed = self.buf["removed"].ravel() from multiprocessing import Pipe, Process @@ -1002,7 +1027,20 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + # Cache the loaded state_dict at module level so render epochs + # don't reload the (~50MB) checkpoint from disk every time + # vector.make is called. The co-player checkpoint is FROZEN + # throughout a run, so re-reading is pure I/O waste. Keyed by + # full path + mtime to invalidate if the file is replaced. + try: + _ckpt_mtime = os.path.getmtime(checkpoint_path) + except OSError: + _ckpt_mtime = 0 + _cache_key = (checkpoint_path, _ckpt_mtime) + state_dict = _CO_PLAYER_STATE_DICT_CACHE.get(_cache_key) + if state_dict is None: + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + _CO_PLAYER_STATE_DICT_CACHE[_cache_key] = state_dict policy.load_state_dict(state_dict, strict=True) if external_coplayer: diff --git a/render.py b/render.py index 4742f0c1f2..6384c55f76 100644 --- a/render.py +++ b/render.py @@ -117,6 +117,11 @@ def build_config(args): config["env"]["num_ego_agents"] = args.num_ego_agents config["env"]["k_scenarios"] = args.k_scenarios config["env"]["scenario_length"] = args.scenario_length + if args.goal_behavior is not None: + config["env"]["goal_behavior"] = int(args.goal_behavior) + # Under gb=3: max_trials_per_episode and per_trial_timeout are derived + # from k_scenarios + scenario_length in AdaptiveDrivingAgent.__init__. + # No separate CLI knobs. if args.human_replay: if env_name == "puffer_adaptive_drive": @@ -190,7 +195,13 @@ def render_one(env_name, base_config, view_modes, render_idx, seed, args): mode = mode_tag(args) coplayer_part = "" - max_steps = args.max_steps if args.max_steps is not None else (args.k_scenarios * args.scenario_length) + # Default max_steps = full episode budget = k_scenarios * scenario_length + # under both trial and non-trial modes. Under gb=3 the auto-link makes + # max_trials * per_trial_timeout identical. + if args.max_steps is not None: + max_steps = args.max_steps + else: + max_steps = args.k_scenarios * args.scenario_length os.makedirs(args.output_dir, exist_ok=True) saved = [] @@ -266,6 +277,14 @@ def main(): p.add_argument("--k-scenarios", type=int, default=2, help="Number of scenarios per episode (adaptive)") p.add_argument("--scenario-length", type=int, default=91) + p.add_argument( + "--goal-behavior", + type=int, + default=None, + help="Goal behavior: 0=RESPAWN, 1=GENERATE_NEW, 2=STOP, 3=TRIAL " + "(under TRIAL: k_scenarios = #trials, scenario_length = per-trial timeout). " + "Defaults to whatever the checkpoint was trained with (ini default 0).", + ) p.add_argument( "--max-steps", type=int, default=None, help="Steps per render (default: k_scenarios * scenario_length)" ) diff --git a/scripts/adaptive/cluster_nuplan_transformer_k4_gb3.sh b/scripts/adaptive/cluster_nuplan_transformer_k4_gb3.sh new file mode 100755 index 0000000000..e90a37a3bd --- /dev/null +++ b/scripts/adaptive/cluster_nuplan_transformer_k4_gb3.sh @@ -0,0 +1,119 @@ +#!/bin/bash +#SBATCH --job-name=ada_k4_gb3 +#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out +#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err +#SBATCH --mem=256GB +#SBATCH --time=24:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --account=torch_pr_355_tandon_advanced +#SBATCH --cpus-per-task=40 +#SBATCH --gres=gpu:1 +#SBATCH --array=0-11 + +# k=4 gb=3 (GOAL_TRIAL) adaptive sweep: 4 partners × 3 seeds = 12 tasks. +# Each array task is ONE partner × ONE seed on ONE GPU with nw=32. +# +# Array indexing: TASK_ID = partner_idx * 3 + seed_idx +# partner_idx ∈ {0..3} → PARTNERS[partner_idx] +# seed_idx ∈ {0..2} → SEEDS[seed_idx] +# +# Submit: sbatch scripts/adaptive/cluster_nuplan_transformer_k4_gb3.sh + +# (label, partner_id, entropy_ub) +PARTNERS=( + "p005 miku2puk 0.05" + "p010 2e029h15 0.10" + "p020 m2ygolog 0.20" + "p050 6rauydj2 0.50" +) +SEEDS=(42 43 44) + +PARTNER_IDX=$((SLURM_ARRAY_TASK_ID / 3)) +SEED_IDX=$((SLURM_ARRAY_TASK_ID % 3)) +read -r LABEL PARTNER_ID ENTROPY_UB <<< "${PARTNERS[$PARTNER_IDX]}" +SEED=${SEEDS[$SEED_IDX]} +COPLAYER_PATH="experiments/puffer_drive_${PARTNER_ID}.pt" + +# Fixed +GAMMA=0.995 +COLLISION_LB=-2 +OFFROAD_LB=-2 +LANE_REWARD=0.025 +DISCOUNT_LB=0.4 +DISCOUNT_UB=1 +NUPLAN_NUM_MAPS=4999 +TOTAL_TIMESTEPS=2000000000 # 2B + +K_SCENARIOS=4 +SCENARIO_LENGTH=201 +HORIZON=$((K_SCENARIOS * SCENARIO_LENGTH)) # 804 + +NUM_WORKERS=32; NUM_ENVS=32 +MINIBATCH_MULTIPLIER=25 # minibatch_size = 25 * 804 = 20100 +MAX_MINIBATCH_SIZE=20100 + +TAG="ada_k4_gb3_${LABEL}_${PARTNER_ID}_s${SEED}_g${GAMMA}_lane${LANE_REWARD}_nw${NUM_WORKERS}" + +singularity exec --nv \ + --overlay "$OVERLAY_FILE:ro" \ + "$SINGULARITY_IMAGE" \ + bash -c " + set -e + + source ~/.bashrc + cd /scratch/mmk9418/projects/Adaptive_Driving_Agent + source .venv/bin/activate + + export WANDB_MODE=online + export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + + nice -n 19 python scripts/gpu_heartbeat.py & + HEARTBEAT_PID=\$! + + xvfb-run -a puffer train puffer_adaptive_drive \ + --wandb --wandb-project adaptive_aligned \ + --tag $TAG \ + --policy-architecture Transformer --rnn-name Transformer \ + --train.gamma $GAMMA \ + --train.horizon $HORIZON \ + --train.minibatch-multiplier $MINIBATCH_MULTIPLIER \ + --train.max-minibatch-size $MAX_MINIBATCH_SIZE \ + --train.cpu-offload True \ + --train.checkpoint-interval 10 \ + --train.render-interval 30 \ + --train.seed $SEED \ + --train.total-timesteps $TOTAL_TIMESTEPS \ + --vec.num-workers $NUM_WORKERS --vec.num-envs $NUM_ENVS \ + --env.map-dir resources/drive/binaries/nuplan_201 \ + --env.num-maps $NUPLAN_NUM_MAPS \ + --env.scenario-length $SCENARIO_LENGTH \ + --env.k-scenarios $K_SCENARIOS \ + --env.goal-behavior 3 \ + --env.conditioning.type none \ + --env.reward-lane-align $LANE_REWARD \ + --env.co-player-enabled 1 \ + --env.co-player-policy.policy-path $COPLAYER_PATH \ + --env.co-player-policy.architecture Transformer \ + --env.co-player-policy.transformer.horizon $SCENARIO_LENGTH \ + --env.co-player-policy.conditioning.type all \ + --env.co-player-policy.conditioning.collision-weight-lb $COLLISION_LB \ + --env.co-player-policy.conditioning.collision-weight-ub 0 \ + --env.co-player-policy.conditioning.offroad-weight-lb $OFFROAD_LB \ + --env.co-player-policy.conditioning.offroad-weight-ub 0 \ + --env.co-player-policy.conditioning.entropy-weight-lb 0 \ + --env.co-player-policy.conditioning.entropy-weight-ub $ENTROPY_UB \ + --env.co-player-policy.conditioning.discount-weight-lb $DISCOUNT_LB \ + --env.co-player-policy.conditioning.discount-weight-ub $DISCOUNT_UB \ + --env.external-co-player-actions True \ + --env.map-rand-per-scenario False \ + --eval.map-dir resources/drive/binaries/nuplan_hard \ + --eval.num-maps 64 \ + --eval.human-replay-eval True \ + --eval.human-replay-num-rollouts 40 \ + --eval.human-replay-num-maps 64 \ + --eval.human-replay-num-agents 64 \ + --eval.eval-interval 10 + + kill \$HEARTBEAT_PID + " diff --git a/scripts/adaptive/cluster_smoke_test.sh b/scripts/adaptive/cluster_smoke_test.sh new file mode 100755 index 0000000000..a51d003d3b --- /dev/null +++ b/scripts/adaptive/cluster_smoke_test.sh @@ -0,0 +1,71 @@ +#!/bin/bash +#SBATCH --job-name=ada_k4_smoke +#SBATCH --output=/scratch/mmk9418/logs/smoke_%j.out +#SBATCH --error=/scratch/mmk9418/logs/smoke_%j.err +#SBATCH --mem=64GB +#SBATCH --time=00:20:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --account=torch_pr_355_tandon_advanced +#SBATCH --cpus-per-task=8 +#SBATCH --gres=gpu:1 + +# Smoke test for cluster_nuplan_transformer_k4_gb3.sh — runs a tiny single-partner +# training to validate: +# - singularity image + overlay + venv activate +# - C extension built for the cluster arch +# - nuplan_201 + nuplan_hard binaries present +# - partner checkpoint loads +# - co-player + ego forward passes work on GPU +# - wandb sync online +# - first dashboard frame renders without OOM/NaN +# +# Lightweight: nw=4, 5M timesteps (~3-4 min). Submit, watch the log, kill once +# you see Steps advancing. + +singularity exec --nv --overlay "$OVERLAY_FILE:ro" "$SINGULARITY_IMAGE" bash -c " + set -e + cd /scratch/mmk9418/projects/Adaptive_Driving_Agent + source .venv/bin/activate + export WANDB_MODE=online + export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + + xvfb-run -a puffer train puffer_adaptive_drive \ + --wandb --wandb-project adaptive_aligned \ + --tag smoke_k4_gb3_\$(date +%s) \ + --policy-architecture Transformer --rnn-name Transformer \ + --train.horizon 804 \ + --train.minibatch-multiplier 4 \ + --train.max-minibatch-size 3216 \ + --train.cpu-offload True \ + --train.checkpoint-interval 100 \ + --train.render-interval 100 \ + --train.seed 42 \ + --train.total-timesteps 2000000 \ + --vec.num-workers 4 --vec.num-envs 4 \ + --env.map-dir resources/drive/binaries/nuplan_201 \ + --env.num-maps 4999 \ + --env.scenario-length 201 \ + --env.k-scenarios 4 \ + --env.goal-behavior 3 \ + --env.conditioning.type none \ + --env.reward-lane-align 0.025 \ + --env.co-player-enabled 1 \ + --env.co-player-policy.policy-path experiments/puffer_drive_miku2puk.pt \ + --env.co-player-policy.architecture Transformer \ + --env.co-player-policy.transformer.horizon 201 \ + --env.co-player-policy.conditioning.type all \ + --env.co-player-policy.conditioning.collision-weight-lb -2 \ + --env.co-player-policy.conditioning.collision-weight-ub 0 \ + --env.co-player-policy.conditioning.offroad-weight-lb -2 \ + --env.co-player-policy.conditioning.offroad-weight-ub 0 \ + --env.co-player-policy.conditioning.entropy-weight-lb 0 \ + --env.co-player-policy.conditioning.entropy-weight-ub 0.05 \ + --env.co-player-policy.conditioning.discount-weight-lb 0.4 \ + --env.co-player-policy.conditioning.discount-weight-ub 1 \ + --env.external-co-player-actions True \ + --env.map-rand-per-scenario False \ + --eval.map-dir resources/drive/binaries/nuplan_hard \ + --eval.num-maps 16 \ + --eval.eval-interval 1000 +" diff --git a/scripts/adaptive/nuplan_transformer_local_k2_201_gb3_2partners.sh b/scripts/adaptive/nuplan_transformer_local_k2_201_gb3_2partners.sh new file mode 100755 index 0000000000..477e23bd16 --- /dev/null +++ b/scripts/adaptive/nuplan_transformer_local_k2_201_gb3_2partners.sh @@ -0,0 +1,120 @@ +#!/bin/bash +set -e + +# 2 k=2 transformer runs, gb=3 (GOAL_TRIAL) — low-entropy partners only. +# Same config as the 4-partner sweep but bigger nw (more rollout throughput +# per run) since we're only running 2 concurrently. +# +# Train on nuplan_201 (5000 maps), eval + renders on nuplan_hard (540 maps). +# 2B total timesteps each. +# +# Memory math: per-run pinned ≈ nw*2.83 GiB obs + ~10 GiB overhead. +# At nw=24: 2 × (24*2.83 + 10) ≈ 156 GiB — over 132 GiB cap (tight). +# If OOMs, drop to nw=20. +# +# Default GPUs: 0-1. Override: GPUS="0 1" bash