diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock
new file mode 100644
index 0000000000..595f769155
--- /dev/null
+++ b/.claude/scheduled_tasks.lock
@@ -0,0 +1 @@
+{"sessionId":"7bcf951d-92e0-44b9-83b1-dc29b0d69f18","pid":8923,"procStart":"408700719","acquiredAt":1777664169801}
diff --git a/.github/workflows/utest.yml b/.github/workflows/utest.yml
index cfcbdda4b7..db5e5f0e4a 100644
--- a/.github/workflows/utest.yml
+++ b/.github/workflows/utest.yml
@@ -4,7 +4,6 @@ on:
pull_request:
branches: [ main ]
push:
- branches: [ main ]
jobs:
test:
@@ -21,13 +20,9 @@ jobs:
- name: Free up disk space
run: |
- echo "Before cleanup:"
- df -h
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc
sudo apt-get clean
sudo rm -rf ~/.cache/pip /tmp/* /var/tmp/*
- echo "After cleanup:"
- df -h
- name: Install dependencies
run: |
@@ -37,8 +32,7 @@ jobs:
run: python -m pip install -U pip
- name: Install pufferlib
- run: |
- pip install -e .[cpu] --no-cache-dir
+ run: pip install -e .[cpu] --no-cache-dir pytest
env:
TMPDIR: ${{ runner.temp }}/build
PIP_NO_CACHE_DIR: 1
@@ -53,3 +47,11 @@ jobs:
- name: Run python ini parsing tests
run: python tests/test_drive_config.py
+
+ - name: Run pytest suite (conditioning, map_dir, render contracts)
+ run: |
+ pytest \
+ tests/test_drive_conditioning.py \
+ tests/test_map_dir_flow.py \
+ tests/test_render_pipeline.py \
+ -v --tb=short
diff --git a/.gitignore b/.gitignore
index a83a714b7f..71497ea416 100644
--- a/.gitignore
+++ b/.gitignore
@@ -161,13 +161,16 @@ pufferlib/ocean/impulse_wars/debug-*/
pufferlib/ocean/impulse_wars/release-*/
pufferlib/ocean/impulse_wars/benchmark/
-
# Ignore data files
data/
-pufferlib/resources/drive/binaries/
+pufferlib/resources/drive/binaries/*
+pufferlib/resources/drive/binaries/training/
+pufferlib/resources/drive/binaries/validation/
# But keep map_000.bin for the training test
!pufferlib/resources/drive/binaries/map_000.bin
+!pufferlib/resources/drive/binaries/training/map_000.bin
+pufferlib/resources/drive/sanity/sanity_binaries/
# Compiled drive binary in root
/drive
@@ -183,7 +186,12 @@ pufferlib/resources/drive/output_agent.gif pufferlib/resources/drive/output.gif
# Local artifacts and outputs
artifacts/
# Local drive renders
+pufferlib/resources/drive/output*.gif
+emsdk/
+docs/book/*
+!docs/book/assets/
pufferlib/resources/drive/output*.mp4
# Local TODO tracking
TODO.md
+*.mp4
diff --git a/README.md b/README.md
index 4ac61c39eb..06ebf54b84 100644
--- a/README.md
+++ b/README.md
@@ -1,216 +1,350 @@
-# PufferDrive
+# Adaptive Driving Agent
-
+A fork of [PufferDrive](https://github.com/Emerge-Lab/PufferDrive) for training
+**adaptive** driving policies — agents that infer their partners' behavior
+on the fly instead of being told what to expect.
-**PufferDrive is a fast and friendly driving simulator to train and test RL-based models.**
+## What we're doing
-
-
-
-
-
-
-
-
-
-
+The training pipeline is two-stage:
----
+1. **Train a population of co-players.** A single conditioned policy
+ (`puffer_drive`) is trained over a range of conditioning values
+ (entropy weight, discount factor, reward weights). Sampling different
+ conditioning vectors during inference gives us a *population* of behaviors
+ from one set of weights.
+2. **Train an adaptive ego against that population** (`puffer_adaptive_drive`).
+ The adaptive policy is *not* conditioned — it has to figure out what kind
+ of partner it's facing from the observation stream. Each episode contains
+ `k_scenarios` scenarios; the partner is re-sampled at scenario boundaries
+ so the agent has multiple shots to adapt within one episode.
-**Docs**: https://emerge-lab.github.io/PufferDrive
+Everything runs in PufferLib's RL training stack, with `Drive` (the C+raylib
+simulator) for the environment and a PyTorch policy on top.
----
+## Repo layout
-### See our 2.0 release video
-
-
-
-
+```
+Adaptive_Driving_Agent/
+├── render.py # unified Python rendering CLI
+├── pufferlib/
+│ ├── pufferl.py # train/eval entrypoint
+│ ├── utils.py # render_videos, eval subprocess builders
+│ ├── ocean/drive/
+│ │ ├── drive.h # C simulator + raylib renderer
+│ │ ├── drive.py # gymnasium wrapper, RenderView, render()
+│ │ ├── adaptive.py # AdaptiveDrivingAgent subclass
+│ │ ├── rollout.py # arch-agnostic rollout loop (used by render + eval)
+│ │ ├── torch.py # encoder + LSTMWrapper / TransformerWrapper
+│ │ └── binding.{c,h} # Python C-extension
+│ ├── ocean/benchmark/evaluator.py # HumanReplayEvaluator, WOSACEvaluator
+│ └── config/ocean/{drive,adaptive}.ini
+├── scripts/
+│ ├── coplayers/ # train co-players (per dataset × arch)
+│ ├── adaptive/ # train adaptive ego using a co-player ckpt
+│ ├── baselines/ # vanilla (no co-player) for comparison
+│ └── ablations/
+│ ├── all_coplayers.sh # one-shot driver for the full coplayer matrix
+│ └── human_align_ablation.sh # collision/offroad weight grid
+└── tests/ # pytest suite (see CI section)
+```
-## Installation
+## Setup
-Clone the repo
```bash
-https://github.com/Emerge-Lab/PufferDrive.git
-```
+git clone
+cd Adaptive_Driving_Agent
-Make a venv (`uv venv`), activate the venv
-```
-source .venv/bin/activate
-```
-
-Inside the venv, install the dependencies
-```
+uv venv && source .venv/bin/activate
uv pip install -e .
-```
-Compile the C code
-```
+# Build the C extension
python setup.py build_ext --inplace --force
-```
-Run this while your virtual environment is active so the extension is built against the right interpreter.
-To test your setup, you can run
+# Headless rendering deps (skip on a desktop with display)
+sudo apt install ffmpeg xvfb # or:
+conda install -c conda-forge xorg-x11-server-xvfb-cos6-x86_64 ffmpeg
```
-puffer train puffer_drive
-```
-See also the [puffer docs](https://puffer.ai/docs.html).
+### Map binaries
-## Quick start
+The simulator reads pre-converted map binaries. Two datasets ship out of
+the box:
-Start a training run
```
-puffer train puffer_drive
+resources/drive/binaries/training/ # WOMD (default)
+resources/drive/binaries/nuplan/ # nuPlan
```
-## Dataset
-
-
-Downloading and using data
-
-### Data preparation
-
-To train with PufferDrive, you need to convert JSON files to map binaries. Run the following command with the path to your data folder:
+To convert your own JSON scenes:
```bash
-python pufferlib/ocean/drive/drive.py
+python -c "from pufferlib.ocean.drive.drive import process_all_maps; \
+ process_all_maps('path/to/json/folder', max_maps=5000)"
```
-### Downloading Waymo Data
-
-You can download the WOMD data from Hugging Face in two versions:
+The output lives at `resources/drive/binaries//map_NNN.bin`.
-- **Mini dataset**: [GPUDrive_mini](https://huggingface.co/datasets/EMERGE-lab/GPUDrive_mini) contains 1,000 training files and 300 test/validation files
-- **Medium dataset**: [10,000 files from the training dataset](https://huggingface.co/datasets/daphne-cornelisse/pufferdrive_train)
-- **Large dataset**: [GPUDrive](https://huggingface.co/datasets/EMERGE-lab/GPUDrive) contains 100,000 unique scenes
+## Training
-**Note**: Replace 'GPUDrive_mini' with 'GPUDrive' in your download commands if you want to use the full dataset.
+The CLI is `puffer train `. The env names are:
+- `puffer_drive` — co-player / baseline training
+- `puffer_adaptive_drive` — adaptive ego against frozen co-player
-### Additional Data Sources
+Architecture is selected with `--policy-architecture {Recurrent, Transformer}`.
-For more training data compatible with PufferDrive, see [ScenarioMax](https://github.com/valeoai/ScenarioMax). The GPUDrive data format is fully compatible with PufferDrive.
-
+### 1. Train the co-player population
+Co-players are trained on a single env with conditioning sweeping over the
+desired range. The conditioning at inference time is what makes the
+population diverse.
-## Visualizer
+Per-dataset slurm scripts (4×4 entropy/discount grid, 16 array tasks each):
-
-Dependencies and usage
-
-## Local rendering
-
-To launch an interactive renderer, first build:
-```
-bash scripts/build_ocean.sh drive local
+```bash
+sbatch scripts/coplayers/nuplan_recurrent.sh
+sbatch scripts/coplayers/nuplan_transformer.sh
+sbatch scripts/coplayers/womd_recurrent.sh
+sbatch scripts/coplayers/womd_transformer.sh
```
-then launch:
+Or, to fire all 8 ablation cells (2 datasets × 2 archs × {none, all}) as a
+single slurm array:
+
```bash
-./drive
+sbatch scripts/ablations/all_coplayers.sh # 8 array tasks, one per cell
```
-this will run `demo()` with an existing model checkpoint.
-## Headless server setup
+The matrix is documented in the script's header.
-Run the Raylib visualizer on a headless server and export as .mp4. This will rollout the pre-trained policy in the env.
+Resulting checkpoints land at `experiments/puffer_drive_.pt`.
-### Install dependencies
+### 2. Train the adaptive ego
-```bash
-sudo apt update
-sudo apt install ffmpeg xvfb
-```
+The adaptive scripts wire one co-player checkpoint per array task, then
+train the adaptive ego with `co_player_enabled=True`:
-For HPC (There are no root privileges), so install into the conda environment
```bash
-conda install -c conda-forge xorg-x11-server-xvfb-cos6-x86_64
-conda install -c conda-forge ffmpeg
+sbatch scripts/adaptive/nuplan_recurrent.sh
+sbatch scripts/adaptive/nuplan_transformer.sh
+sbatch scripts/adaptive/womd_recurrent.sh
+sbatch scripts/adaptive/womd_transformer.sh
```
-- `ffmpeg`: Video processing and conversion
-- `xvfb`: Virtual display for headless environments
-
-### Build and run
+Before submitting, fill in the `ZIPPED_RUNS=(...)` array in each script with
+the co-player checkpoint paths from step 1. Run IDs come from the wandb URLs
+of the co-player runs.
-1. Build the application:
+The key flags are:
```bash
-bash scripts/build_ocean.sh visualize local
+puffer train puffer_adaptive_drive \
+ --policy-architecture Recurrent \ # or Transformer
+ --env.k-scenarios 2 \ # scenarios per episode
+ --env.co-player-enabled True \
+ --env.co-player-policy.policy-path .pt \
+ --env.co-player-policy.architecture Recurrent \
+ --env.co-player-policy.conditioning.type all \
+ --env.co-player-policy.conditioning.entropy-weight-lb 0 \
+ --env.co-player-policy.conditioning.entropy-weight-ub 0.1 \
+ --env.co-player-policy.conditioning.discount-weight-lb 0.8 \
+ --env.co-player-policy.conditioning.discount-weight-ub 1.0 \
+ --env.map-dir resources/drive/binaries/nuplan \
+ --eval.map-dir resources/drive/binaries/nuplan
```
-2. Run with virtual display:
+### 3. Baselines (no co-player)
+
+For ablation, train the same architectures with `co_player_enabled=False`:
+
```bash
-xvfb-run -s "-screen 0 1280x720x24" ./visualize
+sbatch scripts/baselines/nuplan_recurrent.sh
+sbatch scripts/baselines/nuplan_transformer.sh
+sbatch scripts/baselines/womd_recurrent.sh
+sbatch scripts/baselines/womd_transformer.sh
```
-The `-s` flag sets up a virtual screen at 1280x720 resolution with 24-bit color depth.
+Other agents in the scene replay their logged human trajectories.
----
+## Rendering
-> To force a rebuild, you can delete the cached compiled executable binary using `rm ./visualize`.
+Single Python entrypoint, `render.py`, replaces the old `./visualize` C
+binary and the per-scenario eval scripts. It works for **any** architecture
+because policy inference is done in PyTorch (the C side only handles
+graphics). Architecture is auto-detected from the checkpoint state-dict.
----
+### Modes
-
+| Flags | What it renders |
+|---|---|
+| `--model-path X.pt` | Baseline: ego drives, others follow logs |
+| `--model-path adaptive.pt --co-player-path coplayer.pt` | Adaptive ego vs frozen co-player |
+| `--model-path X.pt --human-replay` | One ego is policy-controlled; everyone else replays human logs |
+### Examples
-## Benchmarks
+```bash
+# Baseline render of a co-player checkpoint
+xvfb-run -s "-screen 0 1280x720x24" python render.py \
+ --model-path experiments/puffer_drive_.pt \
+ --map-dir resources/drive/binaries/training \
+ --conditioning-type all \
+ --num-renders 5
+
+# Adaptive ego rendered against a co-player population
+python render.py \
+ --model-path experiments/puffer_adaptive_drive_.pt \
+ --co-player-path experiments/puffer_drive_.pt \
+ --co-player-conditioning-type all \
+ --k-scenarios 2 \
+ --num-renders 3 \
+ --map-dir resources/drive/binaries/nuplan
+
+# Human-replay: how compatible is the agent with logged humans?
+python render.py \
+ --model-path experiments/puffer_adaptive_drive_.pt \
+ --human-replay \
+ --k-scenarios 2 \
+ --num-renders 5
+
+# All three views (top-down, BEV, third-person) in one call
+python render.py --model-path X.pt --view-mode all
+```
-### Distributional realism
+### Output
-We provide a PufferDrive implementation of the [Waymo Open Sim Agents Challenge (WOSAC)](https://waymo.com/open/challenges/2025/sim-agents/) for fast, easy evaluation of how well your trained agent matches distributional properties of human behavior. See documentation [here](https://emerge-lab.github.io/PufferDrive/wosac/).
+Videos are written to `/renders/`, with descriptive filenames so
+multiple modes on the same map don't collide:
-WOSAC evaluation with random policy:
-```bash
-puffer eval puffer_drive --eval.wosac-realism-eval True
+```
+experiments/puffer_adaptive_drive_/renders/
+ _baseline_k1_map002_sim_state.mp4
+ _human_replay_k2_map002_sim_state.mp4
+ _vs__k2_map002_sim_state.mp4
+ _vs__k2_map002_bev.mp4
+ _vs__k2_map002_persp.mp4
```
-WOSAC evaluation with your checkpoint (must be .pt file):
-```bash
-puffer eval puffer_drive --eval.wosac-realism-eval True --load-model-path .pt
+Different renders pick different maps via a prime stride
+(`--start-seed`, `--seed-stride`).
+
+### During training
+
+`puffer train ... --train.render True` renders periodically every
+`--train.render-interval` epochs through the same path, writes to
+`/renders/epoch___k_map_.mp4`, and uploads
+to wandb.
+
+## Checkpoint sidecar (info.json)
+
+Every checkpoint dir gets an `info.json` written alongside the `.pt` files
+at save time. It records the training config you can't recover from the
+weights alone — useful for grepping later or for filling in adaptive
+training scripts by hand:
+
+```json
+{
+ "run_id": "1qcm9dc0",
+ "env_name": "puffer_drive",
+ "policy_architecture": "Recurrent",
+ "env": {
+ "map_dir": "resources/drive/binaries/nuplan",
+ "k_scenarios": 1,
+ "scenario_length": 91,
+ "dynamics_model": "classic",
+ "co_player_enabled": false,
+ "conditioning": {
+ "type": "all",
+ "entropy_weight_lb": 0.0, "entropy_weight_ub": 0.1,
+ "discount_weight_lb": 0.8, "discount_weight_ub": 1.0,
+ "..."
+ }
+ }
+}
```
-### Human-compatibility
+This is reference-only — no autoload. When you train an adaptive ego
+against a co-player, you still pass `--env.co-player-policy.policy-path`
+(and the matching architecture / conditioning flags) explicitly so the
+intent stays in the script.
+
+## Evaluation
+
+Two evaluators ship out of the box:
+
+### Human-replay (compatibility with logged humans)
-You may be interested in how compatible your agent is with human partners. For this purpose, we support an eval where your policy only controls the self-driving car (SDC). The rest of the agents in the scene are stepped using the logs. While it is not a perfect eval since the human partners here are static, it will still give you a sense of how closely aligned your agent's behavior is to how people drive. You can run it like this:
```bash
-puffer eval puffer_drive --eval.human-replay-eval True --load-model-path .pt
+puffer eval puffer_drive \
+ --eval.human-replay-eval True \
+ --load-model-path experiments/puffer_drive_.pt \
+ --eval.map-dir resources/drive/binaries/nuplan
```
-## Development
+For adaptive agents, this also reports `ada_delta_*` metrics
+(last_scenario − first_scenario) so you can see how much the policy
+adapted within an episode.
-Documentation and browser demo
+The training loop runs this automatically every `--eval.eval-interval`
+epochs when `--eval.human-replay-eval True`.
-**Docs**
+### WOSAC (distributional realism)
-A browsable documentation site now lives under `docs/` and is configured with mkbooks. To preview locally:
-```
-brew install mdbook
-mdbook serve --open docs
+```bash
+puffer eval puffer_drive \
+ --eval.wosac-realism-eval True \
+ --load-model-path experiments/puffer_drive_.pt
```
-Open the served URL to see a local version of the docs.
-**Interactive demo**
+### Map-dir handling
-To edit the browser demo, follow these steps:
-- Download [emscripten](https://github.com/emscripten-core/emscripten)
-- emscripten install latest
-- Activate: `source emsdk/emsdk_env.sh`
-- Run `bash scripts/build_ocean.sh drive web`
-- This generates a number of `game*` files, move them to `assets/` to include them on the webpage
+`eval.map_dir` defaults to `None` in both `drive.ini` and `adaptive.ini`,
+which means **eval inherits the training `env.map_dir`** instead of silently
+falling back to a different dataset. To use a different dataset for eval,
+set `--eval.map-dir` explicitly.
-
+## Tests + CI
+```bash
+pip install pytest
+pytest tests/test_drive_conditioning.py \
+ tests/test_map_dir_flow.py \
+ tests/test_render_pipeline.py \
+ tests/test_drive_config.py
+```
-## Citation
+The pytest suite covers:
+- **Conditioning shapes** (`tests/test_drive_conditioning.py`):
+ every `conditioning_type` × `dynamics_model` combination produces the
+ right observation dim and value range.
+- **map_dir propagation** (`tests/test_map_dir_flow.py`):
+ `eval()` inherits `env.map_dir` when `eval.map_dir` is unset, and the
+ human-replay / WOSAC subprocess builders forward a concrete
+ `--eval.map-dir` so the child process can't fall back to ini default.
+- **Render contract** (`tests/test_render_pipeline.py`):
+ `set_video_suffix(name)` produces `name.mp4`. The dataclass default is safe.
+
+GitHub Actions:
+- `.github/workflows/install.yml`: pre-commit on every push and PR.
+- `.github/workflows/utest.yml`: pytest suite on every push (any branch) and PRs.
+- `.github/workflows/render-ci.yml`: end-to-end render smoke (xvfb + ffmpeg).
+- `.github/workflows/train-ci.yml`: short training + scenario tests.
+
+Run pre-commit locally before pushing:
-If you use PufferDrive in your research, please cite:
-```bibtex
-@software{pufferdrive2025github,
- author = {Daphne Cornelisse* and Spencer Cheng* and Pragnay Mandavilli and Julian Hunt and Kevin Joseph and Waël Doulazmi and Aditya Gupta and Eugene Vinitsky},
- title = {{PufferDrive}: A Fast and Friendly Driving Simulator for Training and Evaluating {RL} Agents},
- url = {https://github.com/Emerge-Lab/PufferDrive},
- version = {2.0.0},
- year = {2025},
-}
+```bash
+pre-commit install # one-time
+pre-commit run --all-files
```
+
+## Common pitfalls
+
+- **Checkpoint shape mismatch on render.** Conditioning changes the ego
+ observation dim, so a checkpoint trained with `conditioning.type=all`
+ needs `--conditioning-type all` at render time too.
+- **Eval using the wrong dataset.** Either set `--eval.map-dir` explicitly
+ or rely on inheritance (the default with `eval.map_dir = None`). Don't set
+ `eval.map_dir` to a stale value in your ini.
+- **Adaptive batch sizes.** `minibatch_size` must be ≥ `batch_size` and both
+ must be divisible by `horizon = k_scenarios * scenario_length` for adaptive.
+- **Multiple consecutive renders.** The C extension recreates the raylib
+ window per render — that's expected, ~1s overhead per render.
diff --git a/TODO_paper.md b/TODO_paper.md
new file mode 100644
index 0000000000..f90aaffb93
--- /dev/null
+++ b/TODO_paper.md
@@ -0,0 +1,39 @@
+# Paper TODO
+
+## Pending — resume after recovery eval finishes
+- **Resume the 4 k=3 adaptive runs** (paused on 2026-04-28 to free GPUs 4-7
+ for parallel recovery-metric eval). Latest checkpoints are at epochs
+ 150-170. One-command resume:
+ ```
+ bash /workspace/ADA/scripts/adaptive/nuplan_transformer_local_k3_resume.sh
+ ```
+ The script already auto-locates the latest `model_*.pt` per wandb id and
+ uses the optimized flags (cpu_offload + external_co_player_actions, nw=32).
+
+## In progress
+- **Conditional recovery metric** in `HumanReplayEvaluator`:
+ per-(agent, scenario) success tracking, then `P(success in s_k | failed s_0)`.
+ Surfaces in-context adaptation in the small fraction of nuplan scenes
+ where adaptation actually matters (most scenes are easy and ego trivially
+ succeeds in s_0, washing out the averaged `ada_delta_score`).
+
+## Backlog (after we have a baseline conditional-recovery number)
+
+- **Map rotation per scenario within an episode**.
+ At each scenario boundary, swap the underlying nuplan map so the ego sees
+ a new scene with new humans. KV cache preserved across scenarios so past
+ observations can inform the new scene. Forces more "hard" cases by
+ guaranteeing each scenario is genuinely novel. Implement as a flag:
+ `--env.map-rand-per-scenario {none, eval, train, both}`. Likely requires
+ retraining adaptive agents to handle the within-episode discontinuity.
+
+- **Train-time conditional-recovery loss/weighting**: bias rollout sampling
+ toward "hard" scenarios (rare but informative) once we have per-agent
+ difficulty estimates from offline eval.
+
+- **Cross-co-player generalization eval**: train ego against partners
+ {A, B, C}, eval against held-out partner D. Tests whether learned
+ adaptation generalizes to unseen partner styles.
+
+- **Architecture ablation**: run with KV-cache disabled (or context-length=1)
+ to confirm the cache is what's enabling adaptation (when it works).
diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md
index 68f9b699e2..3c5e1784f2 100644
--- a/docs/src/getting-started.md
+++ b/docs/src/getting-started.md
@@ -25,7 +25,7 @@ python setup.py build_ext --inplace --force
Run this with your virtual environment activated so the compiled extension links against the correct Python.
### When to rebuild the extension
-- Re-run `python setup.py build_ext --inplace --force` after changing any C/Raylib sources in `pufferlib/ocean/drive` (e.g., `drive.c`, `drive.h`, `binding.c`, `visualize.c`) or after pulling upstream changes that touch those files. This regenerates the `binding.cpython-*.so` used by `Drive`.
+- Re-run `python setup.py build_ext --inplace --force` after changing any C/Raylib sources in `pufferlib/ocean/drive` (e.g., `drive.h`, `binding.c`, `binding.h`) or after pulling upstream changes that touch those files. This regenerates the `binding.cpython-*.so` used by `Drive`.
- Pure Python edits (training scripts, docs, data utilities) do not require a rebuild; just restart your Python process.
## Verify the setup
diff --git a/docs/src/simulator.md b/docs/src/simulator.md
index a500840bb1..c52abc045a 100644
--- a/docs/src/simulator.md
+++ b/docs/src/simulator.md
@@ -210,15 +210,14 @@ mid_x, mid_y, length, width, dir_cos, dir_sin, type
## Source files
### C core
-- `drive.h`: Main simulator (stepping, observations, collisions)
-- `drive.c`: Demo and testing
-- `binding.c`: Python interface
-- `visualize.c`: Raylib renderer
-- `drivenet.h`: C inference network
+- `drive.h`: Main simulator (stepping, observations, collisions, raylib renderer)
+- `binding.c` / `binding.h`: Python interface
### Python
-- `drive.py`: Gymnasium wrapper
+- `drive.py`: Gymnasium wrapper, `Drive.render()`, `set_video_suffix()`
+- `rollout.py`: Policy-agnostic rollout loop used by both training renders and `render.py`
- `torch.py`: Neural network (ego/partner/road encoders → actor/critic)
+- `../../../render.py`: Unified rendering CLI (replaces the old `./visualize` binary)
## Neural network
diff --git a/docs/src/visualizer.md b/docs/src/visualizer.md
index c6b73c90ea..f7046b9d02 100644
--- a/docs/src/visualizer.md
+++ b/docs/src/visualizer.md
@@ -1,35 +1,42 @@
# Visualizer
-PufferDrive ships a Raylib-based visualizer for replaying scenes, exporting videos, and debugging policies.
+PufferDrive renders headless mp4s from Python. Policy inference runs in PyTorch
+and graphics are produced via the C/raylib bindings (`vec_render`), so the same
+pipeline works for both LSTM and Transformer policies.
## Dependencies
-Install the minimal system packages for headless render/export:
```bash
sudo apt update
sudo apt install ffmpeg xvfb
```
-On environments without sudo, install them into your conda/venv:
+Without sudo:
```bash
conda install -c conda-forge xorg-x11-server-xvfb-cos6-x86_64 ffmpeg
```
-## Build
-Compile the visualizer binary from the repo root:
+## Run
+
+The unified entrypoint is `render.py` at the repo root:
```bash
-bash scripts/build_ocean.sh visualize local
-```
+# Baseline: ego drives, others follow logged trajectories
+xvfb-run -s "-screen 0 1280x720x24" python render.py \
+ --model-path experiments/.pt --map-dir resources/drive/binaries/training
-If you need to force a rebuild, remove the cached binary first (`rm ./visualize`).
+# Adaptive ego + frozen co-player population
+python render.py --model-path adaptive.pt --co-player-path coplayer.pt \
+ --co-player-conditioning-type all --k-scenarios 2
-## Run headless
-Launch the visualizer with a virtual display and export an `.mp4`:
+# Human replay: only the SDC is policy-controlled, others = logged
+python render.py --model-path X.pt --human-replay --num-renders 5
-```bash
-xvfb-run -s "-screen 0 1280x720x24" ./visualize
+# Multiple views in one go
+python render.py --model-path X.pt --view-mode all
```
-Adjust the screen size and color depth as needed. The `xvfb-run` wrapper allows Raylib to render without an attached display, which is convenient for servers and CI jobs.
+Architecture is auto-detected from the checkpoint state-dict; override with
+`--policy-architecture {Recurrent,Transformer}`. See `python render.py --help`
+for the full set of conditioning, co-player, and rendering flags.
diff --git a/evaluate_human_logs.py b/evaluate_human_logs.py
deleted file mode 100644
index 83472b6c71..0000000000
--- a/evaluate_human_logs.py
+++ /dev/null
@@ -1,411 +0,0 @@
-import argparse
-import json
-import numpy as np
-import torch
-from tqdm import tqdm
-import pufferlib
-import pufferlib.vector
-from pufferlib.ocean import env_creator
-from pufferlib.ocean.torch import Drive, Recurrent
-
-import matplotlib.pyplot as plt
-import numpy as np
-
-
-def plot_adaptive_metrics(first_metrics, last_metrics, delta_metrics, output_path):
- """
- Plot adaptive metrics showing first scenario (0-shot), last scenario, and delta improvement.
- """
- # Metrics to plot
- metrics_to_plot = {
- "score": "Score",
- "collision_rate": "Collision Rate",
- "offroad_rate": "Offroad Rate",
- "episode_return": "Episode Return",
- }
-
- fig, axes = plt.subplots(2, 2, figsize=(14, 10))
- axes = axes.flatten()
-
- for idx, (metric_key, metric_name) in enumerate(metrics_to_plot.items()):
- ax = axes[idx]
-
- first_val = first_metrics[metric_key]
- last_val = last_metrics[metric_key]
- delta_key = f"ada_delta_{metric_key}"
- delta_pct = delta_metrics[delta_key]
-
- # Create bar chart
- x = np.arange(2)
- bars = ax.bar(x, [first_val, last_val], width=0.6, alpha=0.8)
-
- # Color bars based on improvement
- # For collision/offroad, decrease is good (green), increase is bad (red)
- # For score/return, increase is good (green), decrease is bad (red)
- if metric_key in ["collision_rate", "offroad_rate"]:
- bars[0].set_color("gray")
- bars[1].set_color("green" if delta_pct < 0 else "red")
- else:
- bars[0].set_color("gray")
- bars[1].set_color("green" if delta_pct > 0 else "red")
-
- # Add value labels on bars
- for i, (bar, val) in enumerate(zip(bars, [first_val, last_val])):
- height = bar.get_height()
- ax.text(
- bar.get_x() + bar.get_width() / 2.0,
- height,
- f"{val:.3f}",
- ha="center",
- va="bottom",
- fontsize=10,
- fontweight="bold",
- )
-
- # Add delta percentage annotation
- mid_x = 0.5
- mid_y = max(first_val, last_val) * 0.5
- arrow_props = dict(
- arrowstyle="->",
- lw=2,
- color="green"
- if (delta_pct > 0 and metric_key in ["score", "episode_return"])
- or (delta_pct < 0 and metric_key in ["collision_rate", "offroad_rate"])
- else "red",
- )
-
- ax.annotate(
- f"{delta_pct:+.1f}%",
- xy=(1, last_val),
- xytext=(mid_x, mid_y),
- fontsize=14,
- fontweight="bold",
- ha="center",
- bbox=dict(boxstyle="round,pad=0.5", facecolor="yellow", alpha=0.7),
- arrowprops=arrow_props,
- )
-
- # Formatting
- ax.set_xticks(x)
- ax.set_xticklabels(["First Scenario\n(0-shot)", "Last Scenario\n(Adapted)"], fontsize=11)
- ax.set_ylabel(metric_name, fontsize=12, fontweight="bold")
- ax.set_title(f"{metric_name}", fontsize=13, fontweight="bold")
- ax.grid(axis="y", alpha=0.3, linestyle="--")
- ax.spines["top"].set_visible(False)
- ax.spines["right"].set_visible(False)
-
- plt.tight_layout()
- plt.savefig(output_path.replace(".json", "_adaptive_metrics.png"), dpi=300, bbox_inches="tight")
- print(f"\nAdaptive metrics plot saved to {output_path.replace('.json', '_adaptive_metrics.png')}")
- plt.close()
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--policy-path", type=str, required=True)
- parser.add_argument("--num-maps", type=int, default=10)
- parser.add_argument("--num-rollouts", type=int, default=1000)
- parser.add_argument("--batch-size", type=int, default=32, help="Max parallel rollouts per batch")
- parser.add_argument("--num-workers", type=int, default=16)
- parser.add_argument("--num-agents", type=int, default=64)
- parser.add_argument(
- "--condition-type",
- type=str,
- default="none",
- choices=["none", "reward", "entropy", "discount", "all"],
- help="Conditioning type (none, reward, entropy, discount, all)",
- )
- parser.add_argument("--output", type=str, default="eval_human_logs.json")
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
- parser.add_argument(
- "--max-controlled-agents", type=int, default=-1
- ) ## needs to be 1 if you want human logs, -1 if you want Self Play
- parser.add_argument("--adaptive-driving-agent", type=int, default=0, help="Enable adaptive driving agent")
- parser.add_argument("--k-scenarios", type=int, default=1, help="Number of scenarios (default 1 for non-adaptive)")
- parser.add_argument("--dynamics-model", type=str, default="classic")
- args = parser.parse_args()
-
- num_batches = (args.num_rollouts + args.batch_size - 1) // args.batch_size
-
- print(f"Evaluation Configuration:")
- print(f" Policy: {args.policy_path}")
- print(f" Conditioning: {args.condition_type}")
- print(f" Num maps: {args.num_maps}")
- print(f" Total rollouts: {args.num_rollouts}")
- print(f" Batch size: {args.batch_size}")
- print(f" Num batches: {num_batches}")
- print(f" Num agents per env: {args.num_agents}")
- print(f" Adaptive agent: {bool(args.adaptive_driving_agent)}")
- print(f" K scenarios: {args.k_scenarios}")
- print(f" Output: {args.output}\n")
- print(f" Dynamics Model: {args.dynamics_model}")
-
- # Load policy
- print("Loading policy...")
- env_name = "puffer_adaptive_drive" if args.adaptive_driving_agent else "puffer_drive"
- make_env = env_creator(env_name)
- temp_env = make_env(
- num_agents=64,
- num_maps=args.num_maps,
- scenario_length=91,
- co_player_cond_type=args.condition_type,
- adaptive_driving_agent=args.adaptive_driving_agent,
- k_scenarios=args.k_scenarios,
- dynamics_model=args.dynamics_model,
- )
-
- base_policy = Drive(temp_env, input_size=64, hidden_size=256)
- policy = Recurrent(temp_env, base_policy, input_size=256, hidden_size=256).to(args.device)
- state_dict = torch.load(args.policy_path, map_location=args.device)
- state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
- policy.load_state_dict(state_dict)
- policy.eval()
- temp_env.close()
- print("Policy loaded successfully\n")
-
- # Run evaluation in batches
- all_returns = []
- all_scenario_metrics = [] # Track metrics per scenario for adaptive agents
- all_metrics = []
-
- env_kwargs = {
- "num_agents": args.num_agents,
- "num_maps": args.num_maps,
- "max_controlled_agents": args.max_controlled_agents,
- "report_interval": 1,
- "scenario_length": 91,
- "adaptive_driving_agent": args.adaptive_driving_agent,
- "k_scenarios": args.k_scenarios,
- "dynamics_model": args.dynamics_model,
- "co_player_cond_type": args.condition_type,
- "co_player_cond_entropy_ub": 0.05,
- "co_player_cond_discount_lb": 0.40,
- }
-
- print("Running evaluation...")
- for batch_idx in range(num_batches):
- batch_rollouts = min(args.batch_size, args.num_rollouts - batch_idx * args.batch_size)
-
- print(f"Batch {batch_idx + 1}/{num_batches} ({batch_rollouts} rollouts)")
-
- # Find largest valid num_workers (divisor of batch_rollouts)
- max_workers = min(args.num_workers, batch_rollouts)
- while batch_rollouts % max_workers != 0:
- max_workers -= 1
-
- vecenv = pufferlib.vector.make(
- make_env,
- env_kwargs=env_kwargs,
- backend=pufferlib.vector.Multiprocessing,
- num_envs=batch_rollouts,
- num_workers=max_workers,
- )
-
- obs, _ = vecenv.reset()
- total_agents = obs.shape[0]
-
- state = {
- "lstm_h": torch.zeros(total_agents, policy.hidden_size, device=args.device),
- "lstm_c": torch.zeros(total_agents, policy.hidden_size, device=args.device),
- }
-
- batch_infos = []
- total_reward = np.zeros(total_agents)
- scenario_infos = [] # Track infos per scenario
-
- with torch.no_grad():
- # Run through all scenarios (1 for non-adaptive, k for adaptive)
- for scenario in range(args.k_scenarios):
- scenario_info_list = []
- desc = f" Scenario {scenario + 1}/{args.k_scenarios}" if args.k_scenarios > 1 else " Steps"
- for t in tqdm(range(91), desc=desc, ncols=80, leave=False):
- obs_t = torch.as_tensor(obs, device=args.device)
- logits, _ = policy.forward_eval(obs_t, state)
- action, _, _ = pufferlib.pytorch.sample_logits(logits)
-
- obs, reward, done, trunc, info = vecenv.step(action.cpu().numpy())
- total_reward += reward
-
- if info:
- valid_infos = [inf for inf in info if "score" in inf]
- batch_infos.extend(valid_infos)
- scenario_info_list.extend(valid_infos)
-
- # Store per-scenario infos
- if args.adaptive_driving_agent:
- scenario_infos.append(scenario_info_list)
-
- vecenv.close()
-
- # Aggregate batch metrics
- num_infos = len(batch_infos) or 1
- batch_metrics = {
- "score": sum(info.get("score", 0) for info in batch_infos) / num_infos,
- "collision_rate": sum(info.get("collision_rate", 0) for info in batch_infos) / num_infos,
- "offroad_rate": sum(info.get("offroad_rate", 0) for info in batch_infos) / num_infos,
- "completion_rate": sum(info.get("completion_rate", 0) for info in batch_infos) / num_infos,
- "dnf_rate": sum(info.get("dnf_rate", 0) for info in batch_infos) / num_infos,
- "avg_collisions_per_agent": sum(info.get("avg_collisions_per_agent", 0) for info in batch_infos)
- / num_infos,
- "avg_offroad_per_agent": sum(info.get("avg_offroad_per_agent", 0) for info in batch_infos) / num_infos,
- }
-
- rollout_rewards = total_reward.reshape(batch_rollouts, args.num_agents)
- rollout_returns = rollout_rewards.mean(axis=1)
-
- all_returns.extend(rollout_returns.tolist())
- all_metrics.append(batch_metrics)
-
- # Store scenario-specific metrics for adaptive agents
- if args.adaptive_driving_agent:
- batch_scenario_metrics = []
- for scenario_info_list in scenario_infos:
- num_scenario_infos = len(scenario_info_list) or 1
- scenario_metrics = {
- "score": sum(info.get("score", 0) for info in scenario_info_list) / num_scenario_infos,
- "collision_rate": sum(info.get("collision_rate", 0) for info in scenario_info_list)
- / num_scenario_infos,
- "offroad_rate": sum(info.get("offroad_rate", 0) for info in scenario_info_list)
- / num_scenario_infos,
- "completion_rate": sum(info.get("completion_rate", 0) for info in scenario_info_list)
- / num_scenario_infos,
- "dnf_rate": sum(info.get("dnf_rate", 0) for info in scenario_info_list) / num_scenario_infos,
- "num_goals_reached": sum(info.get("num_goals_reached", 0) for info in scenario_info_list)
- / num_scenario_infos,
- "lane_alignment_rate": sum(info.get("lane_alignment_rate", 0) for info in scenario_info_list)
- / num_scenario_infos,
- "avg_displacement_error": sum(info.get("avg_displacement_error", 0) for info in scenario_info_list)
- / num_scenario_infos,
- "episode_return": sum(info.get("episode_return", 0) for info in scenario_info_list)
- / num_scenario_infos,
- }
- # Compute perf (score without collision before goal)
- scenario_metrics["perf"] = scenario_metrics["score"]
- batch_scenario_metrics.append(scenario_metrics)
-
- all_scenario_metrics.append(batch_scenario_metrics)
-
- # Aggregate across all batches
- all_returns = np.array(all_returns)
-
- # Divide by k_scenarios since we accumulated across all scenarios
- all_returns = all_returns / args.k_scenarios
-
- metrics = {k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0].keys()}
-
- results = {
- "avg_return": float(np.mean(all_returns)),
- "std_return": float(np.std(all_returns)),
- "se_return": float(np.std(all_returns) / np.sqrt(len(all_returns))), # Standard error
- **{k: float(v) for k, v in metrics.items()},
- }
-
- # Compute adaptive delta metrics from scenario metrics
- first_scenario_metrics = None
- last_scenario_metrics = None
-
- if args.adaptive_driving_agent and len(all_scenario_metrics) > 0:
- # Aggregate scenario metrics across all batches
- # all_scenario_metrics is a list of [batch][scenario] metrics
- # We need to compute average for each scenario across all batches
-
- aggregated_scenario_metrics = []
- for scenario_idx in range(args.k_scenarios):
- scenario_metrics_list = [batch[scenario_idx] for batch in all_scenario_metrics]
-
- # Average each metric across batches for this scenario
- avg_scenario_metrics = {}
- for key in scenario_metrics_list[0].keys():
- avg_scenario_metrics[key] = np.mean([m[key] for m in scenario_metrics_list])
-
- aggregated_scenario_metrics.append(avg_scenario_metrics)
-
- # Get first and last scenario metrics
- first_scenario_metrics = aggregated_scenario_metrics[0]
- last_scenario_metrics = aggregated_scenario_metrics[-1]
-
- # Helper function to compute delta percentage
- def compute_delta_percent(first_val, last_val):
- if abs(first_val) < 0.0001:
- return 0.0
- return (last_val - first_val) / first_val * 100.0
-
- # Compute all delta metrics
- results["ada_delta_completion_rate"] = compute_delta_percent(
- first_scenario_metrics["completion_rate"], last_scenario_metrics["completion_rate"]
- )
- results["ada_delta_score"] = compute_delta_percent(
- first_scenario_metrics["score"], last_scenario_metrics["score"]
- )
- results["ada_delta_perf"] = compute_delta_percent(first_scenario_metrics["perf"], last_scenario_metrics["perf"])
- results["ada_delta_collision_rate"] = compute_delta_percent(
- first_scenario_metrics["collision_rate"], last_scenario_metrics["collision_rate"]
- )
- results["ada_delta_offroad_rate"] = compute_delta_percent(
- first_scenario_metrics["offroad_rate"], last_scenario_metrics["offroad_rate"]
- )
- results["ada_delta_num_goals_reached"] = compute_delta_percent(
- first_scenario_metrics["num_goals_reached"], last_scenario_metrics["num_goals_reached"]
- )
- results["ada_delta_dnf_rate"] = compute_delta_percent(
- first_scenario_metrics["dnf_rate"], last_scenario_metrics["dnf_rate"]
- )
- results["ada_delta_lane_alignment_rate"] = compute_delta_percent(
- first_scenario_metrics["lane_alignment_rate"], last_scenario_metrics["lane_alignment_rate"]
- )
- results["ada_delta_avg_displacement_error"] = compute_delta_percent(
- first_scenario_metrics["avg_displacement_error"], last_scenario_metrics["avg_displacement_error"]
- )
- results["ada_delta_episode_return"] = compute_delta_percent(
- first_scenario_metrics["episode_return"], last_scenario_metrics["episode_return"]
- )
-
- # Store first and last scenario values for reporting
- results["first_scenario_score"] = float(first_scenario_metrics["score"])
- results["first_scenario_collision_rate"] = float(first_scenario_metrics["collision_rate"])
- results["first_scenario_offroad_rate"] = float(first_scenario_metrics["offroad_rate"])
- results["first_scenario_episode_return"] = float(first_scenario_metrics["episode_return"])
-
- results["last_scenario_score"] = float(last_scenario_metrics["score"])
- results["last_scenario_collision_rate"] = float(last_scenario_metrics["collision_rate"])
- results["last_scenario_offroad_rate"] = float(last_scenario_metrics["offroad_rate"])
- results["last_scenario_episode_return"] = float(last_scenario_metrics["episode_return"])
-
- with open(args.output, "w") as f:
- json.dump(results, f, indent=2)
-
- print(f"\nResults:")
- print(f" Return: {results['avg_return']:.2f} ± {results['se_return']:.2f} (SE)")
- print(f" Score: {results['score']:.3f}")
- print(f" Completion: {results['completion_rate']:.3f}")
- print(f" Collision: {results['collision_rate']:.3f}")
- print(f" Collision per agent: {results['avg_collisions_per_agent']:.3f}")
- print(f" Offroad: {results['offroad_rate']:.3f}")
-
- if args.adaptive_driving_agent:
- print(f"\n0-Shot Performance (First Scenario):")
- print(f" Score: {results['first_scenario_score']:.3f}")
- print(f" Collision: {results['first_scenario_collision_rate']:.3f}")
- print(f" Offroad: {results['first_scenario_offroad_rate']:.3f}")
- print(f" Return: {results['first_scenario_episode_return']:.2f}")
-
- print(f"\nAdapted Performance (Last Scenario):")
- print(f" Score: {results['last_scenario_score']:.3f}")
- print(f" Collision: {results['last_scenario_collision_rate']:.3f}")
- print(f" Offroad: {results['last_scenario_offroad_rate']:.3f}")
- print(f" Return: {results['last_scenario_episode_return']:.2f}")
-
- print(f"\nAdaptive Metrics (Delta %):")
- print(f" Score: {results['ada_delta_score']:.2f}%")
- print(f" Collision rate: {results['ada_delta_collision_rate']:.2f}%")
- print(f" Offroad rate: {results['ada_delta_offroad_rate']:.2f}%")
- print(f" Episode return: {results['ada_delta_episode_return']:.2f}%")
-
- # Generate visualization
- plot_adaptive_metrics(first_scenario_metrics, last_scenario_metrics, results, args.output)
-
- print(f"\nSaved to {args.output}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/external/pyxodr b/external/pyxodr
new file mode 160000
index 0000000000..cd4b837a65
--- /dev/null
+++ b/external/pyxodr
@@ -0,0 +1 @@
+Subproject commit cd4b837a651d4f10c3c4e77b04a029cac367c64b
diff --git a/notes/nuplan_hard.md b/notes/nuplan_hard.md
new file mode 100644
index 0000000000..272f883189
--- /dev/null
+++ b/notes/nuplan_hard.md
@@ -0,0 +1,150 @@
+# `nuplan_hard` map split
+
+## Why this exists
+
+Default eval (`--eval.map-dir resources/drive/binaries/nuplan_201`, 50 rollouts) shows
+`ada_delta_score ≈ ±0.005` across all configs. Looks like the policy doesn't
+adapt across scenarios.
+
+Diagnosis (from `score_maps_interaction.py` over the 5401 maps in nuplan_201):
+
+| Property | Count |
+|----------|------:|
+| Total maps | 5401 |
+| Maps where SDC has *any* moving vehicle within 5m at any timestep | 2497 (46%) |
+| Maps where SDC has zero such interaction-steps | **2904 (54%)** |
+
+**Over half of all maps have zero ego-other interaction.** On those maps the
+ego just drives straight and `ada_delta` collapses to noise. The signal exists,
+but the average is dominated by maps where adaptation is irrelevant.
+
+## What `nuplan_hard` is
+
+A symlink-only directory at `resources/drive/binaries/nuplan_hard/` containing
+the **top 10% of nuplan_201 maps by SDC-vehicle interaction density** (540 maps,
+sequentially renumbered as `map_001.bin` through `map_540.bin`).
+
+Threshold: maps with `sdc_interaction_steps ≥ 52`. Mean in the hard set: 79
+steps (out of 91-201 timesteps per scenario).
+
+Each `map_NNN.bin` is a symlink pointing back to the original map in
+`nuplan_201/`. The mapping (new bin id ↔ original bin id ↔ scenario_id) is
+recorded in `_manifest.csv` for traceability.
+
+## How to recreate
+
+The split is reproducible. To regenerate from the underlying JSON
+trajectories in `data/nuplan_gpudrive/nuplan/`:
+
+```bash
+# 1. Score every map by SDC interaction-steps (~3 min, multi-process)
+python scripts/score_maps_interaction.py \
+ --data-dir data/nuplan_gpudrive/nuplan \
+ --out /tmp/nuplan_201_hardness_scores.csv
+
+# 2. Build the symlink directory from the top-10% slice
+python scripts/build_nuplan_hard.py \
+ --scores /tmp/nuplan_201_hardness_scores.csv \
+ --source-dir resources/drive/binaries/nuplan_201 \
+ --out-dir resources/drive/binaries/nuplan_hard \
+ --metric sdc_interaction_steps \
+ --top-pct 10
+```
+
+Different threshold or metric? Pass other values to step 2 (`--top-pct 25`,
+`--metric interaction_events`, etc.). The scoring step doesn't have to be
+re-run — its CSV holds both metrics for every map.
+
+## Hardness definition (operational)
+
+For each scenario in the JSON dataset:
+
+1. Filter to vehicle-type agents only (excludes pedestrians, cyclists).
+2. Identify the SDC (the `is_sdc=True` agent — this is the slot the ego
+ policy occupies during human-replay eval).
+3. For each timestep `t`:
+ - Skip if SDC is not valid at `t`, or is parked (speed < 0.5 m/s).
+ - Find any other vehicle that is valid AND moving (speed > 0.5 m/s)
+ AND within Euclidean distance 5 m of the SDC.
+ - If at least one such vehicle exists, count this timestep as an
+ SDC interaction-step.
+4. `sdc_interaction_steps` = total such timesteps for the scenario.
+
+This is purely a property of the map + recorded human trajectories — no
+trained policy is involved. A map is "hard" iff the SDC's logged trajectory
+brings it close to another moving vehicle, often. That's exactly when our
+trained ego (which replaces the SDC at eval time) faces real interactions
+to adapt to.
+
+## Eval against `nuplan_hard`
+
+Stock command (replace checkpoint and GPU):
+
+```bash
+xvfb-run -a puffer eval puffer_adaptive_drive \
+ --load-model-path experiments/puffer_adaptive_drive_/model_..._N.pt \
+ --policy-architecture Transformer --rnn-name Transformer \
+ --train.horizon 402 \
+ --vec.num-workers 1 --vec.num-envs 1 \
+ --env.map-dir resources/drive/binaries/nuplan_hard \
+ --env.num-maps 540 \
+ --env.scenario-length 201 \
+ --env.k-scenarios 2 \
+ --env.conditioning.type none \
+ --env.co-player-enabled 1 \
+ --env.co-player-policy.policy-path experiments/puffer_drive_2e029h15.pt \
+ --env.co-player-policy.architecture Transformer \
+ --env.co-player-policy.transformer.horizon 201 \
+ --env.co-player-policy.conditioning.type all \
+ --env.co-player-policy.conditioning.collision-weight-lb -2 \
+ --env.co-player-policy.conditioning.collision-weight-ub 0 \
+ --env.co-player-policy.conditioning.offroad-weight-lb -2 \
+ --env.co-player-policy.conditioning.offroad-weight-ub 0 \
+ --env.co-player-policy.conditioning.entropy-weight-lb 0 \
+ --env.co-player-policy.conditioning.entropy-weight-ub 0.10 \
+ --env.co-player-policy.conditioning.discount-weight-lb 0.4 \
+ --env.co-player-policy.conditioning.discount-weight-ub 1 \
+ --env.map-rand-per-scenario False \
+ --eval.map-dir resources/drive/binaries/nuplan_hard \
+ --eval.num-maps 540 \
+ --eval.human-replay-eval True \
+ --eval.human-replay-num-rollouts 200 \
+ --eval.human-replay-num-maps 540 \
+ --eval.human-replay-num-agents 540
+```
+
+Note `--eval.human-replay-num-{rollouts,maps,agents}` are required:
+without them, eval defaults to 100 maps / 100 rollouts which both
+under-samples and may not all hit the hard split.
+
+## Result on baseline checkpoint
+
+On `puffer_adaptive_drive_4lm6kkh7/model_..._40.pt` (γ=0.995, partner 2e029h15)
+with 200 rollouts:
+
+| Metric | Full set (50 rollouts) | `nuplan_hard` (200 rollouts) |
+|--------|-----------------------:|-----------------------------:|
+| `ada_delta_score` | ±0.005 | **+0.222 ± 0.18** |
+| `ada_delta_collision_rate` | ~0 | -0.045 |
+| `ada_delta_episode_return` | ~0 | +0.449 |
+| `ada_delta_dnf_rate` | ~0 | -0.170 |
+| `scenario_0_score` | — | 0.712 |
+| `scenario_1_score` | — | **0.935** |
+
+The model **does** adapt (s_1 score 22% higher than s_0; collisions cut roughly
+in half). The full-set average diluted the signal by 40-200×.
+
+## Caveats / things to keep in mind
+
+- "Hard" here is a property of the SDC's *recorded* trajectory in the
+ human-replay data. The trained ego is free to take a different path; a
+ defensive policy may avoid the close-proximity moments that scored the map
+ hard in the first place. The split is a *prior* over which maps are likely
+ to involve interaction, not a guarantee.
+- Threshold (5 m, 0.5 m/s, top 10%) is somewhat arbitrary. Re-running with
+ different thresholds will produce different splits — `_manifest.csv` records
+ the actual threshold used.
+- Eval channel is human-replay (other vehicles = recorded humans, no synthetic
+ partner). The hard-set definition matches this channel — for synthetic-
+ partner eval channels you'd want a different definition (e.g.
+ partner-sensitivity-based).
diff --git a/notes/oracle_partner_conditioning_investigation.md b/notes/oracle_partner_conditioning_investigation.md
new file mode 100644
index 0000000000..c75d29d7c8
--- /dev/null
+++ b/notes/oracle_partner_conditioning_investigation.md
@@ -0,0 +1,384 @@
+# Why doesn't the adaptive ego use its K/V cache? (and the oracle test plan)
+
+Investigation date: 2026-05-01.
+Author: Claude (working with Mohit on the adaptive-driving paper).
+
+---
+
+## Empirical setup that triggered this investigation
+
+We trained two adaptive ego policies against the entropy-sweep partner
+`6rauydj2` (trained on `entropy_ub ∈ [0, 0.5]`):
+
+- **`curr_e0.5`** (wandb `hprfn8dc`) — entropy-curriculum on (ub
+ annealed 0.025 → 0.10 → 0.25 → 0.50 across the first ~120 episodes).
+- **`nocurr_e0.5`** (wandb `7wm1sk5v`) — entropy-curriculum off
+ (ub fixed at 0.50 throughout).
+
+Both trained from scratch to ~epoch 203 (resumed from intermediate
+checkpoints after host OOMs; resume restores optimizer + global_step
+cleanly via `pufferl.py:280`). Same partner, same seed, same map dir,
+same rollout config — **the only varying factor is the curriculum**.
+
+The training-time `ada_delta_score` favored `curr` over `nocurr`. We
+ran three offline analyses on the final checkpoints to understand
+**why** and **whether the policy is actually using the cache for
+adaptation**.
+
+---
+
+## Three offline analyses (all on 300 truly-held-out scenes)
+
+Held-out set: `resources/drive/binaries/nuplan_201_heldout300/` —
+symlinked `map_001..300` → originals `map_5102..5401` (training used
+the first 4999, so these are unseen).
+
+### Analysis 1: head-to-head eval (`puffer eval`, no co-player)
+
+```
+metric nocurr_e05 curr_e05 delta
+─────────────────────────────────────────────────────────────────────
+s0_rate 0.5906 0.6913 +0.1007
+s1_rate 0.5906 0.6946 +0.1040
+ada_delta 0.0000 0.0034 +0.0034
+p_s1_given_s0_pass 0.9830 0.9854 +0.0025
+p_s1_given_s0_fail 0.0246 0.0435 +0.0189
+n_s0_pass 176 206 +30
+n_s0_fail 122 92 -30
+```
+
+Curr is ~10 pp better in absolute success rate in BOTH scenarios on
+held-out scenes. ada_delta is ~0 for both — the curriculum gain is a
+*better policy* gain, not an *adaptation* gain. Recovery rate
+(`p_s1_given_s0_fail`) almost doubled with curriculum (2.5% → 4.4%) but
+both numbers are tiny in absolute terms.
+
+### Analysis 2: attention probe (`scripts/probe_attention.py`)
+
+Cross-scenario attention mass = fraction of attention from s_1 query
+positions to s_0 cache slots. Recorded per (layer, head) over a single
+rollout against `6rauydj2`.
+
+| Layer/Head | curr_e05 | nocurr_e05 |
+|------------|----------|------------|
+| L0H0 | 0.44 | 0.68 |
+| L0H1 | 0.51 | 0.69 |
+| L0H2 | 0.55 | 0.67 |
+| L0H3 | 0.49 | 0.67 |
+| L1H0 | **0.14** | 0.66 |
+| L1H1 | 0.35 | 0.70 |
+| L1H2 | **0.05** | 0.63 |
+| L1H3 | 0.23 | 0.65 |
+
+`curr` attends LESS to past than `nocurr` across every head — opposite
+of the cache-use hypothesis.
+
+### Analysis 3: counterfactual cache (`scripts/counterfactual_cache.py`)
+
+Paired rollouts: same map, same seed, same partner conditioning.
+Condition A preserves the K/V cache from s_0 → s_1; Condition B zeros
+it. Difference isolates "does cache CONTENT matter?"
+
+| Metric | curr_e05 (z) | nocurr_e05 (z) |
+|-------------------------------------|--------------|----------------|
+| Paired s1 lift (preserved − zeroed) | -0.007 (-0.45) | +0.030 (+1.41) |
+| P(s1 \| s0 fail) preserved | 0.30 | 0.14 |
+| P(s1 \| s0 fail) zeroed | 0.30 | 0.10 |
+
+`curr` is unaffected by cache zeroing (z = -0.45). `nocurr` does feel
+it — recovery rate drops ~30% relative when cache is wiped.
+
+### Joint reading
+
+| Question | curr_e0.5 | nocurr_e0.5 |
+|-----------------------------------|-------------------|--------------|
+| Better single-shot driver? | Yes (69% vs 59%) | No |
+| Attends to past more? | No (0.05–0.55) | Yes (0.63+) |
+| Cache content load-bearing? | No (z = -0.45) | Marginally yes (z = +1.41) |
+
+The curriculum produced a stronger driver that **ignores** the cache.
+The non-curriculum policy attends and (slightly) uses the cache, but
+that doesn't translate into better absolute scores.
+
+**Either way, neither policy has meaningful in-context adaptation.**
+
+---
+
+## First-principles teardown: what does adaptation actually require?
+
+For the ego to "adapt to partner P across scenarios" via cache, ALL of
+these must hold:
+
+1. **Partner type is observable** — some signal in ego's obs encodes
+ partner type.
+2. **The encoder preserves it** — the encoder doesn't lose the
+ discriminating info during compression.
+3. **The hidden states get written to cache** — automatic given the
+ architecture.
+4. **The policy queries the cache** — attention pattern looks back.
+5. **The retrieved values still carry the info** — what's stored is
+ still partner-relevant.
+6. **The policy conditions on cache content** — the policy learns to
+ change its action distribution based on cache values.
+
+We have evidence on each link:
+
+- (1) **Weak**: partner conditioning is NOT in ego obs. The ego sees
+ the partner's `(rel_x, rel_y, rel_heading_x, rel_heading_y, width,
+ length, rel_signed_speed)` — instantaneous behavior only. To infer
+ partner type (e.g., `entropy_weight`), the ego must integrate
+ *behavioral variance* over many timesteps. From a single snapshot,
+ partner type is unobservable.
+- (2) **Suspicious**: encoder compresses 1855 → 256. Partner-typing
+ info needs to survive this bottleneck, but the only push for it is
+ RL gradient — which is weak unless the policy is already using
+ partner type info.
+- (3) ✓
+- (4) ✓ (probe confirmed)
+- (5) **Likely the break**: counterfactual cache shows zeroing the
+ cache barely changes s_1 — the values don't carry usable info.
+- (6) **The break**: if the values aren't useful, conditioning on
+ them can't help.
+
+### Three specific architectural concerns
+
+**(a) Observations are in EGO FRAME**
+Every partner feature (`rel_x`, `rel_heading_x`, `rel_signed_speed`)
+is computed relative to the ego's CURRENT pose. As the ego moves, the
+same world-frame partner behavior produces different obs values. The
+K/V cache stores ego-frame snapshots, but the ego frame is
+non-stationary across scenarios. To compare partner behavior across
+the scenario boundary, the policy would have to implicitly invert the
+ego frame change — a hard sub-problem.
+
+**(b) Partner identity is unlabeled**
+Slot K of partner_obs is "the K-th nearby agent in
+`active_agent_indices` iteration order." If 60 partners are within
+visibility, slot 5 could be partner-id-42 at one step and same agent
+at the next (stable for stationary scenes), but there's no identity
+TAG. Comparing slot K values across timesteps requires the policy to
+verify that's the same partner — an implicit sub-problem.
+
+**(c) The latent we want to identify is BEHAVIORAL VARIANCE**
+A high-entropy partner samples actions stochastically. Its
+instantaneous position+speed look identical to a low-entropy partner.
+The DIFFERENCE only shows up as "the partner did something
+unexpected" — which requires the policy to have a forward model of
+what the partner *should* be doing under each conditioning value, then
+notice deviations. That's an enormous hidden inference problem, and
+nothing pushes the policy to learn the forward model.
+
+### Conclusion
+
+End-to-end, we're asking the policy to learn (from sparse goal-reach
+RL reward) to:
+1. Track partner identities across timesteps (no ID labels)
+2. Undo ego-frame motion to recover world-frame trajectories
+3. Estimate behavioral variance per partner over many timesteps
+4. Map variance → partner-type embedding
+5. Use that embedding to modulate actions
+
+That's a tower of hard implicit inference problems on top of "drive a
+car well." The optimization just doesn't push hard enough for steps
+1-4, especially when most of the reward can be earned by single-shot
+driving alone. This explains the empirical finding: the better
+single-shot driver (curr) ignores the cache, because the gradient
+toward "be a better driver" is much stronger than the gradient toward
+"learn to read the cache."
+
+---
+
+## What's already in the codebase that could help
+
+`grep -rin oracle …` returns NO hits anywhere in `/workspace/ADA`.
+
+The closest thing is the `[env.conditioning]` section in
+`adaptive.ini` — but this is **per-ego self-conditioning** (each ego
+agent gets its OWN conditioning vector that determines its OWN reward
+function: `reward_collision_weight`, `entropy_weight` for its own
+exploration, etc.). It is plumbed end-to-end:
+
+| Layer | What happens |
+|--------------------|--------------------------------------------------|
+| Python `Drive.__init__` | kwarg `conditioning={}` → `self.entropy_conditioned`, etc. (drive.py:60-167) |
+| Adds `conditioning_dims` to `self.ego_features` so `obs_dim` accounts for the slots (drive.py:177) |
+| C `env->collision_weights[i]` etc. allocated per active ego agent (drive.h:1841-1844) |
+| C samples weights at episode reset (drive.h:2486-2494) |
+| C appends weights to each ego's obs in `compute_observations` (drive.h:2294-2302) |
+
+This is NOT what we want, but the slots exist and are reachable. We
+can **hijack them**: keep the obs format the same, but at sample time,
+overwrite the per-ego conditioning weights with this env's partner
+conditioning vector. Each ego in env E sees the same vector its
+partner is using.
+
+---
+
+## Plan: oracle test (give ego the partner's conditioning)
+
+### Goal
+
+A clean diagnostic: if the ego is just *handed* the partner's
+conditioning vector as obs, does the policy then learn to adapt
+(drive differently when entropy_weight is high vs low)?
+
+- **If yes**: the architecture downstream is fine. The bottleneck is
+ inference from behavior. Then we can pursue real fixes (per-partner
+ history features, world-frame obs, partner-attention layer, etc.).
+- **If no**: the bottleneck is downstream — even with explicit
+ partner-type signal, the policy can't condition action on it. Then
+ we need to look at the policy head, the action distribution, or
+ the loss formulation.
+
+### C-side dive: where do the existing conditioning slots come from?
+
+The per-ego conditioning weights `env->{collision,offroad,goal,entropy,discount}_weights[i]` are used in **two places** in `drive.h`:
+
+1. **obs append** (lines 2294-2302) — appended to ego obs after `base_ego_dim`.
+2. **reward computation** (lines 2625, 2637, 2669, 2679, 2691) — the per-step
+ reward at collision / offroad / goal events SCALES with these weights.
+
+This means we **cannot just enable `[env.conditioning].type=all` and
+overwrite the C arrays** — that would also change the ego's reward
+function (e.g., sampling collision_weight ∈ [-1, 0] makes some
+collisions free).
+
+### Pufferl-side dive: more silent uses of these slots
+
+Even worse, **pufferl reads two of these slots from obs to drive
+training dynamics**:
+
+| Slot | Read by | Becomes |
+|------|---------|---------|
+| collision_w, offroad_w, goal_w | (only C-side reward) | per-agent reward scaling |
+| entropy_w (slot at offset 12 with reward+ego) | `pufferl.py:885` | per-agent entropy weight in PPO loss |
+| discount_w (slot at offset 13 with reward+ego) | `pufferl.py:743` | per-agent γ in GAE advantage computation |
+
+The `pufferl` hooks are gated on `entropy_conditioned` /
+`discount_conditioned` flags on the env. If we naively enable
+`[env.conditioning].type=all` to get the obs slots and then override
+them in Python, pufferl will use the **partner's** entropy/discount
+values as the **ego's** PPO hyperparameters — a major silent training
+distortion.
+
+### The complete "no behavior change" recipe
+
+To get the obs slots while preserving the no-conditioning training
+dynamics:
+
+1. **Force `[env.conditioning].type = "all"`** in the env init (from
+ the `ego_is_oracle` flag), so 5 obs slots are allocated and the C
+ `env->*_weights[i]` arrays exist.
+2. **Pin all 5 sampling ranges to lb=ub=default** so the C-side
+ per-agent samples are constant and equal to the default reward
+ weights. Reward computation is identical to a non-conditioned run.
+3. **After C-side setup, in Python, flip
+ `self.entropy_conditioned = False` and
+ `self.discount_conditioned = False`** (keep `reward_conditioned =
+ True` — it has no pufferl-side hook). Pufferl now sees False on
+ these flags and skips the per-agent γ / entropy hooks → uses
+ global defaults.
+4. **In step() / reset(), Python overrides the 5 obs slots in each
+ ego row with `env_conditioning[env_of_ego]`** — the partner's
+ conditioning vector. This is pure obs signal: pufferl no longer
+ reads these slots, the policy just sees them as input.
+5. **Bypass the `line 350` `NotImplementedError`** that forbids dual
+ ego+co-player conditioning. The check predates oracle and is not
+ protecting any actual coupling for our setup.
+
+This makes the diff one-flag, no-C, contained:
+- One env kwarg + ini knob (`ego_is_oracle`)
+- Init-time validation + flag fix-up + bypass
+- One step()-time obs override loop
+- One sentence in the launcher to set the deterministic-default lb/ub on the 5 dims
+
+### Implementation (Python obs override, no C changes)
+
+1. **New env kwarg** `ego_is_oracle = False` (default off).
+2. **Adaptive.ini knob** `ego_is_oracle = False` so CLI auto-generates
+ the `--env.ego-is-oracle` flag.
+3. **In `step()`** (drive.py, after `binding.vec_step(self.c_envs)`,
+ before return), if `ego_is_oracle` is on AND `env_conditioning`
+ has slots:
+ - For each env e, for each ego in env e, overwrite the
+ `[base_ego_dim : base_ego_dim + n_partner_dims]` slice of
+ `self.observations[ego_id]` with `self.env_conditioning[e]`.
+ - Same for the post-reset obs (need to call after `_set_env_variables` too).
+4. **In `_add_co_player_conditioning`-style helper**: factor the
+ slot-locating logic so we can reuse it for ego.
+5. **Launcher**: must set `--env.conditioning.type all` AND
+ `--env.conditioning.{collision,offroad,goal,entropy,discount}-weight-{lb,ub}`
+ to default values (so the C-side reward isn't changed):
+ ```
+ --env.conditioning.collision-weight-lb -0.5 --env.conditioning.collision-weight-ub -0.5
+ --env.conditioning.offroad-weight-lb -0.5 --env.conditioning.offroad-weight-ub -0.5
+ --env.conditioning.goal-weight-lb 1.0 --env.conditioning.goal-weight-ub 1.0
+ --env.conditioning.entropy-weight-lb 0.001 --env.conditioning.entropy-weight-ub 0.001
+ --env.conditioning.discount-weight-lb 0.98 --env.conditioning.discount-weight-ub 0.98
+ ```
+ These are the same defaults as the no-conditioning case, so reward
+ semantics are identical to a non-conditioned run. The slots in obs
+ are then allocated but their values are overwritten by Python with
+ the partner's per-env conditioning vector.
+
+### Caveats / things to verify
+
+- **Slot offset & width**: ego's conditioning width = 3 (reward) + 1
+ (entropy) + 1 (discount) = 5 when type="all". Partner's width must
+ match — also "all" → 5. We assert this at init.
+- **The partner conditioning sample is per-env, valid for the whole
+ episode.** Need to inject the same vector EVERY step (overwrite obs
+ slots in step(), not just at reset).
+- **The model checkpoint obs_dim with type=all is 1855** (ego_dim=14
+ = base 9 + cond 5, partner_obs and road_obs unchanged). Fresh
+ training run from scratch. Cannot load the curr/nocurr e=0.5
+ checkpoints into this — their obs_dim is also 1855 but their slots
+ contained their own conditioning, not the partner's; the policy
+ weights have learned an interpretation that doesn't transfer.
+ Train fresh.
+
+### What success looks like
+
+If the oracle test works, we expect to see — vs the no-oracle baseline:
+- ada_delta improves significantly (e.g., +0.05+ instead of ~0)
+- Per-condition behavior visible: rollouts at extreme partner
+ conditioning (e.g., entropy=0.5 vs entropy=0.0) should produce
+ visibly different ego policies (different distance kept from
+ partner, different speeds, etc.)
+- Counterfactual: zeroing the partner conditioning slots (not the
+ full cache, just the oracle slots) should hurt s_0 score because
+ the ego is now driving "blind" to partner type
+
+### Run plan once implemented
+
+User decision: **only do the curr experiment** — single oracle run, no
+nocurr-baseline pairing. The comparison is against the existing
+`curr_e0.5` (`hprfn8dc`) which has identical config except no oracle.
+
+1. **Wire the flag** (drive.py + adaptive.ini).
+2. **Smoke test** — first 30 sec of a launch, confirm:
+ - obs slot layout is correct
+ - the partner conditioning is actually being injected (print first
+ few obs slot values, verify they match `env_conditioning[e]`)
+ - training doesn't crash
+3. **Full oracle run** — same config as `curr_e0.5`:
+ - Partner: `6rauydj2` at `e_ub=0.5`
+ - k=2/201, horizon=402, nw=24 nv=24 (or 32/32 if k_eff has freed
+ RAM by then)
+ - Entropy curriculum: ON (mirror `curr_e0.5`)
+ - Ego conditioning: type=all with default-value lb=ub
+ - Ego oracle: ON
+4. **Analyses identical to the curr/nocurr runs**:
+ - Eval on 300 held-out scenes (same `nuplan_201_heldout300`
+ symlinked dir we already built)
+ - Counterfactual cache (zero the cache to see if it matters now
+ that the policy has the oracle slot)
+ - Attention probe (does attention still go to past, or does the
+ policy ignore the past now that oracle gives the answer?)
+
+If the oracle policy adapts strongly (`ada_delta` jumps from ~0 to
+0.05+, attention to past drops, conditioning slot is load-bearing),
+the bottleneck is **inference from behavior** and we move on to fixes
+like per-partner history features. If the oracle policy STILL doesn't
+adapt, the bottleneck is **downstream** (action conditioning) and we
+need to look at the policy head, action distribution, or loss.
diff --git a/notes/paper_plan.md b/notes/paper_plan.md
new file mode 100644
index 0000000000..101806eca7
--- /dev/null
+++ b/notes/paper_plan.md
@@ -0,0 +1,198 @@
+# Adaptive Population Play — paper experiment plan
+
+40-hour execution window. 8 GPUs. Each ego training ≈12h, each
+co-player training ≈4.5h, each hard-map eval ≈4 min (M=10 × 540 maps).
+Budget = 320 GPU-h; T1 ≈ 180, T1+T2 ≈ 300.
+
+## What we already have
+
+- **Adaptive ego baseline** `4lm6kkh7` (γ=0.995, partner `2e029h15`,
+ e_ub=0.10) → `ada_delta = +0.222 ± 0.18` on `nuplan_hard` (200 rollouts).
+- **5 partner-sweep adaptive egos** (one per entropy partner, γ=0.995,
+ lane=0.025) — final ckpts at epoch 152. Hard-map eval (M=10):
+
+ | partner | e_ub | ada_delta ± std |
+ |---------|-----:|----------------:|
+ | miku2puk | 0.05 | **+0.155 ± 0.008** |
+ | 2e029h15 | 0.10 | **+0.154 ± 0.085** |
+ | m2ygolog | 0.20 | -0.010 ± 0.016 |
+ | 6rauydj2 | 0.50 | -0.001 ± 0.008 |
+ | n48teqjs | 1.00 | -0.000 ± 0.003 |
+
+- **15-partner DE-grid in progress** on GPUs 0-4, ETA 22:10 UTC.
+ Grid: `discount_lb ∈ {0.4, 0.6, 0.8} × entropy_ub ∈ {0.001, 0.01, 0.05,
+ 0.1, 0.2}`. Wandb: `ada_coplayer_sweep`.
+- **`nuplan_hard` map split** (top 10% by SDC interaction), recipe
+ recorded in `notes/nuplan_hard.md`.
+- **Fast eval recipe**: M=10, vanilla code, 5390 (map, ego) datapoints/eval,
+ SE on mean ≈0.003.
+
+## Key finding driving the plan
+
+**Adaptation only emerges against deterministic partners (e ≤ 0.10).**
+For e ≥ 0.20 the gap collapses. This means:
+
+- Robustness experiments need to vary the partner distribution along the
+ *deterministic ↔ noisy* axis, not blindly use [0, 1].
+- The "interesting" ada_delta lives in a narrow slice of partner space.
+- We should structure the partner-distribution test to span that slice
+ cleanly.
+
+---
+
+## Tier 1 — must-have for paper
+
+### T1.1 Main result table: 4 ablation conditions vs adaptive-conditioned baseline
+
+Using a single fixed partner (`2e029h15`, e=0.10 — matches the
+`4lm6kkh7` baseline that gave +0.222), train 4 ego variants:
+
+| label | description | training change |
+|-------|-------------|----------------|
+| **A** | Adaptive + conditioned partner | (baseline, already trained) |
+| **B** | Non-adaptive + conditioned partner | reset ego K/V at every scenario boundary (`RECOVERY_CACHE_RESET_PER_SCENARIO=1` env var, or k_scenarios=1) |
+| **C** | Adaptive + fixed-policy partner | partner = single ckpt, ego never sees conditioning slot |
+| **D** | Self-play | both ego and "partner" use the same current policy |
+| **E** | Adaptive + log-replay only | no learned partner, others follow recorded humans |
+
+Cost: 4 trainings × 12h = **48 GPU-h**. Eval: 5 × 4 min on free GPUs.
+
+### T1.2 Scaling along k_scenarios
+
+Train adaptive egos at `k_scenarios ∈ {1, 2, 4}`. k=8 is 48h on its own
+and probably won't fit; defer.
+
+Cost: 12 + 12 + 24 = **48 GPU-h**.
+
+### T1.3 Robustness: in-distribution vs OOD partner distribution
+
+`[TBD — confirm exact ranges]`. Updated based on what we learned about partner space:
+
+- **In-dist baseline**: train ego sampling partner conditioning over
+ `e ∈ [0, 0.20]` (the slice where adaptation happens), test on same range.
+- **OOD test**: train on `e ∈ [0, 0.05]` (very deterministic only), test
+ on `e ∈ [0.10, 0.20]`. Asks: does an ego trained on near-deterministic
+ partners still show ada_delta when faced with a slightly noisier one?
+- **Aggressive vs cautious test**: condition partners with
+ collision_weight_lb=-2 (cautious) vs 0 (aggressive); reuse existing egos.
+
+Cost: 1 new training (the OOD ego) + 2 evals = **12 GPU-h** + **20 min**.
+
+---
+
+## Tier 2 — should-have
+
+### T2.1 Robustness: held-out maps
+
+Eval-only. Build a `nuplan_hard_holdout` split (different 10% of the
+nuplan_201 maps that we did NOT score / use for `nuplan_hard`). Eval the
+5 partner-sweep egos against this set. Cost: ~30 min.
+
+### T2.2 Adversarial co-player
+
+Train a partner with `collision_weight = +2` (rewards collisions); eval
+all egos against it.
+
+Cost: 4.5h partner + 5 evals × 4 min = **5 GPU-h**.
+
+### T2.3 Scaling: number of training maps
+
+Train ego on `nuplan_201` subsets of size 200, 1000, 5000.
+
+Cost: 3 × 12h = **36 GPU-h**. Defer if tight.
+
+### T2.4 Conditioning leave-one-out (5 dims)
+
+Train 5 egos, each with one conditioning dim disabled
+(collision / offroad / goal / entropy / discount).
+
+Cost: 5 × 12h = **60 GPU-h**. Pick 2-3 dims if budget tight.
+
+---
+
+## Tier 3 — defer to post-paper unless time
+
+- LSTM vs Transformer (one extra training)
+- RNN state size sweep
+- Co-player ratio (25 / 50 / 75 %)
+- Diversity of conditioning ranges (wide vs narrow)
+
+---
+
+## Proposed 40h schedule
+
+Times relative to the 09:55 UTC start of this plan. All wall-clock
+estimates assume 12h per ego training and 4.5h per co-player.
+
+### Phase 1 (h0 → h13, overlaps with DE-grid finishing on GPUs 0-4)
+
+Free GPUs: 5, 6, 7. Launch 3 of the 4 T1.1 ablations:
+
+| GPU | run | duration |
+|----:|-----|---------:|
+| 5 | T1.1-B non-adaptive ego vs `2e029h15` | 12h |
+| 6 | T1.1-C adaptive ego vs fixed partner | 12h |
+| 7 | T1.1-D self-play | 12h |
+
+### Phase 2 (h13 → h25, all 8 GPUs free)
+
+DE-grid done at h~12. Launch:
+
+| GPU | run | duration |
+|----:|-----|---------:|
+| 0 | T1.1-E log-replay ego | 12h |
+| 1 | T1.2 scaling k=1 | ~6h |
+| 2 | T1.2 scaling k=4 | ~24h (runs into Phase 3) |
+| 3 | T1.3 OOD ego (train on e ∈ [0, 0.05]) | 12h |
+| 4 | T2.2 adversarial partner (4.5h) → ego | ~16h |
+| 5 | T2.4 leave-one-out: discount | 12h |
+| 6 | T2.4 leave-one-out: entropy | 12h |
+| 7 | T2.4 leave-one-out: collision | 12h |
+
+Run T1.2 k=2 already exists in 4lm6kkh7; reuse for the curve.
+
+### Phase 3 (h25 → h40, finishing + evals)
+
+- k=4 finishes around h37
+- All other Phase 2 trainings done by h25
+- Free GPUs run evals (M=10 × 540 maps × ~4 min each)
+- Build all paper plots:
+ - **Fig 1**: main result bar plot (5 conditions, ada_delta ± std)
+ - **Fig 2**: ada_delta vs partner entropy (uses 5 partner egos +
+ DE-grid eval)
+ - **Fig 3**: ada_delta vs k_scenarios
+ - **Fig 4**: in-dist vs OOD partner distribution
+ - **Fig 5**: leave-one-out conditioning ablation
+
+---
+
+## Open questions to resolve before launching
+
+1. **Self-play definition (T1.1-D)**: ego = partner = same live policy
+ on each step? Or ego = current policy, partner = past-snapshot? The
+ former is "true" self-play but the partner is non-stationary; the
+ latter is more like league play. Pick one.
+2. **Non-adaptive definition (T1.1-B)**: do we set `k_scenarios = 1` (so
+ training never gives the ego cross-scenario context) or keep
+ `k_scenarios = 2` and reset the cache at the boundary? The latter
+ isolates "is it the cache that helps?" cleanly; the former is more
+ honest as a "non-adaptive" baseline. Pick one or do both.
+3. **OOD ranges (T1.3)**: my proposal `[0, 0.05]` train → `[0.10, 0.20]`
+ test is one option. Confirm this is the slice you want — alternatives
+ are deterministic→noisy (`[0, 0.1]` → `[0.5, 1.0]`) which we already
+ know breaks, or wide→narrow (`[0, 1]` → `[0.05, 0.2]`).
+4. **Drop k=8?** Yes unless you have a strong opinion. 48h alone.
+
+Answer 1-4 and I'll start writing the launchers.
+
+---
+
+## Quick links
+
+- DE-grid sweep status: `tmux attach -t coplayer_de_grid`,
+ `/tmp/coplayer_de_grid_driver.log`
+- Hard-map eval recipe: `notes/nuplan_hard.md`
+- Latest ego eval results: `/tmp/eval_partner_sweep_m10.log`
+- 5 partner-sweep ckpts:
+ `experiments/puffer_adaptive_drive_{0bmlyvg3,l7c13x1m,0yih6s9k,sa80qcs2,p6q2b2xp}/model_..._000152.pt`
+- Adaptive launcher template: `scripts/adaptive/nuplan_transformer_local_k2_201_partner_sweep.sh`
diff --git a/notes/trial_episode_design.md b/notes/trial_episode_design.md
new file mode 100644
index 0000000000..8c89d25d8c
--- /dev/null
+++ b/notes/trial_episode_design.md
@@ -0,0 +1,202 @@
+# Trial-as-episode redesign for adaptive driving
+
+## Goal
+
+Replace the current scenario-based episode structure with **variable-length trials packed into a fixed-budget episode**, where each trial is an independent "fresh map, fresh agent" goal-reach attempt and the K/V cache persists across trials within an episode.
+
+## Decided semantics (from user)
+
+1. **"New trial"**: fresh agent placement (back to start), fresh goal, fresh recorded-human trajectories restarting from t=0. **Same map** within an episode (no map-swap-per-trial). The map_rand_per_scenario / _reinit_envs_with_new_maps machinery is being removed entirely from the codebase as part of this work.
+2. **Episode budget**: fixed at `k * scenario_length` ticks. If trials finish early, fill remaining budget with *another fresh episode* (start trial 0 of a new episode in the same buffer slot, separated by terminal flag).
+3. **K/V cache persistence**:
+ - **Adaptive (treatment)**: cache persists across trials within an episode. Only resets at episode boundary.
+ - **Control**: cache resets at every trial boundary. Configurable knob.
+4. **Full design, not MVP.** OK with 5-7 days of work.
+
+## Code to be DELETED in this refactor
+
+- `map_rand_per_scenario` flag and all its handling sites (drive.py, drive.h, eval scripts)
+- `_reinit_envs_with_new_maps` Python function (drive.py:888-975) — entire function gone
+- The `_render_keep_client_on_swap` / `vec_donate_client` / `vec_adopt_client` plumbing that exists *only* to keep raylib alive across the now-removed reinit
+- The "broken as an ICL probe" warning block in adaptive.ini
+
+The trial-boundary swap is now a fresh, smaller C function (`start_new_trial`) that does ONLY agent-state reset + goal sample, no env tear-down, no terminals write. New goal source: either `sample_new_goal()` (a road-lane point ahead of agent) or another agent's `init_goal` from the same map.
+
+## Key implementation insight (from Plan agent)
+
+`done_mask = terminals + truncations` controls cache reset (pufferl.py:610).
+GAE only uses `terminals` for bootstrap truncation (not truncations).
+
+So we have three distinct events:
+
+| event | terminals | truncations | cache | GAE bootstrap |
+|-------|----------:|------------:|:------|:--------------|
+| within-trial step | 0 | 0 | continues | continues |
+| **trial boundary (adaptive)** | 0 | 0 | continues | continues across |
+| **trial boundary (control)** | 1 | 0 | resets | truncates |
+| episode boundary | 1 | 0 | resets | truncates |
+
+The episode boundary marker lets pufferl pack two episodes into one segment row — `create_episode_mask` will block cross-episode attention via the cumsum-on-terminals episode-id mechanism. Variable-length episodes inside a fixed-size segment work out of the box this way.
+
+## Architecture
+
+### C side (drive.h)
+
+**New goal_behavior mode `GOAL_TRIAL = 3`**:
+- On goal-reach: full goal_weight reward, mark `current_goal_reached=1`, set per-agent `trial_complete=1`. **Do not** stop, do not respawn (yet — Python decides).
+- A new per-env counter `current_trial` (uses existing `current_scenario`-style logic).
+
+**New per-env flag accessible from Python**: `trial_ended_this_step[env_idx]`. Set to 1 in the C step when:
+- All ego agents in the env have reached goal, OR
+- The trial-length budget (currently `scenario_length`) has elapsed without reach.
+
+This flag is consumed by Python at the same tick — Python decides whether to swap maps, whether to set terminals, etc.
+
+**No `_reinit_envs_with_new_maps` involvement.** Trial boundary uses a new lighter C function `start_new_trial(env, agent_idx)`:
+- Reset agent position via `set_start_position` for that agent only
+- Sample new goal via `sample_new_goal()` (or pick from another agent's `init_goal`)
+- Reset per-agent metrics (`current_goal_reached=0, collided_before_goal=0, stopped=0, removed=0`)
+- Reset agent's per-trial Log fields
+- DO NOT reset `env->timestep` — the experts/partners continue along their recorded trajectory
+- DO NOT touch other agents
+
+Episode boundary (`tick == episode_budget`): pufferl handles this via the standard `done_mask` path. We set `terminals[ego_ids]=1` from Python, the env is fully reset by `binding.vec_reset` on the next `puffer_env.reset()` call. No special C function needed; the existing `c_reset` path does it.
+
+**Per-trial Log fields** (in `Log` struct):
+- `trials_completed_this_episode` (int)
+- `trials_attempted_this_episode` (int)
+- `mean_time_to_goal_per_trial` (float, running mean)
+- `per_trial_succeeded[MAX_TRIALS]` (fixed-size array, MAX_TRIALS=8)
+- `per_trial_collided_before_goal[MAX_TRIALS]` (fixed-size array)
+
+vec_log aggregation accumulates these across envs.
+
+### Python side (drive.py + adaptive.py)
+
+**`AdaptiveDrivingAgent.__init__`**: add knobs
+- `goal_behavior_trial = True` (whether to use new mode)
+- `max_trials_per_episode` (default = `k_scenarios`)
+- `trial_length` (default = `scenario_length`)
+- `trial_cache_reset` (False = adaptive treatment, True = control)
+- `episode_budget = max_trials_per_episode * trial_length` (just an alias)
+
+**`Drive.step` modifications**:
+- Per-step: read `trial_ended_this_step` array from C.
+- For each env where trial ended:
+ - Call C function `start_new_trial(env, ego_idx)` for each ego in that env
+ - Increment `current_trial`
+ - If `trial_cache_reset=True`: set `truncations[ego_ids_in_env]=1` (resets cache via done_mask, does NOT truncate GAE)
+ - Aggregate trial metrics into per-trial dict
+- After all per-env trial-ends processed:
+ - Increment `episode_ticks` counter
+ - If `episode_ticks >= episode_budget` OR `current_trial >= max_trials_per_episode`:
+ - Set `terminals[ego_ids]=1` for ALL envs (episode boundary; pufferl's c_reset path handles the actual env reset on the next call)
+ - Reset per-episode counters
+ - Compute episode-level delta metrics (analogous to current `_compute_delta_metrics`)
+ - Aggregate per-trial metrics list
+
+**`_compute_delta_metrics` generalization**: now per-trial. Returns `trial_N_score` for each completed trial, plus `ada_delta_score = trial_(last)_score - trial_0_score`.
+
+### Eval side (evaluator.py + utils.py)
+
+**`HumanReplayEvaluator.rollout`**: rewrite outer loop to be trial-aware:
+```python
+for episode in range(num_episodes):
+ obs, _ = env.reset()
+ state = _fresh_state()
+ trial_metrics_per_episode = []
+ while True:
+ for tick in range(self.sim_steps):
+ obs, rew, dones, truncs, info = env.step(action)
+ ...track per-trial success_arr...
+ if dones.any():
+ break # episode ended
+ else:
+ continue # trial ended but episode budget remaining
+ break
+ aggregate per-trial metrics
+```
+
+Per-`(rollout, agent, trial_idx)` success matrix instead of per-`(rollout, agent, scenario)`.
+
+`RECOVERY_CACHE_RESET_PER_SCENARIO` → `RECOVERY_CACHE_RESET_PER_TRIAL` env var (control switch; ON = control, OFF = adaptive).
+
+### Pufferl integration
+
+Should require **zero changes**. The recv loop, GAE kernel, and `create_episode_mask` already handle:
+- Multiple episodes per segment row (terminals=1 marks the joins)
+- Cache reset on `done_mask=t+d`
+- GAE truncation only on terminals
+
+Verify with a synthetic test before assuming.
+
+### Config knobs (adaptive.ini)
+
+```ini
+[env]
+goal_behavior = 3 # GOAL_TRIAL mode
+max_trials_per_episode = 4 # how many trials per episode max
+trial_length = 100 # per-trial budget in ticks
+episode_budget = 400 # max total episode ticks (= max_trials * trial_length)
+trial_cache_reset = False # adaptive treatment (cache persists). True = control.
+```
+
+## Implementation milestones
+
+### Milestone 1 — C-side trial mode + Python signal (Day 1-2)
+
+- Add `GOAL_TRIAL=3` constant + new branch in `c_step` goal-reach logic.
+- Add `trial_ended_this_step[env]` field to env struct, exposed via a new binding (`vec_get_trial_ended` or similar) or via observation channel.
+- Add per-trial Log fields.
+- Unit test: instantiate one Drive env with goal_behavior=3, step until agent reaches goal, verify trial_ended flag fires, verify trial_complete count increments.
+
+### Milestone 2 — Python wiring + map swap (Day 3)
+
+- Refactor `_reinit_envs_with_new_maps` to take `set_terminals=True` param (default True for backwards compat).
+- In `Drive.step`, read trial_ended array, call reinit-without-terminals at trial boundary.
+- Increment current_trial, handle episode-budget exhaustion.
+- Set terminals at episode boundary; integrate with map swap.
+- Track per-trial metrics.
+
+### Milestone 3 — `_compute_delta_metrics` rewrite (Day 4)
+
+- Generalize from per-scenario to per-trial.
+- Emit `trial_N_score`, `trial_N_collision_rate`, etc.
+- Compute `ada_delta_score` as last_trial_score - first_trial_score.
+- Add `mean_trial_score`, `n_trials_completed_per_episode`.
+
+### Milestone 4 — Eval mirror (Day 5)
+
+- Rewrite `HumanReplayEvaluator.rollout` outer loop.
+- Per-(rollout, agent, trial) success array.
+- Mirror trial_cache_reset for control runs.
+- Test on a saved checkpoint.
+
+### Milestone 5 — Integration test + verification (Day 6)
+
+- End-to-end: train one epoch with new mode, check loss not NaN, scores look sensible.
+- Compare GAE flow with `terminals=1` only at episode boundary vs at every trial boundary; verify expected behavior.
+- Log `transformer_position` over an episode to verify cache lifecycle.
+
+### Milestone 6 — A/B run + ship (Day 7-10)
+
+- Train 3 seeds k=2 (now 4 trials × 100 ticks per trial = 400-tick episode budget) gb=3 — adaptive treatment.
+- Train 3 seeds k=2 — control (trial_cache_reset=True).
+- Compare ada_delta_score curves.
+
+## Risks / open questions
+
+1. **Cost of map-swap-per-trial**: `_reinit_envs_with_new_maps` is the slow call (vec_close + 540× env_init). Per-trial cost ~5-10s. With 4 trials/episode and many parallel envs, this could 4× training wall-clock.
+ - Mitigation: could pre-load N maps and just switch index, but env doesn't support that today. Adds engineering.
+
+2. **K/V cache size**: with episode_budget=400 and horizon=400, cache exactly fits. With variable trial counts, the cache might overrun if trials run long. We need to cap episode at horizon to avoid wraparound.
+
+3. **Per-trial Log fixed-size arrays**: MAX_TRIALS=8 is arbitrary. If a user sets max_trials_per_episode=16, breaks. Needs runtime validation.
+
+4. **Per-env trial counters when num_envs=540**: 540 separate counters. Cheap in C, just need to expose them right.
+
+5. **`map_rand_per_scenario` flag**: existing knob is "broken as ICL probe" because of the terminals=1 write. Our new mode bypasses this. Should we also fix the existing flag, or remove it / mark deprecated?
+
+## Smallest-possible-test before committing to full impl
+
+(Removed — Day 0 used `map_rand_per_scenario`, which we're deleting. Going straight to the full implementation.)
diff --git a/pufferlib/config/default.ini b/pufferlib/config/default.ini
index 810de828c0..248eb8d3b3 100644
--- a/pufferlib/config/default.ini
+++ b/pufferlib/config/default.ini
@@ -49,7 +49,7 @@ minibatch_size = 8192
# Accumulate gradients above this size
max_minibatch_size = 32768
-bptt_horizon = 64
+horizon = 64
compile = False
compile_mode = max-autotune-no-cudagraphs
compile_fullgraph = True
@@ -81,7 +81,7 @@ downsample = 10
; mean = 1e8
; scale = time
-; [sweep.train.bptt_horizon]
+; [sweep.train.horizon]
; distribution = uniform_pow2
; min = 16
; max = 64
diff --git a/pufferlib/config/ocean/adaptive.ini b/pufferlib/config/ocean/adaptive.ini
index 3e4eaf60c4..fc347130b8 100644
--- a/pufferlib/config/ocean/adaptive.ini
+++ b/pufferlib/config/ocean/adaptive.ini
@@ -2,22 +2,35 @@
package = ocean
env_name = puffer_adaptive_drive
policy_name = Drive
-rnn_name = Recurrent
+policy_architecture = Transformer
+; Adaptive runs always use the transformer wrapper.
+rnn_name = Transformer
[vec]
-num_workers = 16
-num_envs = 16
+num_workers = 32
+num_envs = 32
batch_size = 1
; backend = Serial
[policy]
-input_size = 64
+input_size = 128
hidden_size = 256
[rnn]
input_size = 256
hidden_size = 256
+[transformer]
+input_size = 256
+hidden_size = 256
+num_layers = 2
+; Number of transformer layers
+num_heads = 4
+; Number of attention heads (must divide hidden_size)
+; Transformer uses `horizon` from [train] section for attention span
+dropout = 0.0
+; Dropout (keep at 0 for RL stability initially)
+
[env]
num_agents = 1024
num_ego_agents = 512
@@ -25,35 +38,99 @@ num_ego_agents = 512
action_type = discrete
; Options: classic, jerk
dynamics_model = classic
-; Number of consecutive scenarios per episode (adaptive-specific)
-k_scenarios = 2
reward_vehicle_collision = -0.5
-reward_offroad_collision = -0.2
-reward_ade = 0.0
+reward_offroad_collision = -0.5
dt = 0.1
reward_goal = 1.0
reward_goal_post_respawn = 0.25
+; in case of reward conditioning, we scale the goal_weight by this number for post respawn
+; Lane alignment reward (GIGAFLOW) - set to 0 to disable
+reward_lane_align = 0.0
+; Velocity alignment coefficient for lane reward
+reward_vel_align = 1.0
; Meters around goal to be considered "reached"
goal_radius = 2.0
-; What to do when goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop"
+; Max target speed in m/s for the agent to maintain towards the goal
+goal_speed = 100.0
+; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop"
goal_behavior = 0
+; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals.
+; Large numbers will select a goal point further away from the agent's current position.
+goal_target_distance = 30.0
; Options: 0 - Ignore, 1 - Stop, 2 - Remove
collision_behavior = 0
; Options: 0 - Ignore, 1 - Stop, 2 - Remove
offroad_behavior = 0
-; Number of steps before reset
+; Number of steps in each scenario (constrained by base data)
scenario_length = 91
-; Resample frequency = k_scenarios * scenario_length (adaptive-specific)
-resample_frequency = 182
-num_maps = 1000
-; Which step of the trajectory to initialize the agents at upon reset
+k_scenarios = 2
+termination_mode = 1
+; 0 - terminate at episode_length, 1 - terminate after all agents have been reset
+map_dir = "resources/drive/binaries/training"
+num_maps = 10000
+; Determines which step of the trajectory to initialize the agents at upon reset
init_steps = 0
-; Options: "control_vehicles", "control_agents", "control_tracks_to_predict"
+; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only"
control_mode = "control_vehicles"
; Options: "created_all_valid", "create_only_controlled"
init_mode = "create_all_valid"
; train with co players
-co_player_enabled = 1
+co_player_enabled = True
+; When True, co-player inference happens on the main GPU process (batched
+; across workers) instead of single-thread CPU per worker. ~5-10x speedup
+; on env stepping; required for k>=3 to be tractable on adaptive runs.
+external_co_player_actions = False
+; Re-init C envs with fresh map_ids at every within-episode scenario
+; boundary. BROKEN as an ICL probe — the reinit sets terminals[:]=1
+; which wipes the ego cache and truncates GAE at the boundary.
+map_rand_per_scenario = False
+; When True (only meaningful for k_scenarios > 1, with co_player_enabled): at
+; every scenario boundary, the partner conditioning vector is re-sampled from
+; configured ranges. The partner POLICY weights are unchanged, but its
+; effective behavior shifts per scenario. The ego policy must infer partner
+; type from s_0 observations and use it in s_1. Independent of
+; map_rand_per_scenario; set ONE of them, not both.
+condition_rand_per_scenario = False
+; When True, partner's entropy_weight_ub is annealed up over training in 4
+; stages: 0.05 → 0.20 → 0.50 → 1.0 of the user-passed entropy_ub. Each
+; stage advances every 30 episodes per worker (≈30 epochs at nw=32 nv=32).
+; Other conditioning dims (collision/offroad/discount) sample at their
+; passed ranges throughout. Use with a wide-conditioning partner whose
+; trained range covers up to the final entropy_ub.
+entropy_curriculum_enabled = False
+; When resuming a curriculum run from checkpoint, set this to the original
+; run's last episode count so the curriculum picks up at the right stage
+; instead of restarting from stage 0. Each per-worker counter starts at
+; this value. No-op when entropy_curriculum_enabled = False.
+entropy_curriculum_episodes_start = 0
+; When True, ego K/V cache is reset at SOME within-episode scenario
+; boundaries (controlled by curriculum stage) so the policy first learns
+; to drive single scenarios, then short cross-scenario context, before
+; full K_max-scenario context. Implementation: at boundaries to cut, we
+; set truncations[ego_ids]=1 + terminals[ego_ids]=1, which makes
+; pufferl drop the cache via done_mask=t+d during eval and blocks
+; cross-boundary attention via create_episode_mask during training.
+; Stage schedule: stage 0 = k_eff=1, stage 1 = k_eff=2, stage 2 = K_max.
+; K_max=4 with k_eff∈{1,2,4} produces uniform splits; for other K_max,
+; only k_eff=1 and k_max produce uniform stages.
+k_eff_curriculum_enabled = False
+; Episodes per stage in the k_eff curriculum.
+k_eff_curriculum_episodes_per_stage = 30
+; When True, the partner's per-env conditioning vector is appended to
+; the END of every ego's observation as `oracle` slots. The C-side obs
+; format is unchanged (writes into a private buffer); a wider Python
+; buffer holds the C output plus the appended oracle dims. Pufferl's
+; partner-obs path strips the trailing oracle dims via the env's
+; `_c_obs_dim` attribute, so the partner policy still sees its original
+; obs width. Reward, GAE γ, and PPO entropy coefficient unchanged.
+; Requires co_player_policy.conditioning.type != "none".
+ego_is_oracle = False
+
+; If True, zero rewards in scenarios 0..k-2; only the last scenario
+; produces reward signal. Forces credit assignment to flow back from s_{k-1}
+; rewards to actions in earlier scenarios — used to test whether reward
+; shape is the bottleneck for cross-scenario adaptation.
+reward_only_last_scenario = False
[env.conditioning]
@@ -71,17 +148,25 @@ discount_weight_lb = 0.80
discount_weight_ub = 0.98
[env.co_player_policy]
-enabled = True
policy_name = Drive
-rnn_name = Recurrent
-policy_path = "experiments/puffer_drive_ewdjljwd.pt"
-input_size = 64
+; Options: "Recurrent", "Transformer"
+architecture = Recurrent
+policy_path = "pufferlib/resources/drive/policies/varied_discount.pt"
+input_size = 128
hidden_size = 256
[env.co_player_policy.rnn]
input_size = 256
hidden_size = 256
+[env.co_player_policy.transformer]
+input_size = 256
+hidden_size = 256
+num_layers = 2
+num_heads = 4
+horizon = 91
+dropout = 0.0
+
[env.co_player_policy.conditioning]
; Options: "none", "reward", "entropy", "discount", "all"
type = "all"
@@ -96,68 +181,77 @@ entropy_weight_ub = 0.001
discount_weight_lb = 0.80
discount_weight_ub = 0.98
-
[train]
+seed=42
+compile = True
+compile_mode = default
+precision = bfloat16
total_timesteps = 2_000_000_000
-# learning_rate = 0.02
-# gamma = 0.985
anneal_lr = True
-; Needs to be: num_agents * num_workers * BPTT horizon
+; Needs to be: num_agents * num_workers * horizon
batch_size = auto
-; minibatch_size = 745472
-; minibatch_multiplier = 512
-; max_minibatch_size = 745472
-minibatch_size = 372736
-minibatch_multiplier = 256
-max_minibatch_size = 372736
-; BPTT horizon (overridden by pufferl.py for adaptive agents to k_scenarios * scenario_length)
-bptt_horizon = 32
+minibatch_size = 36400
+; 200 * 182 (must be divisible by horizon = k_scenarios * scenario_length)
+max_minibatch_size = 36400
+minibatch_multiplier = 400
+; Sequence length - overridden to k_scenarios * scenario_length for adaptive
+horizon = 91
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_eps = 1e-8
clip_coef = 0.2
-ent_coef = 0.001
+ent_coef = 0.005
gae_lambda = 0.95
-gamma = 0.98
-learning_rate = 0.001
-max_grad_norm = 1
-prio_alpha = 0.8499999999999999
-prio_beta0 = 0.8499999999999999
+; gamma=0.995 → ~135-step credit half-life (was 0.98 → ~34-step). Bumped
+; for k=2/201 (horizon=402) so s_1 rewards can actually flow back to s_0
+; actions; γ=0.98 over 402 steps gave ~0.0003 effective discount → no
+; cross-scenario credit assignment, killing the in-context-learning
+; signal. See notes/oracle_partner_conditioning_investigation.md.
+gamma = 0.995
+learning_rate = 0.003
+; Reduced from 0.003 (transformers often need lower LR)
+max_grad_norm = 1.0
+prio_alpha = 0.85
+prio_beta0 = 0.85
update_epochs = 1
-vf_clip_coef = 0.1999999999999999
-vf_coef = 2
+vf_clip_coef = 0.2
+vf_coef = 2.0
vtrace_c_clip = 1
vtrace_rho_clip = 1
-checkpoint_interval = 1000
+checkpoint_interval = 40
# Rendering options
render = True
-render_interval = 1000
+render_interval = 100
; If True, show exactly what the agent sees in agent observation
obs_only = True
; Show grid lines
-show_grid = False
+show_grid = True
; Draws lines from ego agent observed ORUs and road elements to show detection range
show_lasers = False
; Display human xy logs in the background
-show_human_logs = True
-; Options: str to path (e.g., "resources/drive/binaries/map_001.bin"), None
+show_human_logs = False
+; If True, zoom in on a part of the map. Otherwise, show full map
+zoom_in = True
+; Options: List[str to path], str to path (e.g., "resources/drive/training/binaries/map_001.bin"), None
render_map = none
[eval]
-eval_interval = 1000
+eval_interval = 40
+; Path to dataset used for evaluation. None = inherit env.map_dir from training.
+map_dir = None
+; Evaluation will run on the first num_maps maps in the map_dir directory
+num_maps = 20
backend = PufferEnv
-# WOSAC (Waymo Open Sim Agents Challenge) evaluation settings
+; WOSAC (Waymo Open Sim Agents Challenge) evaluation settings
; If True, enables evaluation on realism metrics each time we save a checkpoint
-wosac_realism_eval = True
+wosac_realism_eval = False
; Number of policy rollouts per scene
wosac_num_rollouts = 32
; When to start the simulation
wosac_init_steps = 10
-; Total number of WOSAC agents to evaluate
-wosac_num_agents = 256
-; Control the tracks to predict
-wosac_control_mode = "control_tracks_to_predict"
-; Initialize from the tracks to predict
+; Control everything valid at init in the scene
+wosac_control_mode = "control_wosac"
+; Create everything in valid at init the scene
wosac_init_mode = "create_all_valid"
; Stop when reaching the goal
wosac_goal_behavior = 2
@@ -168,23 +262,33 @@ wosac_sanity_check = False
wosac_aggregate_results = True
; If True, enable human replay evaluation (pair policy-controlled agent with human replays)
human_replay_eval = True
-; Control only the self-driving car
-human_replay_control_mode = "control_sdc_only"
-; This equals the number of scenarios, since we control one agent in each
-human_replay_num_agents = 64
+; Control mode for human replay (control_vehicles with max_controlled_agents=1 controls one agent)
+human_replay_control_mode = "control_vehicles"
+; Number of agents in human replay evaluation environment.
+; Set equal to num_maps so every scene gets a controllable SDC each rollout
+; (human_replay caps at one controllable agent per scene).
+human_replay_num_agents = 100
+; Number of independent rollouts per eval cycle. With deterministic resets the
+; scene set is fixed across rollouts, so num_rollouts samples policy action
+; stochasticity. Reported metrics are mean across rollouts with `_std`.
+human_replay_num_rollouts = 100
+; Number of maps to use for human replay evaluation
+human_replay_num_maps = 100
+; Number of maps to render for human replay (subset of eval maps)
+human_replay_render_num_maps = 1
-[sweep.env.reward_vehicle_collision]
-distribution = uniform
-min = -0.5
-max = 0.0
-mean = -0.05
+[sweep.train.learning_rate]
+distribution = log_normal
+min = 0.001
+mean = 0.003
+max = 0.005
scale = auto
-[sweep.env.reward_offroad_collision]
-distribution = uniform
-min = -0.5
-max = 0.0
-mean = -0.05
+[sweep.train.ent_coef]
+distribution = log_normal
+min = 0.001
+mean = 0.005
+max = 0.03
scale = auto
[sweep.env.goal_radius]
@@ -194,16 +298,18 @@ max = 20.0
mean = 10.0
scale = auto
-[sweep.env.reward_ade]
-distribution = uniform
-min = -0.1
-max = 0.0
-mean = -0.02
+[sweep.train.gae_lambda]
+distribution = log_normal
+min = 0.95
+mean = 0.98
+max = 0.999
scale = auto
-[sweep.env.reward_goal_post_respawn]
-distribution = uniform
-min = 0.0
-max = 1.0
-mean = 0.5
-scale = auto
+[controlled_exp.train.goal_speed]
+values = [10, 20, 30, 3]
+
+[controlled_exp.train.ent_coef]
+values = [0.001, 0.005, 0.01]
+
+[controlled_exp.train.seed]
+values = [42, 55, 1]
diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini
index b7bc0d021f..366ee25ac3 100644
--- a/pufferlib/config/ocean/drive.ini
+++ b/pufferlib/config/ocean/drive.ini
@@ -2,22 +2,33 @@
package = ocean
env_name = puffer_drive
policy_name = Drive
-rnn_name = Recurrent
+policy_architecture = Transformer
[vec]
-num_workers = 16
-num_envs = 16
+num_workers = 8
+num_envs = 8
batch_size = 2
; backend = Serial
[policy]
-input_size = 64
+input_size = 128
hidden_size = 256
[rnn]
input_size = 256
hidden_size = 256
+[transformer]
+input_size = 256
+hidden_size = 256
+num_layers = 2
+; Number of transformer layers
+num_heads = 4
+; Number of attention heads (must divide hidden_size)
+; k_scenarios (2) * scenario_length (91) = maximum attention span
+dropout = 0.0
+; Dropout (keep at 0 for RL stability initially)
+
[env]
num_agents = 1024
num_ego_agents = 512
@@ -26,26 +37,36 @@ action_type = discrete
; Options: classic, jerk
dynamics_model = classic
reward_vehicle_collision = -0.5
-reward_offroad_collision = -0.2
-reward_ade = 0.0
+reward_offroad_collision = -0.5
dt = 0.1
reward_goal = 1.0
reward_goal_post_respawn = 0.25 # in case of reward conditioning, we scale the goal_weight by this number for post respawn
+; Lane alignment reward (GIGAFLOW) - set to 0 to disable
+reward_lane_align = 0.0
+; Velocity alignment coefficient for lane reward
+reward_vel_align = 1.0
; Meters around goal to be considered "reached"
goal_radius = 2.0
-; What to do when goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop"
+; Max target speed in m/s for the agent to maintain towards the goal
+goal_speed = 100.0
+; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop"
goal_behavior = 0
+; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals.
+; Large numbers will select a goal point further away from the agent's current position.
+goal_target_distance = 30.0
; Options: 0 - Ignore, 1 - Stop, 2 - Remove
collision_behavior = 0
; Options: 0 - Ignore, 1 - Stop, 2 - Remove
offroad_behavior = 0
-; Number of steps before reset
+; Number of steps before
scenario_length = 91
-resample_frequency = 182
+resample_frequency = 910
+termination_mode = 1 # 0 - terminate at episode_length, 1 - terminate after all agents have been reset
+map_dir = "resources/drive/binaries/training"
num_maps = 1000
-; Which step of the trajectory to initialize the agents at upon reset
+; Determines which step of the trajectory to initialize the agents at upon reset
init_steps = 0
-; Options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only"
+; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only"
control_mode = "control_vehicles"
; Options: "created_all_valid", "create_only_controlled"
init_mode = "create_all_valid"
@@ -67,10 +88,9 @@ discount_weight_lb = 0.80
discount_weight_ub = 0.98
[env.co_player_policy]
-enabled = False
policy_name = Drive
rnn_name = Recurrent
-policy_path = "resources/drive/policies/varied_discount.pt"
+policy_path = "pufferlib/resources/drive/policies/varied_discount.pt"
input_size = 64
hidden_size = 256
@@ -93,15 +113,17 @@ discount_weight_lb = 0.98
discount_weight_ub = 0.80
[train]
+seed=42
total_timesteps = 2_000_000_000
-# learning_rate = 0.02
-# gamma = 0.985
anneal_lr = True
-; Needs to be: num_agents * num_workers * BPTT horizon
+; Needs to be: num_agents * num_workers * horizon
batch_size = auto
-minibatch_size = 32768
-max_minibatch_size = 32768
-bptt_horizon = 32
+minibatch_size = 32760
+; 360 * 91
+max_minibatch_size = 32760
+minibatch_multiplier = 360
+; Sequence length for training - matches scenario_length for full episode context
+horizon = 91
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_eps = 1e-8
@@ -110,44 +132,49 @@ ent_coef = 0.005
gae_lambda = 0.95
gamma = 0.98
learning_rate = 0.003
-max_grad_norm = 1
-prio_alpha = 0.8499999999999999
-prio_beta0 = 0.8499999999999999
+max_grad_norm = 1.0
+prio_alpha = 0.85
+prio_beta0 = 0.85
update_epochs = 1
-vf_clip_coef = 0.1999999999999999
-vf_coef = 2
+vf_clip_coef = 0.2
+vf_coef = 2.0
vtrace_c_clip = 1
vtrace_rho_clip = 1
-checkpoint_interval = 1000
+checkpoint_interval = 100
+context_length = 32
# Rendering options
render = True
-render_interval = 1000
+render_interval = 100
; If True, show exactly what the agent sees in agent observation
obs_only = True
; Show grid lines
-show_grid = False
+show_grid = True
; Draws lines from ego agent observed ORUs and road elements to show detection range
show_lasers = False
; Display human xy logs in the background
-show_human_logs = True
-; Options: str to path (e.g., "resources/drive/binaries/map_001.bin"), None
+show_human_logs = False
+; If True, zoom in on a part of the map. Otherwise, show full map
+zoom_in = True
+; Options: List[str to path], str to path (e.g., "resources/drive/training/binaries/map_001.bin"), None
render_map = none
[eval]
-eval_interval = 1000
+eval_interval = 100
+; Path to dataset used for evaluation. None = inherit env.map_dir from training.
+map_dir = None
+; Evaluation will run on the first num_maps maps in the map_dir directory
+num_maps = 20
backend = PufferEnv
-# WOSAC (Waymo Open Sim Agents Challenge) evaluation settings
+; WOSAC (Waymo Open Sim Agents Challenge) evaluation settings
; If True, enables evaluation on realism metrics each time we save a checkpoint
wosac_realism_eval = False
; Number of policy rollouts per scene
wosac_num_rollouts = 32
; When to start the simulation
wosac_init_steps = 10
-; Total number of WOSAC agents to evaluate
-wosac_num_agents = 256
-; Control the tracks to predict
-wosac_control_mode = "control_tracks_to_predict"
-; Initialize from the tracks to predict
+; Control everything valid at init in the scene
+wosac_control_mode = "control_wosac"
+; Create everything in valid at init the scene
wosac_init_mode = "create_all_valid"
; Stop when reaching the goal
wosac_goal_behavior = 2
@@ -157,11 +184,17 @@ wosac_sanity_check = False
; Only return aggregate results across all scenes
wosac_aggregate_results = True
; If True, enable human replay evaluation (pair policy-controlled agent with human replays)
-human_replay_eval = False
-; Control only the self-driving car
-human_replay_control_mode = "control_sdc_only"
-; This equals the number of scenarios, since we control one agent in each
-human_replay_num_agents = 64
+human_replay_eval = True
+; Control mode for human replay (control_vehicles with max_controlled_agents=1 controls one agent)
+human_replay_control_mode = "control_vehicles"
+; Number of agents in human replay evaluation environment
+human_replay_num_agents = 32
+; Number of rollouts for human replay evaluation
+human_replay_num_rollouts = 100
+; Number of maps to use for human replay evaluation
+human_replay_num_maps = 100
+; Number of maps to render for human replay (subset of eval maps)
+human_replay_render_num_maps = 1
[sweep.train.learning_rate]
distribution = log_normal
@@ -174,10 +207,9 @@ scale = auto
distribution = log_normal
min = 0.001
mean = 0.005
-max = 0.01
+max = 0.03
scale = auto
-
[sweep.env.goal_radius]
distribution = uniform
min = 2.0
@@ -185,16 +217,18 @@ max = 20.0
mean = 10.0
scale = auto
-[sweep.env.reward_ade]
-distribution = uniform
-min = -0.1
-max = 0.0
-mean = -0.02
+[sweep.train.gae_lambda]
+distribution = log_normal
+min = 0.95
+mean = 0.98
+max = 0.999
scale = auto
-[sweep.env.reward_goal_post_respawn]
-distribution = uniform
-min = 0.0
-max = 1.0
-mean = 0.5
-scale = auto
+[controlled_exp.train.goal_speed]
+values = [10, 20, 30, 3]
+
+[controlled_exp.train.ent_coef]
+values = [0.001, 0.005, 0.01]
+
+[controlled_exp.train.seed]
+values = [42, 55, 1]
diff --git a/pufferlib/models.py b/pufferlib/models.py
index 0893a9db47..72121e458e 100644
--- a/pufferlib/models.py
+++ b/pufferlib/models.py
@@ -1,3 +1,5 @@
+import os
+
import numpy as np
import torch
@@ -7,6 +9,14 @@
import pufferlib.pytorch
import pufferlib.spaces
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+import math
+
+
+# Set PUFFER_TRANSFORMER_LEGACY_EVAL=1 to fall back to the pre-KV-cache path.
+_USE_LEGACY_EVAL = os.environ.get("PUFFER_TRANSFORMER_LEGACY_EVAL", "0") == "1"
+
class Default(nn.Module):
"""Default PyTorch policy. Flattens obs and applies a linear layer.
@@ -196,6 +206,488 @@ def forward(self, observations, state):
return logits, values
+class TransformerWrapper(nn.Module): # TransformerWrapper
+ def __init__(
+ self,
+ env,
+ policy,
+ input_size=128,
+ hidden_size=128,
+ num_layers=4,
+ num_heads=8,
+ horizon=512,
+ dropout=0.0,
+ use_checkpointing=False,
+ ):
+ """Wraps your policy with a Transformer for temporal modeling.
+
+ Args:
+ env: Environment instance
+ policy: Your Drive policy (must have encode_observations and decode_actions)
+ input_size: Size of encoded observations (from policy.encode_observations)
+ hidden_size: Transformer hidden dimension
+ num_layers: Number of transformer layers
+ num_heads: Number of attention heads
+ horizon: Maximum sequence length to attend over
+ dropout: Dropout probability
+ use_checkpointing: Enable gradient checkpointing to save memory (slower training)
+ """
+ super().__init__()
+ self.obs_shape = env.single_observation_space.shape
+ self.policy = policy
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.horizon = horizon
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ if hidden_size % num_heads != 0:
+ raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
+ self.head_dim = hidden_size // num_heads
+ self.is_continuous = self.policy.is_continuous
+ self.use_checkpointing = use_checkpointing
+ # Per-slot attention masks for KV-cached streaming inference. Cached
+ # lazily per device to avoid recomputing the same mask each step.
+ self._streaming_mask_cache = {}
+
+ # Project encoded observations to transformer dimension if needed
+ if input_size != hidden_size:
+ self.input_projection = nn.Linear(input_size, hidden_size)
+ else:
+ self.input_projection = nn.Identity()
+
+ # Sinusoidal positional embedding (Vaswani et al.) — non-trainable.
+ # Switched from learnable PE so the transformer has temporal
+ # structure from initialization rather than having to learn it
+ # from gradients. Slot-tied: PE[i] is added when writing to
+ # cache slot i, identical for both forward (training) and
+ # forward_eval (rollout) paths via get_positional_embedding().
+ pe = torch.zeros(horizon, hidden_size)
+ position = torch.arange(0, horizon, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, hidden_size, 2, dtype=torch.float) * (-math.log(10000.0) / hidden_size))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ # register_buffer keeps it on the module's device but excludes it
+ # from .parameters() (no gradient updates).
+ self.register_buffer("positional_embedding", pe.unsqueeze(0))
+
+ # Transformer encoder
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=hidden_size,
+ nhead=num_heads,
+ dim_feedforward=hidden_size * 2,
+ dropout=dropout,
+ activation="gelu",
+ batch_first=True,
+ norm_first=True, # Pre-LN architecture (more stable)
+ )
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
+
+ # create cache for memory context
+ for T in [1, 2, 4, 8, 16, 32, 64, 91, 182, 273, 364, 455]:
+ mask = self.create_causal_mask(T, "cpu")
+ self.register_buffer(f"_causal_mask_{T}", mask, persistent=False)
+
+ # Cached masks for episode mask creation (reduces memory allocation)
+ self.register_buffer("_zero_mask", torch.zeros(1), persistent=False)
+ self.register_buffer("_neg_inf_mask", torch.full((1,), float("-inf")), persistent=False)
+
+ # Layer norm for output
+ self.output_norm = nn.LayerNorm(hidden_size)
+
+ # Initialize weights
+ self._init_weights()
+
+ def _init_weights(self):
+ """Initialize weights similar to GPT-2"""
+ for name, param in self.named_parameters():
+ if "layer_norm" in name or "layernorm" in name or "output_norm" in name:
+ continue
+ if "bias" in name:
+ nn.init.constant_(param, 0)
+ elif "weight" in name and param.ndim >= 2:
+ nn.init.orthogonal_(param, 1.0)
+
+ def create_causal_mask(self, seq_len, device):
+ """Create causal attention mask"""
+ mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1)
+ return mask
+
+ def get_causal_mask(self, T, device):
+ """Get cached causal mask or create new one"""
+ buffer_name = f"_causal_mask_{T}"
+ if hasattr(self, buffer_name):
+ mask = getattr(self, buffer_name)
+ if mask.device != device:
+ # Move to device and cache
+ mask = mask.to(device)
+ setattr(self, buffer_name, mask)
+ return mask
+ mask = self.create_causal_mask(T, device)
+ self.register_buffer(buffer_name, mask, persistent=False)
+ return mask
+
+ def get_positional_embedding(self, T, device):
+ """Get cached positional embedding for length T"""
+ cache_key = f"_pos_embed_{T}"
+ if not hasattr(self, cache_key) or getattr(self, cache_key).device != device:
+ pos_embed = self.positional_embedding[:, :T].to(device)
+ setattr(self, cache_key, pos_embed)
+ return getattr(self, cache_key)
+
+ def create_episode_mask(self, terminals, seq_len):
+ """Episode mask which ensures that you arent attending over episode boundaries.
+ Optimized with cached mask buffers to reduce memory allocation."""
+ B = terminals.shape[0]
+ device = terminals.device
+
+ # Use cumsum for episode IDs
+ episode_ids = torch.nn.functional.pad(terminals[:, :-1], (1, 0)).cumsum(dim=1)
+
+ # Avoid full (B, T, T) allocation - use sparse comparison
+ mask_allow = episode_ids.unsqueeze(2) == episode_ids.unsqueeze(1)
+
+ # Use cached tensors moved to correct device
+ zero_mask = self._zero_mask.to(device) if self._zero_mask.device != device else self._zero_mask
+ neg_inf_mask = self._neg_inf_mask.to(device) if self._neg_inf_mask.device != device else self._neg_inf_mask
+
+ return torch.where(mask_allow, zero_mask, neg_inf_mask)
+
+ # ------------------------------------------------------------------ #
+ # Streaming inference with per-layer KV cache.
+ #
+ # The legacy `_forward_eval_legacy` below recomputes a full transformer
+ # forward over the entire horizon-length context buffer on every step,
+ # then throws away all but one output position. For B=512 co-players,
+ # horizon=91, single-thread CPU, this costs ~3 s per step and dominates
+ # training wallclock.
+ #
+ # The KV-cached path maintains per-layer (K, V) buffers in `state` and
+ # only computes Q/K/V for the new token, attending against the cache.
+ # Numerically equivalent to the legacy path (same rolling buffer +
+ # causal-row semantics, including the post-wrap "self-attention only"
+ # behavior at slot 0 after horizon steps).
+ #
+ # Implementation note: `slot` is kept as a 1-element long tensor (not
+ # `int(pos.item())`) so this method stays compile-friendly. On the GPU
+ # ego policy under torch.compile, `.item()` would cause a Dynamo graph
+ # break and a CUDA sync every call.
+ # ------------------------------------------------------------------ #
+
+ def _slot_arange(self, device):
+ """Return a length-`horizon` arange tensor on `device`, cached per device."""
+ key = (device.type, device.index if device.index is not None else -1)
+ cached = self._streaming_mask_cache.get(key)
+ if cached is not None:
+ return cached
+ arr = torch.arange(self.horizon, device=device)
+ self._streaming_mask_cache[key] = arr
+ return arr
+
+ def _make_kv_cache(self, batch_size, device, dtype):
+ return [
+ torch.zeros(batch_size, self.num_heads, self.horizon, self.head_dim, device=device, dtype=dtype)
+ for _ in range(self.num_layers)
+ ]
+
+ def init_eval_state(self, batch_size, device, dtype=torch.float32):
+ """Allocate a fresh streaming-inference state dict for this policy."""
+ return dict(
+ k_cache=self._make_kv_cache(batch_size, device, dtype),
+ v_cache=self._make_kv_cache(batch_size, device, dtype),
+ transformer_position=torch.zeros(1, dtype=torch.long, device=device),
+ )
+
+ def _prime_kv_cache(self, indices, state):
+ """Prime K/V cache for `indices` to match legacy 'zero hidden context'.
+
+ The legacy reset only zeroed the rolling hidden buffer. Because that
+ buffer is summed with the slot-tied positional embedding inside the
+ transformer, the *effective* K/V at unwritten slots is the K/V you
+ get from running the transformer over an all-zero hidden sequence
+ (i.e. just the pos embeddings, with causal attention propagating
+ through layers). This priming fills our cache with exactly that
+ state, so subsequent forward_eval calls match the legacy bit-close.
+ """
+ if isinstance(indices, slice) and indices == slice(None):
+ n_idx = state["k_cache"][0].shape[0]
+ elif torch.is_tensor(indices):
+ n_idx = int(indices.shape[0])
+ else:
+ n_idx = len(indices)
+ if n_idx == 0:
+ return
+
+ H, D = self.num_heads, self.head_dim
+ T = self.horizon
+ device = state["k_cache"][0].device
+ dtype = state["k_cache"][0].dtype
+
+ pos_embed = self.get_positional_embedding(T, device).to(dtype) # (1, T, hidden)
+ layer_input = pos_embed.expand(n_idx, T, self.hidden_size).contiguous()
+ causal_mask = self.get_causal_mask(T, device)
+
+ with torch.no_grad():
+ for li, layer in enumerate(self.transformer.layers):
+ attn = layer.self_attn
+ x_norm = layer.norm1(layer_input)
+ qkv = F.linear(x_norm, attn.in_proj_weight, attn.in_proj_bias)
+ q, k, v = qkv.chunk(3, dim=-1)
+ q = q.view(n_idx, T, H, D).transpose(1, 2)
+ k = k.view(n_idx, T, H, D).transpose(1, 2)
+ v = v.view(n_idx, T, H, D).transpose(1, 2)
+
+ # Defensive cast: under autocast or mixed-precision, k/v can
+ # come out in a different dtype than the cache; PyTorch refuses
+ # cross-dtype `index_put`. Match the cache's dtype.
+ state["k_cache"][li][indices] = k.to(state["k_cache"][li].dtype)
+ state["v_cache"][li][indices] = v.to(state["v_cache"][li].dtype)
+
+ attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask, is_causal=False)
+ attn_out = attn_out.transpose(1, 2).reshape(n_idx, T, self.hidden_size)
+ attn_out = F.linear(attn_out, attn.out_proj.weight, attn.out_proj.bias)
+ x = layer_input + attn_out
+ x_norm2 = layer.norm2(x)
+ ffn_h = layer.activation(F.linear(x_norm2, layer.linear1.weight, layer.linear1.bias))
+ ffn_out = F.linear(ffn_h, layer.linear2.weight, layer.linear2.bias)
+ layer_input = x + ffn_out
+
+ def reset_eval_state(self, state, done_indices=None):
+ """Reset KV cache (and step counter) for done agents.
+
+ - done_indices=None: full reset. K/V zeroed, step counter cleared.
+ (Equivalent to allocating a fresh state.)
+ - done_indices=tensor/array of agent indices: re-prime those rows
+ to mirror the legacy "zero hidden buffer" behavior. The shared
+ step counter is intentionally NOT reset in that case (matches
+ legacy, which only touched per-row context).
+ """
+ k_cache = state.get("k_cache")
+ v_cache = state.get("v_cache")
+ if k_cache is None or v_cache is None:
+ return
+ if done_indices is None:
+ for c in k_cache:
+ c.zero_()
+ for c in v_cache:
+ c.zero_()
+ pos = state.get("transformer_position")
+ if pos is not None:
+ pos.zero_()
+ else:
+ idx = done_indices
+ if not torch.is_tensor(idx):
+ idx = torch.as_tensor(idx, device=k_cache[0].device, dtype=torch.long)
+ self._prime_kv_cache(idx, state)
+
+ def forward_eval(self, observations, state):
+ if _USE_LEGACY_EVAL:
+ # Escape hatch for benchmarking / safety net: set
+ # PUFFER_TRANSFORMER_LEGACY_EVAL=1 in the environment to bypass
+ # the KV-cached path and use the original full-context forward.
+ return self._forward_eval_legacy(observations, state)
+ B = observations.shape[0]
+ device = observations.device
+
+ hidden = self.policy.encode_observations(observations, state=state)
+ hidden = self.input_projection(hidden)
+ # hidden: (B, hidden_size)
+
+ # Fetch or lazily allocate the KV cache. We re-allocate if shape
+ # changes (e.g. batch size differs across calls) or dtype mismatches
+ # the input (mixed-precision boundary).
+ k_cache = state.get("k_cache")
+ v_cache = state.get("v_cache")
+ need_alloc = (
+ k_cache is None or v_cache is None or k_cache[0].shape[0] != B or k_cache[0].shape[2] != self.horizon
+ )
+ if need_alloc:
+ k_cache = self._make_kv_cache(B, device, hidden.dtype)
+ v_cache = self._make_kv_cache(B, device, hidden.dtype)
+ pos = torch.zeros(1, dtype=torch.long, device=device)
+ else:
+ pos = state.get("transformer_position", torch.zeros(1, dtype=torch.long, device=device))
+ if k_cache[0].dtype != hidden.dtype:
+ k_cache = [c.to(hidden.dtype) for c in k_cache]
+ v_cache = [c.to(hidden.dtype) for c in v_cache]
+
+ slot_t = (pos % self.horizon).long() # (1,) long tensor
+
+ # Add the slot's positional embedding (slot-tied, matching the
+ # legacy rolling-buffer scheme).
+ pos_embed = self.get_positional_embedding(self.horizon, device) # (1, horizon, hidden)
+ pos_embed_slot = pos_embed.index_select(1, slot_t).squeeze(1) # (1, hidden)
+ x = (hidden + pos_embed_slot).unsqueeze(1) # (B, 1, hidden)
+
+ # Build (1, 1, 1, horizon) bool mask: True at slots [0, slot_t].
+ slots_arange = self._slot_arange(device)
+ attn_mask = (slots_arange <= slot_t).view(1, 1, 1, self.horizon)
+ H = self.num_heads
+ D = self.head_dim
+
+ for li, layer in enumerate(self.transformer.layers):
+ attn = layer.self_attn
+
+ x_norm = layer.norm1(x)
+ qkv = F.linear(x_norm, attn.in_proj_weight, attn.in_proj_bias)
+ q, k, v = qkv.chunk(3, dim=-1)
+ q = q.view(B, 1, H, D).transpose(1, 2) # (B, H, 1, D)
+ k = k.view(B, 1, H, D).transpose(1, 2) # (B, H, 1, D)
+ v = v.view(B, 1, H, D).transpose(1, 2) # (B, H, 1, D)
+
+ # index_copy_ on dim=2 writes one slot using a tensor index, which
+ # avoids the .item() sync that would force a Dynamo graph break.
+ k_cache[li].index_copy_(2, slot_t, k)
+ v_cache[li].index_copy_(2, slot_t, v)
+
+ if state.get("_probe_attention", False):
+ # Manual softmax attention so we can stash the weights. SDPA's
+ # functional form doesn't return weights. Math is identical to
+ # the SDPA call below but we capture (B, H, 1, horizon) weights
+ # per layer per step into state["_attn_weights"].
+ scale = 1.0 / math.sqrt(D)
+ logits = torch.matmul(q, k_cache[li].transpose(-2, -1)) * scale # (B, H, 1, horizon)
+ logits = logits.masked_fill(~attn_mask, float("-inf"))
+ weights = F.softmax(logits, dim=-1) # (B, H, 1, horizon)
+ attn_out = torch.matmul(weights, v_cache[li]) # (B, H, 1, D)
+ state.setdefault("_attn_weights", []).append(
+ {"layer": li, "slot": int(slot_t.item()), "weights": weights.detach().cpu()}
+ )
+ else:
+ attn_out = F.scaled_dot_product_attention(
+ q,
+ k_cache[li],
+ v_cache[li],
+ attn_mask=attn_mask,
+ is_causal=False,
+ )
+ attn_out = attn_out.transpose(1, 2).reshape(B, 1, self.hidden_size)
+ attn_out = F.linear(attn_out, attn.out_proj.weight, attn.out_proj.bias)
+ x = x + attn_out
+
+ x_norm2 = layer.norm2(x)
+ ffn_h = layer.activation(F.linear(x_norm2, layer.linear1.weight, layer.linear1.bias))
+ ffn_out = F.linear(ffn_h, layer.linear2.weight, layer.linear2.bias)
+ x = x + ffn_out
+
+ x = self.output_norm(x)
+ hidden_out = x.squeeze(1)
+
+ state["k_cache"] = k_cache
+ state["v_cache"] = v_cache
+ state["transformer_position"] = pos + 1
+ state["hidden"] = hidden_out
+
+ logits, values = self.policy.decode_actions(hidden_out)
+ return logits, values
+
+ def _forward_eval_legacy(self, observations, state):
+ """Original full-context forward. Kept for equivalence testing only."""
+ B = observations.shape[0]
+ device = observations.device
+
+ hidden = self.policy.encode_observations(observations, state=state)
+ hidden = self.input_projection(hidden)
+
+ if "transformer_context" not in state or state["transformer_context"] is None:
+ context = torch.zeros(B, self.horizon, self.hidden_size, device=device, dtype=hidden.dtype)
+ pos = torch.zeros(1, dtype=torch.long, device=device)
+ else:
+ context = state["transformer_context"]
+ pos = state.get("transformer_position", torch.zeros(1, dtype=torch.long, device=device))
+
+ if context.shape[-1] != self.hidden_size or context.shape[0] != B or context.shape[1] != self.horizon:
+ context = torch.zeros(B, self.horizon, self.hidden_size, device=device, dtype=hidden.dtype)
+ pos = torch.zeros(1, dtype=torch.long, device=device)
+ if context.dtype != hidden.dtype:
+ context = context.to(hidden.dtype)
+
+ write_idx = (pos % self.horizon).long()
+ context[:, write_idx, :] = hidden.unsqueeze(1)
+ pos = pos + 1
+
+ pos_embed = self.get_positional_embedding(self.horizon, device)
+ context_with_pos = context + pos_embed
+
+ causal_mask = self.get_causal_mask(self.horizon, device)
+
+ output = self.transformer(context_with_pos, mask=causal_mask, is_causal=True)
+ output = self.output_norm(output)
+
+ read_idx = ((pos - 1) % self.horizon).long()
+ hidden_out = output[:, read_idx, :].squeeze(1)
+
+ state["transformer_context"] = context
+ state["transformer_position"] = pos
+ state["hidden"] = hidden_out
+
+ logits, values = self.policy.decode_actions(hidden_out)
+ return logits, values
+
+ def forward(self, observations, state):
+ x = observations
+ device = x.device
+
+ if x.ndim == len(self.obs_shape) + 1:
+ B, T = x.shape[0], 1
+ elif x.ndim == len(self.obs_shape) + 2:
+ B, T = x.shape[:2]
+ else:
+ raise ValueError(f"Invalid input tensor shape: {x.shape}")
+
+ x_flat = x.view(B * T, *self.obs_shape)
+ hidden = self.policy.encode_observations(x_flat, state)
+
+ hidden = hidden.view(B, T, self.input_size)
+ hidden = self.input_projection(hidden)
+
+ # Remove dynamic truncation - use clamp instead of if
+ T_actual = min(T, self.horizon) # Python int, fine
+ if T_actual < T:
+ hidden = hidden[:, -T_actual:]
+ T = T_actual
+
+ hidden = hidden + self.get_positional_embedding(T, device)
+
+ use_episode_mask = "terminals" in state and state["terminals"] is not None
+
+ if not use_episode_mask:
+ causal_mask = self.get_causal_mask(T, device)
+ if self.training and self.use_checkpointing:
+ hidden = checkpoint(
+ lambda h, m: self.transformer(h, mask=m, is_causal=True), hidden, causal_mask, use_reentrant=False
+ )
+ else:
+ hidden = self.transformer(hidden, mask=causal_mask, is_causal=True)
+ else:
+ terminals = state["terminals"]
+ if terminals.shape[1] > T:
+ terminals = terminals[:, -T:]
+ causal_mask = self.get_causal_mask(T, device)
+ episode_mask = self.create_episode_mask(terminals, T)
+ attn_mask = causal_mask.unsqueeze(0) + episode_mask
+ attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0)
+ if self.training and self.use_checkpointing:
+ hidden = checkpoint(
+ lambda h, m: self.transformer(h, mask=m, is_causal=False), hidden, attn_mask, use_reentrant=False
+ )
+ else:
+ hidden = self.transformer(hidden, mask=attn_mask, is_causal=False)
+
+ hidden = self.output_norm(hidden)
+ flat_hidden = hidden.contiguous().view(B * T, self.hidden_size)
+
+ logits, values = self.policy.decode_actions(flat_hidden)
+ values = values.view(B, T)
+
+ # Use Python int for context_len - no sync
+ context_len = min(T, self.horizon)
+ state["hidden"] = hidden
+ state["transformer_context"] = hidden[:, -context_len:].detach()
+ state["transformer_position"] = torch.full((B,), context_len - 1, dtype=torch.long, device=device)
+
+ return logits, values
+
+
class Convolutional(nn.Module):
def __init__(
self,
diff --git a/pufferlib/ocean/benchmark/evaluator.py b/pufferlib/ocean/benchmark/evaluator.py
index 383384e623..818a266277 100644
--- a/pufferlib/ocean/benchmark/evaluator.py
+++ b/pufferlib/ocean/benchmark/evaluator.py
@@ -625,53 +625,156 @@ class HumanReplayEvaluator:
def __init__(self, config: Dict):
self.config = config
- self.sim_steps = 91 - self.config["env"]["init_steps"]
+ k_scenarios = self.config["env"].get("k_scenarios", 1)
+ scenario_length = self.config["env"].get("scenario_length", 91)
+ init_steps = self.config["env"].get("init_steps", 0)
+ self.sim_steps = scenario_length - init_steps
def rollout(self, args, puffer_env, policy):
"""Roll out policy in env with human replays. Store statistics.
- In human replay mode, only the SDC (self-driving car) is controlled by the policy
- while all other agents replay their human trajectories. This tests how compatible
- the policy is with (static) human partners.
+ In human replay mode, only the SDC is controlled by the policy while
+ all other agents replay their human trajectories. This tests how
+ compatible the policy is with static human partners.
- Args:
- args: Config dict with train settings (device, use_rnn, etc.)
- puffer_env: PufferLib environment wrapper
- policy: Trained policy to evaluate
-
- Returns:
- dict: Aggregated metrics including:
- - avg_collisions_per_agent: Average collisions per agent
- - avg_offroad_per_agent: Average offroad events per agent
+ Runs `num_rollouts` independent rollouts (env is reset between each,
+ which resamples the map/agent slice). Reports per-key mean across
+ rollouts, plus `_std` so you can see variance once the per-batch
+ score saturates.
"""
import numpy as np
import torch
import pufferlib
+ num_rollouts = int(args.get("eval", {}).get("human_replay_num_rollouts", 1) or 1)
num_agents = puffer_env.observation_space.shape[0]
device = args["train"]["device"]
+ k_scenarios = args["env"].get("k_scenarios", 1)
- obs, info = puffer_env.reset()
- state = {}
- if args["train"]["use_rnn"]:
- state = dict(
- lstm_h=torch.zeros(num_agents, policy.hidden_size, device=device),
- lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device),
- )
-
- for time_idx in range(self.sim_steps):
- # Step policy
- with torch.no_grad():
- ob_tensor = torch.as_tensor(obs).to(device)
- logits, value = policy.forward_eval(ob_tensor, state)
- action, logprob, _ = pufferlib.pytorch.sample_logits(logits)
- action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape)
-
- if isinstance(logits, torch.distributions.Normal):
- action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high)
+ is_transformer = hasattr(policy, "horizon") and hasattr(policy, "transformer")
+ is_recurrent = hasattr(policy, "lstm")
- obs, rewards, dones, truncs, info_list = puffer_env.step(action_np)
-
- if len(info_list) > 0: # Happens at the end of episode
- results = info_list[0]
- return results
+ def _fresh_state():
+ if is_recurrent:
+ return dict(
+ lstm_h=torch.zeros(num_agents, policy.hidden_size, device=device),
+ lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device),
+ )
+ if is_transformer:
+ return dict(
+ transformer_context=torch.zeros(num_agents, policy.horizon, policy.hidden_size, device=device),
+ transformer_position=torch.zeros(1, dtype=torch.long, device=device),
+ )
+ return {}
+
+ per_rollout_aggregates = []
+ per_rollout_scenario = []
+ per_rollout_delta = []
+
+ # Per-(rollout, scenario, agent) success tracking. An agent is marked
+ # "successful in scenario s" when it terminates with a positive reward
+ # at any step of s — this is the goal-reach signal in stop-on-goal eval
+ # (goal_behavior=2). Agents that collide / go offroad terminate with
+ # negative reward; agents that time out stay non-terminal. The 0.5
+ # threshold is conservative: reward_goal default is 1.0 and the only
+ # other positive per-step reward is reward_lane_align (0.01-ish), so
+ # a single tick can't accumulate to 0.5 from lane reward alone.
+ goal_reward_threshold = float(args.get("eval", {}).get("recovery_goal_reward_threshold", 0.5))
+ # CONTROL: when env var RECOVERY_CACHE_RESET_PER_SCENARIO=1, reset
+ # the policy's K/V cache (= "_fresh_state") at every scenario
+ # boundary. This kills any cross-scenario context the Transformer
+ # would have used, isolating "is the cache helping?" from "is the
+ # per-scenario obs alone enough?". Env var because pufferl's
+ # argparser doesn't auto-create new --eval.* flags.
+ cache_reset_per_scenario = os.environ.get("RECOVERY_CACHE_RESET_PER_SCENARIO", "0") == "1"
+ if cache_reset_per_scenario:
+ print("[recovery] CONTROL mode: resetting K/V cache at every scenario boundary", flush=True)
+ success_arr = np.zeros((num_rollouts, k_scenarios, num_agents), dtype=bool)
+
+ for rollout_idx in range(num_rollouts):
+ obs, _ = puffer_env.reset()
+ state = _fresh_state()
+ collected_infos = []
+ scenario_metrics = {}
+ delta_metrics = {}
+
+ for scenario in range(k_scenarios):
+ if scenario > 0 and cache_reset_per_scenario:
+ state = _fresh_state()
+ for time_idx in range(self.sim_steps):
+ with torch.no_grad():
+ ob_tensor = torch.as_tensor(obs).to(device)
+ logits, value = policy.forward_eval(ob_tensor, state)
+ action, logprob, _ = pufferlib.pytorch.sample_logits(logits)
+ action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape)
+
+ if isinstance(logits, torch.distributions.Normal):
+ action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high)
+
+ obs, rewards, dones, truncs, info_list = puffer_env.step(action_np)
+
+ # Mark per-agent success this scenario: a +reward_goal spike
+ # at any tick == goal reached. In stop-on-goal mode the env
+ # does NOT set `dones` per agent (the agent just stops moving),
+ # so we can't gate on dones. The only step-level reward that
+ # crosses `goal_reward_threshold` is the goal reward itself
+ # (lane_align is ~0.01/step, so even integrated it can't
+ # reach 0.5 in one tick). We OR across the scenario so the
+ # success flag sticks even if subsequent ticks are 0.
+ rewards_arr = np.asarray(rewards).reshape(-1)
+ success_arr[rollout_idx, scenario] |= rewards_arr > goal_reward_threshold
+
+ for info_dict in info_list:
+ if not isinstance(info_dict, dict):
+ continue
+ if "ada_delta_score" in info_dict:
+ delta_metrics = info_dict
+ elif any(k.startswith("scenario_") for k in info_dict.keys()):
+ scenario_metrics.update(info_dict)
+ elif "score" in info_dict:
+ collected_infos.append(info_dict)
+
+ if collected_infos:
+ rollout_agg = {
+ k: float(np.mean([d.get(k, 0) for d in collected_infos])) for k in collected_infos[0].keys()
+ }
+ else:
+ rollout_agg = {}
+ per_rollout_aggregates.append(rollout_agg)
+ per_rollout_scenario.append(scenario_metrics)
+ per_rollout_delta.append(delta_metrics)
+
+ # Mean + std across rollouts (std only meaningful for >1 rollout)
+ final = {}
+ for dicts in (per_rollout_aggregates, per_rollout_scenario, per_rollout_delta):
+ keys = {k for d in dicts for k in d.keys()}
+ for k in keys:
+ vals = [float(d[k]) for d in dicts if k in d]
+ if not vals:
+ continue
+ final[k] = float(np.mean(vals))
+ if len(vals) > 1:
+ final[f"{k}_std"] = float(np.std(vals, ddof=0))
+
+ final["n_rollouts"] = num_rollouts
+ final["n_agents_per_rollout"] = num_agents
+ final["n_total_evals"] = num_rollouts * num_agents
+
+ # ----- Raw per-(rollout, agent, scenario) success log -----
+ # No fancy aggregation here. We dump the full success grid as a flat
+ # list of records, one per (rollout, agent), with success bools per
+ # scenario. Any conditional rate (e.g., P(succeed s_k | fail s_0)) is
+ # a one-liner over this data downstream.
+ #
+ # Schema: list of {"rollout": int, "agent": int, "s0": int, ...,
+ # "s_{k-1}": int} — one record per (rollout, agent) pair.
+ records = []
+ for r in range(num_rollouts):
+ for a in range(num_agents):
+ rec = {"rollout": int(r), "agent": int(a)}
+ for s_idx in range(k_scenarios):
+ rec[f"s{s_idx}"] = int(success_arr[r, s_idx, a])
+ records.append(rec)
+ final["per_agent_success_log"] = records
+
+ return final
diff --git a/pufferlib/ocean/drive/README.md b/pufferlib/ocean/drive/README.md
deleted file mode 100644
index a37907b206..0000000000
--- a/pufferlib/ocean/drive/README.md
+++ /dev/null
@@ -1,108 +0,0 @@
-# PufferDrive
-
-This readme contains several important assumptions and definions about the `PufferDrive` environment.
-
-## Agent initialization and control
-
-### `init_mode`
-
-Determines which agents are **created** in the environment.
-
-| Option | Description |
-| ------------------------ | ---------------------------------------------------------------------------- |
-| `create_all_valid` | Create all entities valid at initialization (`traj_valid[init_steps] == 1`). |
-| `create_only_controlled` | Create only those agents that are controlled by the policy. |
-
-### `control_mode`
-
-Determines which created agents are **controlled** by the policy.
-
-| Option | Description |
-| ----------------------------------------- | ------------------------------------------------------------------------------------------------- |
-| `control_vehicles` (default) | Control only valid **vehicles** (not experts, beyond `MIN_DISTANCE_TO_GOAL`, under `MAX_AGENTS`). |
-| `control_agents` | Control all valid **agent types** (vehicles, cyclists, pedestrians). |
-| `control_tracks_to_predict` *(WOMD only)* | Control agents listed in the `tracks_to_predict` metadata. |
-
-
-## Termination conditions (`done`)
-
-Episodes are never truncated before reaching `episode_len`. The `goal_behavior` argument controls agent behavior after reaching a goal early:
-
-* **`goal_behavior=0` (default):** Agents respawn at their initial position after reaching their goal (last valid log position).
-* **`goal_behavior=1`:** Agents receive new goals indefinitely after reaching each goal.
-* **`goal_behavior=2`:** Agents stop after reaching their goal.
-
-## Logged performance metrics
-
-We record multiple performance metrics during training, aggregated over all *active agents* (alive and controlled). Key metrics include:
-
-- `score`: Goals reached cleanly (goal was achieved without collision or going off-road)
-- `collision_rate`: Binary flag (0 or 1) if agent hit another vehicle.
-- `offroad_rate`: Binary flag (0 or 1) if agent left road bounds.
-- `completion_rate`: Whether the agent reached its goal in this episode (even if it collided or went off-road).
-
-
-### Metric aggregation
-
-The `num_agents` parameter in `drive.ini` defines the total number of agents used to collect experience.
-At runtime, **Puffer** uses `num_maps` to create enough environments to populate the buffer with `num_agents`, distributing them evenly across `num_envs`.
-
-Because agents are respawned immediately after reaching their goal, they remain active throughout the episode.
-
-At the end of each episode (i.e., when `timestep == TRAJECTORY_LENGTH`), metrics are logged once via:
-
-```C
-if (env->timestep == TRAJECTORY_LENGTH) {
- add_log(env);
- c_reset(env);
- return;
-}
-```
-
-Metrics are normalized and aggregated in `vec_log` (`pufferlib/ocean/env_binding.h`). They are averaged over all active agents across all environments. For example, the aggregated collision rate is computed as:
-
-$$
-r^{agg}_{\text{collision}} = \frac{\mathbb{I}[\text{collided in episode}]}{N}
-$$
-
-where $N$ is the number of controlled agents.
-This value represents the fraction of agents that collided at least once during the episode. So, cases **A** and **B** below would yield identical off-road and collision rates:
-
-
-
-Since these metrics do not capture *multiple* events per agent, we additionally log the **average number of collision and off-road events per episode**. This is computed as:
-
-$$
-c^{avg}_{\text{collision}} = \frac{\text{total number of collision events across all agents and environments}}{N}
-$$
-
-where $N$ is the total number of controlled agents.
-For example, an `avg_collisions_per_agent` value of 4 indicates that, on average, each agent collides four times per episode.
-
-### Effect of respawning on metrics
-
-By default, agents are reset to their initial position when they reach their goal before the episode ends. Upon respawn, `respawn_timestep` is updated from `-1` to the current step index.
-
-This raises the question: **how does repeated respawning affect aggregated metrics?**
-
-To begin, note that the environment is a bit different before and after respawn. After an agent respawns, all other agents are "removed" from the environment. As a result, collisions with other agents cannot occur post-respawn.
-
-This effectively transforms the scenario into a single-agent environment, simplifying the task since the agent no longer needs to coordinate with others.
-
-
-
-#### `score`
-
-Consider an episode of 91 steps where an agent is initialized relatively close to the goal position and reaches its goal three times:
-
-1. **First attempt:** reaches the goal without collisions
-2. **Second attempt:** reaches the goal without collisions
-3. **Third attempt:** reaches the goal but goes off-road along the way
-
-
-
-The highlighted trajectory shows the first attempt. In this case, the recorded score is `0.0` — a single off-road event invalidates the score for the entire episode. This behavior is desired: the score metric is unforgiving.
-
-#### `offroad_rate` and `collision_rate`
-
-Same logic holds as above.
diff --git a/pufferlib/ocean/drive/adaptive.py b/pufferlib/ocean/drive/adaptive.py
index 184ce41577..7af77c62db 100644
--- a/pufferlib/ocean/drive/adaptive.py
+++ b/pufferlib/ocean/drive/adaptive.py
@@ -12,7 +12,12 @@ def __init__(self, **kwargs):
kwargs["ini_file"] = "pufferlib/config/ocean/adaptive.ini"
kwargs["adaptive_driving_agent"] = True
+ # Human replay mode: disable co-players, use human trajectories for other agents
+ human_replay_mode = kwargs.pop("human_replay_mode", False)
+ if human_replay_mode:
+ kwargs["co_player_enabled"] = False
+
kwargs["resample_frequency"] = self.k_scenarios * self.scenario_length
self.episode_length = kwargs["resample_frequency"]
- # print(f"resample frequency is ", kwargs["resample_frequency"], flush=True)
+
super().__init__(**kwargs)
diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c
index 60c03814b7..5eaed37b0a 100644
--- a/pufferlib/ocean/drive/binding.c
+++ b/pufferlib/ocean/drive/binding.c
@@ -1,27 +1,29 @@
#define Env Drive
#define MY_SHARED
#define MY_PUT
+
+#include
#include "binding.h"
-static int my_put(Env* env, PyObject* args, PyObject* kwargs) {
- PyObject* obs = PyDict_GetItemString(kwargs, "observations");
+static int my_put(Env *env, PyObject *args, PyObject *kwargs) {
+ PyObject *obs = PyDict_GetItemString(kwargs, "observations");
if (!PyObject_TypeCheck(obs, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array");
return 1;
}
- PyArrayObject* observations = (PyArrayObject*)obs;
+ PyArrayObject *observations = (PyArrayObject *)obs;
if (!PyArray_ISCONTIGUOUS(observations)) {
PyErr_SetString(PyExc_ValueError, "Observations must be contiguous");
return 1;
}
env->observations = PyArray_DATA(observations);
- PyObject* act = PyDict_GetItemString(kwargs, "actions");
+ PyObject *act = PyDict_GetItemString(kwargs, "actions");
if (!PyObject_TypeCheck(act, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array");
return 1;
}
- PyArrayObject* actions = (PyArrayObject*)act;
+ PyArrayObject *actions = (PyArrayObject *)act;
if (!PyArray_ISCONTIGUOUS(actions)) {
PyErr_SetString(PyExc_ValueError, "Actions must be contiguous");
return 1;
@@ -32,12 +34,12 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) {
return 1;
}
- PyObject* rew = PyDict_GetItemString(kwargs, "rewards");
+ PyObject *rew = PyDict_GetItemString(kwargs, "rewards");
if (!PyObject_TypeCheck(rew, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array");
return 1;
}
- PyArrayObject* rewards = (PyArrayObject*)rew;
+ PyArrayObject *rewards = (PyArrayObject *)rew;
if (!PyArray_ISCONTIGUOUS(rewards)) {
PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous");
return 1;
@@ -48,12 +50,12 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) {
}
env->rewards = PyArray_DATA(rewards);
- PyObject* term = PyDict_GetItemString(kwargs, "terminals");
+ PyObject *term = PyDict_GetItemString(kwargs, "terminals");
if (!PyObject_TypeCheck(term, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array");
return 1;
}
- PyArrayObject* terminals = (PyArrayObject*)term;
+ PyArrayObject *terminals = (PyArrayObject *)term;
if (!PyArray_ISCONTIGUOUS(terminals)) {
PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous");
return 1;
@@ -66,23 +68,21 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) {
return 0;
}
-static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) {
+static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
int population_play = unpack(kwargs, "population_play");
- if (population_play){
- return my_shared_population_play(self, args, kwargs);
- }
- else{
- return my_shared_self_play( self, args, kwargs);
+ if (population_play) {
+ return my_shared_population_play(self, args, kwargs);
+ } else {
+ return my_shared_self_play(self, args, kwargs);
}
-
}
-static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
+static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->human_agent_idx = unpack(kwargs, "human_agent_idx");
env->ini_file = unpack_str(kwargs, "ini_file");
env_init_config conf = {0};
- if(ini_parse(env->ini_file, handler, &conf) < 0) {
+ if (ini_parse(env->ini_file, handler, &conf) < 0) {
printf("Error while loading %s", env->ini_file);
}
if (kwargs && PyDict_GetItemString(kwargs, "scenario_length")) {
@@ -95,15 +95,18 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
env->action_type = conf.action_type;
env->dynamics_model = conf.dynamics_model;
if (PyDict_GetItemString(kwargs, "dynamics_model")) {
- char* dynamics_str = unpack_str(kwargs, "dynamics_model");
+ char *dynamics_str = unpack_str(kwargs, "dynamics_model");
env->dynamics_model = (strcmp(dynamics_str, "jerk") == 0) ? JERK : CLASSIC;
}
env->reward_vehicle_collision = conf.reward_vehicle_collision;
env->reward_offroad_collision = conf.reward_offroad_collision;
env->reward_goal = conf.reward_goal;
env->reward_goal_post_respawn = conf.reward_goal_post_respawn;
- env->reward_ade = conf.reward_ade;
+ env->reward_lane_align = (float)unpack(kwargs, "reward_lane_align");
+ env->reward_vel_align = (float)unpack(kwargs, "reward_vel_align");
env->scenario_length = conf.scenario_length;
+
+ env->termination_mode = conf.termination_mode;
env->collision_behavior = conf.collision_behavior;
env->offroad_behavior = conf.offroad_behavior;
env->max_controlled_agents = unpack(kwargs, "max_controlled_agents");
@@ -127,10 +130,10 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
if (env->population_play) {
env->num_co_players = unpack(kwargs, "num_co_players");
- double* co_player_ids_d = unpack_float_array(kwargs, "co_player_ids", &env->num_co_players);
+ double *co_player_ids_d = unpack_float_array(kwargs, "co_player_ids", &env->num_co_players);
if (co_player_ids_d != NULL && env->num_co_players > 0) {
- env->co_player_ids = (int*)malloc(env->num_co_players * sizeof(int));
+ env->co_player_ids = (int *)malloc(env->num_co_players * sizeof(int));
if (env->co_player_ids == NULL) {
fprintf(stderr, "Error: Failed to allocate memory for co_player_ids\n");
free(co_player_ids_d);
@@ -152,9 +155,9 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
// Handle ego agents - always as an array
env->num_ego_agents = unpack(kwargs, "num_ego_agents");
if (env->num_ego_agents > 0) {
- double* ego_agent_ids_d = unpack_float_array(kwargs, "ego_agent_ids", &env->num_ego_agents);
+ double *ego_agent_ids_d = unpack_float_array(kwargs, "ego_agent_ids", &env->num_ego_agents);
if (ego_agent_ids_d != NULL) {
- env->ego_agent_ids = (int*)malloc(env->num_ego_agents * sizeof(int));
+ env->ego_agent_ids = (int *)malloc(env->num_ego_agents * sizeof(int));
for (int i = 0; i < env->num_ego_agents; i++) {
env->ego_agent_ids[i] = (int)ego_agent_ids_d[i];
}
@@ -173,16 +176,23 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
env->ego_agent_ids = NULL;
}
-
env->init_mode = (int)unpack(kwargs, "init_mode");
env->control_mode = (int)unpack(kwargs, "control_mode");
+ // Render mode: 0=RENDER_OFF, 1=RENDER_HEADLESS, 2=RENDER_WINDOW
+ env->render_mode = RENDER_OFF; // Default to off
+ if (kwargs && PyDict_GetItemString(kwargs, "render_mode")) {
+ env->render_mode = (int)unpack(kwargs, "render_mode");
+ }
env->goal_behavior = (int)unpack(kwargs, "goal_behavior");
+ env->goal_target_distance = (float)unpack(kwargs, "goal_target_distance");
env->goal_radius = (float)unpack(kwargs, "goal_radius");
+ env->goal_speed = (float)unpack(kwargs, "goal_speed");
+ char *map_dir = unpack_str(kwargs, "map_dir");
int map_id = unpack(kwargs, "map_id");
int max_agents = unpack(kwargs, "max_agents");
int init_steps = unpack(kwargs, "init_steps");
- char map_file[100];
- sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id);
+ char map_file[512];
+ snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id);
env->num_agents = max_agents;
env->map_name = strdup(map_file);
env->init_steps = init_steps;
@@ -191,18 +201,21 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
return 0;
}
-static int my_log(PyObject* dict, Log* log) {
+static int my_log(PyObject *dict, Log *log) {
assign_to_dict(dict, "n", log->n);
+ assign_to_dict(dict, "score", log->score);
assign_to_dict(dict, "offroad_rate", log->offroad_rate);
- assign_to_dict(dict, "episode_length", log->episode_length);
assign_to_dict(dict, "collision_rate", log->collision_rate);
+ assign_to_dict(dict, "episode_length", log->episode_length);
assign_to_dict(dict, "episode_return", log->episode_return);
assign_to_dict(dict, "dnf_rate", log->dnf_rate);
- assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error);
assign_to_dict(dict, "completion_rate", log->completion_rate);
assign_to_dict(dict, "lane_alignment_rate", log->lane_alignment_rate);
- assign_to_dict(dict, "score", log->score);
- assign_to_dict(dict, "avg_offroad_per_agent", log->avg_offroad_per_agent);
- assign_to_dict(dict, "avg_collisions_per_agent", log->avg_collisions_per_agent);
+ assign_to_dict(dict, "offroad_per_agent", log->offroad_per_agent);
+ assign_to_dict(dict, "collisions_per_agent", log->collisions_per_agent);
+ assign_to_dict(dict, "goals_sampled_this_episode", log->goals_sampled_this_episode);
+ assign_to_dict(dict, "goals_reached_this_episode", log->goals_reached_this_episode);
+ assign_to_dict(dict, "speed_at_goal", log->speed_at_goal);
+ // assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error);
return 0;
}
diff --git a/pufferlib/ocean/drive/binding.h b/pufferlib/ocean/drive/binding.h
index b5ec9ed65d..492dde2f14 100644
--- a/pufferlib/ocean/drive/binding.h
+++ b/pufferlib/ocean/drive/binding.h
@@ -1,57 +1,75 @@
#include "drive.h"
#include "../env_binding.h"
-static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* kwargs) {
+static PyObject *my_shared_self_play(PyObject *self, PyObject *args, PyObject *kwargs) {
+ char *map_dir = unpack_str(kwargs, "map_dir");
int num_agents = unpack(kwargs, "num_agents");
int num_maps = unpack(kwargs, "num_maps");
int init_mode = unpack(kwargs, "init_mode");
int control_mode = unpack(kwargs, "control_mode");
int init_steps = unpack(kwargs, "init_steps");
+ int goal_behavior = unpack(kwargs, "goal_behavior");
+ float goal_target_distance = unpack(kwargs, "goal_target_distance");
+ int use_all_maps = unpack(kwargs, "use_all_maps");
int max_controlled_agents = unpack(kwargs, "max_controlled_agents");
- clock_gettime(CLOCK_REALTIME, &ts);
- srand(ts.tv_nsec);
+ int map_seed = PyDict_GetItemString(kwargs, "map_seed") ? unpack(kwargs, "map_seed") : -1;
+ printf("Generating environments for %d agents using %s maps from %s, num maps %d \n", num_agents,
+ use_all_maps ? "all" : "random", map_dir, num_maps);
+ fflush(stdout);
+ // Use provided seed or fall back to time+pid
+ if (map_seed >= 0) {
+ srand((unsigned int)map_seed);
+ } else {
+ clock_gettime(CLOCK_REALTIME, &ts);
+ srand((unsigned int)(ts.tv_sec ^ ts.tv_nsec ^ getpid()));
+ }
int total_agent_count = 0;
int env_count = 0;
- int max_envs = num_agents;
+ int max_envs = use_all_maps ? num_maps : num_agents;
+ int map_idx = 0;
int maps_checked = 0;
- PyObject* agent_offsets = PyList_New(max_envs+1);
- PyObject* map_ids = PyList_New(max_envs);
+ PyObject *agent_offsets = PyList_New(max_envs + 1);
+ PyObject *map_ids = PyList_New(max_envs);
// getting env count
- while(total_agent_count < num_agents && env_count < max_envs){
- char map_file[100];
- int map_id = rand() % num_maps;
- Drive* env = calloc(1, sizeof(Drive));
+ while (use_all_maps ? map_idx < max_envs : total_agent_count < num_agents && env_count < max_envs) {
+ char map_file[512];
+ int map_id = use_all_maps ? map_idx++ : rand() % num_maps;
+ Drive *env = calloc(1, sizeof(Drive));
env->init_mode = init_mode;
+ env->max_controlled_agents = max_controlled_agents;
env->control_mode = control_mode;
env->init_steps = init_steps;
- env->max_controlled_agents = max_controlled_agents;
- sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id);
+ env->goal_behavior = goal_behavior;
+ env->goal_target_distance = goal_target_distance;
+ snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id);
env->entities = load_map_binary(map_file, env);
set_active_agents(env);
// Skip map if it doesn't contain any controllable agents
- if(env->active_agent_count == 0) {
- maps_checked++;
-
- // Safeguard: if we've checked all available maps and found no active agents, raise an error
- if(maps_checked >= num_maps) {
- for(int j=0;jnum_entities;j++) {
- free_entity(&env->entities[j]);
+ if (env->active_agent_count == 0) {
+ if (!use_all_maps) {
+ maps_checked++;
+
+ // Safeguard: if we've checked all available maps and found no active agents, raise an error
+ if (maps_checked >= num_maps) {
+ for (int j = 0; j < env->num_entities; j++) {
+ free_entity(&env->entities[j]);
+ }
+ free(env->entities);
+ free(env->active_agent_indices);
+ free(env->static_agent_indices);
+ free(env->expert_static_agent_indices);
+ free(env);
+ Py_DECREF(agent_offsets);
+ Py_DECREF(map_ids);
+ char error_msg[256];
+ sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps);
+ PyErr_SetString(PyExc_ValueError, error_msg);
+ return NULL;
}
- free(env->entities);
- free(env->active_agent_indices);
- free(env->static_agent_indices);
- free(env->expert_static_agent_indices);
- free(env);
- Py_DECREF(agent_offsets);
- Py_DECREF(map_ids);
- char error_msg[256];
- sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps);
- PyErr_SetString(PyExc_ValueError, error_msg);
- return NULL;
}
- for(int j=0;jnum_entities;j++) {
+ for (int j = 0; j < env->num_entities; j++) {
free_entity(&env->entities[j]);
}
free(env->entities);
@@ -60,17 +78,17 @@ static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* k
free(env->expert_static_agent_indices);
free(env);
continue;
- }
+ }
// Store map_id
- PyObject* map_id_obj = PyLong_FromLong(map_id);
+ PyObject *map_id_obj = PyLong_FromLong(map_id);
PyList_SetItem(map_ids, env_count, map_id_obj);
// Store agent offset
- PyObject* offset = PyLong_FromLong(total_agent_count);
+ PyObject *offset = PyLong_FromLong(total_agent_count);
PyList_SetItem(agent_offsets, env_count, offset);
total_agent_count += env->active_agent_count;
env_count++;
- for(int j=0;jnum_entities;j++) {
+ for (int j = 0; j < env->num_entities; j++) {
free_entity(&env->entities[j]);
}
free(env->entities);
@@ -79,26 +97,26 @@ static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* k
free(env->expert_static_agent_indices);
free(env);
}
- //printf("Generated %d environments to cover %d agents (requested %d agents)\n", env_count, total_agent_count, num_agents);
- if(total_agent_count >= num_agents){
+ // printf("Generated %d environments to cover %d agents (requested %d agents)\n", env_count, total_agent_count,
+ // num_agents);
+ if (!use_all_maps && total_agent_count >= num_agents) {
total_agent_count = num_agents;
}
- PyObject* final_total_agent_count = PyLong_FromLong(total_agent_count);
+ PyObject *final_total_agent_count = PyLong_FromLong(total_agent_count);
PyList_SetItem(agent_offsets, env_count, final_total_agent_count);
- PyObject* final_env_count = PyLong_FromLong(env_count);
+ PyObject *final_env_count = PyLong_FromLong(env_count);
// resize lists
- PyObject* resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1);
- PyObject* resized_map_ids = PyList_GetSlice(map_ids, 0, env_count);
- PyObject* tuple = PyTuple_New(3);
+ PyObject *resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1);
+ PyObject *resized_map_ids = PyList_GetSlice(map_ids, 0, env_count);
+ PyObject *tuple = PyTuple_New(3);
PyTuple_SetItem(tuple, 0, resized_agent_offsets);
PyTuple_SetItem(tuple, 1, resized_map_ids);
PyTuple_SetItem(tuple, 2, final_env_count);
return tuple;
}
-
-static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_size) {
- PyObject* val = PyDict_GetItemString(kwargs, key);
+static double *unpack_float_array(PyObject *kwargs, char *key, Py_ssize_t *out_size) {
+ PyObject *val = PyDict_GetItemString(kwargs, key);
if (val == NULL) {
char error_msg[100];
snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key);
@@ -123,15 +141,14 @@ static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_s
return NULL;
}
- double* array = (double*)malloc(size * sizeof(double));
+ double *array = (double *)malloc(size * sizeof(double));
if (array == NULL) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate memory for float array");
return NULL;
}
-
for (Py_ssize_t i = 0; i < size; i++) {
- PyObject* item = PySequence_GetItem(val, i);
+ PyObject *item = PySequence_GetItem(val, i);
if (item == NULL) {
free(array);
return NULL;
@@ -168,19 +185,20 @@ static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_s
return array;
}
-
-static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObject* kwargs) {
+static PyObject *my_shared_population_play(PyObject *self, PyObject *args, PyObject *kwargs) {
+ char *map_dir = unpack_str(kwargs, "map_dir");
int num_agents = unpack(kwargs, "num_agents");
int num_maps = unpack(kwargs, "num_maps");
int num_ego_agents = unpack(kwargs, "num_ego_agents");
- int init_mode = unpack(kwargs, "init_mode");
+ int init_mode = unpack(kwargs, "init_mode");
int population_play = unpack(kwargs, "population_play");
int control_mode = unpack(kwargs, "control_mode");
int init_steps = unpack(kwargs, "init_steps");
int max_controlled_agents = unpack(kwargs, "max_controlled_agents");
+ int map_seed = PyDict_GetItemString(kwargs, "map_seed") ? unpack(kwargs, "map_seed") : -1;
int max_scenes_per_process = 0;
- PyObject* max_envs_obj = PyDict_GetItemString(kwargs, "max_scenes_per_process");
+ PyObject *max_envs_obj = PyDict_GetItemString(kwargs, "max_scenes_per_process");
if (max_envs_obj && PyLong_Check(max_envs_obj)) {
long v = PyLong_AsLong(max_envs_obj);
if (v > 0 && v <= INT_MAX) {
@@ -188,17 +206,22 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
}
}
- // Use current time for randomness
- struct timespec ts;
- clock_gettime(CLOCK_REALTIME, &ts);
- srand(ts.tv_nsec);
+ // Use provided seed or fall back to time+pid
+ if (map_seed >= 0) {
+ srand((unsigned int)map_seed);
+ } else {
+ struct timespec ts;
+ clock_gettime(CLOCK_REALTIME, &ts);
+ srand((unsigned int)(ts.tv_sec ^ ts.tv_nsec ^ getpid()));
+ }
int num_coplayers = num_agents - num_ego_agents;
- printf("Creating worlds for %d total agents (%d egos, %d co-players)\n",
- num_agents, num_ego_agents, num_coplayers);
+ // Silenced: noisy during normal training. Re-enable for debug.
+ // printf("Creating worlds for %d total agents (%d egos, %d co-players)\n", num_agents, num_ego_agents,
+ // num_coplayers);
// Create shuffled agent role array (0 = coplayer, 1 = ego)
- int* agent_roles = malloc(num_agents * sizeof(int));
+ int *agent_roles = malloc(num_agents * sizeof(int));
for (int i = 0; i < num_ego_agents; i++) {
agent_roles[i] = 1; // ego
}
@@ -218,23 +241,45 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
int env_count = 0;
int total_egos_assigned = 0;
int total_coplayers_assigned = 0;
+ int agent_role_index = 0; // Track position in agent_roles array
int max_envs = num_agents;
if (max_scenes_per_process > 0 && max_scenes_per_process < max_envs) {
max_envs = max_scenes_per_process;
}
- PyObject* agent_offsets = PyList_New(max_envs + 1);
- PyObject* map_ids = PyList_New(max_envs);
- PyObject* ego_agent_ids = PyList_New(max_envs);
- PyObject* coplayer_ids = PyList_New(max_envs);
+ PyObject *agent_offsets = PyList_New(max_envs + 1);
+ PyObject *map_ids = PyList_New(max_envs);
+ PyObject *ego_agent_ids = PyList_New(max_envs);
+ PyObject *coplayer_ids = PyList_New(max_envs);
+
+ int consecutive_skips = 0; // Safety counter for infinite loop detection
+ int max_consecutive_skips = num_maps * 3; // Allow trying each map multiple times
// Create worlds by randomly sampling maps
while (total_agent_count < num_agents && env_count < max_envs) {
+ // Safety check: if we've skipped too many times in a row, something is wrong
+ if (consecutive_skips > max_consecutive_skips) {
+ fprintf(stderr,
+ "[shared_population_play] ERROR: Too many consecutive skips (%d). "
+ "All maps may have 0 active agents. agent_role_index=%d, total_agent_count=%d\n",
+ consecutive_skips, agent_role_index, total_agent_count);
+
+ Py_DECREF(agent_offsets);
+ Py_DECREF(map_ids);
+ Py_DECREF(ego_agent_ids);
+ Py_DECREF(coplayer_ids);
+ free(agent_roles);
+ PyErr_Format(PyExc_RuntimeError,
+ "shared_population_play: unable to find maps with active agents after %d attempts",
+ consecutive_skips);
+ return NULL;
+ }
+
char map_file[100];
int map_id = rand() % num_maps;
- Drive* env = calloc(1, sizeof(Drive));
- sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id);
+ Drive *env = calloc(1, sizeof(Drive));
+ snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id);
env->entities = load_map_binary(map_file, env);
int remaining_capacity = num_agents - total_agent_count;
@@ -250,15 +295,28 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
set_active_agents(env);
+ // CRITICAL FIX: Skip maps with 0 active agents
+ if (env->active_agent_count == 0) {
+ // Silenced: noisy during normal training. Re-enable for debug.
+ // printf("Skipping map %d (0 active agents)\n", map_id);
+ for (int j = 0; j < env->num_entities; j++) {
+ free_entity(&env->entities[j]);
+ }
+ free(env->entities);
+ free(env->active_agent_indices);
+ free(env->static_agent_indices);
+ free(env->expert_static_agent_indices);
+ free(env);
+ consecutive_skips++;
+ continue;
+ }
+
int next_total = total_agent_count + env->active_agent_count;
if (next_total > num_agents) {
int remaining = num_agents - total_agent_count;
- fprintf(stderr,
- "[shared_population_play] ERROR oversubscribed agents: requested=%d remaining=%d map=%d\n",
- env->active_agent_count,
- remaining,
- map_id);
- for(int j=0; jnum_entities; j++) {
+ fprintf(stderr, "[shared_population_play] ERROR oversubscribed agents: requested=%d remaining=%d map=%d\n",
+ env->active_agent_count, remaining, map_id);
+ for (int j = 0; j < env->num_entities; j++) {
free_entity(&env->entities[j]);
}
free(env->entities);
@@ -272,25 +330,22 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
Py_DECREF(coplayer_ids);
free(agent_roles);
PyErr_Format(PyExc_RuntimeError,
- "shared_population_play oversubscribed: total=%d target=%d map=%d active=%d",
- next_total,
- num_agents,
- map_id,
- env->active_agent_count);
+ "shared_population_play oversubscribed: total=%d target=%d map=%d active=%d", next_total,
+ num_agents, map_id, env->active_agent_count);
return NULL;
}
// Store map_id
- PyObject* map_id_obj = PyLong_FromLong(map_id);
+ PyObject *map_id_obj = PyLong_FromLong(map_id);
PyList_SetItem(map_ids, env_count, map_id_obj);
// Store agent offset
- PyObject* offset = PyLong_FromLong(total_agent_count);
+ PyObject *offset = PyLong_FromLong(total_agent_count);
PyList_SetItem(agent_offsets, env_count, offset);
// Create ego and coplayer lists for this world
- PyObject* ego_list = PyList_New(0);
- PyObject* coplayer_list = PyList_New(0);
+ PyObject *ego_list = PyList_New(0);
+ PyObject *coplayer_list = PyList_New(0);
int world_egos = 0;
int world_coplayers = 0;
@@ -298,9 +353,9 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
// Assign agents from the shuffled roles
for (int a = 0; a < env->active_agent_count; a++) {
- PyObject* agent_id = PyLong_FromLong(total_agent_count);
+ PyObject *agent_id = PyLong_FromLong(total_agent_count);
- if (agent_roles[total_agent_count] == 1) {
+ if (agent_roles[agent_role_index] == 1) {
// This agent is an ego
PyList_Append(ego_list, agent_id);
world_egos++;
@@ -314,22 +369,26 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
Py_DECREF(agent_id);
total_agent_count++;
+ agent_role_index++;
}
// Enforce constraint: must have at least 1 ego per world (if egos remain)
if (world_egos == 0 && remaining_egos > 0) {
- fprintf(stderr,
- "[shared_population_play] WARNING: World %d has no ego agents but %d egos remain. Skipping world.\n",
- env_count, remaining_egos);
+ // Silenced: noisy during normal training. Re-enable for debug.
+ // fprintf(
+ // stderr,
+ // "[shared_population_play] WARNING: World %d has no ego agents but %d egos remain. Skipping world.\n",
+ // env_count, remaining_egos);
// Rollback the agent assignments for this world
total_agent_count -= env->active_agent_count;
total_coplayers_assigned -= world_coplayers;
+ agent_role_index -= env->active_agent_count;
Py_DECREF(ego_list);
Py_DECREF(coplayer_list);
- for(int j=0; jnum_entities; j++) {
+ for (int j = 0; j < env->num_entities; j++) {
free_entity(&env->entities[j]);
}
free(env->entities);
@@ -337,18 +396,24 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
free(env->static_agent_indices);
free(env->expert_static_agent_indices);
free(env);
+
+ consecutive_skips++;
continue; // Try another map
}
+ // Successfully created a world, reset skip counter
+ consecutive_skips = 0;
+
PyList_SetItem(ego_agent_ids, env_count, ego_list);
PyList_SetItem(coplayer_ids, env_count, coplayer_list);
- printf("World %d (map %d): %d agents (%d egos, %d co-players)\n",
- env_count, map_id, env->active_agent_count, world_egos, world_coplayers);
+ // Silenced: noisy during normal training. Re-enable for debug.
+ // printf("World %d (map %d): %d agents (%d egos, %d co-players)\n", env_count, map_id, env->active_agent_count,
+ // world_egos, world_coplayers);
env_count++;
- for(int j=0; jnum_entities; j++) {
+ for (int j = 0; j < env->num_entities; j++) {
free_entity(&env->entities[j]);
}
free(env->entities);
@@ -362,15 +427,15 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
total_agent_count = num_agents;
}
- PyObject* final_total_agent_count = PyLong_FromLong(total_agent_count);
+ PyObject *final_total_agent_count = PyLong_FromLong(total_agent_count);
PyList_SetItem(agent_offsets, env_count, final_total_agent_count);
- PyObject* final_env_count = PyLong_FromLong(env_count);
+ PyObject *final_env_count = PyLong_FromLong(env_count);
// Resize lists
- PyObject* resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1);
- PyObject* resized_map_ids = PyList_GetSlice(map_ids, 0, env_count);
- PyObject* resized_ego_ids = PyList_GetSlice(ego_agent_ids, 0, env_count);
- PyObject* resized_coplayer_ids = PyList_GetSlice(coplayer_ids, 0, env_count);
+ PyObject *resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1);
+ PyObject *resized_map_ids = PyList_GetSlice(map_ids, 0, env_count);
+ PyObject *resized_ego_ids = PyList_GetSlice(ego_agent_ids, 0, env_count);
+ PyObject *resized_coplayer_ids = PyList_GetSlice(coplayer_ids, 0, env_count);
Py_DECREF(agent_offsets);
Py_DECREF(map_ids);
@@ -381,15 +446,16 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj
free(agent_roles);
// Create a tuple
- PyObject* tuple = PyTuple_New(5);
+ PyObject *tuple = PyTuple_New(5);
PyTuple_SetItem(tuple, 0, resized_agent_offsets);
PyTuple_SetItem(tuple, 1, resized_map_ids);
PyTuple_SetItem(tuple, 2, final_env_count);
PyTuple_SetItem(tuple, 3, resized_ego_ids);
PyTuple_SetItem(tuple, 4, resized_coplayer_ids);
- printf("Total: %d agents across %d worlds (egos: %d, co-players: %d)\n",
- total_agent_count, env_count, total_egos_assigned, total_coplayers_assigned);
+ // Silenced: noisy during normal training. Re-enable for debug.
+ // printf("Total: %d agents across %d worlds (egos: %d, co-players: %d)\n", total_agent_count, env_count,
+ // total_egos_assigned, total_coplayers_assigned);
return tuple;
}
diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c
deleted file mode 100644
index 5e193c0371..0000000000
--- a/pufferlib/ocean/drive/drive.c
+++ /dev/null
@@ -1,171 +0,0 @@
-#include "drive.h"
-#include "drivenet.h"
-#include
-#include "../env_config.h"
-
-// Use this test if the network changes to ensure that the forward pass
-// matches the torch implementation to the 3rd or ideally 4th decimal place
-void test_drivenet() {
- int num_obs = 1848;
- int num_actions = 2;
- int num_agents = 4;
-
- float* observations = calloc(num_agents*num_obs, sizeof(float));
- for (int i = 0; i < num_obs*num_agents; i++) {
- observations[i] = i % 7;
- }
-
- int* actions = calloc(num_agents*num_actions, sizeof(int));
-
- //Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin");
- Weights* weights = load_weights("puffer_drive_weights.bin");
- DriveNet* net = init_drivenet(weights, num_agents, CLASSIC, false, false, false);
-
- forward(net, observations, actions);
- for (int i = 0; i < num_agents*num_actions; i++) {
- printf("idx: %d, action: %d, logits:", i, actions[i]);
- for (int j = 0; j < num_actions; j++) {
- printf(" %.6f", net->actor->output[i*num_actions + j]);
- }
- printf("\n");
- }
- free_drivenet(net);
- free(weights);
-}
-
-void demo() {
- // Read configuration from INI file
- env_init_config conf = {0};
- const char* ini_file = "pufferlib/config/ocean/drive.ini";
- if(ini_parse(ini_file, handler, &conf) < 0) {
- fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file);
- exit(1);
- }
-
- Drive env = {
- .human_agent_idx = 0,
- .dynamics_model = conf.dynamics_model,
- .reward_vehicle_collision = conf.reward_vehicle_collision,
- .reward_offroad_collision = conf.reward_offroad_collision,
- .reward_ade = conf.reward_ade,
- .goal_radius = conf.goal_radius,
- .dt = conf.dt,
- .map_name = "resources/drive/binaries/map_000.bin",
- .init_steps = conf.init_steps,
- .collision_behavior = conf.collision_behavior,
- .offroad_behavior = conf.offroad_behavior,
- };
- allocate(&env);
- c_reset(&env);
- c_render(&env);
- Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin");
- DriveNet* net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, false, false, false);
- //Client* client = make_client(&env);
- int accel_delta = 2;
- int steer_delta = 4;
- while (!WindowShouldClose()) {
- // Handle camera controls
- int (*actions)[2] = (int(*)[2])env.actions;
- forward(net, env.observations, env.actions);
- if (IsKeyDown(KEY_LEFT_SHIFT)) {
- actions[env.human_agent_idx][0] = 3;
- actions[env.human_agent_idx][1] = 6;
- if(IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)){
- actions[env.human_agent_idx][0] += accel_delta;
- // Cap acceleration to maximum of 6
- if(actions[env.human_agent_idx][0] > 6) {
- actions[env.human_agent_idx][0] = 6;
- }
- }
- if(IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)){
- actions[env.human_agent_idx][0] -= accel_delta;
- // Cap acceleration to minimum of 0
- if(actions[env.human_agent_idx][0] < 0) {
- actions[env.human_agent_idx][0] = 0;
- }
- }
- if(IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)){
- actions[env.human_agent_idx][1] += steer_delta;
- // Cap steering to minimum of 0
- if(actions[env.human_agent_idx][1] < 0) {
- actions[env.human_agent_idx][1] = 0;
- }
- }
- if(IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)){
- actions[env.human_agent_idx][1] -= steer_delta;
- // Cap steering to maximum of 12
- if(actions[env.human_agent_idx][1] > 12) {
- actions[env.human_agent_idx][1] = 12;
- }
- }
- if(IsKeyPressed(KEY_TAB)){
- env.human_agent_idx = (env.human_agent_idx + 1) % env.active_agent_count;
- }
- }
- c_step(&env);
- c_render(&env);
- }
-
- close_client(env.client);
- free_allocated(&env);
- free_drivenet(net);
- free(weights);
-}
-
-void performance_test() {
- // Read configuration from INI file
- env_init_config conf = {0};
- const char* ini_file = "pufferlib/config/ocean/drive.ini";
- if(ini_parse(ini_file, handler, &conf) < 0) {
- fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file);
- exit(1);
- }
-
- long test_time = 10;
- Drive env = {
- .human_agent_idx = 0,
- .dynamics_model = conf.dynamics_model,
- .reward_vehicle_collision = conf.reward_vehicle_collision,
- .reward_offroad_collision = conf.reward_offroad_collision,
- .reward_ade = conf.reward_ade,
- .goal_radius = conf.goal_radius,
- .dt = conf.dt,
- .map_name = "resources/drive/binaries/map_000.bin",
- .init_steps = conf.init_steps,
- };
- clock_t start_time, end_time;
- double cpu_time_used;
- start_time = clock();
- allocate(&env);
- c_reset(&env);
- end_time = clock();
- cpu_time_used = ((double) (end_time - start_time)) / CLOCKS_PER_SEC;
- printf("Init time: %f\n", cpu_time_used);
-
- long start = time(NULL);
- int i = 0;
- int (*actions)[2] = (int(*)[2])env.actions;
-
- while (time(NULL) - start < test_time) {
- // Set random actions for all agents
- for(int j = 0; j < env.active_agent_count; j++) {
- int accel = rand() % 7;
- int steer = rand() % 13;
- actions[j][0] = accel; // -1, 0, or 1
- actions[j][1] = steer; // Random steering
- }
-
- c_step(&env);
- i++;
- }
- long end = time(NULL);
- printf("SPS: %ld\n", (i*env.active_agent_count) / (end - start));
- free_allocated(&env);
-}
-
-int main() {
- //performance_test();
- demo();
- //test_drivenet();
- return 0;
-}
diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h
index cb112b9a7e..895604077f 100644
--- a/pufferlib/ocean/drive/drive.h
+++ b/pufferlib/ocean/drive/drive.h
@@ -6,12 +6,26 @@
#include
#include
#include
+#include
+#include
+#include
#include "raylib.h"
#include "raymath.h"
#include "rlgl.h"
#include
+#include
#include "error.h"
+// Render modes
+#define RENDER_OFF 0
+#define RENDER_HEADLESS 1
+#define RENDER_WINDOW 2
+
+// View modes for rendering
+#define VIEW_MODE_SIM_STATE 0 // Full simulation state (top-down orthographic)
+#define VIEW_MODE_BEV_AGENT_OBS 1 // Bird's eye view centered on agent
+#define VIEW_MODE_AGENT_PERSP 2 // Agent perspective (3rd person)
+
// Entity Types
#define NONE 0
#define VEHICLE 1
@@ -37,7 +51,7 @@
// Control modes
#define CONTROL_VEHICLES 0
#define CONTROL_AGENTS 1
-#define CONTROL_TRACKS_TO_PREDICT 2
+#define CONTROL_WOSAC 2
#define CONTROL_SDC_ONLY 3
// Minimum distance to goal position
@@ -60,11 +74,23 @@
#define OFFROAD_IDX 1
#define REACHED_GOAL_IDX 2
#define LANE_ALIGNED_IDX 3
-#define AVG_DISPLACEMENT_ERROR_IDX 4
+#define LANE_DIST_IDX 4
+#define LANE_ANGLE_IDX 5
+#define AVG_DISPLACEMENT_ERROR_IDX 6
+
+// Lane alignment constants (from GIGAFLOW)
+#define LANE_DISTANCE_NORMALIZATION 4.0f
+#define LANE_SELECTION_DISTANCE_WEIGHT 0.7f
+#define LANE_SELECTION_HEADING_WEIGHT 0.3f
+#define LANE_SWITCH_THRESHOLD 0.5f
+#define LANE_ALIGN_COS_THRESHOLD 0.965f // ~15 degrees
+#define MAX_CHECKED_LANES 32
// Grid cell size
#define GRID_CELL_SIZE 5.0f
-#define MAX_ENTITIES_PER_CELL 30 // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) => For each entity type in gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2
+#define MAX_ENTITIES_PER_CELL \
+ 30 // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) => For each entity type in
+ // gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2
// Max road segment observation entities
#define MAX_ROAD_SEGMENT_OBSERVATIONS 200
@@ -86,34 +112,58 @@
#define STOP_AGENT 1
#define REMOVE_AGENT 2
-//GOAL BEHAVIOUR
+// GOAL BEHAVIOUR
#define GOAL_RESPAWN 0
#define GOAL_GENERATE_NEW 1
#define GOAL_STOP 2
+#define PARTNER_FEATURES 7
+
+#define ROAD_FEATURES 7
+#define ROAD_FEATURES_ONEHOT 13
+#define PARTNER_FEATURES 7
+
+// Ego features depend on dynamics model
+// Classic: goal_x, goal_y, speed, width, length, collision, respawn, lane_dist, lane_angle
+#define EGO_FEATURES_CLASSIC 9
+// Jerk: + steering_angle, a_long, a_lat
+#define EGO_FEATURES_JERK 12
+
// Jerk action space (for JERK dynamics model)
static const float JERK_LONG[4] = {-15.0f, -4.0f, 0.0f, 4.0f};
static const float JERK_LAT[3] = {-4.0f, 0.0f, 4.0f};
// Classic action space (for CLASSIC dynamics model)
static const float ACCELERATION_VALUES[7] = {-4.0000f, -2.6670f, -1.3330f, -0.0000f, 1.3330f, 2.6670f, 4.0000f};
-static const float STEERING_VALUES[13] = {-1.000f, -0.833f, -0.667f, -0.500f, -0.333f, -0.167f, 0.000f, 0.167f, 0.333f, 0.500f, 0.667f, 0.833f, 1.000f};
+static const float STEERING_VALUES[13] = {-1.000f, -0.833f, -0.667f, -0.500f, -0.333f, -0.167f, 0.000f,
+ 0.167f, 0.333f, 0.500f, 0.667f, 0.833f, 1.000f};
static const float offsets[4][2] = {
- {-1, 1}, // top-left
- {1, 1}, // top-right
- {1, -1}, // bottom-right
- {-1, -1} // bottom-left
- };
+ {-1, 1}, // top-left
+ {1, 1}, // top-right
+ {1, -1}, // bottom-right
+ {-1, -1} // bottom-left
+};
static const int collision_offsets[25][2] = {
- {-2, -2}, {-1, -2}, {0, -2}, {1, -2}, {2, -2}, // Top row
- {-2, -1}, {-1, -1}, {0, -1}, {1, -1}, {2, -1}, // Second row
- {-2, 0}, {-1, 0}, {0, 0}, {1, 0}, {2, 0}, // Middle row (including center)
- {-2, 1}, {-1, 1}, {0, 1}, {1, 1}, {2, 1}, // Fourth row
- {-2, 2}, {-1, 2}, {0, 2}, {1, 2}, {2, 2} // Bottom row
+ {-2, -2}, {-1, -2}, {0, -2}, {1, -2}, {2, -2}, // Top row
+ {-2, -1}, {-1, -1}, {0, -1}, {1, -1}, {2, -1}, // Second row
+ {-2, 0}, {-1, 0}, {0, 0}, {1, 0}, {2, 0}, // Middle row (including center)
+ {-2, 1}, {-1, 1}, {0, 1}, {1, 1}, {2, 1}, // Fourth row
+ {-2, 2}, {-1, 2}, {0, 2}, {1, 2}, {2, 2} // Bottom row
};
+const Color STONE_GRAY = (Color){80, 80, 80, 255};
+const Color PUFF_RED = (Color){187, 0, 0, 255};
+const Color PUFF_CYAN = (Color){0, 187, 187, 255};
+const Color PUFF_WHITE = (Color){241, 241, 241, 241};
+const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255};
+const Color PUFF_BACKGROUND2 = (Color){18, 72, 72, 255};
+const Color LIGHTGREEN = (Color){152, 255, 152, 255};
+const Color LIGHTYELLOW = (Color){255, 255, 152, 255};
+const Color LIGHTBLUE = (Color){152, 200, 255, 255};
+const Color SOFT_YELLOW = (Color){245, 245, 220, 255};
+
struct timespec ts;
typedef struct Drive Drive;
@@ -122,20 +172,22 @@ typedef struct Log Log;
typedef struct Graph Graph;
typedef struct AdjListNode AdjListNode;
typedef struct Co_Player_Log Co_Player_Log;
-typedef struct Adaptive_Agent_Log Adaptive_Agent_Log;
struct Log {
float episode_return;
float episode_length;
float score;
+ float goals_reached_this_episode;
+ float goals_sampled_this_episode;
float offroad_rate;
float collision_rate;
- float num_goals_reached;
float completion_rate;
+ float offroad_per_agent;
+ float collisions_per_agent;
float dnf_rate;
float n;
float lane_alignment_rate;
- float avg_displacement_error;
+ float speed_at_goal;
float active_agent_count;
float expert_static_agent_count;
float static_agent_count;
@@ -145,8 +197,6 @@ struct Log {
float avg_goal_weight;
float avg_entropy_weight;
float avg_discount_weight;
- float avg_offroad_per_agent;
- float avg_collisions_per_agent;
};
typedef struct Entity Entity;
@@ -155,14 +205,14 @@ struct Entity {
int type;
int id;
int array_size;
- float* traj_x;
- float* traj_y;
- float* traj_z;
- float* traj_vx;
- float* traj_vy;
- float* traj_vz;
- float* traj_heading;
- int* traj_valid;
+ float *traj_x;
+ float *traj_y;
+ float *traj_z;
+ float *traj_vx;
+ float *traj_vy;
+ float *traj_vz;
+ float *traj_heading;
+ int *traj_valid;
float width;
float length;
float height;
@@ -173,7 +223,8 @@ struct Entity {
float init_goal_y;
int mark_as_expert;
int collision_state;
- float metrics_array[5]; // metrics_array: [collision, offroad, reached_goal, lane_aligned, avg_displacement_error]
+ float metrics_array[7]; // metrics_array: [collision, offroad, reached_goal, lane_aligned, lane_dist, lane_angle,
+ // avg_disp_error]
float x;
float y;
float z;
@@ -184,13 +235,14 @@ struct Entity {
float heading_x;
float heading_y;
int current_lane_idx;
+ int current_lane_geometry_idx;
int valid;
int respawn_timestep;
int respawn_count;
int collided_before_goal;
- int sampled_new_goal;
- int reached_goal_this_episode;
- int num_goals_reached;
+ float goals_reached_this_episode;
+ float goals_sampled_this_episode;
+ int current_goal_reached;
int active_agent;
float cumulative_displacement;
int displacement_sample_count;
@@ -206,12 +258,12 @@ struct Entity {
float steering_angle;
float wheelbase;
- //population play
+ // population play
bool is_ego;
bool is_co_player;
};
-void free_entity(Entity* entity){
+void free_entity(Entity *entity) {
// free trajectory arrays
free(entity->traj_x);
free(entity->traj_y);
@@ -222,22 +274,6 @@ void free_entity(Entity* entity){
free(entity->traj_heading);
free(entity->traj_valid);
}
-struct Co_Player_Log {
- float co_player_episode_return;
- float co_player_episode_length;
- float co_player_perf;
- float co_player_score;
- float co_player_offroad_rate;
- float co_player_collision_rate;
- float co_player_clean_collision_rate;
- float co_player_num_goals_reached;
- float co_player_completion_rate;
- float co_player_dnf_rate;
- float co_player_lane_alignment_rate;
- float co_player_avg_displacement_error;
- float co_player_n;
-};
-
// Utility functions
float compute_delta_percent(float first, float last) {
@@ -247,51 +283,26 @@ float compute_delta_percent(float first, float last) {
return (last - first) / first * 100.0f;
}
-float relative_distance(float a, float b){
+float relative_distance(float a, float b) {
float distance = sqrtf(powf(a - b, 2));
return distance;
}
-float relative_distance_2d(float x1, float y1, float x2, float y2){
+float relative_distance_2d(float x1, float y1, float x2, float y2) {
float dx = x2 - x1;
float dy = y2 - y1;
- float distance = sqrtf(dx*dx + dy*dy);
+ float distance = sqrtf(dx * dx + dy * dy);
return distance;
}
float clip(float value, float min, float max) {
- if (value < min) return min;
- if (value > max) return max;
+ if (value < min)
+ return min;
+ if (value > max)
+ return max;
return value;
}
-float compute_displacement_error(Entity* agent, int timestep) {
- // Check if timestep is within valid range
- if (timestep < 0 || timestep >= agent->array_size) {
- return 0.0f;
- }
-
- // Check if reference trajectory is valid at this timestep
- if (!agent->traj_valid[timestep]) {
- return 0.0f;
- }
-
- // Get reference position at current timestep, skip invalid ones
- float ref_x = agent->traj_x[timestep];
- float ref_y = agent->traj_y[timestep];
-
- if (ref_x == INVALID_POSITION || ref_y == INVALID_POSITION) {
- return 0.0f;
- }
-
- // Compute deltas: Euclidean distance between actual and reference position
- float dx = agent->x - ref_x;
- float dy = agent->y - ref_y;
- float displacement = sqrtf(dx*dx + dy*dy);
-
- return displacement;
-}
-
typedef struct GridMapEntity GridMapEntity;
struct GridMapEntity {
int entity_idx;
@@ -308,64 +319,66 @@ struct GridMap {
int grid_rows;
int cell_size_x;
int cell_size_y;
- int* cell_entities_count; // number of entities in each cell of the GridMap
- GridMapEntity** cells; // list of gridEntities in each cell of the GridMap
-
+ int *cell_entities_count; // number of entities in each cell of the GridMap
+ GridMapEntity **cells; // list of gridEntities in each cell of the GridMap
// Extras/Optimizations
int vision_range;
- int* neighbor_cache_count; // number of entities in each cells neighbor cache
- GridMapEntity** neighbor_cache_entities; // preallocated array to hold neighbor entities
+ int *neighbor_cache_count; // number of entities in each cells neighbor cache
+ GridMapEntity **neighbor_cache_entities; // preallocated array to hold neighbor entities
};
struct Drive {
- Client* client;
- float* observations;
- float* actions;
- float* rewards;
- unsigned char* terminals;
+ Client *client;
+ float *observations;
+ float *actions;
+ float *rewards;
+ unsigned char *terminals;
Log log;
- Log* logs;
+ Log *logs;
int num_agents;
int active_agent_count;
- int* active_agent_indices;
+ int *active_agent_indices;
int action_type;
int human_agent_idx;
- Entity* entities;
- Graph* topology_graph;
+ Entity *entities;
int num_entities;
int num_actors;
int num_objects;
int num_roads;
int static_agent_count;
- int* static_agent_indices;
+ int *static_agent_indices;
int expert_static_agent_count;
- int* expert_static_agent_indices;
+ int *expert_static_agent_indices;
int timestep;
int init_steps;
int dynamics_model;
- GridMap* grid_map;
- int* neighbor_offsets;
+ GridMap *grid_map;
+ int *neighbor_offsets;
int scenario_length;
+ int termination_mode;
float reward_vehicle_collision;
float reward_offroad_collision;
- float reward_ade;
- char* map_name;
+ char *map_name;
float world_mean_x;
float world_mean_y;
float dt;
float reward_goal;
float reward_goal_post_respawn;
+ float reward_lane_align;
+ float reward_vel_align;
float goal_radius;
+ float goal_speed;
int max_controlled_agents;
int logs_capacity;
int goal_behavior;
- char* ini_file;
- char* scenario_id;
+ float goal_target_distance;
+ char *ini_file;
+ char *scenario_id;
int collision_behavior;
int offroad_behavior;
int sdc_track_index;
int num_tracks_to_predict;
- int* tracks_to_predict_indices;
+ int *tracks_to_predict_indices;
int init_mode;
int control_mode;
@@ -377,60 +390,82 @@ struct Drive {
float offroad_weight_ub;
float goal_weight_lb;
float goal_weight_ub;
- float* collision_weights;
- float* offroad_weights;
- float* goal_weights;
+ float *collision_weights;
+ float *offroad_weights;
+ float *goal_weights;
// Entropy conditioning
bool use_ec;
float entropy_weight_lb;
float entropy_weight_ub;
- float* entropy_weights;
+ float *entropy_weights;
// Discount conditioning
bool use_dc;
float discount_weight_lb;
float discount_weight_ub;
- float* discount_weights;
- //fixed population play
- Co_Player_Log co_player_log;
- Co_Player_Log* co_player_logs;
+ float *discount_weights;
+ // fixed population play
+ Log co_player_log;
+ Log *co_player_logs;
int num_co_players;
int num_ego_agents;
- int* co_player_ids;
- int* ego_agent_ids;
+ int *co_player_ids;
+ int *ego_agent_ids;
bool population_play;
-
-
+ // Rendering
+ int render_mode; // RENDER_OFF, RENDER_HEADLESS, or RENDER_WINDOW
+ char video_basename[256]; // Full mp4 basename (without ".mp4") set by Python via vec_set_video_suffix. Defaults to
+ // "render".
};
-void add_log(Drive* env) {
-
+void add_log(Drive *env) {
for (int i = 0; i < env->active_agent_count; i++) {
- Entity* e = &env->entities[env->active_agent_indices[i]];
+ Entity *e = &env->entities[env->active_agent_indices[i]];
- if (e->is_ego) {
- // ALWAYS update regular logs for all ego agents
- if (e->reached_goal_this_episode)
- env->log.completion_rate += 1.0f;
+ // Common metrics for all agents
+ env->log.goals_reached_this_episode += e->goals_reached_this_episode;
+ env->log.goals_sampled_this_episode += e->goals_sampled_this_episode;
+ if (e->is_ego) {
+ // EGO agent logging
int offroad = env->logs[i].offroad_rate;
env->log.offroad_rate += offroad;
int collided = env->logs[i].collision_rate;
env->log.collision_rate += collided;
- int num_goals_reached = env->logs[i].num_goals_reached;
- env->log.num_goals_reached += num_goals_reached;
+ float offroad_per_agent = env->logs[i].offroad_per_agent;
+ env->log.offroad_per_agent += offroad_per_agent;
+ float collisions_per_agent = env->logs[i].collisions_per_agent;
+ env->log.collisions_per_agent += collisions_per_agent;
+
+ float frac_goal_reached = e->goals_reached_this_episode / e->goals_sampled_this_episode;
+
+ // Calculate threshold based on goals sampled
+ float threshold = 0.99f; // Default threshold for 1 goal
+ if (e->goals_sampled_this_episode == 2.0f) {
+ threshold = 0.5f; // Require ≥50% completion for 2 goals
+ } else if (e->goals_sampled_this_episode < 5.0f) {
+ threshold = 0.8f; // Require ≥80% completion for 3-4 goals
+ } else {
+ threshold = 0.9f; // Require ≥90% completion for 5+ goals
+ }
+
+ // Use the "before goal" flag in respawn AND stop modes so that a
+ // post-goal rear-end (which the SDC can't avoid once it has stopped)
+ // does not deny the score award.
+ int collision_occurred = (env->goal_behavior == GOAL_RESPAWN || env->goal_behavior == GOAL_STOP)
+ ? e->collided_before_goal
+ : env->logs[i].collision_rate;
- if (e->reached_goal_this_episode && !e->collided_before_goal) {
+ if (frac_goal_reached > threshold && !collision_occurred) {
env->log.score += 1.0f;
}
- if (!offroad && !collided && !e->reached_goal_this_episode) {
+ if (!offroad && !collided && frac_goal_reached < 1.0f) {
env->log.dnf_rate += 1.0f;
}
int lane_aligned = env->logs[i].lane_alignment_rate;
env->log.lane_alignment_rate += lane_aligned;
- float displacement_error = env->logs[i].avg_displacement_error;
- env->log.avg_displacement_error += displacement_error;
+ env->log.speed_at_goal += env->logs[i].speed_at_goal;
env->log.episode_length += env->logs[i].episode_length;
env->log.episode_return += env->logs[i].episode_return;
@@ -438,81 +473,89 @@ void add_log(Drive* env) {
env->log.expert_static_agent_count += env->expert_static_agent_count;
env->log.static_agent_count += env->static_agent_count;
env->log.n += 1.0f;
-
}
- // Process co-player agents (separate if, not else-if!)
+
if (e->is_co_player && env->co_player_logs != NULL) {
- if (e->reached_goal_this_episode)
- env->co_player_log.co_player_completion_rate += 1.0f;
-
- int co_offroad = env->co_player_logs[i].co_player_offroad_rate;
- env->co_player_log.co_player_offroad_rate += co_offroad;
- int co_collided = env->co_player_logs[i].co_player_collision_rate;
- env->co_player_log.co_player_collision_rate += co_collided;
- int co_num_goals_reached = env->co_player_logs[i].co_player_num_goals_reached;
- env->co_player_log.co_player_num_goals_reached += co_num_goals_reached;
-
- env->co_player_log.co_player_clean_collision_rate +=
- env->co_player_logs[i].co_player_clean_collision_rate;
-
- if (e->reached_goal_this_episode && !e->collided_before_goal) {
- env->co_player_log.co_player_score += 1.0f;
- env->co_player_log.co_player_perf += 1.0f;
+ int co_offroad = env->co_player_logs[i].offroad_rate;
+ env->co_player_log.offroad_rate += co_offroad;
+ int co_collided = env->co_player_logs[i].collision_rate;
+ env->co_player_log.collision_rate += co_collided;
+ float co_offroad_per_agent = env->co_player_logs[i].offroad_per_agent;
+ env->co_player_log.offroad_per_agent += co_offroad_per_agent;
+ float co_collisions_per_agent = env->co_player_logs[i].collisions_per_agent;
+ env->co_player_log.collisions_per_agent += co_collisions_per_agent;
+
+ float co_frac_goal_reached = e->goals_reached_this_episode / e->goals_sampled_this_episode;
+
+ // Calculate threshold for co-players
+ float co_threshold = 0.99f;
+ if (e->goals_sampled_this_episode == 2.0f) {
+ co_threshold = 0.5f;
+ } else if (e->goals_sampled_this_episode < 5.0f) {
+ co_threshold = 0.8f;
+ } else {
+ co_threshold = 0.9f;
}
- if (!co_offroad && !co_collided && !e->reached_goal_this_episode) {
- env->co_player_log.co_player_dnf_rate += 1.0f;
+ // Same post-goal-collision exemption as the ego score check above.
+ int co_collision_occurred = (env->goal_behavior == GOAL_RESPAWN || env->goal_behavior == GOAL_STOP)
+ ? e->collided_before_goal
+ : env->co_player_logs[i].collision_rate;
+
+ if (co_frac_goal_reached > co_threshold && !co_collision_occurred) {
+ env->co_player_log.score += 1.0f;
+ }
+
+ if (!co_offroad && !co_collided && co_frac_goal_reached < 1.0f) {
+ env->co_player_log.dnf_rate += 1.0f;
}
- int co_lane_aligned = env->co_player_logs[i].co_player_lane_alignment_rate;
- env->co_player_log.co_player_lane_alignment_rate += co_lane_aligned;
- float co_displacement_error = env->co_player_logs[i].co_player_avg_displacement_error;
- env->co_player_log.co_player_avg_displacement_error += co_displacement_error;
- env->co_player_log.co_player_episode_return +=
- env->co_player_logs[i].co_player_episode_return;
- env->co_player_log.co_player_episode_length +=
- env->co_player_logs[i].co_player_episode_length;
+ int co_lane_aligned = env->co_player_logs[i].lane_alignment_rate;
+ env->co_player_log.lane_alignment_rate += co_lane_aligned;
+ env->co_player_log.speed_at_goal += env->co_player_logs[i].speed_at_goal;
+ env->co_player_log.episode_return += env->co_player_logs[i].episode_return;
+ env->co_player_log.episode_length += env->co_player_logs[i].episode_length;
- env->co_player_log.co_player_n += 1.0f;
+ env->co_player_log.n += 1.0f;
}
}
}
struct AdjListNode {
int dest;
- struct AdjListNode* next;
+ struct AdjListNode *next;
};
struct Graph {
int V;
- struct AdjListNode** array;
+ struct AdjListNode **array;
};
// Function to create a new adjacency list node
-struct AdjListNode* newAdjListNode(int dest) {
- struct AdjListNode* newNode = malloc(sizeof(struct AdjListNode));
+struct AdjListNode *newAdjListNode(int dest) {
+ struct AdjListNode *newNode = malloc(sizeof(struct AdjListNode));
newNode->dest = dest;
newNode->next = NULL;
return newNode;
}
// Function to create a graph of V vertices
-struct Graph* createGraph(int V) {
- struct Graph* graph = malloc(sizeof(struct Graph));
+struct Graph *createGraph(int V) {
+ struct Graph *graph = malloc(sizeof(struct Graph));
graph->V = V;
- graph->array = calloc(V, sizeof(struct AdjListNode*));
+ graph->array = calloc(V, sizeof(struct AdjListNode *));
return graph;
}
// Function to get next lanes from a given lane entity index
// Returns the number of next lanes found, fills next_lanes array with entity indices
-int getNextLanes(struct Graph* graph, int entity_idx, int* next_lanes, int max_lanes) {
+int getNextLanes(struct Graph *graph, int entity_idx, int *next_lanes, int max_lanes) {
if (!graph || entity_idx < 0 || entity_idx >= graph->V) {
return 0;
}
int count = 0;
- struct AdjListNode* node = graph->array[entity_idx];
+ struct AdjListNode *node = graph->array[entity_idx];
while (node && count < max_lanes) {
next_lanes[count] = node->dest;
@@ -524,13 +567,14 @@ int getNextLanes(struct Graph* graph, int entity_idx, int* next_lanes, int max_l
}
// Function to free the topology graph
-void freeTopologyGraph(struct Graph* graph) {
- if (!graph) return;
+void freeTopologyGraph(struct Graph *graph) {
+ if (!graph)
+ return;
for (int i = 0; i < graph->V; i++) {
- struct AdjListNode* node = graph->array[i];
+ struct AdjListNode *node = graph->array[i];
while (node) {
- struct AdjListNode* temp = node;
+ struct AdjListNode *temp = node;
node = node->next;
free(temp);
}
@@ -540,19 +584,24 @@ void freeTopologyGraph(struct Graph* graph) {
free(graph);
}
-
-Entity* load_map_binary(const char* filename, Drive* env) {
- FILE* file = fopen(filename, "rb");
- if (!file) return NULL;
-
+Entity *load_map_binary(const char *filename, Drive *env) {
+ FILE *file = fopen(filename, "rb");
+ if (!file)
+ return NULL;
// Read sdc_track_index
- fread(&env->sdc_track_index, sizeof(int), 1, file);
+ if (fread(&env->sdc_track_index, sizeof(int), 1, file) != 1) {
+ fclose(file);
+ return NULL;
+ }
// Read tracks_to_predict
- fread(&env->num_tracks_to_predict, sizeof(int), 1, file);
+ if (fread(&env->num_tracks_to_predict, sizeof(int), 1, file) != 1) {
+ fclose(file);
+ return NULL;
+ }
if (env->num_tracks_to_predict > 0) {
- env->tracks_to_predict_indices = (int*)malloc(env->num_tracks_to_predict * sizeof(int));
+ env->tracks_to_predict_indices = (int *)malloc(env->num_tracks_to_predict * sizeof(int));
for (int i = 0; i < env->num_tracks_to_predict; i++) {
fread(&env->tracks_to_predict_indices[i], sizeof(int), 1, file);
@@ -564,25 +613,38 @@ Entity* load_map_binary(const char* filename, Drive* env) {
fread(&env->num_objects, sizeof(int), 1, file);
fread(&env->num_roads, sizeof(int), 1, file);
env->num_entities = env->num_objects + env->num_roads;
- Entity* entities = (Entity*)malloc(env->num_entities * sizeof(Entity));
+ Entity *entities = (Entity *)malloc(env->num_entities * sizeof(Entity));
for (int i = 0; i < env->num_entities; i++) {
- // Read base entity data
- fread(&entities[i].scenario_id, sizeof(int), 1, file);
- fread(&entities[i].type, sizeof(int), 1, file);
- fread(&entities[i].id, sizeof(int), 1, file);
- fread(&entities[i].array_size, sizeof(int), 1, file);
+ // Read base entity data
+ if (fread(&entities[i].scenario_id, sizeof(int), 1, file) != 1 ||
+ fread(&entities[i].type, sizeof(int), 1, file) != 1 || fread(&entities[i].id, sizeof(int), 1, file) != 1 ||
+ fread(&entities[i].array_size, sizeof(int), 1, file) != 1) {
+ // File truncated - adjust entity count and break
+ env->num_entities = i;
+ env->num_objects = (i < env->num_objects) ? i : env->num_objects;
+ env->num_roads = env->num_entities - env->num_objects;
+ break;
+ }
+ // Validate array_size is reasonable (max 1000 timesteps = 100s at 0.1s dt)
+ if (entities[i].array_size <= 0 || entities[i].array_size > 1000) {
+ env->num_entities = i;
+ env->num_objects = (i < env->num_objects) ? i : env->num_objects;
+ env->num_roads = env->num_entities - env->num_objects;
+ break;
+ }
// Allocate arrays based on type
int size = entities[i].array_size;
- entities[i].traj_x = (float*)malloc(size * sizeof(float));
- entities[i].traj_y = (float*)malloc(size * sizeof(float));
- entities[i].traj_z = (float*)malloc(size * sizeof(float));
- if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type
+ entities[i].traj_x = (float *)malloc(size * sizeof(float));
+ entities[i].traj_y = (float *)malloc(size * sizeof(float));
+ entities[i].traj_z = (float *)malloc(size * sizeof(float));
+ if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN ||
+ entities[i].type == CYCLIST) { // Object type
// Allocate arrays for object-specific data
- entities[i].traj_vx = (float*)malloc(size * sizeof(float));
- entities[i].traj_vy = (float*)malloc(size * sizeof(float));
- entities[i].traj_vz = (float*)malloc(size * sizeof(float));
- entities[i].traj_heading = (float*)malloc(size * sizeof(float));
- entities[i].traj_valid = (int*)malloc(size * sizeof(int));
+ entities[i].traj_vx = (float *)malloc(size * sizeof(float));
+ entities[i].traj_vy = (float *)malloc(size * sizeof(float));
+ entities[i].traj_vz = (float *)malloc(size * sizeof(float));
+ entities[i].traj_heading = (float *)malloc(size * sizeof(float));
+ entities[i].traj_valid = (int *)malloc(size * sizeof(int));
} else {
// Roads don't use these arrays
entities[i].traj_vx = NULL;
@@ -595,7 +657,8 @@ Entity* load_map_binary(const char* filename, Drive* env) {
fread(entities[i].traj_x, sizeof(float), size, file);
fread(entities[i].traj_y, sizeof(float), size, file);
fread(entities[i].traj_z, sizeof(float), size, file);
- if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type
+ if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN ||
+ entities[i].type == CYCLIST) { // Object type
fread(entities[i].traj_vx, sizeof(float), size, file);
fread(entities[i].traj_vy, sizeof(float), size, file);
fread(entities[i].traj_vz, sizeof(float), size, file);
@@ -616,32 +679,31 @@ Entity* load_map_binary(const char* filename, Drive* env) {
return entities;
}
-void set_start_position(Drive* env){
- //InitWindow(800, 600, "GPU Drive");
- //BeginDrawing();
- for(int i = 0; i < env->num_entities; i++){
+void set_start_position(Drive *env) {
+ for (int i = 0; i < env->num_entities; i++) {
int is_active = 0;
- for(int j = 0; j < env->active_agent_count; j++){
- if(env->active_agent_indices[j] == i){
+ for (int j = 0; j < env->active_agent_count; j++) {
+ if (env->active_agent_indices[j] == i) {
is_active = 1;
break;
}
}
- Entity* e = &env->entities[i];
+ Entity *e = &env->entities[i];
// Clamp init_steps to ensure we don't go out of bounds
int step = env->init_steps;
- if (step >= e->array_size) step = e->array_size - 1;
- if (step < 0) step = 0;
+ if (step >= e->array_size)
+ step = e->array_size - 1;
+ if (step < 0)
+ step = 0;
e->x = e->traj_x[step];
e->y = e->traj_y[step];
e->z = e->traj_z[step];
-
- if(e->type > CYCLIST || e->type == 0){
+ if (e->type > CYCLIST || e->type == 0) {
continue;
}
- if(is_active == 0){
+ if (is_active == 0) {
e->vx = 0;
e->vy = 0;
e->vz = 0;
@@ -656,15 +718,16 @@ void set_start_position(Drive* env){
e->heading_y = sinf(e->heading);
e->valid = e->traj_valid[env->init_steps];
e->collision_state = 0;
- e->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision
- e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad
- e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal
- e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned
- e->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; // avg displacement error
- e->cumulative_displacement = 0.0f;
- e->displacement_sample_count = 0;
+ e->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision
+ e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad
+ e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal
+ e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned
+ e->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; // far from lane
+ e->metrics_array[LANE_ANGLE_IDX] = 0.0f; // no alignment
+ e->current_lane_idx = -1;
+ e->current_lane_geometry_idx = -1;
e->respawn_timestep = -1;
- e->stopped = 0;
+ e->stopped = 0;
e->removed = 0;
e->respawn_count = 0;
@@ -676,33 +739,35 @@ void set_start_position(Drive* env){
e->steering_angle = 0.0f;
e->wheelbase = 0.6f * e->length;
}
- //EndDrawing();
}
-int getGridIndex(Drive* env, float x1, float y1) {
- if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x || env->grid_map->bottom_right_y >= env->grid_map->top_left_y) {
- return -1; // Invalid grid coordinates
+int getGridIndex(Drive *env, float x1, float y1) {
+ if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x ||
+ env->grid_map->bottom_right_y >= env->grid_map->top_left_y) {
+ return -1; // Invalid grid coordinates
}
- float relativeX = x1 - env->grid_map->top_left_x; // Distance from left
- float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom
- int gridX = (int)(relativeX / GRID_CELL_SIZE); // Column index
- int gridY = (int)(relativeY / GRID_CELL_SIZE); // Row index
+ float relativeX = x1 - env->grid_map->top_left_x; // Distance from left
+ float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom
+ int gridX = (int)(relativeX / GRID_CELL_SIZE); // Column index
+ int gridY = (int)(relativeY / GRID_CELL_SIZE); // Row index
if (gridX < 0 || gridX >= env->grid_map->grid_cols || gridY < 0 || gridY >= env->grid_map->grid_rows) {
- return -1; // Return -1 for out of bounds
+ return -1; // Return -1 for out of bounds
}
- int index = (gridY*env->grid_map->grid_cols) + gridX;
+ int index = (gridY * env->grid_map->grid_cols) + gridX;
return index;
}
-void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry_idx, int* cell_entities_insert_index){
- if(grid_index == -1){
+void add_entity_to_grid(Drive *env, int grid_index, int entity_idx, int geometry_idx, int *cell_entities_insert_index) {
+ if (grid_index == -1) {
return;
}
int count = cell_entities_insert_index[grid_index];
- if(count >= env->grid_map->cell_entities_count[grid_index]) {
- printf("Error: Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): %d\n", grid_index, count, env->grid_map->cell_entities_count[grid_index]);
+ if (count >= env->grid_map->cell_entities_count[grid_index]) {
+ printf("Error: Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): "
+ "%d\n",
+ grid_index, count, env->grid_map->cell_entities_count[grid_index]);
return;
}
@@ -711,72 +776,9 @@ void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry
cell_entities_insert_index[grid_index] = count + 1;
}
-
-void init_topology_graph(Drive* env){
- // Count ROAD_LANE entities
- int road_lane_count = 0;
- for(int i = 0; i < env->num_entities; i++){
- if(env->entities[i].type == ROAD_LANE){
- road_lane_count++;
- }
- }
-
- if(road_lane_count == 0){
- env->topology_graph = NULL;
- return;
- }
-
- // Create graph with all entities as vertices (we'll only use ROAD_LANE indices)
- env->topology_graph = createGraph(env->num_entities);
-
- // Connect ROAD_LANE entities based on geometric connectivity
- for(int i = 0; i < env->num_entities; i++){
- if(env->entities[i].type != ROAD_LANE) continue;
-
- Entity* lane_i = &env->entities[i];
- if(lane_i->array_size < 2) continue; // Need at least 2 points
-
- // Get end point of current lane
- float end_x = lane_i->traj_x[lane_i->array_size - 1];
- float end_y = lane_i->traj_y[lane_i->array_size - 1];
- float end_vector_x = lane_i->traj_x[lane_i->array_size - 1] - lane_i->traj_x[lane_i->array_size - 2];
- float end_vector_y = lane_i->traj_y[lane_i->array_size - 1] - lane_i->traj_y[lane_i->array_size - 2];
- float end_heading = atan2f(end_vector_y, end_vector_x);
-
- // Find lanes that start near this lane's end
- for(int j = 0; j < env->num_entities; j++){
- if(i == j || env->entities[j].type != ROAD_LANE) continue;
-
- Entity* lane_j = &env->entities[j];
- if(lane_j->array_size < 2) continue;
-
- // Get start point of potential next lane
- float start_x = lane_j->traj_x[0];
- float start_y = lane_j->traj_y[0];
- float start_vector_x = lane_j->traj_x[1] - lane_j->traj_x[0];
- float start_vector_y = lane_j->traj_y[1] - lane_j->traj_y[0];
- float start_heading = atan2f(start_vector_y, start_vector_x);
-
- // Check if end of lane_i is close to start of lane_j
- float distance = relative_distance_2d(end_x, end_y, start_x, start_y);
- float heading_diff = fabsf(end_heading - start_heading);
-
- // Lane connectivity thresholds:
- // - 0.01m distance: lanes must connect within 1cm (very strict for clean topology)
- // - 0.1 (~5.7 degrees) heading difference: allow slight curves
- if(distance < 0.01f && heading_diff < 0.1f){
- // Add directed edge from i to j (lane i connects to lane j)
- struct AdjListNode* node = newAdjListNode(j);
- node->next = env->topology_graph->array[i];
- env->topology_graph->array[i] = node;
- }
- }
- }
-}
-
-void init_grid_map(Drive* env){
+void init_grid_map(Drive *env) {
// Allocate memory for the grid map structure
- env->grid_map = (GridMap*)malloc(sizeof(GridMap));
+ env->grid_map = (GridMap *)malloc(sizeof(GridMap));
// Find top left and bottom right points of the map
float top_left_x;
@@ -784,23 +786,29 @@ void init_grid_map(Drive* env){
float bottom_right_x;
float bottom_right_y;
int first_valid_point = 0;
- for(int i = 0; i < env->num_entities; i++){
- if(env->entities[i].type > 3 && env->entities[i].type < 7){
+ for (int i = 0; i < env->num_entities; i++) {
+ if (env->entities[i].type > 3 && env->entities[i].type < 7) {
// Check all points in the trajectory for road elements
- Entity* e = &env->entities[i];
- for(int j = 0; j < e->array_size; j++){
- if(e->traj_x[j] == INVALID_POSITION) continue;
- if(e->traj_y[j] == INVALID_POSITION) continue;
- if(!first_valid_point) {
+ Entity *e = &env->entities[i];
+ for (int j = 0; j < e->array_size; j++) {
+ if (e->traj_x[j] == INVALID_POSITION)
+ continue;
+ if (e->traj_y[j] == INVALID_POSITION)
+ continue;
+ if (!first_valid_point) {
top_left_x = bottom_right_x = e->traj_x[j];
top_left_y = bottom_right_y = e->traj_y[j];
first_valid_point = true;
continue;
}
- if(e->traj_x[j] < top_left_x) top_left_x = e->traj_x[j];
- if(e->traj_x[j] > bottom_right_x) bottom_right_x = e->traj_x[j];
- if(e->traj_y[j] > top_left_y) top_left_y = e->traj_y[j];
- if(e->traj_y[j] < bottom_right_y) bottom_right_y = e->traj_y[j];
+ if (e->traj_x[j] < top_left_x)
+ top_left_x = e->traj_x[j];
+ if (e->traj_x[j] > bottom_right_x)
+ bottom_right_x = e->traj_x[j];
+ if (e->traj_y[j] > top_left_y)
+ top_left_y = e->traj_y[j];
+ if (e->traj_y[j] < bottom_right_y)
+ bottom_right_y = e->traj_y[j];
}
}
}
@@ -817,66 +825,71 @@ void init_grid_map(Drive* env){
float grid_height = top_left_y - bottom_right_y;
env->grid_map->grid_cols = ceil(grid_width / GRID_CELL_SIZE);
env->grid_map->grid_rows = ceil(grid_height / GRID_CELL_SIZE);
- int grid_cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows;
- env->grid_map->cells = (GridMapEntity**)calloc(grid_cell_count, sizeof(GridMapEntity*));
- env->grid_map->cell_entities_count = (int*)calloc(grid_cell_count, sizeof(int));
+ int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows;
+ env->grid_map->cells = (GridMapEntity **)calloc(grid_cell_count, sizeof(GridMapEntity *));
+ env->grid_map->cell_entities_count = (int *)calloc(grid_cell_count, sizeof(int));
// Calculate number of entities in each grid cell
- for(int i = 0; i < env->num_entities; i++){
- if(env->entities[i].type > 3 && env->entities[i].type < 7){
- for(int j = 0; j < env->entities[i].array_size - 1; j++){
- float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2;
- float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2;
+ for (int i = 0; i < env->num_entities; i++) {
+ if (env->entities[i].type > 3 && env->entities[i].type < 7) {
+ for (int j = 0; j < env->entities[i].array_size - 1; j++) {
+ float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j + 1]) / 2;
+ float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j + 1]) / 2;
int grid_index = getGridIndex(env, x_center, y_center);
- env->grid_map->cell_entities_count[grid_index]++;
+ if (grid_index != -1) {
+ env->grid_map->cell_entities_count[grid_index]++;
+ }
}
}
}
- int cell_entities_insert_index[grid_cell_count]; // Helper array for insertion index
- memset(cell_entities_insert_index, 0, grid_cell_count * sizeof(int));
+ // Use heap allocation instead of VLA to avoid stack overflow on large maps
+ int *cell_entities_insert_index = (int *)calloc(grid_cell_count, sizeof(int));
// Initialize grid cells
- for(int grid_index = 0; grid_index < grid_cell_count; grid_index++){
- env->grid_map->cells[grid_index] = (GridMapEntity*)calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity));
+ for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) {
+ env->grid_map->cells[grid_index] =
+ (GridMapEntity *)calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity));
}
- for(int i = 0;inum_entities; i++){
- if(env->entities[i].type > 3 && env->entities[i].type < 7){ // NOTE: Only Road Edges, Lines, and Lanes in grid map
- for(int j = 0; j < env->entities[i].array_size - 1; j++){
- float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2;
- float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2;
+ for (int i = 0; i < env->num_entities; i++) {
+ if (env->entities[i].type > 3 &&
+ env->entities[i].type < 7) { // NOTE: Only Road Edges, Lines, and Lanes in grid map
+ for (int j = 0; j < env->entities[i].array_size - 1; j++) {
+ float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j + 1]) / 2;
+ float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j + 1]) / 2;
int grid_index = getGridIndex(env, x_center, y_center);
add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index);
}
}
}
+ free(cell_entities_insert_index);
}
-void init_neighbor_offsets(Drive* env) {
+void init_neighbor_offsets(Drive *env) {
// Allocate memory for the offsets
- env->neighbor_offsets = (int*)calloc(env->grid_map->vision_range*env->grid_map->vision_range*2, sizeof(int));
+ env->neighbor_offsets = (int *)calloc(env->grid_map->vision_range * env->grid_map->vision_range * 2, sizeof(int));
// neighbor offsets in a spiral pattern
int dx[] = {1, 0, -1, 0};
int dy[] = {0, 1, 0, -1};
- int x = 0; // Current x offset
- int y = 0; // Current y offset
- int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down)
- int steps_to_take = 1; // Number of steps in current direction
- int steps_taken = 0; // Steps taken in current direction
+ int x = 0; // Current x offset
+ int y = 0; // Current y offset
+ int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down)
+ int steps_to_take = 1; // Number of steps in current direction
+ int steps_taken = 0; // Steps taken in current direction
int segments_completed = 0; // Count of direction segments completed
- int total = 0; // Total offsets added
- int max_offsets = env->grid_map->vision_range*env->grid_map->vision_range;
+ int total = 0; // Total offsets added
+ int max_offsets = env->grid_map->vision_range * env->grid_map->vision_range;
// Start at center (0,0)
int curr_idx = 0;
- env->neighbor_offsets[curr_idx++] = 0; // x offset
- env->neighbor_offsets[curr_idx++] = 0; // y offset
+ env->neighbor_offsets[curr_idx++] = 0; // x offset
+ env->neighbor_offsets[curr_idx++] = 0; // y offset
total++;
// Generate spiral pattern
while (total < max_offsets) {
@@ -884,16 +897,17 @@ void init_neighbor_offsets(Drive* env) {
x += dx[dir];
y += dy[dir];
// Only add if within vision range bounds
- if (abs(x) <= env->grid_map->vision_range/2 && abs(y) <= env->grid_map->vision_range/2) {
+ if (abs(x) <= env->grid_map->vision_range / 2 && abs(y) <= env->grid_map->vision_range / 2) {
env->neighbor_offsets[curr_idx++] = x;
env->neighbor_offsets[curr_idx++] = y;
total++;
}
steps_taken++;
// Check if we need to change direction
- if(steps_taken != steps_to_take) continue;
- steps_taken = 0; // Reset steps taken
- dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down)
+ if (steps_taken != steps_to_take)
+ continue;
+ steps_taken = 0; // Reset steps taken
+ dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down)
segments_completed++;
// Increase step length every two direction changes
if (segments_completed % 2 == 0) {
@@ -902,65 +916,64 @@ void init_neighbor_offsets(Drive* env) {
}
}
-void cache_neighbor_offsets(Drive* env){
+void cache_neighbor_offsets(Drive *env) {
int count = 0;
- int cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows;
- env->grid_map->neighbor_cache_entities = (GridMapEntity**)calloc(cell_count, sizeof(GridMapEntity*));
- env->grid_map->neighbor_cache_count = (int*)calloc(cell_count + 1, sizeof(int));
- for(int i = 0; i < cell_count; i++){
- int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates
+ int cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows;
+ env->grid_map->neighbor_cache_entities = (GridMapEntity **)calloc(cell_count, sizeof(GridMapEntity *));
+ env->grid_map->neighbor_cache_count = (int *)calloc(cell_count + 1, sizeof(int));
+ for (int i = 0; i < cell_count; i++) {
+ int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates
int cell_y = i / env->grid_map->grid_cols;
int current_cell_neighbor_count = 0;
- for(int j = 0; j < env->grid_map->vision_range*env->grid_map->vision_range; j++){
- int x = cell_x + env->neighbor_offsets[j*2];
- int y = cell_y + env->neighbor_offsets[j*2+1];
- int grid_index = env->grid_map->grid_cols*y + x;
- if(x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) continue;
+ for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) {
+ int x = cell_x + env->neighbor_offsets[j * 2];
+ int y = cell_y + env->neighbor_offsets[j * 2 + 1];
+ int grid_index = env->grid_map->grid_cols * y + x;
+ if (x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows)
+ continue;
int grid_count = env->grid_map->cell_entities_count[grid_index];
current_cell_neighbor_count += grid_count;
}
env->grid_map->neighbor_cache_count[i] = current_cell_neighbor_count;
count += current_cell_neighbor_count;
- if(current_cell_neighbor_count == 0) {
+ if (current_cell_neighbor_count == 0) {
env->grid_map->neighbor_cache_entities[i] = NULL;
continue;
}
- env->grid_map->neighbor_cache_entities[i] = (GridMapEntity*)calloc(current_cell_neighbor_count, sizeof(GridMapEntity));
+ env->grid_map->neighbor_cache_entities[i] =
+ (GridMapEntity *)calloc(current_cell_neighbor_count, sizeof(GridMapEntity));
}
env->grid_map->neighbor_cache_count[cell_count] = count;
- for(int i = 0; i < cell_count; i ++){
- int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates
+ for (int i = 0; i < cell_count; i++) {
+ int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates
int cell_y = i / env->grid_map->grid_cols;
int base_index = 0;
- for(int j = 0; j < env->grid_map->vision_range*env->grid_map->vision_range; j++){
- int x = cell_x + env->neighbor_offsets[j*2];
- int y = cell_y + env->neighbor_offsets[j*2+1];
- int grid_index = env->grid_map->grid_cols*y + x;
- if(x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) continue;
+ for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) {
+ int x = cell_x + env->neighbor_offsets[j * 2];
+ int y = cell_y + env->neighbor_offsets[j * 2 + 1];
+ int grid_index = env->grid_map->grid_cols * y + x;
+ if (x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows)
+ continue;
int grid_count = env->grid_map->cell_entities_count[grid_index];
// Skip if no entities or source is NULL
- if(grid_count == 0 || env->grid_map->cells[grid_index] == NULL) {
+ if (grid_count == 0 || env->grid_map->cells[grid_index] == NULL) {
continue;
}
int src_idx = grid_index;
int dst_idx = base_index;
// Copy grid_count pairs (entity_idx, geometry_idx) at once
- memcpy(&env->grid_map->neighbor_cache_entities[i][dst_idx],
- env->grid_map->cells[src_idx],
- grid_count * sizeof(GridMapEntity));
- // for(int k = 0; k < grid_count; k++){
- // env->grid_map->neighbor_cache_entities[i][dst_idx + k] = env->grid_map->cells[src_idx][k];
- // }
+ memcpy(&env->grid_map->neighbor_cache_entities[i][dst_idx], env->grid_map->cells[src_idx],
+ grid_count * sizeof(GridMapEntity));
base_index += grid_count;
}
}
}
-int get_neighbor_cache_entities(Drive* env, int cell_idx, GridMapEntity* entities, int max_entities) {
- GridMap* grid_map = env->grid_map;
+int get_neighbor_cache_entities(Drive *env, int cell_idx, GridMapEntity *entities, int max_entities) {
+ GridMap *grid_map = env->grid_map;
if (cell_idx < 0 || cell_idx >= (grid_map->grid_cols * grid_map->grid_rows)) {
return 0; // Invalid cell index
}
@@ -974,14 +987,15 @@ int get_neighbor_cache_entities(Drive* env, int cell_idx, GridMapEntity* entitie
return count;
}
-void set_means(Drive* env) {
+void set_means(Drive *env) {
float mean_x = 0.0f;
float mean_y = 0.0f;
int64_t point_count = 0;
// Compute single mean for all entities (vehicles and roads)
for (int i = 0; i < env->num_entities; i++) {
- if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) {
+ if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN ||
+ env->entities[i].type == CYCLIST) {
for (int j = 0; j < env->entities[i].array_size; j++) {
// Assume a validity flag exists (e.g., valid[j]); adjust if not available
if (env->entities[i].traj_valid[j]) { // Add validity check if applicable
@@ -1001,9 +1015,11 @@ void set_means(Drive* env) {
env->world_mean_x = mean_x;
env->world_mean_y = mean_y;
for (int i = 0; i < env->num_entities; i++) {
- if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST || env->entities[i].type >= 4) {
+ if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN ||
+ env->entities[i].type == CYCLIST || env->entities[i].type >= 4) {
for (int j = 0; j < env->entities[i].array_size; j++) {
- if(env->entities[i].traj_x[j] == INVALID_POSITION) continue;
+ if (env->entities[i].traj_x[j] == INVALID_POSITION)
+ continue;
env->entities[i].traj_x[j] -= mean_x;
env->entities[i].traj_y[j] -= mean_y;
}
@@ -1011,11 +1027,10 @@ void set_means(Drive* env) {
env->entities[i].goal_position_y -= mean_y;
}
}
-
}
-void move_expert(Drive* env, float* actions, int agent_idx){
- Entity* agent = &env->entities[agent_idx];
+void move_expert(Drive *env, float *actions, int agent_idx) {
+ Entity *agent = &env->entities[agent_idx];
int t = env->timestep;
if (t < 0 || t >= agent->array_size) {
agent->x = INVALID_POSITION;
@@ -1058,7 +1073,8 @@ bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2])
float cross = dx1 * dy2 - dy1 * dx2;
// If lines are parallel
- if (cross == 0) return false;
+ if (cross == 0)
+ return false;
// Calculate relative vectors between start points
float dx3 = p1[0] - q1[0];
@@ -1072,10 +1088,12 @@ bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2])
return (s >= 0 && s <= 1 && t >= 0 && t <= 1);
}
-int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int max_size, const int (*local_offsets)[2], int offset_size) {
+int checkNeighbors(Drive *env, float x, float y, GridMapEntity *entity_list, int max_size,
+ const int (*local_offsets)[2], int offset_size) {
// Get the grid index for the given position (x, y)
int index = getGridIndex(env, x, y);
- if (index == -1) return 0; // Return 0 size if position invalid
+ if (index == -1)
+ return 0; // Return 0 size if position invalid
// Calculate 2D grid coordinates
int cellsX = env->grid_map->grid_cols;
int gridX = index % cellsX;
@@ -1086,7 +1104,8 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int
int nx = gridX + local_offsets[i][0];
int ny = gridY + local_offsets[i][1];
// Ensure the neighbor is within grid bounds
- if(nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) continue;
+ if (nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows)
+ continue;
int neighborIndex = ny * env->grid_map->grid_cols + nx;
int count = env->grid_map->cell_entities_count[neighborIndex];
// Add entities from this cell to the list
@@ -1101,7 +1120,7 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int
return entity_list_count;
}
-int check_aabb_collision(Entity* car1, Entity* car2) {
+int check_aabb_collision(Entity *car1, Entity *car2) {
// Get car corners in world space
float cos1 = car1->heading_x;
float sin1 = car1->heading_y;
@@ -1119,79 +1138,83 @@ int check_aabb_collision(Entity* car1, Entity* car2) {
{car1->x + (half_len1 * cos1 - half_width1 * sin1), car1->y + (half_len1 * sin1 + half_width1 * cos1)},
{car1->x + (half_len1 * cos1 + half_width1 * sin1), car1->y + (half_len1 * sin1 - half_width1 * cos1)},
{car1->x + (-half_len1 * cos1 - half_width1 * sin1), car1->y + (-half_len1 * sin1 + half_width1 * cos1)},
- {car1->x + (-half_len1 * cos1 + half_width1 * sin1), car1->y + (-half_len1 * sin1 - half_width1 * cos1)}
- };
+ {car1->x + (-half_len1 * cos1 + half_width1 * sin1), car1->y + (-half_len1 * sin1 - half_width1 * cos1)}};
// Calculate car2's corners in world space
float car2_corners[4][2] = {
{car2->x + (half_len2 * cos2 - half_width2 * sin2), car2->y + (half_len2 * sin2 + half_width2 * cos2)},
{car2->x + (half_len2 * cos2 + half_width2 * sin2), car2->y + (half_len2 * sin2 - half_width2 * cos2)},
{car2->x + (-half_len2 * cos2 - half_width2 * sin2), car2->y + (-half_len2 * sin2 + half_width2 * cos2)},
- {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)}
- };
+ {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)}};
// Get the axes to check (normalized vectors perpendicular to each edge)
float axes[4][2] = {
- {cos1, sin1}, // Car1's length axis
- {-sin1, cos1}, // Car1's width axis
- {cos2, sin2}, // Car2's length axis
- {-sin2, cos2} // Car2's width axis
+ {cos1, sin1}, // Car1's length axis
+ {-sin1, cos1}, // Car1's width axis
+ {cos2, sin2}, // Car2's length axis
+ {-sin2, cos2} // Car2's width axis
};
// Check each axis
- for(int i = 0; i < 4; i++) {
+ for (int i = 0; i < 4; i++) {
float min1 = INFINITY, max1 = -INFINITY;
float min2 = INFINITY, max2 = -INFINITY;
// Project car1's corners onto the axis
- for(int j = 0; j < 4; j++) {
+ for (int j = 0; j < 4; j++) {
float proj = car1_corners[j][0] * axes[i][0] + car1_corners[j][1] * axes[i][1];
min1 = fminf(min1, proj);
max1 = fmaxf(max1, proj);
}
// Project car2's corners onto the axis
- for(int j = 0; j < 4; j++) {
+ for (int j = 0; j < 4; j++) {
float proj = car2_corners[j][0] * axes[i][0] + car2_corners[j][1] * axes[i][1];
min2 = fminf(min2, proj);
max2 = fmaxf(max2, proj);
}
// If there's a gap on this axis, the boxes don't intersect
- if(max1 < min2 || min1 > max2) {
- return 0; // No collision
+ if (max1 < min2 || min1 > max2) {
+ return 0; // No collision
}
}
// If we get here, there's no separating axis, so the boxes intersect
- return 1; // Collision
+ return 1; // Collision
}
-int collision_check(Drive* env, int agent_idx) {
- Entity* agent = &env->entities[agent_idx];
+int collision_check(Drive *env, int agent_idx) {
+ Entity *agent = &env->entities[agent_idx];
- if(agent->x == INVALID_POSITION ) return -1;
+ if (agent->x == INVALID_POSITION)
+ return -1;
int car_collided_with_index = -1;
- if (agent->respawn_timestep != -1) return car_collided_with_index; // Skip respawning entities
+ if (agent->respawn_timestep != -1)
+ return car_collided_with_index; // Skip respawning entities
- for(int i = 0; i < MAX_AGENTS; i++){
+ for (int i = 0; i < MAX_AGENTS; i++) {
int index = -1;
- if(i < env->active_agent_count){
+ if (i < env->active_agent_count) {
index = env->active_agent_indices[i];
- } else if (i < env->num_actors){
+ } else if (i < env->num_actors) {
index = env->static_agent_indices[i - env->active_agent_count];
}
- if(index == -1) continue;
- if(index == agent_idx) continue;
- Entity* entity = &env->entities[index];
- if (entity->respawn_timestep != -1) continue; // Skip respawning entities
+ if (index == -1)
+ continue;
+ if (index == agent_idx)
+ continue;
+ Entity *entity = &env->entities[index];
+ if (entity->respawn_timestep != -1)
+ continue; // Skip respawning entities
float x1 = entity->x;
float y1 = entity->y;
- float dist = ((x1 - agent->x)*(x1 - agent->x) + (y1 - agent->y)*(y1 - agent->y));
- if(dist > 225.0f) continue;
- if(check_aabb_collision(agent, entity)) {
+ float dist = ((x1 - agent->x) * (x1 - agent->x) + (y1 - agent->y) * (y1 - agent->y));
+ if (dist > 225.0f)
+ continue;
+ if (check_aabb_collision(agent, entity)) {
car_collided_with_index = index;
break;
}
@@ -1200,13 +1223,127 @@ int collision_check(Drive* env, int agent_idx) {
return car_collided_with_index;
}
-int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) {
+// ============================================================================
+// Lane Alignment Helper Functions (from GIGAFLOW / PufferDrive 3.0)
+// ============================================================================
+
+// Normalize heading to [-π, π]
+float normalize_heading(float h) {
+ while (h > M_PI)
+ h -= 2.0f * M_PI;
+ while (h < -M_PI)
+ h += 2.0f * M_PI;
+ return h;
+}
+
+// Compute signed heading difference between agent and lane, normalized to [-π, π]
+float compute_heading_diff(float agent_heading, float lane_heading) {
+ float diff = normalize_heading(agent_heading - lane_heading);
+ return diff;
+}
+
+// Find closest segment on lane polyline, return signed lateral distance
+// Positive = right of lane center, Negative = left of lane center
+float find_closest_segment_on_lane(Entity *lane, float px, float py, int *segment_idx) {
+ if (!lane || lane->array_size < 2) {
+ *segment_idx = 0;
+ return LANE_DISTANCE_NORMALIZATION;
+ }
+
+ float min_dist = LANE_DISTANCE_NORMALIZATION * 10.0f;
+ int best_idx = 0;
+ float best_signed_dist = 0.0f;
+
+ for (int i = 0; i < lane->array_size - 1; i++) {
+ float x1 = lane->traj_x[i];
+ float y1 = lane->traj_y[i];
+ float x2 = lane->traj_x[i + 1];
+ float y2 = lane->traj_y[i + 1];
+
+ float dx = x2 - x1;
+ float dy = y2 - y1;
+ float seg_len_sq = dx * dx + dy * dy;
+
+ if (seg_len_sq < 1e-6f)
+ continue;
+
+ // Project point onto segment
+ float t = ((px - x1) * dx + (py - y1) * dy) / seg_len_sq;
+ t = fmaxf(0.0f, fminf(1.0f, t));
+
+ float closest_x = x1 + t * dx;
+ float closest_y = y1 + t * dy;
+
+ float dist = sqrtf((px - closest_x) * (px - closest_x) + (py - closest_y) * (py - closest_y));
+
+ if (dist < min_dist) {
+ min_dist = dist;
+ best_idx = i;
+
+ // Compute signed distance (cross product gives sign)
+ // Positive if point is to the right of the lane direction
+ float cross = dx * (py - y1) - dy * (px - x1);
+ best_signed_dist = (cross >= 0) ? dist : -dist;
+ }
+ }
+
+ *segment_idx = best_idx;
+ return best_signed_dist;
+}
+
+// Compute lane heading using weighted average of neighboring segments
+float compute_multi_segment_alignment(Entity *lane, int segment_idx) {
+ if (!lane || lane->array_size < 2)
+ return 0.0f;
+
+ // Clamp to valid range
+ if (segment_idx < 0)
+ segment_idx = 0;
+ if (segment_idx >= lane->array_size - 1)
+ segment_idx = lane->array_size - 2;
+
+ float sum_heading = 0.0f;
+ float sum_weight = 0.0f;
+
+ // Consider current segment and neighbors
+ for (int offset = -1; offset <= 1; offset++) {
+ int idx = segment_idx + offset;
+ if (idx < 0 || idx >= lane->array_size - 1)
+ continue;
+
+ float dx = lane->traj_x[idx + 1] - lane->traj_x[idx];
+ float dy = lane->traj_y[idx + 1] - lane->traj_y[idx];
+ float heading = atan2f(dy, dx);
+
+ // Weight by distance from main segment (center segment has highest weight)
+ float weight = (offset == 0) ? 2.0f : 1.0f;
+
+ // Handle angle wrapping for averaging
+ if (sum_weight > 0) {
+ float diff = heading - (sum_heading / sum_weight);
+ if (diff > M_PI)
+ heading -= 2.0f * M_PI;
+ else if (diff < -M_PI)
+ heading += 2.0f * M_PI;
+ }
+
+ sum_heading += weight * heading;
+ sum_weight += weight;
+ }
+
+ return (sum_weight > 0) ? normalize_heading(sum_heading / sum_weight) : 0.0f;
+}
+
+int check_lane_aligned(Entity *car, Entity *lane, int geometry_idx) {
// Validate lane geometry length
- if (!lane || lane->array_size < 2) return 0;
+ if (!lane || lane->array_size < 2)
+ return 0;
// Clamp geometry index to valid segment range [0, array_size-2]
- if (geometry_idx < 0) geometry_idx = 0;
- if (geometry_idx >= lane->array_size - 1) geometry_idx = lane->array_size - 2;
+ if (geometry_idx < 0)
+ geometry_idx = 0;
+ if (geometry_idx >= lane->array_size - 1)
+ geometry_idx = lane->array_size - 2;
// Compute local lane segment heading
float heading_x1, heading_y1;
@@ -1227,26 +1364,31 @@ int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) {
float heading = (heading_1 + heading_2) / 2.0f;
// Normalize to [-pi, pi]
- if (heading > M_PI) heading -= 2.0f * M_PI;
- if (heading < -M_PI) heading += 2.0f * M_PI;
+ if (heading > M_PI)
+ heading -= 2.0f * M_PI;
+ if (heading < -M_PI)
+ heading += 2.0f * M_PI;
// Compute heading difference
float car_heading = car->heading; // radians
float heading_diff = fabsf(car_heading - heading);
- if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff;
+ if (heading_diff > M_PI)
+ heading_diff = 2.0f * M_PI - heading_diff;
// within 15 degrees
return (heading_diff < (M_PI / 12.0f)) ? 1 : 0;
}
-void reset_agent_metrics(Drive* env, int agent_idx){
- Entity* agent = &env->entities[agent_idx];
- agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision
- agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad
- agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned
- agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f;
+void reset_agent_metrics(Drive *env, int agent_idx) {
+ Entity *agent = &env->entities[agent_idx];
+ agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision
+ agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad
+ agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned
+ agent->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; // far from lane
+ agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // no alignment
agent->collision_state = 0;
+ agent->current_lane_geometry_idx = -1;
}
float point_to_segment_distance_2d(float px, float py, float x1, float y1, float x2, float y2) {
@@ -1262,8 +1404,10 @@ float point_to_segment_distance_2d(float px, float py, float x1, float y1, float
float t = ((px - x1) * dx + (py - y1) * dy) / (dx * dx + dy * dy);
// Clamp t to the segment
- if (t < 0) t = 0;
- else if (t > 1) t = 1;
+ if (t < 0)
+ t = 0;
+ else if (t > 1)
+ t = 1;
// Find the closest point on the segment
float closestX = x1 + t * dx;
@@ -1273,28 +1417,17 @@ float point_to_segment_distance_2d(float px, float py, float x1, float y1, float
return sqrtf((px - closestX) * (px - closestX) + (py - closestY) * (py - closestY));
}
-void compute_agent_metrics(Drive* env, int agent_idx) {
- Entity* agent = &env->entities[agent_idx];
+void compute_agent_metrics(Drive *env, int agent_idx) {
+ Entity *agent = &env->entities[agent_idx];
reset_agent_metrics(env, agent_idx);
- if(agent->x == INVALID_POSITION ) return; // invalid agent position
-
- // Compute displacement error
- float displacement_error = compute_displacement_error(agent, env->timestep);
-
- if (displacement_error > 0.0f) { // Only count valid displacements
- agent->cumulative_displacement += displacement_error;
- agent->displacement_sample_count++;
-
- // Compute running average
- agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] =
- agent->cumulative_displacement / agent->displacement_sample_count;
- }
+ if (agent->x == INVALID_POSITION)
+ return; // invalid agent position
int collided = 0;
- float half_length = agent->length/2.0f;
- float half_width = agent->width/2.0f;
+ float half_length = agent->length / 2.0f;
+ float half_width = agent->width / 2.0f;
float cos_heading = cosf(agent->heading);
float sin_heading = sinf(agent->heading);
float min_distance = (float)INT16_MAX;
@@ -1304,20 +1437,30 @@ void compute_agent_metrics(Drive* env, int agent_idx) {
float corners[4][2];
for (int i = 0; i < 4; i++) {
- corners[i][0] = agent->x + (offsets[i][0]*half_length*cos_heading - offsets[i][1]*half_width*sin_heading);
- corners[i][1] = agent->y + (offsets[i][0]*half_length*sin_heading + offsets[i][1]*half_width*cos_heading);
+ corners[i][0] =
+ agent->x + (offsets[i][0] * half_length * cos_heading - offsets[i][1] * half_width * sin_heading);
+ corners[i][1] =
+ agent->y + (offsets[i][0] * half_length * sin_heading + offsets[i][1] * half_width * cos_heading);
}
- GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; // Array big enough for all neighboring cells
- int list_size = checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL*25, collision_offsets, 25);
- for (int i = 0; i < list_size ; i++) {
- if(entity_list[i].entity_idx == -1) continue;
- if(entity_list[i].entity_idx == agent_idx) continue;
- Entity* entity;
+ GridMapEntity entity_list[MAX_ENTITIES_PER_CELL * 25]; // Array big enough for all neighboring cells
+ int list_size =
+ checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL * 25, collision_offsets, 25);
+
+ // Track checked lanes to avoid duplicate processing (reset before loop)
+ int checked_lanes[MAX_CHECKED_LANES];
+ int num_checked = 0;
+
+ for (int i = 0; i < list_size; i++) {
+ if (entity_list[i].entity_idx == -1)
+ continue;
+ if (entity_list[i].entity_idx == agent_idx)
+ continue;
+ Entity *entity;
entity = &env->entities[entity_list[i].entity_idx];
// Check for offroad collision with road edges
- if(entity->type == ROAD_EDGE) {
+ if (entity->type == ROAD_EDGE) {
int geometry_idx = entity_list[i].geometry_idx;
float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]};
float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]};
@@ -1330,90 +1473,158 @@ void compute_agent_metrics(Drive* env, int agent_idx) {
}
}
- if (collided == OFFROAD) break;
+ if (collided == OFFROAD)
+ break;
- // Find closest point on the road centerline to the agent
- if(entity->type == ROAD_LANE) {
+ // Find closest lane using GIGAFLOW-style scoring (distance + heading)
+ if (entity->type == ROAD_LANE) {
int entity_idx = entity_list[i].entity_idx;
- int geometry_idx = entity_list[i].geometry_idx;
- float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]};
- float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]};
+ // Skip if already checked this lane
+ int already_checked = 0;
+ for (int c = 0; c < num_checked; c++) {
+ if (checked_lanes[c] == entity_idx) {
+ already_checked = 1;
+ break;
+ }
+ }
+ if (already_checked)
+ continue;
+ if (num_checked < MAX_CHECKED_LANES)
+ checked_lanes[num_checked++] = entity_idx;
+
+ // Find closest segment on this lane (returns signed distance)
+ int segment_idx;
+ float signed_dist = find_closest_segment_on_lane(entity, agent->x, agent->y, &segment_idx);
+ float abs_dist = fabsf(signed_dist);
+ if (abs_dist > LANE_DISTANCE_NORMALIZATION)
+ continue;
+
+ // Compute lane heading using multi-segment alignment
+ float lane_heading = compute_multi_segment_alignment(entity, segment_idx);
+
+ // Compute heading alignment penalty (0.0 = perfect, 1.0 = opposite)
+ float heading_diff = fabsf(compute_heading_diff(agent->heading, lane_heading));
+ float heading_penalty = heading_diff / M_PI;
- float dist = point_to_segment_distance_2d(agent->x, agent->y, start[0], start[1], end[0], end[1]);
- float heading_diff = fabsf(atan2f(end[1]-start[1], end[0]-start[0]) - agent->heading);
+ // Normalize distance for scoring
+ float distance_penalty = abs_dist / LANE_DISTANCE_NORMALIZATION;
- // Normalize heading difference to [0, pi]
- if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff;
+ // Combined score using defined weights
+ float score =
+ LANE_SELECTION_DISTANCE_WEIGHT * distance_penalty + LANE_SELECTION_HEADING_WEIGHT * heading_penalty;
- // Penalize if heading differs by more than 30 degrees
- if (heading_diff > (M_PI / 6.0f)) dist += 3.0f;
+ // Hysteresis: penalize switching away from current lane
+ if (agent->current_lane_idx != entity_idx && agent->current_lane_idx != -1) {
+ score += LANE_SWITCH_THRESHOLD;
+ }
- if (dist < min_distance) {
- min_distance = dist;
+ // Track best candidate
+ if (score < min_distance) { // Using min_distance to store best score
+ min_distance = score;
closest_lane_entity_idx = entity_idx;
- closest_lane_geometry_idx = geometry_idx;
+ closest_lane_geometry_idx = segment_idx;
}
}
}
- // check if aligned with closest lane and set current lane
- // 4.0m threshold: agents more than 4 meters from any lane are considered off-road
- if (min_distance > 4.0f || closest_lane_entity_idx == -1) {
- agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f;
- agent->current_lane_idx = -1;
- } else {
+ // Update lane metrics using GIGAFLOW Frenet coordinates
+ if (closest_lane_entity_idx != -1) {
+ Entity *lane = &env->entities[closest_lane_entity_idx];
+
+ // Recompute signed distance and heading for best lane
+ int segment_idx;
+ float signed_dist = find_closest_segment_on_lane(lane, agent->x, agent->y, &segment_idx);
+ float lane_heading = compute_multi_segment_alignment(lane, segment_idx);
+ float theta_f = compute_heading_diff(agent->heading, lane_heading);
+
+ // Store lane metrics
agent->current_lane_idx = closest_lane_entity_idx;
+ agent->current_lane_geometry_idx = closest_lane_geometry_idx;
+ agent->metrics_array[LANE_DIST_IDX] = signed_dist;
+ agent->metrics_array[LANE_ANGLE_IDX] = cosf(theta_f);
- int lane_aligned = check_lane_aligned(agent, &env->entities[closest_lane_entity_idx], closest_lane_geometry_idx);
+ // Binary lane aligned flag (within ~15 degrees)
+ int lane_aligned = (fabsf(agent->metrics_array[LANE_ANGLE_IDX]) > LANE_ALIGN_COS_THRESHOLD) ? 1 : 0;
agent->metrics_array[LANE_ALIGNED_IDX] = lane_aligned;
+ } else {
+ // Not on any lane
+ agent->current_lane_idx = -1;
+ agent->current_lane_geometry_idx = -1;
+ agent->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION;
+ agent->metrics_array[LANE_ANGLE_IDX] = 0.0f;
+ agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f;
}
// Check for vehicle collisions
int car_collided_with_index = collision_check(env, agent_idx);
- if (car_collided_with_index != -1) collided = VEHICLE_COLLISION;
+ if (car_collided_with_index != -1)
+ collided = VEHICLE_COLLISION;
agent->collision_state = collided;
+ if (collided == VEHICLE_COLLISION) {
+ if (env->collision_behavior == STOP_AGENT && !agent->stopped) {
+ agent->stopped = 1;
+ agent->vx = agent->vy = 0.0f;
+ } else if (env->collision_behavior == REMOVE_AGENT && !agent->removed) {
+ Entity *agent_collided = &env->entities[car_collided_with_index];
+ agent->removed = 1;
+ agent_collided->removed = 1;
+ agent->x = agent->y = -10000.0f;
+ agent_collided->x = agent_collided->y = -10000.0f;
+ }
+ }
+ if (collided == OFFROAD) {
+ agent->metrics_array[OFFROAD_IDX] = 1.0f;
+ if (env->offroad_behavior == STOP_AGENT && !agent->stopped) {
+ agent->stopped = 1;
+ agent->vx = agent->vy = 0.0f;
+ } else if (env->offroad_behavior == REMOVE_AGENT && !agent->removed) {
+ agent->removed = 1;
+ agent->x = agent->y = -10000.0f;
+ }
+ }
+
return;
}
-bool should_control_agent(Drive* env, int agent_idx){
-
+bool should_control_agent(Drive *env, int agent_idx) {
// Check if we have room for more agents or are already at capacity
if (env->active_agent_count >= env->num_agents) {
return false;
}
- Entity* entity = &env->entities[agent_idx];
+ Entity *entity = &env->entities[agent_idx];
- // Shrink agent size for collision checking
- entity->width *= 0.7f; // TODO: Move this somewhere else
+ // TODO: Move this elsewhere or remove
+ entity->width *= 0.7f;
entity->length *= 0.7f;
if (env->control_mode == CONTROL_SDC_ONLY) {
- return (agent_idx == env->sdc_track_index);
+ return agent_idx == env->sdc_track_index;
}
- // Special mode: control only agents in prediction track list
- if (env->control_mode == CONTROL_TRACKS_TO_PREDICT) {
- for (int j = 0; j < env->num_tracks_to_predict; j++) {
- if (env->tracks_to_predict_indices[j] == agent_idx) {
- return true;
- }
- }
- return false;
- }
+ bool is_vehicle = (entity->type == VEHICLE);
+ bool is_ped_or_bike = (entity->type == PEDESTRIAN || entity->type == CYCLIST);
+ bool type_is_valid = false;
+
+ switch (env->control_mode) {
+ case CONTROL_WOSAC:
+ // Valid types only, ignore expert flag and goal distance
+ return (is_vehicle || is_ped_or_bike);
+
+ case CONTROL_VEHICLES:
+ type_is_valid = is_vehicle;
+ break;
- // Standard mode: check type, distance to goal, and expert status
- bool type_is_controllable = false;
- if (env->control_mode == CONTROL_VEHICLES) {
- type_is_controllable = (entity->type == VEHICLE);
- } else { // CONTROL_AGENTS mode
- type_is_controllable = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST);
+ default:
+ type_is_valid = (is_vehicle || is_ped_or_bike);
+ break;
}
- if (!type_is_controllable || entity->mark_as_expert) {
+ // Filter invalid types or experts
+ if (!type_is_valid || entity->mark_as_expert) {
return false;
}
@@ -1431,26 +1642,24 @@ bool should_control_agent(Drive* env, int agent_idx){
return distance_to_goal >= MIN_DISTANCE_TO_GOAL;
}
-void set_active_agents(Drive* env){
-
+void set_active_agents(Drive *env) {
// Initialize
- env->active_agent_count = 0; // Policy-controlled agents
- env->static_agent_count = 0; // Non-moving background agents
+ env->active_agent_count = 0; // Policy-controlled agents
+ env->static_agent_count = 0; // Non-moving background agents
env->expert_static_agent_count = 0; // Expert replay agents (non-controlled)
- env->num_actors = 0; // Total agents created
+ env->num_actors = 0; // Total agents created
int active_agent_indices[MAX_AGENTS];
int static_agent_indices[MAX_AGENTS];
int expert_static_agent_indices[MAX_AGENTS];
- if(env->num_agents == 0){
+ if (env->num_agents == 0) {
env->num_agents = MAX_AGENTS;
}
-
// Iterate through entities to find agents to create and/or control
- for(int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++){
+ for (int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++) {
- Entity* entity = &env->entities[i];
+ Entity *entity = &env->entities[i];
// Skip if not valid at initialization
if (entity->traj_valid[env->init_steps] != 1) {
@@ -1460,14 +1669,15 @@ void set_active_agents(Drive* env){
// Determine if entity should be created
bool should_create = false;
if (env->init_mode == INIT_ALL_VALID) {
- should_create = true; // All valid entities
+ should_create = true; // All valid entities
} else if (env->control_mode == CONTROL_VEHICLES) {
should_create = (entity->type == VEHICLE);
- } else { // Control all agents
+ } else { // Control all agents
should_create = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST);
}
- if (!should_create) continue;
+ if (!should_create)
+ continue;
env->num_actors++;
@@ -1475,12 +1685,13 @@ void set_active_agents(Drive* env){
bool is_controlled = false;
is_controlled = should_control_agent(env, i);
- if (is_controlled && env->active_agent_count >= env->max_controlled_agents && env->max_controlled_agents != -1) {
+ if (is_controlled && env->active_agent_count >= env->max_controlled_agents &&
+ env->max_controlled_agents != -1) {
is_controlled = false;
entity->mark_as_expert = 1;
}
- if(is_controlled){
+ if (is_controlled) {
active_agent_indices[env->active_agent_count] = i;
env->active_agent_count++;
env->entities[i].active_agent = 1;
@@ -1488,7 +1699,7 @@ void set_active_agents(Drive* env){
static_agent_indices[env->static_agent_count] = i;
env->static_agent_count++;
env->entities[i].active_agent = 0;
- if(env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) {
+ if (env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) {
expert_static_agent_indices[env->expert_static_agent_count] = i;
env->expert_static_agent_count++;
env->entities[i].mark_as_expert = 1;
@@ -1497,24 +1708,24 @@ void set_active_agents(Drive* env){
}
// Set up initial active agents
- env->active_agent_indices = (int*)malloc(env->active_agent_count * sizeof(int));
- env->static_agent_indices = (int*)malloc(env->static_agent_count * sizeof(int));
- env->expert_static_agent_indices = (int*)malloc(env->expert_static_agent_count * sizeof(int));
- for(int i=0;iactive_agent_count;i++){
+ env->active_agent_indices = (int *)malloc(env->active_agent_count * sizeof(int));
+ env->static_agent_indices = (int *)malloc(env->static_agent_count * sizeof(int));
+ env->expert_static_agent_indices = (int *)malloc(env->expert_static_agent_count * sizeof(int));
+ for (int i = 0; i < env->active_agent_count; i++) {
env->active_agent_indices[i] = active_agent_indices[i];
}
- for(int i=0;istatic_agent_count;i++){
+ for (int i = 0; i < env->static_agent_count; i++) {
env->static_agent_indices[i] = static_agent_indices[i];
}
- for(int i=0;iexpert_static_agent_count;i++){
+ for (int i = 0; i < env->expert_static_agent_count; i++) {
env->expert_static_agent_indices[i] = expert_static_agent_indices[i];
}
return;
}
-void remove_bad_trajectories(Drive* env){
+void remove_bad_trajectories(Drive *env) {
- if (env->control_mode != CONTROL_TRACKS_TO_PREDICT) {
+ if (env->control_mode != CONTROL_WOSAC) {
return; // Leave all trajectories in WOSAC control mode
}
@@ -1526,22 +1737,23 @@ void remove_bad_trajectories(Drive* env){
collided_with_indices[i] = -1;
}
// move experts through trajectories to check for collisions and remove as illegal agents
- for(int t = 0; t < env->scenario_length; t++){
- for(int i = 0; i < env->active_agent_count; i++){
+ for (int t = 0; t < env->scenario_length; t++) {
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
move_expert(env, env->actions, agent_idx);
}
- for(int i = 0; i < env->expert_static_agent_count; i++){
+ for (int i = 0; i < env->expert_static_agent_count; i++) {
int expert_idx = env->expert_static_agent_indices[i];
- if(env->entities[expert_idx].x == INVALID_POSITION) continue;
+ if (env->entities[expert_idx].x == INVALID_POSITION)
+ continue;
move_expert(env, env->actions, expert_idx);
}
// check collisions
- for(int i = 0; i < env->active_agent_count; i++){
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
env->entities[agent_idx].collision_state = 0;
int collided_with_index = collision_check(env, agent_idx);
- if((collided_with_index >= 0) && collided_agents[i] == 0){
+ if ((collided_with_index >= 0) && collided_agents[i] == 0) {
collided_agents[i] = 1;
collided_with_indices[i] = collided_with_index;
}
@@ -1549,11 +1761,13 @@ void remove_bad_trajectories(Drive* env){
env->timestep++;
}
- for(int i = 0; i< env->active_agent_count; i++){
- if(collided_with_indices[i] == -1) continue;
- for(int j = 0; j < env->static_agent_count; j++){
+ for (int i = 0; i < env->active_agent_count; i++) {
+ if (collided_with_indices[i] == -1)
+ continue;
+ for (int j = 0; j < env->static_agent_count; j++) {
int static_agent_idx = env->static_agent_indices[j];
- if(static_agent_idx != collided_with_indices[i]) continue;
+ if (static_agent_idx != collided_with_indices[i])
+ continue;
env->entities[static_agent_idx].traj_x[0] = INVALID_POSITION;
env->entities[static_agent_idx].traj_y[0] = INVALID_POSITION;
}
@@ -1561,15 +1775,15 @@ void remove_bad_trajectories(Drive* env){
env->timestep = 0;
}
-void init_goal_positions(Drive* env){
- for(int x = 0;xactive_agent_count; x++){
+void init_goal_positions(Drive *env) {
+ for (int x = 0; x < env->active_agent_count; x++) {
int agent_idx = env->active_agent_indices[x];
env->entities[agent_idx].init_goal_x = env->entities[agent_idx].goal_position_x;
env->entities[agent_idx].init_goal_y = env->entities[agent_idx].goal_position_y;
}
}
-void assign_ego_and_coplayer_roles(Drive* env) {
+void assign_ego_and_coplayer_roles(Drive *env) {
if (!env->population_play || env->num_ego_agents == 0) {
for (int i = 0; i < env->num_entities; i++) {
if (!env->entities[i].mark_as_expert) {
@@ -1605,17 +1819,13 @@ void assign_ego_and_coplayer_roles(Drive* env) {
}
}
-
-
-
-void init(Drive* env){
+void init(Drive *env) {
env->human_agent_idx = 0;
env->timestep = 0;
env->entities = load_map_binary(env->map_name, env);
set_means(env);
init_grid_map(env);
- if (env->goal_behavior==GOAL_GENERATE_NEW) init_topology_graph(env);
- env->grid_map->vision_range = 21;
+ env->grid_map->vision_range = 21; // TODO: Why is this hardcoded?
init_neighbor_offsets(env);
cache_neighbor_offsets(env);
env->logs_capacity = 0;
@@ -1625,34 +1835,32 @@ void init(Drive* env){
set_start_position(env);
init_goal_positions(env);
assign_ego_and_coplayer_roles(env);
- env->logs = (Log*)calloc(env->active_agent_count, sizeof(Log));
+ env->logs = (Log *)calloc(env->active_agent_count, sizeof(Log));
// Always allocate weight arrays for consistency
- env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float));
+ env->collision_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->offroad_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->goal_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->entropy_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->discount_weights = (float *)calloc(env->active_agent_count, sizeof(float));
if (env->population_play) {
-
if (env->co_player_logs) {
free(env->co_player_logs);
env->co_player_logs = NULL;
}
if (env->active_agent_count > 0) {
- env->co_player_logs = (Co_Player_Log*)calloc(env->active_agent_count, sizeof(Co_Player_Log));
+ env->co_player_logs = (Log *)calloc(env->active_agent_count, sizeof(Log));
} else {
env->co_player_logs = NULL;
}
- memset(&env->co_player_log, 0, sizeof(Co_Player_Log));
+ memset(&env->co_player_log, 0, sizeof(Log));
}
if (env->population_play) {
-
if (env->co_player_logs) {
free(env->co_player_logs);
env->co_player_logs = NULL;
@@ -1661,38 +1869,64 @@ void init(Drive* env){
// Always allocate for all active agents, not just co-players
// because we index by x which goes from 0 to active_agent_count-1
if (env->active_agent_count > 0) {
- env->co_player_logs = (Co_Player_Log*)calloc(env->active_agent_count, sizeof(Co_Player_Log));
+ env->co_player_logs = (Log *)calloc(env->active_agent_count, sizeof(Log));
} else {
env->co_player_logs = NULL;
}
- memset(&env->co_player_log, 0, sizeof(Co_Player_Log));
+ memset(&env->co_player_log, 0, sizeof(Log));
+ }
+}
+
+void close_client(Client *client);
+
+// Render-mode helper: stash env->client across vec_close + re-vectorize so the
+// raylib window + ffmpeg pipe survive a map swap. Needed because raylib's
+// CloseWindow → InitWindow cycle segfaults on LoadModel under our xvfb-headless
+// setup (stale GL state). We keep a single global slot — the renderer is
+// always single-env, single-window. Multi-env render would need an array.
+static Client *g_donated_client = NULL;
+
+void c_donate_client(Drive *env) {
+ if (env->client != NULL) {
+ g_donated_client = env->client;
+ env->client = NULL; // c_close now no-ops the client teardown.
}
}
+void c_adopt_client(Drive *env) {
+ if (g_donated_client != NULL) {
+ env->client = g_donated_client;
+ g_donated_client = NULL;
+ }
+}
-void c_close(Drive* env){
+void c_close(Drive *env) {
+ if (env->client != NULL) {
+ close_client(env->client);
+ env->client = NULL;
+ }
if (env->population_play && env->co_player_logs != NULL) {
free(env->co_player_logs);
free(env->co_player_ids);
free(env->ego_agent_ids);
}
- for(int i = 0; i < env->num_entities; i++){
+ for (int i = 0; i < env->num_entities; i++) {
free_entity(&env->entities[i]);
}
free(env->entities);
free(env->active_agent_indices);
free(env->logs);
// GridMap cleanup
- int grid_cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows;
- for(int grid_index = 0; grid_index < grid_cell_count; grid_index++){
+ int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows;
+ for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) {
free(env->grid_map->cells[grid_index]);
}
free(env->grid_map->cells);
free(env->grid_map->cell_entities_count);
free(env->neighbor_offsets);
- for(int i = 0; i < grid_cell_count; i++){
+ for (int i = 0; i < grid_cell_count; i++) {
free(env->grid_map->neighbor_cache_entities[i]);
}
free(env->grid_map->neighbor_cache_entities);
@@ -1700,34 +1934,30 @@ void c_close(Drive* env){
free(env->grid_map);
free(env->static_agent_indices);
free(env->expert_static_agent_indices);
- freeTopologyGraph(env->topology_graph);
- // free(env->map_name);
free(env->ini_file);
-
}
-void allocate(Drive* env){
+void allocate(Drive *env) {
init(env);
- int base_ego_dim = (env->dynamics_model == JERK) ? 10 : 7;
+ int base_ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC;
int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0);
int ego_dim = base_ego_dim + conditioning_dims;
// Always allocate weight arrays for consistency
- env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float));
- env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float));
-
- int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS;
- env->observations = (float*)calloc(env->active_agent_count*max_obs, sizeof(float));
- env->actions = (float*)calloc(env->active_agent_count*2, sizeof(float));
- env->rewards = (float*)calloc(env->active_agent_count, sizeof(float));
- env->terminals= (unsigned char*)calloc(env->active_agent_count, sizeof(unsigned char));
-
+ env->collision_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->offroad_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->goal_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->entropy_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->discount_weights = (float *)calloc(env->active_agent_count, sizeof(float));
+
+ int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS;
+ env->observations = (float *)calloc(env->active_agent_count * max_obs, sizeof(float));
+ env->actions = (float *)calloc(env->active_agent_count * 2, sizeof(float));
+ env->rewards = (float *)calloc(env->active_agent_count, sizeof(float));
+ env->terminals = (unsigned char *)calloc(env->active_agent_count, sizeof(unsigned char));
}
-void free_allocated(Drive* env){
+void free_allocated(Drive *env) {
free(env->observations);
free(env->actions);
free(env->rewards);
@@ -1745,24 +1975,21 @@ void free_allocated(Drive* env){
float clipSpeed(float speed) {
const float maxSpeed = MAX_SPEED;
- if (speed > maxSpeed) return maxSpeed;
- if (speed < -maxSpeed) return -maxSpeed;
+ if (speed > maxSpeed)
+ return maxSpeed;
+ if (speed < -maxSpeed)
+ return -maxSpeed;
return speed;
}
-float normalize_heading(float heading){
- if(heading > M_PI) heading -= 2*M_PI;
- if(heading < -M_PI) heading += 2*M_PI;
- return heading;
-}
+// normalize_heading is defined earlier in this file (around line 1222)
-float normalize_value(float value, float min, float max){
- return (value - min) / (max - min);
-}
+float normalize_value(float value, float min, float max) { return (value - min) / (max - min); }
-void move_dynamics(Drive* env, int action_idx, int agent_idx){
- Entity* agent = &env->entities[agent_idx];
- if (agent->removed) return;
+void move_dynamics(Drive *env, int action_idx, int agent_idx) {
+ Entity *agent = &env->entities[agent_idx];
+ if (agent->removed)
+ return;
if (agent->stopped) {
agent->vx = 0.0f;
@@ -1776,7 +2003,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){
float steering = 0.0f;
if (env->action_type == 1) { // continuous
- float (*action_array_f)[2] = (float(*)[2])env->actions;
+ float (*action_array_f)[2] = (float (*)[2])env->actions;
acceleration = action_array_f[action_idx][0];
steering = action_array_f[action_idx][1];
@@ -1784,8 +2011,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){
steering *= STEERING_VALUES[12];
} else { // discrete
// Interpret action as a single integer: a = accel_idx * num_steer + steer_idx
- int* action_array = (int*)env->actions;
- int num_accel = sizeof(ACCELERATION_VALUES) / sizeof(ACCELERATION_VALUES[0]);
+ int *action_array = (int *)env->actions;
int num_steer = sizeof(STEERING_VALUES) / sizeof(STEERING_VALUES[0]);
int action_val = action_array[action_idx];
int acceleration_index = action_val / num_steer;
@@ -1801,27 +2027,28 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){
float vx = agent->vx;
float vy = agent->vy;
- // Calculate current speed
- float speed = sqrtf(vx*vx + vy*vy);
+ // Calculate current speed (signed based on direction relative to heading)
+ float speed_magnitude = sqrtf(vx * vx + vy * vy);
+ float v_dot_heading = vx * agent->heading_x + vy * agent->heading_y;
+ float signed_speed = copysignf(speed_magnitude, v_dot_heading);
// Update speed with acceleration
- speed = speed + acceleration*env->dt;
- speed = clipSpeed(speed);
-
+ signed_speed = signed_speed + acceleration * env->dt;
+ signed_speed = clipSpeed(signed_speed);
// Compute yaw rate
- float beta = tanh(.5*tanf(steering));
+ float beta = tanh(.5 * tanf(steering));
// New heading
- float yaw_rate = (speed*cosf(beta)*tanf(steering)) / agent->length;
+ float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / agent->length;
// New velocity
- float new_vx = speed*cosf(heading + beta);
- float new_vy = speed*sinf(heading + beta);
+ float new_vx = signed_speed * cosf(heading + beta);
+ float new_vy = signed_speed * sinf(heading + beta);
// Update position
- x = x + (new_vx*env->dt);
- y = y + (new_vy*env->dt);
- heading = heading + yaw_rate*env->dt;
+ x = x + (new_vx * env->dt);
+ y = y + (new_vy * env->dt);
+ heading = heading + yaw_rate * env->dt;
// Apply updates to the agent's state
agent->x = x;
@@ -1836,23 +2063,26 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){
// Extract action components
float a_long, a_lat;
if (env->action_type == 1) { // continuous
- float (*action_array_f)[2] = (float(*)[2])env->actions;
+ float (*action_array_f)[2] = (float (*)[2])env->actions;
// Asymmetric scaling for longitudinal jerk to match discrete action space
// Discrete: JERK_LONG = [-15, -4, 0, 4] (more braking than acceleration)
- float a_long_action = action_array_f[action_idx][0]; // [-1, 1]
+ float a_long_action = action_array_f[action_idx][0]; // [-1, 1]
if (a_long_action < 0) {
- a_long = a_long_action * (-JERK_LONG[0]); // Negative: [-1, 0] → [-15, 0] (braking)
+ a_long = a_long_action * (-JERK_LONG[0]); // Negative: [-1, 0] → [-15, 0] (braking)
} else {
- a_long = a_long_action * JERK_LONG[3]; // Positive: [0, 1] → [0, 4] (acceleration)
+ a_long = a_long_action * JERK_LONG[3]; // Positive: [0, 1] → [0, 4] (acceleration)
}
// Symmetric scaling for lateral jerk
a_lat = action_array_f[action_idx][1] * JERK_LAT[2];
} else { // discrete
- int (*action_array)[2] = (int(*)[2])env->actions;
- int a_long_idx = action_array[action_idx][0];
- int a_lat_idx = action_array[action_idx][1];
+ // Interpret action as a single integer: a = long_idx * num_lat + lat_idx
+ int *action_array = (int *)env->actions;
+ int num_lat = sizeof(JERK_LAT) / sizeof(JERK_LAT[0]);
+ int action_val = action_array[action_idx];
+ int a_long_idx = action_val / num_lat;
+ int a_lat_idx = action_val % num_lat;
a_long = JERK_LONG[a_long_idx];
a_lat = JERK_LAT[a_lat_idx];
}
@@ -1876,7 +2106,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){
// Calculate new velocity
float v_dot_heading = agent->vx * agent->heading_x + agent->vy * agent->heading_y;
- float signed_v = copysignf(sqrtf(agent->vx*agent->vx + agent->vy*agent->vy), v_dot_heading);
+ float signed_v = copysignf(sqrtf(agent->vx * agent->vx + agent->vy * agent->vy), v_dot_heading);
float v_new = signed_v + 0.5f * (a_long_new + agent->a_long) * env->dt;
// Make it easy to stop with 0 vel
@@ -1931,10 +2161,23 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){
return;
}
-void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* id_out) {
- for(int i = 0; i < env->active_agent_count; i++){
+static inline int get_track_id_or_placeholder(Drive *env, int agent_idx) {
+ if (env->tracks_to_predict_indices == NULL || env->num_tracks_to_predict == 0) {
+ return -1;
+ }
+ for (int k = 0; k < env->num_tracks_to_predict; k++) {
+ if (env->tracks_to_predict_indices[k] == agent_idx) {
+ return env->tracks_to_predict_indices[k];
+ }
+ }
+ return -1;
+}
+
+void c_get_global_agent_state(Drive *env, float *x_out, float *y_out, float *z_out, float *heading_out, int *id_out,
+ float *length_out, float *width_out) {
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
- Entity* agent = &env->entities[agent_idx];
+ Entity *agent = &env->entities[agent_idx];
// For WOSAC, we need the original world coordinates, so we add the world means back
x_out[i] = agent->x + env->world_mean_x;
@@ -1945,14 +2188,15 @@ void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_o
}
}
-void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* valid_out, int* id_out, int* scenario_id_out) {
- for(int i = 0; i < env->active_agent_count; i++){
+void c_get_global_ground_truth_trajectories(Drive *env, float *x_out, float *y_out, float *z_out, float *heading_out,
+ int *valid_out, int *id_out, int *scenario_id_out) {
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
- Entity* agent = &env->entities[agent_idx];
+ Entity *agent = &env->entities[agent_idx];
id_out[i] = env->tracks_to_predict_indices[i];
scenario_id_out[i] = agent->scenario_id;
- for(int t = env->init_steps; t < agent->array_size; t++){
+ for (int t = env->init_steps; t < agent->array_size; t++) {
int out_idx = i * (agent->array_size - env->init_steps) + (t - env->init_steps);
// Add world means back to get original world coordinates
x_out[out_idx] = agent->traj_x[t] + env->world_mean_x;
@@ -1964,48 +2208,87 @@ void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_o
}
}
-void compute_observations(Drive* env) {
- int base_ego_dim = (env->dynamics_model == JERK) ? 10 : 7;
+void c_get_road_edge_counts(Drive *env, int *num_polylines_out, int *total_points_out) {
+ int count = 0, points = 0;
+ for (int i = env->num_objects; i < env->num_entities; i++) {
+ if (env->entities[i].type == ROAD_EDGE) {
+ count++;
+ points += env->entities[i].array_size;
+ }
+ }
+ *num_polylines_out = count;
+ *total_points_out = points;
+}
+
+void c_get_road_edge_polylines(Drive *env, float *x_out, float *y_out, int *lengths_out, int *scenario_ids_out) {
+ int poly_idx = 0, pt_idx = 0;
+ for (int i = env->num_objects; i < env->num_entities; i++) {
+ Entity *e = &env->entities[i];
+ if (e->type == ROAD_EDGE) {
+ lengths_out[poly_idx] = e->array_size;
+ scenario_ids_out[poly_idx] = e->scenario_id;
+ for (int j = 0; j < e->array_size; j++) {
+ x_out[pt_idx] = e->traj_x[j] + env->world_mean_x;
+ y_out[pt_idx] = e->traj_y[j] + env->world_mean_y;
+ pt_idx++;
+ }
+ poly_idx++;
+ }
+ }
+}
+
+void compute_observations(Drive *env) {
+ int base_ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC;
int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0);
int ego_dim = base_ego_dim + conditioning_dims;
- int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS;
+ int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS;
- memset(env->observations, 0, max_obs*env->active_agent_count*sizeof(float));
- float (*observations)[max_obs] = (float(*)[max_obs])env->observations;
+ memset(env->observations, 0, max_obs * env->active_agent_count * sizeof(float));
+ float (*observations)[max_obs] = (float (*)[max_obs])env->observations;
- for(int i = 0; i < env->active_agent_count; i++) {
- float* obs = &observations[i][0];
- Entity* ego_entity = &env->entities[env->active_agent_indices[i]];
- if(ego_entity->type > 3) break;
+ for (int i = 0; i < env->active_agent_count; i++) {
+ float *obs = &observations[i][0];
+ Entity *ego_entity = &env->entities[env->active_agent_indices[i]];
+ if (ego_entity->type > 3)
+ break;
float cos_heading = ego_entity->heading_x;
float sin_heading = ego_entity->heading_y;
- float ego_speed = sqrtf(ego_entity->vx*ego_entity->vx + ego_entity->vy*ego_entity->vy);
+ float speed_magnitude = sqrtf(ego_entity->vx * ego_entity->vx + ego_entity->vy * ego_entity->vy);
+ float v_dot_heading = ego_entity->vx * ego_entity->heading_x + ego_entity->vy * ego_entity->heading_y;
+ float signed_speed = copysignf(speed_magnitude, v_dot_heading);
// Set goal distances
float goal_x = ego_entity->goal_position_x - ego_entity->x;
float goal_y = ego_entity->goal_position_y - ego_entity->y;
// Rotate to ego vehicle's frame
- float rel_goal_x = goal_x*cos_heading + goal_y*sin_heading;
- float rel_goal_y = -goal_x*sin_heading + goal_y*cos_heading;
+ float rel_goal_x = goal_x * cos_heading + goal_y * sin_heading;
+ float rel_goal_y = -goal_x * sin_heading + goal_y * cos_heading;
- obs[0] = rel_goal_x* 0.005f;
- obs[1] = rel_goal_y* 0.005f;
- obs[2] = ego_speed / MAX_SPEED;
+ obs[0] = rel_goal_x * 0.005f;
+ obs[1] = rel_goal_y * 0.005f;
+ obs[2] = signed_speed / MAX_SPEED;
obs[3] = ego_entity->width / MAX_VEH_WIDTH;
obs[4] = ego_entity->length / MAX_VEH_LEN;
obs[5] = (ego_entity->collision_state > 0) ? 1.0f : 0.0f;
obs[6] = (ego_entity->respawn_timestep != -1) ? 1 : 0;
+ // Lane alignment observations (GIGAFLOW Frenet coordinates)
+ float lane_center_dist = ego_entity->metrics_array[LANE_DIST_IDX] / LANE_DISTANCE_NORMALIZATION;
+ lane_center_dist = fmaxf(-1.0f, fminf(1.0f, lane_center_dist)); // Clamp to [-1, 1]
+ obs[7] = lane_center_dist;
+ obs[8] = ego_entity->metrics_array[LANE_ANGLE_IDX]; // cos(theta_f), already in [-1, 1]
+
if (env->dynamics_model == JERK) {
- obs[7] = ego_entity->steering_angle / M_PI;
+ obs[9] = ego_entity->steering_angle / M_PI;
// Asymmetric normalization for a_long to match action space
- obs[8] = (ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3];
- obs[9] = ego_entity->a_lat / JERK_LAT[2];
+ obs[10] =
+ (ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3];
+ obs[11] = ego_entity->a_lat / JERK_LAT[2];
}
- int obs_idx = (env->dynamics_model == JERK) ? 10 : 7;
+ int obs_idx = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC;
// Add conditioning weights to observations
if (env->use_rc) {
obs[obs_idx++] = env->collision_weights[i];
@@ -2021,86 +2304,97 @@ void compute_observations(Drive* env) {
// Relative Pos of other cars
int cars_seen = 0;
- for(int j = 0; j < MAX_AGENTS; j++) {
+ for (int j = 0; j < MAX_AGENTS; j++) {
int index = -1;
- if(j < env->active_agent_count){
+ if (j < env->active_agent_count) {
index = env->active_agent_indices[j];
- } else if (j < env->num_actors){
+ } else if (j < env->num_actors) {
index = env->static_agent_indices[j - env->active_agent_count];
}
- if(index == -1) continue;
- if(env->entities[index].type > 3) break;
- if(index == env->active_agent_indices[i]) continue; // Skip self, but don't increment obs_idx
- Entity* other_entity = &env->entities[index];
- if(ego_entity->respawn_timestep != -1) continue;
- if(other_entity->respawn_timestep != -1) continue;
+ if (index == -1)
+ continue;
+ if (env->entities[index].type > 3)
+ break;
+ if (index == env->active_agent_indices[i])
+ continue; // Skip self, but don't increment obs_idx
+ Entity *other_entity = &env->entities[index];
+ if (ego_entity->respawn_timestep != -1)
+ continue;
+ if (other_entity->respawn_timestep != -1)
+ continue;
// Store original relative positions
float dx = other_entity->x - ego_entity->x;
float dy = other_entity->y - ego_entity->y;
- float dist = (dx*dx + dy*dy);
- if(dist > 2500.0f) continue;
+ float dist = (dx * dx + dy * dy);
+ if (dist > 2500.0f)
+ continue;
// Rotate to ego vehicle's frame
- float rel_x = dx*cos_heading + dy*sin_heading;
- float rel_y = -dx*sin_heading + dy*cos_heading;
+ float rel_x = dx * cos_heading + dy * sin_heading;
+ float rel_y = -dx * sin_heading + dy * cos_heading;
// Store observations with correct indexing
obs[obs_idx] = rel_x * 0.02f;
- // Add conditioning weights to observations
+ // Add conditioning weights to observations
obs[obs_idx + 1] = rel_y * 0.02f;
obs[obs_idx + 2] = other_entity->width / MAX_VEH_WIDTH;
obs[obs_idx + 3] = other_entity->length / MAX_VEH_LEN;
// relative heading
- float rel_heading_x = other_entity->heading_x * ego_entity->heading_x +
- other_entity->heading_y * ego_entity->heading_y; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b)
- float rel_heading_y = other_entity->heading_y * ego_entity->heading_x -
- other_entity->heading_x * ego_entity->heading_y; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
+ float rel_heading_x =
+ other_entity->heading_x * ego_entity->heading_x +
+ other_entity->heading_y * ego_entity->heading_y; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b)
+ float rel_heading_y =
+ other_entity->heading_y * ego_entity->heading_x -
+ other_entity->heading_x * ego_entity->heading_y; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
obs[obs_idx + 4] = rel_heading_x;
obs[obs_idx + 5] = rel_heading_y;
- // obs[obs_idx + 4] = cosf(rel_heading) / MAX_ORIENTATION_RAD;
- // obs[obs_idx + 5] = sinf(rel_heading) / MAX_ORIENTATION_RAD;
- // // relative speed
- float other_speed = sqrtf(other_entity->vx*other_entity->vx + other_entity->vy*other_entity->vy);
- obs[obs_idx + 6] = other_speed / MAX_SPEED;
+
+ // relative speed
+ float other_speed_magnitude =
+ sqrtf(other_entity->vx * other_entity->vx + other_entity->vy * other_entity->vy);
+ float other_v_dot_heading =
+ other_entity->vx * other_entity->heading_x + other_entity->vy * other_entity->heading_y;
+ float other_signed_speed = copysignf(other_speed_magnitude, other_v_dot_heading);
+ obs[obs_idx + 6] = other_signed_speed / MAX_SPEED;
cars_seen++;
- obs_idx += 7; // Move to next observation slot
+ obs_idx += 7; // Move to next observation slot
}
int remaining_partner_obs = (MAX_AGENTS - 1 - cars_seen) * 7;
memset(&obs[obs_idx], 0, remaining_partner_obs * sizeof(float));
obs_idx += remaining_partner_obs;
// map observations
- GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25];
+ GridMapEntity entity_list[MAX_ENTITIES_PER_CELL * 25];
int grid_idx = getGridIndex(env, ego_entity->x, ego_entity->y);
int list_size = get_neighbor_cache_entities(env, grid_idx, entity_list, MAX_ROAD_SEGMENT_OBSERVATIONS);
- for(int k = 0; k < list_size; k++) {
+ for (int k = 0; k < list_size; k++) {
int entity_idx = entity_list[k].entity_idx;
int geometry_idx = entity_list[k].geometry_idx;
// Validate entity_idx before accessing
- if(entity_idx < 0 || entity_idx >= env->num_entities) {
- printf("ERROR: Invalid entity_idx %d (max: %d)\n", entity_idx, env->num_entities-1);
+ if (entity_idx < 0 || entity_idx >= env->num_entities) {
+ printf("ERROR: Invalid entity_idx %d (max: %d)\n", entity_idx, env->num_entities - 1);
continue;
}
- Entity* entity = &env->entities[entity_idx];
+ Entity *entity = &env->entities[entity_idx];
// Validate geometry_idx before accessing
- if(geometry_idx < 0 || geometry_idx >= entity->array_size) {
- printf("ERROR: Invalid geometry_idx %d for entity %d (max: %d)\n",
- geometry_idx, entity_idx, entity->array_size-1);
+ if (geometry_idx < 0 || geometry_idx >= entity->array_size) {
+ printf("ERROR: Invalid geometry_idx %d for entity %d (max: %d)\n", geometry_idx, entity_idx,
+ entity->array_size - 1);
continue;
}
float start_x = entity->traj_x[geometry_idx];
float start_y = entity->traj_y[geometry_idx];
- float end_x = entity->traj_x[geometry_idx+1];
- float end_y = entity->traj_y[geometry_idx+1];
+ float end_x = entity->traj_x[geometry_idx + 1];
+ float end_y = entity->traj_y[geometry_idx + 1];
float mid_x = (start_x + end_x) / 2.0f;
float mid_y = (start_y + end_y) / 2.0f;
float rel_x = mid_x - ego_entity->x;
float rel_y = mid_y - ego_entity->y;
- float x_obs = rel_x*cos_heading + rel_y*sin_heading;
- float y_obs = -rel_x*sin_heading + rel_y*cos_heading;
+ float x_obs = rel_x * cos_heading + rel_y * sin_heading;
+ float y_obs = -rel_x * sin_heading + rel_y * cos_heading;
float length = relative_distance_2d(mid_x, mid_y, end_x, end_y);
float width = 0.1;
// Calculate angle from ego to midpoint (vector from ego to midpoint)
@@ -2108,14 +2402,14 @@ void compute_observations(Drive* env) {
float dy = end_y - mid_y;
float dx_norm = dx;
float dy_norm = dy;
- float hypot = sqrtf(dx*dx + dy*dy);
- if(hypot > 0) {
+ float hypot = sqrtf(dx * dx + dy * dy);
+ if (hypot > 0) {
dx_norm /= hypot;
dy_norm /= hypot;
}
// Compute sin and cos of relative angle directly without atan2f
- float cos_angle = dx_norm*cos_heading + dy_norm*sin_heading;
- float sin_angle = -dx_norm*sin_heading + dy_norm*cos_heading;
+ float cos_angle = dx_norm * cos_heading + dy_norm * sin_heading;
+ float sin_angle = -dx_norm * sin_heading + dy_norm * cos_heading;
obs[obs_idx] = x_obs * 0.02f;
obs[obs_idx + 1] = y_obs * 0.02f;
obs[obs_idx + 2] = length / MAX_ROAD_SEGMENT_LENGTH;
@@ -2131,170 +2425,103 @@ void compute_observations(Drive* env) {
}
}
-static int find_forward_projection_on_lane(Entity* lane, Entity* agent, int* out_segment_idx, float* out_fraction) {
- int best_idx = -1;
- float best_dist_sq = 1e30f;
-
- for (int i = 1; i < lane->array_size; i++) {
- float x0 = lane->traj_x[i - 1];
- float y0 = lane->traj_y[i - 1];
- float x1 = lane->traj_x[i];
- float y1 = lane->traj_y[i];
- float dx = x1 - x0;
- float dy = y1 - y0;
- float seg_len_sq = dx * dx + dy * dy;
- if (seg_len_sq < 1e-6f) continue;
-
- float to_agent_x = agent->x - x0;
- float to_agent_y = agent->y - y0;
- float t = (to_agent_x * dx + to_agent_y * dy) / seg_len_sq;
- if (t < 0.0f) t = 0.0f;
- else if (t > 1.0f) t = 1.0f;
-
- float proj_x = x0 + t * dx;
- float proj_y = y0 + t * dy;
-
- float rel_x = proj_x - agent->x;
- float rel_y = proj_y - agent->y;
- float forward = rel_x * agent->heading_x + rel_y * agent->heading_y;
- if (forward < 0.0f) continue;
-
- float dist_sq = rel_x * rel_x + rel_y * rel_y;
- if (dist_sq < best_dist_sq) {
- best_dist_sq = dist_sq;
- best_idx = i;
- *out_fraction = t;
- }
- }
-
- if (best_idx != -1) {
- *out_segment_idx = best_idx;
- return 1;
- }
+void sample_new_goal(Drive *env, int agent_idx) {
+ // Samples a new goal position based on the existing road lane points
+ Entity *agent = &env->entities[agent_idx];
+ float best_x = agent->x;
+ float best_y = agent->y;
+ float best_distance_error = 1e30f;
- return 0;
-}
+ // Sample points from all road lanes
+ for (int i = env->num_objects; i < env->num_entities; i++) {
+ if (env->entities[i].type != ROAD_LANE)
+ continue;
-void compute_new_goal(Drive* env, int agent_idx) {
- Entity* agent = &env->entities[agent_idx];
- int current_lane = agent->current_lane_idx;
+ Entity *lane = &env->entities[i];
- if (current_lane == -1) return; // No current lane
+ // Check every point in the lane
+ for (int j = 0; j < lane->array_size; j++) {
+ float point_x = lane->traj_x[j];
+ float point_y = lane->traj_y[j];
- // Target distance: 40m ahead along the lane topology from agent's current position
- float target_distance = 40.0f;
- int current_entity = current_lane;
- Entity* lane = &env->entities[current_entity];
+ // Calculate vector from agent to point
+ float to_point_x = point_x - agent->x;
+ float to_point_y = point_y - agent->y;
- int initial_segment_idx = 1;
- float initial_fraction = 0.0f;
- if (!find_forward_projection_on_lane(lane, agent, &initial_segment_idx, &initial_fraction)) {
- int forward_idx = -1;
- for (int i = 0; i < lane->array_size; i++) {
- float to_point_x = lane->traj_x[i] - agent->x;
- float to_point_y = lane->traj_y[i] - agent->y;
+ // Check if point is ahead of agent
float dot = to_point_x * agent->heading_x + to_point_y * agent->heading_y;
- if (dot > 0.0f) {
- forward_idx = i;
- break;
- }
- }
-
- if (forward_idx == -1) {
- agent->goal_position_x = lane->traj_x[lane->array_size - 1];
- agent->goal_position_y = lane->traj_y[lane->array_size - 1];
- agent->sampled_new_goal = 0;
- return;
- }
-
- initial_segment_idx = forward_idx;
- if (initial_segment_idx == 0) initial_segment_idx = 1;
- initial_fraction = 0.0f;
- }
-
- float remaining_distance = target_distance;
- int first_lane = 1;
-
- // Traverse the topology graph starting from the vehicle's position forward
- while (current_entity != -1) {
- lane = &env->entities[current_entity];
-
- int start_idx = first_lane ? initial_segment_idx : 1;
- // Ensure start_idx is at least 1 to avoid accessing traj_x[i-1] with i=0
- if (start_idx < 1) start_idx = 1;
- first_lane = 0;
+ if (dot <= 0.0f)
+ continue;
- for (int i = start_idx; i < lane->array_size; i++) {
- float prev_x = lane->traj_x[i - 1];
- float prev_y = lane->traj_y[i - 1];
- float next_x = lane->traj_x[i];
- float next_y = lane->traj_y[i];
- float seg_dx = next_x - prev_x;
- float seg_dy = next_y - prev_y;
- float segment_length = relative_distance_2d(prev_x, prev_y, next_x, next_y);
+ // Calculate distance to point
+ float distance = sqrtf(to_point_x * to_point_x + to_point_y * to_point_y);
- if (remaining_distance <= segment_length) {
- agent->goal_position_x = next_x;
- agent->goal_position_y = next_y;
- agent->sampled_new_goal = 0;
- return;
+ // Find point closest to target distance
+ float distance_error = fabsf(distance - env->goal_target_distance);
+ if (distance_error < best_distance_error) {
+ best_distance_error = distance_error;
+ best_x = point_x;
+ best_y = point_y;
}
-
- remaining_distance -= segment_length;
- }
-
- int connected_lanes[5];
- int num_connected = getNextLanes(env->topology_graph, current_entity, connected_lanes, 5);
-
- if (num_connected == 0) {
- agent->goal_position_x = lane->traj_x[lane->array_size - 1];
- agent->goal_position_y = lane->traj_y[lane->array_size - 1];
- agent->sampled_new_goal = 0;
- return; // No further lanes to traverse
}
+ }
- int random_idx = agent_idx % num_connected;
- current_entity = connected_lanes[random_idx];
+ // If no valid goal found, use another agent's initial goal
+ if (best_distance_error >= 1e30f && env->active_agent_count > 1) {
+ int other_idx = env->active_agent_indices[(agent_idx + 1) % env->active_agent_count];
+ best_x = env->entities[other_idx].init_goal_x;
+ best_y = env->entities[other_idx].init_goal_y;
}
+
+ agent->goal_position_x = best_x;
+ agent->goal_position_y = best_y;
+ agent->goals_sampled_this_episode += 1;
}
-void c_reset(Drive* env){
+void c_reset(Drive *env) {
env->timestep = env->init_steps;
set_start_position(env);
- // Initialize all conditioning weights even when no conditioning (lb=ub)
- for(int i = 0; i < env->active_agent_count; i++) {
- env->collision_weights[i] = ((float)rand() / RAND_MAX) * (env->collision_weight_ub - env->collision_weight_lb) + env->collision_weight_lb;
- env->offroad_weights[i] = ((float)rand() / RAND_MAX) * (env->offroad_weight_ub - env->offroad_weight_lb) + env->offroad_weight_lb;
- env->goal_weights[i] = ((float)rand() / RAND_MAX) * (env->goal_weight_ub - env->goal_weight_lb) + env->goal_weight_lb;
- env->entropy_weights[i] = ((float)rand() / RAND_MAX) * (env->entropy_weight_ub - env->entropy_weight_lb) + env->entropy_weight_lb;
- env->discount_weights[i] = ((float)rand() / RAND_MAX) * (env->discount_weight_ub - env->discount_weight_lb) + env->discount_weight_lb;
- }
-
- for(int x = 0;xactive_agent_count; x++){
+ for (int i = 0; i < env->active_agent_count; i++) {
+ env->collision_weights[i] = ((float)rand() / RAND_MAX) * (env->collision_weight_ub - env->collision_weight_lb) +
+ env->collision_weight_lb;
+ env->offroad_weights[i] =
+ ((float)rand() / RAND_MAX) * (env->offroad_weight_ub - env->offroad_weight_lb) + env->offroad_weight_lb;
+ env->goal_weights[i] =
+ ((float)rand() / RAND_MAX) * (env->goal_weight_ub - env->goal_weight_lb) + env->goal_weight_lb;
+ env->entropy_weights[i] =
+ ((float)rand() / RAND_MAX) * (env->entropy_weight_ub - env->entropy_weight_lb) + env->entropy_weight_lb;
+ env->discount_weights[i] =
+ ((float)rand() / RAND_MAX) * (env->discount_weight_ub - env->discount_weight_lb) + env->discount_weight_lb;
+ }
+
+ for (int x = 0; x < env->active_agent_count; x++) {
env->logs[x] = (Log){0};
int agent_idx = env->active_agent_indices[x];
env->entities[agent_idx].respawn_timestep = -1;
env->entities[agent_idx].respawn_count = 0;
env->entities[agent_idx].collided_before_goal = 0;
- env->entities[agent_idx].reached_goal_this_episode = 0;
+ env->entities[agent_idx].goals_reached_this_episode = 0.0f;
+ // Initialize to 1 because there is one goal in the data file
+ env->entities[agent_idx].goals_sampled_this_episode = 1.0f;
+ env->entities[agent_idx].current_goal_reached = 0;
env->entities[agent_idx].metrics_array[COLLISION_IDX] = 0.0f;
env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f;
env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f;
env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f;
- env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f;
- env->entities[agent_idx].cumulative_displacement = 0.0f;
- env->entities[agent_idx].displacement_sample_count = 0;
- env->entities[agent_idx].stopped = 0;
+ env->entities[agent_idx].metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION;
+ env->entities[agent_idx].metrics_array[LANE_ANGLE_IDX] = 0.0f;
+ env->entities[agent_idx].current_lane_idx = -1;
+ env->entities[agent_idx].current_lane_geometry_idx = -1;
+ env->entities[agent_idx].stopped = 0;
env->entities[agent_idx].removed = 0;
- if (env->goal_behavior==GOAL_GENERATE_NEW) {
+ if (env->goal_behavior == GOAL_GENERATE_NEW) {
env->entities[agent_idx].goal_position_x = env->entities[agent_idx].init_goal_x;
env->entities[agent_idx].goal_position_y = env->entities[agent_idx].init_goal_y;
- env->entities[agent_idx].sampled_new_goal = 0;
}
- if (env->population_play){
- env->co_player_logs[x] = (Co_Player_Log){0};
+ if (env->population_play) {
+ env->co_player_logs[x] = (Log){0};
}
compute_agent_metrics(env, agent_idx);
@@ -2302,7 +2529,7 @@ void c_reset(Drive* env){
compute_observations(env);
}
-void respawn_agent(Drive* env, int agent_idx){
+void respawn_agent(Drive *env, int agent_idx) {
env->entities[agent_idx].x = env->entities[agent_idx].traj_x[0];
env->entities[agent_idx].y = env->entities[agent_idx].traj_y[0];
env->entities[agent_idx].heading = env->entities[agent_idx].traj_heading[0];
@@ -2314,10 +2541,13 @@ void respawn_agent(Drive* env, int agent_idx){
env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f;
env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f;
env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f;
- env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f;
- env->entities[agent_idx].cumulative_displacement = 0.0f;
- env->entities[agent_idx].displacement_sample_count = 0;
+ env->entities[agent_idx].metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION;
+ env->entities[agent_idx].metrics_array[LANE_ANGLE_IDX] = 0.0f;
+ env->entities[agent_idx].current_lane_idx = -1;
+ env->entities[agent_idx].current_lane_geometry_idx = -1;
+
env->entities[agent_idx].respawn_timestep = env->timestep;
+ env->entities[agent_idx].collided_before_goal = 0;
env->entities[agent_idx].stopped = 0;
env->entities[agent_idx].removed = 0;
env->entities[agent_idx].a_long = 0.0f;
@@ -2327,43 +2557,55 @@ void respawn_agent(Drive* env, int agent_idx){
env->entities[agent_idx].steering_angle = 0.0f;
}
-void c_step(Drive* env){
+void c_step(Drive *env) {
memset(env->rewards, 0, env->active_agent_count * sizeof(float));
memset(env->terminals, 0, env->active_agent_count * sizeof(unsigned char));
env->timestep++;
- if(env->timestep == env->scenario_length){
+
+ int originals_remaining = 0;
+ for (int i = 0; i < env->active_agent_count; i++) {
+ int agent_idx = env->active_agent_indices[i];
+ // Keep flag true if there is at least one agent that has not been respawned yet
+ if (env->entities[agent_idx].respawn_count == 0) {
+ originals_remaining = 1;
+ break;
+ }
+ }
+
+ if (env->timestep == env->scenario_length || (!originals_remaining && env->termination_mode == 1)) {
add_log(env);
c_reset(env);
-
return;
}
// Move static experts
for (int i = 0; i < env->expert_static_agent_count; i++) {
int expert_idx = env->expert_static_agent_indices[i];
- if(env->entities[expert_idx].x == INVALID_POSITION) continue;
+ if (env->entities[expert_idx].x == INVALID_POSITION)
+ continue;
move_expert(env, env->actions, expert_idx);
}
// Process actions for all active agents
- for(int i = 0; i < env->active_agent_count; i++){
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
env->entities[agent_idx].collision_state = 0;
+
move_dynamics(env, i, agent_idx);
// Update logs based on agent type - use i directly as log index
- if(env->entities[agent_idx].is_ego){
+ if (env->entities[agent_idx].is_ego) {
env->logs[i].score = 0.0f;
env->logs[i].episode_length += 1;
- } else if(env->entities[agent_idx].is_co_player){
- env->co_player_logs[i].co_player_score = 0.0f;
- env->co_player_logs[i].co_player_episode_length += 1;
+ } else if (env->entities[agent_idx].is_co_player) {
+ env->co_player_logs[i].score = 0.0f;
+ env->co_player_logs[i].episode_length += 1;
}
}
// Compute metrics and rewards
- for(int i = 0; i < env->active_agent_count; i++){
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
env->entities[agent_idx].collision_state = 0;
@@ -2372,143 +2614,155 @@ void c_step(Drive* env){
int collision_state = env->entities[agent_idx].collision_state;
int is_ego = env->entities[agent_idx].is_ego;
int is_co_player = env->entities[agent_idx].is_co_player;
-
- // Handle collisions - SAME REWARD for both ego and co-players
- if(collision_state > 0){
- if(collision_state == VEHICLE_COLLISION){
+ int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX];
+
+ // Handle collisions - SAME REWARD for both ego and co-players.
+ // Skip for agents already stopped (e.g. after reaching the goal in
+ // GOAL_STOP mode): the policy isn't driving anymore, so a rear-end
+ // shouldn't count against its collision/offroad metrics.
+ if (collision_state > 0 && !env->entities[agent_idx].stopped) {
+ if (collision_state == VEHICLE_COLLISION) {
env->rewards[i] = env->collision_weights[i];
- if(is_ego){
+ if (is_ego) {
env->logs[i].episode_return += env->collision_weights[i];
env->logs[i].collision_rate = 1.0f;
- env->logs[i].avg_collisions_per_agent += 1.0f;
- } else if(is_co_player){
- env->co_player_logs[i].co_player_episode_return += env->collision_weights[i];
- env->co_player_logs[i].co_player_collision_rate = 1.0f;
+ env->logs[i].collisions_per_agent += 1.0f;
+ } else if (is_co_player) {
+ env->co_player_logs[i].episode_return += env->collision_weights[i];
+ env->co_player_logs[i].collision_rate = 1.0f;
+ env->co_player_logs[i].collisions_per_agent += 1.0f;
}
- } else if(collision_state == OFFROAD){
+ } else if (collision_state == OFFROAD) {
env->rewards[i] = env->offroad_weights[i];
- if(is_ego){
+ if (is_ego) {
env->logs[i].episode_return += env->offroad_weights[i];
env->logs[i].offroad_rate = 1.0f;
- env->logs[i].avg_offroad_per_agent += 1.0f; // ADD THIS
- } else if(is_co_player){
- env->co_player_logs[i].co_player_episode_return += env->offroad_weights[i];
- env->co_player_logs[i].co_player_offroad_rate = 1.0f;
- env->logs[i].avg_offroad_per_agent += 1.0f;
- }
+ env->logs[i].offroad_per_agent += 1.0f;
+ } else if (is_co_player) {
+ env->co_player_logs[i].episode_return += env->offroad_weights[i];
+ env->co_player_logs[i].offroad_rate = 1.0f;
+ env->co_player_logs[i].offroad_per_agent += 1.0f;
+ }
}
- if(!env->entities[agent_idx].reached_goal_this_episode){
+ if (!reached_goal) {
env->entities[agent_idx].collided_before_goal = 1;
}
}
- // Handle goal reward - SAME REWARD for both ego and co-players
- float distance_to_goal = relative_distance_2d(
- env->entities[agent_idx].x,
- env->entities[agent_idx].y,
- env->entities[agent_idx].goal_position_x,
- env->entities[agent_idx].goal_position_y
+ // Handle goal reward - NEW INCOMING LOGIC with speed check
+ float distance_to_goal =
+ relative_distance_2d(env->entities[agent_idx].x, env->entities[agent_idx].y,
+ env->entities[agent_idx].goal_position_x, env->entities[agent_idx].goal_position_y);
+
+ float current_speed = sqrtf(env->entities[agent_idx].vx * env->entities[agent_idx].vx +
+ env->entities[agent_idx].vy * env->entities[agent_idx].vy);
- );
+ // Reward agent if it is within X meters of goal and speed is below threshold
+ bool within_distance = distance_to_goal < env->goal_radius;
+ bool within_speed = current_speed <= env->goal_speed;
- if(distance_to_goal < env->goal_radius){
- if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1){
+ if (within_distance && within_speed && !env->entities[agent_idx].current_goal_reached) {
+ if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1) {
float scaled_post_respawn_reward = env->reward_goal_post_respawn * env->goal_weights[i];
env->rewards[i] += scaled_post_respawn_reward;
- if(is_ego){
+
+ if (is_ego) {
env->logs[i].episode_return += scaled_post_respawn_reward;
- } else if(is_co_player){
- env->co_player_logs[i].co_player_episode_return += scaled_post_respawn_reward;
+ } else if (is_co_player) {
+ env->co_player_logs[i].episode_return += scaled_post_respawn_reward;
}
- } else if (env->goal_behavior == GOAL_GENERATE_NEW) {
+ env->entities[agent_idx].current_goal_reached = 1;
+ } else if (env->goal_behavior == GOAL_GENERATE_NEW && (!env->entities[agent_idx].current_goal_reached)) {
env->rewards[i] += env->goal_weights[i];
- env->entities[agent_idx].sampled_new_goal = 1;
- if(is_ego){
+
+ if (is_ego) {
env->logs[i].episode_return += env->goal_weights[i];
- env->logs[i].num_goals_reached += 1;
- } else if(is_co_player){
- env->co_player_logs[i].co_player_episode_return += env->goal_weights[i];
- env->co_player_logs[i].co_player_num_goals_reached += 1;
+ } else if (is_co_player) {
+ env->co_player_logs[i].episode_return += env->goal_weights[i];
}
+
+ sample_new_goal(env, agent_idx);
+ env->entities[agent_idx].current_goal_reached = 0;
+ env->entities[agent_idx].goals_reached_this_episode += 1.0f;
} else { // Zero out the velocity so that the agent stops at the goal
env->rewards[i] = env->goal_weights[i];
- if(is_ego){
+
+ if (is_ego) {
env->logs[i].episode_return = env->goal_weights[i];
- env->logs[i].num_goals_reached = 1;
- } else if(is_co_player){
- env->co_player_logs[i].co_player_episode_return = env->goal_weights[i];
- env->co_player_logs[i].co_player_num_goals_reached = 1;
+ } else if (is_co_player) {
+ env->co_player_logs[i].episode_return = env->goal_weights[i];
}
+
env->entities[agent_idx].stopped = 1;
- env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f;
+ env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f;
+ env->entities[agent_idx].goals_reached_this_episode += 1.0f;
}
- env->entities[agent_idx].reached_goal_this_episode = 1;
env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 1.0f;
- }
- if(env->entities[agent_idx].sampled_new_goal && env->goal_behavior == GOAL_GENERATE_NEW){
- compute_new_goal(env, agent_idx);
+ if (is_ego) {
+ env->logs[i].speed_at_goal = current_speed;
+ } else if (is_co_player) {
+ env->co_player_logs[i].speed_at_goal = current_speed;
+ }
}
- int lane_aligned = env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX];
- if(is_ego){
- env->logs[i].lane_alignment_rate = lane_aligned;
- } else if(is_co_player){
- env->co_player_logs[i].co_player_lane_alignment_rate = lane_aligned;
- }
+ // Lane alignment reward (GIGAFLOW formula)
+ // Only apply if reward_lane_align > 0 (disabled by default)
+ if (env->reward_lane_align > 0.0f) {
+ float cos_theta = env->entities[agent_idx].metrics_array[LANE_ANGLE_IDX];
+ float theta_f = acosf(fminf(fmaxf(cos_theta, -1.0f), 1.0f)); // Get |θ_f| from cos
+
+ // GIGAFLOW Rl-align: min(cos,0) + vel_align*min(cos*v,0) + 0.0025*(1-|θ|/(π/2))
+ float against_lane_penalty = fminf(cos_theta, 0.0f); // Negative when >90° off
+ float vel_aligned_penalty = env->reward_vel_align * fminf(cos_theta * current_speed, 0.0f);
+ float alignment_bonus = 0.0025f * (1.0f - theta_f / (M_PI / 2.0f));
- float current_ade = env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX];
- if(current_ade > 0.0f && env->reward_ade != 0.0f){
- float ade_reward = env->reward_ade * current_ade;
- env->rewards[i] += ade_reward;
+ float lane_align_reward =
+ env->reward_lane_align * env->dt * (against_lane_penalty + vel_aligned_penalty + alignment_bonus);
- if(is_ego){
- env->logs[i].episode_return += ade_reward;
- env->logs[i].avg_displacement_error = current_ade;
- } else if(is_co_player){
- env->co_player_logs[i].co_player_episode_return += ade_reward;
- env->co_player_logs[i].co_player_avg_displacement_error = current_ade;
+ env->rewards[i] += lane_align_reward;
+
+ if (is_ego) {
+ env->logs[i].episode_return += lane_align_reward;
+ } else if (is_co_player) {
+ env->co_player_logs[i].episode_return += lane_align_reward;
}
}
+
+ int lane_aligned = env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX];
+ if (is_ego) {
+ env->logs[i].lane_alignment_rate = lane_aligned;
+ } else if (is_co_player) {
+ env->co_player_logs[i].lane_alignment_rate = lane_aligned;
+ }
}
- if (env->goal_behavior==GOAL_RESPAWN) {
- for(int i = 0; i < env->active_agent_count; i++){
+ if (env->goal_behavior == GOAL_RESPAWN) {
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX];
- if(reached_goal){
+ if (reached_goal) {
respawn_agent(env, agent_idx);
env->entities[agent_idx].respawn_count++;
}
}
- }
- else if (env->goal_behavior==GOAL_STOP) {
- for(int i = 0; i < env->active_agent_count; i++){
+ } else if (env->goal_behavior == GOAL_STOP) {
+ for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX];
- if(reached_goal){
+ if (reached_goal) {
env->entities[agent_idx].stopped = 1;
- env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f;
+ env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f;
}
}
}
-
compute_observations(env);
}
-
-const Color STONE_GRAY = (Color){80, 80, 80, 255};
-const Color PUFF_RED = (Color){187, 0, 0, 255};
-const Color PUFF_CYAN = (Color){0, 187, 187, 255};
-const Color PUFF_WHITE = (Color){241, 241, 241, 241};
-const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255};
-const Color PUFF_BACKGROUND2 = (Color){18, 72, 72, 255};
-const Color LIGHTGREEN = (Color){152, 255, 152, 255};
-
typedef struct Client Client;
struct Client {
float width;
@@ -2518,18 +2772,118 @@ struct Client {
float camera_zoom;
Camera3D camera;
Model cars[6];
- int car_assignments[MAX_AGENTS]; // To keep car model assignments consistent per vehicle
+ Model cyclist;
+ Model pedestrian;
+ ModelAnimation *cycle_anim;
+ int car_assignments[MAX_AGENTS]; // To keep car model assignments consistent per vehicle
Vector3 default_camera_position;
Vector3 default_camera_target;
+ // Video recording state (for headless rendering)
+ int recorder_pipefd[2]; // Pipe to ffmpeg process
+ pid_t recorder_pid; // PID of ffmpeg process
+ // Original map dimensions (for consistent rendering across scenarios)
+ float original_map_width;
+ float original_map_height;
};
-Client* make_client(Drive* env){
- Client* client = (Client*)calloc(1, sizeof(Client));
- client->width = 1280;
- client->height = 704;
- SetConfigFlags(FLAG_MSAA_4X_HINT);
- InitWindow(client->width, client->height, "PufferLib Ray GPU Drive");
- SetTargetFPS(30);
+// Stop the running ffmpeg recorder (if any) and wait for it to finish writing.
+static void stop_video_recorder(Client *client) {
+ if (client->recorder_pipefd[1] >= 0) {
+ close(client->recorder_pipefd[1]);
+ client->recorder_pipefd[1] = -1;
+ }
+ if (client->recorder_pipefd[0] >= 0) {
+ close(client->recorder_pipefd[0]);
+ client->recorder_pipefd[0] = -1;
+ }
+ if (client->recorder_pid > 0) {
+ int status;
+ waitpid(client->recorder_pid, &status, 0);
+ client->recorder_pid = 0;
+ }
+}
+
+// Start (or restart) the ffmpeg recorder. Stops any in-flight recorder first
+// so callers can switch output files mid-life without recreating raylib state.
+static void start_video_recorder(Client *client, const char *basename) {
+ stop_video_recorder(client);
+
+ if (pipe(client->recorder_pipefd) == -1) {
+ fprintf(stderr, "Failed to create pipe for video recording\n");
+ client->recorder_pipefd[0] = -1;
+ client->recorder_pipefd[1] = -1;
+ return;
+ }
+
+ char size_str[64];
+ snprintf(size_str, sizeof(size_str), "%dx%d", (int)client->width, (int)client->height);
+
+ char filename[320];
+ snprintf(filename, sizeof(filename), "%s.mp4", basename && basename[0] ? basename : "render");
+
+ client->recorder_pid = fork();
+ if (client->recorder_pid == -1) {
+ fprintf(stderr, "Failed to fork ffmpeg process\n");
+ close(client->recorder_pipefd[0]);
+ close(client->recorder_pipefd[1]);
+ client->recorder_pipefd[0] = -1;
+ client->recorder_pipefd[1] = -1;
+ client->recorder_pid = 0;
+ return;
+ }
+
+ if (client->recorder_pid == 0) {
+ // Child: run ffmpeg
+ close(client->recorder_pipefd[1]);
+ dup2(client->recorder_pipefd[0], STDIN_FILENO);
+ close(client->recorder_pipefd[0]);
+ for (int fd = 3; fd < 256; fd++) {
+ close(fd);
+ }
+ execlp("ffmpeg", "ffmpeg", "-y", "-f", "rawvideo", "-pix_fmt", "rgba", "-s", size_str, "-r", "30", "-i", "-",
+ "-c:v", "libx264", "-threads", "4", "-pix_fmt", "yuv420p", "-preset", "ultrafast", "-crf", "23",
+ "-loglevel", "error", filename, NULL);
+ fprintf(stderr, "Failed to exec ffmpeg\n");
+ _exit(1);
+ }
+
+ close(client->recorder_pipefd[0]); // parent: keep write end only
+ client->recorder_pipefd[0] = -1;
+ fprintf(stderr, "[drive] ffmpeg forked: pid=%d file=%s size=%s\n", client->recorder_pid, filename, size_str);
+}
+
+Client *make_client(Drive *env) {
+ Client *client = (Client *)calloc(1, sizeof(Client));
+ client->recorder_pid = 0;
+ client->recorder_pipefd[0] = -1;
+ client->recorder_pipefd[1] = -1;
+
+ if (env->render_mode == RENDER_WINDOW) {
+ // Interactive window mode
+ client->width = 1280;
+ client->height = 704;
+ SetConfigFlags(FLAG_MSAA_4X_HINT);
+ InitWindow(client->width, client->height, "PufferDrive");
+ SetTargetFPS(30);
+ } else if (env->render_mode == RENDER_HEADLESS) {
+ // Headless rendering mode - hidden window with ffmpeg pipe
+ float map_width = env->grid_map->bottom_right_x - env->grid_map->top_left_x;
+ float map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y;
+ float scale = 6.0f;
+ client->width = (int)roundf(map_width * scale / 2.0f) * 2;
+ client->height = (int)roundf(map_height * scale / 2.0f) * 2;
+ client->original_map_width = map_width;
+ client->original_map_height = map_height;
+
+ SetConfigFlags(FLAG_WINDOW_HIDDEN);
+ SetConfigFlags(FLAG_MSAA_4X_HINT);
+ InitWindow(client->width, client->height, "PufferDrive Headless");
+ SetTargetFPS(6000);
+
+ start_video_recorder(client, env->video_basename);
+ }
+
+ // Load textures and models (for both window and headless modes)
client->puffers = LoadTexture("resources/puffers_128.png");
client->cars[0] = LoadModel("resources/drive/RedCar.glb");
client->cars[1] = LoadModel("resources/drive/WhiteCar.glb");
@@ -2537,26 +2891,21 @@ Client* make_client(Drive* env){
client->cars[3] = LoadModel("resources/drive/YellowCar.glb");
client->cars[4] = LoadModel("resources/drive/GreenCar.glb");
client->cars[5] = LoadModel("resources/drive/GreyCar.glb");
+ client->cyclist = LoadModel("resources/drive/cyclist.glb");
+ client->pedestrian = LoadModel("resources/drive/pedestrian.glb");
+ int animCountCyc = 0;
+ client->cycle_anim = LoadModelAnimations("resources/drive/cyclist.glb", &animCountCyc);
for (int i = 0; i < MAX_AGENTS; i++) {
client->car_assignments[i] = (rand() % 4) + 1;
}
- // Get initial target position from first active agent
- Vector3 target_pos = {
- 0,
- 0, // Y is up
- 1 // Z is depth
- };
- // Set up camera to look at target from above and behind
- client->default_camera_position = (Vector3){
- 0, // Same X as target
- 120.0f, // 20 units above target
- 175.0f // 20 units behind target
- };
+ // Set up camera defaults
+ Vector3 target_pos = {0, 0, 1};
+ client->default_camera_position = (Vector3){0, 120.0f, 175.0f};
client->default_camera_target = target_pos;
client->camera.position = client->default_camera_position;
client->camera.target = client->default_camera_target;
- client->camera.up = (Vector3){ 0.0f, -1.0f, 0.0f }; // Y is up
+ client->camera.up = (Vector3){0.0f, -1.0f, 0.0f};
client->camera.fovy = 45.0f;
client->camera.projection = CAMERA_PERSPECTIVE;
client->camera_zoom = 1.0f;
@@ -2564,7 +2913,7 @@ Client* make_client(Drive* env){
}
// Camera control functions
-void handle_camera_controls(Client* client) {
+void handle_camera_controls(Client *client) {
static Vector2 prev_mouse_pos = {0};
static bool is_dragging = false;
float camera_move_speed = 0.5f;
@@ -2581,10 +2930,8 @@ void handle_camera_controls(Client* client) {
if (is_dragging) {
Vector2 current_mouse_pos = GetMousePosition();
- Vector2 delta = {
- (current_mouse_pos.x - prev_mouse_pos.x) * camera_move_speed,
- -(current_mouse_pos.y - prev_mouse_pos.y) * camera_move_speed
- };
+ Vector2 delta = {(current_mouse_pos.x - prev_mouse_pos.x) * camera_move_speed,
+ -(current_mouse_pos.y - prev_mouse_pos.y) * camera_move_speed};
// Update camera position (only X and Y)
client->camera.position.x += delta.x;
@@ -2602,11 +2949,9 @@ void handle_camera_controls(Client* client) {
if (wheel != 0) {
float zoom_factor = 1.0f - (wheel * 0.1f);
// Calculate the current direction vector from target to position
- Vector3 direction = {
- client->camera.position.x - client->camera.target.x,
- client->camera.position.y - client->camera.target.y,
- client->camera.position.z - client->camera.target.z
- };
+ Vector3 direction = {client->camera.position.x - client->camera.target.x,
+ client->camera.position.y - client->camera.target.y,
+ client->camera.position.z - client->camera.target.z};
// Scale the direction vector by the zoom factor
direction.x *= zoom_factor;
@@ -2620,28 +2965,27 @@ void handle_camera_controls(Client* client) {
}
}
-void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int lasers){
+void draw_agent_obs(Drive *env, int agent_index, int mode, int obs_only, int lasers) {
// Diamond dimensions
- float diamond_height = 3.0f; // Total height of diamond
- float diamond_width = 1.5f; // Width of diamond
- float diamond_z = 8.0f; // Base Z position
+ float diamond_height = 3.0f; // Total height of diamond
+ float diamond_width = 1.5f; // Width of diamond
+ float diamond_z = 8.0f; // Base Z position
// Define diamond points
- Vector3 top_point = (Vector3){0.0f, 0.0f, diamond_z + diamond_height/2}; // Top point
- Vector3 bottom_point = (Vector3){0.0f, 0.0f, diamond_z - diamond_height/2}; // Bottom point
- Vector3 front_point = (Vector3){0.0f, diamond_width/2, diamond_z}; // Front point
- Vector3 back_point = (Vector3){0.0f, -diamond_width/2, diamond_z}; // Back point
- Vector3 left_point = (Vector3){-diamond_width/2, 0.0f, diamond_z}; // Left point
- Vector3 right_point = (Vector3){diamond_width/2, 0.0f, diamond_z}; // Right point
+ Vector3 top_point = (Vector3){0.0f, 0.0f, diamond_z + diamond_height / 2}; // Top point
+ Vector3 bottom_point = (Vector3){0.0f, 0.0f, diamond_z - diamond_height / 2}; // Bottom point
+ Vector3 front_point = (Vector3){0.0f, diamond_width / 2, diamond_z}; // Front point
+ Vector3 back_point = (Vector3){0.0f, -diamond_width / 2, diamond_z}; // Back point
+ Vector3 left_point = (Vector3){-diamond_width / 2, 0.0f, diamond_z}; // Left point
+ Vector3 right_point = (Vector3){diamond_width / 2, 0.0f, diamond_z}; // Right point
// Draw the diamond faces
// Top pyramid
-
- if(mode ==0){
- DrawTriangle3D(top_point, front_point, right_point, PUFF_CYAN); // Front-right face
- DrawTriangle3D(top_point, right_point, back_point, PUFF_CYAN); // Back-right face
- DrawTriangle3D(top_point, back_point, left_point, PUFF_CYAN); // Back-left face
- DrawTriangle3D(top_point, left_point, front_point, PUFF_CYAN); // Front-left face
+ if (mode == 0) {
+ DrawTriangle3D(top_point, front_point, right_point, PUFF_CYAN); // Front-right face
+ DrawTriangle3D(top_point, right_point, back_point, PUFF_CYAN); // Back-right face
+ DrawTriangle3D(top_point, back_point, left_point, PUFF_CYAN); // Back-left face
+ DrawTriangle3D(top_point, left_point, front_point, PUFF_CYAN); // Front-left face
// Bottom pyramid
DrawTriangle3D(bottom_point, right_point, front_point, PUFF_CYAN); // Front-right face
@@ -2649,14 +2993,16 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las
DrawTriangle3D(bottom_point, left_point, back_point, PUFF_CYAN); // Back-left face
DrawTriangle3D(bottom_point, front_point, left_point, PUFF_CYAN); // Front-left face
}
- if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){
+ if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) {
return;
}
- int ego_dim = (env->dynamics_model == JERK) ? 10 : 7;
- int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS;
- float (*observations)[max_obs] = (float(*)[max_obs])env->observations;
- float* agent_obs = &observations[agent_index][0];
+ int base_ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC;
+ int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0);
+ int ego_dim = base_ego_dim + conditioning_dims;
+ int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS;
+ float (*observations)[max_obs] = (float (*)[max_obs])env->observations;
+ float *agent_obs = &observations[agent_index][0];
// self
int active_idx = env->active_agent_indices[agent_index];
float heading_self_x = env->entities[active_idx].heading_x;
@@ -2666,82 +3012,64 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las
// draw goal
float goal_x = agent_obs[0] * 200;
float goal_y = agent_obs[1] * 200;
- if(mode == 0 ){
+ if (mode == 0) {
DrawSphere((Vector3){goal_x, goal_y, 1}, 0.5f, LIGHTGREEN);
- DrawCircle3D((Vector3){goal_x, goal_y, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f));
+ DrawCircle3D((Vector3){goal_x, goal_y, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f,
+ Fade(LIGHTGREEN, 0.3f));
}
- if (mode == 1){
- float goal_x_world = px + (goal_x * heading_self_x - goal_y*heading_self_y);
- float goal_y_world = py + (goal_x * heading_self_y + goal_y*heading_self_x);
+ if (mode == 1) {
+ float goal_x_world = px + (goal_x * heading_self_x - goal_y * heading_self_y);
+ float goal_y_world = py + (goal_x * heading_self_y + goal_y * heading_self_x);
DrawSphere((Vector3){goal_x_world, goal_y_world, 1}, 0.5f, LIGHTGREEN);
- DrawCircle3D((Vector3){goal_x_world, goal_y_world, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f));
+ DrawCircle3D((Vector3){goal_x_world, goal_y_world, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f,
+ Fade(LIGHTGREEN, 0.3f));
}
// First draw other agent observations
- int obs_idx = ego_dim; // Start after ego obs
- for(int j = 0; j < MAX_AGENTS - 1; j++) {
- if(agent_obs[obs_idx] == 0 || agent_obs[obs_idx + 1] == 0) {
- obs_idx += 7; // Move to next agent observation
+ int obs_idx = ego_dim; // Start after ego obs
+ for (int j = 0; j < MAX_AGENTS - 1; j++) {
+ if (agent_obs[obs_idx] == 0 || agent_obs[obs_idx + 1] == 0) {
+ obs_idx += 7; // Move to next agent observation
continue;
}
// Draw position of other agents
float x = agent_obs[obs_idx] * 50;
float y = agent_obs[obs_idx + 1] * 50;
- if(lasers && mode == 0){
- DrawLine3D(
- (Vector3){0, 0, 0},
- (Vector3){x, y, 1},
- ORANGE
- );
- }
-
- float partner_x = px + (x*heading_self_x - y*heading_self_y);
- float partner_y = py + (x*heading_self_y + y*heading_self_x);
- if(lasers && mode ==1){
- DrawLine3D(
- (Vector3){px, py, 1},
- (Vector3){partner_x,partner_y,1},
- ORANGE
- );
- }
-
- float half_width = 0.5*agent_obs[obs_idx + 2]*MAX_VEH_WIDTH;
- float half_len = 0.5*agent_obs[obs_idx + 3]*MAX_VEH_LEN;
+ if (lasers && mode == 0) {
+ DrawLine3D((Vector3){0, 0, 0}, (Vector3){x, y, 1}, ORANGE);
+ }
+
+ float partner_x = px + (x * heading_self_x - y * heading_self_y);
+ float partner_y = py + (x * heading_self_y + y * heading_self_x);
+ if (lasers && mode == 1) {
+ DrawLine3D((Vector3){px, py, 1}, (Vector3){partner_x, partner_y, 1}, ORANGE);
+ }
+
+ float half_width = 0.5 * agent_obs[obs_idx + 2] * MAX_VEH_WIDTH;
+ float half_len = 0.5 * agent_obs[obs_idx + 3] * MAX_VEH_LEN;
float theta_x = agent_obs[obs_idx + 4];
float theta_y = agent_obs[obs_idx + 5];
float partner_angle = atan2f(theta_y, theta_x);
float cos_heading = cosf(partner_angle);
float sin_heading = sinf(partner_angle);
Vector3 corners[4] = {
- (Vector3){
- x + (half_len * cos_heading - half_width * sin_heading),
- y + (half_len * sin_heading + half_width * cos_heading),
- 1
- },
- (Vector3){
- x + (half_len * cos_heading + half_width * sin_heading),
- y + (half_len * sin_heading - half_width * cos_heading),
- 1
- },
- (Vector3){
- x + (-half_len * cos_heading + half_width * sin_heading),
- y + (-half_len * sin_heading - half_width * cos_heading),
- 1
- },
- (Vector3){
- x + (-half_len * cos_heading - half_width * sin_heading),
- y + (-half_len * sin_heading + half_width * cos_heading),
- 1
- },
+ (Vector3){x + (half_len * cos_heading - half_width * sin_heading),
+ y + (half_len * sin_heading + half_width * cos_heading), 1},
+ (Vector3){x + (half_len * cos_heading + half_width * sin_heading),
+ y + (half_len * sin_heading - half_width * cos_heading), 1},
+ (Vector3){x + (-half_len * cos_heading + half_width * sin_heading),
+ y + (-half_len * sin_heading - half_width * cos_heading), 1},
+ (Vector3){x + (-half_len * cos_heading - half_width * sin_heading),
+ y + (-half_len * sin_heading + half_width * cos_heading), 1},
};
- if(mode ==0){
+ if (mode == 0) {
for (int j = 0; j < 4; j++) {
- DrawLine3D(corners[j], corners[(j+1)%4], ORANGE);
+ DrawLine3D(corners[j], corners[(j + 1) % 4], ORANGE);
}
}
- if(mode ==1){
+ if (mode == 1) {
Vector3 world_corners[4];
for (int j = 0; j < 4; j++) {
float lx = corners[j].x;
@@ -2752,90 +3080,74 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las
world_corners[j].z = 1;
}
for (int j = 0; j < 4; j++) {
- DrawLine3D(world_corners[j], world_corners[(j+1)%4], ORANGE);
+ DrawLine3D(world_corners[j], world_corners[(j + 1) % 4], ORANGE);
}
}
// draw an arrow above the car pointing in the direction that the partner is going
- float arrow_length = 7.5f;
- float arrow_x = x + arrow_length*cosf(partner_angle);
- float arrow_y = y + arrow_length*sinf(partner_angle);
+ float arrow_length = 2.5f;
+ float arrow_x = x + arrow_length * cosf(partner_angle);
+ float arrow_y = y + arrow_length * sinf(partner_angle);
float arrow_x_world;
float arrow_y_world;
- if(mode ==0){
- DrawLine3D((Vector3){x, y, 1}, (Vector3){arrow_x, arrow_y, 1}, PUFF_WHITE);
+ if (mode == 0) {
+ DrawLine3D((Vector3){x, y, 0.0}, (Vector3){arrow_x, arrow_y, 0.0}, PUFF_WHITE);
}
- if(mode == 1){
- arrow_x_world = px + (arrow_x * heading_self_x - arrow_y*heading_self_y);
- arrow_y_world = py + (arrow_x * heading_self_y + arrow_y*heading_self_x);
+ if (mode == 1) {
+ arrow_x_world = px + (arrow_x * heading_self_x - arrow_y * heading_self_y);
+ arrow_y_world = py + (arrow_x * heading_self_y + arrow_y * heading_self_x);
DrawLine3D((Vector3){partner_x, partner_y, 1}, (Vector3){arrow_x_world, arrow_y_world, 1}, PUFF_WHITE);
}
// Calculate perpendicular offsets for arrow head
- float arrow_size = 2.0f; // Size of the arrow head
+ float arrow_size = 0.3f; // Size of the arrow head
float dx = arrow_x - x;
float dy = arrow_y - y;
- float length = sqrtf(dx*dx + dy*dy);
+ float length = sqrtf(dx * dx + dy * dy);
if (length > 0) {
// Normalize direction vector
dx /= length;
dy /= length;
// Calculate perpendicular vector
-
float perp_x = -dy * arrow_size;
float perp_y = dx * arrow_size;
- float arrow_x_end1 = arrow_x - dx*arrow_size + perp_x;
- float arrow_y_end1 = arrow_y - dy*arrow_size + perp_y;
- float arrow_x_end2 = arrow_x - dx*arrow_size - perp_x;
- float arrow_y_end2 = arrow_y - dy*arrow_size - perp_y;
+ float arrow_x_end1 = arrow_x - dx * arrow_size + perp_x;
+ float arrow_y_end1 = arrow_y - dy * arrow_size + perp_y;
+ float arrow_x_end2 = arrow_x - dx * arrow_size - perp_x;
+ float arrow_y_end2 = arrow_y - dy * arrow_size - perp_y;
// Draw the two lines forming the arrow head
- if(mode ==0){
- DrawLine3D(
- (Vector3){arrow_x, arrow_y, 1},
- (Vector3){arrow_x_end1, arrow_y_end1, 1},
- PUFF_WHITE
- );
- DrawLine3D(
- (Vector3){arrow_x, arrow_y, 1},
- (Vector3){arrow_x_end2, arrow_y_end2, 1},
- PUFF_WHITE
- );
+ if (mode == 0) {
+ DrawLine3D((Vector3){arrow_x, arrow_y, 0.0}, (Vector3){arrow_x_end1, arrow_y_end1, 0.0}, PUFF_WHITE);
+ DrawLine3D((Vector3){arrow_x, arrow_y, 0.0}, (Vector3){arrow_x_end2, arrow_y_end2, 0.0}, PUFF_WHITE);
}
- if(mode==1){
- float arrow_x_end1_world = px + (arrow_x_end1 * heading_self_x - arrow_y_end1*heading_self_y);
- float arrow_y_end1_world = py + (arrow_x_end1 * heading_self_y + arrow_y_end1*heading_self_x);
- float arrow_x_end2_world = px + (arrow_x_end2 * heading_self_x - arrow_y_end2*heading_self_y);
- float arrow_y_end2_world = py + (arrow_x_end2 * heading_self_y + arrow_y_end2*heading_self_x);
- DrawLine3D(
- (Vector3){arrow_x_world, arrow_y_world, 1},
- (Vector3){arrow_x_end1_world, arrow_y_end1_world, 1},
- PUFF_WHITE
- );
- DrawLine3D(
- (Vector3){arrow_x_world, arrow_y_world, 1},
- (Vector3){arrow_x_end2_world, arrow_y_end2_world, 1},
- PUFF_WHITE
- );
-
+ if (mode == 1) {
+ float arrow_x_end1_world = px + (arrow_x_end1 * heading_self_x - arrow_y_end1 * heading_self_y);
+ float arrow_y_end1_world = py + (arrow_x_end1 * heading_self_y + arrow_y_end1 * heading_self_x);
+ float arrow_x_end2_world = px + (arrow_x_end2 * heading_self_x - arrow_y_end2 * heading_self_y);
+ float arrow_y_end2_world = py + (arrow_x_end2 * heading_self_y + arrow_y_end2 * heading_self_x);
+ DrawLine3D((Vector3){arrow_x_world, arrow_y_world, 0.0},
+ (Vector3){arrow_x_end1_world, arrow_y_end1_world, 0.0}, PUFF_WHITE);
+ DrawLine3D((Vector3){arrow_x_world, arrow_y_world, 0.0},
+ (Vector3){arrow_x_end2_world, arrow_y_end2_world, 0.0}, PUFF_WHITE);
}
}
- obs_idx += 7; // Move to next agent observation (7 values per agent)
+ obs_idx += PARTNER_FEATURES; // Move to next agent observation (7 values per agent)
}
// Then draw map observations
- int map_start_idx = 7 + 7*(MAX_AGENTS - 1); // Start after agent observations
- for(int k = 0; k < MAX_ROAD_SEGMENT_OBSERVATIONS; k++) { // Loop through potential map entities
- int entity_idx = map_start_idx + k*7;
- if(agent_obs[entity_idx] == 0 && agent_obs[entity_idx + 1] == 0){
+ int map_start_idx = ego_dim + PARTNER_FEATURES * (MAX_AGENTS - 1); // Start after agent observations
+ for (int k = 0; k < MAX_ROAD_SEGMENT_OBSERVATIONS; k++) { // Loop through potential map entities
+ int entity_idx = map_start_idx + k * 7;
+ if (agent_obs[entity_idx] == 0 && agent_obs[entity_idx + 1] == 0) {
continue;
}
- Color lineColor = BLUE; // Default color
+ Color lineColor = BLUE; // Default color
int entity_type = (int)agent_obs[entity_idx + 6];
// Choose color based on entity type
- if(entity_type+4 != ROAD_EDGE){
+ if (entity_type + 4 != ROAD_EDGE) {
continue;
}
lineColor = PUFF_CYAN;
@@ -2848,88 +3160,60 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las
float segment_length = agent_obs[entity_idx + 2] * MAX_ROAD_SEGMENT_LENGTH;
// Calculate endpoint using the relative angle directly
// Calculate endpoint directly
- float x_start = x_middle - segment_length*cosf(rel_angle);
- float y_start = y_middle - segment_length*sinf(rel_angle);
- float x_end = x_middle + segment_length*cosf(rel_angle);
- float y_end = y_middle + segment_length*sinf(rel_angle);
-
-
- if(lasers && mode ==0){
- DrawLine3D((Vector3){0,0,0}, (Vector3){x_middle, y_middle, 1}, lineColor);
- }
-
- if(mode ==1){
- float x_middle_world = px + (x_middle*heading_self_x - y_middle*heading_self_y);
- float y_middle_world = py + (x_middle*heading_self_y + y_middle*heading_self_x);
- float x_start_world = px + (x_start*heading_self_x - y_start*heading_self_y);
- float y_start_world = py + (x_start*heading_self_y + y_start*heading_self_x);
- float x_end_world = px + (x_end*heading_self_x - y_end*heading_self_y);
- float y_end_world = py + (x_end*heading_self_y + y_end*heading_self_x);
+ float x_start = x_middle - segment_length * cosf(rel_angle);
+ float y_start = y_middle - segment_length * sinf(rel_angle);
+ float x_end = x_middle + segment_length * cosf(rel_angle);
+ float y_end = y_middle + segment_length * sinf(rel_angle);
+
+ if (lasers && mode == 0) {
+ DrawLine3D((Vector3){0, 0, 0}, (Vector3){x_middle, y_middle, 1}, lineColor);
+ }
+
+ if (mode == 1) {
+ float x_middle_world = px + (x_middle * heading_self_x - y_middle * heading_self_y);
+ float y_middle_world = py + (x_middle * heading_self_y + y_middle * heading_self_x);
+ float x_start_world = px + (x_start * heading_self_x - y_start * heading_self_y);
+ float y_start_world = py + (x_start * heading_self_y + y_start * heading_self_x);
+ float x_end_world = px + (x_end * heading_self_x - y_end * heading_self_y);
+ float y_end_world = py + (x_end * heading_self_y + y_end * heading_self_x);
DrawCube((Vector3){x_middle_world, y_middle_world, 1}, 0.5f, 0.5f, 0.5f, lineColor);
DrawLine3D((Vector3){x_start_world, y_start_world, 1}, (Vector3){x_end_world, y_end_world, 1}, BLUE);
- if(lasers) DrawLine3D((Vector3){px,py,1}, (Vector3){x_middle_world, y_middle_world, 1}, lineColor);
+ if (lasers)
+ DrawLine3D((Vector3){px, py, 1}, (Vector3){x_middle_world, y_middle_world, 1}, lineColor);
}
- if(mode ==0){
+ if (mode == 0) {
DrawCube((Vector3){x_middle, y_middle, 1}, 0.5f, 0.5f, 0.5f, lineColor);
DrawLine3D((Vector3){x_start, y_start, 1}, (Vector3){x_end, y_end, 1}, BLUE);
}
}
}
-void draw_road_edge(Drive* env, float start_x, float start_y, float end_x, float end_y){
- Color CURB_TOP = (Color){220, 220, 220, 255}; // Top surface - lightest
- Color CURB_SIDE = (Color){180, 180, 180, 255}; // Side faces - medium
+void draw_road_edge(Drive *env, float start_x, float start_y, float end_x, float end_y) {
+ Color CURB_TOP = (Color){220, 220, 220, 255}; // Top surface - lightest
+ Color CURB_SIDE = (Color){180, 180, 180, 255}; // Side faces - medium
Color CURB_BOTTOM = (Color){160, 160, 160, 255};
- // Calculate curb dimensions
- float curb_height = 0.5f; // Height of the curb
- float curb_width = 0.3f; // Width/thickness of the curb
- float road_z = 0.2f; // Ensure z-level for roads is below agents
+ // Calculate curb dimensions
+ float curb_height = 0.5f; // Height of the curb
+ float curb_width = 0.3f; // Width/thickness of the curb
+ float road_z = 0.0f; // Ensure z-level for roads is below agents
// Calculate direction vector between start and end
- Vector3 direction = {
- end_x - start_x,
- end_y - start_y,
- 0.0f
- };
+ Vector3 direction = {end_x - start_x, end_y - start_y, 0.0f};
// Calculate length of the segment
float length = sqrtf(direction.x * direction.x + direction.y * direction.y);
// Normalize direction vector
- Vector3 normalized_dir = {
- direction.x / length,
- direction.y / length,
- 0.0f
- };
+ Vector3 normalized_dir = {direction.x / length, direction.y / length, 0.0f};
// Calculate perpendicular vector for width
- Vector3 perpendicular = {
- -normalized_dir.y,
- normalized_dir.x,
- 0.0f
- };
+ Vector3 perpendicular = {-normalized_dir.y, normalized_dir.x, 0.0f};
// Calculate the four bottom corners of the curb
- Vector3 b1 = {
- start_x - perpendicular.x * curb_width/2,
- start_y - perpendicular.y * curb_width/2,
- road_z
- };
- Vector3 b2 = {
- start_x + perpendicular.x * curb_width/2,
- start_y + perpendicular.y * curb_width/2,
- road_z
- };
- Vector3 b3 = {
- end_x + perpendicular.x * curb_width/2,
- end_y + perpendicular.y * curb_width/2,
- road_z
- };
- Vector3 b4 = {
- end_x - perpendicular.x * curb_width/2,
- end_y - perpendicular.y * curb_width/2,
- road_z
- };
+ Vector3 b1 = {start_x - perpendicular.x * curb_width / 2, start_y - perpendicular.y * curb_width / 2, road_z};
+ Vector3 b2 = {start_x + perpendicular.x * curb_width / 2, start_y + perpendicular.y * curb_width / 2, road_z};
+ Vector3 b3 = {end_x + perpendicular.x * curb_width / 2, end_y + perpendicular.y * curb_width / 2, road_z};
+ Vector3 b4 = {end_x - perpendicular.x * curb_width / 2, end_y - perpendicular.y * curb_width / 2, road_z};
// Draw the curb faces
// Bottom face
@@ -2955,56 +3239,58 @@ void draw_road_edge(Drive* env, float start_x, float start_y, float end_x, float
DrawTriangle3D(t4, t1, b1, CURB_SIDE);
}
-void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, int show_grid){
- // Draw a grid to help with orientation
- // DrawGrid(20, 1.0f);
- DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->top_left_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->top_left_y, 0}, PUFF_CYAN);
- DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->top_left_x, env->grid_map->top_left_y, 0}, PUFF_CYAN);
- DrawLine3D((Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->top_left_y, 0}, PUFF_CYAN);
- DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, PUFF_CYAN);
- for(int i = 0; i < env->num_entities; i++) {
+void draw_scene(Drive *env, Client *client, int mode, int obs_only, int lasers, int show_grid) {
+
+ if (show_grid) {
+ float grid_start_x = env->grid_map->top_left_x;
+ float grid_start_y = env->grid_map->bottom_right_y;
+ for (int i = 0; i < env->grid_map->grid_cols; i++) {
+ for (int j = 0; j < env->grid_map->grid_rows; j++) {
+ float x = grid_start_x + i * GRID_CELL_SIZE;
+ float y = grid_start_y + j * GRID_CELL_SIZE;
+ DrawCubeWires((Vector3){x + GRID_CELL_SIZE / 2, y + GRID_CELL_SIZE / 2, 0.0f}, GRID_CELL_SIZE,
+ GRID_CELL_SIZE, 0.1f, Fade(PUFF_BACKGROUND2, 0.3f));
+ }
+ }
+ }
+
+ // Draw a grid to help with orientation
+ for (int i = 0; i < env->num_entities; i++) {
// Draw objects
- if(env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) {
+ if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN ||
+ env->entities[i].type == CYCLIST) {
// Check if this vehicle is an active agent
bool is_active_agent = false;
bool is_static_agent = false;
int agent_index = -1;
- for(int j = 0; j < env->active_agent_count; j++) {
- if(env->active_agent_indices[j] == i) {
+ for (int j = 0; j < env->active_agent_count; j++) {
+ if (env->active_agent_indices[j] == i) {
is_active_agent = true;
agent_index = j;
break;
}
}
- for(int j = 0; j < env->static_agent_count; j++) {
- if(env->static_agent_indices[j] == i) {
+ for (int j = 0; j < env->static_agent_count; j++) {
+ if (env->static_agent_indices[j] == i) {
is_static_agent = true;
break;
}
}
// HIDE CARS ON RESPAWN - IMPORTANT TO KNOW VISUAL SETTING
- if((!is_active_agent && !is_static_agent) || env->entities[i].respawn_timestep != -1){
+ if ((!is_active_agent && !is_static_agent) || env->entities[i].respawn_timestep != -1) {
continue;
}
Vector3 position;
float heading;
- position = (Vector3){
- env->entities[i].x,
- env->entities[i].y,
- 1
- };
+ position = (Vector3){env->entities[i].x, env->entities[i].y, 1.1};
heading = env->entities[i].heading;
// Create size vector
- Vector3 size = {
- env->entities[i].length,
- env->entities[i].width,
- env->entities[i].height
- };
+ Vector3 size = {env->entities[i].length, env->entities[i].width, env->entities[i].height};
bool is_expert = (!is_active_agent) && (env->entities[i].mark_as_expert == 1);
// Save current transform
- if(mode==1){
+ if (mode == 1) {
float cos_heading = env->entities[i].heading_x;
float sin_heading = env->entities[i].heading_y;
@@ -3014,220 +3300,184 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers,
// Calculate the four corners of the collision box
Vector3 corners[4] = {
- (Vector3){
- position.x + (half_len * cos_heading - half_width * sin_heading),
- position.y + (half_len * sin_heading + half_width * cos_heading),
- position.z
- },
-
-
- (Vector3){
- position.x + (half_len * cos_heading + half_width * sin_heading),
- position.y + (half_len * sin_heading - half_width * cos_heading),
- position.z
- },
- (Vector3){
- position.x + (-half_len * cos_heading + half_width * sin_heading),
- position.y + (-half_len * sin_heading - half_width * cos_heading),
- position.z
- },
- (Vector3){
- position.x + (-half_len * cos_heading - half_width * sin_heading),
- position.y + (-half_len * sin_heading + half_width * cos_heading),
- position.z
- },
-
+ (Vector3){position.x + (half_len * cos_heading - half_width * sin_heading),
+ position.y + (half_len * sin_heading + half_width * cos_heading), position.z},
+ (Vector3){position.x + (half_len * cos_heading + half_width * sin_heading),
+ position.y + (half_len * sin_heading - half_width * cos_heading), position.z},
+ (Vector3){position.x + (-half_len * cos_heading + half_width * sin_heading),
+ position.y + (-half_len * sin_heading - half_width * cos_heading), position.z},
+ (Vector3){position.x + (-half_len * cos_heading - half_width * sin_heading),
+ position.y + (-half_len * sin_heading + half_width * cos_heading), position.z},
};
- if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) {
+ if (agent_index == env->human_agent_idx &&
+ !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) {
draw_agent_obs(env, agent_index, mode, obs_only, lasers);
}
- if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){
+
+ if ((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx) {
continue;
}
// --- Draw the car ---
-
- Vector3 carPos = { position.x, position.y, position.z };
- Color car_color = GRAY; // default for static
- if (is_expert) car_color = GOLD; // expert replay
- if (is_active_agent) car_color = BLUE; // policy-controlled
- if (is_active_agent && env->entities[i].collision_state > 0) car_color = RED;
+ Color car_color = GRAY; // default for static
+ if (is_expert)
+ car_color = GOLD; // expert replay
+ if (is_active_agent) {
+ // Distinguish ego (magenta) from co-player (blue) so renders
+ // make it obvious which policy controls which car. Without
+ // this both look identical and behavioral debugging is
+ // ambiguous.
+ if (env->entities[i].is_ego)
+ car_color = MAGENTA;
+ else
+ car_color = BLUE;
+ }
+ if (is_active_agent && env->entities[i].collision_state > 0)
+ car_color = RED;
rlSetLineWidth(3.0f);
for (int j = 0; j < 4; j++) {
- DrawLine3D(corners[j], corners[(j+1)%4], car_color);
+ DrawLine3D(corners[j], corners[(j + 1) % 4], car_color);
}
// --- Draw a heading arrow pointing forward ---
Vector3 arrowStart = position;
- Vector3 arrowEnd = {
- position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car
- position.y + sin_heading * half_len * 1.5f,
- position.z
- };
+ Vector3 arrowEnd = {position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car
+ position.y + sin_heading * half_len * 1.5f, position.z};
DrawLine3D(arrowStart, arrowEnd, car_color);
- DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip
+ DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip
- }
- else {
+ } else { // Agent view
rlPushMatrix();
// Translate to position, rotate around Y axis, then draw
rlTranslatef(position.x, position.y, position.z);
- rlRotatef(heading*RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees
- // Determine color based on status
- Color object_color = PUFF_BACKGROUND2; // fill color unused for model tint
- Color outline_color = PUFF_CYAN; // not used for model tint
- Model car_model = client->cars[5];
- if(is_active_agent){
- car_model = client->cars[client->car_assignments[i %64]];
- }
- if(agent_index == env->human_agent_idx){
- object_color = PUFF_CYAN;
- outline_color = PUFF_WHITE;
- }
- if(is_active_agent && env->entities[i].collision_state > 0) {
- car_model = client->cars[0]; // Collided agent
+ rlRotatef(heading * RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees
+
+ // Select car model (skip index 0)
+ Model car_model = client->cars[(i % 5) + 1]; // Cycles through indices 1-5
+
+ if (agent_index == env->human_agent_idx) {
+ car_model = client->cars[0]; // Ego agent always uses red car
+ } else if (is_active_agent) {
+
+ car_model = client->cars[(i % 5) + 1];
+
+ if (env->entities[i].collision_state > 0) {
+ car_model = client->cars[0]; // Collided agents use red
+ }
}
- // Draw obs for human selected agent
- if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) {
+ // Draw obs for selected agent index
+ if (agent_index == env->human_agent_idx &&
+ (!env->entities[agent_index].metrics_array[REACHED_GOAL_IDX] ||
+ env->goal_behavior == GOAL_GENERATE_NEW || env->goal_behavior == GOAL_STOP)) {
draw_agent_obs(env, agent_index, mode, obs_only, lasers);
}
+
// Draw cube for cars static and active
// Calculate scale factors based on desired size and model dimensions
-
BoundingBox bounds = GetModelBoundingBox(car_model);
- Vector3 model_size = {
- bounds.max.x - bounds.min.x,
- bounds.max.y - bounds.min.y,
- bounds.max.z - bounds.min.z
- };
- Vector3 scale = {
- size.x / model_size.x,
- size.y / model_size.y,
- size.z / model_size.z
- };
- if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){
- rlPopMatrix();
- continue;
+ Vector3 model_size = {bounds.max.x - bounds.min.x, bounds.max.y - bounds.min.y,
+ bounds.max.z - bounds.min.z};
+ Vector3 scale = {size.x / model_size.x, size.y / model_size.y, size.z / model_size.z};
+ // if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){
+ // rlPopMatrix();
+ // continue;
+ // }
+ if (env->entities[i].type == CYCLIST) {
+ scale = (Vector3){0.01, 0.01, 0.01};
+ car_model = client->cyclist;
+ }
+ if (env->entities[i].type == PEDESTRIAN) {
+ scale = (Vector3){2, 2, 2};
+ car_model = client->pedestrian;
}
-
DrawModelEx(car_model, (Vector3){0, 0, 0}, (Vector3){1, 0, 0}, 90.0f, scale, WHITE);
{
- float cos_heading = env->entities[i].heading_x;
- float sin_heading = env->entities[i].heading_y;
float half_len = env->entities[i].length * 0.5f;
float half_width = env->entities[i].width * 0.5f;
Vector3 corners[4] = {
- (Vector3){ 0 + ( half_len * cos_heading - half_width * sin_heading), 0 + ( half_len * sin_heading + half_width * cos_heading), 0 },
- (Vector3){ 0 + ( half_len * cos_heading + half_width * sin_heading), 0 + ( half_len * sin_heading - half_width * cos_heading), 0 },
- (Vector3){ 0 + (-half_len * cos_heading + half_width * sin_heading), 0 + (-half_len * sin_heading - half_width * cos_heading), 0 },
- (Vector3){ 0 + (-half_len * cos_heading - half_width * sin_heading), 0 + (-half_len * sin_heading + half_width * cos_heading), 0 },
+ (Vector3){half_len, -half_width, 0}, // Front-left
+ (Vector3){half_len, half_width, 0}, // Front-right
+ (Vector3){-half_len, half_width, 0}, // Back-right
+ (Vector3){-half_len, -half_width, 0}, // Back-left
};
- Color wire_color = GRAY; // static
- if (!is_active_agent && env->entities[i].mark_as_expert == 1) wire_color = GOLD; // expert replay
- if (is_active_agent) wire_color = BLUE; // policy
- if (is_active_agent && env->entities[i].collision_state > 0) wire_color = RED;
+ Color wire_color = GRAY; // static
+ if (!is_active_agent && env->entities[i].mark_as_expert == 1)
+ wire_color = GOLD; // expert replay
+ if (is_active_agent)
+ wire_color = BLUE; // policy
+ if (is_active_agent && env->entities[i].collision_state > 0)
+ wire_color = RED;
rlSetLineWidth(2.0f);
for (int j = 0; j < 4; j++) {
- DrawLine3D(corners[j], corners[(j+1)%4], wire_color);
+ DrawLine3D(corners[j], corners[(j + 1) % 4], wire_color);
}
}
rlPopMatrix();
}
// FPV Camera Control
- if(IsKeyDown(KEY_SPACE) && env->human_agent_idx== agent_index){
- if(env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]){
- env->human_agent_idx = rand() % env->active_agent_count;
- }
- Vector3 camera_position = (Vector3){
- position.x - (25.0f * cosf(heading)),
- position.y - (25.0f * sinf(heading)),
- position.z + 15
- };
+ if (IsKeyDown(KEY_SPACE) && env->human_agent_idx == agent_index) {
+ Vector3 camera_position = (Vector3){position.x - (25.0f * cosf(heading)),
+ position.y - (25.0f * sinf(heading)), position.z + 15};
- Vector3 camera_target = (Vector3){
- position.x + 40.0f * cosf(heading),
- position.y + 40.0f * sinf(heading),
- position.z - 5.0f
- };
+ Vector3 camera_target = (Vector3){position.x + 40.0f * cosf(heading),
+ position.y + 40.0f * sinf(heading), position.z - 5.0f};
client->camera.position = camera_position;
client->camera.target = camera_target;
client->camera.up = (Vector3){0, 0, 1};
}
- if(IsKeyReleased(KEY_SPACE)){
+ if (IsKeyReleased(KEY_SPACE)) {
client->camera.position = client->default_camera_position;
client->camera.target = client->default_camera_target;
client->camera.up = (Vector3){0, 0, 1};
}
// Draw goal position for active agents
-
- if(!is_active_agent || env->entities[i].valid == 0) {
+ if (!is_active_agent || env->entities[i].valid == 0) {
continue;
}
- if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){
- DrawSphere((Vector3){
- env->entities[i].goal_position_x,
- env->entities[i].goal_position_y,
- 1
- }, 0.5f, DARKGREEN);
-
- DrawCircle3D((Vector3){
- env->entities[i].goal_position_x,
- env->entities[i].goal_position_y,
- 0.1f
- }, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f));
+ if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) {
+ DrawSphere((Vector3){env->entities[i].goal_position_x, env->entities[i].goal_position_y, 1}, 0.5f,
+ DARKGREEN);
+
+ DrawCircle3D((Vector3){env->entities[i].goal_position_x, env->entities[i].goal_position_y, 0.1f},
+ env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.9f));
}
}
// Draw road elements
- if(env->entities[i].type <=3 && env->entities[i].type >= 7){
+ if (env->entities[i].type <= 3 && env->entities[i].type >= 7) {
continue;
}
- for(int j = 0; j < env->entities[i].array_size - 1; j++) {
- Vector3 start = {
- env->entities[i].traj_x[j],
- env->entities[i].traj_y[j],
- 1
- };
- Vector3 end = {
- env->entities[i].traj_x[j + 1],
- env->entities[i].traj_y[j + 1],
- 1
- };
+ for (int j = 0; j < env->entities[i].array_size - 1; j++) {
+ Vector3 start = {env->entities[i].traj_x[j], env->entities[i].traj_y[j], 1};
+ Vector3 end = {env->entities[i].traj_x[j + 1], env->entities[i].traj_y[j + 1], 1};
Color lineColor = GRAY;
- if (env->entities[i].type == ROAD_LANE) lineColor = GRAY;
- else if (env->entities[i].type == ROAD_LINE) lineColor = BLUE;
- else if (env->entities[i].type == ROAD_EDGE) lineColor = WHITE;
- else if (env->entities[i].type == DRIVEWAY) lineColor = RED;
- if(env->entities[i].type != ROAD_EDGE){
- continue;
- }
- if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){
- draw_road_edge(env, start.x, start.y, end.x, end.y);
+ if (env->entities[i].type == ROAD_LANE)
+ lineColor = Fade(SOFT_YELLOW, 0.25f);
+ else if (env->entities[i].type == ROAD_LINE)
+ lineColor = WHITE;
+ else if (env->entities[i].type == ROAD_EDGE)
+ lineColor = WHITE;
+ else if (env->entities[i].type == DRIVEWAY)
+ lineColor = RED;
+
+ if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) {
+ if (env->entities[i].type == ROAD_EDGE) {
+ draw_road_edge(env, start.x, start.y, end.x, end.y);
+ } else if (env->entities[i].type == ROAD_LANE || env->entities[i].type == ROAD_LINE) {
+ // Draw road lanes and lines as purple lines
+ rlSetLineWidth(2.0f);
+ DrawLine3D(start, end, lineColor);
+ }
}
}
}
- if(show_grid) {
- // Draw grid cells using the stored bounds
- float grid_start_x = env->grid_map->top_left_x;
- float grid_start_y = env->grid_map->bottom_right_y;
- for(int i = 0; i < env->grid_map->grid_cols; i++) {
- for(int j = 0; j < env->grid_map->grid_rows; j++) {
- float x = grid_start_x + i*GRID_CELL_SIZE;
- float y = grid_start_y + j*GRID_CELL_SIZE;
- DrawCubeWires(
- (Vector3){x + GRID_CELL_SIZE/2, y + GRID_CELL_SIZE/2, 1},
- GRID_CELL_SIZE, GRID_CELL_SIZE, 0.1f, PUFF_BACKGROUND2);
- }
- }
- }
EndMode3D();
// Draw track indices for the tracks to predict
- if (mode == 1 && env->control_mode == CONTROL_TRACKS_TO_PREDICT) {
- float map_width = env->grid_map->bottom_right_x - env->grid_map->top_left_x;
+ if (mode == 1 && env->control_mode == CONTROL_WOSAC) {
float map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y;
float pixels_per_world_unit = client->height / map_height;
@@ -3242,151 +3492,257 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers,
float raw_x = -env->entities[agent_idx].x * pixels_per_world_unit;
float raw_y = env->entities[agent_idx].y * pixels_per_world_unit;
- int screen_x = (int)raw_x + client->width/2 + 20;
- int screen_y = (int)raw_y + client->height/2 - 25;
+ int screen_x = (int)raw_x + client->width / 2 + 20;
+ int screen_y = (int)raw_y + client->height / 2 - 25;
- if (screen_x >= 0 && screen_x <= client->width &&
- screen_y >= 0 && screen_y <= client->height) {
+ if (screen_x >= 0 && screen_x <= client->width && screen_y >= 0 && screen_y <= client->height) {
char text[32];
snprintf(text, sizeof(text), "%d", womd_track_idx);
int text_width = MeasureText(text, 20);
- DrawText(text, screen_x - text_width/2, screen_y, 20, PUFF_WHITE);
+ DrawText(text, screen_x - text_width / 2, screen_y, 20, PUFF_WHITE);
}
}
}
}
-void saveTopDownImage(Drive* env, Client* client, const char *filename, RenderTexture2D target, int map_height, int obs, int lasers, int trajectories, int frame_count, float* path, int log_trajectories, int show_grid){
- // Top-down orthographic camera
- Camera3D camera = {0};
- camera.position = (Vector3){ 0.0f, 0.0f, 500.0f }; // above the scene
- camera.target = (Vector3){ 0.0f, 0.0f, 0.0f }; // look at origin
- camera.up = (Vector3){ 0.0f, -1.0f, 0.0f };
- camera.fovy = map_height;
- camera.projection = CAMERA_ORTHOGRAPHIC;
- Color road = (Color){35, 35, 37, 255};
+// Headless rendering helper: write frame to ffmpeg pipe
+static void write_frame_to_pipe(Client *client) {
+ if (client->recorder_pipefd[1] < 0)
+ return;
- BeginTextureMode(target);
- ClearBackground(road);
- BeginMode3D(camera);
- rlEnableDepthTest();
+ int w = (int)client->width;
+ int h = (int)client->height;
+ unsigned char *screen_data = rlReadScreenPixels(w, h);
+ if (screen_data) {
+ write(client->recorder_pipefd[1], screen_data, w * h * 4);
+ RL_FREE(screen_data);
+ }
+}
+
+// Render with view_mode and draw_traces parameters (for Python-based rendering)
+void c_render_with_mode(Drive *env, int view_mode, int draw_traces, int current_scenario, int k_scenarios) {
+ if (env->client == NULL) {
+ env->client = make_client(env);
+ }
+ if (env->client == NULL)
+ return; // make_client may fail in headless mode
- // Draw log trajectories FIRST (in background at lower Z-level)
- if(log_trajectories){
- for(int i=0; iactive_agent_count;i++){
+ Client *client = env->client;
+ Color road = (Color){35, 35, 37, 255};
+
+ if (env->render_mode == RENDER_HEADLESS) {
+ // Headless rendering mode.
+ // Use the CURRENT env's grid bounds for camera fovy. The original code
+ // used client->original_map_height (captured once at make_client) for
+ // "consistent resolution across scenarios", but with map_rand_per_scenario
+ // the new map can have entirely different bounds. Locking to scenario 0's
+ // dims meant new maps got cropped or off-centered, even though the env
+ // state was correct (ego scores stay ~0.99 in s_1; the agents are fine,
+ // just rendered outside the camera view). Read live so the camera always
+ // frames the actual current map.
+ float live_map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y;
+ float render_map_height = live_map_height > 0.0f ? live_map_height : client->original_map_height;
+
+ Camera3D camera = {0};
+
+ if (view_mode == VIEW_MODE_SIM_STATE) {
+ // Top-down orthographic view of full simulation state
+ // Using original_map_height keeps pixels-per-world-unit consistent
+ camera.position = (Vector3){0.0f, 0.0f, 400.0f};
+ camera.target = (Vector3){0.0f, 0.0f, 0.0f};
+ camera.up = (Vector3){0.0f, -1.0f, 0.0f};
+ camera.projection = CAMERA_ORTHOGRAPHIC;
+ camera.fovy = render_map_height;
+
+ BeginDrawing();
+ ClearBackground(road);
+ BeginMode3D(camera);
+
+ if (draw_traces) {
+ // Draw trajectory traces for all active agents
+ for (int i = 0; i < env->active_agent_count; i++) {
int idx = env->active_agent_indices[i];
- for(int j=0; jentities[idx].array_size;j++){
- float x = env->entities[idx].traj_x[j];
- float y = env->entities[idx].traj_y[j];
- float valid = env->entities[idx].traj_valid[j];
- if(!valid) continue;
- DrawSphere((Vector3){x,y,0.5f}, 0.3f, Fade(LIGHTGREEN, 0.6f));
+ int t_end = env->scenario_length;
+ if (t_end > env->entities[idx].array_size) {
+ t_end = env->entities[idx].array_size;
+ }
+ for (int t = env->init_steps; t < t_end; t++) {
+ if (env->entities[idx].traj_valid[t]) {
+ DrawPoint3D((Vector3){env->entities[idx].traj_x[t], env->entities[idx].traj_y[t], 0.5f},
+ LIGHTBLUE);
+ }
}
}
}
+ draw_scene(env, client, 1, 0, 0, 0);
+ EndMode3D();
+
+ } else if (view_mode == VIEW_MODE_BEV_AGENT_OBS) {
+ // Bird's eye view centered on human-controlled agent
+ int agent_idx = env->active_agent_indices[env->human_agent_idx];
+ Entity *agent = &env->entities[agent_idx];
+
+ camera.position = (Vector3){agent->x, agent->y, 400.0f};
+ camera.target = (Vector3){agent->x, agent->y, 0.0f};
+ camera.up = (Vector3){0.0f, -1.0f, 0.0f};
+ camera.projection = CAMERA_ORTHOGRAPHIC;
+ camera.fovy = env->grid_map->vision_range * GRID_CELL_SIZE * 2.0f;
+
+ BeginDrawing();
+ ClearBackground(road);
+ BeginMode3D(camera);
+ draw_scene(env, client, 1, 1, 0, 0);
+ EndMode3D();
- // Draw current path trajectories SECOND (slightly higher than log trajectories)
- if(trajectories){
- for(int i=0; iactive_agent_indices[env->human_agent_idx];
+ Entity *agent = &env->entities[agent_idx];
- // Draw main scene LAST (on top)
- draw_scene(env, client, 1, obs, lasers, show_grid);
+ camera.position =
+ (Vector3){agent->x - (25.0f * cosf(agent->heading)), agent->y - (25.0f * sinf(agent->heading)), 15.0f};
+ camera.target =
+ (Vector3){agent->x + 40.0f * cosf(agent->heading), agent->y + 40.0f * sinf(agent->heading), 1.0f};
+ camera.up = (Vector3){0.0f, 0.0f, 1.0f};
+ camera.fovy = 60.0f;
+ camera.projection = CAMERA_PERSPECTIVE;
- EndMode3D();
- EndTextureMode();
-
- // save to file
- Image img = LoadImageFromTexture(target.texture);
- ImageFlipVertical(&img);
- ExportImage(img, filename);
- UnloadImage(img);
-}
-
-void saveAgentViewImage(Drive* env, Client* client, const char *filename, RenderTexture2D target, int map_height, int obs_only, int lasers, int show_grid) {
- // Agent perspective camera following the human agent
- int agent_idx = env->active_agent_indices[env->human_agent_idx];
- Entity* agent = &env->entities[agent_idx];
-
- Camera3D camera = {0};
- // Position camera behind and above the agent
- camera.position = (Vector3){
- agent->x - (25.0f * cosf(agent->heading)),
- agent->y - (25.0f * sinf(agent->heading)),
- 15.0f
- };
- camera.target = (Vector3){
- agent->x + 40.0f * cosf(agent->heading),
- agent->y + 40.0f * sinf(agent->heading),
- 1.0f
- };
- camera.up = (Vector3){ 0.0f, 0.0f, 1.0f };
- camera.fovy = 45.0f;
- camera.projection = CAMERA_PERSPECTIVE;
+ BeginDrawing();
+ ClearBackground(road);
+ BeginMode3D(camera);
+ draw_scene(env, client, 0, 0, 0, 1);
+ EndMode3D();
+ }
- Color road = (Color){35, 35, 37, 255};
+ // Draw scenario counter overlay (2D text on top of 3D scene)
+ if (k_scenarios > 1) {
+ char scenario_text[64];
+ snprintf(scenario_text, sizeof(scenario_text), "Scenario %d / %d", current_scenario + 1, k_scenarios);
+ DrawText(scenario_text, 40, 40, 120, WHITE);
+ }
+
+ EndDrawing();
+
+ // Write frame to ffmpeg pipe
+ write_frame_to_pipe(client);
- BeginTextureMode(target);
+ } else {
+ // Interactive window mode - use default rendering
+ BeginDrawing();
ClearBackground(road);
- BeginMode3D(camera);
- rlEnableDepthTest();
- draw_scene(env, client, 0, obs_only, lasers, show_grid); // mode=0 for agent view
+ BeginMode3D(client->camera);
+ handle_camera_controls(env->client);
+ draw_scene(env, client, 0, 0, 0, 0);
EndMode3D();
- EndTextureMode();
-
- // Save to file
- Image img = LoadImageFromTexture(target.texture);
- ImageFlipVertical(&img);
- ExportImage(img, filename);
- UnloadImage(img);
+ EndDrawing();
+ }
}
-void c_render(Drive* env) {
+// Original c_render for backward compatibility (interactive window mode)
+void c_render(Drive *env) {
if (env->client == NULL) {
env->client = make_client(env);
}
- Client* client = env->client;
+ Client *client = env->client;
BeginDrawing();
Color road = (Color){35, 35, 37, 255};
ClearBackground(road);
BeginMode3D(client->camera);
handle_camera_controls(env->client);
draw_scene(env, client, 0, 0, 0, 0);
+ EndMode3D();
+
+ if (IsKeyPressed(KEY_TAB)) {
+ env->human_agent_idx = (env->human_agent_idx + 1) % env->active_agent_count;
+ }
+
// Draw debug info
- DrawText(TextFormat("Camera Position: (%.2f, %.2f, %.2f)",
- client->camera.position.x,
- client->camera.position.y,
- client->camera.position.z), 10, 10, 20, PUFF_WHITE);
- DrawText(TextFormat("Camera Target: (%.2f, %.2f, %.2f)",
- client->camera.target.x,
- client->camera.target.y,
- client->camera.target.z), 10, 30, 20, PUFF_WHITE);
+ DrawText(TextFormat("Camera Position: (%.2f, %.2f, %.2f)", client->camera.position.x, client->camera.position.y,
+ client->camera.position.z),
+ 10, 10, 20, PUFF_WHITE);
+ DrawText(TextFormat("Camera Target: (%.2f, %.2f, %.2f)", client->camera.target.x, client->camera.target.y,
+ client->camera.target.z),
+ 10, 30, 20, PUFF_WHITE);
DrawText(TextFormat("Timestep: %d", env->timestep), 10, 50, 20, PUFF_WHITE);
- // acceleration & steering
+
int human_idx = env->active_agent_indices[env->human_agent_idx];
DrawText(TextFormat("Controlling Agent: %d", env->human_agent_idx), 10, 70, 20, PUFF_WHITE);
DrawText(TextFormat("Agent Index: %d", human_idx), 10, 90, 20, PUFF_WHITE);
+
+ // Display current action values - yellow when controlling, white otherwise
+ Color action_color = IsKeyDown(KEY_LEFT_SHIFT) ? YELLOW : PUFF_WHITE;
+
+ if (env->action_type == 0) { // discrete
+ int *action_array = (int *)env->actions;
+ int action_val = action_array[env->human_agent_idx];
+
+ if (env->dynamics_model == CLASSIC) {
+ int num_steer = 13;
+ int accel_idx = action_val / num_steer;
+ int steer_idx = action_val % num_steer;
+ float accel_value = ACCELERATION_VALUES[accel_idx];
+ float steer_value = STEERING_VALUES[steer_idx];
+
+ DrawText(TextFormat("Acceleration: %.2f m/s^2", accel_value), 10, 110, 20, action_color);
+ DrawText(TextFormat("Steering: %.3f", steer_value), 10, 130, 20, action_color);
+ } else if (env->dynamics_model == JERK) {
+ int num_lat = 3;
+ int jerk_long_idx = action_val / num_lat;
+ int jerk_lat_idx = action_val % num_lat;
+ float jerk_long_value = JERK_LONG[jerk_long_idx];
+ float jerk_lat_value = JERK_LAT[jerk_lat_idx];
+
+ DrawText(TextFormat("Longitudinal Jerk: %.2f m/s^3", jerk_long_value), 10, 110, 20, action_color);
+ DrawText(TextFormat("Lateral Jerk: %.2f m/s^3", jerk_lat_value), 10, 130, 20, action_color);
+ }
+ } else { // continuous
+ float (*action_array_f)[2] = (float (*)[2])env->actions;
+ DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, action_color);
+ DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, action_color);
+ }
+
+ // Show key press status
+ int status_y = 150;
+ if (IsKeyDown(KEY_LEFT_SHIFT)) {
+ DrawText("[shift pressed]", 10, status_y, 20, YELLOW);
+ status_y += 20;
+ }
+ if (IsKeyDown(KEY_SPACE)) {
+ DrawText("[space pressed]", 10, status_y, 20, YELLOW);
+ status_y += 20;
+ }
+ if (IsKeyDown(KEY_LEFT_CONTROL)) {
+ DrawText("[ctrl pressed]", 10, status_y, 20, YELLOW);
+ status_y += 20;
+ }
+
// Controls help
- DrawText("Controls: W/S - Accelerate/Brake, A/D - Steer, 1-4 - Switch Agent",
- 10, client->height - 30, 20, PUFF_WHITE);
- // acceleration & steering
- if (env->action_type == 1) { // continuous (float)
- float (*action_array_f)[2] = (float(*)[2])env->actions;
- DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE);
- DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE);
- } else { // discrete (int)
- int (*action_array)[2] = (int(*)[2])env->actions;
- DrawText(TextFormat("Acceleration: %d", action_array[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE);
- DrawText(TextFormat("Steering: %d", action_array[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE);
- }
- DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, 150, 20, PUFF_WHITE);
- DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, 170, 20, PUFF_WHITE);
+ DrawText("Controls: SHIFT + W/S - Accelerate/Brake, SHIFT + A/D - Steer, TAB - Switch Agent", 10,
+ client->height - 30, 20, PUFF_WHITE);
+
+ DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, status_y, 20, PUFF_WHITE);
+ DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, status_y + 20, 20, PUFF_WHITE);
EndDrawing();
}
-void close_client(Client* client){
+// Set the full mp4 basename (without ".mp4") for headless rendering.
+// If a recorder is already running for this env, it's stopped and a new one
+// is started so multiple basenames can be used in a single env lifetime.
+void set_video_suffix(Drive *env, const char *basename) {
+ if (basename) {
+ strncpy(env->video_basename, basename, sizeof(env->video_basename) - 1);
+ env->video_basename[sizeof(env->video_basename) - 1] = '\0';
+ } else {
+ env->video_basename[0] = '\0';
+ }
+
+ if (env->client != NULL && env->render_mode == RENDER_HEADLESS) {
+ start_video_recorder(env->client, env->video_basename);
+ }
+}
+
+void close_client(Client *client) {
+ stop_video_recorder(client);
for (int i = 0; i < 6; i++) {
UnloadModel(client->cars[i]);
}
diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py
index c3c6be27af..357d37649c 100644
--- a/pufferlib/ocean/drive/drive.py
+++ b/pufferlib/ocean/drive/drive.py
@@ -3,9 +3,20 @@
import json
import struct
import os
+from enum import IntEnum
import pufferlib
from pufferlib.ocean.drive import binding
import torch
+from multiprocessing import Pool, cpu_count
+from tqdm import tqdm
+
+
+class RenderView(IntEnum):
+ """View modes for rendering."""
+
+ FULL_SIM_STATE = 0 # Top-down orthographic view of full simulation
+ BEV_AGENT_OBS = 1 # Bird's eye view centered on agent observation
+ AGENT_PERSPECTIVE = 2 # Third-person chase camera following agent
class Drive(pufferlib.PufferEnv):
@@ -20,13 +31,18 @@ def __init__(
reward_offroad_collision=-0.1,
reward_goal=1.0,
reward_goal_post_respawn=0.5,
- reward_ade=0.0,
+ reward_lane_align=0.0, # GIGAFLOW lane alignment reward (0 = disabled)
+ reward_vel_align=1.0, # Velocity alignment coefficient for lane reward
goal_behavior=0,
+ goal_target_distance=10.0,
goal_radius=2.0,
+ goal_speed=20.0,
collision_behavior=0,
offroad_behavior=0,
dt=0.1,
scenario_length=None,
+ episode_length=None,
+ termination_mode=None,
resample_frequency=91,
num_maps=100,
num_agents=512,
@@ -45,25 +61,62 @@ def __init__(
co_player_enabled=False,
num_ego_agents=512,
co_player_policy={},
+ map_dir="resources/drive/binaries/training",
+ use_all_maps=False,
+ report_all_scenarios=False,
+ map_seed=None,
+ external_co_player_actions=False,
+ worker_idx=0,
+ co_player_conditioning_shm=None,
+ map_rand_per_scenario=False,
+ condition_rand_per_scenario=False,
+ entropy_curriculum_enabled=False,
+ entropy_curriculum_episodes_start=0,
+ k_eff_curriculum_enabled=False,
+ k_eff_curriculum_episodes_per_stage=30,
+ ego_is_oracle=False,
+ reward_only_last_scenario=False,
):
# env
self.dt = dt
+ # Convert render_mode string to integer constant
+ if render_mode is None or render_mode == 0:
+ self._render_mode_int = binding.RENDER_OFF
+ elif render_mode == 1 or render_mode == "headless":
+ self._render_mode_int = binding.RENDER_HEADLESS
+ elif render_mode == 2 or render_mode == "window" or render_mode == "human":
+ self._render_mode_int = binding.RENDER_WINDOW
+ else:
+ self._render_mode_int = binding.RENDER_OFF
self.render_mode = render_mode
+ self.report_all_scenarios = report_all_scenarios
self.num_maps = num_maps
self.report_interval = report_interval
self.reward_vehicle_collision = reward_vehicle_collision
self.reward_offroad_collision = reward_offroad_collision
self.reward_goal = reward_goal
self.reward_goal_post_respawn = reward_goal_post_respawn
+ self.reward_lane_align = reward_lane_align
+ self.reward_vel_align = reward_vel_align
self.goal_radius = goal_radius
+ self.goal_speed = goal_speed
self.goal_behavior = goal_behavior
+ self.goal_target_distance = goal_target_distance
self.collision_behavior = collision_behavior
self.offroad_behavior = offroad_behavior
- self.reward_ade = reward_ade
self.human_agent_idx = human_agent_idx
self.scenario_length = scenario_length
+ self.termination_mode = termination_mode
self.resample_frequency = resample_frequency
self.ini_file = ini_file
+ self.use_all_maps = use_all_maps
+ self.map_seed = map_seed
+
+ if episode_length != None:
+ self.scenario_length = episode_length
+ # Only set episode_length if not already set (adaptive.py sets it before calling super())
+ if not hasattr(self, "episode_length"):
+ self.episode_length = self.scenario_length
# Adaptive driving agent setup
self.adaptive_driving_agent = int(adaptive_driving_agent)
@@ -119,22 +172,100 @@ def __init__(
self.dynamics_model = dynamics_model
# Observation space calculation
- base_ego_dim = 10 if self.dynamics_model == "jerk" else 7
+ self.ego_features = {"classic": binding.EGO_FEATURES_CLASSIC, "jerk": binding.EGO_FEATURES_JERK}.get(
+ dynamics_model
+ )
+
+ self.ego_features += conditioning_dims
+
+ # Extract observation shapes from constants
+ # These need to be defined in C, since they determine the shape of the arrays
+ self.max_road_objects = binding.MAX_ROAD_SEGMENT_OBSERVATIONS
+ self.max_partner_objects = binding.MAX_AGENTS - 1
+ self.partner_features = binding.PARTNER_FEATURES
+ self.road_features = binding.ROAD_FEATURES
- partner_features = 7
- road_features = 7
- max_partner_objects = 63
- max_road_objects = 200
self.num_obs = (
- base_ego_dim + conditioning_dims + max_partner_objects * partner_features + max_road_objects * road_features
+ self.ego_features
+ + self.max_partner_objects * self.partner_features
+ + self.max_road_objects * self.road_features
)
-
self.single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(self.num_obs,), dtype=np.float32)
# Co-player policy setup
self.population_play = co_player_enabled
self.num_agents = num_agents
self.num_ego_agents = num_ego_agents if self.population_play else num_agents
+ # When True, co-player actions are filled into self.actions[co_player_ids]
+ # by the *main* process (centralized GPU inference). Worker's step()
+ # then skips the local CPU forward in get_co_player_actions().
+ self.external_co_player_actions = bool(external_co_player_actions)
+ # When True (and adaptive_drive with k>1), at every scenario boundary
+ # we re-init the C envs with FRESH map_ids (and freshly sampled co-
+ # player conditioning). The agents are spawned on a brand-new map
+ # while the EGO POLICY's K/V cache (held in main / pufferl) is NOT
+ # touched — so past-scenario context is the only stable signal that
+ # carries across scenarios. This is the experimental setup that
+ # actually exercises in-context adaptation.
+ self.map_rand_per_scenario = bool(map_rand_per_scenario)
+ # When True (only meaningful for k_scenarios > 1, with co_player_enabled):
+ # at every scenario boundary within an episode, partner conditioning is
+ # re-sampled from the configured ranges. The same partner POLICY weights
+ # are used, but the partner's effective behavior changes per scenario
+ # because conditioning shifts. This gives the ego policy a meaningful
+ # latent variable (partner type) to encode in its K/V cache without
+ # introducing the agent-identity misalignment that map_rand causes.
+ # Independent of map_rand_per_scenario.
+ self.condition_rand_per_scenario = bool(condition_rand_per_scenario)
+ # When True, partner's entropy_weight_ub is annealed up over training:
+ # the user-passed co_player_entropy_weight_ub is treated as the FINAL
+ # value, and a 4-stage schedule scales it 0.05 → 0.20 → 0.50 → 1.0 of
+ # the final, advancing every 30 episodes per worker. The other
+ # conditioning dims (collision/offroad/discount) sample at full range
+ # throughout. Reason: we observed ada_delta peaking early in training
+ # then drifting toward 0 as scores saturate; the curriculum keeps the
+ # task in an informative-difficulty regime for longer.
+ self.entropy_curriculum_enabled = bool(entropy_curriculum_enabled)
+ # When resuming from a checkpoint of a curriculum run, the per-env
+ # episode counter isn't part of the model state — pass the original
+ # run's ending episode count here so the curriculum picks up at the
+ # right stage instead of restarting from stage 0. Each worker still
+ # advances its own counter from this starting value.
+ self._entropy_curriculum_episodes_seen = int(entropy_curriculum_episodes_start)
+ self._pending_entropy_log = None
+ self._entropy_curriculum_final_ub = None # set lazily once we know co_player_entropy_weight_ub
+ # When True, the ego's K/V cache is reset at SOME within-episode
+ # scenario boundaries based on the curriculum stage. K_max is the
+ # configured `k_scenarios`; the curriculum has 3 stages of
+ # `k_eff_curriculum_episodes_per_stage` episodes each:
+ # Stage 0: k_eff=1 (reset at every within-episode boundary)
+ # Stage 1: k_eff=2 (reset at boundaries where current_scenario%2==0)
+ # Stage 2: k_eff=k_max (no within-episode resets)
+ # Implementation: at the boundary that should reset, we set
+ # truncations[ego_ids]=1 and terminals[ego_ids]=1 for the current
+ # step. pufferl picks up done_mask=t+d to reset transformer_position
+ # at eval time; create_episode_mask uses terminals to block
+ # cross-boundary attention during training. K_max=4 with stages
+ # k_eff∈{1,2,4} gives clean splits (boundaries 1,2,3 → reset {all},
+ # {middle only}, {none}). For other K_max, only k_eff=1 and k_max
+ # produce uniform stages.
+ self.k_eff_curriculum_enabled = bool(k_eff_curriculum_enabled)
+ self.k_eff_curriculum_episodes_per_stage = int(k_eff_curriculum_episodes_per_stage)
+ self._k_eff_curriculum_episodes_seen = 0
+ self._pending_k_eff_log = None
+ # When True, _reinit_envs_with_new_maps() donates env[0]->client to a
+ # C-side global before vec_close and re-attaches it to the new env[0]
+ # afterwards. This keeps the raylib window + ffmpeg pipe alive across
+ # the swap (raylib's CloseWindow → InitWindow cycle segfaults under
+ # xvfb), so a single mp4 captures all k scenarios with the maps
+ # rotating mid-stream. Set True only on the single-env render driver.
+ self._render_keep_client_on_swap = False
+ self.worker_idx = int(worker_idx)
+ # SHM view (numpy) of the per-worker conditioning slice. Env writes
+ # sampled conditioning here so the main process can read it before
+ # running the centralized co-player forward. None when conditioning
+ # is disabled or when running the per-worker CPU path.
+ self.co_player_conditioning_shm = co_player_conditioning_shm
# Co-player conditioning setup
self.co_player_conditioning = co_player_policy.get("conditioning")
@@ -156,15 +287,64 @@ def __init__(
self.co_player_discount_weight_lb = self.co_player_conditioning.get("discount_weight_lb", 0.98)
self.co_player_discount_weight_ub = self.co_player_conditioning.get("discount_weight_ub", 0.98)
+ # ----- Ego oracle (NEW, isolated machinery) -----
+ # When True, the ego's obs gets the partner's per-env conditioning
+ # vector appended at the END (after road_obs). Implementation:
+ # 1. Allocate a private `_c_observations` buffer the C side writes
+ # into (sized to the C's expected obs_dim, no oracle slots).
+ # 2. The pufferl-facing `self.observations` buffer is sized
+ # bigger (`+ oracle_dims`); each step we copy the C buffer
+ # into the first part and write `_oracle_obs_per_env[env]`
+ # into the trailing oracle slots for every ego row.
+ # No changes to [env.conditioning], pufferl, or reward — the
+ # oracle slots are pure obs signal that only the policy reads.
+ self.ego_is_oracle = bool(ego_is_oracle)
+ self.reward_only_last_scenario = bool(reward_only_last_scenario)
+ if self.reward_only_last_scenario and not self.adaptive_driving_agent:
+ raise ValueError("reward_only_last_scenario=True requires adaptive_driving_agent=True (k_scenarios > 1).")
+ if self.ego_is_oracle:
+ # Determine the partner's conditioning dim count (== oracle width).
+ ct = self.co_player_condition_type
+ if ct is None or ct == "none":
+ raise ValueError(
+ "ego_is_oracle=True requires co-player conditioning to be "
+ "enabled (co_player_policy.conditioning.type != 'none')."
+ )
+ self._oracle_dims = (
+ (3 if self.co_player_reward_conditioned else 0)
+ + (1 if self.co_player_entropy_conditioned else 0)
+ + (1 if self.co_player_discount_conditioned else 0)
+ )
+ if self._oracle_dims == 0:
+ raise ValueError("ego_is_oracle=True but partner conditioning resolved to 0 dims.")
+ # Grow obs space so pufferl allocates a buffer wide enough to
+ # hold the appended oracle slots (placed AFTER road_obs, at
+ # offset = num_obs - oracle_dims). C still writes the smaller
+ # part into its own private `_c_observations` buffer.
+ self.num_obs += self._oracle_dims
+ self.single_observation_space = gymnasium.spaces.Box(
+ low=-1, high=1, shape=(self.num_obs,), dtype=np.float32
+ )
+ # Per-env oracle vector. Filled at every _set_co_player_conditioning
+ # call from a copy of `self.env_conditioning`. We don't allocate
+ # `_c_observations` or `_ego_env_indices` here — they need
+ # `num_envs`/`num_agents` which aren't known until later.
+ self._oracle_obs_per_env = None
+ self._c_observations = None
+ self._ego_env_indices = None
+ else:
+ self._oracle_dims = 0
+
self.init_steps = init_steps
self.init_mode_str = init_mode
self.control_mode_str = control_mode
+ self.map_dir = map_dir
if self.control_mode_str == "control_vehicles":
self.control_mode = 0
elif self.control_mode_str == "control_agents":
self.control_mode = 1
- elif self.control_mode_str == "control_tracks_to_predict":
+ elif self.control_mode_str == "control_wosac":
self.control_mode = 2
elif self.control_mode_str == "control_sdc_only":
self.control_mode = 3
@@ -188,7 +368,8 @@ def __init__(
# Multi discrete (assume independence)
# self.single_action_space = gymnasium.spaces.MultiDiscrete([7, 13])
elif dynamics_model == "jerk":
- self.single_action_space = gymnasium.spaces.MultiDiscrete([4, 3])
+ # Joint action space (assume dependence) - 4 longitudinal × 3 lateral = 12
+ self.single_action_space = gymnasium.spaces.MultiDiscrete([4 * 3])
else:
raise ValueError(f"dynamics_model must be 'classic' or 'jerk'. Got: {dynamics_model}")
elif action_type == "continuous":
@@ -198,18 +379,18 @@ def __init__(
self._action_type_flag = 0 if action_type == "discrete" else 1
- # Check if resources directory exists
- binary_path = "resources/drive/binaries/map_000.bin"
+ # Check if resources directory exists (check map_001 since some datasets start at 001)
+ binary_path = f"{map_dir}/map_001.bin"
if not os.path.exists(binary_path):
raise FileNotFoundError(
- f"Required directory {binary_path} not found. Please ensure the Drive maps are downloaded and installed correctly per docs."
+ f"Required file {binary_path} not found. Please ensure the Drive maps are downloaded and installed correctly per docs."
)
# Check maps availability
- available_maps = len([name for name in os.listdir("resources/drive/binaries") if name.endswith(".bin")])
+ available_maps = len([name for name in os.listdir(map_dir) if name.endswith(".bin")])
if num_maps > available_maps:
raise ValueError(
- f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). Please reduce num_maps or add more maps to resources/drive/binaries."
+ f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). Please reduce num_maps or add more maps to {map_dir}."
)
if self.population_play:
if self.num_ego_agents > num_agents:
@@ -226,8 +407,18 @@ def __init__(
if self.population_play:
self.co_player_policy_name = co_player_policy.get("policy_name")
self.co_player_rnn_name = co_player_policy.get("rnn_name")
- self.co_player_policy = co_player_policy.get("co_player_policy_func")
- self._set_co_player_state()
+ if self.external_co_player_actions:
+ # Main owns the policy + state on GPU; worker only needs the
+ # action slots (co_player_ids) to be filled via shared memory
+ # before vec_step. Skip the per-worker CPU model entirely.
+ self.co_player_policy = None
+ self.co_player_device = None
+ else:
+ self.co_player_policy = co_player_policy.get("co_player_policy_func")
+ # Co-player runs in forked subprocess - must stay on CPU
+ # (CUDA doesn't work with fork)
+ self.co_player_device = torch.device("cpu")
+ self._set_co_player_state()
super().__init__(buf=buf)
if self.population_play:
@@ -238,12 +429,27 @@ def __init__(
else:
self.co_player_actions = np.zeros(co_player_atn_space.shape, dtype=np.int32)
+ # Allocate the private C-only obs buffer + ego→env index map (oracle
+ # path only). C writes into `_c_observations` (no oracle slots);
+ # we copy + append into `self.observations` (which has oracle slots)
+ # every step/reset. `_oracle_obs_per_env` was filled by
+ # `_set_co_player_conditioning` during the prior `_set_env_variables`
+ # call (which always fires here because oracle requires co-player
+ # conditioning to be on).
+ if self.ego_is_oracle:
+ self._c_obs_dim = self.num_obs - self._oracle_dims
+ self._c_observations = np.zeros((self.num_agents, self._c_obs_dim), dtype=np.float32)
+ self._rebuild_ego_env_indices()
+
env_ids = []
for i in range(self.num_envs):
cur = self.agent_offsets[i]
nxt = self.agent_offsets[i + 1]
+ # Oracle: hand C its own private obs slice (smaller, no oracle
+ # slots). Otherwise C uses the pufferl-provided buffer directly.
+ obs_slice_for_c = self._c_observations[cur:nxt] if self.ego_is_oracle else self.observations[cur:nxt]
env_id = binding.env_init(
- self.observations[cur:nxt],
+ obs_slice_for_c,
self.actions[cur:nxt],
self.rewards[cur:nxt],
self.terminals[cur:nxt],
@@ -256,13 +462,17 @@ def __init__(
reward_offroad_collision=reward_offroad_collision,
reward_goal=reward_goal,
reward_goal_post_respawn=reward_goal_post_respawn,
- reward_ade=reward_ade,
+ reward_lane_align=self.reward_lane_align,
+ reward_vel_align=self.reward_vel_align,
goal_radius=goal_radius,
+ goal_speed=goal_speed,
goal_behavior=self.goal_behavior,
+ goal_target_distance=self.goal_target_distance,
collision_behavior=self.collision_behavior,
offroad_behavior=self.offroad_behavior,
dt=dt,
- scenario_length=(int(scenario_length) if scenario_length is not None else None),
+ scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None),
+ termination_mode=(int(self.termination_mode) if self.termination_mode is not None else 0),
max_controlled_agents=self.max_controlled_agents,
map_id=self.map_ids[i],
max_agents=nxt - cur,
@@ -288,6 +498,8 @@ def __init__(
discount_weight_ub=self.discount_weight_ub,
init_mode=self.init_mode,
control_mode=self.control_mode,
+ map_dir=map_dir,
+ render_mode=self._render_mode_int,
)
env_ids.append(env_id)
@@ -295,15 +507,30 @@ def __init__(
def reset(self, seed=0):
binding.vec_reset(self.c_envs, seed)
+ # Oracle: copy C obs into pufferl buffer + write oracle slots.
+ self._refresh_ego_oracle_obs()
info = []
if self.population_play:
info.append(self.ego_ids)
- self._reset_co_player_state()
+ if self.external_co_player_actions:
+ # Pass the real co_player_ids so main does not have to
+ # guess by complement (which would include padding slots
+ # and pollute the shared KV cache).
+ info.append({"_external_co_player_ids": self.co_player_ids})
+ # Tell main to drop the K/V cache, mirroring OFF's
+ # `_reset_co_player_state()` here. Initial conditioning
+ # was written to SHM in `_set_env_variables` and stays
+ # valid across the episode (matching OFF, which doesn't
+ # re-sample at reset either).
+ info.append({"_external_reset_co_cache": True})
+ else:
+ self._reset_co_player_state()
self.tick = 0
return self.observations, info
def _set_env_variables(self):
my_shared_tuple = binding.shared(
+ map_dir=self.map_dir,
num_agents=self.num_agents,
num_maps=self.num_maps,
init_mode=self.init_mode,
@@ -313,6 +540,9 @@ def _set_env_variables(self):
goal_behavior=self.goal_behavior,
population_play=self.population_play,
num_ego_agents=self.num_ego_agents,
+ goal_target_distance=self.goal_target_distance,
+ use_all_maps=self.use_all_maps,
+ map_seed=self.map_seed if self.map_seed is not None else -1,
)
if self.population_play:
@@ -377,7 +607,11 @@ def _set_env_variables(self):
self.agent_offsets, self.map_ids, self.num_envs = my_shared_tuple
self.ego_ids = [i for i in range(self.agent_offsets[-1])]
if len(self.ego_ids) != self.num_agents:
- raise ValueError("mismatch between number of ego agents and number of agents")
+ print(
+ f"Warning: requested {self.num_agents} agents but maps contain {len(self.ego_ids)} valid agents. Adjusting.",
+ flush=True,
+ )
+ self.num_agents = len(self.ego_ids)
self.local_co_player_ids = [[] for i in range(self.num_envs)]
self.local_ego_ids = [[0] for i in range(self.num_envs)]
@@ -388,30 +622,62 @@ def get_co_player_actions(self):
if self.co_player_condition_type != "none":
co_player_obs = self._add_co_player_conditioning(co_player_obs)
- co_player_obs = torch.as_tensor(co_player_obs)
+ # Convert directly to device for GPU acceleration
+ co_player_obs = torch.as_tensor(co_player_obs, device=self.co_player_device)
+ import sys
+
+ sys.stdout.flush() # Prevent multiprocessing deadlock
logits, value = self.co_player_policy.forward_eval(co_player_obs, self.state)
- co_player_action, logprob, _ = pufferlib.pytorch.sample_logits(logits)
+ # Handle multi-discrete actions (logits is a tuple) vs single discrete (logits is tensor)
+ if isinstance(logits, tuple):
+ co_player_action = torch.cat([l.argmax(dim=-1, keepdim=True) for l in logits], dim=-1)
+ else:
+ co_player_action = logits.argmax(dim=-1)
+ # Only this transfer is necessary
co_player_action = co_player_action.cpu().numpy().reshape(self.co_player_actions.shape)
return co_player_action
def _set_co_player_state(self):
with torch.no_grad():
- self.state = dict(
- lstm_h=torch.zeros(self.num_co_players, self.co_player_policy.hidden_size),
- lstm_c=torch.zeros(self.num_co_players, self.co_player_policy.hidden_size),
- )
+ # Detect if co-player uses Transformer (has horizon) or LSTM
+ self.co_player_is_transformer = hasattr(self.co_player_policy, "horizon")
+
+ if self.co_player_is_transformer:
+ # Transformer co-player uses streaming KV cache (see
+ # TransformerWrapper.forward_eval). The cache is allocated
+ # lazily inside forward_eval the first time it sees a state
+ # without "k_cache". We initialize the position counter here
+ # so reset_eval_state has something to zero on full reset.
+ self.state = dict(
+ transformer_position=torch.zeros(1, dtype=torch.long, device=self.co_player_device),
+ )
+ else:
+ self.state = dict(
+ lstm_h=torch.zeros(
+ self.num_co_players, self.co_player_policy.hidden_size, device=self.co_player_device
+ ),
+ lstm_c=torch.zeros(
+ self.num_co_players, self.co_player_policy.hidden_size, device=self.co_player_device
+ ),
+ )
def _reset_co_player_state(self, done_indices=None):
- """Reset LSTM state for co-players whose episodes ended"""
+ """Reset LSTM/Transformer state for co-players whose episodes ended"""
with torch.no_grad():
if done_indices is None:
# Reset all
self._set_co_player_state()
else:
# Reset only specific co-players
- device = self.state["lstm_h"].device
- self.state["lstm_h"][done_indices] = 0
- self.state["lstm_c"][done_indices] = 0
+ if self.co_player_is_transformer:
+ # Re-prime the KV cache for the done rows so that subsequent
+ # forward_eval calls behave as if those rows had a fresh
+ # zero hidden buffer (matches the original semantics of
+ # `state["transformer_context"][done_indices] = 0`).
+ self.co_player_policy.reset_eval_state(self.state, done_indices=done_indices)
+ else:
+ self.state["lstm_h"][done_indices] = 0
+ self.state["lstm_c"][done_indices] = 0
def _add_co_player_conditioning(self, observations):
"""Add pre-sampled conditioning variables to co-player observations"""
@@ -426,10 +692,96 @@ def _add_co_player_conditioning(self, observations):
if observations.shape[0] != self.total_co_players:
raise ValueError(f"Expected {self.total_co_players} observations, got {observations.shape[0]}")
- return np.concatenate([observations[:, :7], self.cached_conditioning_array, observations[:, 7:]], axis=1)
+ # Use dynamic base_ego_dim based on dynamics model
+ base_ego_dim = binding.EGO_FEATURES_JERK if self.dynamics_model == "jerk" else binding.EGO_FEATURES_CLASSIC
+ return np.concatenate(
+ [observations[:, :base_ego_dim], self.cached_conditioning_array, observations[:, base_ego_dim:]], axis=1
+ )
+
+ def _rebuild_ego_env_indices(self):
+ """Recompute self._ego_env_indices: for the k-th ego in self.ego_ids
+ order, the env index it belongs to. Called at init and after every
+ _set_env_variables (since map re-roll may change per-env ego counts).
+ Oracle path only."""
+ if not self.ego_is_oracle:
+ return
+ ego_env_ids = []
+ if self.population_play:
+ for env_idx, env_egos in enumerate(self.local_ego_ids):
+ ego_env_ids.extend([env_idx] * len(env_egos))
+ else:
+ for env_idx in range(self.num_envs):
+ cur = int(self.agent_offsets[env_idx])
+ nxt = int(self.agent_offsets[env_idx + 1])
+ ego_env_ids.extend([env_idx] * (nxt - cur))
+ self._ego_env_indices = np.asarray(ego_env_ids, dtype=np.int64)
+
+ def _refresh_ego_oracle_obs(self):
+ """Copy C-side obs into the pufferl-facing buffer and write the
+ per-env partner-conditioning vector into the trailing oracle
+ slots for every ego row. No-op when oracle is off."""
+ if not self.ego_is_oracle:
+ return
+ c_dim = self._c_obs_dim
+ # Copy C output into the leading c_obs_dim columns. (Cannot slice
+ # the assignment to a single np.copyto because the buffers were
+ # allocated separately; numpy fast path is fine.)
+ self.observations[:, :c_dim] = self._c_observations
+ # Append partner conditioning to ego rows only. Non-ego rows keep
+ # whatever was there (zeros from allocation; pufferl filters out
+ # non-ego rows downstream anyway).
+ if len(self.ego_ids) > 0:
+ self.observations[self.ego_ids, c_dim:] = self._oracle_obs_per_env[self._ego_env_indices]
+
+ def _current_k_eff(self):
+ """Effective k for the ego's K/V cache horizon at the current
+ curriculum stage. Returns k_scenarios (i.e. K_max) when the
+ curriculum is disabled or has finished. Stages last
+ `k_eff_curriculum_episodes_per_stage` episodes each:
+ stage 0 → k_eff = 1
+ stage 1 → k_eff = 2
+ stage 2+ → k_eff = K_max
+ """
+ if not self.k_eff_curriculum_enabled:
+ return self.k_scenarios
+ n = self._k_eff_curriculum_episodes_seen
+ s = self.k_eff_curriculum_episodes_per_stage
+ if n < s:
+ return 1
+ elif n < 2 * s:
+ return 2
+ else:
+ return self.k_scenarios
+
+ def _k_eff_should_reset_at_current_boundary(self):
+ """True when the just-crossed scenario boundary should cut the ego
+ K/V cache under the current curriculum stage. Caller must already
+ have ensured this is a within-episode boundary (current_scenario != 0
+ after the modulo increment)."""
+ k_eff = self._current_k_eff()
+ return self.current_scenario % k_eff == 0
def _set_co_player_conditioning(self):
"""Sample and store conditioning values for each environment and update all caches"""
+ # Entropy curriculum: scale entropy_ub based on episodes seen so far.
+ # Schedule: stage 0 = 0.05*final, stage 1 = 0.20*final, stage 2 = 0.50*final,
+ # stage 3 = 1.00*final. Each stage = 30 episodes per worker (≈30 epochs
+ # given ~1 episode/epoch with our nw=32 nv=32 setup).
+ if self.entropy_curriculum_enabled and self.co_player_entropy_conditioned:
+ if self._entropy_curriculum_final_ub is None:
+ self._entropy_curriculum_final_ub = self.co_player_entropy_weight_ub
+ n = self._entropy_curriculum_episodes_seen
+ if n < 30:
+ ratio = 0.05
+ elif n < 60:
+ ratio = 0.20
+ elif n < 90:
+ ratio = 0.50
+ else:
+ ratio = 1.00
+ self.co_player_entropy_weight_ub = ratio * self._entropy_curriculum_final_ub
+ self._entropy_curriculum_episodes_seen += 1
+
# Update co-player counts and indices
self.num_co_players_per_env = np.array([len(ids) for ids in self.local_co_player_ids], dtype=np.int32)
self.total_co_players = self.num_co_players_per_env.sum()
@@ -477,6 +829,133 @@ def _set_co_player_conditioning(self):
else:
self.cached_conditioning_array = np.empty((0, len(conditioning_dims)), dtype=np.float32)
+ # Oracle: sync per-env oracle vector with the freshly sampled
+ # partner conditioning. num_envs can change across
+ # _reinit_envs_with_new_maps (some maps yield no valid agents and
+ # get dropped C-side), so just take a fresh copy with the current
+ # shape rather than a fixed-size in-place write. Width invariant
+ # is checked by the assertion below.
+ if self.ego_is_oracle:
+ assert self.env_conditioning.shape[1] == self._oracle_dims, (
+ f"oracle width mismatch: env_conditioning has "
+ f"{self.env_conditioning.shape[1]} dims, oracle expects {self._oracle_dims}"
+ )
+ self._oracle_obs_per_env = self.env_conditioning.copy()
+
+ # Stash sampled entropy stats for wandb. Index of the entropy column
+ # within env_conditioning depends on which dims are active above:
+ # reward(3) -> [collision, offroad, goal] then entropy then discount
+ if self.co_player_entropy_conditioned and self.env_conditioning.shape[1] > 0:
+ entropy_col = 3 if self.co_player_reward_conditioned else 0
+ sampled = self.env_conditioning[:, entropy_col]
+ self._pending_entropy_log = {
+ "co_player/entropy_weight_ub": float(self.co_player_entropy_weight_ub),
+ "co_player/entropy_sampled_mean": float(sampled.mean()),
+ "co_player/entropy_sampled_min": float(sampled.min()),
+ "co_player/entropy_sampled_max": float(sampled.max()),
+ }
+ if self.entropy_curriculum_enabled:
+ self._pending_entropy_log["co_player/entropy_curriculum_episodes"] = int(
+ self._entropy_curriculum_episodes_seen
+ )
+
+ # Mirror the freshly-sampled conditioning into the shared-memory
+ # buffer so the main process (centralized co-player on GPU) sees the
+ # latest values before its next forward pass. SHM rows beyond
+ # `total_co_players` are left at whatever value they previously held;
+ # main only reads rows for active co-players.
+ if (
+ self.co_player_conditioning_shm is not None
+ and self.cached_conditioning_array.shape[1] > 0
+ and self.total_co_players > 0
+ ):
+ shm = self.co_player_conditioning_shm
+ n = min(self.total_co_players, shm.shape[0])
+ shm[:n, :] = self.cached_conditioning_array[:n, :]
+
+ def _reinit_envs_with_new_maps(self):
+ """Close + recreate C envs with fresh map_ids.
+
+ Called at episode resample boundary and (when
+ `map_rand_per_scenario=True`) at scenario boundaries. Note: this
+ sets `self.terminals[:] = 1`, which pufferl uses to wipe the ego
+ K/V cache — so `map_rand_per_scenario=True` is currently broken
+ as an ICL probe (cache + GAE both truncate at the boundary).
+ """
+ if self._render_keep_client_on_swap:
+ binding.vec_donate_client(self.c_envs)
+ binding.vec_close(self.c_envs)
+ self._set_env_variables()
+ env_ids = []
+ seed = np.random.randint(0, 2**32 - 1)
+ for i in range(self.num_envs):
+ cur = self.agent_offsets[i]
+ nxt = self.agent_offsets[i + 1]
+ obs_slice_for_c = self._c_observations[cur:nxt] if self.ego_is_oracle else self.observations[cur:nxt]
+ env_id = binding.env_init(
+ obs_slice_for_c,
+ self.actions[cur:nxt],
+ self.rewards[cur:nxt],
+ self.terminals[cur:nxt],
+ self.truncations[cur:nxt],
+ seed,
+ action_type=self._action_type_flag,
+ human_agent_idx=self.human_agent_idx,
+ dynamics_model=self.dynamics_model,
+ reward_vehicle_collision=self.reward_vehicle_collision,
+ reward_offroad_collision=self.reward_offroad_collision,
+ goal_radius=self.goal_radius,
+ goal_behavior=self.goal_behavior,
+ collision_behavior=self.collision_behavior,
+ offroad_behavior=self.offroad_behavior,
+ reward_goal=self.reward_goal,
+ reward_goal_post_respawn=self.reward_goal_post_respawn,
+ reward_lane_align=self.reward_lane_align,
+ reward_vel_align=self.reward_vel_align,
+ goal_speed=self.goal_speed,
+ goal_target_distance=self.goal_target_distance,
+ dt=self.dt,
+ scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None),
+ max_controlled_agents=self.max_controlled_agents,
+ map_id=self.map_ids[i],
+ use_rc=self.reward_conditioned,
+ use_ec=self.entropy_conditioned,
+ use_dc=self.discount_conditioned,
+ collision_weight_lb=self.collision_weight_lb,
+ collision_weight_ub=self.collision_weight_ub,
+ offroad_weight_lb=self.offroad_weight_lb,
+ offroad_weight_ub=self.offroad_weight_ub,
+ goal_weight_lb=self.goal_weight_lb,
+ goal_weight_ub=self.goal_weight_ub,
+ entropy_weight_lb=self.entropy_weight_lb,
+ entropy_weight_ub=self.entropy_weight_ub,
+ discount_weight_lb=self.discount_weight_lb,
+ discount_weight_ub=self.discount_weight_ub,
+ max_agents=nxt - cur,
+ ini_file=self.ini_file,
+ population_play=self.population_play,
+ num_co_players=len(self.local_co_player_ids[i]),
+ co_player_ids=self.local_co_player_ids[i],
+ ego_agent_ids=self.local_ego_ids[i],
+ num_ego_agents=len(self.local_ego_ids[i]),
+ init_steps=self.init_steps,
+ init_mode=self.init_mode,
+ control_mode=self.control_mode,
+ map_dir=self.map_dir,
+ render_mode=self._render_mode_int,
+ )
+ env_ids.append(env_id)
+ self.c_envs = binding.vectorize(*env_ids)
+ if self._render_keep_client_on_swap:
+ binding.vec_adopt_client(self.c_envs)
+
+ binding.vec_reset(self.c_envs, seed)
+ # Oracle: per-env ego counts may have shifted with the new map IDs;
+ # rebuild the ego→env index map and refresh obs.
+ self._rebuild_ego_env_indices()
+ self._refresh_ego_oracle_obs()
+ self.terminals[:] = 1
+
def _aggregate_scenario_metrics(self, scenario_infos):
"""Aggregate metrics from all infos collected during a scenario."""
if not scenario_infos:
@@ -526,10 +1005,6 @@ def _compute_delta_metrics(self):
delta_key = f"ada_delta_{metric}"
delta_metrics[delta_key] = last_metrics[metric] - first_metrics[metric]
- # Add a count of how many agents this represents
- if "n" in last_metrics:
- delta_metrics["ada_agent_count"] = last_metrics["n"]
-
return delta_metrics
def step(self, actions):
@@ -537,29 +1012,44 @@ def step(self, actions):
self.actions[self.ego_ids] = actions
- if self.population_play:
+ if self.population_play and not self.external_co_player_actions:
co_player_actions = self.get_co_player_actions()
self.actions[self.co_player_ids] = co_player_actions
+ # When external_co_player_actions=True, the main process has already
+ # written co-player actions into self.actions[co_player_ids] via the
+ # shared-memory action buffer; nothing to do here.
binding.vec_step(self.c_envs)
+ if self.reward_only_last_scenario and self.current_scenario != self.k_scenarios - 1:
+ self.rewards[:] = 0
+ # Oracle: copy C obs into pufferl buffer + write oracle slots.
+ self._refresh_ego_oracle_obs()
self.tick += 1
info = []
if self.tick % self.report_interval == 0:
- log = binding.vec_log(self.c_envs)
+ log = binding.vec_log(self.c_envs, self.num_agents)
if log:
if self.adaptive_driving_agent:
self.current_scenario_infos.append(log)
-
- # Only append to info if we're in the 0th scenario
- if self.current_scenario == 0:
+ # For training: only report 0-shot (scenario 0) metrics
+ # For evaluation: report all scenarios when report_all_scenarios=True
+ if self.current_scenario == 0 or self.report_all_scenarios:
info.append(log)
- print("0th scenario metrics are ", log, flush=True)
else:
# Non-adaptive mode: always append
info.append(log)
- print("Regular metrics are ", log, flush=True)
+
+ # Surface the entropy bound + sampled distribution that
+ # _set_co_player_conditioning stashed at the most recent reset.
+ # Drained on emit so each fresh sampling gets logged exactly once.
+ if self._pending_entropy_log is not None:
+ info.append(self._pending_entropy_log)
+ self._pending_entropy_log = None
+ if self._pending_k_eff_log is not None:
+ info.append(self._pending_k_eff_log)
+ self._pending_k_eff_log = None
if self.tick % self.scenario_length == 0:
if self.adaptive_driving_agent and self.current_scenario_infos:
@@ -567,11 +1057,16 @@ def step(self, actions):
scenario_log["scenario_id"] = self.current_scenario
self.scenario_metrics.append(scenario_log)
+ # Log metrics for all scenarios with scenario-specific prefixes
+ prefixed_log = {
+ f"scenario_{self.current_scenario}_{k}": v for k, v in scenario_log.items() if k != "scenario_id"
+ }
+ info.append(prefixed_log)
+
if self.current_scenario == self.k_scenarios - 1:
delta_metrics = self._compute_delta_metrics()
if delta_metrics:
info.append(delta_metrics)
- print("delta metrics are ", delta_metrics, flush=True)
self.scenario_metrics = []
@@ -579,6 +1074,60 @@ def step(self, actions):
self.current_scenario = (self.current_scenario + 1) % self.k_scenarios
+ # Reset coplayer LSTM/Transformer state at scenario boundary so
+ # the partner behaves consistently. The OFF (per-worker) path
+ # does not re-sample conditioning here — it sticks with the
+ # values written to SHM at init/resample — so neither do we.
+ if self.population_play:
+ if self.external_co_player_actions:
+ # Signal main to drop the K/V cache so scenario N+1
+ # starts fresh, mirroring the per-worker OFF path's
+ # `_reset_co_player_state()` here.
+ info.append({"_external_reset_co_cache": True})
+ else:
+ self._reset_co_player_state()
+
+ # MAP ROTATION per scenario: re-init the C envs with new map_ids
+ # while leaving the EGO POLICY's K/V cache (held in main) alone.
+ # This forces the policy to actually USE its past-scenario context
+ # because the current scene is genuinely new.
+ if (
+ self.adaptive_driving_agent
+ and self.map_rand_per_scenario
+ and self.current_scenario != 0 # we just incremented above; 0 means we already wrapped to next episode
+ ):
+ self._reinit_envs_with_new_maps()
+ # PARTNER CONDITIONING ROTATION per scenario: resample the partner's
+ # conditioning vector. The partner POLICY is unchanged but its
+ # effective behavior shifts (e.g. high-entropy stochastic vs
+ # low-entropy deterministic) — the ego must encode partner type
+ # from s_0 observations. Skipped when map_rand fired above
+ # (reinit already re-samples conditioning via _set_env_variables).
+ elif (
+ self.adaptive_driving_agent
+ and self.condition_rand_per_scenario
+ and self.population_play
+ and self.current_scenario != 0
+ and self.co_player_condition_type is not None
+ and self.co_player_condition_type != "none"
+ ):
+ self._set_co_player_conditioning()
+
+ # k_eff curriculum: at within-episode scenario boundaries, decide
+ # whether to cut the ego's K/V cache. Setting truncations=1 (and
+ # terminals=1) at this step makes pufferl drop the cache via
+ # done_mask=t+d, and during training, create_episode_mask blocks
+ # cross-boundary attention. current_scenario != 0 excludes the
+ # episode boundary itself (which is reset by the normal episode-
+ # done logic). Reset rule: cut if current_scenario % k_eff == 0.
+ if (
+ self.adaptive_driving_agent
+ and self.current_scenario != 0
+ and self._k_eff_should_reset_at_current_boundary()
+ ):
+ self.truncations[self.ego_ids] = 1
+ self.terminals[self.ego_ids] = 1
+
if self.tick > 0 and self.resample_frequency > 0 and self.tick % self.resample_frequency == 0:
self.tick = 0
will_resample = 1
@@ -588,73 +1137,26 @@ def step(self, actions):
delta_metrics = self._compute_delta_metrics()
if delta_metrics:
info.append(delta_metrics)
- print("delta metrics 2, are ", delta_metrics, flush=True)
self.scenario_metrics = []
self.current_scenario_infos = []
self.current_scenario = 0
- binding.vec_close(self.c_envs)
- self._set_env_variables()
- env_ids = []
- seed = np.random.randint(0, 2**32 - 1)
- for i in range(self.num_envs):
- cur = self.agent_offsets[i]
- nxt = self.agent_offsets[i + 1]
- env_id = binding.env_init(
- self.observations[cur:nxt],
- self.actions[cur:nxt],
- self.rewards[cur:nxt],
- self.terminals[cur:nxt],
- self.truncations[cur:nxt],
- seed,
- action_type=self._action_type_flag,
- human_agent_idx=self.human_agent_idx,
- dynamics_model=self.dynamics_model,
- reward_vehicle_collision=self.reward_vehicle_collision,
- reward_offroad_collision=self.reward_offroad_collision,
- reward_goal=self.reward_goal,
- reward_goal_post_respawn=self.reward_goal_post_respawn,
- reward_ade=self.reward_ade,
- goal_radius=self.goal_radius,
- goal_behavior=self.goal_behavior,
- collision_behavior=self.collision_behavior,
- offroad_behavior=self.offroad_behavior,
- dt=self.dt,
- scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None),
- max_controlled_agents=self.max_controlled_agents,
- map_id=self.map_ids[i],
- use_rc=self.reward_conditioned,
- use_ec=self.entropy_conditioned,
- use_dc=self.discount_conditioned,
- collision_weight_lb=self.collision_weight_lb,
- collision_weight_ub=self.collision_weight_ub,
- offroad_weight_lb=self.offroad_weight_lb,
- offroad_weight_ub=self.offroad_weight_ub,
- goal_weight_lb=self.goal_weight_lb,
- goal_weight_ub=self.goal_weight_ub,
- entropy_weight_lb=self.entropy_weight_lb,
- entropy_weight_ub=self.entropy_weight_ub,
- discount_weight_lb=self.discount_weight_lb,
- discount_weight_ub=self.discount_weight_ub,
- max_agents=nxt - cur,
- ini_file=self.ini_file,
- population_play=self.population_play,
- num_co_players=len(self.local_co_player_ids[i]),
- co_player_ids=self.local_co_player_ids[i],
- ego_agent_ids=self.local_ego_ids[i],
- num_ego_agents=len(self.local_ego_ids[i]),
- init_steps=self.init_steps,
- init_mode=self.init_mode,
- control_mode=self.control_mode,
- )
- env_ids.append(env_id)
- self.c_envs = binding.vectorize(*env_ids)
-
- binding.vec_reset(self.c_envs, seed)
- self.terminals[:] = 1
+ # Advance k_eff curriculum once per real episode and stash
+ # the new stage's k_eff for wandb. Done before reinit so the
+ # log reflects the stage that the next episode will run at.
+ if self.k_eff_curriculum_enabled:
+ self._k_eff_curriculum_episodes_seen += 1
+ self._pending_k_eff_log = {
+ "ego_curriculum/k_eff": int(self._current_k_eff()),
+ "ego_curriculum/episodes_seen": int(self._k_eff_curriculum_episodes_seen),
+ }
+
+ self._reinit_envs_with_new_maps()
if self.population_play:
info.append(self.ego_ids)
+ if self.external_co_player_actions:
+ info.append({"_external_co_player_ids": self.co_player_ids})
return (self.observations, self.rewards, self.terminals, self.truncations, info)
@@ -662,7 +1164,7 @@ def get_global_agent_state(self):
"""Get current global state of all active agents.
Returns:
- dict with keys 'x', 'y', 'z', 'heading', 'id' containing numpy arrays
+ dict with keys 'x', 'y', 'z', 'heading', 'id', 'length', 'width' containing numpy arrays
of shape (num_active_agents,)
"""
num_agents = self.num_agents
@@ -673,10 +1175,19 @@ def get_global_agent_state(self):
"z": np.zeros(num_agents, dtype=np.float32),
"heading": np.zeros(num_agents, dtype=np.float32),
"id": np.zeros(num_agents, dtype=np.int32),
+ "length": np.zeros(num_agents, dtype=np.float32),
+ "width": np.zeros(num_agents, dtype=np.float32),
}
binding.vec_get_global_agent_state(
- self.c_envs, states["x"], states["y"], states["z"], states["heading"], states["id"]
+ self.c_envs,
+ states["x"],
+ states["y"],
+ states["z"],
+ states["heading"],
+ states["id"],
+ states["length"],
+ states["width"],
)
return states
@@ -715,8 +1226,56 @@ def get_ground_truth_trajectories(self):
return trajectories
- def render(self):
- binding.vec_render(self.c_envs, 0)
+ def get_road_edge_polylines(self):
+ """Get road edge polylines for all scenarios.
+
+ Returns:
+ dict with keys 'x', 'y', 'lengths', 'scenario_id' containing numpy arrays.
+ x, y are flattened point coordinates; lengths indicates points per polyline.
+ """
+ num_polylines, total_points = binding.vec_get_road_edge_counts(self.c_envs)
+
+ polylines = {
+ "x": np.zeros(total_points, dtype=np.float32),
+ "y": np.zeros(total_points, dtype=np.float32),
+ "lengths": np.zeros(num_polylines, dtype=np.int32),
+ "scenario_id": np.zeros(num_polylines, dtype=np.int32),
+ }
+
+ binding.vec_get_road_edge_polylines(
+ self.c_envs,
+ polylines["x"],
+ polylines["y"],
+ polylines["lengths"],
+ polylines["scenario_id"],
+ )
+
+ return polylines
+
+ def render(self, view_mode: int = 0, draw_traces: bool = True, env_id: int = 0):
+ """Render the environment.
+
+ Args:
+ view_mode: View mode for rendering:
+ 0 = VIEW_MODE_SIM_STATE (top-down orthographic)
+ 1 = VIEW_MODE_BEV_AGENT_OBS (bird's eye view centered on agent)
+ 2 = VIEW_MODE_AGENT_PERSP (third-person chase camera)
+ draw_traces: Whether to draw trajectory traces
+ env_id: Which environment to render (default 0)
+ """
+ binding.vec_render(self.c_envs, int(view_mode), draw_traces, env_id, self.current_scenario, self.k_scenarios)
+
+ def set_video_suffix(self, suffix: str, env_id: int = 0):
+ """Set the suffix appended to the mp4 filename for headless rendering.
+
+ Must be called before the first render() call of a rollout.
+ E.g. set_video_suffix("_bev", env_id=0) -> {scenario_id}_bev.mp4
+
+ Args:
+ suffix: Suffix string to append to video filename
+ env_id: Which environment to set suffix for (default 0)
+ """
+ binding.vec_set_video_suffix(self.c_envs, env_id, suffix)
def close(self):
binding.vec_close(self.c_envs)
@@ -727,7 +1286,13 @@ def calculate_area(p1, p2, p3):
return 0.5 * abs((p1["x"] - p3["x"]) * (p2["y"] - p1["y"]) - (p1["x"] - p2["x"]) * (p3["y"] - p1["y"]))
-def simplify_polyline(geometry, polyline_reduction_threshold):
+def dist(a, b):
+ dx = a["x"] - b["x"]
+ dy = a["y"] - b["y"]
+ return dx * dx + dy * dy
+
+
+def simplify_polyline(geometry, polyline_reduction_threshold, max_segment_length):
"""Simplify the given polyline using a method inspired by Visvalingham-Whyatt, optimized for Python."""
num_points = len(geometry)
if num_points < 3:
@@ -756,8 +1321,7 @@ def simplify_polyline(geometry, polyline_reduction_threshold):
point2 = geometry[k_1]
point3 = geometry[k_2]
area = calculate_area(point1, point2, point3)
-
- if area < polyline_reduction_threshold:
+ if area < polyline_reduction_threshold and dist(point1, point3) <= max_segment_length:
skip[k_1] = True
skip_changed = True
k = k_2
@@ -767,9 +1331,28 @@ def simplify_polyline(geometry, polyline_reduction_threshold):
return [geometry[i] for i in range(num_points) if not skip[i]]
-def save_map_binary(map_data, output_file, unique_map_id):
- trajectory_length = 91
- """Saves map data in a binary format readable by C"""
+def _to_int32(v, default=0):
+ """Wrap an arbitrary integer into the signed int32 range using two's-complement
+ semantics so struct.pack('i', ...) cannot overflow. nuPlan IDs and some type
+ fields can exceed 2^31-1; this preserves the low 32 bits the way C would."""
+ try:
+ v = int(v)
+ except (TypeError, ValueError):
+ return default
+ v &= 0xFFFFFFFF
+ if v >= 0x80000000:
+ v -= 0x100000000
+ return v
+
+
+def save_map_binary(map_data, output_file, unique_map_id, trajectory_length=91):
+ """Saves map data in a binary format readable by C.
+
+ `trajectory_length` is how many frames per object/road to write. The
+ C reader is parametric on the per-binary `array_size` header, so any
+ value works. Default 91 matches the legacy WOMD/short-window setup;
+ nuplan scenes go up to 201 frames so pass `trajectory_length=201`
+ to capture the full data."""
with open(output_file, "wb") as f:
# Get metadata
metadata = map_data.get("metadata", {})
@@ -777,27 +1360,25 @@ def save_map_binary(map_data, output_file, unique_map_id):
tracks_to_predict = metadata.get("tracks_to_predict", [])
# Write sdc_track_index
- f.write(struct.pack("i", sdc_track_index))
+ f.write(struct.pack("i", _to_int32(sdc_track_index, -1)))
# Write tracks_to_predict info (indices only)
- f.write(struct.pack("i", len(tracks_to_predict)))
+ f.write(struct.pack("i", _to_int32(len(tracks_to_predict))))
for track in tracks_to_predict:
track_index = track.get("track_index", -1)
- f.write(struct.pack("i", track_index))
+ f.write(struct.pack("i", _to_int32(track_index, -1)))
# Count total entities
- print(len(map_data.get("objects", [])))
- print(len(map_data.get("roads", [])))
num_objects = len(map_data.get("objects", []))
num_roads = len(map_data.get("roads", []))
# num_entities = num_objects + num_roads
- f.write(struct.pack("i", num_objects))
- f.write(struct.pack("i", num_roads))
+ f.write(struct.pack("i", _to_int32(num_objects)))
+ f.write(struct.pack("i", _to_int32(num_roads)))
# f.write(struct.pack('i', num_entities))
# Write objects
for obj in map_data.get("objects", []):
# Write unique map id
- f.write(struct.pack("i", unique_map_id))
+ f.write(struct.pack("i", _to_int32(unique_map_id)))
# Write base entity data
obj_type = obj.get("type", 1)
@@ -807,9 +1388,10 @@ def save_map_binary(map_data, output_file, unique_map_id):
obj_type = 2
elif obj_type == "cyclist":
obj_type = 3
- f.write(struct.pack("i", obj_type)) # type
- f.write(struct.pack("i", obj.get("id", 0))) # id
- f.write(struct.pack("i", trajectory_length)) # array_size
+ f.write(struct.pack("i", _to_int32(obj_type))) # type
+ obj_id = obj.get("id", 0)
+ f.write(struct.pack("i", _to_int32(obj_id))) # id
+ f.write(struct.pack("i", _to_int32(trajectory_length))) # array_size
# Write position arrays
positions = obj.get("position", [])
for i in range(trajectory_length):
@@ -842,7 +1424,7 @@ def save_map_binary(map_data, output_file, unique_map_id):
f.write(
struct.pack(
f"{trajectory_length}i",
- *[int(valids[i]) if i < len(valids) else 0 for i in range(trajectory_length)],
+ *[_to_int32(valids[i]) if i < len(valids) else 0 for i in range(trajectory_length)],
)
)
@@ -854,11 +1436,11 @@ def save_map_binary(map_data, output_file, unique_map_id):
f.write(struct.pack("f", float(goal_pos.get("x", 0.0)))) # Get x value
f.write(struct.pack("f", float(goal_pos.get("y", 0.0)))) # Get y value
f.write(struct.pack("f", float(goal_pos.get("z", 0.0)))) # Get z value
- f.write(struct.pack("i", obj.get("mark_as_expert", 0)))
+ f.write(struct.pack("i", _to_int32(obj.get("mark_as_expert", 0))))
# Write roads
for idx, road in enumerate(map_data.get("roads", [])):
- f.write(struct.pack("i", unique_map_id))
+ f.write(struct.pack("i", _to_int32(unique_map_id)))
geometry = road.get("geometry", [])
road_type = road.get("map_element_id", 0)
@@ -869,7 +1451,7 @@ def save_map_binary(map_data, output_file, unique_map_id):
road_type = 15
# breakpoint()
if len(geometry) > 10 and road_type <= 16:
- geometry = simplify_polyline(geometry, 0.1)
+ geometry = simplify_polyline(geometry, 0.1, 250)
size = len(geometry)
# breakpoint()
if road_type >= 0 and road_type <= 3:
@@ -887,9 +1469,10 @@ def save_map_binary(map_data, output_file, unique_map_id):
elif road_type == 20:
road_type = 10
# Write base entity data
- f.write(struct.pack("i", road_type)) # type
- f.write(struct.pack("i", road.get("id", 0))) # id
- f.write(struct.pack("i", size)) # array_size
+ f.write(struct.pack("i", _to_int32(road_type))) # type
+ road_id = road.get("id", 0)
+ f.write(struct.pack("i", _to_int32(road_id))) # id
+ f.write(struct.pack("i", _to_int32(size))) # array_size
# Write position arrays
for coord in ["x", "y", "z"]:
@@ -904,44 +1487,89 @@ def save_map_binary(map_data, output_file, unique_map_id):
f.write(struct.pack("f", float(goal_pos.get("x", 0.0)))) # Get x value
f.write(struct.pack("f", float(goal_pos.get("y", 0.0)))) # Get y value
f.write(struct.pack("f", float(goal_pos.get("z", 0.0)))) # Get z value
- f.write(struct.pack("i", road.get("mark_as_expert", 0)))
+ f.write(struct.pack("i", _to_int32(road.get("mark_as_expert", 0))))
-def load_map(map_name, unique_map_id, binary_output=None):
+def load_map(map_name, unique_map_id, binary_output=None, trajectory_length=91):
"""Loads a JSON map and optionally saves it as binary"""
with open(map_name, "r") as f:
map_data = json.load(f)
if binary_output:
- save_map_binary(map_data, binary_output, unique_map_id)
+ save_map_binary(map_data, binary_output, unique_map_id, trajectory_length=trajectory_length)
+
+
+def _process_single_map(args):
+ """Worker function to process a single map file"""
+ i, map_path, binary_path, trajectory_length = args
+ try:
+ load_map(str(map_path), i, str(binary_path), trajectory_length=trajectory_length)
+ return (i, map_path.name, True, None)
+ except Exception as e:
+ return (i, map_path.name, False, str(e))
+
+
+def process_all_maps(
+ data_folder="data/processed/training",
+ max_maps=50_000,
+ num_workers=None,
+ shuffle=False,
+ trajectory_length=91,
+ output_subdir=None,
+):
+ """Process all maps and save them as binaries using multiprocessing
+
+ Args:
+ data_folder: Path to the folder containing JSON map files
+ max_maps: Maximum number of maps to process
+ num_workers: Number of parallel workers (defaults to cpu_count())
+ shuffle: If True, shuffle the JSON files before assigning map IDs.
+ This ensures that when using num_maps < total, you get
+ a random mix of all source maps instead of alphabetically first ones.
+ """
+ from pathlib import Path
+ import random
+ if num_workers is None:
+ num_workers = cpu_count()
-def process_all_maps():
- """Process all maps and save them as binaries"""
- from pathlib import Path
+ # Path to the training data
+ data_dir = Path(data_folder)
+ dataset_name = output_subdir if output_subdir is not None else data_dir.name
# Create the binaries directory if it doesn't exist
- binary_dir = Path("resources/drive/binaries")
+ binary_dir = Path(f"resources/drive/binaries/{dataset_name}")
binary_dir.mkdir(parents=True, exist_ok=True)
- # Path to the training data
- data_dir = Path("data/processed/training")
-
# Get all JSON files in the training directory
json_files = sorted(data_dir.glob("*.json"))
- print(f"Found {len(json_files)} JSON files")
+ if shuffle:
+ json_files = list(json_files)
+ random.shuffle(json_files)
- # Process each JSON file
- for i, map_path in enumerate(json_files[:10000]):
- binary_file = f"map_{i:03d}.bin" # Use zero-padded numbers for consistent sorting
+ # Prepare arguments for parallel processing
+ tasks = []
+ for i, map_path in enumerate(json_files[:max_maps]):
+ binary_file = f"map_{i:03d}.bin"
binary_path = binary_dir / binary_file
+ tasks.append((i, map_path, binary_path, trajectory_length))
+
+ # Process maps in parallel with progress bar
+ with Pool(num_workers) as pool:
+ results = list(
+ tqdm(pool.imap(_process_single_map, tasks), total=len(tasks), desc="Processing maps", unit="map")
+ )
+
+ # Collect statistics
+ successful = sum(1 for _, _, success, _ in results if success)
+ failed = sum(1 for _, _, success, _ in results if not success)
- print(f"Processing {map_path.name} -> {binary_file}")
- # try:
- load_map(str(map_path), i, str(binary_path))
- # except Exception as e:
- # print(f"Error processing {map_path.name}: {e}")
+ if failed > 0:
+ print(f"\nFailed {failed}/{len(results)} files:")
+ for i, name, success, error in results:
+ if not success:
+ print(f" {name}: {error}")
def test_performance(timeout=10, atn_cache=1024, num_agents=1024):
@@ -976,4 +1604,10 @@ def test_performance(timeout=10, atn_cache=1024, num_agents=1024):
if __name__ == "__main__":
# test_performance()
- process_all_maps()
+ # Process the train dataset
+ # process_all_maps(data_folder="/data/processed/training")
+ process_all_maps(data_folder="/workspace/ADA/data/nuplan-gpudrive/nuplan")
+ # Process the validation/test dataset
+ # process_all_maps(data_folder="data/processed/validation")
+ # # Process the validation_interactive dataset
+ # process_all_maps(data_folder="data/processed/validation_interactive")
diff --git a/pufferlib/ocean/drive/drivenet.h b/pufferlib/ocean/drive/drivenet.h
deleted file mode 100644
index 56dbda5212..0000000000
--- a/pufferlib/ocean/drive/drivenet.h
+++ /dev/null
@@ -1,256 +0,0 @@
-#include
-#include "drive.h"
-#include "puffernet.h"
-#include
-#include
-#include
-#include
-#include
-#include
-
-typedef struct DriveNet DriveNet;
-struct DriveNet {
- int num_agents;
- int conditioning_dims;
- int ego_dim;
- float* obs_self;
- float* obs_partner;
- float* obs_road;
- float* partner_linear_output;
- float* road_linear_output;
- float* partner_layernorm_output;
- float* road_layernorm_output;
- float* partner_linear_output_two;
- float* road_linear_output_two;
- Linear* ego_encoder;
- Linear* road_encoder;
- Linear* partner_encoder;
- LayerNorm* ego_layernorm;
- LayerNorm* road_layernorm;
- LayerNorm* partner_layernorm;
- Linear* ego_encoder_two;
- Linear* road_encoder_two;
- Linear* partner_encoder_two;
- MaxDim1* partner_max;
- MaxDim1* road_max;
- CatDim1* cat1;
- CatDim1* cat2;
- GELU* gelu;
- Linear* shared_embedding;
- ReLU* relu;
- LSTM* lstm;
- Linear* actor;
- Linear* value_fn;
- Multidiscrete* multidiscrete;
-};
-
-DriveNet* init_drivenet(Weights* weights, int num_agents, int dynamics_model, bool use_rc, bool use_ec, bool use_dc) {
- DriveNet* net = calloc(1, sizeof(DriveNet));
- int hidden_size = 256;
- int input_size = 64;
-
- int base_ego_dim = (dynamics_model == JERK) ? 10 : 7;
- net->conditioning_dims = (use_rc ? 3 : 0) + (use_ec ? 1 : 0) + (use_dc ? 1 : 0);
- net->ego_dim = base_ego_dim + net->conditioning_dims;
-
- // Determine action space size based on dynamics model
- int action_size, logit_sizes[2];
- int action_dim;
- if (dynamics_model == CLASSIC) {
- action_size = 7 * 13; // Joint action space
- logit_sizes[0] = 7 * 13;
- action_dim = 1;
- } else { // JERK
- action_size = 7; // 4 + 3
- logit_sizes[0] = 4;
- logit_sizes[1] = 3;
- action_dim = 2;
- }
-
- net->num_agents = num_agents;
-
- net->obs_self = calloc(num_agents*net->ego_dim, sizeof(float));
- net->obs_partner = calloc(num_agents*63*7, sizeof(float)); // 63 objects, 7 features
- net->obs_road = calloc(num_agents*200*13, sizeof(float)); // 200 objects, 13 features
-
- net->partner_linear_output = calloc(num_agents*63*input_size, sizeof(float));
- net->road_linear_output = calloc(num_agents*200*input_size, sizeof(float));
- net->partner_linear_output_two = calloc(num_agents*63*input_size, sizeof(float));
- net->road_linear_output_two = calloc(num_agents*200*input_size, sizeof(float));
- net->partner_layernorm_output = calloc(num_agents*63*input_size, sizeof(float));
- net->road_layernorm_output = calloc(num_agents*200*input_size, sizeof(float));
- net->ego_encoder = make_linear(weights, num_agents, net->ego_dim, input_size);
- net->ego_layernorm = make_layernorm(weights, num_agents, input_size);
- net->ego_encoder_two = make_linear(weights, num_agents, input_size, input_size);
- net->road_encoder = make_linear(weights, num_agents, 13, input_size);
- net->road_layernorm = make_layernorm(weights, num_agents, input_size);
- net->road_encoder_two = make_linear(weights, num_agents, input_size, input_size);
- net->partner_encoder = make_linear(weights, num_agents, 7, input_size);
- net->partner_layernorm = make_layernorm(weights, num_agents, input_size);
- net->partner_encoder_two = make_linear(weights, num_agents, input_size, input_size);
- net->partner_max = make_max_dim1(num_agents, 63, input_size);
- net->road_max = make_max_dim1(num_agents, 200, input_size);
- net->cat1 = make_cat_dim1(num_agents, input_size, input_size);
- net->cat2 = make_cat_dim1(num_agents, input_size + input_size, input_size);
- net->gelu = make_gelu(num_agents, 3*input_size);
- net->shared_embedding = make_linear(weights, num_agents, input_size*3, hidden_size);
- net->relu = make_relu(num_agents, hidden_size);
- net->actor = make_linear(weights, num_agents, hidden_size, action_size);
- net->value_fn = make_linear(weights, num_agents, hidden_size, 1);
- net->lstm = make_lstm(weights, num_agents, hidden_size, 256);
- memset(net->lstm->state_h, 0, num_agents*256*sizeof(float));
- memset(net->lstm->state_c, 0, num_agents*256*sizeof(float));
- net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, action_dim);
- return net;
-}
-
-void free_drivenet(DriveNet* net) {
- free(net->obs_self);
- free(net->obs_partner);
- free(net->obs_road);
- free(net->partner_linear_output);
- free(net->road_linear_output);
- free(net->partner_linear_output_two);
- free(net->road_linear_output_two);
- free(net->partner_layernorm_output);
- free(net->road_layernorm_output);
- free(net->ego_encoder);
- free(net->road_encoder);
- free(net->partner_encoder);
- free(net->ego_layernorm);
- free(net->road_layernorm);
- free(net->partner_layernorm);
- free(net->ego_encoder_two);
- free(net->road_encoder_two);
- free(net->partner_encoder_two);
- free(net->partner_max);
- free(net->road_max);
- free(net->cat1);
- free(net->cat2);
- free(net->gelu);
- free(net->shared_embedding);
- free(net->relu);
- free(net->multidiscrete);
- free(net->actor);
- free(net->value_fn);
- free(net->lstm);
- free(net);
-}
-
-void forward(DriveNet* net, float* observations, int* actions) {
- int ego_dim = net->ego_dim;
-
- // Clear previous observations
- memset(net->obs_self, 0, net->num_agents * ego_dim * sizeof(float));
- memset(net->obs_partner, 0, net->num_agents * 63 * 7 * sizeof(float));
- memset(net->obs_road, 0, net->num_agents * 200 * 13 * sizeof(float));
-
- // Reshape observations into 2D boards and additional features
- float* obs_self = net->obs_self;
- float (*obs_partner)[63][7] = (float (*)[63][7])net->obs_partner;
- float (*obs_road)[200][13] = (float (*)[200][13])net->obs_road;
-
- for (int b = 0; b < net->num_agents; b++) {
- int b_offset = b * (ego_dim + 63*7 + 200*7); // offset for each batch
- int partner_offset = b_offset + ego_dim;
- int road_offset = b_offset + ego_dim + 63*7;
- // Process self observation
- for(int i = 0; i < ego_dim; i++) {
- obs_self[b*ego_dim + i] = observations[b_offset + i];
- }
-
- // Process partner observation
- for(int i = 0; i < 63; i++) {
- for(int j = 0; j < 7; j++) {
- net->obs_partner[b*63*7 + i*7 + j] = observations[partner_offset + i*7 + j];
- }
- }
-
- // Process road observation
- for(int i = 0; i < 200; i++) {
- for(int j = 0; j < 7; j++) {
- net->obs_road[b*200*13 + i*13 + j] = observations[road_offset + i*7 + j];
- }
- for(int j = 0; j < 7; j++) {
- if(j == observations[road_offset+i*7 + 6]) {
- net->obs_road[b*200*13 + i*13 + 6 + j] = 1.0f;
- } else {
- net->obs_road[b*200*13 + i*13 + 6 + j] = 0.0f;
- }
- }
- }
- }
-
- // Forward pass through the network
- linear(net->ego_encoder, net->obs_self);
- layernorm(net->ego_layernorm, net->ego_encoder->output);
- linear(net->ego_encoder_two, net->ego_layernorm->output);
- for (int b = 0; b < net->num_agents; b++) {
- for (int obj = 0; obj < 63; obj++) {
- // Get the 7 features for this object
- float* obj_features = &net->obs_partner[b*63*7 + obj*7];
- // Apply linear layer to this object
- _linear(obj_features, net->partner_encoder->weights, net->partner_encoder->bias,
- &net->partner_linear_output[b*63*64 + obj*64], 1, 7, 64);
- }
- }
-
- for (int b = 0; b < net->num_agents; b++) {
- for (int obj = 0; obj < 63; obj++) {
- float* after_first = &net->partner_linear_output[b*63*64 + obj*64];
- _layernorm(after_first, net->partner_layernorm->weights, net->partner_layernorm->bias,
- &net->partner_layernorm_output[b*63*64 + obj*64], 1, 64);
- }
- }
- for (int b = 0; b < net->num_agents; b++) {
- for (int obj = 0; obj < 63; obj++) {
- // Get the 7 features for this object
- float* obj_features = &net->partner_layernorm_output[b*63*64 + obj*64];
- // Apply linear layer to this object
- _linear(obj_features, net->partner_encoder_two->weights, net->partner_encoder_two->bias,
- &net->partner_linear_output_two[b*63*64 + obj*64], 1, 64, 64);
-
- }
- }
-
- // Process road objects: apply linear to each object individually
- for (int b = 0; b < net->num_agents; b++) {
- for (int obj = 0; obj < 200; obj++) {
- // Get the 13 features for this object
- float* obj_features = &net->obs_road[b*200*13 + obj*13];
- // Apply linear layer to this object
- _linear(obj_features, net->road_encoder->weights, net->road_encoder->bias,
- &net->road_linear_output[b*200*64 + obj*64], 1, 13, 64);
- }
- }
-
- // Apply layer norm and second linear to each road object
- for (int b = 0; b < net->num_agents; b++) {
- for (int obj = 0; obj < 200; obj++) {
- float* after_first = &net->road_linear_output[b*200*64 + obj*64];
- _layernorm(after_first, net->road_layernorm->weights, net->road_layernorm->bias,
- &net->road_layernorm_output[b*200*64 + obj*64], 1, 64);
- }
- }
- for (int b = 0; b < net->num_agents; b++) {
- for (int obj = 0; obj < 200; obj++) {
- float* after_first = &net->road_layernorm_output[b*200*64 + obj*64];
- _linear(after_first, net->road_encoder_two->weights, net->road_encoder_two->bias,
- &net->road_linear_output_two[b*200*64 + obj*64], 1, 64, 64);
- }
- }
-
- max_dim1(net->partner_max, net->partner_linear_output_two);
- max_dim1(net->road_max, net->road_linear_output_two);
- cat_dim1(net->cat1, net->ego_encoder_two->output, net->road_max->output);
- cat_dim1(net->cat2, net->cat1->output, net->partner_max->output);
- gelu(net->gelu, net->cat2->output);
- linear(net->shared_embedding, net->gelu->output);
- relu(net->relu, net->shared_embedding->output);
- lstm(net->lstm, net->relu->output);
- linear(net->actor, net->lstm->state_h);
- linear(net->value_fn, net->lstm->state_h);
-
- // Get action by taking argmax of actor output
- softmax_multidiscrete(net->multidiscrete, net->actor->output, actions);
-}
diff --git a/pufferlib/ocean/drive/error.h b/pufferlib/ocean/drive/error.h
index b1eb78e7ed..77ae171bb5 100644
--- a/pufferlib/ocean/drive/error.h
+++ b/pufferlib/ocean/drive/error.h
@@ -18,21 +18,29 @@ typedef enum {
ERROR_UNKNOWN
} ErrorType;
-const char* error_type_to_string(ErrorType type) {
+const char *error_type_to_string(ErrorType type) {
switch (type) {
- case ERROR_NONE: return "No Error";
- case ERROR_NULL_POINTER: return "Null Pointer";
- case ERROR_INVALID_ARGUMENT: return "Invalid Argument";
- case ERROR_OUT_OF_BOUNDS: return "Out of Bounds";
- case ERROR_MEMORY_ALLOCATION: return "Memory Allocation Failed";
- case ERROR_FILE_NOT_FOUND: return "File Not Found";
- case ERROR_INITIALIZATION_FAILED: return "Initialization Failed";
- default: return "Unknown Error";
+ case ERROR_NONE:
+ return "No Error";
+ case ERROR_NULL_POINTER:
+ return "Null Pointer";
+ case ERROR_INVALID_ARGUMENT:
+ return "Invalid Argument";
+ case ERROR_OUT_OF_BOUNDS:
+ return "Out of Bounds";
+ case ERROR_MEMORY_ALLOCATION:
+ return "Memory Allocation Failed";
+ case ERROR_FILE_NOT_FOUND:
+ return "File Not Found";
+ case ERROR_INITIALIZATION_FAILED:
+ return "Initialization Failed";
+ default:
+ return "Unknown Error";
}
}
// Enhanced error function with custom message support
-void raise_error_with_message(ErrorType type, const char* format, ...) {
+void raise_error_with_message(ErrorType type, const char *format, ...) {
printf("Error occurred: %s", error_type_to_string(type));
if (format != NULL) {
@@ -47,35 +55,28 @@ void raise_error_with_message(ErrorType type, const char* format, ...) {
}
// Simple error function (backward compatibility)
-void raise_error(ErrorType type) {
- raise_error_with_message(type, NULL);
-}
+void raise_error(ErrorType type) { raise_error_with_message(type, NULL); }
// Convenience macros for common error patterns
-#define RAISE_FILE_ERROR(path) \
- raise_error_with_message(ERROR_FILE_NOT_FOUND, "at path: %s", path)
+#define RAISE_FILE_ERROR(path) raise_error_with_message(ERROR_FILE_NOT_FOUND, "at path: %s", path)
-#define RAISE_BOUNDS_ERROR() \
- raise_error(ERROR_OUT_OF_BOUNDS)
+#define RAISE_BOUNDS_ERROR() raise_error(ERROR_OUT_OF_BOUNDS)
-#define RAISE_BOUNDS_ERROR_WITH_BOUNDS(index, min, max) \
+#define RAISE_BOUNDS_ERROR_WITH_BOUNDS(index, min, max) \
raise_error_with_message(ERROR_OUT_OF_BOUNDS, "index %d exceeds minimum of %d and maximum %d", index, min, max)
-#define RAISE_NULL_ERROR() \
- raise_error(ERROR_NULL_POINTER)
+#define RAISE_NULL_ERROR() raise_error(ERROR_NULL_POINTER)
-#define RAISE_NULL_ERROR_WITH_NAME(var_name) \
+#define RAISE_NULL_ERROR_WITH_NAME(var_name) \
raise_error_with_message(ERROR_NULL_POINTER, "variable '%s' is null", var_name)
-#define RAISE_MEMORY_ERROR() \
- raise_error(ERROR_MEMORY_ALLOCATION)
+#define RAISE_MEMORY_ERROR() raise_error(ERROR_MEMORY_ALLOCATION)
-#define RAISE_MEMORY_ERROR_WITH_SIZE(size) \
+#define RAISE_MEMORY_ERROR_WITH_SIZE(size) \
raise_error_with_message(ERROR_MEMORY_ALLOCATION, "failed to allocate %zu bytes", size)
-#define RAISE_INVALID_ARG_ERROR() \
- raise_error(ERROR_INVALID_ARGUMENT)
+#define RAISE_INVALID_ARG_ERROR() raise_error(ERROR_INVALID_ARGUMENT)
-#define RAISE_INVALID_ARG_ERROR_WITH_ARG(arg_name, value) \
+#define RAISE_INVALID_ARG_ERROR_WITH_ARG(arg_name, value) \
raise_error_with_message(ERROR_INVALID_ARGUMENT, "invalid value for '%s': %d", arg_name, value)
#endif
diff --git a/pufferlib/ocean/drive/rollout.py b/pufferlib/ocean/drive/rollout.py
new file mode 100644
index 0000000000..aa8d4babf4
--- /dev/null
+++ b/pufferlib/ocean/drive/rollout.py
@@ -0,0 +1,161 @@
+"""Shared rollout loop for Drive evaluation and rendering.
+
+Single source of truth for the forward-sample-step-break cycle. Used by:
+ - Training renders (periodic evaluation with video logging)
+ - Offline batch rendering
+ - Safe evaluation time rendering
+
+This module enables rendering with ANY policy architecture (LSTM, Transformer, etc.)
+by keeping policy inference in Python/PyTorch while using C bindings for environment
+simulation and graphics rendering.
+"""
+
+from dataclasses import dataclass
+from enum import IntEnum
+from typing import Optional
+
+import numpy as np
+import torch
+
+import pufferlib.pytorch
+
+
+class RenderView(IntEnum):
+ """View modes for rendering."""
+
+ FULL_SIM_STATE = 0 # Top-down orthographic view of full simulation
+ BEV_AGENT_OBS = 1 # Bird's eye view centered on agent observation
+ AGENT_PERSPECTIVE = 2 # Third-person chase camera following agent
+
+
+@dataclass
+class RenderContext:
+ """Enables rendering inside rollout_loop.
+
+ Attributes:
+ view_mode: RenderView enum value passed to driver.render().
+ env_id: which sub-env in the vecenv to record from (default 0).
+ draw_traces: whether to draw trajectory traces.
+ video_basename: full mp4 basename (without ".mp4"). Set once before
+ the first render via driver.set_video_suffix. Caller is
+ responsible for making this unique across renders.
+ """
+
+ view_mode: RenderView
+ env_id: int = 0
+ draw_traces: bool = True
+ video_basename: str = "render"
+
+
+def rollout_loop(
+ policy,
+ env,
+ device,
+ use_rnn: bool,
+ max_steps: Optional[int] = None,
+ render_ctx: Optional[RenderContext] = None,
+):
+ """Run a single policy rollout in a Drive vecenv.
+
+ This function handles policy inference in Python/PyTorch, making it work
+ with ANY model architecture (LSTM, Transformer, etc.).
+
+ Args:
+ policy: the policy to run. Caller is responsible for calling .eval().
+ env: a PufferEnv-compatible vecenv wrapping one or more Drive sub-envs.
+ device: torch device for observation / state tensors.
+ use_rnn: whether to allocate and carry LSTM hidden state.
+ max_steps: loop iteration cap. Defaults to env.driver_env.scenario_length.
+ render_ctx: if set, render the specified env/view every step before
+ sampling actions. Filename suffix is applied via set_video_suffix.
+
+ Returns:
+ The last info returned by env.step().
+ """
+ driver = env.driver_env
+
+ # Handle population play mode - only ego agents are controlled by the policy
+ population_play = getattr(driver, "population_play", False)
+ if population_play:
+ num_ego_agents = driver.num_ego_agents
+ ego_ids = driver.ego_ids
+ print(f"[rollout] Population play mode: {num_ego_agents} ego agents, {driver.num_co_players} co-players")
+ else:
+ num_ego_agents = env.observation_space.shape[0]
+ ego_ids = None
+
+ # Set full video basename before the first render call
+ if render_ctx is not None:
+ driver.set_video_suffix(render_ctx.video_basename, env_id=render_ctx.env_id)
+
+ obs, _ = env.reset()
+
+ # Initialize recurrent state based on policy type
+ # Note: state is for EGO agents only, not co-players
+ state = {}
+ if use_rnn:
+ # Check if this is a Transformer or LSTM policy
+ is_transformer = hasattr(policy, "transformer") or hasattr(policy, "horizon")
+
+ if is_transformer:
+ # Transformer handles its own state initialization in forward_eval
+ # when transformer_context is missing, so we just pass empty state
+ state = {}
+ else:
+ # LSTM policy - initialize h and c states for ego agents only
+ if hasattr(policy, "hidden_size"):
+ hidden_size = policy.hidden_size
+ else:
+ hidden_size = 128 # default
+ state = dict(
+ lstm_h=torch.zeros(num_ego_agents, hidden_size, device=device),
+ lstm_c=torch.zeros(num_ego_agents, hidden_size, device=device),
+ )
+
+ if max_steps is None:
+ max_steps = getattr(driver, "scenario_length", 91)
+
+ info = []
+ for step in range(max_steps):
+ if step % 30 == 0:
+ print(f"[Python Render] Step {step}/{max_steps}", flush=True)
+ # Render BEFORE the step so each frame shows the state the policy was
+ # conditioned on.
+ if render_ctx is not None:
+ driver.render(
+ view_mode=render_ctx.view_mode,
+ draw_traces=render_ctx.draw_traces,
+ env_id=render_ctx.env_id,
+ )
+
+ with torch.no_grad():
+ # In population play, only pass ego observations to the policy
+ if population_play:
+ ego_obs = obs[ego_ids]
+ ob_t = torch.as_tensor(ego_obs).to(device)
+ else:
+ ob_t = torch.as_tensor(obs).to(device)
+
+ logits, _ = policy.forward_eval(ob_t, state)
+ action, _, _ = pufferlib.pytorch.sample_logits(logits)
+
+ # Reshape actions to match expected format
+ if population_play:
+ # In population play, policy outputs actions for ego agents only
+ # env.action_space.shape is (total_agents, action_dim), we need (num_ego_agents, action_dim)
+ action_dim = env.action_space.shape[1] if len(env.action_space.shape) > 1 else 1
+ action_np = action.cpu().numpy().reshape(num_ego_agents, action_dim)
+ else:
+ action_np = action.cpu().numpy().reshape(env.action_space.shape)
+
+ # Clip continuous actions to the valid range
+ if isinstance(logits, torch.distributions.Normal):
+ action_np = np.clip(action_np, env.action_space.low, env.action_space.high)
+
+ obs, _, _, truncs, info = env.step(action_np)
+
+ # Break when episode ends (truncs.all() is set when the env auto-resets)
+ if truncs.all():
+ break
+
+ return info
diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c
deleted file mode 100644
index 172ddd1a8f..0000000000
--- a/pufferlib/ocean/drive/visualize.c
+++ /dev/null
@@ -1,556 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-#include "rlgl.h"
-#include
-#include
-#include
-#include
-#include "error.h"
-#include "drivenet.h"
-#include "libgen.h"
-#include "../env_config.h"
-#define TRAJECTORY_LENGTH_DEFAULT 91
-
-typedef struct {
- int pipefd[2];
- pid_t pid;
-} VideoRecorder;
-
-bool OpenVideo(VideoRecorder *recorder, const char *output_filename, int width, int height) {
- if (pipe(recorder->pipefd) == -1) {
- fprintf(stderr, "Failed to create pipe\n");
- return false;
- }
-
- recorder->pid = fork();
- if (recorder->pid == -1) {
- fprintf(stderr, "Failed to fork\n");
- return false;
- }
-
- char size_str[64];
- snprintf(size_str, sizeof(size_str), "%dx%d", width, height);
-
- if (recorder->pid == 0) { // Child process: run ffmpeg
- close(recorder->pipefd[1]);
- dup2(recorder->pipefd[0], STDIN_FILENO);
- close(recorder->pipefd[0]);
- // Close all other file descriptors to prevent leaks
- for (int fd = 3; fd < 256; fd++) {
- close(fd);
- }
- execlp("ffmpeg", "ffmpeg",
- "-y",
- "-f", "rawvideo",
- "-pix_fmt", "rgba",
- "-s", size_str,
- "-r", "30",
- "-i", "-",
- "-c:v", "libx264",
- "-pix_fmt", "yuv420p",
- "-preset", "ultrafast",
- "-crf", "23",
- "-loglevel", "error",
- output_filename,
- NULL);
- TraceLog(LOG_ERROR, "Failed to launch ffmpeg");
- return false;
- }
-
- close(recorder->pipefd[0]); // Close read end in parent
- return true;
-}
-
-void WriteFrame(VideoRecorder *recorder, int width, int height) {
- unsigned char *screen_data = rlReadScreenPixels(width, height);
- write(recorder->pipefd[1], screen_data, width * height * 4 * sizeof(*screen_data));
- RL_FREE(screen_data);
-}
-
-void CloseVideo(VideoRecorder *recorder) {
- close(recorder->pipefd[1]);
- waitpid(recorder->pid, NULL, 0);
-}
-
-void renderTopDownView(Drive* env, Client* client, int map_height, int obs, int lasers, int trajectories, int frame_count, float* path, int log_trajectories, int show_grid, int img_width, int img_height) {
-
- BeginDrawing();
-
- // Top-down orthographic camera
- Camera3D camera = {0};
- camera.position = (Vector3){ 0.0f, 0.0f, 500.0f }; // above the scene
- camera.target = (Vector3){ 0.0f, 0.0f, 0.0f }; // look at origin
- camera.up = (Vector3){ 0.0f, -1.0f, 0.0f };
- camera.fovy = map_height;
- camera.projection = CAMERA_ORTHOGRAPHIC;
-
- client->width = img_width;
- client->height = img_height;
-
- Color road = (Color){35, 35, 37, 255};
- ClearBackground(road);
- BeginMode3D(camera);
- rlEnableDepthTest();
-
- // Draw human replay trajectories if enabled
- if(log_trajectories){
- for(int i=0; iactive_agent_count; i++){
- int idx = env->active_agent_indices[i];
- Vector3 prev_point = {0};
- bool has_prev = false;
-
- for(int j = 0; j < env->entities[idx].array_size; j++){
- float x = env->entities[idx].traj_x[j];
- float y = env->entities[idx].traj_y[j];
- float valid = env->entities[idx].traj_valid[j];
-
- if(!valid) {
- has_prev = false;
- continue;
- }
-
- Vector3 curr_point = {x, y, 0.5f};
-
- if(has_prev) {
- DrawLine3D(prev_point, curr_point, Fade(LIGHTGREEN, 0.6f));
- }
-
- prev_point = curr_point;
- has_prev = true;
- }
- }
- }
-
- // Draw agent trajs
- if(trajectories){
- for(int i=0; iactive_agent_indices[env->human_agent_idx];
- Entity* agent = &env->entities[agent_idx];
-
- BeginDrawing();
-
- Camera3D camera = {0};
- // Position camera behind and above the agent
- camera.position = (Vector3){
- agent->x - (25.0f * cosf(agent->heading)),
- agent->y - (25.0f * sinf(agent->heading)),
- 15.0f
- };
- camera.target = (Vector3){
- agent->x + 40.0f * cosf(agent->heading),
- agent->y + 40.0f * sinf(agent->heading),
- 1.0f
- };
- camera.up = (Vector3){ 0.0f, 0.0f, 1.0f };
- camera.fovy = 45.0f;
- camera.projection = CAMERA_PERSPECTIVE;
-
- Color road = (Color){35, 35, 37, 255};
-
- ClearBackground(road);
- BeginMode3D(camera);
- rlEnableDepthTest();
- draw_scene(env, client, 0, obs_only, lasers, show_grid); // mode=0 for agent view
- EndMode3D();
- EndDrawing();
-}
-
-static int run_cmd(const char *cmd) {
- int rc = system(cmd);
- if (rc != 0) {
- fprintf(stderr, "[ffmpeg] command failed (%d): %s\n", rc, cmd);
- }
- return rc;
-}
-
-// Make a high-quality GIF from numbered PNG frames like frame_000.png
-static int make_gif_from_frames(const char *pattern, int fps,
- const char *palette_path,
- const char *out_gif) {
- char cmd[1024];
-
- // 1) Generate palette (no quotes needed for simple filter)
- // NOTE: if your frames start at 000, you don't need -start_number.
- snprintf(cmd, sizeof(cmd),
- "ffmpeg -y -framerate %d -i %s -vf palettegen %s",
- fps, pattern, palette_path);
- if (run_cmd(cmd) != 0) return -1;
-
- // 2) Use palette to encode the GIF
- snprintf(cmd, sizeof(cmd),
- "ffmpeg -y -framerate %d -i %s -i %s -lavfi paletteuse -loop 0 %s",
- fps, pattern, palette_path, out_gif);
- if (run_cmd(cmd) != 0) return -1;
-
- return 0;
-}
-
-int eval_gif(const char* map_name, const char* policy_name, int show_grid, int obs_only, int lasers, int log_trajectories, int frame_skip, float goal_radius, int init_steps, int use_rc, int use_ec, int use_dc, int max_controlled_agents, const char* view_mode, const char* output_topdown, const char* output_agent, int num_maps, int scenario_length_override, int init_mode, int control_mode, int goal_behavior) {
-
- // Parse configuration from INI file
- env_init_config conf = {0}; // Initialize to zero
- const char* ini_file = "pufferlib/config/ocean/drive.ini";
- if(ini_parse(ini_file, handler, &conf) < 0) {
- fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file);
- return -1;
- }
-
- char map_buffer[100];
- if (map_name == NULL) {
- srand(time(NULL));
- int random_map = rand() % num_maps;
- sprintf(map_buffer, "resources/drive/binaries/map_%03d.bin", random_map); // random map file
- map_name = map_buffer;
- }
-
- if (frame_skip <= 0) {
- frame_skip = 1; // Default: render every frame
- }
-
- // Check if map file exists
- FILE* map_file = fopen(map_name, "rb");
- if (map_file == NULL) {
- RAISE_FILE_ERROR(map_name);
- }
- fclose(map_file);
-
- FILE* policy_file = fopen(policy_name, "rb");
- if (policy_file == NULL) {
- RAISE_FILE_ERROR(policy_name);
- }
- fclose(policy_file);
-
- Drive env = {
- .dynamics_model = conf.dynamics_model,
- .reward_vehicle_collision = conf.reward_vehicle_collision,
- .reward_offroad_collision = conf.reward_offroad_collision,
- .reward_ade = conf.reward_ade,
- .goal_radius = goal_radius,
- .dt = conf.dt,
- .map_name = (char*)map_name,
- .init_steps = init_steps,
- .max_controlled_agents = max_controlled_agents,
- .collision_behavior = conf.collision_behavior,
- .offroad_behavior = conf.offroad_behavior,
- .goal_behavior = goal_behavior,
- .init_mode = init_mode,
- .control_mode = control_mode,
- .use_rc = use_rc,
- .use_ec = use_ec,
- .use_dc = use_dc,
- // Conditioning weight bounds (defaults from drive.py)
- .collision_weight_lb = -0.0f,
- .collision_weight_ub = -0.0f,
- .offroad_weight_lb = -0.0f,
- .offroad_weight_ub = -0.0f,
- .goal_weight_lb = 1.0f,
- .goal_weight_ub = 1.0f,
- .entropy_weight_lb = 0.001f,
- .entropy_weight_ub = 0.001f,
- .discount_weight_lb = 0.98f,
- .discount_weight_ub = 0.98f,
- };
-
- env.scenario_length = (scenario_length_override > 0) ? scenario_length_override :
- (conf.scenario_length > 0) ? conf.scenario_length : TRAJECTORY_LENGTH_DEFAULT;
- allocate(&env);
-
- // Set which vehicle to focus on for obs mode
- env.human_agent_idx = 0;
-
- c_reset(&env);
- // Make client for rendering
- Client* client = (Client*)calloc(1, sizeof(Client));
- env.client = client;
-
- SetConfigFlags(FLAG_WINDOW_HIDDEN);
-
- SetTargetFPS(6000);
-
- float map_width = env.grid_map->bottom_right_x - env.grid_map->top_left_x;
- float map_height = env.grid_map->top_left_y - env.grid_map->bottom_right_y;
-
- printf("Map size: %.1fx%.1f\n", map_width, map_height);
- float scale = 6.0f; // Can be used to increase the video quality
-
- // Calculate video width and height; round to nearest even number
- int img_width = (int)roundf(map_width * scale / 2.0f) * 2;
- int img_height = (int)roundf(map_height * scale / 2.0f) * 2;
- InitWindow(img_width, img_height, "Puffer Drive");
- SetConfigFlags(FLAG_MSAA_4X_HINT);
-
- Weights* weights = load_weights(policy_name);
- printf("Active agents in map: %d\n", env.active_agent_count);
- DriveNet* net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, use_rc, use_ec, use_dc);
-
- int frame_count = env.scenario_length > 0 ? env.scenario_length : TRAJECTORY_LENGTH_DEFAULT;
- int log_trajectory = log_trajectories;
- char filename_topdown[256];
- char filename_agent[256];
-
- if (output_topdown != NULL && output_agent != NULL) {
- strcpy(filename_topdown, output_topdown);
- strcpy(filename_agent, output_agent);
- } else {
- char policy_base[256];
- strcpy(policy_base, policy_name);
- *strrchr(policy_base, '.') = '\0';
-
- char map[256];
- strcpy(map, basename((char*)map_name));
- *strrchr(map, '.') = '\0';
-
- // Create video directory if it doesn't exist
- char video_dir[256];
- sprintf(video_dir, "%s/video", policy_base);
- char mkdir_cmd[512];
- snprintf(mkdir_cmd, sizeof(mkdir_cmd), "mkdir -p \"%s\"", video_dir);
- system(mkdir_cmd);
-
- sprintf(filename_topdown, "%s/video/%s_topdown.mp4", policy_base, map);
- sprintf(filename_agent, "%s/video/%s_agent.mp4", policy_base, map);
- }
-
- bool render_topdown = (strcmp(view_mode, "both") == 0 || strcmp(view_mode, "topdown") == 0);
- bool render_agent = (strcmp(view_mode, "both") == 0 || strcmp(view_mode, "agent") == 0);
-
- printf("Rendering: %s\n", view_mode);
-
- int rendered_frames = 0;
- double startTime = GetTime();
-
- VideoRecorder topdown_recorder, agent_recorder;
-
- if (render_topdown) {
- if (!OpenVideo(&topdown_recorder, filename_topdown, img_width, img_height)) {
- CloseWindow();
- return -1;
- }
- }
-
- if (render_agent) {
- if (!OpenVideo(&agent_recorder, filename_agent, img_width, img_height)) {
- if (render_topdown) CloseVideo(&topdown_recorder);
- CloseWindow();
- return -1;
- }
- }
-
- if (render_topdown) {
- printf("Recording topdown view...\n");
- for(int i = 0; i < frame_count; i++) {
- if (i % frame_skip == 0) {
- renderTopDownView(&env, client, map_height, 0, 0, 0, frame_count, NULL, log_trajectories, show_grid, img_width, img_height);
- WriteFrame(&topdown_recorder, img_width, img_height);
- rendered_frames++;
- }
- int (*actions)[2] = (int(*)[2])env.actions;
- forward(net, env.observations, (int*)env.actions);
- c_step(&env);
- }
-
- }
-
- if (render_agent) {
- c_reset(&env);
- printf("Recording agent view...\n");
- for(int i = 0; i < frame_count; i++) {
- if (i % frame_skip == 0) {
- renderAgentView(&env, client, map_height, obs_only, lasers, show_grid);
- WriteFrame(&agent_recorder, img_width, img_height);
- rendered_frames++;
- }
- int (*actions)[2] = (int(*)[2])env.actions;
- forward(net, env.observations, (int*)env.actions);
- c_step(&env);
- }
- }
-
- double endTime = GetTime();
- double elapsedTime = endTime - startTime;
- double writeFPS = (elapsedTime > 0) ? rendered_frames / elapsedTime : 0;
-
- printf("Wrote %d frames in %.2f seconds (%.2f FPS) to %s \n",
- rendered_frames, elapsedTime, writeFPS, filename_topdown);
-
- if (render_topdown) {
- CloseVideo(&topdown_recorder);
- }
- if (render_agent) {
- CloseVideo(&agent_recorder);
- }
- CloseWindow();
-
- // Clean up resources
- free(client);
- free_allocated(&env);
- free_drivenet(net);
- free(weights);
- return 0;
-}
-
-int main(int argc, char* argv[]) {
- int show_grid = 0;
- int obs_only = 0;
- int lasers = 0;
- int log_trajectories = 1;
- int frame_skip = 1;
- float goal_radius = 2.0f;
- int init_steps = 0;
- const char* map_name = NULL;
- const char* policy_name = "resources/drive/puffer_drive_weights.bin";
- int max_controlled_agents = -1;
- int num_maps = 1;
- int scenario_length_cli = -1;
- int use_rc = 0;
- int use_ec = 0;
- int use_dc = 0;
- int init_mode = 0;
- int control_mode = 0;
- int goal_behavior = 0;
-
- const char* view_mode = "both"; // "both", "topdown", "agent"
- const char* output_topdown = NULL;
- const char* output_agent = NULL;
-
- // Parse command line arguments
- for (int i = 1; i < argc; i++) {
- if (strcmp(argv[i], "--show-grid") == 0) {
- show_grid = 1;
- } else if (strcmp(argv[i], "--obs-only") == 0) {
- obs_only = 1;
- } else if (strcmp(argv[i], "--lasers") == 0) {
- lasers = 1;
- } else if (strcmp(argv[i], "--log-trajectories") == 0) {
- log_trajectories = 1;
- } else if (strcmp(argv[i], "--frame-skip") == 0) {
- if (i + 1 < argc) {
- frame_skip = atoi(argv[i + 1]);
- i++; // Skip the next argument since we consumed it
- if (frame_skip <= 0) {
- frame_skip = 1; // Ensure valid value
- }
- }
- } else if (strcmp(argv[i], "--goal-radius") == 0) {
- if (i + 1 < argc) {
- goal_radius = atof(argv[i + 1]);
- i++;
- if (goal_radius <= 0) {
- goal_radius = 2.0f; // Ensure valid value
- }
- }
- } else if (strcmp(argv[i], "--map-name") == 0) {
- // Check if there's a next argument for the map path
- if (i + 1 < argc) {
- map_name = argv[i + 1];
- i++; // Skip the next argument since we used it as map path
- } else {
- fprintf(stderr, "Error: --map-name option requires a map file path\n");
- return 1;
- }
- } else if (strcmp(argv[i], "--policy-name") == 0) {
- if (i + 1 < argc) {
- policy_name = argv[i + 1];
- i++;
- } else {
- fprintf(stderr, "Error: --policy-name option requires a policy file path\n");
- return 1;
- }
- } else if (strcmp(argv[i], "--view") == 0) {
- if (i + 1 < argc) {
- view_mode = argv[i + 1];
- i++;
- if (strcmp(view_mode, "both") != 0 &&
- strcmp(view_mode, "topdown") != 0 &&
- strcmp(view_mode, "agent") != 0) {
- fprintf(stderr, "Error: --view must be 'both', 'topdown', or 'agent'\n");
- return 1;
- }
- } else {
- fprintf(stderr, "Error: --view option requires a value (both/topdown/agent)\n");
- return 1;
- }
- } else if (strcmp(argv[i], "--output-topdown") == 0) {
- if (i + 1 < argc) {
- output_topdown = argv[i + 1];
- i++;
- }
- } else if (strcmp(argv[i], "--output-agent") == 0) {
- if (i + 1 < argc) {
- output_agent = argv[i + 1];
- i++;
- }
- } else if (strcmp(argv[i], "--init-steps") == 0) {
- if (i + 1 < argc) {
- init_steps = atoi(argv[i + 1]);
- i++;
- if (init_steps < 0) {
- init_steps = 0;
- }
- }
- } else if (strcmp(argv[i], "--init-mode") == 0) {
- if (i + 1 < argc) {
- init_mode = atoi(argv[i + 1]);
- i++;
- }
- } else if (strcmp(argv[i], "--control-mode") == 0) {
- if (i + 1 < argc) {
- control_mode = atoi(argv[i + 1]);
- i++;
- }
- } else if (strcmp(argv[i], "--max-controlled-agents") == 0) {
- if (i + 1 < argc) {
- max_controlled_agents = atoi(argv[i + 1]);
- i++;
- }
- } else if (strcmp(argv[i], "--num-maps") == 0) {
- if (i + 1 < argc) {
- num_maps = atoi(argv[i + 1]);
- i++;
- }
- } else if (strcmp(argv[i], "--scenario-length") == 0) {
- if (i + 1 < argc) {
- scenario_length_cli = atoi(argv[i + 1]);
- i++;
- }
- } else if (strcmp(argv[i], "--use-rc") == 0) {
- if (i + 1 < argc) {
- use_rc = atoi(argv[i + 1]);
- i++;
- }
- } else if (strcmp(argv[i], "--use-ec") == 0) {
- if (i + 1 < argc) {
- use_ec = atoi(argv[i + 1]);
- i++;
- }
- } else if (strcmp(argv[i], "--use-dc") == 0) {
- if (i + 1 < argc) {
- use_dc = atoi(argv[i + 1]);
- }
- } else if (strcmp(argv[i], "--goal-behavior") == 0) {
- if (i + 1 < argc) {
- goal_behavior = atoi(argv[i + 1]);
- i++;
- }
- }
- }
-
- eval_gif(map_name, policy_name, show_grid, obs_only, lasers, log_trajectories, frame_skip, goal_radius, init_steps, use_rc, use_ec, use_dc, max_controlled_agents, view_mode, output_topdown, output_agent, num_maps, scenario_length_cli, init_mode, control_mode, goal_behavior);
- return 0;
-}
diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h
index 6f7fa3c7df..343cc96584 100644
--- a/pufferlib/ocean/env_binding.h
+++ b/pufferlib/ocean/env_binding.h
@@ -3,42 +3,36 @@
#include
// Forward declarations for env-specific functions supplied by user
-static int my_log(PyObject* dict, Log* log);
-static int my_init(Env* env, PyObject* args, PyObject* kwargs);
+static int my_log(PyObject *dict, Log *log);
+static int my_init(Env *env, PyObject *args, PyObject *kwargs);
-static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs);
+static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs);
#ifndef MY_SHARED
-static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) {
- return NULL;
-}
+static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) { return NULL; }
#endif
-static PyObject* my_get(PyObject* dict, Env* env);
+static PyObject *my_get(PyObject *dict, Env *env);
#ifndef MY_GET
-static PyObject* my_get(PyObject* dict, Env* env) {
- return NULL;
-}
+static PyObject *my_get(PyObject *dict, Env *env) { return NULL; }
#endif
-static int my_put(Env* env, PyObject* args, PyObject* kwargs);
+static int my_put(Env *env, PyObject *args, PyObject *kwargs);
#ifndef MY_PUT
-static int my_put(Env* env, PyObject* args, PyObject* kwargs) {
- return 0;
-}
+static int my_put(Env *env, PyObject *args, PyObject *kwargs) { return 0; }
#endif
#ifndef MY_METHODS
#define MY_METHODS {NULL, NULL, 0, NULL}
#endif
-static Env* unpack_env(PyObject* args) {
- PyObject* handle_obj = PyTuple_GetItem(args, 0);
+static Env *unpack_env(PyObject *args) {
+ PyObject *handle_obj = PyTuple_GetItem(args, 0);
if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) {
PyErr_SetString(PyExc_TypeError, "env_handle must be an integer");
return NULL;
}
- Env* env = (Env*)PyLong_AsVoidPtr(handle_obj);
+ Env *env = (Env *)PyLong_AsVoidPtr(handle_obj);
if (!env) {
PyErr_SetString(PyExc_ValueError, "Invalid env handle");
return NULL;
@@ -48,36 +42,36 @@ static Env* unpack_env(PyObject* args) {
}
// Python function to initialize the environment
-static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
+static PyObject *env_init(PyObject *self, PyObject *args, PyObject *kwargs) {
if (PyTuple_Size(args) != 6) {
PyErr_SetString(PyExc_TypeError, "Environment requires 5 arguments");
return NULL;
}
- Env* env = (Env*)calloc(1, sizeof(Env));
+ Env *env = (Env *)calloc(1, sizeof(Env));
if (!env) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate environment");
return NULL;
}
- PyObject* obs = PyTuple_GetItem(args, 0);
+ PyObject *obs = PyTuple_GetItem(args, 0);
if (!PyObject_TypeCheck(obs, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array");
return NULL;
}
- PyArrayObject* observations = (PyArrayObject*)obs;
+ PyArrayObject *observations = (PyArrayObject *)obs;
if (!PyArray_ISCONTIGUOUS(observations)) {
PyErr_SetString(PyExc_ValueError, "Observations must be contiguous");
return NULL;
}
env->observations = PyArray_DATA(observations);
- PyObject* act = PyTuple_GetItem(args, 1);
+ PyObject *act = PyTuple_GetItem(args, 1);
if (!PyObject_TypeCheck(act, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array");
return NULL;
}
- PyArrayObject* actions = (PyArrayObject*)act;
+ PyArrayObject *actions = (PyArrayObject *)act;
if (!PyArray_ISCONTIGUOUS(actions)) {
PyErr_SetString(PyExc_ValueError, "Actions must be contiguous");
return NULL;
@@ -88,12 +82,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
return NULL;
}
- PyObject* rew = PyTuple_GetItem(args, 2);
+ PyObject *rew = PyTuple_GetItem(args, 2);
if (!PyObject_TypeCheck(rew, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array");
return NULL;
}
- PyArrayObject* rewards = (PyArrayObject*)rew;
+ PyArrayObject *rewards = (PyArrayObject *)rew;
if (!PyArray_ISCONTIGUOUS(rewards)) {
PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous");
return NULL;
@@ -104,12 +98,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
}
env->rewards = PyArray_DATA(rewards);
- PyObject* term = PyTuple_GetItem(args, 3);
+ PyObject *term = PyTuple_GetItem(args, 3);
if (!PyObject_TypeCheck(term, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array");
return NULL;
}
- PyArrayObject* terminals = (PyArrayObject*)term;
+ PyArrayObject *terminals = (PyArrayObject *)term;
if (!PyArray_ISCONTIGUOUS(terminals)) {
PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous");
return NULL;
@@ -120,12 +114,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
}
env->terminals = PyArray_DATA(terminals);
- PyObject* trunc = PyTuple_GetItem(args, 4);
+ PyObject *trunc = PyTuple_GetItem(args, 4);
if (!PyObject_TypeCheck(trunc, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Truncations must be a NumPy array");
return NULL;
}
- PyArrayObject* truncations = (PyArrayObject*)trunc;
+ PyArrayObject *truncations = (PyArrayObject *)trunc;
if (!PyArray_ISCONTIGUOUS(truncations)) {
PyErr_SetString(PyExc_ValueError, "Truncations must be contiguous");
return NULL;
@@ -136,8 +130,7 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
}
// env->truncations = PyArray_DATA(truncations);
-
- PyObject* seed_arg = PyTuple_GetItem(args, 5);
+ PyObject *seed_arg = PyTuple_GetItem(args, 5);
if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) {
PyErr_SetString(PyExc_TypeError, "seed must be an integer");
return NULL;
@@ -151,11 +144,11 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
if (kwargs == NULL) {
kwargs = PyDict_New();
} else {
- Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it
+ Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it
}
// Add the seed to kwargs
- PyObject* py_seed = PyLong_FromLong(seed);
+ PyObject *py_seed = PyLong_FromLong(seed);
if (PyDict_SetItemString(kwargs, "seed", py_seed) < 0) {
PyErr_SetString(PyExc_RuntimeError, "Failed to set seed in kwargs");
Py_DECREF(py_seed);
@@ -164,7 +157,7 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
}
Py_DECREF(py_seed);
- PyObject* empty_args = PyTuple_New(0);
+ PyObject *empty_args = PyTuple_New(0);
my_init(env, empty_args, kwargs);
Py_DECREF(kwargs);
if (PyErr_Occurred()) {
@@ -175,14 +168,14 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) {
}
// Python function to reset the environment
-static PyObject* env_reset(PyObject* self, PyObject* args) {
+static PyObject *env_reset(PyObject *self, PyObject *args) {
if (PyTuple_Size(args) != 2) {
PyErr_SetString(PyExc_TypeError, "env_reset requires 2 arguments");
return NULL;
}
- Env* env = unpack_env(args);
- if (!env){
+ Env *env = unpack_env(args);
+ if (!env) {
return NULL;
}
c_reset(env);
@@ -190,15 +183,15 @@ static PyObject* env_reset(PyObject* self, PyObject* args) {
}
// Python function to step the environment
-static PyObject* env_step(PyObject* self, PyObject* args) {
+static PyObject *env_step(PyObject *self, PyObject *args) {
int num_args = PyTuple_Size(args);
if (num_args != 1) {
PyErr_SetString(PyExc_TypeError, "vec_render requires 1 argument");
return NULL;
}
- Env* env = unpack_env(args);
- if (!env){
+ Env *env = unpack_env(args);
+ if (!env) {
return NULL;
}
c_step(env);
@@ -206,9 +199,9 @@ static PyObject* env_step(PyObject* self, PyObject* args) {
}
// Python function to step the environment
-static PyObject* env_render(PyObject* self, PyObject* args) {
- Env* env = unpack_env(args);
- if (!env){
+static PyObject *env_render(PyObject *self, PyObject *args) {
+ Env *env = unpack_env(args);
+ if (!env) {
return NULL;
}
c_render(env);
@@ -216,9 +209,9 @@ static PyObject* env_render(PyObject* self, PyObject* args) {
}
// Python function to close the environment
-static PyObject* env_close(PyObject* self, PyObject* args) {
- Env* env = unpack_env(args);
- if (!env){
+static PyObject *env_close(PyObject *self, PyObject *args) {
+ Env *env = unpack_env(args);
+ if (!env) {
return NULL;
}
c_close(env);
@@ -226,12 +219,12 @@ static PyObject* env_close(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
-static PyObject* env_get(PyObject* self, PyObject* args) {
- Env* env = unpack_env(args);
- if (!env){
+static PyObject *env_get(PyObject *self, PyObject *args) {
+ Env *env = unpack_env(args);
+ if (!env) {
return NULL;
}
- PyObject* dict = PyDict_New();
+ PyObject *dict = PyDict_New();
my_get(dict, env);
if (PyErr_Occurred()) {
return NULL;
@@ -239,19 +232,19 @@ static PyObject* env_get(PyObject* self, PyObject* args) {
return dict;
}
-static PyObject* env_put(PyObject* self, PyObject* args, PyObject* kwargs) {
+static PyObject *env_put(PyObject *self, PyObject *args, PyObject *kwargs) {
int num_args = PyTuple_Size(args);
if (num_args != 1) {
PyErr_SetString(PyExc_TypeError, "env_put requires 1 positional argument");
return NULL;
}
- Env* env = unpack_env(args);
- if (!env){
+ Env *env = unpack_env(args);
+ if (!env) {
return NULL;
}
- PyObject* empty_args = PyTuple_New(0);
+ PyObject *empty_args = PyTuple_New(0);
my_put(env, empty_args, kwargs);
if (PyErr_Occurred()) {
return NULL;
@@ -261,18 +254,18 @@ static PyObject* env_put(PyObject* self, PyObject* args, PyObject* kwargs) {
}
typedef struct {
- Env** envs;
+ Env **envs;
int num_envs;
} VecEnv;
-static VecEnv* unpack_vecenv(PyObject* args) {
- PyObject* handle_obj = PyTuple_GetItem(args, 0);
+static VecEnv *unpack_vecenv(PyObject *args) {
+ PyObject *handle_obj = PyTuple_GetItem(args, 0);
if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) {
PyErr_SetString(PyExc_TypeError, "env_handle must be an integer");
return NULL;
}
- VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(handle_obj);
+ VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(handle_obj);
if (!vec) {
PyErr_SetString(PyExc_ValueError, "Missing or invalid vec env handle");
return NULL;
@@ -286,18 +279,18 @@ static VecEnv* unpack_vecenv(PyObject* args) {
return vec;
}
-static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
+static PyObject *vec_init(PyObject *self, PyObject *args, PyObject *kwargs) {
if (PyTuple_Size(args) != 7) {
PyErr_SetString(PyExc_TypeError, "vec_init requires 6 arguments");
return NULL;
}
- VecEnv* vec = (VecEnv*)calloc(1, sizeof(VecEnv));
+ VecEnv *vec = (VecEnv *)calloc(1, sizeof(VecEnv));
if (!vec) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env");
return NULL;
}
- PyObject* num_envs_arg = PyTuple_GetItem(args, 5);
+ PyObject *num_envs_arg = PyTuple_GetItem(args, 5);
if (!PyObject_TypeCheck(num_envs_arg, &PyLong_Type)) {
PyErr_SetString(PyExc_TypeError, "num_envs must be an integer");
return NULL;
@@ -308,25 +301,25 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
return NULL;
}
vec->num_envs = num_envs;
- vec->envs = (Env**)calloc(num_envs, sizeof(Env*));
+ vec->envs = (Env **)calloc(num_envs, sizeof(Env *));
if (!vec->envs) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env");
return NULL;
}
- PyObject* seed_obj = PyTuple_GetItem(args, 6);
+ PyObject *seed_obj = PyTuple_GetItem(args, 6);
if (!PyObject_TypeCheck(seed_obj, &PyLong_Type)) {
PyErr_SetString(PyExc_TypeError, "seed must be an integer");
return NULL;
}
int seed = PyLong_AsLong(seed_obj);
- PyObject* obs = PyTuple_GetItem(args, 0);
+ PyObject *obs = PyTuple_GetItem(args, 0);
if (!PyObject_TypeCheck(obs, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array");
return NULL;
}
- PyArrayObject* observations = (PyArrayObject*)obs;
+ PyArrayObject *observations = (PyArrayObject *)obs;
if (!PyArray_ISCONTIGUOUS(observations)) {
PyErr_SetString(PyExc_ValueError, "Observations must be contiguous");
return NULL;
@@ -336,12 +329,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
return NULL;
}
- PyObject* act = PyTuple_GetItem(args, 1);
+ PyObject *act = PyTuple_GetItem(args, 1);
if (!PyObject_TypeCheck(act, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array");
return NULL;
}
- PyArrayObject* actions = (PyArrayObject*)act;
+ PyArrayObject *actions = (PyArrayObject *)act;
if (!PyArray_ISCONTIGUOUS(actions)) {
PyErr_SetString(PyExc_ValueError, "Actions must be contiguous");
return NULL;
@@ -351,12 +344,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
return NULL;
}
- PyObject* rew = PyTuple_GetItem(args, 2);
+ PyObject *rew = PyTuple_GetItem(args, 2);
if (!PyObject_TypeCheck(rew, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array");
return NULL;
}
- PyArrayObject* rewards = (PyArrayObject*)rew;
+ PyArrayObject *rewards = (PyArrayObject *)rew;
if (!PyArray_ISCONTIGUOUS(rewards)) {
PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous");
return NULL;
@@ -366,12 +359,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
return NULL;
}
- PyObject* term = PyTuple_GetItem(args, 3);
+ PyObject *term = PyTuple_GetItem(args, 3);
if (!PyObject_TypeCheck(term, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array");
return NULL;
}
- PyArrayObject* terminals = (PyArrayObject*)term;
+ PyArrayObject *terminals = (PyArrayObject *)term;
if (!PyArray_ISCONTIGUOUS(terminals)) {
PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous");
return NULL;
@@ -381,12 +374,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
return NULL;
}
- PyObject* trunc = PyTuple_GetItem(args, 4);
+ PyObject *trunc = PyTuple_GetItem(args, 4);
if (!PyObject_TypeCheck(trunc, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "Truncations must be a NumPy array");
return NULL;
}
- PyArrayObject* truncations = (PyArrayObject*)trunc;
+ PyArrayObject *truncations = (PyArrayObject *)trunc;
if (!PyArray_ISCONTIGUOUS(truncations)) {
PyErr_SetString(PyExc_ValueError, "Truncations must be contiguous");
return NULL;
@@ -400,11 +393,11 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
if (kwargs == NULL) {
kwargs = PyDict_New();
} else {
- Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it
+ Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it
}
for (int i = 0; i < num_envs; i++) {
- Env* env = (Env*)calloc(1, sizeof(Env));
+ Env *env = (Env *)calloc(1, sizeof(Env));
if (!env) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate environment");
Py_DECREF(kwargs);
@@ -415,18 +408,18 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
// // Make sure the log is initialized to 0
memset(&env->log, 0, sizeof(Log));
- env->observations = (void*)((char*)PyArray_DATA(observations) + i*PyArray_STRIDE(observations, 0));
- env->actions = (void*)((char*)PyArray_DATA(actions) + i*PyArray_STRIDE(actions, 0));
- env->rewards = (void*)((char*)PyArray_DATA(rewards) + i*PyArray_STRIDE(rewards, 0));
- env->terminals = (void*)((char*)PyArray_DATA(terminals) + i*PyArray_STRIDE(terminals, 0));
+ env->observations = (void *)((char *)PyArray_DATA(observations) + i * PyArray_STRIDE(observations, 0));
+ env->actions = (void *)((char *)PyArray_DATA(actions) + i * PyArray_STRIDE(actions, 0));
+ env->rewards = (void *)((char *)PyArray_DATA(rewards) + i * PyArray_STRIDE(rewards, 0));
+ env->terminals = (void *)((char *)PyArray_DATA(terminals) + i * PyArray_STRIDE(terminals, 0));
// env->truncations = (void*)((char*)PyArray_DATA(truncations) + i*PyArray_STRIDE(truncations, 0));
// Assumes each process has the same number of environments
- int env_seed = i + seed*vec->num_envs;
+ int env_seed = i + seed * vec->num_envs;
srand(env_seed);
// Add the seed to kwargs for this environment
- PyObject* py_seed = PyLong_FromLong(env_seed);
+ PyObject *py_seed = PyLong_FromLong(env_seed);
if (PyDict_SetItemString(kwargs, "seed", py_seed) < 0) {
PyErr_SetString(PyExc_RuntimeError, "Failed to set seed in kwargs");
Py_DECREF(py_seed);
@@ -435,7 +428,7 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
}
Py_DECREF(py_seed);
- PyObject* empty_args = PyTuple_New(0);
+ PyObject *empty_args = PyTuple_New(0);
my_init(env, empty_args, kwargs);
if (PyErr_Occurred()) {
return NULL;
@@ -446,22 +439,21 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) {
return PyLong_FromVoidPtr(vec);
}
-
// Python function to close the environment
-static PyObject* vectorize(PyObject* self, PyObject* args) {
+static PyObject *vectorize(PyObject *self, PyObject *args) {
int num_envs = PyTuple_Size(args);
if (num_envs == 0) {
PyErr_SetString(PyExc_TypeError, "make_vec requires at least 1 env id");
return NULL;
}
- VecEnv* vec = (VecEnv*)calloc(1, sizeof(VecEnv));
+ VecEnv *vec = (VecEnv *)calloc(1, sizeof(VecEnv));
if (!vec) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env");
return NULL;
}
- vec->envs = (Env**)calloc(num_envs, sizeof(Env*));
+ vec->envs = (Env **)calloc(num_envs, sizeof(Env *));
if (!vec->envs) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env");
return NULL;
@@ -469,29 +461,30 @@ static PyObject* vectorize(PyObject* self, PyObject* args) {
vec->num_envs = num_envs;
for (int i = 0; i < num_envs; i++) {
- PyObject* handle_obj = PyTuple_GetItem(args, i);
+ PyObject *handle_obj = PyTuple_GetItem(args, i);
if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) {
- PyErr_SetString(PyExc_TypeError, "Env ids must be integers. Pass them as separate args with *env_ids, not as a list.");
+ PyErr_SetString(PyExc_TypeError,
+ "Env ids must be integers. Pass them as separate args with *env_ids, not as a list.");
return NULL;
}
- vec->envs[i] = (Env*)PyLong_AsVoidPtr(handle_obj);
+ vec->envs[i] = (Env *)PyLong_AsVoidPtr(handle_obj);
}
return PyLong_FromVoidPtr(vec);
}
-static PyObject* vec_reset(PyObject* self, PyObject* args) {
+static PyObject *vec_reset(PyObject *self, PyObject *args) {
if (PyTuple_Size(args) != 2) {
PyErr_SetString(PyExc_TypeError, "vec_reset requires 2 arguments");
return NULL;
}
- VecEnv* vec = unpack_vecenv(args);
+ VecEnv *vec = unpack_vecenv(args);
if (!vec) {
return NULL;
}
- PyObject* seed_arg = PyTuple_GetItem(args, 1);
+ PyObject *seed_arg = PyTuple_GetItem(args, 1);
if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) {
PyErr_SetString(PyExc_TypeError, "seed must be an integer");
return NULL;
@@ -500,20 +493,20 @@ static PyObject* vec_reset(PyObject* self, PyObject* args) {
for (int i = 0; i < vec->num_envs; i++) {
// Assumes each process has the same number of environments
- srand(i + seed*vec->num_envs);
+ srand(i + seed * vec->num_envs);
c_reset(vec->envs[i]);
}
Py_RETURN_NONE;
}
-static PyObject* vec_step(PyObject* self, PyObject* arg) {
+static PyObject *vec_step(PyObject *self, PyObject *arg) {
int num_args = PyTuple_Size(arg);
if (num_args != 1) {
PyErr_SetString(PyExc_TypeError, "vec_step requires 1 argument");
return NULL;
}
- VecEnv* vec = unpack_vecenv(arg);
+ VecEnv *vec = unpack_vecenv(arg);
if (!vec) {
return NULL;
}
@@ -524,37 +517,72 @@ static PyObject* vec_step(PyObject* self, PyObject* arg) {
Py_RETURN_NONE;
}
-static PyObject* vec_render(PyObject* self, PyObject* args) {
+static PyObject *vec_render(PyObject *self, PyObject *args) {
int num_args = PyTuple_Size(args);
- if (num_args != 2) {
- PyErr_SetString(PyExc_TypeError, "vec_render requires 2 arguments");
+ if (num_args != 6) {
+ PyErr_SetString(PyExc_TypeError, "vec_render requires 6 arguments: (vec_env, view_mode, draw_traces, env_id, "
+ "current_scenario, k_scenarios)");
return NULL;
}
- VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0));
+ VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0));
if (!vec) {
PyErr_SetString(PyExc_ValueError, "Invalid vec_env handle");
return NULL;
}
- PyObject* env_id_arg = PyTuple_GetItem(args, 1);
- if (!PyObject_TypeCheck(env_id_arg, &PyLong_Type)) {
- PyErr_SetString(PyExc_TypeError, "env_id must be an integer");
+ int view_mode = (int)PyLong_AsLong(PyTuple_GetItem(args, 1));
+ int draw_traces = PyObject_IsTrue(PyTuple_GetItem(args, 2));
+ int env_id = (int)PyLong_AsLong(PyTuple_GetItem(args, 3));
+ int current_scenario = (int)PyLong_AsLong(PyTuple_GetItem(args, 4));
+ int k_scenarios = (int)PyLong_AsLong(PyTuple_GetItem(args, 5));
+
+ if (env_id < 0 || env_id >= vec->num_envs) {
+ PyErr_SetString(PyExc_ValueError, "env_id out of range");
return NULL;
}
- int env_id = PyLong_AsLong(env_id_arg);
- c_render(vec->envs[env_id]);
+ c_render_with_mode(vec->envs[env_id], view_mode, draw_traces, current_scenario, k_scenarios);
Py_RETURN_NONE;
}
-static int assign_to_dict(PyObject* dict, char* key, float value) {
- PyObject* v = PyFloat_FromDouble(value);
+static PyObject *vec_set_video_suffix(PyObject *self, PyObject *args) {
+ int num_args = PyTuple_Size(args);
+ if (num_args != 3) {
+ PyErr_SetString(PyExc_TypeError, "vec_set_video_suffix requires 3 arguments: (vec_env, env_id, suffix)");
+ return NULL;
+ }
+
+ VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0));
+ if (!vec) {
+ PyErr_SetString(PyExc_ValueError, "Invalid vec_env handle");
+ return NULL;
+ }
+
+ int env_id = (int)PyLong_AsLong(PyTuple_GetItem(args, 1));
+ if (env_id < 0 || env_id >= vec->num_envs) {
+ PyErr_SetString(PyExc_ValueError, "env_id out of range");
+ return NULL;
+ }
+
+ PyObject *suffix_obj = PyTuple_GetItem(args, 2);
+ const char *suffix = PyUnicode_AsUTF8(suffix_obj);
+ if (!suffix) {
+ PyErr_SetString(PyExc_TypeError, "suffix must be a string");
+ return NULL;
+ }
+
+ set_video_suffix(vec->envs[env_id], suffix);
+ Py_RETURN_NONE;
+}
+
+static int assign_to_dict(PyObject *dict, char *key, float value) {
+ PyObject *v = PyFloat_FromDouble(value);
if (v == NULL) {
PyErr_SetString(PyExc_TypeError, "Failed to convert log value");
return 1;
}
- if(PyDict_SetItemString(dict, key, v) < 0) {
+ if (PyDict_SetItemString(dict, key, v) < 0) {
PyErr_SetString(PyExc_TypeError, "Failed to set log value");
return 1;
}
@@ -562,99 +590,125 @@ static int assign_to_dict(PyObject* dict, char* key, float value) {
return 0;
}
-static PyObject* vec_log(PyObject* self, PyObject* args) {
- VecEnv* vec = unpack_vecenv(args);
+static PyObject *vec_log(PyObject *self, PyObject *args) {
+ if (PyTuple_Size(args) != 2) {
+ PyErr_SetString(PyExc_TypeError, "vec_log requires 2 arguments");
+ return NULL;
+ }
+ VecEnv *vec = unpack_vecenv(args);
if (!vec) {
return NULL;
}
+ PyObject *num_agents_arg = PyTuple_GetItem(args, 1);
+ float num_agents = (float)PyLong_AsLong(num_agents_arg);
+
// Iterates over logs one float at a time. Will break
// horribly if Log has non-float data.
Log aggregate = {0};
int num_keys = sizeof(Log) / sizeof(float);
- // Adaptive agent logging variables
- float ada_delta_completion_rate = 0.0f;
- float ada_delta_score = 0.0f;
- float ada_delta_perf = 0.0f;
- float ada_delta_collision_rate = 0.0f;
- float ada_delta_offroad_rate = 0.0f;
- float ada_delta_num_goals_reached = 0.0f;
- float ada_delta_dnf_rate = 0.0f;
- float ada_delta_lane_alignment_rate = 0.0f;
- float ada_delta_avg_displacement_error = 0.0f;
- float ada_delta_episode_return = 0.0f;
- int ada_agent_count = 0;
- int has_co_players = 0; // Flag to check if any env has co-players
- Co_Player_Log co_player_aggregate = {0};
- int num_co_player_keys = sizeof(Co_Player_Log) / sizeof(float);
+ int has_co_players = 0; // Flag to check if any env has co-players
+ Log co_player_aggregate = {0}; // Now using Log struct instead of Co_Player_Log
for (int i = 0; i < vec->num_envs; i++) {
- Env* env = vec->envs[i];
-
+ Env *env = vec->envs[i];
for (int j = 0; j < num_keys; j++) {
- ((float*)&aggregate)[j] += ((float*)&env->log)[j];
- ((float*)&env->log)[j] = 0.0f;
+ ((float *)&aggregate)[j] += ((float *)&env->log)[j];
}
if (env->population_play && env->num_co_players > 0 && env->co_player_ids != NULL) {
has_co_players = 1;
-
- // Aggregate co-player logs
- for (int j = 0; j < num_co_player_keys; j++) {
- ((float*)&co_player_aggregate)[j] += ((float*)&env->co_player_log)[j];
- ((float*)&env->co_player_log)[j] = 0.0f; // Reset after aggregating
+ // Aggregate co-player logs (now same structure as ego logs)
+ for (int j = 0; j < num_keys; j++) {
+ ((float *)&co_player_aggregate)[j] += ((float *)&env->co_player_log)[j];
}
}
+ }
+
+ PyObject *dict = PyDict_New();
+ // Check if we have enough data from EITHER ego agents OR total (ego + co-players)
+ float total_n = aggregate.n + (has_co_players ? co_player_aggregate.n : 0.0f);
+ if (total_n < num_agents) {
+ return dict; // Not enough data yet
}
- PyObject* dict = PyDict_New();
+ // Got enough data. Reset logs and return metrics
+ for (int i = 0; i < vec->num_envs; i++) {
+ Env *env = vec->envs[i];
+ for (int j = 0; j < num_keys; j++) {
+ ((float *)&env->log)[j] = 0.0f;
+ }
- // Average regular logs
- if (aggregate.n > 0.0f) {
- float n = aggregate.n;
- for (int i = 0; i < num_keys; i++) {
- ((float*)&aggregate)[i] /= n;
+ if (env->population_play && env->num_co_players > 0 && env->co_player_ids != NULL) {
+ for (int j = 0; j < num_keys; j++) {
+ ((float *)&env->co_player_log)[j] = 0.0f;
+ }
+ }
+ }
+
+ float n = aggregate.n;
+ // Average across EGO agents only
+ if (n > 0) {
+ // Compute completion_rate from raw totals BEFORE averaging
+ float total_goals_reached = aggregate.goals_reached_this_episode;
+ float total_goals_sampled = aggregate.goals_sampled_this_episode;
+ if (total_goals_sampled > 0) {
+ aggregate.completion_rate = total_goals_reached / total_goals_sampled;
+ } else {
+ aggregate.completion_rate = 0.0f;
}
- // User populates dict
- my_log(dict, &aggregate);
- assign_to_dict(dict, "n", n);
+ for (int i = 0; i < num_keys; i++) {
+ ((float *)&aggregate)[i] /= n;
+ }
}
- if (has_co_players && co_player_aggregate.co_player_n > 0.0f) {
- float co_player_n = co_player_aggregate.co_player_n;
- // Only divide non-zero values to avoid corruption
- for (int i = 0; i < num_co_player_keys; i++) {
- if (((float*)&co_player_aggregate)[i] != 0.0f) {
- ((float*)&co_player_aggregate)[i] /= co_player_n;
- }
+ // User populates dict
+ my_log(dict, &aggregate);
+ assign_to_dict(dict, "n", n);
+
+ // Handle co-player metrics
+ if (has_co_players && co_player_aggregate.n > 0.0f) {
+ float co_player_n = co_player_aggregate.n;
+
+ // Compute co-player completion rate from raw totals BEFORE averaging
+ float co_total_goals_reached = co_player_aggregate.goals_reached_this_episode;
+ float co_total_goals_sampled = co_player_aggregate.goals_sampled_this_episode;
+ if (co_total_goals_sampled > 0) {
+ co_player_aggregate.completion_rate = co_total_goals_reached / co_total_goals_sampled;
+ } else {
+ co_player_aggregate.completion_rate = 0.0f;
}
- // Add co-player metrics directly
- assign_to_dict(dict, "ego_co_player_ratio", aggregate.n / co_player_n);
- assign_to_dict(dict, "co_player_completion_rate", co_player_aggregate.co_player_completion_rate);
- assign_to_dict(dict, "co_player_collision_rate", co_player_aggregate.co_player_collision_rate);
- assign_to_dict(dict, "co_player_offroad_rate", co_player_aggregate.co_player_offroad_rate);
- assign_to_dict(dict, "co_player_clean_collision_rate", co_player_aggregate.co_player_clean_collision_rate);
- assign_to_dict(dict, "co_player_num_goals_reached", co_player_aggregate.co_player_num_goals_reached);
- assign_to_dict(dict, "co_player_score", co_player_aggregate.co_player_score);
- assign_to_dict(dict, "co_player_perf", co_player_aggregate.co_player_perf);
- assign_to_dict(dict, "co_player_dnf_rate", co_player_aggregate.co_player_dnf_rate);
- assign_to_dict(dict, "co_player_episode_length", co_player_aggregate.co_player_episode_length);
- assign_to_dict(dict, "co_player_episode_return", co_player_aggregate.co_player_episode_return);
- assign_to_dict(dict, "co_player_lane_alignment_rate", co_player_aggregate.co_player_lane_alignment_rate);
- assign_to_dict(dict, "co_player_avg_displacement_error", co_player_aggregate.co_player_avg_displacement_error);
+ // Average co-player metrics across CO-PLAYER agents only
+ for (int i = 0; i < num_keys; i++) {
+ ((float *)&co_player_aggregate)[i] /= co_player_n;
+ }
+
+ // Add co-player metrics to dict with co_player_ prefix
+ assign_to_dict(dict, "ego_co_player_ratio", n / co_player_n);
+ assign_to_dict(dict, "co_player_completion_rate", co_player_aggregate.completion_rate);
+ assign_to_dict(dict, "co_player_collision_rate", co_player_aggregate.collision_rate);
+ assign_to_dict(dict, "co_player_collisions_per_agent", co_player_aggregate.collisions_per_agent);
+ assign_to_dict(dict, "co_player_offroad_rate", co_player_aggregate.offroad_rate);
+ assign_to_dict(dict, "co_player_offroad_per_agent", co_player_aggregate.offroad_per_agent);
+ assign_to_dict(dict, "co_player_score", co_player_aggregate.score);
+ assign_to_dict(dict, "co_player_dnf_rate", co_player_aggregate.dnf_rate);
+ assign_to_dict(dict, "co_player_episode_length", co_player_aggregate.episode_length);
+ assign_to_dict(dict, "co_player_episode_return", co_player_aggregate.episode_return);
+ assign_to_dict(dict, "co_player_lane_alignment_rate", co_player_aggregate.lane_alignment_rate);
+ assign_to_dict(dict, "co_player_speed_at_goal", co_player_aggregate.speed_at_goal);
+ assign_to_dict(dict, "co_player_goals_reached_this_episode", co_player_aggregate.goals_reached_this_episode);
+ assign_to_dict(dict, "co_player_goals_sampled_this_episode", co_player_aggregate.goals_sampled_this_episode);
assign_to_dict(dict, "co_player_n", co_player_n);
}
-
return dict;
}
-
-static PyObject* vec_close(PyObject* self, PyObject* args) {
- VecEnv* vec = unpack_vecenv(args);
+static PyObject *vec_close(PyObject *self, PyObject *args) {
+ VecEnv *vec = unpack_vecenv(args);
if (!vec) {
return NULL;
}
@@ -668,93 +722,118 @@ static PyObject* vec_close(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
-static PyObject* get_global_agent_state(PyObject* self, PyObject* args) {
- if (PyTuple_Size(args) != 5) {
- PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 5 arguments");
+// Render-mode helpers: stash env[0]->client into a global before vec_close so
+// raylib + ffmpeg pipe survive the map swap, then re-attach to env[0] of the
+// freshly built vec. Single-slot global: render envs are single-env.
+static PyObject *vec_donate_client(PyObject *self, PyObject *args) {
+ VecEnv *vec = unpack_vecenv(args);
+ if (!vec || vec->num_envs == 0) {
+ Py_RETURN_NONE;
+ }
+ c_donate_client(vec->envs[0]);
+ Py_RETURN_NONE;
+}
+
+static PyObject *vec_adopt_client(PyObject *self, PyObject *args) {
+ VecEnv *vec = unpack_vecenv(args);
+ if (!vec || vec->num_envs == 0) {
+ Py_RETURN_NONE;
+ }
+ c_adopt_client(vec->envs[0]);
+ Py_RETURN_NONE;
+}
+
+static PyObject *get_global_agent_state(PyObject *self, PyObject *args) {
+ if (PyTuple_Size(args) != 7) {
+ PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 7 arguments");
return NULL;
}
- Env* env = unpack_env(args);
+ Env *env = unpack_env(args);
if (!env) {
return NULL;
}
- Drive* drive = (Drive*)env; // Cast to Drive*
+ Drive *drive = (Drive *)env; // Cast to Drive*
// Get the numpy arrays from arguments
- PyObject* x_arr = PyTuple_GetItem(args, 1);
- PyObject* y_arr = PyTuple_GetItem(args, 2);
- PyObject* z_arr = PyTuple_GetItem(args, 3);
- PyObject* heading_arr = PyTuple_GetItem(args, 4);
- PyObject* id_arr = PyTuple_GetItem(args, 5);
-
- if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) ||
- !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
- !PyArray_Check(id_arr)) {
+ PyObject *x_arr = PyTuple_GetItem(args, 1);
+ PyObject *y_arr = PyTuple_GetItem(args, 2);
+ PyObject *z_arr = PyTuple_GetItem(args, 3);
+ PyObject *heading_arr = PyTuple_GetItem(args, 4);
+ PyObject *id_arr = PyTuple_GetItem(args, 5);
+ PyObject *length_arr = PyTuple_GetItem(args, 6);
+ PyObject *width_arr = PyTuple_GetItem(args, 7);
+
+ if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
+ !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || !PyArray_Check(width_arr)) {
PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays");
return NULL;
}
- float* x_data = (float*)PyArray_DATA((PyArrayObject*)x_arr);
- float* y_data = (float*)PyArray_DATA((PyArrayObject*)y_arr);
- float* z_data = (float*)PyArray_DATA((PyArrayObject*)z_arr);
- float* heading_data = (float*)PyArray_DATA((PyArrayObject*)heading_arr);
- int* id_data = (int*)PyArray_DATA((PyArrayObject*)id_arr);
+ float *x_data = (float *)PyArray_DATA((PyArrayObject *)x_arr);
+ float *y_data = (float *)PyArray_DATA((PyArrayObject *)y_arr);
+ float *z_data = (float *)PyArray_DATA((PyArrayObject *)z_arr);
+ float *heading_data = (float *)PyArray_DATA((PyArrayObject *)heading_arr);
+ int *id_data = (int *)PyArray_DATA((PyArrayObject *)id_arr);
+ float *length_data = (float *)PyArray_DATA((PyArrayObject *)length_arr);
+ float *width_data = (float *)PyArray_DATA((PyArrayObject *)width_arr);
- c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data);
+ c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data, length_data, width_data);
Py_RETURN_NONE;
}
-static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) {
- if (PyTuple_Size(args) != 6) {
- PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 6 arguments");
+static PyObject *vec_get_global_agent_state(PyObject *self, PyObject *args) {
+ if (PyTuple_Size(args) != 8) {
+ PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 8 arguments");
return NULL;
}
- VecEnv* vec = unpack_vecenv(args);
+ VecEnv *vec = unpack_vecenv(args);
if (!vec) {
return NULL;
}
// Get the numpy arrays from arguments
- PyObject* x_arr = PyTuple_GetItem(args, 1);
- PyObject* y_arr = PyTuple_GetItem(args, 2);
- PyObject* z_arr = PyTuple_GetItem(args, 3);
- PyObject* heading_arr = PyTuple_GetItem(args, 4);
- PyObject* id_arr = PyTuple_GetItem(args, 5);
-
- if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) ||
- !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
- !PyArray_Check(id_arr)) {
+ PyObject *x_arr = PyTuple_GetItem(args, 1);
+ PyObject *y_arr = PyTuple_GetItem(args, 2);
+ PyObject *z_arr = PyTuple_GetItem(args, 3);
+ PyObject *heading_arr = PyTuple_GetItem(args, 4);
+ PyObject *id_arr = PyTuple_GetItem(args, 5);
+ PyObject *length_arr = PyTuple_GetItem(args, 6);
+ PyObject *width_arr = PyTuple_GetItem(args, 7);
+
+ if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
+ !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || !PyArray_Check(width_arr)) {
PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays");
return NULL;
}
- PyArrayObject* x_array = (PyArrayObject*)x_arr;
- PyArrayObject* y_array = (PyArrayObject*)y_arr;
- PyArrayObject* z_array = (PyArrayObject*)z_arr;
- PyArrayObject* heading_array = (PyArrayObject*)heading_arr;
- PyArrayObject* id_array = (PyArrayObject*)id_arr;
+ PyArrayObject *x_array = (PyArrayObject *)x_arr;
+ PyArrayObject *y_array = (PyArrayObject *)y_arr;
+ PyArrayObject *z_array = (PyArrayObject *)z_arr;
+ PyArrayObject *heading_array = (PyArrayObject *)heading_arr;
+ PyArrayObject *id_array = (PyArrayObject *)id_arr;
+ PyArrayObject *length_array = (PyArrayObject *)length_arr;
+ PyArrayObject *width_array = (PyArrayObject *)width_arr;
// Get base pointers to the arrays
- float* x_base = (float*)PyArray_DATA(x_array);
- float* y_base = (float*)PyArray_DATA(y_array);
- float* z_base = (float*)PyArray_DATA(z_array);
- float* heading_base = (float*)PyArray_DATA(heading_array);
- int* id_base = (int*)PyArray_DATA(id_array);
+ float *x_base = (float *)PyArray_DATA(x_array);
+ float *y_base = (float *)PyArray_DATA(y_array);
+ float *z_base = (float *)PyArray_DATA(z_array);
+ float *heading_base = (float *)PyArray_DATA(heading_array);
+ int *id_base = (int *)PyArray_DATA(id_array);
+ float *length_base = (float *)PyArray_DATA(length_array);
+ float *width_base = (float *)PyArray_DATA(width_array);
// Iterate through environments and write to correct offsets
int offset = 0;
for (int i = 0; i < vec->num_envs; i++) {
- Drive* drive = (Drive*)vec->envs[i];
+ Drive *drive = (Drive *)vec->envs[i];
// Write to the arrays at the current offset
- c_get_global_agent_state(drive,
- &x_base[offset],
- &y_base[offset],
- &z_base[offset],
- &heading_base[offset],
- &id_base[offset]);
+ c_get_global_agent_state(drive, &x_base[offset], &y_base[offset], &z_base[offset], &heading_base[offset],
+ &id_base[offset], &length_base[offset], &width_base[offset]);
// Move offset forward by the number of agents in this environment
offset += drive->active_agent_count;
@@ -763,111 +842,105 @@ static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
-static PyObject* get_ground_truth_trajectories(PyObject* self, PyObject* args) {
+static PyObject *get_ground_truth_trajectories(PyObject *self, PyObject *args) {
if (PyTuple_Size(args) != 7) {
PyErr_SetString(PyExc_TypeError, "get_ground_truth_trajectories requires 7 arguments");
return NULL;
}
- Env* env = unpack_env(args);
+ Env *env = unpack_env(args);
if (!env) {
return NULL;
}
- Drive* drive = (Drive*)env;
+ Drive *drive = (Drive *)env;
// Get the numpy arrays from arguments
- PyObject* x_arr = PyTuple_GetItem(args, 1);
- PyObject* y_arr = PyTuple_GetItem(args, 2);
- PyObject* z_arr = PyTuple_GetItem(args, 3);
- PyObject* heading_arr = PyTuple_GetItem(args, 4);
- PyObject* valid_arr = PyTuple_GetItem(args, 5);
- PyObject* id_arr = PyTuple_GetItem(args, 6);
- PyObject* scenario_id_arr = PyTuple_GetItem(args, 7);
-
- if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) ||
- !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
+ PyObject *x_arr = PyTuple_GetItem(args, 1);
+ PyObject *y_arr = PyTuple_GetItem(args, 2);
+ PyObject *z_arr = PyTuple_GetItem(args, 3);
+ PyObject *heading_arr = PyTuple_GetItem(args, 4);
+ PyObject *valid_arr = PyTuple_GetItem(args, 5);
+ PyObject *id_arr = PyTuple_GetItem(args, 6);
+ PyObject *scenario_id_arr = PyTuple_GetItem(args, 7);
+
+ if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
!PyArray_Check(valid_arr) || !PyArray_Check(id_arr) || !PyArray_Check(scenario_id_arr)) {
PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays");
return NULL;
}
- float* x_data = (float*)PyArray_DATA((PyArrayObject*)x_arr);
- float* y_data = (float*)PyArray_DATA((PyArrayObject*)y_arr);
- float* z_data = (float*)PyArray_DATA((PyArrayObject*)z_arr);
- float* heading_data = (float*)PyArray_DATA((PyArrayObject*)heading_arr);
- int* valid_data = (int*)PyArray_DATA((PyArrayObject*)valid_arr);
- int* id_data = (int*)PyArray_DATA((PyArrayObject*)id_arr);
- int* scenario_id_data = (int*)PyArray_DATA((PyArrayObject*)scenario_id_arr);
+ float *x_data = (float *)PyArray_DATA((PyArrayObject *)x_arr);
+ float *y_data = (float *)PyArray_DATA((PyArrayObject *)y_arr);
+ float *z_data = (float *)PyArray_DATA((PyArrayObject *)z_arr);
+ float *heading_data = (float *)PyArray_DATA((PyArrayObject *)heading_arr);
+ int *valid_data = (int *)PyArray_DATA((PyArrayObject *)valid_arr);
+ int *id_data = (int *)PyArray_DATA((PyArrayObject *)id_arr);
+ int *scenario_id_data = (int *)PyArray_DATA((PyArrayObject *)scenario_id_arr);
- c_get_global_ground_truth_trajectories(drive, x_data, y_data, z_data, heading_data, valid_data, id_data, scenario_id_data);
+ c_get_global_ground_truth_trajectories(drive, x_data, y_data, z_data, heading_data, valid_data, id_data,
+ scenario_id_data);
Py_RETURN_NONE;
}
-static PyObject* vec_get_global_ground_truth_trajectories(PyObject* self, PyObject* args) {
+static PyObject *vec_get_global_ground_truth_trajectories(PyObject *self, PyObject *args) {
if (PyTuple_Size(args) != 8) {
PyErr_SetString(PyExc_TypeError, "vec_get_global_ground_truth_trajectories requires 8 arguments");
return NULL;
}
- VecEnv* vec = unpack_vecenv(args);
+ VecEnv *vec = unpack_vecenv(args);
if (!vec) {
return NULL;
}
// Get the numpy arrays from arguments
- PyObject* x_arr = PyTuple_GetItem(args, 1);
- PyObject* y_arr = PyTuple_GetItem(args, 2);
- PyObject* z_arr = PyTuple_GetItem(args, 3);
- PyObject* heading_arr = PyTuple_GetItem(args, 4);
- PyObject* valid_arr = PyTuple_GetItem(args, 5);
- PyObject* id_arr = PyTuple_GetItem(args, 6);
- PyObject* scenario_id_arr = PyTuple_GetItem(args, 7);
-
- if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) ||
- !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
+ PyObject *x_arr = PyTuple_GetItem(args, 1);
+ PyObject *y_arr = PyTuple_GetItem(args, 2);
+ PyObject *z_arr = PyTuple_GetItem(args, 3);
+ PyObject *heading_arr = PyTuple_GetItem(args, 4);
+ PyObject *valid_arr = PyTuple_GetItem(args, 5);
+ PyObject *id_arr = PyTuple_GetItem(args, 6);
+ PyObject *scenario_id_arr = PyTuple_GetItem(args, 7);
+
+ if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) ||
!PyArray_Check(valid_arr) || !PyArray_Check(id_arr) || !PyArray_Check(scenario_id_arr)) {
PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays");
return NULL;
}
- PyArrayObject* x_array = (PyArrayObject*)x_arr;
- PyArrayObject* y_array = (PyArrayObject*)y_arr;
- PyArrayObject* z_array = (PyArrayObject*)z_arr;
- PyArrayObject* heading_array = (PyArrayObject*)heading_arr;
- PyArrayObject* valid_array = (PyArrayObject*)valid_arr;
- PyArrayObject* id_array = (PyArrayObject*)id_arr;
- PyArrayObject* scenario_id_array = (PyArrayObject*)scenario_id_arr;
+ PyArrayObject *x_array = (PyArrayObject *)x_arr;
+ PyArrayObject *y_array = (PyArrayObject *)y_arr;
+ PyArrayObject *z_array = (PyArrayObject *)z_arr;
+ PyArrayObject *heading_array = (PyArrayObject *)heading_arr;
+ PyArrayObject *valid_array = (PyArrayObject *)valid_arr;
+ PyArrayObject *id_array = (PyArrayObject *)id_arr;
+ PyArrayObject *scenario_id_array = (PyArrayObject *)scenario_id_arr;
// Get base pointers to the arrays
- float* x_base = (float*)PyArray_DATA(x_array);
- float* y_base = (float*)PyArray_DATA(y_array);
- float* z_base = (float*)PyArray_DATA(z_array);
- float* heading_base = (float*)PyArray_DATA(heading_array);
- int* valid_base = (int*)PyArray_DATA(valid_array);
- int* id_base = (int*)PyArray_DATA(id_array);
- int* scenario_id_base = (int*)PyArray_DATA(scenario_id_array);
+ float *x_base = (float *)PyArray_DATA(x_array);
+ float *y_base = (float *)PyArray_DATA(y_array);
+ float *z_base = (float *)PyArray_DATA(z_array);
+ float *heading_base = (float *)PyArray_DATA(heading_array);
+ int *valid_base = (int *)PyArray_DATA(valid_array);
+ int *id_base = (int *)PyArray_DATA(id_array);
+ int *scenario_id_base = (int *)PyArray_DATA(scenario_id_array);
// Get number of timesteps from array shape
- npy_intp* x_shape = PyArray_DIMS(x_array);
- int num_timesteps = x_shape[1]; // Second dimension for 2D arrays
+ npy_intp *x_shape = PyArray_DIMS(x_array);
+ int num_timesteps = x_shape[1]; // Second dimension for 2D arrays
// Iterate through environments and write to correct offsets
- int agent_offset = 0; // Offset for 1D arrays (id, scenario_id)
- int traj_offset = 0; // Offset for 2D arrays (x, y, z, heading, valid)
+ int agent_offset = 0; // Offset for 1D arrays (id, scenario_id)
+ int traj_offset = 0; // Offset for 2D arrays (x, y, z, heading, valid)
for (int i = 0; i < vec->num_envs; i++) {
- Drive* drive = (Drive*)vec->envs[i];
+ Drive *drive = (Drive *)vec->envs[i];
- c_get_global_ground_truth_trajectories(drive,
- &x_base[traj_offset],
- &y_base[traj_offset],
- &z_base[traj_offset],
- &heading_base[traj_offset],
- &valid_base[traj_offset],
- &id_base[agent_offset],
- &scenario_id_base[agent_offset]);
+ c_get_global_ground_truth_trajectories(drive, &x_base[traj_offset], &y_base[traj_offset], &z_base[traj_offset],
+ &heading_base[traj_offset], &valid_base[traj_offset],
+ &id_base[agent_offset], &scenario_id_base[agent_offset]);
// Move offsets forward
agent_offset += drive->active_agent_count;
@@ -876,8 +949,64 @@ static PyObject* vec_get_global_ground_truth_trajectories(PyObject* self, PyObje
Py_RETURN_NONE;
}
-static double unpack(PyObject* kwargs, char* key) {
- PyObject* val = PyDict_GetItemString(kwargs, key);
+
+static PyObject *vec_get_road_edge_counts(PyObject *self, PyObject *args) {
+ VecEnv *vec = unpack_vecenv(args);
+ if (!vec)
+ return NULL;
+
+ int total_polylines = 0, total_points = 0;
+ for (int i = 0; i < vec->num_envs; i++) {
+ Drive *drive = (Drive *)vec->envs[i];
+ int np, tp;
+ c_get_road_edge_counts(drive, &np, &tp);
+ total_polylines += np;
+ total_points += tp;
+ }
+ return Py_BuildValue("(ii)", total_polylines, total_points);
+}
+
+static PyObject *vec_get_road_edge_polylines(PyObject *self, PyObject *args) {
+ if (PyTuple_Size(args) != 5) {
+ PyErr_SetString(PyExc_TypeError, "vec_get_road_edge_polylines requires 5 arguments");
+ return NULL;
+ }
+
+ VecEnv *vec = unpack_vecenv(args);
+ if (!vec)
+ return NULL;
+
+ PyObject *x_arr = PyTuple_GetItem(args, 1);
+ PyObject *y_arr = PyTuple_GetItem(args, 2);
+ PyObject *lengths_arr = PyTuple_GetItem(args, 3);
+ PyObject *scenario_ids_arr = PyTuple_GetItem(args, 4);
+
+ if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(lengths_arr) ||
+ !PyArray_Check(scenario_ids_arr)) {
+ PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays");
+ return NULL;
+ }
+
+ float *x_base = (float *)PyArray_DATA((PyArrayObject *)x_arr);
+ float *y_base = (float *)PyArray_DATA((PyArrayObject *)y_arr);
+ int *lengths_base = (int *)PyArray_DATA((PyArrayObject *)lengths_arr);
+ int *scenario_ids_base = (int *)PyArray_DATA((PyArrayObject *)scenario_ids_arr);
+
+ int poly_offset = 0, pt_offset = 0;
+ for (int i = 0; i < vec->num_envs; i++) {
+ Drive *drive = (Drive *)vec->envs[i];
+ int np, tp;
+ c_get_road_edge_counts(drive, &np, &tp);
+ c_get_road_edge_polylines(drive, &x_base[pt_offset], &y_base[pt_offset], &lengths_base[poly_offset],
+ &scenario_ids_base[poly_offset]);
+ poly_offset += np;
+ pt_offset += tp;
+ }
+ Py_RETURN_NONE;
+}
+
+static double unpack(PyObject *kwargs, char *key) {
+ PyObject *val = PyDict_GetItemString(kwargs, key);
if (val == NULL) {
char error_msg[100];
snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key);
@@ -904,8 +1033,8 @@ static double unpack(PyObject* kwargs, char* key) {
return 1;
}
-static char* unpack_str(PyObject* kwargs, char* key) {
- PyObject* val = PyDict_GetItemString(kwargs, key);
+static char *unpack_str(PyObject *kwargs, char *key) {
+ PyObject *val = PyDict_GetItemString(kwargs, key);
if (val == NULL) {
char error_msg[100];
snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key);
@@ -918,12 +1047,12 @@ static char* unpack_str(PyObject* kwargs, char* key) {
PyErr_SetString(PyExc_TypeError, error_msg);
return NULL;
}
- const char* str_val = PyUnicode_AsUTF8(val);
+ const char *str_val = PyUnicode_AsUTF8(val);
if (str_val == NULL) {
// PyUnicode_AsUTF8 sets an error on failure
return NULL;
}
- char* ret = strdup(str_val);
+ char *ret = strdup(str_val);
if (ret == NULL) {
PyErr_SetString(PyExc_MemoryError, "strdup failed in unpack_str");
}
@@ -932,7 +1061,8 @@ static char* unpack_str(PyObject* kwargs, char* key) {
// Method table
static PyMethodDef methods[] = {
- {"env_init", (PyCFunction)env_init, METH_VARARGS | METH_KEYWORDS, "Init environment with observation, action, reward, terminal, truncation arrays"},
+ {"env_init", (PyCFunction)env_init, METH_VARARGS | METH_KEYWORDS,
+ "Init environment with observation, action, reward, terminal, truncation arrays"},
{"env_reset", env_reset, METH_VARARGS, "Reset the environment"},
{"env_step", env_step, METH_VARARGS, "Step the environment"},
{"env_render", env_render, METH_VARARGS, "Render the environment"},
@@ -945,26 +1075,56 @@ static PyMethodDef methods[] = {
{"vec_step", vec_step, METH_VARARGS, "Step the vector of environments"},
{"vec_log", vec_log, METH_VARARGS, "Log the vector of environments"},
{"vec_render", vec_render, METH_VARARGS, "Render the vector of environments"},
+ {"vec_set_video_suffix", vec_set_video_suffix, METH_VARARGS, "Set video filename suffix for headless rendering"},
{"vec_close", vec_close, METH_VARARGS, "Close the vector of environments"},
+ {"vec_donate_client", vec_donate_client, METH_VARARGS,
+ "Stash env[0]->client into a global so it survives a subsequent vec_close (render only)"},
+ {"vec_adopt_client", vec_adopt_client, METH_VARARGS,
+ "Re-attach the previously donated client to env[0] of the new vec (render only)"},
{"shared", (PyCFunction)my_shared, METH_VARARGS | METH_KEYWORDS, "Shared state"},
{"get_global_agent_state", get_global_agent_state, METH_VARARGS, "Get global agent state"},
{"vec_get_global_agent_state", vec_get_global_agent_state, METH_VARARGS, "Get agent state from vectorized env"},
{"get_ground_truth_trajectories", get_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories"},
- {"vec_get_global_ground_truth_trajectories", vec_get_global_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories from vectorized env"},
+ {"vec_get_global_ground_truth_trajectories", vec_get_global_ground_truth_trajectories, METH_VARARGS,
+ "Get ground truth trajectories from vectorized env"},
+ {"vec_get_road_edge_counts", vec_get_road_edge_counts, METH_VARARGS,
+ "Get road edge polyline counts from vectorized env"},
+ {"vec_get_road_edge_polylines", vec_get_road_edge_polylines, METH_VARARGS,
+ "Get road edge polylines from vectorized env"},
MY_METHODS,
- {NULL, NULL, 0, NULL}
-};
+ {NULL, NULL, 0, NULL}};
// Module definition
-static PyModuleDef module = {
- PyModuleDef_HEAD_INIT,
- "binding",
- NULL,
- -1,
- methods
-};
+static PyModuleDef module = {PyModuleDef_HEAD_INIT, "binding", NULL, -1, methods};
PyMODINIT_FUNC PyInit_binding(void) {
import_array();
- return PyModule_Create(&module);
+ PyObject *m = PyModule_Create(&module); // Changed variable name from 'module' to 'm'
+
+ if (m == NULL) {
+ return NULL;
+ }
+
+ // Make constants accessible from Python
+ PyModule_AddIntConstant(m, "MAX_ROAD_SEGMENT_OBSERVATIONS", MAX_ROAD_SEGMENT_OBSERVATIONS);
+ PyModule_AddIntConstant(m, "MAX_AGENTS", MAX_AGENTS);
+ PyModule_AddIntConstant(m, "TRAJECTORY_LENGTH", TRAJECTORY_LENGTH);
+ PyModule_AddIntConstant(m, "MAX_ENTITIES_PER_CELL", MAX_ENTITIES_PER_CELL);
+
+ PyModule_AddIntConstant(m, "ROAD_FEATURES", ROAD_FEATURES);
+ PyModule_AddIntConstant(m, "PARTNER_FEATURES", PARTNER_FEATURES);
+ PyModule_AddIntConstant(m, "EGO_FEATURES_CLASSIC", EGO_FEATURES_CLASSIC);
+ PyModule_AddIntConstant(m, "EGO_FEATURES_JERK", EGO_FEATURES_JERK);
+
+ // Render mode constants
+ PyModule_AddIntConstant(m, "RENDER_OFF", RENDER_OFF);
+ PyModule_AddIntConstant(m, "RENDER_HEADLESS", RENDER_HEADLESS);
+ PyModule_AddIntConstant(m, "RENDER_WINDOW", RENDER_WINDOW);
+
+ // View mode constants
+ PyModule_AddIntConstant(m, "VIEW_MODE_SIM_STATE", VIEW_MODE_SIM_STATE);
+ PyModule_AddIntConstant(m, "VIEW_MODE_BEV_AGENT_OBS", VIEW_MODE_BEV_AGENT_OBS);
+ PyModule_AddIntConstant(m, "VIEW_MODE_AGENT_PERSP", VIEW_MODE_AGENT_PERSP);
+
+ return m;
}
diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h
index e8400c51c9..35f2806806 100644
--- a/pufferlib/ocean/env_config.h
+++ b/pufferlib/ocean/env_config.h
@@ -6,9 +6,22 @@
#include
#include
+typedef struct {
+ char *type;
+ float reward_offroad_weight_lb;
+ float reward_offroad_weight_ub;
+ float reward_collision_weight_lb;
+ float reward_collision_weight_ub;
+ float reward_goal_weight_lb;
+ float reward_goal_weight_ub;
+ float entropy_weight_lb;
+ float entropy_weight_ub;
+ float discount_weight_lb;
+ float discount_weight_ub;
+} conditioning_config;
+
// Config struct for parsing INI files - contains all environment configuration
-typedef struct
-{
+typedef struct {
int action_type;
int dynamics_model;
float reward_vehicle_collision;
@@ -16,49 +29,56 @@ typedef struct
float reward_goal;
float reward_goal_post_respawn;
float reward_vehicle_collision_post_respawn;
- float reward_ade;
float goal_radius;
+ float goal_speed;
int collision_behavior;
int offroad_behavior;
int spawn_immunity_timer;
float dt;
int goal_behavior;
+ float goal_target_distance;
int scenario_length;
+ int k_scenarios;
+ int termination_mode;
int init_steps;
int init_mode;
int control_mode;
+ char map_dir[256];
+ conditioning_config *conditioning;
+ // Population play settings
+ int co_player_enabled;
+ int num_ego_agents;
+ char co_player_policy_path[256];
+ conditioning_config *co_player_conditioning;
} env_init_config;
// INI file parser handler - parses all environment configuration from drive.ini
-static int handler(
- void* config,
- const char* section,
- const char* name,
- const char* value
-) {
- env_init_config* env_config = (env_init_config*)config;
- #define MATCH(s, n) strcmp(section, s) == 0 && strcmp(name, n) == 0
+static int handler(void *config, const char *section, const char *name, const char *value) {
+ env_init_config *env_config = (env_init_config *)config;
+#define MATCH(s, n) strcmp(section, s) == 0 && strcmp(name, n) == 0
if (MATCH("env", "action_type")) {
- if (strcmp(value, "\"discrete\"") == 0 ||strcmp(value, "discrete") == 0) {
- env_config->action_type = 0; // DISCRETE
+ if (strcmp(value, "\"discrete\"") == 0 || strcmp(value, "discrete") == 0) {
+ env_config->action_type = 0; // DISCRETE
} else if (strcmp(value, "\"continuous\"") == 0 || strcmp(value, "continuous") == 0) {
- env_config->action_type = 1; // CONTINUOUS
+ env_config->action_type = 1; // CONTINUOUS
} else {
printf("Warning: Unknown action_type value '%s', defaulting to DISCRETE\n", value);
- env_config->action_type = 0; // Default to DISCRETE
+ env_config->action_type = 0; // Default to DISCRETE
}
} else if (MATCH("env", "dynamics_model")) {
if (strcmp(value, "\"classic\"") == 0 || strcmp(value, "classic") == 0) {
- env_config->dynamics_model = 0; // CLASSIC
+ env_config->dynamics_model = 0; // CLASSIC
} else if (strcmp(value, "\"jerk\"") == 0 || strcmp(value, "jerk") == 0) {
- env_config->dynamics_model = 1; // JERK
+ env_config->dynamics_model = 1; // JERK
} else {
printf("Warning: Unknown dynamics_model value '%s', defaulting to JERK\n", value);
- env_config->dynamics_model = 1; // Default to JERK
+ env_config->dynamics_model = 1; // Default to JERK
}
} else if (MATCH("env", "goal_behavior")) {
env_config->goal_behavior = atoi(value);
+ } else if (MATCH("env", "goal_target_distance")) {
+ env_config->goal_target_distance = atof(value);
} else if (MATCH("env", "reward_vehicle_collision")) {
env_config->reward_vehicle_collision = atof(value);
} else if (MATCH("env", "reward_offroad_collision")) {
@@ -69,13 +89,13 @@ static int handler(
env_config->reward_goal_post_respawn = atof(value);
} else if (MATCH("env", "reward_vehicle_collision_post_respawn")) {
env_config->reward_vehicle_collision_post_respawn = atof(value);
- } else if (MATCH("env", "reward_ade")) {
- env_config->reward_ade = atof(value);
} else if (MATCH("env", "goal_radius")) {
env_config->goal_radius = atof(value);
- } else if(MATCH("env", "collision_behavior")){
+ } else if (MATCH("env", "goal_speed")) {
+ env_config->goal_speed = atof(value);
+ } else if (MATCH("env", "collision_behavior")) {
env_config->collision_behavior = atoi(value);
- } else if(MATCH("env", "offroad_behavior")){
+ } else if (MATCH("env", "offroad_behavior")) {
env_config->offroad_behavior = atoi(value);
} else if (MATCH("env", "spawn_immunity_timer")) {
env_config->spawn_immunity_timer = atoi(value);
@@ -83,15 +103,171 @@ static int handler(
env_config->dt = atof(value);
} else if (MATCH("env", "scenario_length")) {
env_config->scenario_length = atoi(value);
+ } else if (MATCH("env", "k_scenarios")) {
+ env_config->k_scenarios = atoi(value);
+ } else if (MATCH("env", "termination_mode")) {
+ env_config->termination_mode = atoi(value);
} else if (MATCH("env", "init_steps")) {
env_config->init_steps = atoi(value);
} else if (MATCH("env", "init_mode")) {
env_config->init_mode = atoi(value);
} else if (MATCH("env", "control_mode")) {
env_config->control_mode = atoi(value);
+ } else if (MATCH("env", "map_dir")) {
+ if (sscanf(value, "\"%255[^\"]\"", env_config->map_dir) != 1) {
+ strncpy(env_config->map_dir, value, sizeof(env_config->map_dir) - 1);
+ env_config->map_dir[sizeof(env_config->map_dir) - 1] = '\0';
+ }
+ // printf("Parsed map_dir: '%s'\n", env_config->map_dir);
+ } else if (MATCH("env.conditioning", "type")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ // Remove quotes if present
+ if (value[0] == '"') {
+ size_t len = strlen(value) - 2; // -2 for both quotes
+ env_config->conditioning->type = (char *)malloc(len + 1);
+ strncpy(env_config->conditioning->type, value + 1, len);
+ env_config->conditioning->type[len] = '\0';
+ } else {
+ env_config->conditioning->type = strdup(value);
+ }
+ } else if (MATCH("env.conditioning", "collision_weight_lb")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->reward_collision_weight_lb = atof(value);
+ } else if (MATCH("env.conditioning", "collision_weight_ub")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->reward_collision_weight_ub = atof(value);
+ } else if (MATCH("env.conditioning", "offroad_weight_lb")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->reward_offroad_weight_lb = atof(value);
+ } else if (MATCH("env.conditioning", "offroad_weight_ub")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->reward_offroad_weight_ub = atof(value);
+ } else if (MATCH("env.conditioning", "goal_weight_lb")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->reward_goal_weight_lb = atof(value);
+ } else if (MATCH("env.conditioning", "goal_weight_ub")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->reward_goal_weight_ub = atof(value);
+ } else if (MATCH("env.conditioning", "entropy_weight_lb")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->entropy_weight_lb = atof(value);
+ } else if (MATCH("env.conditioning", "entropy_weight_ub")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->entropy_weight_ub = atof(value);
+ } else if (MATCH("env.conditioning", "discount_weight_lb")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->discount_weight_lb = atof(value);
+ } else if (MATCH("env.conditioning", "discount_weight_ub")) {
+ if (env_config->conditioning == NULL) {
+ env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->conditioning->discount_weight_ub = atof(value);
+ }
+ // Population play settings
+ else if (MATCH("env", "co_player_enabled")) {
+ if (strcmp(value, "True") == 0 || strcmp(value, "true") == 0 || strcmp(value, "1") == 0) {
+ env_config->co_player_enabled = 1;
+ } else {
+ env_config->co_player_enabled = 0;
+ }
+ } else if (MATCH("env", "num_ego_agents")) {
+ env_config->num_ego_agents = atoi(value);
+ }
+ // Co-player policy settings
+ else if (MATCH("env.co_player_policy", "policy_path")) {
+ if (sscanf(value, "\"%255[^\"]\"", env_config->co_player_policy_path) != 1) {
+ strncpy(env_config->co_player_policy_path, value, sizeof(env_config->co_player_policy_path) - 1);
+ env_config->co_player_policy_path[sizeof(env_config->co_player_policy_path) - 1] = '\0';
+ }
+ } else if (MATCH("env.co_player_policy.conditioning", "type")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ if (value[0] == '"') {
+ size_t len = strlen(value) - 2;
+ env_config->co_player_conditioning->type = (char *)malloc(len + 1);
+ strncpy(env_config->co_player_conditioning->type, value + 1, len);
+ env_config->co_player_conditioning->type[len] = '\0';
+ } else {
+ env_config->co_player_conditioning->type = strdup(value);
+ }
+ } else if (MATCH("env.co_player_policy.conditioning", "collision_weight_lb")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->reward_collision_weight_lb = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "collision_weight_ub")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->reward_collision_weight_ub = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "offroad_weight_lb")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->reward_offroad_weight_lb = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "offroad_weight_ub")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->reward_offroad_weight_ub = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "goal_weight_lb")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->reward_goal_weight_lb = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "goal_weight_ub")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->reward_goal_weight_ub = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "entropy_weight_lb")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->entropy_weight_lb = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "entropy_weight_ub")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->entropy_weight_ub = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "discount_weight_lb")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->discount_weight_lb = atof(value);
+ } else if (MATCH("env.co_player_policy.conditioning", "discount_weight_ub")) {
+ if (env_config->co_player_conditioning == NULL) {
+ env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config));
+ }
+ env_config->co_player_conditioning->discount_weight_ub = atof(value);
+ }
+
+ else {
+ return 0; // Unknown section/name, indicate failure to handle
}
- #undef MATCH
+#undef MATCH
return 1;
}
diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py
index 7aa3daa388..8a2a4c0b39 100644
--- a/pufferlib/ocean/torch.py
+++ b/pufferlib/ocean/torch.py
@@ -10,12 +10,19 @@
Recurrent = pufferlib.models.LSTMWrapper
+Transformer = pufferlib.models.TransformerWrapper
class Drive(nn.Module):
def __init__(self, env, input_size=128, hidden_size=128, **kwargs):
super().__init__()
self.hidden_size = hidden_size
+ self.observation_size = env.single_observation_space.shape[0]
+ self.max_partner_objects = env.max_partner_objects
+ self.partner_features = env.partner_features
+ self.max_road_objects = env.max_road_objects
+ self.road_features = env.road_features
+ self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories
# Conditioning setup
self.use_rc = env.reward_conditioned
@@ -24,25 +31,42 @@ def __init__(self, env, input_size=128, hidden_size=128, **kwargs):
self.conditioning_dims = (3 if self.use_rc else 0) + (1 if self.use_ec else 0) + (1 if self.use_dc else 0)
# Determine ego dimension from environment's dynamics model
- base_ego_dim = 10 if env.dynamics_model == "jerk" else 7
- self.ego_dim = base_ego_dim + self.conditioning_dims
+ # Use binding constants for ego dimensions (includes lane features)
+ from pufferlib.ocean.drive import binding
+ base_ego_dim = binding.EGO_FEATURES_JERK if env.dynamics_model == "jerk" else binding.EGO_FEATURES_CLASSIC
+ self.ego_dim = base_ego_dim + self.conditioning_dims
+ cond_flags = []
+ if self.use_rc:
+ cond_flags.append("reward(3)")
+ if self.use_ec:
+ cond_flags.append("entropy(1)")
+ if self.use_dc:
+ cond_flags.append("discount(1)")
+ cond_str = "+".join(cond_flags) if cond_flags else "none"
+ print(
+ f"[Drive policy] dynamics={env.dynamics_model} ego_dim={self.ego_dim} "
+ f"(base={base_ego_dim}+cond={self.conditioning_dims}: {cond_str}) "
+ f"obs_dim={self.observation_size} max_partners={self.max_partner_objects} "
+ f"max_road={self.max_road_objects} hidden={hidden_size}",
+ flush=True,
+ )
self.ego_encoder = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)),
nn.LayerNorm(input_size),
# nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)),
)
- max_road_objects = 13
+
self.road_encoder = nn.Sequential(
- pufferlib.pytorch.layer_init(nn.Linear(max_road_objects, input_size)),
+ pufferlib.pytorch.layer_init(nn.Linear(self.road_features_after_onehot, input_size)),
nn.LayerNorm(input_size),
# nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)),
)
- max_partner_objects = 7
+
self.partner_encoder = nn.Sequential(
- pufferlib.pytorch.layer_init(nn.Linear(max_partner_objects, input_size)),
+ pufferlib.pytorch.layer_init(nn.Linear(self.partner_features, input_size)),
nn.LayerNorm(input_size),
# nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)),
@@ -72,17 +96,18 @@ def forward_train(self, x, state=None):
def encode_observations(self, observations, state=None):
ego_dim = self.ego_dim
- partner_dim = 63 * 7
- road_dim = 200 * 7
+ partner_dim = self.max_partner_objects * self.partner_features
+ road_dim = self.max_road_objects * self.road_features
ego_obs = observations[:, :ego_dim]
partner_obs = observations[:, ego_dim : ego_dim + partner_dim]
road_obs = observations[:, ego_dim + partner_dim : ego_dim + partner_dim + road_dim]
- partner_objects = partner_obs.view(-1, 63, 7)
- road_objects = road_obs.view(-1, 200, 7)
- road_continuous = road_objects[:, :, :6] # First 6 features
- road_categorical = road_objects[:, :, 6]
- road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, 200, 7]
+ partner_objects = partner_obs.view(-1, self.max_partner_objects, self.partner_features)
+
+ road_objects = road_obs.view(-1, self.max_road_objects, self.road_features)
+ road_continuous = road_objects[:, :, : self.road_features - 1]
+ road_categorical = road_objects[:, :, self.road_features - 1]
+ road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, ROAD_MAX_OBJECTS, 7]
road_objects = torch.cat([road_continuous, road_onehot], dim=2)
ego_features = self.ego_encoder(ego_obs)
partner_features, _ = self.partner_encoder(partner_objects).max(dim=1)
diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py
index 68b2e1ef2a..08e680cd13 100644
--- a/pufferlib/pufferl.py
+++ b/pufferlib/pufferl.py
@@ -20,6 +20,7 @@
import configparser
from threading import Thread
from collections import defaultdict, deque
+from pathlib import Path
import numpy as np
import psutil
@@ -34,6 +35,7 @@
import pufferlib.vector
import pufferlib.pytorch
import pufferlib.utils
+import pufferlib.utils
try:
from pufferlib import _C
@@ -74,7 +76,17 @@ def __init__(self, config, vecenv, policy, logger=None):
# Vecenv info
self.adaptive_driving_agent = getattr(vecenv.driver_env, "env_name", None) == "adaptive_drive"
if self.adaptive_driving_agent:
- config["bptt_horizon"] = vecenv.driver_env.episode_length
+ if config.get("policy_architecture", "Recurrent") == "Recurrent":
+ config["bptt_horizon"] = vecenv.driver_env.episode_length
+ if config.get("policy_architecture", "Recurrent") == "Transformer":
+ config["context_length"] = self.context_length = vecenv.driver_env.episode_length
+ config["bptt_horizon"] = (
+ vecenv.driver_env.episode_length
+ ) ## this is used downstream so you need to define it too
+ else:
+ if config.get("policy_architecture", "Recurrent") == "Transformer":
+ self.context_length = config["context_length"]
+ config["bptt_horizon"] = config["context_length"]
vecenv.async_reset(seed)
obs_space = vecenv.single_observation_space
@@ -84,7 +96,10 @@ def __init__(self, config, vecenv, policy, logger=None):
if self.population_play:
total_ego_agents = vecenv.num_ego_agents
agents_for_calc = total_ego_agents
- batch_size = vecenv.driver_env.num_ego_agents * config["bptt_horizon"] * vecenv.num_workers
+ if config.get("policy_architecture", "Recurrent") == "Recurrent":
+ batch_size = vecenv.driver_env.num_ego_agents * config["bptt_horizon"] * vecenv.num_workers
+ if config.get("policy_architecture", "Recurrent") == "Transformer":
+ batch_size = vecenv.driver_env.num_ego_agents * config["context_length"] * vecenv.num_workers
config["batch_size"] = batch_size ## this is dynamic and based on ego agents
else:
agents_for_calc = total_agents
@@ -93,17 +108,43 @@ def __init__(self, config, vecenv, policy, logger=None):
self.total_agents = total_agents
# Experience
- if config["batch_size"] == "auto" and config["bptt_horizon"] == "auto":
- raise pufferlib.APIUsageError("Must specify batch_size or bptt_horizon")
+ if (
+ config["batch_size"] == "auto"
+ and config.get("bptt_horizon", "auto") == "auto"
+ and config.get("context_length", "auto") == "auto"
+ ):
+ raise pufferlib.APIUsageError("Must specify batch_size, bptt_horizon, or context_length")
elif config["batch_size"] == "auto":
- config["batch_size"] = agents_for_calc * config["bptt_horizon"]
- elif config["bptt_horizon"] == "auto":
+ if config.get("policy_architecture", "Recurrent") == "Recurrent":
+ config["batch_size"] = agents_for_calc * config["bptt_horizon"]
+ elif config.get("policy_architecture", "Recurrent") == "Transformer":
+ config["batch_size"] = agents_for_calc * config["context_length"]
+ elif (
+ config.get("bptt_horizon", "auto") == "auto"
+ and config.get("policy_architecture", "Recurrent") == "Recurrent"
+ ):
config["bptt_horizon"] = config["batch_size"] // agents_for_calc
+ elif (
+ config.get("context_length", "auto") == "auto"
+ and config.get("policy_architecture", "Recurrent") == "Transformer"
+ ):
+ config["context_length"] = config["batch_size"] // agents_for_calc
batch_size = config["batch_size"]
- horizon = config["bptt_horizon"]
+
+ # Set horizon based on model type
+ if config.get("policy_architecture", "Recurrent") == "Recurrent":
+ horizon = config["bptt_horizon"]
+ elif config.get("policy_architecture", "Recurrent") == "Transformer":
+ horizon = config["context_length"]
+ else:
+ horizon = config.get("bptt_horizon", config.get("context_length", 1))
+
+ config["bptt_horizon"] = horizon # For backward compatibility
+
segments = batch_size // horizon
self.segments = segments
+ self.horizon = horizon
if not self.population_play:
if total_agents > segments:
raise pufferlib.APIUsageError(f"Total agents {total_agents} <= segments {segments}")
@@ -140,11 +181,8 @@ def __init__(self, config, vecenv, policy, logger=None):
self.render = config["render"]
self.render_interval = config["render_interval"]
- if self.render:
- ensure_drive_binary()
-
# LSTM
- if config["use_rnn"]:
+ if config.get("rnn_name", "Recurrent") == "Recurrent":
h = policy.hidden_size
if self.population_play:
n = vecenv.ego_agents_per_batch # Use ego agents per batch
@@ -156,6 +194,47 @@ def __init__(self, config, vecenv, policy, logger=None):
self.lstm_h = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)}
self.lstm_c = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)}
+ # TRANSFORMER
+ if config.get("rnn_name", "Recurrent") == "Transformer":
+ h = policy.hidden_size
+
+ if self.population_play:
+ n = vecenv.ego_agents_per_batch # Use ego agents per batch
+ num_chunks = total_ego_agents // n
+ # Initialize transformer context buffers
+ self.transformer_context = {i * n: torch.zeros(n, 0, h, device=device) for i in range(num_chunks)}
+ self.transformer_position = {
+ i * n: torch.zeros(n, dtype=torch.long, device=device) for i in range(num_chunks)
+ }
+ else:
+ n = vecenv.agents_per_batch
+ num_chunks = total_agents // n
+ # Initialize transformer context buffers
+ self.transformer_context = {i * n: torch.zeros(n, 0, h, device=device) for i in range(num_chunks)}
+ self.transformer_position = {
+ i * n: torch.zeros(n, dtype=torch.long, device=device) for i in range(num_chunks)
+ }
+ # K/V cache persistence for the streaming forward_eval path.
+ # The model lazy-allocates k_cache and v_cache (list of per-layer
+ # tensors) on first call when state.get("k_cache") is None. We
+ # persist them here so the next rollout step finds the cache
+ # already populated with past timesteps' projections — without
+ # this, every step would lazy-allocate fresh empty caches and
+ # the policy would attend only to the current step (silent bug
+ # discovered 2026-05-02; broke all in-context-learning runs
+ # prior to that). None initially → first call allocates.
+ self.transformer_k_cache = {i * n: None for i in range(num_chunks)}
+ self.transformer_v_cache = {i * n: None for i in range(num_chunks)}
+
+ # Regression detector for the rnn_name plumbing bug — fires once.
+ print(
+ f"[VERIFY rnn_name] config.get('rnn_name')={config.get('rnn_name')!r}, "
+ f"policy_architecture={config.get('policy_architecture')!r}, "
+ f"has_lstm_h={hasattr(self, 'lstm_h')}, "
+ f"has_transformer_k_cache={hasattr(self, 'transformer_k_cache')}",
+ flush=True,
+ )
+
# Minibatching & gradient accumulation
if self.adaptive_driving_agent:
minibatch_size = config["minibatch_multiplier"] * horizon
@@ -179,7 +258,7 @@ def __init__(self, config, vecenv, policy, logger=None):
self.minibatch_segments = self.minibatch_size // horizon
if self.minibatch_segments * horizon != self.minibatch_size:
raise pufferlib.APIUsageError(
- f"minibatch_size {self.minibatch_size} must be divisible by bptt_horizon {horizon}"
+ f"minibatch_size {self.minibatch_size} must be divisible by horizon {horizon}"
)
# Torch compile
@@ -187,7 +266,8 @@ def __init__(self, config, vecenv, policy, logger=None):
self.policy = policy
if config["compile"]:
self.policy = torch.compile(policy, mode=config["compile_mode"])
- self.policy.forward_eval = torch.compile(policy, mode=config["compile_mode"])
+ if hasattr(policy, "forward_eval"):
+ self.policy.forward_eval = torch.compile(policy.forward_eval, mode=config["compile_mode"])
pufferlib.pytorch.sample_logits = torch.compile(
pufferlib.pytorch.sample_logits, mode=config["compile_mode"]
)
@@ -217,26 +297,95 @@ def __init__(self, config, vecenv, policy, logger=None):
raise ValueError(f"Unknown optimizer: {config['optimizer']}")
self.optimizer = optimizer
+
+ # ---- Resume optimizer / epoch / global_step from trainer_state.pt ----
+ # When --load-model-path points at a checkpoint that has a sibling
+ # trainer_state.pt (which the trainer writes alongside every model
+ # checkpoint), restore optimizer momentum + counters so the resumed
+ # run continues mid-cosine instead of warm-restarting at peak LR with
+ # cold Adam moments.
+ resume_epoch = 0
+ resume_global_step = 0
+ load_path = config.get("load_model_path")
+ if load_path:
+ state_path = os.path.join(os.path.dirname(load_path), "trainer_state.pt")
+ if os.path.exists(state_path):
+ try:
+ # weights_only=False: trainer_state.pt contains optimizer
+ # state (with class refs), not just tensors.
+ saved = torch.load(state_path, map_location=config["device"], weights_only=False)
+ optimizer.load_state_dict(saved["optimizer_state_dict"])
+ resume_epoch = int(saved.get("update", 0))
+ resume_global_step = int(saved.get("global_step", 0))
+ print(
+ f"[trainer-state] Resumed optimizer state from {state_path}\n"
+ f"[trainer-state] epoch={resume_epoch} global_step={resume_global_step}",
+ flush=True,
+ )
+ except Exception as e:
+ print(f"[trainer-state] WARNING: could not load {state_path}: {e}", flush=True)
+
# Logging
self.logger = logger
if logger is None:
self.logger = NoLogger(config)
if self.population_play:
+ # Under external_co_player_actions, driver_env.co_player_policy is
+ # None (worker doesn't load it); the GPU-bound copy lives on the
+ # vecenv as co_player_policy_func.
+ export_co_player = getattr(vecenv, "co_player_policy_func", None) or vecenv.driver_env.co_player_policy
co_player_path = f"resources/drive/{config['env']}_co_player.bin"
export_args = {"env_name": config["env"], "path": co_player_path, **config}
export(
args=export_args,
env_name=config["env"],
vecenv=vecenv,
- policy=vecenv.driver_env.co_player_policy,
+ policy=export_co_player,
path=co_player_path,
silent=True,
)
- # Learning rate scheduler
+ # ---- Centralized GPU co-player inference (when enabled) ------------
+ self.external_co_player = bool(
+ self.population_play
+ and getattr(vecenv, "co_player_policy_func", None) is not None
+ and config.get("env_config", {}).get("external_co_player_actions", False)
+ )
+ if self.external_co_player:
+ co_policy = vecenv.co_player_policy_func.to(config["device"])
+ co_policy.eval()
+ self.co_player_policy = co_policy
+ self.co_player_conditioning_dims = getattr(vecenv, "co_player_conditioning_dims", 0)
+ # One state dict per worker (each worker holds its own slice of
+ # co_players; the per-worker batch size is num_co_players_per_env).
+ num_co_per_worker = vecenv.driver_env.num_co_players
+ num_workers = vecenv.num_workers
+ self._co_player_num_per_worker = num_co_per_worker
+ # Per-worker state dicts. Start each as an empty dict so that
+ # `forward_eval` lazily allocates the K/V cache on first call with
+ # the correct (obs-derived) dtype — avoiding a cache dtype that
+ # mismatches the layer-output dtype during reset_eval_state's
+ # cache-prime path.
+ self.co_player_state = {w: {} for w in range(num_workers)}
+ print(
+ f"[external co-player] Loaded co-player on {device}; "
+ f"per-worker batch={num_co_per_worker}, conditioning_dims={self.co_player_conditioning_dims}, "
+ f"num_workers={num_workers}",
+ flush=True,
+ )
+
+ # Learning rate scheduler — if resuming, advance to the saved epoch
+ # position so cosine annealing continues smoothly.
epochs = config["total_timesteps"] // config["batch_size"]
- self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ last_epoch_arg = -1
+ if resume_epoch > 0:
+ # CosineAnnealingLR requires `initial_lr` in each param_group when
+ # last_epoch != -1; old optimizer states sometimes lack it.
+ for group in optimizer.param_groups:
+ group.setdefault("initial_lr", config["learning_rate"])
+ last_epoch_arg = resume_epoch - 1 # next .step() lands on resume_epoch
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, last_epoch=last_epoch_arg)
self.total_epochs = epochs
# Automatic mixed precision
@@ -250,9 +399,9 @@ def __init__(self, config, vecenv, policy, logger=None):
# Initializations
self.config = config
self.vecenv = vecenv
- self.epoch = 0
- self.global_step = 0
- self.last_log_step = 0
+ self.epoch = resume_epoch
+ self.global_step = resume_global_step
+ self.last_log_step = resume_global_step
self.last_log_time = time.time()
self.start_time = time.time()
self.utilization = Utilization()
@@ -260,6 +409,7 @@ def __init__(self, config, vecenv, policy, logger=None):
self.stats = defaultdict(list)
self.last_stats = defaultdict(list)
self.losses = {}
+
# Dashboard
self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad)
self.print_dashboard(clear=True)
@@ -275,6 +425,119 @@ def sps(self):
return (self.global_step - self.last_log_step) / (time.time() - self.last_log_time)
+ def _fill_external_co_player_actions(self, full_obs, info, env_id, dones, truncs):
+ """Centralized co-player inference on GPU.
+
+ Workers receive co_player actions via the shared `actions` SHM buffer
+ (filled here) instead of running per-worker CPU forward passes.
+
+ Args:
+ full_obs: numpy obs from vecenv.recv(), shape
+ (num_agents_per_recv_batch, *obs_shape).
+ info: the raw info list from vecenv.recv(). Must contain a dict
+ with key "_external_co_player_ids" giving the actual list of
+ co-player agent indices (worker-local). Computing this from
+ complement-of-ego_ids is wrong: many slots are padding or
+ otherwise inactive, and forwarding garbage obs through them
+ pollutes the shared KV cache.
+ env_id: numpy array of agent indices for the current recv batch.
+ dones, truncs: numpy bool per-agent done/trunc flags.
+ """
+ import numpy as np
+ import torch
+
+ device = self.config["device"]
+ agents_per_worker = self.vecenv.agents_per_worker
+ # batch_size > 1 packs multiple workers into one recv; we handle the
+ # batch_size=1 case (the production setup) here. Generalizing to
+ # batch_size>1 is straightforward (loop over workers in the batch).
+ batch_size = self.vecenv.batch_size
+ if batch_size != 1:
+ raise NotImplementedError(
+ f"external_co_player_actions currently only supports batch_size=1; got batch_size={batch_size}."
+ )
+
+ # Map env_id back to a worker index so we know which co_player_state
+ # to use and which row of vecenv.actions to write to.
+ # For population_play, recv() returns ego-only agent ids
+ # (vecenv.ego_agent_ids), so divide by the per-worker ego count, not
+ # the full agent count.
+ ego_agents_per_worker = getattr(self.vecenv, "ego_agents_per_worker", agents_per_worker)
+ worker_id = int(env_id[0]) // ego_agents_per_worker
+
+ # Pull the actual co_player_ids from info (the env knows them).
+ # Also check for the scenario-boundary cache reset signal.
+ co_ids = None
+ reset_cache = False
+ for item in info:
+ if isinstance(item, dict):
+ if "_external_co_player_ids" in item:
+ co_ids = list(item["_external_co_player_ids"])
+ if item.get("_external_reset_co_cache"):
+ reset_cache = True
+ if co_ids is None:
+ raise RuntimeError(
+ "external_co_player_actions=True but the env did not "
+ "publish '_external_co_player_ids' in info. Is drive.py up to date?"
+ )
+ if not co_ids:
+ return # no co-players in this env this step
+
+ # Drop the cache at scenario boundaries — mirrors the per-worker
+ # OFF path's _reset_co_player_state() which fully reinits state.
+ # Replacing the dict makes forward_eval lazy-allocate fresh K/V on
+ # the next call, matching legacy behavior bit-for-bit at scenario
+ # boundaries.
+ if reset_cache:
+ self.co_player_state[worker_id] = {}
+
+ # Slice the co-player observations. When the ego is in oracle mode
+ # (drive.py `ego_is_oracle=True`) the env's obs is wider than the
+ # partner policy expects — partner conditioning is appended to ego
+ # rows only. Strip trailing oracle dims by slicing columns to the
+ # env's `_c_obs_dim` (the C-side obs width). Defaults to None when
+ # oracle is off → take the full width as before.
+ co_obs_width = getattr(self.vecenv.driver_env, "_c_obs_dim", None)
+ if co_obs_width is None:
+ co_obs_np = full_obs[co_ids]
+ else:
+ co_obs_np = full_obs[co_ids, :co_obs_width]
+ co_obs = torch.as_tensor(co_obs_np, device=device)
+ if self.co_player_conditioning_dims > 0:
+ # Pull this worker's conditioning slice from the SHM buffer the
+ # env wrote at scenario boundaries (or at env init). Insert right
+ # after the base ego_features — matches drive.py's
+ # `_add_co_player_conditioning` exactly.
+ cond_shm = self.vecenv.co_player_conditioning # (num_workers, max_co, cdim)
+ cond_np = cond_shm[worker_id, : len(co_ids), :] # only the rows we'll use
+ cond = torch.as_tensor(cond_np, device=device, dtype=co_obs.dtype)
+ from pufferlib.ocean.drive import binding as _b
+
+ base_ego_dim = (
+ _b.EGO_FEATURES_JERK if self.vecenv.driver_env.dynamics_model == "jerk" else _b.EGO_FEATURES_CLASSIC
+ )
+ co_obs = torch.cat([co_obs[:, :base_ego_dim], cond, co_obs[:, base_ego_dim:]], dim=1)
+
+ # NOTE: the OFF (per-worker) path only resets cache at scenario
+ # boundary or reset(), never for individual done agents. So we
+ # don't reset per-done here either — it would diverge from OFF.
+
+ with torch.no_grad():
+ logits, _ = self.co_player_policy.forward_eval(co_obs, self.co_player_state[worker_id])
+
+ # Match the per-worker code path: argmax for discrete actions.
+ if isinstance(logits, tuple):
+ co_action = torch.cat([l.argmax(dim=-1, keepdim=True) for l in logits], dim=-1)
+ else:
+ co_action = logits.argmax(dim=-1)
+ co_action_np = co_action.cpu().numpy().reshape(len(co_ids), -1)
+
+ # Write directly to the worker's slot in the shared-memory action
+ # buffer. The worker's env.step() will call vec_step using these
+ # actions because it has external_co_player_actions=True.
+ co_action_view = self.vecenv.actions[worker_id] # shape (agents_per_worker, *atn_shape)
+ co_action_view[co_ids] = co_action_np.reshape((len(co_ids),) + co_action_view.shape[1:])
+
def evaluate(self):
profile = self.profile
epoch = self.epoch
@@ -284,36 +547,59 @@ def evaluate(self):
config = self.config
device = config["device"]
- if config["use_rnn"]:
+ # Reset hidden states for both RNN and Transformer
+ if config.get("rnn_name", "Recurrent") == "Recurrent":
for k in self.lstm_h:
self.lstm_h[k] = torch.zeros(self.lstm_h[k].shape, device=device)
self.lstm_c[k] = torch.zeros(self.lstm_c[k].shape, device=device)
+ if config.get("rnn_name", "Recurrent") == "Transformer":
+ h = self.policy.hidden_size
+ for k in self.transformer_context:
+ n = self.transformer_context[k].shape[0]
+ # Pre-allocate full buffer instead of empty
+ self.transformer_context[k] = torch.zeros(n, self.horizon, h, device=device)
+ self.transformer_position[k] = torch.zeros(1, dtype=torch.long, device=device)
+ # Drop K/V cache so the model lazy-allocates fresh on
+ # the first forward_eval call of this rollout.
+ self.transformer_k_cache[k] = None
+ self.transformer_v_cache[k] = None
+
self.full_rows = 0
while self.full_rows < self.segments:
profile("env", epoch)
+ # print(".", end="", flush=True) # Workaround: visible I/O prevents multiprocessing deadlock
o, r, d, t, info, env_id, mask = self.vecenv.recv()
# print(f"o shape is {o.shape}", flush = True)
if self.population_play:
batch_size = self.vecenv.batch_size
- ego_ids = info[-1]
+ # Filter info to get only the ego_ids lists (not the metrics dicts)
+ ego_ids_per_env = [item for item in info if isinstance(item, list)]
+
+ if self.external_co_player:
+ # Run co-player forward on GPU before the ego-only slicing
+ # below (we need the FULL obs array to extract co-player obs).
+ self._fill_external_co_player_actions(o, info, env_id, d, t)
if batch_size > 1:
total_agents = len(o)
num_agents_per_env = total_agents // batch_size
- original_shape = o.shape
-
- o = o.reshape(batch_size, num_agents_per_env, *original_shape[1:])
- r = r.reshape(batch_size, num_agents_per_env)
- d = d.reshape(batch_size, num_agents_per_env)
- t = t.reshape(batch_size, num_agents_per_env)
-
- o = o[:, ego_ids].reshape(batch_size * len(ego_ids), *original_shape[1:])
- r = r[:, ego_ids].flatten()
- d = d[:, ego_ids].flatten()
- t = t[:, ego_ids].flatten()
+ # Create flat ego_ids by adding batch offset
+ flat_ego_ids = []
+ for env_idx in range(batch_size):
+ ego_ids = ego_ids_per_env[env_idx]
+ offset = env_idx * num_agents_per_env
+ flat_ego_ids.extend([int(idx) + offset for idx in ego_ids])
+
+ # Simply index with the flat ego_ids
+ o = o[flat_ego_ids]
+ r = r[flat_ego_ids]
+ d = d[flat_ego_ids]
+ t = t[flat_ego_ids]
else:
+ ego_ids = ego_ids_per_env[0] # Single environment
+ ego_ids = [int(idx) for idx in ego_ids] # Convert to int
o = o[ego_ids]
r = r[ego_ids]
d = d[ego_ids]
@@ -326,9 +612,9 @@ def evaluate(self):
profile("eval_copy", epoch)
o = torch.as_tensor(o)
- o_device = o.to(device) # , non_blocking=True)
- r = torch.as_tensor(r).to(device) # , non_blocking=True)
- d = torch.as_tensor(d).to(device) # , non_blocking=True)
+ o_device = o.to(device, non_blocking=True)
+ r = torch.as_tensor(r, device=device)
+ d = torch.as_tensor(d, device=device)
profile("eval_forward", epoch)
with torch.no_grad(), self.amp_context:
@@ -338,29 +624,84 @@ def evaluate(self):
env_id=env_id,
mask=mask,
)
-
- if config["use_rnn"]:
- state["lstm_h"] = self.lstm_h[env_id.start]
- state["lstm_c"] = self.lstm_c[env_id.start]
-
+ # Get appropriate batch key for state lookup
+ if self.population_play:
+ batch_size = self.vecenv.ego_agents_per_batch
+ else:
+ batch_size = self.vecenv.agents_per_batch
+ state_key = (env_id.start // batch_size) * batch_size
+
+ if config.get("rnn_name", "Recurrent") == "Recurrent":
+ state["lstm_h"] = self.lstm_h[state_key]
+ state["lstm_c"] = self.lstm_c[state_key]
+
+ if config.get("rnn_name", "Recurrent") == "Transformer":
+ state["transformer_context"] = self.transformer_context[state_key]
+ state["transformer_position"] = self.transformer_position[state_key]
+ # K/V cache for streaming attention. None on the first
+ # call → model lazy-allocates (and resets pos to 0).
+ # Subsequent calls reuse the populated cache, which is
+ # the whole point: each step appends one new K/V slot
+ # and the policy attends over the full accumulated past.
+ state["k_cache"] = self.transformer_k_cache[state_key]
+ state["v_cache"] = self.transformer_v_cache[state_key]
+ # Note: terminals not needed for eval since we're doing single-step inference
+
+ # print(".", end="", flush=True) # Prevents multiprocessing deadlock
logits, value = self.policy.forward_eval(o_device, state)
action, logprob, _ = pufferlib.pytorch.sample_logits(logits)
r = torch.clamp(r, -1, 1)
profile("eval_copy", epoch)
with torch.no_grad():
- if config["use_rnn"]:
- # Use the same lstm_key calculation
+ # Update hidden states after forward pass
+ if config.get("rnn_name", "Recurrent") == "Recurrent":
if self.population_play:
batch_size = self.vecenv.ego_agents_per_batch
else:
batch_size = self.vecenv.agents_per_batch
lstm_key = (env_id.start // batch_size) * batch_size
-
self.lstm_h[lstm_key] = state["lstm_h"]
self.lstm_c[lstm_key] = state["lstm_c"]
+ if config.get("rnn_name", "Recurrent") == "Transformer":
+ if self.population_play:
+ batch_size = self.vecenv.ego_agents_per_batch
+ else:
+ batch_size = self.vecenv.agents_per_batch
+
+ transformer_key = (env_id.start // batch_size) * batch_size
+ self.transformer_context[transformer_key] = state["transformer_context"]
+ self.transformer_position[transformer_key] = state["transformer_position"]
+ # Persist the K/V cache the model just wrote/updated so
+ # the next forward_eval call sees the accumulated past.
+ # state.get(...) is defensive: model may not have set
+ # these if it took the legacy path.
+ self.transformer_k_cache[transformer_key] = state.get("k_cache")
+ self.transformer_v_cache[transformer_key] = state.get("v_cache")
+
+ # Episode-boundary reset. pos is a shared (1,) scalar
+ # across the chunk; cache rows are per-agent. Filter
+ # done indices against the cache's batch dim, not the
+ # pos buffer's (1,) shape.
+ if done_mask.any():
+ done_indices = torch.where(torch.from_numpy(done_mask))[0]
+ if len(done_indices) > 0:
+ batch_start_in_group = env_id.start % batch_size
+ global_indices = batch_start_in_group + done_indices
+ kc = self.transformer_k_cache[transformer_key]
+ vc = self.transformer_v_cache[transformer_key]
+ cache_batch_dim = kc[0].shape[0] if kc is not None else 0
+ valid_mask = global_indices < cache_batch_dim
+ valid_indices = global_indices[valid_mask]
+ if len(valid_indices) > 0:
+ self.transformer_position[transformer_key][:] = 0
+ if kc is not None and vc is not None:
+ for c in kc:
+ c[valid_indices] = 0
+ for c in vc:
+ c[valid_indices] = 0
# Fast path for fully vectorized envs
l = self.ep_lengths[env_id.start].item()
batch_rows = slice(self.ep_indices[env_id.start].item(), 1 + self.ep_indices[env_id.stop - 1].item())
@@ -378,7 +719,13 @@ def evaluate(self):
# Note: We are not yet handling masks in this version
self.ep_lengths[env_id] += 1
- if l + 1 >= config["bptt_horizon"]:
+ # Use appropriate horizon based on model type
+ horizon = (
+ config.get("context_length")
+ if config.get("policy_architecture", "Recurrent") == "Transformer"
+ else config["bptt_horizon"]
+ )
+ if l + 1 >= horizon:
num_full = env_id.stop - env_id.start
self.ep_indices[env_id] = self.free_idx + torch.arange(num_full, device=config["device"]).int()
self.ep_lengths[env_id] = 0
@@ -400,12 +747,17 @@ def evaluate(self):
self.stats[k].append(v)
profile("env", epoch)
-
self.vecenv.send(action)
profile("eval_misc", epoch)
self.free_idx = self.total_agents
- self.ep_indices = torch.arange(self.total_agents, device=device, dtype=torch.int32)
+
+ if self.population_play:
+ total_agents = self.vecenv.num_ego_agents
+ else:
+ total_agents = self.total_agents
+
+ self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32)
self.ep_lengths.zero_()
profile.end()
return self.stats
@@ -438,9 +790,9 @@ def train(self):
hasattr(self.vecenv.driver_env, "dynamics_model")
and self.vecenv.driver_env.dynamics_model == "jerk"
):
- disc_idx = 10 # base ego obs
+ disc_idx = 12 # EGO_FEATURES_JERK (was 10 before lane features)
else:
- disc_idx = 7
+ disc_idx = 9 # EGO_FEATURES_CLASSIC (was 7 before lane features)
if self.vecenv.driver_env.reward_conditioned:
disc_idx += 3
@@ -468,7 +820,15 @@ def train(self):
prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6)
idx = torch.multinomial(prio_probs, self.minibatch_segments)
mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta
- mb_obs = self.observations[idx]
+ # When cpu_offload=True, self.observations lives on CPU but `idx`
+ # is on the training device (GPU). PyTorch refuses cross-device
+ # fancy indexing, so move the index to CPU for the gather, then
+ # ship the resulting minibatch to the device. Buffer was allocated
+ # with pin_memory=True (see __init__) so the H2D copy is fast.
+ if config["cpu_offload"]:
+ mb_obs = self.observations[idx.cpu()].to(device, non_blocking=True)
+ else:
+ mb_obs = self.observations[idx]
mb_actions = self.actions[idx]
mb_logprobs = self.logprobs[idx]
mb_rewards = self.rewards[idx]
@@ -480,17 +840,45 @@ def train(self):
mb_advantages = advantages[idx]
profile("train_forward", epoch)
- if not config["use_rnn"]:
+
+ # Handle observation reshaping based on model type
+ if (
+ not config.get("rnn_name", "Recurrent") == "Recurrent"
+ and not config.get("rnn_name", "Recurrent") == "Transformer"
+ ):
+ # Flatten for non-recurrent models
mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape)
state = dict(
action=mb_actions,
- lstm_h=None,
- lstm_c=None,
)
+ # Add appropriate state based on model type
+ if config.get("rnn_name", "Recurrent") == "Recurrent":
+ state["lstm_h"] = None
+ state["lstm_c"] = None
+ elif config.get("rnn_name", "Recurrent") == "Transformer":
+ state["transformer_context"] = None
+ state["transformer_position"] = None
+ state["terminals"] = mb_terminals # For episode boundary masking
+
logits, newvalue = self.policy(mb_obs, state)
- actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions)
+
+ # Handle action sampling based on observation shape
+ if (
+ config.get("rnn_name", "Recurrent") == "Recurrent"
+ or config.get("rnn_name", "Recurrent") == "Transformer"
+ ):
+ # Add this right before calling sample_logits
+ if isinstance(logits, tuple):
+ logits = logits[0]
+ actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions)
+ else:
+ # Need to flatten actions for non-recurrent models
+ actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(
+ logits,
+ action=mb_actions.reshape(-1, *mb_actions.shape[2:]) if len(mb_actions.shape) > 2 else mb_actions,
+ )
profile("train_misc", epoch)
newlogprob = newlogprob.reshape(mb_logprobs.shape)
@@ -508,6 +896,8 @@ def train(self):
mb_gammas = gammas[idx]
else:
mb_gammas = torch.full((len(idx),), config["gamma"], device=device, dtype=torch.float32)
+
+ # Recompute advantages with new ratios
adv = compute_puff_advantage(
mb_values,
mb_rewards,
@@ -541,9 +931,9 @@ def train(self):
hasattr(self.vecenv.driver_env, "dynamics_model")
and self.vecenv.driver_env.dynamics_model == "jerk"
):
- ent_idx = 10 # base ego obs
+ ent_idx = 12 # EGO_FEATURES_JERK (was 10 before lane features)
else:
- ent_idx = 7
+ ent_idx = 9 # EGO_FEATURES_CLASSIC (was 7 before lane features)
if self.vecenv.driver_env.reward_conditioned:
ent_idx += 3
@@ -607,30 +997,15 @@ def train(self):
self.msg = f"Checkpoint saved at update {self.epoch}"
if self.render and self.epoch % self.render_interval == 0:
- model_dir = os.path.join(self.config["data_dir"], f"{self.config['env']}_{self.logger.run_id}")
- model_files = glob.glob(os.path.join(model_dir, "model_*.pt"))
-
- if model_files:
- # Take the latest checkpoint
- latest_cpt = max(model_files, key=os.path.getctime)
- bin_path = f"{model_dir}.bin"
-
- # Export to .bin for rendering with raylib
- try:
- export_args = {"env_name": self.config["env"], "load_model_path": latest_cpt, **self.config}
-
- export(
- args=export_args,
- env_name=self.config["env"],
- vecenv=self.vecenv,
- policy=self.uncompiled_policy,
- path=bin_path,
- silent=True,
- )
- pufferlib.utils.render_videos(self.config, self.vecenv, self.logger, self.global_step, bin_path)
-
- except Exception as e:
- print(f"Failed to export model weights: {e}")
+ torch.cuda.empty_cache()
+ pufferlib.utils.render_videos(
+ config=self.config,
+ policy=self.uncompiled_policy,
+ logger=self.logger,
+ epoch=self.epoch,
+ global_step=self.global_step,
+ device=self.config["device"],
+ )
if self.config["eval"]["wosac_realism_eval"] and (
self.epoch % self.config["eval"]["eval_interval"] == 0 or done_training
@@ -641,6 +1016,16 @@ def train(self):
self.epoch % self.config["eval"]["eval_interval"] == 0 or done_training
):
pufferlib.utils.run_human_replay_eval_in_subprocess(self.config, self.logger, self.global_step)
+ torch.cuda.empty_cache()
+ pufferlib.utils.render_videos(
+ config=self.config,
+ policy=self.uncompiled_policy,
+ logger=self.logger,
+ epoch=self.epoch,
+ global_step=self.global_step,
+ device=self.config["device"],
+ human_replay=True,
+ )
def mean_and_log(self):
config = self.config
@@ -716,6 +1101,12 @@ def save_checkpoint(self):
state_path = os.path.join(path, "trainer_state.pt")
torch.save(state, state_path + ".tmp")
os.rename(state_path + ".tmp", state_path)
+
+ # Sidecar metadata: every render/eval can recover the right
+ # conditioning, dataset, and architecture from /info.json
+ # without the user having to re-pass them on the CLI.
+ write_run_info(path, self.config, run_id)
+
return model_path
def print_dashboard(self, clear=False, idx=[0], c1="[cyan]", c2="[white]", b1="[bright_cyan]", b2="[bright_white]"):
@@ -1052,6 +1443,7 @@ def __init__(self, args, load_id=None, resume="allow"):
save_code=False,
resume=resume,
config=args,
+ name=args.get("wandb_name"),
tags=[args["tag"]] if args["tag"] is not None else [],
)
self.wandb = wandb
@@ -1103,11 +1495,26 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None):
policy = model.to(local_rank)
if args["neptune"]:
- logger = NeptuneLogger(args)
+ logger = NeptuneLogger(args, load_id=args.get("load_id"))
elif args["wandb"]:
- logger = WandbLogger(args)
-
- train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {}))
+ # Pass load_id so the wandb logger resumes the existing run instead
+ # of creating a fresh one. WandbLogger uses resume="allow", so wandb
+ # picks up where the original run left off (history, name, tags).
+ logger = WandbLogger(args, load_id=args.get("load_id"))
+
+ train_config = dict(
+ **args["train"],
+ env=env_name,
+ eval=args.get("eval", {}),
+ env_config=args.get("env", {}),
+ policy_architecture=args.get("policy_architecture", "Recurrent"),
+ # rnn_name lives at args top level — must be explicitly propagated,
+ # else config.get("rnn_name") defaults to "Recurrent" and the
+ # Transformer init/rollout branches in PuffeRL never fire.
+ rnn_name=args.get("rnn_name", args.get("policy_architecture", "Recurrent")),
+ load_model_path=args.get("load_model_path"),
+ load_id=args.get("load_id"),
+ )
pufferl = PuffeRL(train_config, vecenv, policy, logger)
all_logs = []
@@ -1149,15 +1556,24 @@ def eval(env_name, args=None, vecenv=None, policy=None):
wosac_enabled = args["eval"]["wosac_realism_eval"]
human_replay_enabled = args["eval"]["human_replay_eval"]
+ # Honor eval.map_dir only when explicitly set; otherwise inherit the
+ # training env.map_dir so eval doesn't silently switch datasets.
+ eval_map_dir = args["eval"].get("map_dir")
+ if eval_map_dir in (None, "", "None"):
+ eval_map_dir = args["env"].get("map_dir")
+ args["env"]["map_dir"] = eval_map_dir
+ args["eval"]["map_dir"] = eval_map_dir
+ args["env"]["num_maps"] = args["eval"]["num_maps"]
+ args["env"]["use_all_maps"] = True
+ dataset_name = args["env"]["map_dir"].split("/")[-1]
if wosac_enabled:
- print(f"Running WOSAC realism evaluation. \n")
+ print(f"Running WOSAC realism evaluation with {dataset_name} dataset. \n")
from pufferlib.ocean.benchmark.evaluator import WOSACEvaluator
backend = args["eval"]["backend"]
assert backend == "PufferEnv" or not wosac_enabled, "WOSAC evaluation only supports PufferEnv backend."
args["vec"] = dict(backend=backend, num_envs=1)
- args["env"]["num_agents"] = args["eval"]["wosac_num_agents"]
args["env"]["init_mode"] = args["eval"]["wosac_init_mode"]
args["env"]["control_mode"] = args["eval"]["wosac_control_mode"]
args["env"]["init_steps"] = args["eval"]["wosac_init_steps"]
@@ -1172,6 +1588,10 @@ def eval(env_name, args=None, vecenv=None, policy=None):
# Collect ground truth trajectories from the dataset
gt_trajectories = evaluator.collect_ground_truth_trajectories(vecenv)
+ print(f"Number of scenarios: {len(np.unique(gt_trajectories['scenario_id']))}")
+ print(f"Number of controlled agents: {gt_trajectories['x'].shape[0]}")
+ print(f"Number of evaluated agents: {np.sum(gt_trajectories['id'] >= 0)}")
+
# Roll out trained policy in the simulator
simulated_trajectories = evaluator.collect_simulated_trajectories(args, vecenv, policy)
@@ -1179,32 +1599,52 @@ def eval(env_name, args=None, vecenv=None, policy=None):
evaluator._quick_sanity_check(gt_trajectories, simulated_trajectories)
# Analyze and compute metrics
+ agent_state = vecenv.driver_env.get_global_agent_state()
+ road_edge_polylines = vecenv.driver_env.get_road_edge_polylines()
results = evaluator.compute_metrics(
- gt_trajectories, simulated_trajectories, args["eval"]["wosac_aggregate_results"]
+ gt_trajectories,
+ simulated_trajectories,
+ agent_state,
+ road_edge_polylines,
+ args["eval"]["wosac_aggregate_results"],
)
if args["eval"]["wosac_aggregate_results"]:
import json
- print("WOSAC_METRICS_START")
+ print("\nWOSAC_METRICS_START")
print(json.dumps(results))
print("WOSAC_METRICS_END")
return results
elif human_replay_enabled:
- print("Running human replay evaluation.\n")
+ print(f"Running human replay evaluation with {dataset_name} dataset.\n")
from pufferlib.ocean.benchmark.evaluator import HumanReplayEvaluator
backend = args["eval"].get("backend", "PufferEnv")
args["vec"] = dict(backend=backend, num_envs=1)
- args["env"]["num_agents"] = args["eval"]["human_replay_num_agents"]
args["env"]["control_mode"] = args["eval"]["human_replay_control_mode"]
- args["env"]["scenario_length"] = 91 # Standard scenario length
+ # episode_length is NOT hardcoded here — inherits scenario_length
+ # from training config (WOMD=91, nuPlan=201).
+ # Human replay: only 1 ego is policy-controlled, others follow logged trajectories
+ args["env"]["co_player_enabled"] = False
+ args["env"]["max_controlled_agents"] = 1
+ # `human_replay_mode` is only accepted by AdaptiveDrivingAgent
+ if "adaptive" in env_name:
+ args["env"]["human_replay_mode"] = True
+ if args["eval"].get("human_replay_num_agents") is not None:
+ args["env"]["num_agents"] = args["eval"]["human_replay_num_agents"]
+ if args["eval"].get("human_replay_num_maps") is not None:
+ args["env"]["num_maps"] = args["eval"]["human_replay_num_maps"]
+ if args["eval"].get("map_dir") not in (None, "", "None"):
+ args["env"]["map_dir"] = args["eval"]["map_dir"]
vecenv = vecenv or load_env(env_name, args)
policy = policy or load_policy(args, vecenv, env_name)
+ print(f"Effective number of scenarios used: {len(vecenv.driver_env.agent_offsets) - 1}")
+
evaluator = HumanReplayEvaluator(args)
# Run rollouts with human replays
@@ -1231,10 +1671,6 @@ def eval(env_name, args=None, vecenv=None, policy=None):
num_agents = vecenv.observation_space.shape[0]
device = args["train"]["device"]
- # Rebuild visualize binary if saving frames (for C-based rendering)
- if args["save_frames"] > 0:
- ensure_drive_binary()
-
state = {}
if args["train"]["use_rnn"]:
state = dict(
@@ -1312,6 +1748,118 @@ def sweep(args=None, env_name=None):
args["train"]["total_timesteps"] = total_timesteps
+def controlled_exp(env_name, args=None):
+ """Run experiments with all combinations of specified parameter values."""
+ import itertools
+ from copy import deepcopy
+
+ args = args or load_config(env_name)
+ if not args["wandb"] and not args["neptune"]:
+ raise pufferlib.APIUsageError("Targeted experiments require either wandb or neptune")
+
+ # Check if controlled_exp config exists
+ if "controlled_exp" not in args:
+ raise pufferlib.APIUsageError("No [controlled_exp.*] sections found in config")
+
+ # Extract parameters from controlled_exp namespace
+ params = {}
+ for section, section_config in args["controlled_exp"].items():
+ if isinstance(section_config, dict):
+ for param, param_config in section_config.items():
+ if isinstance(param_config, dict) and "values" in param_config:
+ params[f"{section}.{param}"] = param_config["values"]
+
+ if not params:
+ raise pufferlib.APIUsageError("No parameters with 'values' lists found in [controlled_exp.*] sections")
+
+ # Generate all combinations
+ keys = list(params.keys())
+ combinations = list(itertools.product(*[params[k] for k in keys]))
+
+ print(f"Running a total of {len(combinations)} experiments with parameters: {keys}")
+
+ # Run each combination
+ for i, combo in enumerate(combinations, 1):
+ exp_args = deepcopy(args)
+
+ # Set parameters
+ for key, value in zip(keys, combo):
+ section, param = key.split(".")
+ exp_args[section][param] = value
+
+ print(f"\nExperiment {i}/{len(combinations)}: {dict(zip(keys, combo))}")
+
+ # Train
+ train(env_name, args=exp_args)
+
+ print(f"\n✓ Completed all {len(combinations)} experiments")
+
+
+def sanity(env_name, args=None):
+ args = args or load_config(env_name)
+ base_dir = Path(__file__).resolve().parent / "resources" / "drive" / "sanity"
+ json_dir = base_dir / "sanity_jsons"
+ binary_dir = base_dir / "sanity_binaries"
+
+ available_maps = {p.stem: p for p in json_dir.glob("*.json")}
+ selected = args.get("sanity_maps")
+ if isinstance(selected, str):
+ selected = [selected]
+
+ if selected:
+ missing = [name for name in selected if name not in available_maps]
+ if missing:
+ raise pufferlib.APIUsageError(f"Unknown sanity maps: {', '.join(sorted(missing))}")
+ chosen = [(name, available_maps[name]) for name in selected]
+ else:
+ chosen = sorted(available_maps.items())
+
+ if not chosen:
+ raise pufferlib.APIUsageError(f"No sanity maps found in {json_dir}")
+
+ from pufferlib.ocean.drive.drive import load_map
+
+ binary_dir.mkdir(parents=True, exist_ok=True)
+ binaries = []
+ for idx, (name, json_path) in enumerate(chosen):
+ output_path = binary_dir / f"{name}.bin"
+ load_map(str(json_path), idx, str(output_path))
+ binaries.append((name, output_path))
+
+ runs = []
+ for name, binary in binaries:
+ map_zero = binary_dir / "map_000.bin"
+ shutil.copy2(binary, map_zero)
+
+ run_args = {
+ **args,
+ "env": {**args["env"], "num_maps": 1, "map_dir": str(binary_dir)},
+ "train": {**args["train"], "render_map": str(map_zero)},
+ }
+ if run_args.get("wandb"):
+ run_args["wandb_name"] = name
+
+ print(f"Running sanity map '{name}' from {binary.name}")
+ run_logs = train(env_name=env_name, args=run_args)
+ runs.append({"map": name, "logs": run_logs})
+
+ print("Sanity checklist:")
+ for entry in runs:
+ name = entry["map"]
+ logs = entry.get("logs") or []
+ final = logs[-1] if logs else {}
+ score = final.get("environment/score")
+ if score is None:
+ status = "unknown (no score)"
+ elif score >= 0.95:
+ status = "✅ Solved"
+ else:
+ status = "❌ unsolved"
+ print(f" - {name}: {status} (score={score})")
+
+ return runs
+
+
def profile(args=None, env_name=None, vecenv=None, policy=None):
args = load_config()
vecenv = vecenv or load_env(env_name, args)
@@ -1352,31 +1900,63 @@ def export(args=None, env_name=None, vecenv=None, policy=None, path=None, silent
print(f"Saved {len(weights)} weights to {path}")
-def ensure_drive_binary():
- """Delete existing visualize binary and rebuild it. This ensures the
- binary is always up-to-date with the latest code changes.
+def write_run_info(run_dir, config, run_id):
+ """Persist the bits of training config that render/eval need to replay.
+
+ We only record fields that change observation/architecture shape or
+ dataset identity — the things you can't recover from the .pt alone.
+ Existing checkpoints don't have this file; readers must treat it as
+ optional.
"""
- if os.path.exists("./visualize"):
- print("Removing existing visualize binary...")
- try:
- os.remove("./visualize")
- except FileNotFoundError:
- print("Binary not found")
- print("Building visualize binary...")
+ import json
+
+ env_cfg = config.get("env_config", {})
+ info = {
+ "run_id": run_id,
+ "env_name": config.get("env"),
+ "policy_architecture": config.get("policy_architecture"),
+ "rnn_name": config.get("rnn_name"),
+ "env": {
+ "map_dir": env_cfg.get("map_dir"),
+ "num_maps": env_cfg.get("num_maps"),
+ "num_agents": env_cfg.get("num_agents"),
+ "num_ego_agents": env_cfg.get("num_ego_agents"),
+ "k_scenarios": env_cfg.get("k_scenarios", 1),
+ "scenario_length": env_cfg.get("scenario_length", 91),
+ "dynamics_model": env_cfg.get("dynamics_model", "classic"),
+ "co_player_enabled": bool(env_cfg.get("co_player_enabled")),
+ "conditioning": env_cfg.get("conditioning", {}),
+ "co_player_policy": env_cfg.get("co_player_policy", {}),
+ },
+ }
+ info_path = os.path.join(run_dir, "info.json")
try:
- result = subprocess.run(
- ["bash", "scripts/build_ocean.sh", "visualize", "local"], capture_output=True, text=True, timeout=300
- )
-
- if result.returncode == 0:
- print("Successfully built visualize binary")
- else:
- print(f"Build failed: {result.stderr}")
- raise RuntimeError("Failed to build visualize binary for rendering")
- except subprocess.TimeoutExpired:
- raise RuntimeError("Build timed out")
+ with open(info_path + ".tmp", "w") as f:
+ json.dump(info, f, indent=2, default=str)
+ os.rename(info_path + ".tmp", info_path)
except Exception as e:
- raise RuntimeError(f"Build error: {e}")
+ print(f"[info.json] failed to write: {e}")
+
+
+def load_run_info(model_path):
+ """Look up the run_dir/info.json for a given checkpoint path. Returns {} if missing."""
+ import json
+
+ candidates = []
+ parent = os.path.dirname(model_path) or "."
+ candidates.append(os.path.join(parent, "info.json"))
+ # Allow `experiments/puffer_drive_.pt` (flat copy) → look in sibling run dir.
+ base, ext = os.path.splitext(model_path)
+ if ext == ".pt" and os.path.basename(base).startswith("puffer_"):
+ candidates.append(os.path.join(base, "info.json"))
+ for path in candidates:
+ if os.path.exists(path):
+ try:
+ with open(path) as f:
+ return json.load(f)
+ except Exception as e:
+ print(f"[info.json] failed to read {path}: {e}")
+ return {}
def autotune(args=None, env_name=None, vecenv=None, policy=None):
@@ -1402,40 +1982,57 @@ def load_policy(args, vecenv, env_name=""):
env_module = importlib.import_module(module_name)
device = args["train"]["device"]
- policy_cls = getattr(env_module.torch, args["policy_name"])
- policy = policy_cls(vecenv.driver_env, **args["policy"])
-
- rnn_name = args["rnn_name"]
- if rnn_name is not None:
- rnn_cls = getattr(env_module.torch, args["rnn_name"])
- policy = rnn_cls(vecenv.driver_env, policy, **args["rnn"])
-
- policy = policy.to(device)
load_id = args["load_id"]
- if load_id is not None:
+ load_path = args.get("load_model_path")
+ state_dict = None
+ rnn_name = args.get("policy_architecture", "Recurrent")
+
+ if load_path is not None:
+ state_dict = torch.load(load_path, map_location=device)
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ elif load_id is not None:
if args["neptune"]:
path = NeptuneLogger(args, load_id, mode="read-only").download()
elif args["wandb"]:
path = WandbLogger(args, load_id).download()
else:
raise pufferlib.APIUsageError("No run id provided for eval")
-
state_dict = torch.load(path, map_location=device)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
- policy.load_state_dict(state_dict)
- load_path = args["load_model_path"]
- if load_path == "latest":
- load_path = max(glob.glob(f"experiments/{env_name}*.pt"), key=os.path.getctime)
+ # Auto-detect architecture from state_dict keys
+ if state_dict is not None:
+ if "positional_embedding" in state_dict:
+ rnn_name = "Transformer"
+ elif "lstm.weight_ih_l0" in state_dict:
+ rnn_name = "Recurrent"
- if load_path is not None:
- state_dict = torch.load(load_path, map_location=device)
- state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ policy_cls = getattr(env_module.torch, args["policy_name"])
+ policy = policy_cls(vecenv.driver_env, **args["policy"])
+
+ # Handle both RNN and Transformer wrappers via rnn_name
+ if rnn_name == "Transformer":
+ # Load transformer wrapper
+ transformer_cls = getattr(env_module.torch, rnn_name)
+ # For adaptive_driving_agent, use episode_length as horizon (k_scenarios * scenario_length)
+ # Otherwise, use config horizon with fallback to episode_length
+ is_adaptive = getattr(vecenv.driver_env, "env_name", None) == "adaptive_drive"
+ if is_adaptive:
+ args["transformer"]["horizon"] = vecenv.driver_env.episode_length
+ else:
+ args["transformer"]["horizon"] = args["train"].get("horizon", vecenv.driver_env.episode_length)
+ policy = transformer_cls(vecenv.driver_env, policy, **args["transformer"])
+ elif rnn_name is not None:
+ # Load RNN wrapper (Recurrent)
+ rnn_cls = getattr(env_module.torch, rnn_name)
+ policy = rnn_cls(vecenv.driver_env, policy, **args["rnn"])
+
+ policy = policy.to(device)
+
+ # Load the state dict if we have one
+ if state_dict is not None:
policy.load_state_dict(state_dict)
- # state_path = os.path.join(*load_path.split('/')[:-1], 'state.pt')
- # optim_state = torch.load(state_path)['optimizer_state_dict']
- # pufferl.optimizer.load_state_dict(optim_state)
return policy
@@ -1466,7 +2063,7 @@ def load_config(env_name):
parser.add_argument("--neptune-project", type=str, default="ablations")
parser.add_argument("--local-rank", type=int, default=0, help="Used by torchrun for DDP")
parser.add_argument("--tag", type=str, default=None, help="Tag for experiment")
-
+ parser.add_argument("--sanity-maps", nargs="*", default=None, help="Optional list of sanity map base names to run")
args = parser.parse_known_args()[0]
# Load defaults and config
@@ -1517,9 +2114,7 @@ def puffer_type(value):
def main():
- err = (
- "Usage: puffer [train, eval, sweep, autotune, profile, export] [env_name] [optional args]. --help for more info"
- )
+ err = "Usage: puffer [train, eval, sweep, controlled_exp, autotune, profile, export, sanity] [env_name] [optional args]. --help for more info"
if len(sys.argv) < 3:
raise pufferlib.APIUsageError(err)
@@ -1531,12 +2126,16 @@ def main():
eval(env_name=env_name)
elif mode == "sweep":
sweep(env_name=env_name)
+ elif mode == "controlled_exp":
+ controlled_exp(env_name=env_name)
elif mode == "autotune":
autotune(env_name=env_name)
elif mode == "profile":
profile(env_name=env_name)
elif mode == "export":
export(env_name=env_name)
+ elif mode == "sanity":
+ sanity(env_name=env_name)
else:
raise pufferlib.APIUsageError(err)
diff --git a/pufferlib/resources/drive/binaries/map_000.bin b/pufferlib/resources/drive/binaries/map_000.bin
deleted file mode 100644
index 434b98c255..0000000000
Binary files a/pufferlib/resources/drive/binaries/map_000.bin and /dev/null differ
diff --git a/pufferlib/resources/drive/binaries/training/map_000.bin b/pufferlib/resources/drive/binaries/training/map_000.bin
new file mode 100644
index 0000000000..ef87c59af2
Binary files /dev/null and b/pufferlib/resources/drive/binaries/training/map_000.bin differ
diff --git a/pufferlib/resources/drive/puffer_adaptive_drive_co_player.bin b/pufferlib/resources/drive/puffer_adaptive_drive_co_player.bin
new file mode 100644
index 0000000000..c3840a5351
Binary files /dev/null and b/pufferlib/resources/drive/puffer_adaptive_drive_co_player.bin differ
diff --git a/pufferlib/utils.py b/pufferlib/utils.py
index a93c6b2945..7210f54d06 100644
--- a/pufferlib/utils.py
+++ b/pufferlib/utils.py
@@ -7,9 +7,11 @@
def run_human_replay_eval_in_subprocess(config, logger, global_step):
- """
- Run human replay evaluation in a subprocess and log metrics to wandb.
+ """Run human replay evaluation in a subprocess and log metrics to wandb.
+ Routes through `pufferl eval --eval.human-replay-eval True` for both adaptive
+ and non-adaptive agents. The subprocess uses HumanReplayEvaluator, which
+ handles both architectures (LSTM, Transformer) and both agent types.
"""
try:
run_id = logger.run_id
@@ -22,8 +24,19 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step):
latest_cpt = max(model_files, key=os.path.getctime)
- # Prepare evaluation command
- eval_config = config["eval"]
+ env_config = config.get("env_config", {})
+ eval_config = config.get("eval", {})
+ conditioning = env_config.get("conditioning", {})
+ conditioning_type = conditioning.get("type", "none")
+ # Resolve map_dir on the parent so the child can't silently fall back to the ini default
+ map_dir = eval_config.get("map_dir") or env_config.get("map_dir")
+ # Adaptive runs override k_scenarios and the resulting episode length;
+ # the child must use the same values or it will build a model whose
+ # positional_embedding shape doesn't match the trained checkpoint.
+ k_scenarios = env_config.get("k_scenarios", 1)
+ scenario_length = env_config.get("scenario_length", 91)
+ train_horizon = config.get("horizon", scenario_length * k_scenarios)
+
cmd = [
sys.executable,
"-m",
@@ -37,35 +50,77 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step):
"--eval.human-replay-eval",
"True",
"--eval.human-replay-num-agents",
- str(eval_config["human_replay_num_agents"]),
+ str(eval_config.get("human_replay_num_agents", 64)),
+ "--eval.human-replay-num-maps",
+ str(eval_config.get("human_replay_num_maps", 100)),
+ "--eval.human-replay-num-rollouts",
+ str(eval_config.get("human_replay_num_rollouts", 100)),
"--eval.human-replay-control-mode",
- str(eval_config["human_replay_control_mode"]),
+ str(eval_config.get("human_replay_control_mode", "control_vehicles")),
+ *(["--eval.map-dir", str(map_dir)] if map_dir else []),
+ "--eval.num-maps",
+ str(eval_config.get("num_maps", 20)),
+ "--env.k-scenarios",
+ str(k_scenarios),
+ "--env.scenario-length",
+ str(scenario_length),
+ "--train.horizon",
+ str(train_horizon),
+ # Match training goal_behavior. The "stop is cleaner" claim was wrong
+ # in practice — sparse reward under gb=2 produces worse drivers; gb=0
+ # respawn captures real efficiency adaptation in scen_1 vs scen_0.
+ "--env.goal-behavior",
+ "0",
+ "--env.conditioning.type",
+ conditioning_type,
+ "--env.conditioning.collision-weight-lb",
+ str(conditioning.get("collision_weight_lb", -3.0)),
+ "--env.conditioning.collision-weight-ub",
+ str(conditioning.get("collision_weight_ub", -3.0)),
+ "--env.conditioning.offroad-weight-lb",
+ str(conditioning.get("offroad_weight_lb", -1.0)),
+ "--env.conditioning.offroad-weight-ub",
+ str(conditioning.get("offroad_weight_ub", -1.0)),
+ "--env.conditioning.goal-weight-lb",
+ str(conditioning.get("goal_weight_lb", 1.0)),
+ "--env.conditioning.goal-weight-ub",
+ str(conditioning.get("goal_weight_ub", 1.0)),
+ "--env.conditioning.entropy-weight-lb",
+ str(conditioning.get("entropy_weight_lb", 0.001)),
+ "--env.conditioning.entropy-weight-ub",
+ str(conditioning.get("entropy_weight_ub", 0.001)),
+ "--env.conditioning.discount-weight-lb",
+ str(conditioning.get("discount_weight_lb", 0.98)),
+ "--env.conditioning.discount-weight-ub",
+ str(conditioning.get("discount_weight_ub", 0.98)),
]
- # Run human replay evaluation in subprocess
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, cwd=os.getcwd())
- if result.returncode == 0:
- # Extract JSON from stdout between markers
- stdout = result.stdout
- if "HUMAN_REPLAY_METRICS_START" in stdout and "HUMAN_REPLAY_METRICS_END" in stdout:
- start = stdout.find("HUMAN_REPLAY_METRICS_START") + len("HUMAN_REPLAY_METRICS_START")
- end = stdout.find("HUMAN_REPLAY_METRICS_END")
- json_str = stdout[start:end].strip()
- human_replay_metrics = json.loads(json_str)
-
- # Log to wandb if available
- if hasattr(logger, "wandb") and logger.wandb:
- logger.wandb.log(
- {
- "eval/human_replay_collision_rate": human_replay_metrics["collision_rate"],
- "eval/human_replay_offroad_rate": human_replay_metrics["offroad_rate"],
- "eval/human_replay_completion_rate": human_replay_metrics["completion_rate"],
- },
- step=global_step,
- )
- else:
- print(f"Human replay evaluation failed with exit code {result.returncode}: {result.stderr}")
+ if result.returncode != 0:
+ print(f"Human replay evaluation failed (exit {result.returncode}): {result.stderr}")
+ return
+
+ stdout = result.stdout
+ if "HUMAN_REPLAY_METRICS_START" not in stdout or "HUMAN_REPLAY_METRICS_END" not in stdout:
+ return
+
+ start = stdout.find("HUMAN_REPLAY_METRICS_START") + len("HUMAN_REPLAY_METRICS_START")
+ end = stdout.find("HUMAN_REPLAY_METRICS_END")
+ metrics = json.loads(stdout[start:end].strip())
+
+ if not (hasattr(logger, "wandb") and logger.wandb):
+ return
+
+ # Forward every metric the evaluator emitted under eval/human_replay_*.
+ # This includes `_std` (variance across rollouts), eval-scale metadata
+ # (n_rollouts, n_agents_per_rollout, n_total_evals), and the full
+ # ada_delta_*/scenario_* family without an explicit allow-list.
+ log_data = {}
+ for k, v in metrics.items():
+ if isinstance(v, (int, float)):
+ log_data[f"eval/human_replay_{k}"] = v
+ logger.wandb.log(log_data, step=global_step)
except subprocess.TimeoutExpired:
print("Human replay evaluation timed out")
@@ -74,18 +129,7 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step):
def run_wosac_eval_in_subprocess(config, logger, global_step):
- """
- Run WOSAC evaluation in a subprocess and log metrics to wandb.
-
- Args:
- config: Configuration dictionary containing data_dir, env, and wosac settings
- logger: Logger object with run_id and optional wandb attribute
- epoch: Current training epoch
- global_step: Current global training step
-
- Returns:
- None. Prints error messages if evaluation fails.
- """
+ """Run WOSAC realism evaluation in a subprocess and log metrics to wandb."""
try:
run_id = logger.run_id
model_dir = os.path.join(config["data_dir"], f"{config['env']}_{run_id}")
@@ -97,8 +141,11 @@ def run_wosac_eval_in_subprocess(config, logger, global_step):
latest_cpt = max(model_files, key=os.path.getctime)
- # Prepare evaluation command
+ env_config = config.get("env_config", {})
eval_config = config.get("eval", {})
+ # Forward training map_dir so eval doesn't silently fall back to ini default
+ map_dir = eval_config.get("map_dir") or env_config.get("map_dir")
+
cmd = [
sys.executable,
"-m",
@@ -114,7 +161,7 @@ def run_wosac_eval_in_subprocess(config, logger, global_step):
"--eval.wosac-init-mode",
str(eval_config.get("wosac_init_mode", "create_all_valid")),
"--eval.wosac-control-mode",
- str(eval_config.get("wosac_control_mode", "control_tracks_to_predict")),
+ str(eval_config.get("wosac_control_mode", "control_wosac")),
"--eval.wosac-init-steps",
str(eval_config.get("wosac_init_steps", 10)),
"--eval.wosac-goal-behavior",
@@ -125,167 +172,184 @@ def run_wosac_eval_in_subprocess(config, logger, global_step):
str(eval_config.get("wosac_sanity_check", False)),
"--eval.wosac-aggregate-results",
str(eval_config.get("wosac_aggregate_results", True)),
+ *(["--eval.map-dir", str(map_dir)] if map_dir else []),
]
- # Run WOSAC evaluation in subprocess
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, cwd=os.getcwd())
- if result.returncode == 0:
- # Extract JSON from stdout between markers
- stdout = result.stdout
- if "WOSAC_METRICS_START" in stdout and "WOSAC_METRICS_END" in stdout:
- start = stdout.find("WOSAC_METRICS_START") + len("WOSAC_METRICS_START")
- end = stdout.find("WOSAC_METRICS_END")
- json_str = stdout[start:end].strip()
- wosac_metrics = json.loads(json_str)
-
- # Log to wandb if available
- if hasattr(logger, "wandb") and logger.wandb:
- logger.wandb.log(
- {
- "eval/wosac_realism_meta_score": wosac_metrics["realism_meta_score"],
- "eval/wosac_ade": wosac_metrics["ade"],
- "eval/wosac_min_ade": wosac_metrics["min_ade"],
- "eval/wosac_total_num_agents": wosac_metrics["total_num_agents"],
- },
- step=global_step,
- )
- else:
- print(f"WOSAC evaluation failed with exit code {result.returncode}: {result.stderr}")
+ if result.returncode != 0:
+ print(f"WOSAC evaluation failed (exit {result.returncode}): {result.stderr}")
+ stderr_lower = result.stderr.lower()
+ if "out of memory" in stderr_lower:
+ print("GPU OOM during WOSAC eval; skipping.")
+ return
+
+ stdout = result.stdout
+ if "WOSAC_METRICS_START" not in stdout or "WOSAC_METRICS_END" not in stdout:
+ return
+
+ start = stdout.find("WOSAC_METRICS_START") + len("WOSAC_METRICS_START")
+ end = stdout.find("WOSAC_METRICS_END")
+ metrics = json.loads(stdout[start:end].strip())
+
+ if hasattr(logger, "wandb") and logger.wandb:
+ logger.wandb.log(
+ {
+ "eval/wosac_realism_meta_score": metrics["realism_meta_score"],
+ "eval/wosac_ade": metrics["ade"],
+ "eval/wosac_min_ade": metrics["min_ade"],
+ "eval/wosac_total_num_agents": metrics["total_num_agents"],
+ },
+ step=global_step,
+ )
except subprocess.TimeoutExpired:
- print("WOSAC evaluation timed out")
+ print("WOSAC evaluation timed out after 600 seconds")
except Exception as e:
- print(f"Failed to run WOSAC evaluation: {e}")
+ print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}")
-def render_videos(config, vecenv, logger, global_step, bin_path):
- """
- Generate and log training videos using C-based rendering.
+_VIEW_NAMES = {0: "sim_state", 1: "bev", 2: "persp"}
+
+
+def render_videos(config, policy, logger, epoch, global_step, device="cuda", human_replay=False):
+ """Generate and log videos via Python rollout (works with any policy architecture).
+
+ Policy inference is in PyTorch; rendering goes through the C bindings
+ (`vec_render`, `vec_set_video_suffix`). Saves under
+ /_/renders/ with names that include the map_id,
+ mode, and view so multiple modes on the same map don't collide.
Args:
- config: Configuration dictionary containing data_dir, env, and render settings
- vecenv: Vectorized environment with driver_env attribute
- logger: Logger object with run_id and optional wandb attribute
- epoch: Current training epoch
- global_step: Current global training step
- bin_path: Path to the exported .bin model weights file
-
- Returns:
- None. Prints error messages if rendering fails.
+ config: PuffeRL flat train config.
+ policy: PyTorch policy (LSTM or Transformer wrapper).
+ logger: Logger with run_id and optional .wandb.
+ epoch: Current training epoch.
+ global_step: Current global step (for wandb step alignment).
+ device: Inference device.
+ human_replay: If True, render in human-replay mode (1 ego, others = logs).
"""
- if not os.path.exists(bin_path):
- print(f"Binary weights file does not exist: {bin_path}")
- return
+ import copy
+ import torch
+ from pufferlib.pufferl import load_env
+ from pufferlib.ocean.drive.rollout import RenderContext, RenderView, rollout_loop
- run_id = logger.run_id
- model_dir = os.path.join(config["data_dir"], f"{config['env']}_{run_id}")
-
- # Now call the C rendering function
try:
- # Create output directory for videos
- video_output_dir = os.path.join(model_dir, "videos")
+ run_id = logger.run_id
+ env_name = config.get("env", "drive")
+ model_dir = os.path.join(config["data_dir"], f"{env_name}_{run_id}")
+ video_output_dir = os.path.join(model_dir, "renders")
os.makedirs(video_output_dir, exist_ok=True)
- # Copy the binary weights to the expected location
- expected_weights_path = "resources/drive/puffer_drive_weights.bin"
- os.makedirs(os.path.dirname(expected_weights_path), exist_ok=True)
- shutil.copy2(bin_path, expected_weights_path)
-
- # TODO: Fix memory leaks so that this is not needed
- # Suppress AddressSanitizer exit code (temp)
- env = os.environ.copy()
- env["ASAN_OPTIONS"] = "exitcode=0"
-
- cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"]
-
- # Add render configurations
- if config["show_grid"]:
- cmd.append("--show-grid")
- if config["obs_only"]:
- cmd.append("--obs-only")
- if config["show_lasers"]:
- cmd.append("--lasers")
- if config["show_human_logs"]:
- cmd.append("--log-trajectories")
- if vecenv.driver_env.goal_radius is not None:
- cmd.extend(["--goal-radius", str(vecenv.driver_env.goal_radius)])
- if vecenv.driver_env.init_steps > 0:
- cmd.extend(["--init-steps", str(vecenv.driver_env.init_steps)])
- if config["render_map"] is not None:
- map_path = config["render_map"]
- if os.path.exists(map_path):
- cmd.extend(["--map-name", map_path])
- if vecenv.driver_env.init_mode is not None:
- cmd.extend(["--init-mode", str(vecenv.driver_env.init_mode)])
- if vecenv.driver_env.control_mode is not None:
- cmd.extend(["--control-mode", str(vecenv.driver_env.control_mode)])
-
- if hasattr(vecenv.driver_env, "reward_conditioned"):
- cmd.extend(["--use-rc", "1" if vecenv.driver_env.reward_conditioned else "0"])
- if hasattr(vecenv.driver_env, "entropy_conditioned"):
- cmd.extend(["--use-ec", "1" if vecenv.driver_env.entropy_conditioned else "0"])
- if hasattr(vecenv.driver_env, "discount_conditioned"):
- cmd.extend(["--use-dc", "1" if vecenv.driver_env.discount_conditioned else "0"])
-
- # Specify output paths for videos
- cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"])
- cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"])
-
- # Add environment configuration
- env_cfg = getattr(vecenv, "driver_env", None)
- if env_cfg is not None:
- n_policy = getattr(env_cfg, "max_controlled_agents", -1)
+ view_modes = config.get("render_view_modes", [RenderView.FULL_SIM_STATE])
+ if isinstance(view_modes, int):
+ view_modes = [view_modes]
+
+ env_kwargs = copy.deepcopy(config.get("env_config", {}))
+ env_kwargs["render_mode"] = 1 # RENDER_HEADLESS
+ # Render env runs alongside training and has to fit in the same VRAM /
+ # RAM budget — override the training num_agents (often 1024+) down to a
+ # render-sized footprint so we don't OOM on first render call.
+ env_kwargs["num_agents"] = min(env_kwargs.get("num_agents", 64), 64)
+ if env_kwargs.get("num_ego_agents") is not None:
+ env_kwargs["num_ego_agents"] = min(env_kwargs["num_ego_agents"], 32)
+ env_kwargs["num_maps"] = min(env_kwargs.get("num_maps", 500), 500)
+ # Force per-env inline co-player inference for the render env. Training
+ # runs with external_co_player_actions=True expect PufferL's centralized
+ # GPU inference path to write co-player actions into shared memory each
+ # step. The render env is a single Serial env with no central path
+ # attached — without this override the SHM is never written, co-players
+ # freeze, and training-style renders show stuck partners.
+ env_kwargs["external_co_player_actions"] = False
+
+ if human_replay:
+ env_kwargs["co_player_enabled"] = False
+ env_kwargs["max_controlled_agents"] = 1
+ # Inherit goal_behavior from training config (no override) so the
+ # render matches whatever regime the policy actually trained under.
+ if "adaptive" in env_name:
+ env_kwargs["human_replay_mode"] = True
+
+ # Force Serial backend for render: raylib's GLFW needs DISPLAY, which
+ # xvfb-run sets in the parent process. A Multiprocessing worker would
+ # spawn without DISPLAY and segfault on InitWindow.
+ render_args = {
+ "env": env_kwargs,
+ "vec": {"num_envs": 1, "backend": "Serial"},
+ "package": config.get("package", "ocean"),
+ }
+
+ use_rnn = config.get("use_rnn", False)
+ episode_length = env_kwargs.get("scenario_length", 91)
+ k_scenarios = env_kwargs.get("k_scenarios", 1)
+ if k_scenarios > 1:
+ episode_length = k_scenarios * episode_length
+
+ mode = "human_replay" if human_replay else ("coplayer" if env_kwargs.get("co_player_enabled") else "baseline")
+ videos_to_log_world = []
+ videos_to_log_agent = []
+
+ for view_mode in view_modes:
+ render_env = load_env(env_name, render_args)
try:
- n_policy = int(n_policy)
- except (TypeError, ValueError):
- n_policy = -1
- if n_policy > 0:
- cmd += ["--num-policy-controlled-agents", str(n_policy)]
- if getattr(env_cfg, "num_maps", False):
- cmd.extend(["--num-maps", str(env_cfg.num_maps)])
- if getattr(env_cfg, "scenario_length", None):
- cmd.extend(["--scenario-length", str(env_cfg.scenario_length)])
-
- # Call C code that runs eval_gif() in subprocess
- result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=120, env=env)
-
- vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists(
- "resources/drive/output_agent.mp4"
- )
-
- if result.returncode == 0 or (result.returncode == 1 and vids_exist):
- # Move both generated videos to the model directory
- videos = [
- ("resources/drive/output_topdown.mp4", f"step_{global_step:09d}_topdown.mp4"),
- ("resources/drive/output_agent.mp4", f"step_{global_step:09d}_agent.mp4"),
- ]
-
- for source_vid, target_filename in videos:
- if os.path.exists(source_vid):
- target_gif = os.path.join(video_output_dir, target_filename)
- shutil.move(source_vid, target_gif)
-
- # Log to wandb if available
- if hasattr(logger, "wandb") and logger.wandb:
- import wandb
-
- view_type = "world_state" if "topdown" in target_filename else "agent_view"
- logger.wandb.log(
- {f"render/{view_type}": wandb.Video(target_gif, format="mp4")},
- step=global_step,
- )
+ driver = render_env.driver_env
+ map_ids = getattr(driver, "map_ids", None)
+ map_id = int(map_ids[0]) if map_ids is not None and len(map_ids) > 0 else 0
+ view = _VIEW_NAMES.get(int(view_mode), "view")
+ basename = f"epoch_{epoch:06d}_{mode}_k{k_scenarios}_map{map_id:03d}_{view}"
+
+ # Tell the env to keep raylib + ffmpeg alive across map swaps so
+ # the in-step _reinit_envs_with_new_maps() at scenario boundaries
+ # doesn't kill the render. Single mp4 captures all k scenarios
+ # with the maps rotating mid-stream.
+ if getattr(driver, "map_rand_per_scenario", False):
+ driver._render_keep_client_on_swap = True
+
+ policy.eval()
+ rollout_loop(
+ policy=policy,
+ env=render_env,
+ device=device,
+ use_rnn=use_rnn,
+ max_steps=episode_length,
+ render_ctx=RenderContext(
+ view_mode=view_mode,
+ env_id=0,
+ draw_traces=True,
+ video_basename=basename,
+ ),
+ )
+ finally:
+ render_env.close()
+
+ src = f"{basename}.mp4"
+ if not os.path.exists(src):
+ print(f"render: expected {src} not produced")
+ continue
+
+ target_path = os.path.join(video_output_dir, src)
+ shutil.move(src, target_path)
+
+ if hasattr(logger, "wandb") and logger.wandb:
+ import wandb
+
+ if view == "sim_state":
+ videos_to_log_world.append(wandb.Video(target_path, format="mp4"))
else:
- print(f"Video generation completed but {source_vid} not found")
- else:
- print(f"C rendering failed with exit code {result.returncode}: {result.stdout}")
+ videos_to_log_agent.append(wandb.Video(target_path, format="mp4"))
+
+ if hasattr(logger, "wandb") and logger.wandb and (videos_to_log_world or videos_to_log_agent):
+ payload = {}
+ world_key = "eval/human_replay_world_view" if human_replay else "render/world_state"
+ agent_key = "eval/human_replay_agent_view" if human_replay else "render/agent_view"
+ if videos_to_log_world:
+ payload[world_key] = videos_to_log_world
+ if videos_to_log_agent:
+ payload[agent_key] = videos_to_log_agent
+ logger.wandb.log(payload, step=global_step)
- except subprocess.TimeoutExpired:
- print("C rendering timed out")
except Exception as e:
- print(f"Failed to generate GIF: {e}")
+ print(f"Failed to render videos: {e}")
+ import traceback
- finally:
- # Clean up bin weights file
- if os.path.exists(expected_weights_path):
- os.remove(expected_weights_path)
+ traceback.print_exc()
diff --git a/pufferlib/vector.py b/pufferlib/vector.py
index 24ac492405..deee21d9f9 100644
--- a/pufferlib/vector.py
+++ b/pufferlib/vector.py
@@ -104,6 +104,13 @@ def __init__(self, env_creators, env_args, env_kwargs, num_envs, buf=None, seed=
self.initialized = False
self.flag = RESET
+ # Handle population play mode (ego agents only controlled by policy)
+ self.population_play = getattr(self.driver_env, "population_play", False)
+ if self.population_play:
+ ego_agents_per_batch = self.driver_env.num_ego_agents * num_envs
+ self.num_ego_agents = ego_agents_per_batch
+ self.ego_action_space = pufferlib.spaces.joint_space(self.single_action_space, ego_agents_per_batch)
+
def _avg_infos(self):
infos = {}
for e in self.infos:
@@ -334,6 +341,9 @@ def __init__(
self.agents_per_batch = driver_env.num_agents * batch_size
agents_per_worker = driver_env.num_agents * envs_per_worker
+ # Persisted on the vecenv so PuffeRL can map per-recv `env_id` back
+ # to a worker index for centralized co-player inference.
+ self.agents_per_worker = agents_per_worker
obs_space = driver_env.single_observation_space
obs_shape = obs_space.shape
self.obs_shape = obs_shape
@@ -398,6 +408,49 @@ def __init__(
self.atn_batch_shape = (self.workers_per_batch, agents_per_worker, *atn_shape)
self.actions = np.ndarray((*shape, *atn_shape), dtype=atn_dtype, buffer=self.shm["actions"])
+ # ---- Centralized GPU co-player conditioning SHM ----
+ # When `external_co_player_actions` is set in env_kwargs AND the
+ # co-player has non-zero conditioning dims, we allocate a per-worker
+ # buffer so the env (which still samples conditioning at scenario
+ # boundaries) can deposit values for the main process to read before
+ # each forward pass. Sized to the worst-case `co_players_per_worker`.
+ env_k0 = env_kwargs[0] if env_kwargs else {}
+ external_coplayer_flag = env_k0.get("external_co_player_actions", False) and env_k0.get(
+ "co_player_enabled", False
+ )
+ co_player_conditioning_dim = 0
+ if external_coplayer_flag:
+ cond = env_k0.get("co_player_policy", {}).get("conditioning", {}) or {}
+ ctype = cond.get("type", "none")
+ co_player_conditioning_dim = (
+ (3 if ctype in ("reward", "all") else 0)
+ + (1 if ctype in ("entropy", "all") else 0)
+ + (1 if ctype in ("discount", "all") else 0)
+ )
+ if self.population_play and co_player_conditioning_dim > 0:
+ co_players_per_worker = agents_per_worker - ego_agents_per_worker
+ self.shm["co_player_conditioning"] = RawArray(
+ "f", num_workers * co_players_per_worker * co_player_conditioning_dim
+ )
+ self.co_player_conditioning = np.ndarray(
+ (num_workers, co_players_per_worker, co_player_conditioning_dim),
+ dtype=np.float32,
+ buffer=self.shm["co_player_conditioning"],
+ )
+ # CRITICAL: pufferlib.vector.make() builds env_kwargs as
+ # `[env_kwargs] * num_envs`, which is a list of N references
+ # to the SAME dict. Mutating env_kwargs[i] modifies all
+ # entries. We have to replace each slot with a per-env
+ # copy before adding worker_idx / SHM-slice entries.
+ for i in range(len(env_kwargs)):
+ w_idx = i // envs_per_worker
+ env_kwargs[i] = {
+ **env_kwargs[i],
+ "worker_idx": w_idx,
+ "co_player_conditioning_shm": self.co_player_conditioning[w_idx],
+ }
+ self._co_player_conditioning_dim = co_player_conditioning_dim
+
self.buf = dict(
observations=np.ndarray((*shape, *obs_shape), dtype=obs_dtype, buffer=self.shm["observations"]),
rewards=np.ndarray(shape, dtype=np.float32, buffer=self.shm["rewards"]),
@@ -831,6 +884,13 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer
# TODO: First step action space check
env_k = env_kwargs[0]
+ # When external_co_player_actions is set, the *main* process owns the
+ # co-player policy on GPU and writes actions into the shared-memory
+ # action buffer at co_player slots before vec_step. Workers don't load
+ # the model and don't need single-thread CPU mode. We still build the
+ # policy here (to ship it to main via vecenv.co_player_policy_func),
+ # but on GPU and not stuffed into env_kwargs for workers.
+ external_coplayer = env_k.get("external_co_player_actions", False) and env_k.get("co_player_enabled", False)
if env_k.get("co_player_enabled", False):
import torch
import os
@@ -838,19 +898,18 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer
import gymnasium
from pufferlib.ocean.torch import Drive
import pufferlib.models
+ from pufferlib.ocean.drive import binding
- dynamics_model = env_k.get("dynamics_model", "jerk")
- # Observation space calculation
- if dynamics_model == "classic":
- ego_features = 7
- elif dynamics_model == "jerk":
- ego_features = 10
+ dynamics_model = env_k.get("dynamics_model", "classic")
+ action_type = env_k.get("action_type", "discrete")
co_player_policy = env_k["co_player_policy"]
input_size = co_player_policy.get("input_size", 256)
hidden_size = co_player_policy.get("hidden_size", 256)
co_player_rnn = co_player_policy.get("rnn", None)
+ co_player_architecture = co_player_policy.get("architecture", "Recurrent")
+ co_player_transformer = co_player_policy.get("transformer", {})
# Get conditioning type from env_k
co_player_conditioning = co_player_policy.get("conditioning")
@@ -858,27 +917,79 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer
reward_conditioned = condition_type in ("reward", "all")
entropy_conditioned = condition_type in ("entropy", "all")
discount_conditioned = condition_type in ("discount", "all")
- # Calculate conditioning dimensions
+
+ if action_type == "discrete":
+ if dynamics_model == "classic":
+ # Joint action space (assume dependence)
+ single_action_space = gymnasium.spaces.MultiDiscrete([7 * 13])
+ # Multi discrete (assume independence)
+ # self.single_action_space = gymnasium.spaces.MultiDiscrete([7, 13])
+ elif dynamics_model == "jerk":
+ # Joint action space (assume dependence) - 4 longitudinal × 3 lateral = 12
+ single_action_space = gymnasium.spaces.MultiDiscrete([4 * 3])
+ else:
+ raise ValueError(f"dynamics_model must be 'classic' or 'jerk'. Got: {dynamics_model}")
+ elif action_type == "continuous":
+ single_action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
+ else:
+ raise ValueError(f"action_space must be 'discrete' or 'continuous'. Got: {action_type}")
+
+ # # Observation space calculation
+ ego_features = {"classic": binding.EGO_FEATURES_CLASSIC, "jerk": binding.EGO_FEATURES_JERK}.get(dynamics_model)
+
conditioning_dims = (
(3 if reward_conditioned else 0) + (1 if entropy_conditioned else 0) + (1 if discount_conditioned else 0)
)
- # Base observations + conditioning observations
- num_obs = ego_features + conditioning_dims + 63 * 7 + 200 * 7
- temp_env = SimpleNamespace(
- single_action_space=gymnasium.spaces.MultiDiscrete([7 * 13]),
- single_observation_space=gymnasium.spaces.Box(low=-1, high=1, shape=(num_obs,), dtype=np.float32),
+ ego_features += conditioning_dims
+
+ # # Extract observation shapes from constants
+ # # These need to be defined in C, since they determine the shape of the arrays
+ # max_road_objects = 200
+ # max_partner_objects = 63
+ # partner_features = 7
+ # road_features = 7
+
+ # Extract observation shapes from constants
+ # These need to be defined in C, since they determine the shape of the arrays
+ max_road_objects = binding.MAX_ROAD_SEGMENT_OBSERVATIONS
+ max_partner_objects = binding.MAX_AGENTS - 1
+ partner_features = binding.PARTNER_FEATURES
+ road_features = binding.ROAD_FEATURES
+
+ num_obs = ego_features + max_partner_objects * partner_features + max_road_objects * road_features
+
+ single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(num_obs,), dtype=np.float32)
+
+ co_player_env = SimpleNamespace(
+ single_action_space=single_action_space,
+ single_observation_space=single_observation_space,
reward_conditioned=reward_conditioned,
entropy_conditioned=entropy_conditioned,
discount_conditioned=discount_conditioned,
dynamics_model=dynamics_model, ## keep these the same I think, multiple dynamics models could get weird
+ max_partner_objects=max_partner_objects,
+ partner_features=partner_features,
+ max_road_objects=max_road_objects,
+ road_features=road_features,
)
- base_policy = Drive(temp_env, input_size=input_size, hidden_size=hidden_size)
+ base_policy = Drive(co_player_env, input_size=input_size, hidden_size=hidden_size)
- if co_player_rnn:
+ if co_player_architecture == "Transformer":
+ policy = pufferlib.models.TransformerWrapper(
+ co_player_env,
+ base_policy,
+ input_size=co_player_transformer.get("input_size", 256),
+ hidden_size=co_player_transformer.get("hidden_size", 256),
+ num_layers=co_player_transformer.get("num_layers", 2),
+ num_heads=co_player_transformer.get("num_heads", 4),
+ horizon=co_player_transformer.get("horizon", 91),
+ dropout=co_player_transformer.get("dropout", 0.0),
+ )
+ elif co_player_rnn:
policy = pufferlib.models.LSTMWrapper(
- temp_env,
+ co_player_env,
base_policy,
input_size=co_player_rnn.get("input_size"),
hidden_size=co_player_rnn.get("hidden_size"),
@@ -891,37 +1002,68 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
- state_dict = torch.load(checkpoint_path, map_location="cpu")
+ state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
policy.load_state_dict(state_dict, strict=True)
- policy.eval()
+ if external_coplayer:
+ # Main owns the co-player on GPU. Don't pin to CPU; don't pass to
+ # workers. We hand the (still-on-CPU) policy to the caller via the
+ # vecenv attribute below — caller will move it to its own device.
+ policy.eval()
+ else:
+ policy = policy.to("cpu") # Ensure all buffers are on CPU for forked subprocesses
+ policy.eval()
print(
- f"Co player policy loaded with {conditioning_dims} conditioning dims (condition_type={condition_type})",
+ f"Co player policy loaded with {conditioning_dims} conditioning dims "
+ f"(condition_type={condition_type}, external={external_coplayer})",
flush=True,
)
- # Store policy and conditioning info in env_k
- env_k["co_player_policy"]["co_player_policy_func"] = policy
- torch.set_num_threads(
- 1
- ) # NOTE this is the only way I could get co-player policies to work inside environment evaluation
- torch.set_num_interop_threads(1)
- import os
+ if not external_coplayer:
+ # Per-worker CPU path (legacy): hand the policy to env_kwargs so
+ # each forked worker sees it via env_k["co_player_policy"][...].
+ env_k["co_player_policy"]["co_player_policy_func"] = policy
- os.environ["OMP_NUM_THREADS"] = "1"
- os.environ["MKL_NUM_THREADS"] = "1"
- os.environ["NUMEXPR_NUM_THREADS"] = "1"
+ # NOTE: Setting threads to 1 is required for co-player policies to work
+ # inside environment evaluation. Higher values cause deadlock.
+ # set_num_interop_threads can only be called once per process; on
+ # subsequent vector-construction calls (e.g. multiple renders) it
+ # raises RuntimeError. Guard so that's idempotent.
+ try:
+ torch.set_num_threads(1)
+ except RuntimeError:
+ pass
+ try:
+ torch.set_num_interop_threads(1)
+ except RuntimeError:
+ pass
- # Disable MKL if available
- try:
- torch.backends.mkl.enabled = False
- except:
- pass
+ os.environ["OMP_NUM_THREADS"] = "1"
+ os.environ["MKL_NUM_THREADS"] = "1"
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
- for i in range(len(env_kwargs)):
- env_kwargs[i]["co_player_policy"]["co_player_policy_func"] = policy
+ # Disable MKL if available
+ try:
+ torch.backends.mkl.enabled = False
+ except:
+ pass
- return backend(env_creators, env_args, env_kwargs, num_envs, **kwargs)
+ for i in range(len(env_kwargs)):
+ env_kwargs[i]["co_player_policy"]["co_player_policy_func"] = policy
+ else:
+ # External path: workers should NOT carry the policy. Ensure their
+ # env_kwargs don't accidentally hold a stale reference.
+ for i in range(len(env_kwargs)):
+ env_kwargs[i]["co_player_policy"]["co_player_policy_func"] = None
+
+ vecenv = backend(env_creators, env_args, env_kwargs, num_envs, **kwargs)
+ if env_k.get("co_player_enabled", False) and external_coplayer:
+ # Stash the GPU-bound co-player policy + the per-worker conditioning
+ # dimension on the vecenv so PuffeRL can pick them up in __init__.
+ vecenv.co_player_policy_func = policy
+ vecenv.co_player_conditioning_dims = conditioning_dims
+ vecenv.co_player_condition_type = condition_type
+ return vecenv
def make_seeds(seed, num_envs):
diff --git a/pyproject.toml b/pyproject.toml
index 4bcc818849..72a1b34949 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -117,8 +117,10 @@ metta = [
'hydra-core',
'duckdb',
'raylib>=5.5.0',
- 'metta-common @ git+https://github.com/metta-ai/metta.git@main#subdirectory=common',
- 'metta-mettagrid @ git+https://github.com/metta-ai/metta.git@main#subdirectory=mettagrid',
+ # 'metta-common @ git+ssh://git@github.com/metta-ai/metta.git@main#subdirectory=common',
+ # 'metta-mettagrid @ git+ssh://git@github.com/metta-ai/metta.git@main#subdirectory=mettagrid',
+ 'mettagrid'
+
]
microrts = [
diff --git a/render.py b/render.py
new file mode 100644
index 0000000000..4742f0c1f2
--- /dev/null
+++ b/render.py
@@ -0,0 +1,332 @@
+#!/usr/bin/env python
+"""Unified Python rendering CLI for PufferDrive.
+
+Replaces the old C `./visualize` binary, `render_test.py`, and the rendering
+side of `evaluate_human_logs.py`. Works with both LSTM and Transformer policies
+on both nuPlan and WOMD datasets.
+
+Modes
+-----
+ Baseline (no co-players):
+ python render.py --model-path X.pt
+ With a frozen co-player population:
+ python render.py --model-path adaptive.pt --co-player-path coplayer.pt
+ Human replay (one ego, others follow logged trajectories):
+ python render.py --model-path X.pt --human-replay
+
+Architecture is auto-detected from the checkpoint state-dict, but can be
+overridden with --policy-architecture.
+"""
+
+import argparse
+import copy
+import glob
+import os
+import shutil
+import sys
+
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+from pufferlib.pufferl import load_config, load_env, load_policy
+from pufferlib.ocean.drive.rollout import RenderContext, RenderView, rollout_loop
+
+
+VIEW_MODE_BY_NAME = {
+ "sim_state": RenderView.FULL_SIM_STATE,
+ "bev": RenderView.BEV_AGENT_OBS,
+ "persp": RenderView.AGENT_PERSPECTIVE,
+}
+
+VIEW_NAME = {
+ RenderView.FULL_SIM_STATE: "sim_state",
+ RenderView.BEV_AGENT_OBS: "bev",
+ RenderView.AGENT_PERSPECTIVE: "persp",
+}
+
+
+def model_id_from_path(path):
+ """Short, readable id derived from the checkpoint filename or its run dir."""
+ fname = os.path.splitext(os.path.basename(path))[0]
+ # If the file lives inside a run directory like puffer_drive_/,
+ # prefer that (handles intermediate model_*.pt files).
+ parent = os.path.basename(os.path.dirname(path) or "")
+ candidate = parent if parent.startswith("puffer_") else fname
+ return candidate.replace("puffer_adaptive_drive_", "").replace("puffer_drive_", "")
+
+
+def run_dir_for(model_path):
+ """Locate the experiment directory that owns this checkpoint."""
+ parent = os.path.dirname(model_path)
+ parent_name = os.path.basename(parent or "")
+ if parent_name.startswith("puffer_"):
+ return parent
+ # File is `/puffer_..._.pt`: matching run dir sits next to it.
+ fname = os.path.splitext(os.path.basename(model_path))[0]
+ if fname.startswith("puffer_"):
+ candidate = os.path.join(parent, fname)
+ if os.path.isdir(candidate):
+ return candidate
+ return None
+
+
+def default_output_dir(model_path):
+ """Default to /renders so artifacts live with the experiment."""
+ run_dir = run_dir_for(model_path)
+ if run_dir:
+ return os.path.join(run_dir, "renders")
+ return os.path.join("./render_output", model_id_from_path(model_path))
+
+
+def detect_architecture(model_path, device="cpu"):
+ state_dict = torch.load(model_path, map_location=device)
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ if "positional_embedding" in state_dict:
+ return "Transformer"
+ if "lstm.weight_ih_l0" in state_dict:
+ return "Recurrent"
+ return None
+
+
+def build_config(args):
+ """Build the env/vec/policy config dict for one render."""
+ if args.adaptive or args.k_scenarios > 1 or args.co_player_path is not None:
+ env_name = "puffer_adaptive_drive"
+ else:
+ env_name = "puffer_drive"
+
+ saved_argv = sys.argv
+ sys.argv = [sys.argv[0]]
+ try:
+ config = load_config(env_name)
+ finally:
+ sys.argv = saved_argv
+
+ arch = args.policy_architecture or detect_architecture(args.model_path) or "Recurrent"
+ config["policy_architecture"] = arch
+ config["rnn_name"] = arch
+ config["use_rnn"] = True
+ config["load_model_path"] = args.model_path
+ config["vec"] = {"backend": "Serial", "num_envs": 1}
+
+ config["env"]["render_mode"] = 1 # RENDER_HEADLESS
+ config["env"]["map_dir"] = args.map_dir
+ config["env"]["num_maps"] = args.num_maps
+ config["env"]["num_agents"] = args.num_agents
+ config["env"]["num_ego_agents"] = args.num_ego_agents
+ config["env"]["k_scenarios"] = args.k_scenarios
+ config["env"]["scenario_length"] = args.scenario_length
+
+ if args.human_replay:
+ if env_name == "puffer_adaptive_drive":
+ config["env"]["human_replay_mode"] = True
+ config["env"]["co_player_enabled"] = False
+ config["env"]["max_controlled_agents"] = 1
+ elif args.co_player_path is not None:
+ config["env"]["co_player_enabled"] = True
+ cpp = config["env"].setdefault("co_player_policy", {})
+ cpp["policy_path"] = args.co_player_path
+ cpp["architecture"] = args.co_player_architecture or detect_architecture(args.co_player_path) or "Recurrent"
+ cpp_cond = cpp.setdefault("conditioning", {})
+ cpp_cond["type"] = args.co_player_conditioning_type
+ cpp_cond["collision_weight_lb"] = args.co_player_collision_weight_lb
+ cpp_cond["collision_weight_ub"] = args.co_player_collision_weight_ub
+ cpp_cond["offroad_weight_lb"] = args.co_player_offroad_weight_lb
+ cpp_cond["offroad_weight_ub"] = args.co_player_offroad_weight_ub
+ cpp_cond["goal_weight_lb"] = args.co_player_goal_weight_lb
+ cpp_cond["goal_weight_ub"] = args.co_player_goal_weight_ub
+ cpp_cond["entropy_weight_lb"] = args.co_player_entropy_weight_lb
+ cpp_cond["entropy_weight_ub"] = args.co_player_entropy_weight_ub
+ cpp_cond["discount_weight_lb"] = args.co_player_discount_weight_lb
+ cpp_cond["discount_weight_ub"] = args.co_player_discount_weight_ub
+ else:
+ config["env"]["co_player_enabled"] = False
+
+ cond = config["env"].setdefault("conditioning", {})
+ cond["type"] = args.conditioning_type
+ cond["collision_weight_lb"] = args.collision_weight_lb
+ cond["collision_weight_ub"] = args.collision_weight_ub
+ cond["offroad_weight_lb"] = args.offroad_weight_lb
+ cond["offroad_weight_ub"] = args.offroad_weight_ub
+ cond["goal_weight_lb"] = args.goal_weight_lb
+ cond["goal_weight_ub"] = args.goal_weight_ub
+ cond["entropy_weight_lb"] = args.entropy_weight_lb
+ cond["entropy_weight_ub"] = args.entropy_weight_ub
+ cond["discount_weight_lb"] = args.discount_weight_lb
+ cond["discount_weight_ub"] = args.discount_weight_ub
+
+ config["train"]["device"] = args.device
+ return env_name, config
+
+
+def mode_tag(args):
+ if args.human_replay:
+ return "human_replay"
+ if args.co_player_path is not None:
+ return "coplayer"
+ return "baseline"
+
+
+def render_one(env_name, base_config, view_modes, render_idx, seed, args):
+ """One render = (one map seed × all view modes), one rollout per view."""
+ print(f"\n[Render {render_idx + 1}/{args.num_renders}] map_seed={seed}")
+ cfg = copy.deepcopy(base_config)
+ cfg["env"]["map_seed"] = seed
+
+ vecenv = load_env(env_name, cfg)
+ try:
+ policy = load_policy(cfg, vecenv, env_name)
+ policy.eval()
+
+ # Pull the actual map id loaded — beats relying on the seed only
+ map_ids = getattr(vecenv.driver_env, "map_ids", None)
+ map_id = int(map_ids[0]) if map_ids is not None and len(map_ids) > 0 else seed
+
+ model_id = model_id_from_path(args.model_path)
+ if args.co_player_path is not None and not args.human_replay:
+ mode = f"vs_{model_id_from_path(args.co_player_path)}"
+ else:
+ mode = mode_tag(args)
+ coplayer_part = ""
+
+ max_steps = args.max_steps if args.max_steps is not None else (args.k_scenarios * args.scenario_length)
+ os.makedirs(args.output_dir, exist_ok=True)
+ saved = []
+
+ for view_mode in view_modes:
+ view = VIEW_NAME[view_mode]
+ basename = f"{model_id}_{mode}{coplayer_part}_k{args.k_scenarios}_map{map_id:03d}_{view}"
+ vecenv.reset(seed=seed)
+ rollout_loop(
+ policy=policy,
+ env=vecenv,
+ device=args.device,
+ use_rnn=True,
+ max_steps=max_steps,
+ render_ctx=RenderContext(
+ view_mode=view_mode,
+ env_id=0,
+ draw_traces=True,
+ video_basename=basename,
+ ),
+ )
+ src = f"{basename}.mp4"
+ if os.path.exists(src):
+ target = os.path.join(args.output_dir, src)
+ shutil.move(src, target)
+ print(f" saved {target}")
+ saved.append(target)
+ else:
+ print(f" WARNING: expected {src} not produced")
+ finally:
+ vecenv.close()
+
+ return saved
+
+
+def main():
+ p = argparse.ArgumentParser(description="Unified Python rendering for PufferDrive")
+ p.add_argument("--model-path", required=True, help="Trained ego policy checkpoint (.pt)")
+ p.add_argument(
+ "--co-player-path", default=None, help="Frozen co-player policy (.pt). Omit for baseline / human-replay."
+ )
+ p.add_argument("--human-replay", action="store_true", help="Render in human-replay mode (one ego, others = log)")
+ p.add_argument("--adaptive", action="store_true", help="Force puffer_adaptive_drive env even when k_scenarios=1")
+
+ p.add_argument(
+ "--policy-architecture",
+ choices=["Recurrent", "Transformer"],
+ default=None,
+ help="Override ego architecture (auto-detected from checkpoint if omitted)",
+ )
+ p.add_argument(
+ "--co-player-architecture",
+ choices=["Recurrent", "Transformer"],
+ default=None,
+ help="Override co-player architecture",
+ )
+
+ p.add_argument(
+ "--map-dir",
+ default="resources/drive/binaries/training",
+ help="Map binary directory (e.g. resources/drive/binaries/nuplan)",
+ )
+ p.add_argument("--num-maps", type=int, default=None, help="Map pool size (default: max(100, num_renders))")
+ p.add_argument("--num-renders", type=int, default=1, help="Number of independent renders (different map seeds)")
+ p.add_argument("--start-seed", type=int, default=1)
+ p.add_argument(
+ "--seed-stride",
+ type=int,
+ default=1009,
+ help="Stride between consecutive map seeds. Spaced apart so adjacent renders pick different maps.",
+ )
+ p.add_argument("--num-agents", type=int, default=64)
+ p.add_argument("--num-ego-agents", type=int, default=32)
+
+ p.add_argument("--k-scenarios", type=int, default=2, help="Number of scenarios per episode (adaptive)")
+ p.add_argument("--scenario-length", type=int, default=91)
+ p.add_argument(
+ "--max-steps", type=int, default=None, help="Steps per render (default: k_scenarios * scenario_length)"
+ )
+
+ p.add_argument("--view-mode", choices=["sim_state", "bev", "persp", "all"], default="sim_state")
+ p.add_argument("--output-dir", default=None, help="Where to write mp4s (default: /renders)")
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
+
+ p.add_argument("--conditioning-type", choices=["none", "reward", "entropy", "discount", "all"], default="none")
+ p.add_argument("--collision-weight-lb", type=float, default=-3.0)
+ p.add_argument("--collision-weight-ub", type=float, default=-3.0)
+ p.add_argument("--offroad-weight-lb", type=float, default=-1.0)
+ p.add_argument("--offroad-weight-ub", type=float, default=-1.0)
+ p.add_argument("--goal-weight-lb", type=float, default=1.0)
+ p.add_argument("--goal-weight-ub", type=float, default=1.0)
+ p.add_argument("--entropy-weight-lb", type=float, default=0.001)
+ p.add_argument("--entropy-weight-ub", type=float, default=0.001)
+ p.add_argument("--discount-weight-lb", type=float, default=0.98)
+ p.add_argument("--discount-weight-ub", type=float, default=0.98)
+
+ p.add_argument(
+ "--co-player-conditioning-type", choices=["none", "reward", "entropy", "discount", "all"], default="all"
+ )
+ p.add_argument("--co-player-collision-weight-lb", type=float, default=-1.0)
+ p.add_argument("--co-player-collision-weight-ub", type=float, default=0.0)
+ p.add_argument("--co-player-offroad-weight-lb", type=float, default=-0.4)
+ p.add_argument("--co-player-offroad-weight-ub", type=float, default=0.0)
+ p.add_argument("--co-player-goal-weight-lb", type=float, default=0.0)
+ p.add_argument("--co-player-goal-weight-ub", type=float, default=1.0)
+ p.add_argument("--co-player-entropy-weight-lb", type=float, default=0.0)
+ p.add_argument("--co-player-entropy-weight-ub", type=float, default=0.1)
+ p.add_argument("--co-player-discount-weight-lb", type=float, default=0.8)
+ p.add_argument("--co-player-discount-weight-ub", type=float, default=1.0)
+
+ args = p.parse_args()
+
+ if args.num_maps is None:
+ args.num_maps = max(100, args.num_renders)
+ if args.output_dir is None:
+ args.output_dir = default_output_dir(args.model_path)
+
+ if args.view_mode == "all":
+ view_modes = list(VIEW_MODE_BY_NAME.values())
+ else:
+ view_modes = [VIEW_MODE_BY_NAME[args.view_mode]]
+
+ env_name, config = build_config(args)
+
+ print(f"Env: {env_name}")
+ print(f"Architecture: {config['policy_architecture']}")
+ print(f"Map dir: {args.map_dir}, num_maps: {args.num_maps}, k_scenarios: {args.k_scenarios}")
+ print(f"Mode: {'human_replay' if args.human_replay else ('coplayer' if args.co_player_path else 'baseline')}")
+ print(f"Views: {[vm.name for vm in view_modes]}")
+
+ saved = []
+ for i in range(args.num_renders):
+ seed = args.start_seed + i * args.seed_stride
+ saved.extend(render_one(env_name, config, view_modes, i, seed, args))
+
+ print(f"\nDone. {len(saved)} videos in {args.output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/ablations/all_coplayers.sh b/scripts/ablations/all_coplayers.sh
new file mode 100755
index 0000000000..49b5caf1d2
--- /dev/null
+++ b/scripts/ablations/all_coplayers.sh
@@ -0,0 +1,105 @@
+#!/bin/bash
+#SBATCH --job-name=coplayer_ablation
+#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out
+#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err
+#SBATCH --mem=128GB
+#SBATCH --time=24:00:00
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+#SBATCH --account=torch_pr_355_tandon_advanced
+#SBATCH --cpus-per-task=48
+#SBATCH --gres=gpu:1
+#SBATCH --array=0-7
+
+# Train every co-player ablation variant in one slurm array.
+#
+# sbatch scripts/ablations/all_coplayers.sh
+#
+# 8 array tasks = 2 datasets × 2 archs × 2 conditioning types:
+# idx | dataset | architecture | conditioning
+# ----+---------+--------------+-------------
+# 0 | womd | Recurrent | none
+# 1 | womd | Recurrent | all
+# 2 | womd | Transformer | none
+# 3 | womd | Transformer | all
+# 4 | nuplan | Recurrent | none
+# 5 | nuplan | Recurrent | all
+# 6 | nuplan | Transformer | none
+# 7 | nuplan | Transformer | all
+#
+# Resulting checkpoints land at experiments/puffer_drive_.pt.
+# Once they've trained, plug them into scripts/adaptive/*.sh as ZIPPED_RUNS
+# entries.
+
+# Decode array index → (dataset, arch, cond_type)
+RUNS=(
+ "womd Recurrent none"
+ "womd Recurrent all"
+ "womd Transformer none"
+ "womd Transformer all"
+ "nuplan Recurrent none"
+ "nuplan Recurrent all"
+ "nuplan Transformer none"
+ "nuplan Transformer all"
+)
+
+read -r DATASET ARCH COND <<< "${RUNS[$SLURM_ARRAY_TASK_ID]}"
+
+# Dataset → map_dir + num_maps
+case "$DATASET" in
+ nuplan)
+ MAP_DIR="resources/drive/binaries/nuplan"
+ NUM_MAPS=5000
+ ;;
+ womd)
+ MAP_DIR="resources/drive/binaries/training"
+ NUM_MAPS=10000
+ ;;
+ *) echo "unknown dataset $DATASET" >&2; exit 1 ;;
+esac
+
+# Conditioning → puffer flags
+# `all` sweeps entropy 0→0.1 and discount 0.8→1.0 (the same range used in
+# scripts/coplayers/* with the cell at idx 0).
+case "$COND" in
+ none)
+ COND_ARGS="--env.conditioning.type none"
+ ;;
+ all)
+ COND_ARGS="--env.conditioning.type all \
+ --env.conditioning.entropy-weight-lb 0 \
+ --env.conditioning.entropy-weight-ub 0.1 \
+ --env.conditioning.discount-weight-lb 0.8 \
+ --env.conditioning.discount-weight-ub 1.0"
+ ;;
+ *) echo "unknown conditioning $COND" >&2; exit 1 ;;
+esac
+
+TAG="coplayer_${DATASET}_${ARCH,,}_cond-${COND}"
+
+singularity exec --nv \
+ --overlay "$OVERLAY_FILE:ro" \
+ "$SINGULARITY_IMAGE" \
+ bash -c "
+ set -e
+
+ source ~/.bashrc
+ cd /scratch/mmk9418/projects/Adaptive_Driving_Agent
+ source .venv/bin/activate
+
+ nice -n 19 python scripts/gpu_heartbeat.py &
+ HEARTBEAT_PID=\$!
+
+ puffer train puffer_drive --wandb \
+ --wandb-project ada_coplayer_ablation \
+ --tag $TAG \
+ --policy-architecture $ARCH \
+ --rnn-name $ARCH \
+ --env.map-dir $MAP_DIR \
+ --env.num-maps $NUM_MAPS \
+ --eval.map-dir $MAP_DIR \
+ --train.checkpoint-interval 50 \
+ $COND_ARGS
+
+ kill \$HEARTBEAT_PID
+ "
diff --git a/scripts/ablations/human_align_ablation.sh b/scripts/ablations/human_align_ablation.sh
new file mode 100755
index 0000000000..e2aa962f58
--- /dev/null
+++ b/scripts/ablations/human_align_ablation.sh
@@ -0,0 +1,74 @@
+#!/bin/bash
+#SBATCH --job-name=human_align_ablation
+#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out
+#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err
+#SBATCH --mem=128GB
+#SBATCH --time=24:00:00
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+#SBATCH --account=torch_pr_355_tandon_advanced
+#SBATCH --cpus-per-task=48
+#SBATCH --gres=gpu:1
+#SBATCH --array=0-5
+
+# Human behavior alignment ablation: 6 experiments
+# collision_weight_lb × offroad_weight_lb grid
+#
+# | Exp | collision_weight_lb | offroad_weight_lb |
+# |-----|---------------------|-------------------|
+# | 0 | -3 | -2.0 |
+# | 1 | -3 | -1.0 |
+# | 2 | -3 | -0.5 |
+# | 3 | -2 | -2.0 |
+# | 4 | -2 | -1.0 |
+# | 5 | -2 | -0.5 |
+
+# Define the grid
+COLLISION_WEIGHTS=(-3 -3 -3 -2 -2 -2)
+OFFROAD_WEIGHTS=(-2.0 -1.0 -0.5 -2.0 -1.0 -0.5)
+
+# Get values for this array task
+COLLISION_LB=${COLLISION_WEIGHTS[$SLURM_ARRAY_TASK_ID]}
+OFFROAD_LB=${OFFROAD_WEIGHTS[$SLURM_ARRAY_TASK_ID]}
+
+# Fixed parameters
+DISCOUNT_UB=0.995
+SEED=42
+NUPLAN_NUM_MAPS=4999
+
+echo "Running experiment $SLURM_ARRAY_TASK_ID: collision_weight_lb=$COLLISION_LB, offroad_weight_lb=$OFFROAD_LB"
+
+singularity exec --nv \
+ --overlay "$OVERLAY_FILE:ro" \
+ "$SINGULARITY_IMAGE" \
+ bash -c "
+ set -e
+
+ source ~/.bashrc
+ cd /scratch/mmk9418/projects/Adaptive_Driving_Agent
+ source .venv/bin/activate
+
+ nice -n 19 python scripts/gpu_heartbeat.py &
+ HEARTBEAT_PID=\$!
+
+ puffer train puffer_drive \
+ --wandb --wandb-project human-align-ablation \
+ --tag ablation_collision${COLLISION_LB}_offroad${OFFROAD_LB} \
+ --env.map-dir resources/drive/binaries/nuplan \
+ --env.num-maps $NUPLAN_NUM_MAPS \
+ --env.conditioning.type all \
+ --env.conditioning.collision-weight-lb $COLLISION_LB \
+ --env.conditioning.collision-weight-ub 0 \
+ --env.conditioning.offroad-weight-lb $OFFROAD_LB \
+ --env.conditioning.offroad-weight-ub 0 \
+ --env.conditioning.discount-weight-lb 0.8 \
+ --env.conditioning.discount-weight-ub $DISCOUNT_UB \
+ --policy-architecture Transformer \
+ --rnn-name Transformer \
+ --train.context-length 91 \
+ --train.horizon 91 \
+ --train.seed $SEED \
+ --eval.map-dir resources/drive/binaries/nuplan
+
+ kill \$HEARTBEAT_PID
+ "
diff --git a/scripts/ablations/human_align_local.sh b/scripts/ablations/human_align_local.sh
new file mode 100755
index 0000000000..72af9454d8
--- /dev/null
+++ b/scripts/ablations/human_align_local.sh
@@ -0,0 +1,57 @@
+#!/bin/bash
+set -e
+
+# Local 8-GPU launcher for human-align ablation.
+# One tmux window per experiment, each pinned to a single GPU.
+
+COLLISION_WEIGHTS=(-3 -3 -2 -2)
+OFFROAD_WEIGHTS=(-2.0 -0.5 -2.0 -0.5)
+
+DISCOUNT_UB=0.995
+SEED=42
+NUPLAN_NUM_MAPS=4999
+
+SESSION=human_align_redo
+tmux new-session -d -s "$SESSION" -n exp0
+
+for i in "${!COLLISION_WEIGHTS[@]}"; do
+ C=${COLLISION_WEIGHTS[$i]}
+ O=${OFFROAD_WEIGHTS[$i]}
+ L=${REWARD_LANE[$i]}
+ WIN="exp${i}"
+
+ if [ "$i" -ne 0 ]; then
+ tmux new-window -t "$SESSION" -n "$WIN"
+ fi
+
+ CMD="cd /workspace/ADA && \
+source .venv/bin/activate && \
+export CUDA_VISIBLE_DEVICES=$((i+4)) && \
+echo 'Running exp${i}: collision_weight_lb=$C, offroad_weight_lb=$O on GPU $i' && \
+xvfb-run -a puffer train puffer_drive \
+ --wandb --wandb-project human-align-ablation \
+ --tag human_ablation_lane_rewards_apr26_redo \
+ --env.map-dir resources/drive/binaries/nuplan \
+ --env.num-maps $NUPLAN_NUM_MAPS \
+ --env.conditioning.type all \
+ --env.conditioning.collision-weight-lb $C \
+ --env.conditioning.collision-weight-ub 0 \
+ --env.conditioning.offroad-weight-lb $O \
+ --env.conditioning.offroad-weight-ub 0 \
+ --env.conditioning.discount-weight-lb 0.8 \
+ --env.conditioning.discount-weight-ub $DISCOUNT_UB \
+ --env.reward-lane-align 0.01 \
+ --env.reward-vel-align 1.0 \
+ --policy-architecture Transformer \
+ --train.context-length 91 \
+ --train.horizon 91 \
+ --train.seed $SEED \
+ --eval.map-dir resources/drive/binaries/nuplan"
+
+ tmux send-keys -t "$SESSION:$WIN" "$CMD" C-m
+done
+
+echo "Launched ${#COLLISION_WEIGHTS[@]} runs in tmux session '$SESSION'."
+echo "Attach: tmux attach -t $SESSION"
+echo "Switch windows: Ctrl-b 0..7 (or Ctrl-b n / Ctrl-b p)"
+echo "Detach: Ctrl-b d"
diff --git a/scripts/adaptive/nuplan_recurrent.sh b/scripts/adaptive/nuplan_recurrent.sh
new file mode 100755
index 0000000000..93136910f6
--- /dev/null
+++ b/scripts/adaptive/nuplan_recurrent.sh
@@ -0,0 +1,81 @@
+#!/bin/bash
+#SBATCH --job-name=adaptive_nuplan_rnn
+#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out
+#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err
+#SBATCH --mem=128GB
+#SBATCH --time=24:00:00
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+#SBATCH --account=torch_pr_355_tandon_advanced
+#SBATCH --cpus-per-task=48
+#SBATCH --gres=gpu:1
+#SBATCH --array=0-15
+
+# Train adaptive agents on NuPlan with Recurrent architecture
+# Uses pre-trained NuPlan Recurrent co-players with varied conditioning
+#
+# PREREQUISITE: Train co-players first with scripts/coplayers/nuplan_recurrent.sh
+# Then update ZIPPED_RUNS with the trained policy paths from wandb
+
+# Co-player policies trained with scripts/coplayers/nuplan_recurrent.sh
+# Each entry: "policy_path entropy_weight_ub discount_weight_lb"
+ZIPPED_RUNS=(
+ "experiments/puffer_drive_mwiatx5g.pt 0.5 0.8"
+ "experiments/puffer_drive_old90mw2.pt 0.1 0.8"
+ "TODO_FAILED 0.01 0.8"
+ "TODO_FAILED 0 0.8"
+
+ "experiments/puffer_drive_hous32qj.pt 0.5 0.6"
+ "experiments/puffer_drive_dhxoorxc.pt 0.1 0.6"
+ "experiments/puffer_drive_0h5radmw.pt 0.01 0.6"
+ "experiments/puffer_drive_elet1zj8.pt 0 0.6"
+
+ "experiments/puffer_drive_a57umusk.pt 0.5 0.4"
+ "experiments/puffer_drive_zozl26ek.pt 0.1 0.4"
+ "experiments/puffer_drive_jco8adma.pt 0.01 0.4"
+ "experiments/puffer_drive_5hhwfmmt.pt 0 0.4"
+
+ "experiments/puffer_drive_a56dgt1x.pt 0.5 0.2"
+ "experiments/puffer_drive_dqzqt7qx.pt 0.1 0.2"
+ "experiments/puffer_drive_qetoozcn.pt 0.01 0.2"
+ "experiments/puffer_drive_q30rjcbu.pt 0 0.2"
+)
+
+read -r COPLAYER_PATH ENTROPY_UB DISCOUNT_LB <<< "${ZIPPED_RUNS[$SLURM_ARRAY_TASK_ID]}"
+
+# Fixed values
+CONDITION_TYPE="all"
+DISCOUNT_UB=1
+ENTROPY_LB=0
+NUPLAN_NUM_MAPS=5000
+
+singularity exec --nv \
+ --overlay "$OVERLAY_FILE:ro" \
+ "$SINGULARITY_IMAGE" \
+ bash -c "
+ set -e
+
+ source ~/.bashrc
+ cd /scratch/mmk9418/projects/Adaptive_Driving_Agent
+ source .venv/bin/activate
+
+ nice -n 19 python scripts/gpu_heartbeat.py &
+ HEARTBEAT_PID=\$!
+
+ puffer train puffer_adaptive_drive --wandb --tag adaptive_nuplan_recurrent_k2 \
+ --env.map-dir resources/drive/binaries/nuplan \
+ --env.num-maps $NUPLAN_NUM_MAPS \
+ --env.conditioning.type none \
+ --env.co-player-enabled 1 \
+ --env.co-player-policy.policy-path $COPLAYER_PATH \
+ --env.co-player-policy.conditioning.type $CONDITION_TYPE \
+ --env.co-player-policy.conditioning.discount-weight-lb $DISCOUNT_LB \
+ --env.co-player-policy.conditioning.discount-weight-ub $DISCOUNT_UB \
+ --env.co-player-policy.conditioning.entropy-weight-lb $ENTROPY_LB \
+ --env.co-player-policy.conditioning.entropy-weight-ub $ENTROPY_UB \
+ --policy-architecture Recurrent \
+ --rnn-name Recurrent \
+ --eval.map-dir resources/drive/binaries/nuplan
+
+ kill \$HEARTBEAT_PID
+ "
diff --git a/scripts/adaptive/nuplan_transformer.sh b/scripts/adaptive/nuplan_transformer.sh
new file mode 100755
index 0000000000..e9c105bc03
--- /dev/null
+++ b/scripts/adaptive/nuplan_transformer.sh
@@ -0,0 +1,78 @@
+#!/bin/bash
+#SBATCH --job-name=adaptive_nuplan_tfm
+#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out
+#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err
+#SBATCH --mem=128GB
+#SBATCH --time=48:00:00
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+#SBATCH --account=torch_pr_355_tandon_advanced
+#SBATCH --cpus-per-task=48
+#SBATCH --gres=gpu:1
+#SBATCH --array=0-8
+
+# Train adaptive agents on NuPlan with Transformer architecture
+# Uses pre-trained NuPlan Transformer co-players with varied conditioning
+#
+# PREREQUISITE: Train co-players first with scripts/coplayers/nuplan_transformer.sh
+# Then update ZIPPED_RUNS with the trained policy paths from wandb
+
+# Co-player policies trained with scripts/coplayers/nuplan_transformer.sh
+# Each entry: "policy_path entropy_weight_ub discount_weight_lb"
+ZIPPED_RUNS=(
+"experiments/puffer_drive_joqbmi4s.pt 0.01 0.2"
+"experiments/puffer_drive_0h81rtfi.pt 0 0.4"
+# "experiments/puffer_drive_medzmgum.pt 0.1 0.2"
+# "experiments/puffer_drive_f4a8yoi9.pt 0.5 0.2"
+"experiments/puffer_drive_b1yx43w2.pt 0.01 0.4"
+# "experiments/puffer_drive_lv4x8hlt.pt 0.1 0.4"
+# "experiments/puffer_drive_j51yz49e.pt 0.5 0.4"
+"experiments/puffer_drive_iry1wanp.pt 0 0.6"
+"experiments/puffer_drive_kx8bhu3v.pt 0.01 0.6"
+"experiments/puffer_drive_js4mf85k.pt 0.1 0.6"
+# "experiments/puffer_drive_52u5onve.pt 0.5 0.6"
+"experiments/puffer_drive_nu7lkmx4.pt 0 0.8"
+"experiments/puffer_drive_f8dpkpbq.pt 0.01 0.8"
+"experiments/puffer_drive_zxsxu6z7.pt 0.1 0.8"
+# "experiments/puffer_drive_mfahi5bc.pt 0.5 0.8"
+)
+
+read -r COPLAYER_PATH ENTROPY_UB DISCOUNT_LB <<< "${ZIPPED_RUNS[$SLURM_ARRAY_TASK_ID]}"
+
+# Fixed values
+CONDITION_TYPE="all"
+DISCOUNT_UB=1
+ENTROPY_LB=0
+NUPLAN_NUM_MAPS=5000
+
+singularity exec --nv \
+ --overlay "$OVERLAY_FILE:ro" \
+ "$SINGULARITY_IMAGE" \
+ bash -c "
+ set -e
+
+ source ~/.bashrc
+ cd /scratch/mmk9418/projects/Adaptive_Driving_Agent
+ source .venv/bin/activate
+
+ nice -n 19 python scripts/gpu_heartbeat.py &
+ HEARTBEAT_PID=\$!
+
+ puffer train puffer_adaptive_drive --wandb --tag adaptive_nuplan_transformer_4apr \
+ --env.map-dir resources/drive/binaries/nuplan \
+ --env.num-maps $NUPLAN_NUM_MAPS \
+ --env.conditioning.type none \
+ --env.co-player-enabled 1 \
+ --env.co-player-policy.policy-path $COPLAYER_PATH \
+ --env.co-player-policy.architecture Transformer \
+ --env.co-player-policy.conditioning.type $CONDITION_TYPE \
+ --env.co-player-policy.conditioning.discount-weight-lb $DISCOUNT_LB \
+ --env.co-player-policy.conditioning.discount-weight-ub $DISCOUNT_UB \
+ --env.co-player-policy.conditioning.entropy-weight-lb $ENTROPY_LB \
+ --env.co-player-policy.conditioning.entropy-weight-ub $ENTROPY_UB \
+ --policy-architecture Transformer \
+ --rnn-name Transformer \
+ --eval.map-dir resources/drive/binaries/nuplan
+
+ kill \$HEARTBEAT_PID
+ "
diff --git a/scripts/adaptive/nuplan_transformer_local.sh b/scripts/adaptive/nuplan_transformer_local.sh
new file mode 100644
index 0000000000..892af6113b
--- /dev/null
+++ b/scripts/adaptive/nuplan_transformer_local.sh
@@ -0,0 +1,111 @@
+#!/bin/bash
+set -e
+
+# Local 4-GPU adaptive launcher (tmux pattern, mirrors human_align_local.sh).
+# Trains 4 adaptive ego agents, one against each of the new local co-player
+# checkpoints (collision_lb=-2, offroad_lb=-2, lane_reward=0.01).
+#
+# Each ego run uses a co-player conditioning band that matches what THAT
+# co-player saw during its own training (so the policy is queried inside the
+# distribution it learned).
+#
+# DEFAULT GPU ASSIGNMENT: 0,1,2,3. Override with: GPUS="4 5 6 7" bash