diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock new file mode 100644 index 0000000000..595f769155 --- /dev/null +++ b/.claude/scheduled_tasks.lock @@ -0,0 +1 @@ +{"sessionId":"7bcf951d-92e0-44b9-83b1-dc29b0d69f18","pid":8923,"procStart":"408700719","acquiredAt":1777664169801} diff --git a/.github/workflows/utest.yml b/.github/workflows/utest.yml index cfcbdda4b7..db5e5f0e4a 100644 --- a/.github/workflows/utest.yml +++ b/.github/workflows/utest.yml @@ -4,7 +4,6 @@ on: pull_request: branches: [ main ] push: - branches: [ main ] jobs: test: @@ -21,13 +20,9 @@ jobs: - name: Free up disk space run: | - echo "Before cleanup:" - df -h sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc sudo apt-get clean sudo rm -rf ~/.cache/pip /tmp/* /var/tmp/* - echo "After cleanup:" - df -h - name: Install dependencies run: | @@ -37,8 +32,7 @@ jobs: run: python -m pip install -U pip - name: Install pufferlib - run: | - pip install -e .[cpu] --no-cache-dir + run: pip install -e .[cpu] --no-cache-dir pytest env: TMPDIR: ${{ runner.temp }}/build PIP_NO_CACHE_DIR: 1 @@ -53,3 +47,11 @@ jobs: - name: Run python ini parsing tests run: python tests/test_drive_config.py + + - name: Run pytest suite (conditioning, map_dir, render contracts) + run: | + pytest \ + tests/test_drive_conditioning.py \ + tests/test_map_dir_flow.py \ + tests/test_render_pipeline.py \ + -v --tb=short diff --git a/.gitignore b/.gitignore index a83a714b7f..71497ea416 100644 --- a/.gitignore +++ b/.gitignore @@ -161,13 +161,16 @@ pufferlib/ocean/impulse_wars/debug-*/ pufferlib/ocean/impulse_wars/release-*/ pufferlib/ocean/impulse_wars/benchmark/ - # Ignore data files data/ -pufferlib/resources/drive/binaries/ +pufferlib/resources/drive/binaries/* +pufferlib/resources/drive/binaries/training/ +pufferlib/resources/drive/binaries/validation/ # But keep map_000.bin for the training test !pufferlib/resources/drive/binaries/map_000.bin +!pufferlib/resources/drive/binaries/training/map_000.bin +pufferlib/resources/drive/sanity/sanity_binaries/ # Compiled drive binary in root /drive @@ -183,7 +186,12 @@ pufferlib/resources/drive/output_agent.gif pufferlib/resources/drive/output.gif # Local artifacts and outputs artifacts/ # Local drive renders +pufferlib/resources/drive/output*.gif +emsdk/ +docs/book/* +!docs/book/assets/ pufferlib/resources/drive/output*.mp4 # Local TODO tracking TODO.md +*.mp4 diff --git a/README.md b/README.md index 4ac61c39eb..06ebf54b84 100644 --- a/README.md +++ b/README.md @@ -1,216 +1,350 @@ -# PufferDrive +# Adaptive Driving Agent - +A fork of [PufferDrive](https://github.com/Emerge-Lab/PufferDrive) for training +**adaptive** driving policies — agents that infer their partners' behavior +on the fly instead of being told what to expect. -**PufferDrive is a fast and friendly driving simulator to train and test RL-based models.** +## What we're doing -
-
-
-
-
-
-
-
-
-
+The training pipeline is two-stage: ---- +1. **Train a population of co-players.** A single conditioned policy + (`puffer_drive`) is trained over a range of conditioning values + (entropy weight, discount factor, reward weights). Sampling different + conditioning vectors during inference gives us a *population* of behaviors + from one set of weights. +2. **Train an adaptive ego against that population** (`puffer_adaptive_drive`). + The adaptive policy is *not* conditioned — it has to figure out what kind + of partner it's facing from the observation stream. Each episode contains + `k_scenarios` scenarios; the partner is re-sampled at scenario boundaries + so the agent has multiple shots to adapt within one episode. -**Docs**: https://emerge-lab.github.io/PufferDrive +Everything runs in PufferLib's RL training stack, with `Drive` (the C+raylib +simulator) for the environment and a PyTorch policy on top. ---- +## Repo layout -### See our 2.0 release video - - - PufferDrive 2.0 - +``` +Adaptive_Driving_Agent/ +├── render.py # unified Python rendering CLI +├── pufferlib/ +│ ├── pufferl.py # train/eval entrypoint +│ ├── utils.py # render_videos, eval subprocess builders +│ ├── ocean/drive/ +│ │ ├── drive.h # C simulator + raylib renderer +│ │ ├── drive.py # gymnasium wrapper, RenderView, render() +│ │ ├── adaptive.py # AdaptiveDrivingAgent subclass +│ │ ├── rollout.py # arch-agnostic rollout loop (used by render + eval) +│ │ ├── torch.py # encoder + LSTMWrapper / TransformerWrapper +│ │ └── binding.{c,h} # Python C-extension +│ ├── ocean/benchmark/evaluator.py # HumanReplayEvaluator, WOSACEvaluator +│ └── config/ocean/{drive,adaptive}.ini +├── scripts/ +│ ├── coplayers/ # train co-players (per dataset × arch) +│ ├── adaptive/ # train adaptive ego using a co-player ckpt +│ ├── baselines/ # vanilla (no co-player) for comparison +│ └── ablations/ +│ ├── all_coplayers.sh # one-shot driver for the full coplayer matrix +│ └── human_align_ablation.sh # collision/offroad weight grid +└── tests/ # pytest suite (see CI section) +``` -## Installation +## Setup -Clone the repo ```bash -https://github.com/Emerge-Lab/PufferDrive.git -``` +git clone +cd Adaptive_Driving_Agent -Make a venv (`uv venv`), activate the venv -``` -source .venv/bin/activate -``` - -Inside the venv, install the dependencies -``` +uv venv && source .venv/bin/activate uv pip install -e . -``` -Compile the C code -``` +# Build the C extension python setup.py build_ext --inplace --force -``` -Run this while your virtual environment is active so the extension is built against the right interpreter. -To test your setup, you can run +# Headless rendering deps (skip on a desktop with display) +sudo apt install ffmpeg xvfb # or: +conda install -c conda-forge xorg-x11-server-xvfb-cos6-x86_64 ffmpeg ``` -puffer train puffer_drive -``` -See also the [puffer docs](https://puffer.ai/docs.html). +### Map binaries -## Quick start +The simulator reads pre-converted map binaries. Two datasets ship out of +the box: -Start a training run ``` -puffer train puffer_drive +resources/drive/binaries/training/ # WOMD (default) +resources/drive/binaries/nuplan/ # nuPlan ``` -## Dataset - -
-Downloading and using data - -### Data preparation - -To train with PufferDrive, you need to convert JSON files to map binaries. Run the following command with the path to your data folder: +To convert your own JSON scenes: ```bash -python pufferlib/ocean/drive/drive.py +python -c "from pufferlib.ocean.drive.drive import process_all_maps; \ + process_all_maps('path/to/json/folder', max_maps=5000)" ``` -### Downloading Waymo Data - -You can download the WOMD data from Hugging Face in two versions: +The output lives at `resources/drive/binaries//map_NNN.bin`. -- **Mini dataset**: [GPUDrive_mini](https://huggingface.co/datasets/EMERGE-lab/GPUDrive_mini) contains 1,000 training files and 300 test/validation files -- **Medium dataset**: [10,000 files from the training dataset](https://huggingface.co/datasets/daphne-cornelisse/pufferdrive_train) -- **Large dataset**: [GPUDrive](https://huggingface.co/datasets/EMERGE-lab/GPUDrive) contains 100,000 unique scenes +## Training -**Note**: Replace 'GPUDrive_mini' with 'GPUDrive' in your download commands if you want to use the full dataset. +The CLI is `puffer train `. The env names are: +- `puffer_drive` — co-player / baseline training +- `puffer_adaptive_drive` — adaptive ego against frozen co-player -### Additional Data Sources +Architecture is selected with `--policy-architecture {Recurrent, Transformer}`. -For more training data compatible with PufferDrive, see [ScenarioMax](https://github.com/valeoai/ScenarioMax). The GPUDrive data format is fully compatible with PufferDrive. -
+### 1. Train the co-player population +Co-players are trained on a single env with conditioning sweeping over the +desired range. The conditioning at inference time is what makes the +population diverse. -## Visualizer +Per-dataset slurm scripts (4×4 entropy/discount grid, 16 array tasks each): -
-Dependencies and usage - -## Local rendering - -To launch an interactive renderer, first build: -``` -bash scripts/build_ocean.sh drive local +```bash +sbatch scripts/coplayers/nuplan_recurrent.sh +sbatch scripts/coplayers/nuplan_transformer.sh +sbatch scripts/coplayers/womd_recurrent.sh +sbatch scripts/coplayers/womd_transformer.sh ``` -then launch: +Or, to fire all 8 ablation cells (2 datasets × 2 archs × {none, all}) as a +single slurm array: + ```bash -./drive +sbatch scripts/ablations/all_coplayers.sh # 8 array tasks, one per cell ``` -this will run `demo()` with an existing model checkpoint. -## Headless server setup +The matrix is documented in the script's header. -Run the Raylib visualizer on a headless server and export as .mp4. This will rollout the pre-trained policy in the env. +Resulting checkpoints land at `experiments/puffer_drive_.pt`. -### Install dependencies +### 2. Train the adaptive ego -```bash -sudo apt update -sudo apt install ffmpeg xvfb -``` +The adaptive scripts wire one co-player checkpoint per array task, then +train the adaptive ego with `co_player_enabled=True`: -For HPC (There are no root privileges), so install into the conda environment ```bash -conda install -c conda-forge xorg-x11-server-xvfb-cos6-x86_64 -conda install -c conda-forge ffmpeg +sbatch scripts/adaptive/nuplan_recurrent.sh +sbatch scripts/adaptive/nuplan_transformer.sh +sbatch scripts/adaptive/womd_recurrent.sh +sbatch scripts/adaptive/womd_transformer.sh ``` -- `ffmpeg`: Video processing and conversion -- `xvfb`: Virtual display for headless environments - -### Build and run +Before submitting, fill in the `ZIPPED_RUNS=(...)` array in each script with +the co-player checkpoint paths from step 1. Run IDs come from the wandb URLs +of the co-player runs. -1. Build the application: +The key flags are: ```bash -bash scripts/build_ocean.sh visualize local +puffer train puffer_adaptive_drive \ + --policy-architecture Recurrent \ # or Transformer + --env.k-scenarios 2 \ # scenarios per episode + --env.co-player-enabled True \ + --env.co-player-policy.policy-path .pt \ + --env.co-player-policy.architecture Recurrent \ + --env.co-player-policy.conditioning.type all \ + --env.co-player-policy.conditioning.entropy-weight-lb 0 \ + --env.co-player-policy.conditioning.entropy-weight-ub 0.1 \ + --env.co-player-policy.conditioning.discount-weight-lb 0.8 \ + --env.co-player-policy.conditioning.discount-weight-ub 1.0 \ + --env.map-dir resources/drive/binaries/nuplan \ + --eval.map-dir resources/drive/binaries/nuplan ``` -2. Run with virtual display: +### 3. Baselines (no co-player) + +For ablation, train the same architectures with `co_player_enabled=False`: + ```bash -xvfb-run -s "-screen 0 1280x720x24" ./visualize +sbatch scripts/baselines/nuplan_recurrent.sh +sbatch scripts/baselines/nuplan_transformer.sh +sbatch scripts/baselines/womd_recurrent.sh +sbatch scripts/baselines/womd_transformer.sh ``` -The `-s` flag sets up a virtual screen at 1280x720 resolution with 24-bit color depth. +Other agents in the scene replay their logged human trajectories. ---- +## Rendering -> To force a rebuild, you can delete the cached compiled executable binary using `rm ./visualize`. +Single Python entrypoint, `render.py`, replaces the old `./visualize` C +binary and the per-scenario eval scripts. It works for **any** architecture +because policy inference is done in PyTorch (the C side only handles +graphics). Architecture is auto-detected from the checkpoint state-dict. ---- +### Modes -
+| Flags | What it renders | +|---|---| +| `--model-path X.pt` | Baseline: ego drives, others follow logs | +| `--model-path adaptive.pt --co-player-path coplayer.pt` | Adaptive ego vs frozen co-player | +| `--model-path X.pt --human-replay` | One ego is policy-controlled; everyone else replays human logs | +### Examples -## Benchmarks +```bash +# Baseline render of a co-player checkpoint +xvfb-run -s "-screen 0 1280x720x24" python render.py \ + --model-path experiments/puffer_drive_.pt \ + --map-dir resources/drive/binaries/training \ + --conditioning-type all \ + --num-renders 5 + +# Adaptive ego rendered against a co-player population +python render.py \ + --model-path experiments/puffer_adaptive_drive_.pt \ + --co-player-path experiments/puffer_drive_.pt \ + --co-player-conditioning-type all \ + --k-scenarios 2 \ + --num-renders 3 \ + --map-dir resources/drive/binaries/nuplan + +# Human-replay: how compatible is the agent with logged humans? +python render.py \ + --model-path experiments/puffer_adaptive_drive_.pt \ + --human-replay \ + --k-scenarios 2 \ + --num-renders 5 + +# All three views (top-down, BEV, third-person) in one call +python render.py --model-path X.pt --view-mode all +``` -### Distributional realism +### Output -We provide a PufferDrive implementation of the [Waymo Open Sim Agents Challenge (WOSAC)](https://waymo.com/open/challenges/2025/sim-agents/) for fast, easy evaluation of how well your trained agent matches distributional properties of human behavior. See documentation [here](https://emerge-lab.github.io/PufferDrive/wosac/). +Videos are written to `/renders/`, with descriptive filenames so +multiple modes on the same map don't collide: -WOSAC evaluation with random policy: -```bash -puffer eval puffer_drive --eval.wosac-realism-eval True +``` +experiments/puffer_adaptive_drive_/renders/ + _baseline_k1_map002_sim_state.mp4 + _human_replay_k2_map002_sim_state.mp4 + _vs__k2_map002_sim_state.mp4 + _vs__k2_map002_bev.mp4 + _vs__k2_map002_persp.mp4 ``` -WOSAC evaluation with your checkpoint (must be .pt file): -```bash -puffer eval puffer_drive --eval.wosac-realism-eval True --load-model-path .pt +Different renders pick different maps via a prime stride +(`--start-seed`, `--seed-stride`). + +### During training + +`puffer train ... --train.render True` renders periodically every +`--train.render-interval` epochs through the same path, writes to +`/renders/epoch___k_map_.mp4`, and uploads +to wandb. + +## Checkpoint sidecar (info.json) + +Every checkpoint dir gets an `info.json` written alongside the `.pt` files +at save time. It records the training config you can't recover from the +weights alone — useful for grepping later or for filling in adaptive +training scripts by hand: + +```json +{ + "run_id": "1qcm9dc0", + "env_name": "puffer_drive", + "policy_architecture": "Recurrent", + "env": { + "map_dir": "resources/drive/binaries/nuplan", + "k_scenarios": 1, + "scenario_length": 91, + "dynamics_model": "classic", + "co_player_enabled": false, + "conditioning": { + "type": "all", + "entropy_weight_lb": 0.0, "entropy_weight_ub": 0.1, + "discount_weight_lb": 0.8, "discount_weight_ub": 1.0, + "..." + } + } +} ``` -### Human-compatibility +This is reference-only — no autoload. When you train an adaptive ego +against a co-player, you still pass `--env.co-player-policy.policy-path` +(and the matching architecture / conditioning flags) explicitly so the +intent stays in the script. + +## Evaluation + +Two evaluators ship out of the box: + +### Human-replay (compatibility with logged humans) -You may be interested in how compatible your agent is with human partners. For this purpose, we support an eval where your policy only controls the self-driving car (SDC). The rest of the agents in the scene are stepped using the logs. While it is not a perfect eval since the human partners here are static, it will still give you a sense of how closely aligned your agent's behavior is to how people drive. You can run it like this: ```bash -puffer eval puffer_drive --eval.human-replay-eval True --load-model-path .pt +puffer eval puffer_drive \ + --eval.human-replay-eval True \ + --load-model-path experiments/puffer_drive_.pt \ + --eval.map-dir resources/drive/binaries/nuplan ``` -## Development +For adaptive agents, this also reports `ada_delta_*` metrics +(last_scenario − first_scenario) so you can see how much the policy +adapted within an episode. -
Documentation and browser demo +The training loop runs this automatically every `--eval.eval-interval` +epochs when `--eval.human-replay-eval True`. -**Docs** +### WOSAC (distributional realism) -A browsable documentation site now lives under `docs/` and is configured with mkbooks. To preview locally: -``` -brew install mdbook -mdbook serve --open docs +```bash +puffer eval puffer_drive \ + --eval.wosac-realism-eval True \ + --load-model-path experiments/puffer_drive_.pt ``` -Open the served URL to see a local version of the docs. -**Interactive demo** +### Map-dir handling -To edit the browser demo, follow these steps: -- Download [emscripten](https://github.com/emscripten-core/emscripten) -- emscripten install latest -- Activate: `source emsdk/emsdk_env.sh` -- Run `bash scripts/build_ocean.sh drive web` -- This generates a number of `game*` files, move them to `assets/` to include them on the webpage +`eval.map_dir` defaults to `None` in both `drive.ini` and `adaptive.ini`, +which means **eval inherits the training `env.map_dir`** instead of silently +falling back to a different dataset. To use a different dataset for eval, +set `--eval.map-dir` explicitly. -
+## Tests + CI +```bash +pip install pytest +pytest tests/test_drive_conditioning.py \ + tests/test_map_dir_flow.py \ + tests/test_render_pipeline.py \ + tests/test_drive_config.py +``` -## Citation +The pytest suite covers: +- **Conditioning shapes** (`tests/test_drive_conditioning.py`): + every `conditioning_type` × `dynamics_model` combination produces the + right observation dim and value range. +- **map_dir propagation** (`tests/test_map_dir_flow.py`): + `eval()` inherits `env.map_dir` when `eval.map_dir` is unset, and the + human-replay / WOSAC subprocess builders forward a concrete + `--eval.map-dir` so the child process can't fall back to ini default. +- **Render contract** (`tests/test_render_pipeline.py`): + `set_video_suffix(name)` produces `name.mp4`. The dataclass default is safe. + +GitHub Actions: +- `.github/workflows/install.yml`: pre-commit on every push and PR. +- `.github/workflows/utest.yml`: pytest suite on every push (any branch) and PRs. +- `.github/workflows/render-ci.yml`: end-to-end render smoke (xvfb + ffmpeg). +- `.github/workflows/train-ci.yml`: short training + scenario tests. + +Run pre-commit locally before pushing: -If you use PufferDrive in your research, please cite: -```bibtex -@software{pufferdrive2025github, - author = {Daphne Cornelisse* and Spencer Cheng* and Pragnay Mandavilli and Julian Hunt and Kevin Joseph and Waël Doulazmi and Aditya Gupta and Eugene Vinitsky}, - title = {{PufferDrive}: A Fast and Friendly Driving Simulator for Training and Evaluating {RL} Agents}, - url = {https://github.com/Emerge-Lab/PufferDrive}, - version = {2.0.0}, - year = {2025}, -} +```bash +pre-commit install # one-time +pre-commit run --all-files ``` + +## Common pitfalls + +- **Checkpoint shape mismatch on render.** Conditioning changes the ego + observation dim, so a checkpoint trained with `conditioning.type=all` + needs `--conditioning-type all` at render time too. +- **Eval using the wrong dataset.** Either set `--eval.map-dir` explicitly + or rely on inheritance (the default with `eval.map_dir = None`). Don't set + `eval.map_dir` to a stale value in your ini. +- **Adaptive batch sizes.** `minibatch_size` must be ≥ `batch_size` and both + must be divisible by `horizon = k_scenarios * scenario_length` for adaptive. +- **Multiple consecutive renders.** The C extension recreates the raylib + window per render — that's expected, ~1s overhead per render. diff --git a/TODO_paper.md b/TODO_paper.md new file mode 100644 index 0000000000..f90aaffb93 --- /dev/null +++ b/TODO_paper.md @@ -0,0 +1,39 @@ +# Paper TODO + +## Pending — resume after recovery eval finishes +- **Resume the 4 k=3 adaptive runs** (paused on 2026-04-28 to free GPUs 4-7 + for parallel recovery-metric eval). Latest checkpoints are at epochs + 150-170. One-command resume: + ``` + bash /workspace/ADA/scripts/adaptive/nuplan_transformer_local_k3_resume.sh + ``` + The script already auto-locates the latest `model_*.pt` per wandb id and + uses the optimized flags (cpu_offload + external_co_player_actions, nw=32). + +## In progress +- **Conditional recovery metric** in `HumanReplayEvaluator`: + per-(agent, scenario) success tracking, then `P(success in s_k | failed s_0)`. + Surfaces in-context adaptation in the small fraction of nuplan scenes + where adaptation actually matters (most scenes are easy and ego trivially + succeeds in s_0, washing out the averaged `ada_delta_score`). + +## Backlog (after we have a baseline conditional-recovery number) + +- **Map rotation per scenario within an episode**. + At each scenario boundary, swap the underlying nuplan map so the ego sees + a new scene with new humans. KV cache preserved across scenarios so past + observations can inform the new scene. Forces more "hard" cases by + guaranteeing each scenario is genuinely novel. Implement as a flag: + `--env.map-rand-per-scenario {none, eval, train, both}`. Likely requires + retraining adaptive agents to handle the within-episode discontinuity. + +- **Train-time conditional-recovery loss/weighting**: bias rollout sampling + toward "hard" scenarios (rare but informative) once we have per-agent + difficulty estimates from offline eval. + +- **Cross-co-player generalization eval**: train ego against partners + {A, B, C}, eval against held-out partner D. Tests whether learned + adaptation generalizes to unseen partner styles. + +- **Architecture ablation**: run with KV-cache disabled (or context-length=1) + to confirm the cache is what's enabling adaptation (when it works). diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md index 68f9b699e2..3c5e1784f2 100644 --- a/docs/src/getting-started.md +++ b/docs/src/getting-started.md @@ -25,7 +25,7 @@ python setup.py build_ext --inplace --force Run this with your virtual environment activated so the compiled extension links against the correct Python. ### When to rebuild the extension -- Re-run `python setup.py build_ext --inplace --force` after changing any C/Raylib sources in `pufferlib/ocean/drive` (e.g., `drive.c`, `drive.h`, `binding.c`, `visualize.c`) or after pulling upstream changes that touch those files. This regenerates the `binding.cpython-*.so` used by `Drive`. +- Re-run `python setup.py build_ext --inplace --force` after changing any C/Raylib sources in `pufferlib/ocean/drive` (e.g., `drive.h`, `binding.c`, `binding.h`) or after pulling upstream changes that touch those files. This regenerates the `binding.cpython-*.so` used by `Drive`. - Pure Python edits (training scripts, docs, data utilities) do not require a rebuild; just restart your Python process. ## Verify the setup diff --git a/docs/src/simulator.md b/docs/src/simulator.md index a500840bb1..c52abc045a 100644 --- a/docs/src/simulator.md +++ b/docs/src/simulator.md @@ -210,15 +210,14 @@ mid_x, mid_y, length, width, dir_cos, dir_sin, type ## Source files ### C core -- `drive.h`: Main simulator (stepping, observations, collisions) -- `drive.c`: Demo and testing -- `binding.c`: Python interface -- `visualize.c`: Raylib renderer -- `drivenet.h`: C inference network +- `drive.h`: Main simulator (stepping, observations, collisions, raylib renderer) +- `binding.c` / `binding.h`: Python interface ### Python -- `drive.py`: Gymnasium wrapper +- `drive.py`: Gymnasium wrapper, `Drive.render()`, `set_video_suffix()` +- `rollout.py`: Policy-agnostic rollout loop used by both training renders and `render.py` - `torch.py`: Neural network (ego/partner/road encoders → actor/critic) +- `../../../render.py`: Unified rendering CLI (replaces the old `./visualize` binary) ## Neural network diff --git a/docs/src/visualizer.md b/docs/src/visualizer.md index c6b73c90ea..f7046b9d02 100644 --- a/docs/src/visualizer.md +++ b/docs/src/visualizer.md @@ -1,35 +1,42 @@ # Visualizer -PufferDrive ships a Raylib-based visualizer for replaying scenes, exporting videos, and debugging policies. +PufferDrive renders headless mp4s from Python. Policy inference runs in PyTorch +and graphics are produced via the C/raylib bindings (`vec_render`), so the same +pipeline works for both LSTM and Transformer policies. ## Dependencies -Install the minimal system packages for headless render/export: ```bash sudo apt update sudo apt install ffmpeg xvfb ``` -On environments without sudo, install them into your conda/venv: +Without sudo: ```bash conda install -c conda-forge xorg-x11-server-xvfb-cos6-x86_64 ffmpeg ``` -## Build -Compile the visualizer binary from the repo root: +## Run + +The unified entrypoint is `render.py` at the repo root: ```bash -bash scripts/build_ocean.sh visualize local -``` +# Baseline: ego drives, others follow logged trajectories +xvfb-run -s "-screen 0 1280x720x24" python render.py \ + --model-path experiments/.pt --map-dir resources/drive/binaries/training -If you need to force a rebuild, remove the cached binary first (`rm ./visualize`). +# Adaptive ego + frozen co-player population +python render.py --model-path adaptive.pt --co-player-path coplayer.pt \ + --co-player-conditioning-type all --k-scenarios 2 -## Run headless -Launch the visualizer with a virtual display and export an `.mp4`: +# Human replay: only the SDC is policy-controlled, others = logged +python render.py --model-path X.pt --human-replay --num-renders 5 -```bash -xvfb-run -s "-screen 0 1280x720x24" ./visualize +# Multiple views in one go +python render.py --model-path X.pt --view-mode all ``` -Adjust the screen size and color depth as needed. The `xvfb-run` wrapper allows Raylib to render without an attached display, which is convenient for servers and CI jobs. +Architecture is auto-detected from the checkpoint state-dict; override with +`--policy-architecture {Recurrent,Transformer}`. See `python render.py --help` +for the full set of conditioning, co-player, and rendering flags. diff --git a/evaluate_human_logs.py b/evaluate_human_logs.py deleted file mode 100644 index 83472b6c71..0000000000 --- a/evaluate_human_logs.py +++ /dev/null @@ -1,411 +0,0 @@ -import argparse -import json -import numpy as np -import torch -from tqdm import tqdm -import pufferlib -import pufferlib.vector -from pufferlib.ocean import env_creator -from pufferlib.ocean.torch import Drive, Recurrent - -import matplotlib.pyplot as plt -import numpy as np - - -def plot_adaptive_metrics(first_metrics, last_metrics, delta_metrics, output_path): - """ - Plot adaptive metrics showing first scenario (0-shot), last scenario, and delta improvement. - """ - # Metrics to plot - metrics_to_plot = { - "score": "Score", - "collision_rate": "Collision Rate", - "offroad_rate": "Offroad Rate", - "episode_return": "Episode Return", - } - - fig, axes = plt.subplots(2, 2, figsize=(14, 10)) - axes = axes.flatten() - - for idx, (metric_key, metric_name) in enumerate(metrics_to_plot.items()): - ax = axes[idx] - - first_val = first_metrics[metric_key] - last_val = last_metrics[metric_key] - delta_key = f"ada_delta_{metric_key}" - delta_pct = delta_metrics[delta_key] - - # Create bar chart - x = np.arange(2) - bars = ax.bar(x, [first_val, last_val], width=0.6, alpha=0.8) - - # Color bars based on improvement - # For collision/offroad, decrease is good (green), increase is bad (red) - # For score/return, increase is good (green), decrease is bad (red) - if metric_key in ["collision_rate", "offroad_rate"]: - bars[0].set_color("gray") - bars[1].set_color("green" if delta_pct < 0 else "red") - else: - bars[0].set_color("gray") - bars[1].set_color("green" if delta_pct > 0 else "red") - - # Add value labels on bars - for i, (bar, val) in enumerate(zip(bars, [first_val, last_val])): - height = bar.get_height() - ax.text( - bar.get_x() + bar.get_width() / 2.0, - height, - f"{val:.3f}", - ha="center", - va="bottom", - fontsize=10, - fontweight="bold", - ) - - # Add delta percentage annotation - mid_x = 0.5 - mid_y = max(first_val, last_val) * 0.5 - arrow_props = dict( - arrowstyle="->", - lw=2, - color="green" - if (delta_pct > 0 and metric_key in ["score", "episode_return"]) - or (delta_pct < 0 and metric_key in ["collision_rate", "offroad_rate"]) - else "red", - ) - - ax.annotate( - f"{delta_pct:+.1f}%", - xy=(1, last_val), - xytext=(mid_x, mid_y), - fontsize=14, - fontweight="bold", - ha="center", - bbox=dict(boxstyle="round,pad=0.5", facecolor="yellow", alpha=0.7), - arrowprops=arrow_props, - ) - - # Formatting - ax.set_xticks(x) - ax.set_xticklabels(["First Scenario\n(0-shot)", "Last Scenario\n(Adapted)"], fontsize=11) - ax.set_ylabel(metric_name, fontsize=12, fontweight="bold") - ax.set_title(f"{metric_name}", fontsize=13, fontweight="bold") - ax.grid(axis="y", alpha=0.3, linestyle="--") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - plt.tight_layout() - plt.savefig(output_path.replace(".json", "_adaptive_metrics.png"), dpi=300, bbox_inches="tight") - print(f"\nAdaptive metrics plot saved to {output_path.replace('.json', '_adaptive_metrics.png')}") - plt.close() - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--policy-path", type=str, required=True) - parser.add_argument("--num-maps", type=int, default=10) - parser.add_argument("--num-rollouts", type=int, default=1000) - parser.add_argument("--batch-size", type=int, default=32, help="Max parallel rollouts per batch") - parser.add_argument("--num-workers", type=int, default=16) - parser.add_argument("--num-agents", type=int, default=64) - parser.add_argument( - "--condition-type", - type=str, - default="none", - choices=["none", "reward", "entropy", "discount", "all"], - help="Conditioning type (none, reward, entropy, discount, all)", - ) - parser.add_argument("--output", type=str, default="eval_human_logs.json") - parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") - parser.add_argument( - "--max-controlled-agents", type=int, default=-1 - ) ## needs to be 1 if you want human logs, -1 if you want Self Play - parser.add_argument("--adaptive-driving-agent", type=int, default=0, help="Enable adaptive driving agent") - parser.add_argument("--k-scenarios", type=int, default=1, help="Number of scenarios (default 1 for non-adaptive)") - parser.add_argument("--dynamics-model", type=str, default="classic") - args = parser.parse_args() - - num_batches = (args.num_rollouts + args.batch_size - 1) // args.batch_size - - print(f"Evaluation Configuration:") - print(f" Policy: {args.policy_path}") - print(f" Conditioning: {args.condition_type}") - print(f" Num maps: {args.num_maps}") - print(f" Total rollouts: {args.num_rollouts}") - print(f" Batch size: {args.batch_size}") - print(f" Num batches: {num_batches}") - print(f" Num agents per env: {args.num_agents}") - print(f" Adaptive agent: {bool(args.adaptive_driving_agent)}") - print(f" K scenarios: {args.k_scenarios}") - print(f" Output: {args.output}\n") - print(f" Dynamics Model: {args.dynamics_model}") - - # Load policy - print("Loading policy...") - env_name = "puffer_adaptive_drive" if args.adaptive_driving_agent else "puffer_drive" - make_env = env_creator(env_name) - temp_env = make_env( - num_agents=64, - num_maps=args.num_maps, - scenario_length=91, - co_player_cond_type=args.condition_type, - adaptive_driving_agent=args.adaptive_driving_agent, - k_scenarios=args.k_scenarios, - dynamics_model=args.dynamics_model, - ) - - base_policy = Drive(temp_env, input_size=64, hidden_size=256) - policy = Recurrent(temp_env, base_policy, input_size=256, hidden_size=256).to(args.device) - state_dict = torch.load(args.policy_path, map_location=args.device) - state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} - policy.load_state_dict(state_dict) - policy.eval() - temp_env.close() - print("Policy loaded successfully\n") - - # Run evaluation in batches - all_returns = [] - all_scenario_metrics = [] # Track metrics per scenario for adaptive agents - all_metrics = [] - - env_kwargs = { - "num_agents": args.num_agents, - "num_maps": args.num_maps, - "max_controlled_agents": args.max_controlled_agents, - "report_interval": 1, - "scenario_length": 91, - "adaptive_driving_agent": args.adaptive_driving_agent, - "k_scenarios": args.k_scenarios, - "dynamics_model": args.dynamics_model, - "co_player_cond_type": args.condition_type, - "co_player_cond_entropy_ub": 0.05, - "co_player_cond_discount_lb": 0.40, - } - - print("Running evaluation...") - for batch_idx in range(num_batches): - batch_rollouts = min(args.batch_size, args.num_rollouts - batch_idx * args.batch_size) - - print(f"Batch {batch_idx + 1}/{num_batches} ({batch_rollouts} rollouts)") - - # Find largest valid num_workers (divisor of batch_rollouts) - max_workers = min(args.num_workers, batch_rollouts) - while batch_rollouts % max_workers != 0: - max_workers -= 1 - - vecenv = pufferlib.vector.make( - make_env, - env_kwargs=env_kwargs, - backend=pufferlib.vector.Multiprocessing, - num_envs=batch_rollouts, - num_workers=max_workers, - ) - - obs, _ = vecenv.reset() - total_agents = obs.shape[0] - - state = { - "lstm_h": torch.zeros(total_agents, policy.hidden_size, device=args.device), - "lstm_c": torch.zeros(total_agents, policy.hidden_size, device=args.device), - } - - batch_infos = [] - total_reward = np.zeros(total_agents) - scenario_infos = [] # Track infos per scenario - - with torch.no_grad(): - # Run through all scenarios (1 for non-adaptive, k for adaptive) - for scenario in range(args.k_scenarios): - scenario_info_list = [] - desc = f" Scenario {scenario + 1}/{args.k_scenarios}" if args.k_scenarios > 1 else " Steps" - for t in tqdm(range(91), desc=desc, ncols=80, leave=False): - obs_t = torch.as_tensor(obs, device=args.device) - logits, _ = policy.forward_eval(obs_t, state) - action, _, _ = pufferlib.pytorch.sample_logits(logits) - - obs, reward, done, trunc, info = vecenv.step(action.cpu().numpy()) - total_reward += reward - - if info: - valid_infos = [inf for inf in info if "score" in inf] - batch_infos.extend(valid_infos) - scenario_info_list.extend(valid_infos) - - # Store per-scenario infos - if args.adaptive_driving_agent: - scenario_infos.append(scenario_info_list) - - vecenv.close() - - # Aggregate batch metrics - num_infos = len(batch_infos) or 1 - batch_metrics = { - "score": sum(info.get("score", 0) for info in batch_infos) / num_infos, - "collision_rate": sum(info.get("collision_rate", 0) for info in batch_infos) / num_infos, - "offroad_rate": sum(info.get("offroad_rate", 0) for info in batch_infos) / num_infos, - "completion_rate": sum(info.get("completion_rate", 0) for info in batch_infos) / num_infos, - "dnf_rate": sum(info.get("dnf_rate", 0) for info in batch_infos) / num_infos, - "avg_collisions_per_agent": sum(info.get("avg_collisions_per_agent", 0) for info in batch_infos) - / num_infos, - "avg_offroad_per_agent": sum(info.get("avg_offroad_per_agent", 0) for info in batch_infos) / num_infos, - } - - rollout_rewards = total_reward.reshape(batch_rollouts, args.num_agents) - rollout_returns = rollout_rewards.mean(axis=1) - - all_returns.extend(rollout_returns.tolist()) - all_metrics.append(batch_metrics) - - # Store scenario-specific metrics for adaptive agents - if args.adaptive_driving_agent: - batch_scenario_metrics = [] - for scenario_info_list in scenario_infos: - num_scenario_infos = len(scenario_info_list) or 1 - scenario_metrics = { - "score": sum(info.get("score", 0) for info in scenario_info_list) / num_scenario_infos, - "collision_rate": sum(info.get("collision_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "offroad_rate": sum(info.get("offroad_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "completion_rate": sum(info.get("completion_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "dnf_rate": sum(info.get("dnf_rate", 0) for info in scenario_info_list) / num_scenario_infos, - "num_goals_reached": sum(info.get("num_goals_reached", 0) for info in scenario_info_list) - / num_scenario_infos, - "lane_alignment_rate": sum(info.get("lane_alignment_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "avg_displacement_error": sum(info.get("avg_displacement_error", 0) for info in scenario_info_list) - / num_scenario_infos, - "episode_return": sum(info.get("episode_return", 0) for info in scenario_info_list) - / num_scenario_infos, - } - # Compute perf (score without collision before goal) - scenario_metrics["perf"] = scenario_metrics["score"] - batch_scenario_metrics.append(scenario_metrics) - - all_scenario_metrics.append(batch_scenario_metrics) - - # Aggregate across all batches - all_returns = np.array(all_returns) - - # Divide by k_scenarios since we accumulated across all scenarios - all_returns = all_returns / args.k_scenarios - - metrics = {k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0].keys()} - - results = { - "avg_return": float(np.mean(all_returns)), - "std_return": float(np.std(all_returns)), - "se_return": float(np.std(all_returns) / np.sqrt(len(all_returns))), # Standard error - **{k: float(v) for k, v in metrics.items()}, - } - - # Compute adaptive delta metrics from scenario metrics - first_scenario_metrics = None - last_scenario_metrics = None - - if args.adaptive_driving_agent and len(all_scenario_metrics) > 0: - # Aggregate scenario metrics across all batches - # all_scenario_metrics is a list of [batch][scenario] metrics - # We need to compute average for each scenario across all batches - - aggregated_scenario_metrics = [] - for scenario_idx in range(args.k_scenarios): - scenario_metrics_list = [batch[scenario_idx] for batch in all_scenario_metrics] - - # Average each metric across batches for this scenario - avg_scenario_metrics = {} - for key in scenario_metrics_list[0].keys(): - avg_scenario_metrics[key] = np.mean([m[key] for m in scenario_metrics_list]) - - aggregated_scenario_metrics.append(avg_scenario_metrics) - - # Get first and last scenario metrics - first_scenario_metrics = aggregated_scenario_metrics[0] - last_scenario_metrics = aggregated_scenario_metrics[-1] - - # Helper function to compute delta percentage - def compute_delta_percent(first_val, last_val): - if abs(first_val) < 0.0001: - return 0.0 - return (last_val - first_val) / first_val * 100.0 - - # Compute all delta metrics - results["ada_delta_completion_rate"] = compute_delta_percent( - first_scenario_metrics["completion_rate"], last_scenario_metrics["completion_rate"] - ) - results["ada_delta_score"] = compute_delta_percent( - first_scenario_metrics["score"], last_scenario_metrics["score"] - ) - results["ada_delta_perf"] = compute_delta_percent(first_scenario_metrics["perf"], last_scenario_metrics["perf"]) - results["ada_delta_collision_rate"] = compute_delta_percent( - first_scenario_metrics["collision_rate"], last_scenario_metrics["collision_rate"] - ) - results["ada_delta_offroad_rate"] = compute_delta_percent( - first_scenario_metrics["offroad_rate"], last_scenario_metrics["offroad_rate"] - ) - results["ada_delta_num_goals_reached"] = compute_delta_percent( - first_scenario_metrics["num_goals_reached"], last_scenario_metrics["num_goals_reached"] - ) - results["ada_delta_dnf_rate"] = compute_delta_percent( - first_scenario_metrics["dnf_rate"], last_scenario_metrics["dnf_rate"] - ) - results["ada_delta_lane_alignment_rate"] = compute_delta_percent( - first_scenario_metrics["lane_alignment_rate"], last_scenario_metrics["lane_alignment_rate"] - ) - results["ada_delta_avg_displacement_error"] = compute_delta_percent( - first_scenario_metrics["avg_displacement_error"], last_scenario_metrics["avg_displacement_error"] - ) - results["ada_delta_episode_return"] = compute_delta_percent( - first_scenario_metrics["episode_return"], last_scenario_metrics["episode_return"] - ) - - # Store first and last scenario values for reporting - results["first_scenario_score"] = float(first_scenario_metrics["score"]) - results["first_scenario_collision_rate"] = float(first_scenario_metrics["collision_rate"]) - results["first_scenario_offroad_rate"] = float(first_scenario_metrics["offroad_rate"]) - results["first_scenario_episode_return"] = float(first_scenario_metrics["episode_return"]) - - results["last_scenario_score"] = float(last_scenario_metrics["score"]) - results["last_scenario_collision_rate"] = float(last_scenario_metrics["collision_rate"]) - results["last_scenario_offroad_rate"] = float(last_scenario_metrics["offroad_rate"]) - results["last_scenario_episode_return"] = float(last_scenario_metrics["episode_return"]) - - with open(args.output, "w") as f: - json.dump(results, f, indent=2) - - print(f"\nResults:") - print(f" Return: {results['avg_return']:.2f} ± {results['se_return']:.2f} (SE)") - print(f" Score: {results['score']:.3f}") - print(f" Completion: {results['completion_rate']:.3f}") - print(f" Collision: {results['collision_rate']:.3f}") - print(f" Collision per agent: {results['avg_collisions_per_agent']:.3f}") - print(f" Offroad: {results['offroad_rate']:.3f}") - - if args.adaptive_driving_agent: - print(f"\n0-Shot Performance (First Scenario):") - print(f" Score: {results['first_scenario_score']:.3f}") - print(f" Collision: {results['first_scenario_collision_rate']:.3f}") - print(f" Offroad: {results['first_scenario_offroad_rate']:.3f}") - print(f" Return: {results['first_scenario_episode_return']:.2f}") - - print(f"\nAdapted Performance (Last Scenario):") - print(f" Score: {results['last_scenario_score']:.3f}") - print(f" Collision: {results['last_scenario_collision_rate']:.3f}") - print(f" Offroad: {results['last_scenario_offroad_rate']:.3f}") - print(f" Return: {results['last_scenario_episode_return']:.2f}") - - print(f"\nAdaptive Metrics (Delta %):") - print(f" Score: {results['ada_delta_score']:.2f}%") - print(f" Collision rate: {results['ada_delta_collision_rate']:.2f}%") - print(f" Offroad rate: {results['ada_delta_offroad_rate']:.2f}%") - print(f" Episode return: {results['ada_delta_episode_return']:.2f}%") - - # Generate visualization - plot_adaptive_metrics(first_scenario_metrics, last_scenario_metrics, results, args.output) - - print(f"\nSaved to {args.output}") - - -if __name__ == "__main__": - main() diff --git a/external/pyxodr b/external/pyxodr new file mode 160000 index 0000000000..cd4b837a65 --- /dev/null +++ b/external/pyxodr @@ -0,0 +1 @@ +Subproject commit cd4b837a651d4f10c3c4e77b04a029cac367c64b diff --git a/notes/nuplan_hard.md b/notes/nuplan_hard.md new file mode 100644 index 0000000000..272f883189 --- /dev/null +++ b/notes/nuplan_hard.md @@ -0,0 +1,150 @@ +# `nuplan_hard` map split + +## Why this exists + +Default eval (`--eval.map-dir resources/drive/binaries/nuplan_201`, 50 rollouts) shows +`ada_delta_score ≈ ±0.005` across all configs. Looks like the policy doesn't +adapt across scenarios. + +Diagnosis (from `score_maps_interaction.py` over the 5401 maps in nuplan_201): + +| Property | Count | +|----------|------:| +| Total maps | 5401 | +| Maps where SDC has *any* moving vehicle within 5m at any timestep | 2497 (46%) | +| Maps where SDC has zero such interaction-steps | **2904 (54%)** | + +**Over half of all maps have zero ego-other interaction.** On those maps the +ego just drives straight and `ada_delta` collapses to noise. The signal exists, +but the average is dominated by maps where adaptation is irrelevant. + +## What `nuplan_hard` is + +A symlink-only directory at `resources/drive/binaries/nuplan_hard/` containing +the **top 10% of nuplan_201 maps by SDC-vehicle interaction density** (540 maps, +sequentially renumbered as `map_001.bin` through `map_540.bin`). + +Threshold: maps with `sdc_interaction_steps ≥ 52`. Mean in the hard set: 79 +steps (out of 91-201 timesteps per scenario). + +Each `map_NNN.bin` is a symlink pointing back to the original map in +`nuplan_201/`. The mapping (new bin id ↔ original bin id ↔ scenario_id) is +recorded in `_manifest.csv` for traceability. + +## How to recreate + +The split is reproducible. To regenerate from the underlying JSON +trajectories in `data/nuplan_gpudrive/nuplan/`: + +```bash +# 1. Score every map by SDC interaction-steps (~3 min, multi-process) +python scripts/score_maps_interaction.py \ + --data-dir data/nuplan_gpudrive/nuplan \ + --out /tmp/nuplan_201_hardness_scores.csv + +# 2. Build the symlink directory from the top-10% slice +python scripts/build_nuplan_hard.py \ + --scores /tmp/nuplan_201_hardness_scores.csv \ + --source-dir resources/drive/binaries/nuplan_201 \ + --out-dir resources/drive/binaries/nuplan_hard \ + --metric sdc_interaction_steps \ + --top-pct 10 +``` + +Different threshold or metric? Pass other values to step 2 (`--top-pct 25`, +`--metric interaction_events`, etc.). The scoring step doesn't have to be +re-run — its CSV holds both metrics for every map. + +## Hardness definition (operational) + +For each scenario in the JSON dataset: + +1. Filter to vehicle-type agents only (excludes pedestrians, cyclists). +2. Identify the SDC (the `is_sdc=True` agent — this is the slot the ego + policy occupies during human-replay eval). +3. For each timestep `t`: + - Skip if SDC is not valid at `t`, or is parked (speed < 0.5 m/s). + - Find any other vehicle that is valid AND moving (speed > 0.5 m/s) + AND within Euclidean distance 5 m of the SDC. + - If at least one such vehicle exists, count this timestep as an + SDC interaction-step. +4. `sdc_interaction_steps` = total such timesteps for the scenario. + +This is purely a property of the map + recorded human trajectories — no +trained policy is involved. A map is "hard" iff the SDC's logged trajectory +brings it close to another moving vehicle, often. That's exactly when our +trained ego (which replaces the SDC at eval time) faces real interactions +to adapt to. + +## Eval against `nuplan_hard` + +Stock command (replace checkpoint and GPU): + +```bash +xvfb-run -a puffer eval puffer_adaptive_drive \ + --load-model-path experiments/puffer_adaptive_drive_/model_..._N.pt \ + --policy-architecture Transformer --rnn-name Transformer \ + --train.horizon 402 \ + --vec.num-workers 1 --vec.num-envs 1 \ + --env.map-dir resources/drive/binaries/nuplan_hard \ + --env.num-maps 540 \ + --env.scenario-length 201 \ + --env.k-scenarios 2 \ + --env.conditioning.type none \ + --env.co-player-enabled 1 \ + --env.co-player-policy.policy-path experiments/puffer_drive_2e029h15.pt \ + --env.co-player-policy.architecture Transformer \ + --env.co-player-policy.transformer.horizon 201 \ + --env.co-player-policy.conditioning.type all \ + --env.co-player-policy.conditioning.collision-weight-lb -2 \ + --env.co-player-policy.conditioning.collision-weight-ub 0 \ + --env.co-player-policy.conditioning.offroad-weight-lb -2 \ + --env.co-player-policy.conditioning.offroad-weight-ub 0 \ + --env.co-player-policy.conditioning.entropy-weight-lb 0 \ + --env.co-player-policy.conditioning.entropy-weight-ub 0.10 \ + --env.co-player-policy.conditioning.discount-weight-lb 0.4 \ + --env.co-player-policy.conditioning.discount-weight-ub 1 \ + --env.map-rand-per-scenario False \ + --eval.map-dir resources/drive/binaries/nuplan_hard \ + --eval.num-maps 540 \ + --eval.human-replay-eval True \ + --eval.human-replay-num-rollouts 200 \ + --eval.human-replay-num-maps 540 \ + --eval.human-replay-num-agents 540 +``` + +Note `--eval.human-replay-num-{rollouts,maps,agents}` are required: +without them, eval defaults to 100 maps / 100 rollouts which both +under-samples and may not all hit the hard split. + +## Result on baseline checkpoint + +On `puffer_adaptive_drive_4lm6kkh7/model_..._40.pt` (γ=0.995, partner 2e029h15) +with 200 rollouts: + +| Metric | Full set (50 rollouts) | `nuplan_hard` (200 rollouts) | +|--------|-----------------------:|-----------------------------:| +| `ada_delta_score` | ±0.005 | **+0.222 ± 0.18** | +| `ada_delta_collision_rate` | ~0 | -0.045 | +| `ada_delta_episode_return` | ~0 | +0.449 | +| `ada_delta_dnf_rate` | ~0 | -0.170 | +| `scenario_0_score` | — | 0.712 | +| `scenario_1_score` | — | **0.935** | + +The model **does** adapt (s_1 score 22% higher than s_0; collisions cut roughly +in half). The full-set average diluted the signal by 40-200×. + +## Caveats / things to keep in mind + +- "Hard" here is a property of the SDC's *recorded* trajectory in the + human-replay data. The trained ego is free to take a different path; a + defensive policy may avoid the close-proximity moments that scored the map + hard in the first place. The split is a *prior* over which maps are likely + to involve interaction, not a guarantee. +- Threshold (5 m, 0.5 m/s, top 10%) is somewhat arbitrary. Re-running with + different thresholds will produce different splits — `_manifest.csv` records + the actual threshold used. +- Eval channel is human-replay (other vehicles = recorded humans, no synthetic + partner). The hard-set definition matches this channel — for synthetic- + partner eval channels you'd want a different definition (e.g. + partner-sensitivity-based). diff --git a/notes/oracle_partner_conditioning_investigation.md b/notes/oracle_partner_conditioning_investigation.md new file mode 100644 index 0000000000..c75d29d7c8 --- /dev/null +++ b/notes/oracle_partner_conditioning_investigation.md @@ -0,0 +1,384 @@ +# Why doesn't the adaptive ego use its K/V cache? (and the oracle test plan) + +Investigation date: 2026-05-01. +Author: Claude (working with Mohit on the adaptive-driving paper). + +--- + +## Empirical setup that triggered this investigation + +We trained two adaptive ego policies against the entropy-sweep partner +`6rauydj2` (trained on `entropy_ub ∈ [0, 0.5]`): + +- **`curr_e0.5`** (wandb `hprfn8dc`) — entropy-curriculum on (ub + annealed 0.025 → 0.10 → 0.25 → 0.50 across the first ~120 episodes). +- **`nocurr_e0.5`** (wandb `7wm1sk5v`) — entropy-curriculum off + (ub fixed at 0.50 throughout). + +Both trained from scratch to ~epoch 203 (resumed from intermediate +checkpoints after host OOMs; resume restores optimizer + global_step +cleanly via `pufferl.py:280`). Same partner, same seed, same map dir, +same rollout config — **the only varying factor is the curriculum**. + +The training-time `ada_delta_score` favored `curr` over `nocurr`. We +ran three offline analyses on the final checkpoints to understand +**why** and **whether the policy is actually using the cache for +adaptation**. + +--- + +## Three offline analyses (all on 300 truly-held-out scenes) + +Held-out set: `resources/drive/binaries/nuplan_201_heldout300/` — +symlinked `map_001..300` → originals `map_5102..5401` (training used +the first 4999, so these are unseen). + +### Analysis 1: head-to-head eval (`puffer eval`, no co-player) + +``` +metric nocurr_e05 curr_e05 delta +───────────────────────────────────────────────────────────────────── +s0_rate 0.5906 0.6913 +0.1007 +s1_rate 0.5906 0.6946 +0.1040 +ada_delta 0.0000 0.0034 +0.0034 +p_s1_given_s0_pass 0.9830 0.9854 +0.0025 +p_s1_given_s0_fail 0.0246 0.0435 +0.0189 +n_s0_pass 176 206 +30 +n_s0_fail 122 92 -30 +``` + +Curr is ~10 pp better in absolute success rate in BOTH scenarios on +held-out scenes. ada_delta is ~0 for both — the curriculum gain is a +*better policy* gain, not an *adaptation* gain. Recovery rate +(`p_s1_given_s0_fail`) almost doubled with curriculum (2.5% → 4.4%) but +both numbers are tiny in absolute terms. + +### Analysis 2: attention probe (`scripts/probe_attention.py`) + +Cross-scenario attention mass = fraction of attention from s_1 query +positions to s_0 cache slots. Recorded per (layer, head) over a single +rollout against `6rauydj2`. + +| Layer/Head | curr_e05 | nocurr_e05 | +|------------|----------|------------| +| L0H0 | 0.44 | 0.68 | +| L0H1 | 0.51 | 0.69 | +| L0H2 | 0.55 | 0.67 | +| L0H3 | 0.49 | 0.67 | +| L1H0 | **0.14** | 0.66 | +| L1H1 | 0.35 | 0.70 | +| L1H2 | **0.05** | 0.63 | +| L1H3 | 0.23 | 0.65 | + +`curr` attends LESS to past than `nocurr` across every head — opposite +of the cache-use hypothesis. + +### Analysis 3: counterfactual cache (`scripts/counterfactual_cache.py`) + +Paired rollouts: same map, same seed, same partner conditioning. +Condition A preserves the K/V cache from s_0 → s_1; Condition B zeros +it. Difference isolates "does cache CONTENT matter?" + +| Metric | curr_e05 (z) | nocurr_e05 (z) | +|-------------------------------------|--------------|----------------| +| Paired s1 lift (preserved − zeroed) | -0.007 (-0.45) | +0.030 (+1.41) | +| P(s1 \| s0 fail) preserved | 0.30 | 0.14 | +| P(s1 \| s0 fail) zeroed | 0.30 | 0.10 | + +`curr` is unaffected by cache zeroing (z = -0.45). `nocurr` does feel +it — recovery rate drops ~30% relative when cache is wiped. + +### Joint reading + +| Question | curr_e0.5 | nocurr_e0.5 | +|-----------------------------------|-------------------|--------------| +| Better single-shot driver? | Yes (69% vs 59%) | No | +| Attends to past more? | No (0.05–0.55) | Yes (0.63+) | +| Cache content load-bearing? | No (z = -0.45) | Marginally yes (z = +1.41) | + +The curriculum produced a stronger driver that **ignores** the cache. +The non-curriculum policy attends and (slightly) uses the cache, but +that doesn't translate into better absolute scores. + +**Either way, neither policy has meaningful in-context adaptation.** + +--- + +## First-principles teardown: what does adaptation actually require? + +For the ego to "adapt to partner P across scenarios" via cache, ALL of +these must hold: + +1. **Partner type is observable** — some signal in ego's obs encodes + partner type. +2. **The encoder preserves it** — the encoder doesn't lose the + discriminating info during compression. +3. **The hidden states get written to cache** — automatic given the + architecture. +4. **The policy queries the cache** — attention pattern looks back. +5. **The retrieved values still carry the info** — what's stored is + still partner-relevant. +6. **The policy conditions on cache content** — the policy learns to + change its action distribution based on cache values. + +We have evidence on each link: + +- (1) **Weak**: partner conditioning is NOT in ego obs. The ego sees + the partner's `(rel_x, rel_y, rel_heading_x, rel_heading_y, width, + length, rel_signed_speed)` — instantaneous behavior only. To infer + partner type (e.g., `entropy_weight`), the ego must integrate + *behavioral variance* over many timesteps. From a single snapshot, + partner type is unobservable. +- (2) **Suspicious**: encoder compresses 1855 → 256. Partner-typing + info needs to survive this bottleneck, but the only push for it is + RL gradient — which is weak unless the policy is already using + partner type info. +- (3) ✓ +- (4) ✓ (probe confirmed) +- (5) **Likely the break**: counterfactual cache shows zeroing the + cache barely changes s_1 — the values don't carry usable info. +- (6) **The break**: if the values aren't useful, conditioning on + them can't help. + +### Three specific architectural concerns + +**(a) Observations are in EGO FRAME** +Every partner feature (`rel_x`, `rel_heading_x`, `rel_signed_speed`) +is computed relative to the ego's CURRENT pose. As the ego moves, the +same world-frame partner behavior produces different obs values. The +K/V cache stores ego-frame snapshots, but the ego frame is +non-stationary across scenarios. To compare partner behavior across +the scenario boundary, the policy would have to implicitly invert the +ego frame change — a hard sub-problem. + +**(b) Partner identity is unlabeled** +Slot K of partner_obs is "the K-th nearby agent in +`active_agent_indices` iteration order." If 60 partners are within +visibility, slot 5 could be partner-id-42 at one step and same agent +at the next (stable for stationary scenes), but there's no identity +TAG. Comparing slot K values across timesteps requires the policy to +verify that's the same partner — an implicit sub-problem. + +**(c) The latent we want to identify is BEHAVIORAL VARIANCE** +A high-entropy partner samples actions stochastically. Its +instantaneous position+speed look identical to a low-entropy partner. +The DIFFERENCE only shows up as "the partner did something +unexpected" — which requires the policy to have a forward model of +what the partner *should* be doing under each conditioning value, then +notice deviations. That's an enormous hidden inference problem, and +nothing pushes the policy to learn the forward model. + +### Conclusion + +End-to-end, we're asking the policy to learn (from sparse goal-reach +RL reward) to: +1. Track partner identities across timesteps (no ID labels) +2. Undo ego-frame motion to recover world-frame trajectories +3. Estimate behavioral variance per partner over many timesteps +4. Map variance → partner-type embedding +5. Use that embedding to modulate actions + +That's a tower of hard implicit inference problems on top of "drive a +car well." The optimization just doesn't push hard enough for steps +1-4, especially when most of the reward can be earned by single-shot +driving alone. This explains the empirical finding: the better +single-shot driver (curr) ignores the cache, because the gradient +toward "be a better driver" is much stronger than the gradient toward +"learn to read the cache." + +--- + +## What's already in the codebase that could help + +`grep -rin oracle …` returns NO hits anywhere in `/workspace/ADA`. + +The closest thing is the `[env.conditioning]` section in +`adaptive.ini` — but this is **per-ego self-conditioning** (each ego +agent gets its OWN conditioning vector that determines its OWN reward +function: `reward_collision_weight`, `entropy_weight` for its own +exploration, etc.). It is plumbed end-to-end: + +| Layer | What happens | +|--------------------|--------------------------------------------------| +| Python `Drive.__init__` | kwarg `conditioning={}` → `self.entropy_conditioned`, etc. (drive.py:60-167) | +| Adds `conditioning_dims` to `self.ego_features` so `obs_dim` accounts for the slots (drive.py:177) | +| C `env->collision_weights[i]` etc. allocated per active ego agent (drive.h:1841-1844) | +| C samples weights at episode reset (drive.h:2486-2494) | +| C appends weights to each ego's obs in `compute_observations` (drive.h:2294-2302) | + +This is NOT what we want, but the slots exist and are reachable. We +can **hijack them**: keep the obs format the same, but at sample time, +overwrite the per-ego conditioning weights with this env's partner +conditioning vector. Each ego in env E sees the same vector its +partner is using. + +--- + +## Plan: oracle test (give ego the partner's conditioning) + +### Goal + +A clean diagnostic: if the ego is just *handed* the partner's +conditioning vector as obs, does the policy then learn to adapt +(drive differently when entropy_weight is high vs low)? + +- **If yes**: the architecture downstream is fine. The bottleneck is + inference from behavior. Then we can pursue real fixes (per-partner + history features, world-frame obs, partner-attention layer, etc.). +- **If no**: the bottleneck is downstream — even with explicit + partner-type signal, the policy can't condition action on it. Then + we need to look at the policy head, the action distribution, or + the loss formulation. + +### C-side dive: where do the existing conditioning slots come from? + +The per-ego conditioning weights `env->{collision,offroad,goal,entropy,discount}_weights[i]` are used in **two places** in `drive.h`: + +1. **obs append** (lines 2294-2302) — appended to ego obs after `base_ego_dim`. +2. **reward computation** (lines 2625, 2637, 2669, 2679, 2691) — the per-step + reward at collision / offroad / goal events SCALES with these weights. + +This means we **cannot just enable `[env.conditioning].type=all` and +overwrite the C arrays** — that would also change the ego's reward +function (e.g., sampling collision_weight ∈ [-1, 0] makes some +collisions free). + +### Pufferl-side dive: more silent uses of these slots + +Even worse, **pufferl reads two of these slots from obs to drive +training dynamics**: + +| Slot | Read by | Becomes | +|------|---------|---------| +| collision_w, offroad_w, goal_w | (only C-side reward) | per-agent reward scaling | +| entropy_w (slot at offset 12 with reward+ego) | `pufferl.py:885` | per-agent entropy weight in PPO loss | +| discount_w (slot at offset 13 with reward+ego) | `pufferl.py:743` | per-agent γ in GAE advantage computation | + +The `pufferl` hooks are gated on `entropy_conditioned` / +`discount_conditioned` flags on the env. If we naively enable +`[env.conditioning].type=all` to get the obs slots and then override +them in Python, pufferl will use the **partner's** entropy/discount +values as the **ego's** PPO hyperparameters — a major silent training +distortion. + +### The complete "no behavior change" recipe + +To get the obs slots while preserving the no-conditioning training +dynamics: + +1. **Force `[env.conditioning].type = "all"`** in the env init (from + the `ego_is_oracle` flag), so 5 obs slots are allocated and the C + `env->*_weights[i]` arrays exist. +2. **Pin all 5 sampling ranges to lb=ub=default** so the C-side + per-agent samples are constant and equal to the default reward + weights. Reward computation is identical to a non-conditioned run. +3. **After C-side setup, in Python, flip + `self.entropy_conditioned = False` and + `self.discount_conditioned = False`** (keep `reward_conditioned = + True` — it has no pufferl-side hook). Pufferl now sees False on + these flags and skips the per-agent γ / entropy hooks → uses + global defaults. +4. **In step() / reset(), Python overrides the 5 obs slots in each + ego row with `env_conditioning[env_of_ego]`** — the partner's + conditioning vector. This is pure obs signal: pufferl no longer + reads these slots, the policy just sees them as input. +5. **Bypass the `line 350` `NotImplementedError`** that forbids dual + ego+co-player conditioning. The check predates oracle and is not + protecting any actual coupling for our setup. + +This makes the diff one-flag, no-C, contained: +- One env kwarg + ini knob (`ego_is_oracle`) +- Init-time validation + flag fix-up + bypass +- One step()-time obs override loop +- One sentence in the launcher to set the deterministic-default lb/ub on the 5 dims + +### Implementation (Python obs override, no C changes) + +1. **New env kwarg** `ego_is_oracle = False` (default off). +2. **Adaptive.ini knob** `ego_is_oracle = False` so CLI auto-generates + the `--env.ego-is-oracle` flag. +3. **In `step()`** (drive.py, after `binding.vec_step(self.c_envs)`, + before return), if `ego_is_oracle` is on AND `env_conditioning` + has slots: + - For each env e, for each ego in env e, overwrite the + `[base_ego_dim : base_ego_dim + n_partner_dims]` slice of + `self.observations[ego_id]` with `self.env_conditioning[e]`. + - Same for the post-reset obs (need to call after `_set_env_variables` too). +4. **In `_add_co_player_conditioning`-style helper**: factor the + slot-locating logic so we can reuse it for ego. +5. **Launcher**: must set `--env.conditioning.type all` AND + `--env.conditioning.{collision,offroad,goal,entropy,discount}-weight-{lb,ub}` + to default values (so the C-side reward isn't changed): + ``` + --env.conditioning.collision-weight-lb -0.5 --env.conditioning.collision-weight-ub -0.5 + --env.conditioning.offroad-weight-lb -0.5 --env.conditioning.offroad-weight-ub -0.5 + --env.conditioning.goal-weight-lb 1.0 --env.conditioning.goal-weight-ub 1.0 + --env.conditioning.entropy-weight-lb 0.001 --env.conditioning.entropy-weight-ub 0.001 + --env.conditioning.discount-weight-lb 0.98 --env.conditioning.discount-weight-ub 0.98 + ``` + These are the same defaults as the no-conditioning case, so reward + semantics are identical to a non-conditioned run. The slots in obs + are then allocated but their values are overwritten by Python with + the partner's per-env conditioning vector. + +### Caveats / things to verify + +- **Slot offset & width**: ego's conditioning width = 3 (reward) + 1 + (entropy) + 1 (discount) = 5 when type="all". Partner's width must + match — also "all" → 5. We assert this at init. +- **The partner conditioning sample is per-env, valid for the whole + episode.** Need to inject the same vector EVERY step (overwrite obs + slots in step(), not just at reset). +- **The model checkpoint obs_dim with type=all is 1855** (ego_dim=14 + = base 9 + cond 5, partner_obs and road_obs unchanged). Fresh + training run from scratch. Cannot load the curr/nocurr e=0.5 + checkpoints into this — their obs_dim is also 1855 but their slots + contained their own conditioning, not the partner's; the policy + weights have learned an interpretation that doesn't transfer. + Train fresh. + +### What success looks like + +If the oracle test works, we expect to see — vs the no-oracle baseline: +- ada_delta improves significantly (e.g., +0.05+ instead of ~0) +- Per-condition behavior visible: rollouts at extreme partner + conditioning (e.g., entropy=0.5 vs entropy=0.0) should produce + visibly different ego policies (different distance kept from + partner, different speeds, etc.) +- Counterfactual: zeroing the partner conditioning slots (not the + full cache, just the oracle slots) should hurt s_0 score because + the ego is now driving "blind" to partner type + +### Run plan once implemented + +User decision: **only do the curr experiment** — single oracle run, no +nocurr-baseline pairing. The comparison is against the existing +`curr_e0.5` (`hprfn8dc`) which has identical config except no oracle. + +1. **Wire the flag** (drive.py + adaptive.ini). +2. **Smoke test** — first 30 sec of a launch, confirm: + - obs slot layout is correct + - the partner conditioning is actually being injected (print first + few obs slot values, verify they match `env_conditioning[e]`) + - training doesn't crash +3. **Full oracle run** — same config as `curr_e0.5`: + - Partner: `6rauydj2` at `e_ub=0.5` + - k=2/201, horizon=402, nw=24 nv=24 (or 32/32 if k_eff has freed + RAM by then) + - Entropy curriculum: ON (mirror `curr_e0.5`) + - Ego conditioning: type=all with default-value lb=ub + - Ego oracle: ON +4. **Analyses identical to the curr/nocurr runs**: + - Eval on 300 held-out scenes (same `nuplan_201_heldout300` + symlinked dir we already built) + - Counterfactual cache (zero the cache to see if it matters now + that the policy has the oracle slot) + - Attention probe (does attention still go to past, or does the + policy ignore the past now that oracle gives the answer?) + +If the oracle policy adapts strongly (`ada_delta` jumps from ~0 to +0.05+, attention to past drops, conditioning slot is load-bearing), +the bottleneck is **inference from behavior** and we move on to fixes +like per-partner history features. If the oracle policy STILL doesn't +adapt, the bottleneck is **downstream** (action conditioning) and we +need to look at the policy head, action distribution, or loss. diff --git a/notes/paper_plan.md b/notes/paper_plan.md new file mode 100644 index 0000000000..101806eca7 --- /dev/null +++ b/notes/paper_plan.md @@ -0,0 +1,198 @@ +# Adaptive Population Play — paper experiment plan + +40-hour execution window. 8 GPUs. Each ego training ≈12h, each +co-player training ≈4.5h, each hard-map eval ≈4 min (M=10 × 540 maps). +Budget = 320 GPU-h; T1 ≈ 180, T1+T2 ≈ 300. + +## What we already have + +- **Adaptive ego baseline** `4lm6kkh7` (γ=0.995, partner `2e029h15`, + e_ub=0.10) → `ada_delta = +0.222 ± 0.18` on `nuplan_hard` (200 rollouts). +- **5 partner-sweep adaptive egos** (one per entropy partner, γ=0.995, + lane=0.025) — final ckpts at epoch 152. Hard-map eval (M=10): + + | partner | e_ub | ada_delta ± std | + |---------|-----:|----------------:| + | miku2puk | 0.05 | **+0.155 ± 0.008** | + | 2e029h15 | 0.10 | **+0.154 ± 0.085** | + | m2ygolog | 0.20 | -0.010 ± 0.016 | + | 6rauydj2 | 0.50 | -0.001 ± 0.008 | + | n48teqjs | 1.00 | -0.000 ± 0.003 | + +- **15-partner DE-grid in progress** on GPUs 0-4, ETA 22:10 UTC. + Grid: `discount_lb ∈ {0.4, 0.6, 0.8} × entropy_ub ∈ {0.001, 0.01, 0.05, + 0.1, 0.2}`. Wandb: `ada_coplayer_sweep`. +- **`nuplan_hard` map split** (top 10% by SDC interaction), recipe + recorded in `notes/nuplan_hard.md`. +- **Fast eval recipe**: M=10, vanilla code, 5390 (map, ego) datapoints/eval, + SE on mean ≈0.003. + +## Key finding driving the plan + +**Adaptation only emerges against deterministic partners (e ≤ 0.10).** +For e ≥ 0.20 the gap collapses. This means: + +- Robustness experiments need to vary the partner distribution along the + *deterministic ↔ noisy* axis, not blindly use [0, 1]. +- The "interesting" ada_delta lives in a narrow slice of partner space. +- We should structure the partner-distribution test to span that slice + cleanly. + +--- + +## Tier 1 — must-have for paper + +### T1.1 Main result table: 4 ablation conditions vs adaptive-conditioned baseline + +Using a single fixed partner (`2e029h15`, e=0.10 — matches the +`4lm6kkh7` baseline that gave +0.222), train 4 ego variants: + +| label | description | training change | +|-------|-------------|----------------| +| **A** | Adaptive + conditioned partner | (baseline, already trained) | +| **B** | Non-adaptive + conditioned partner | reset ego K/V at every scenario boundary (`RECOVERY_CACHE_RESET_PER_SCENARIO=1` env var, or k_scenarios=1) | +| **C** | Adaptive + fixed-policy partner | partner = single ckpt, ego never sees conditioning slot | +| **D** | Self-play | both ego and "partner" use the same current policy | +| **E** | Adaptive + log-replay only | no learned partner, others follow recorded humans | + +Cost: 4 trainings × 12h = **48 GPU-h**. Eval: 5 × 4 min on free GPUs. + +### T1.2 Scaling along k_scenarios + +Train adaptive egos at `k_scenarios ∈ {1, 2, 4}`. k=8 is 48h on its own +and probably won't fit; defer. + +Cost: 12 + 12 + 24 = **48 GPU-h**. + +### T1.3 Robustness: in-distribution vs OOD partner distribution + +`[TBD — confirm exact ranges]`. Updated based on what we learned about partner space: + +- **In-dist baseline**: train ego sampling partner conditioning over + `e ∈ [0, 0.20]` (the slice where adaptation happens), test on same range. +- **OOD test**: train on `e ∈ [0, 0.05]` (very deterministic only), test + on `e ∈ [0.10, 0.20]`. Asks: does an ego trained on near-deterministic + partners still show ada_delta when faced with a slightly noisier one? +- **Aggressive vs cautious test**: condition partners with + collision_weight_lb=-2 (cautious) vs 0 (aggressive); reuse existing egos. + +Cost: 1 new training (the OOD ego) + 2 evals = **12 GPU-h** + **20 min**. + +--- + +## Tier 2 — should-have + +### T2.1 Robustness: held-out maps + +Eval-only. Build a `nuplan_hard_holdout` split (different 10% of the +nuplan_201 maps that we did NOT score / use for `nuplan_hard`). Eval the +5 partner-sweep egos against this set. Cost: ~30 min. + +### T2.2 Adversarial co-player + +Train a partner with `collision_weight = +2` (rewards collisions); eval +all egos against it. + +Cost: 4.5h partner + 5 evals × 4 min = **5 GPU-h**. + +### T2.3 Scaling: number of training maps + +Train ego on `nuplan_201` subsets of size 200, 1000, 5000. + +Cost: 3 × 12h = **36 GPU-h**. Defer if tight. + +### T2.4 Conditioning leave-one-out (5 dims) + +Train 5 egos, each with one conditioning dim disabled +(collision / offroad / goal / entropy / discount). + +Cost: 5 × 12h = **60 GPU-h**. Pick 2-3 dims if budget tight. + +--- + +## Tier 3 — defer to post-paper unless time + +- LSTM vs Transformer (one extra training) +- RNN state size sweep +- Co-player ratio (25 / 50 / 75 %) +- Diversity of conditioning ranges (wide vs narrow) + +--- + +## Proposed 40h schedule + +Times relative to the 09:55 UTC start of this plan. All wall-clock +estimates assume 12h per ego training and 4.5h per co-player. + +### Phase 1 (h0 → h13, overlaps with DE-grid finishing on GPUs 0-4) + +Free GPUs: 5, 6, 7. Launch 3 of the 4 T1.1 ablations: + +| GPU | run | duration | +|----:|-----|---------:| +| 5 | T1.1-B non-adaptive ego vs `2e029h15` | 12h | +| 6 | T1.1-C adaptive ego vs fixed partner | 12h | +| 7 | T1.1-D self-play | 12h | + +### Phase 2 (h13 → h25, all 8 GPUs free) + +DE-grid done at h~12. Launch: + +| GPU | run | duration | +|----:|-----|---------:| +| 0 | T1.1-E log-replay ego | 12h | +| 1 | T1.2 scaling k=1 | ~6h | +| 2 | T1.2 scaling k=4 | ~24h (runs into Phase 3) | +| 3 | T1.3 OOD ego (train on e ∈ [0, 0.05]) | 12h | +| 4 | T2.2 adversarial partner (4.5h) → ego | ~16h | +| 5 | T2.4 leave-one-out: discount | 12h | +| 6 | T2.4 leave-one-out: entropy | 12h | +| 7 | T2.4 leave-one-out: collision | 12h | + +Run T1.2 k=2 already exists in 4lm6kkh7; reuse for the curve. + +### Phase 3 (h25 → h40, finishing + evals) + +- k=4 finishes around h37 +- All other Phase 2 trainings done by h25 +- Free GPUs run evals (M=10 × 540 maps × ~4 min each) +- Build all paper plots: + - **Fig 1**: main result bar plot (5 conditions, ada_delta ± std) + - **Fig 2**: ada_delta vs partner entropy (uses 5 partner egos + + DE-grid eval) + - **Fig 3**: ada_delta vs k_scenarios + - **Fig 4**: in-dist vs OOD partner distribution + - **Fig 5**: leave-one-out conditioning ablation + +--- + +## Open questions to resolve before launching + +1. **Self-play definition (T1.1-D)**: ego = partner = same live policy + on each step? Or ego = current policy, partner = past-snapshot? The + former is "true" self-play but the partner is non-stationary; the + latter is more like league play. Pick one. +2. **Non-adaptive definition (T1.1-B)**: do we set `k_scenarios = 1` (so + training never gives the ego cross-scenario context) or keep + `k_scenarios = 2` and reset the cache at the boundary? The latter + isolates "is it the cache that helps?" cleanly; the former is more + honest as a "non-adaptive" baseline. Pick one or do both. +3. **OOD ranges (T1.3)**: my proposal `[0, 0.05]` train → `[0.10, 0.20]` + test is one option. Confirm this is the slice you want — alternatives + are deterministic→noisy (`[0, 0.1]` → `[0.5, 1.0]`) which we already + know breaks, or wide→narrow (`[0, 1]` → `[0.05, 0.2]`). +4. **Drop k=8?** Yes unless you have a strong opinion. 48h alone. + +Answer 1-4 and I'll start writing the launchers. + +--- + +## Quick links + +- DE-grid sweep status: `tmux attach -t coplayer_de_grid`, + `/tmp/coplayer_de_grid_driver.log` +- Hard-map eval recipe: `notes/nuplan_hard.md` +- Latest ego eval results: `/tmp/eval_partner_sweep_m10.log` +- 5 partner-sweep ckpts: + `experiments/puffer_adaptive_drive_{0bmlyvg3,l7c13x1m,0yih6s9k,sa80qcs2,p6q2b2xp}/model_..._000152.pt` +- Adaptive launcher template: `scripts/adaptive/nuplan_transformer_local_k2_201_partner_sweep.sh` diff --git a/notes/trial_episode_design.md b/notes/trial_episode_design.md new file mode 100644 index 0000000000..8c89d25d8c --- /dev/null +++ b/notes/trial_episode_design.md @@ -0,0 +1,202 @@ +# Trial-as-episode redesign for adaptive driving + +## Goal + +Replace the current scenario-based episode structure with **variable-length trials packed into a fixed-budget episode**, where each trial is an independent "fresh map, fresh agent" goal-reach attempt and the K/V cache persists across trials within an episode. + +## Decided semantics (from user) + +1. **"New trial"**: fresh agent placement (back to start), fresh goal, fresh recorded-human trajectories restarting from t=0. **Same map** within an episode (no map-swap-per-trial). The map_rand_per_scenario / _reinit_envs_with_new_maps machinery is being removed entirely from the codebase as part of this work. +2. **Episode budget**: fixed at `k * scenario_length` ticks. If trials finish early, fill remaining budget with *another fresh episode* (start trial 0 of a new episode in the same buffer slot, separated by terminal flag). +3. **K/V cache persistence**: + - **Adaptive (treatment)**: cache persists across trials within an episode. Only resets at episode boundary. + - **Control**: cache resets at every trial boundary. Configurable knob. +4. **Full design, not MVP.** OK with 5-7 days of work. + +## Code to be DELETED in this refactor + +- `map_rand_per_scenario` flag and all its handling sites (drive.py, drive.h, eval scripts) +- `_reinit_envs_with_new_maps` Python function (drive.py:888-975) — entire function gone +- The `_render_keep_client_on_swap` / `vec_donate_client` / `vec_adopt_client` plumbing that exists *only* to keep raylib alive across the now-removed reinit +- The "broken as an ICL probe" warning block in adaptive.ini + +The trial-boundary swap is now a fresh, smaller C function (`start_new_trial`) that does ONLY agent-state reset + goal sample, no env tear-down, no terminals write. New goal source: either `sample_new_goal()` (a road-lane point ahead of agent) or another agent's `init_goal` from the same map. + +## Key implementation insight (from Plan agent) + +`done_mask = terminals + truncations` controls cache reset (pufferl.py:610). +GAE only uses `terminals` for bootstrap truncation (not truncations). + +So we have three distinct events: + +| event | terminals | truncations | cache | GAE bootstrap | +|-------|----------:|------------:|:------|:--------------| +| within-trial step | 0 | 0 | continues | continues | +| **trial boundary (adaptive)** | 0 | 0 | continues | continues across | +| **trial boundary (control)** | 1 | 0 | resets | truncates | +| episode boundary | 1 | 0 | resets | truncates | + +The episode boundary marker lets pufferl pack two episodes into one segment row — `create_episode_mask` will block cross-episode attention via the cumsum-on-terminals episode-id mechanism. Variable-length episodes inside a fixed-size segment work out of the box this way. + +## Architecture + +### C side (drive.h) + +**New goal_behavior mode `GOAL_TRIAL = 3`**: +- On goal-reach: full goal_weight reward, mark `current_goal_reached=1`, set per-agent `trial_complete=1`. **Do not** stop, do not respawn (yet — Python decides). +- A new per-env counter `current_trial` (uses existing `current_scenario`-style logic). + +**New per-env flag accessible from Python**: `trial_ended_this_step[env_idx]`. Set to 1 in the C step when: +- All ego agents in the env have reached goal, OR +- The trial-length budget (currently `scenario_length`) has elapsed without reach. + +This flag is consumed by Python at the same tick — Python decides whether to swap maps, whether to set terminals, etc. + +**No `_reinit_envs_with_new_maps` involvement.** Trial boundary uses a new lighter C function `start_new_trial(env, agent_idx)`: +- Reset agent position via `set_start_position` for that agent only +- Sample new goal via `sample_new_goal()` (or pick from another agent's `init_goal`) +- Reset per-agent metrics (`current_goal_reached=0, collided_before_goal=0, stopped=0, removed=0`) +- Reset agent's per-trial Log fields +- DO NOT reset `env->timestep` — the experts/partners continue along their recorded trajectory +- DO NOT touch other agents + +Episode boundary (`tick == episode_budget`): pufferl handles this via the standard `done_mask` path. We set `terminals[ego_ids]=1` from Python, the env is fully reset by `binding.vec_reset` on the next `puffer_env.reset()` call. No special C function needed; the existing `c_reset` path does it. + +**Per-trial Log fields** (in `Log` struct): +- `trials_completed_this_episode` (int) +- `trials_attempted_this_episode` (int) +- `mean_time_to_goal_per_trial` (float, running mean) +- `per_trial_succeeded[MAX_TRIALS]` (fixed-size array, MAX_TRIALS=8) +- `per_trial_collided_before_goal[MAX_TRIALS]` (fixed-size array) + +vec_log aggregation accumulates these across envs. + +### Python side (drive.py + adaptive.py) + +**`AdaptiveDrivingAgent.__init__`**: add knobs +- `goal_behavior_trial = True` (whether to use new mode) +- `max_trials_per_episode` (default = `k_scenarios`) +- `trial_length` (default = `scenario_length`) +- `trial_cache_reset` (False = adaptive treatment, True = control) +- `episode_budget = max_trials_per_episode * trial_length` (just an alias) + +**`Drive.step` modifications**: +- Per-step: read `trial_ended_this_step` array from C. +- For each env where trial ended: + - Call C function `start_new_trial(env, ego_idx)` for each ego in that env + - Increment `current_trial` + - If `trial_cache_reset=True`: set `truncations[ego_ids_in_env]=1` (resets cache via done_mask, does NOT truncate GAE) + - Aggregate trial metrics into per-trial dict +- After all per-env trial-ends processed: + - Increment `episode_ticks` counter + - If `episode_ticks >= episode_budget` OR `current_trial >= max_trials_per_episode`: + - Set `terminals[ego_ids]=1` for ALL envs (episode boundary; pufferl's c_reset path handles the actual env reset on the next call) + - Reset per-episode counters + - Compute episode-level delta metrics (analogous to current `_compute_delta_metrics`) + - Aggregate per-trial metrics list + +**`_compute_delta_metrics` generalization**: now per-trial. Returns `trial_N_score` for each completed trial, plus `ada_delta_score = trial_(last)_score - trial_0_score`. + +### Eval side (evaluator.py + utils.py) + +**`HumanReplayEvaluator.rollout`**: rewrite outer loop to be trial-aware: +```python +for episode in range(num_episodes): + obs, _ = env.reset() + state = _fresh_state() + trial_metrics_per_episode = [] + while True: + for tick in range(self.sim_steps): + obs, rew, dones, truncs, info = env.step(action) + ...track per-trial success_arr... + if dones.any(): + break # episode ended + else: + continue # trial ended but episode budget remaining + break + aggregate per-trial metrics +``` + +Per-`(rollout, agent, trial_idx)` success matrix instead of per-`(rollout, agent, scenario)`. + +`RECOVERY_CACHE_RESET_PER_SCENARIO` → `RECOVERY_CACHE_RESET_PER_TRIAL` env var (control switch; ON = control, OFF = adaptive). + +### Pufferl integration + +Should require **zero changes**. The recv loop, GAE kernel, and `create_episode_mask` already handle: +- Multiple episodes per segment row (terminals=1 marks the joins) +- Cache reset on `done_mask=t+d` +- GAE truncation only on terminals + +Verify with a synthetic test before assuming. + +### Config knobs (adaptive.ini) + +```ini +[env] +goal_behavior = 3 # GOAL_TRIAL mode +max_trials_per_episode = 4 # how many trials per episode max +trial_length = 100 # per-trial budget in ticks +episode_budget = 400 # max total episode ticks (= max_trials * trial_length) +trial_cache_reset = False # adaptive treatment (cache persists). True = control. +``` + +## Implementation milestones + +### Milestone 1 — C-side trial mode + Python signal (Day 1-2) + +- Add `GOAL_TRIAL=3` constant + new branch in `c_step` goal-reach logic. +- Add `trial_ended_this_step[env]` field to env struct, exposed via a new binding (`vec_get_trial_ended` or similar) or via observation channel. +- Add per-trial Log fields. +- Unit test: instantiate one Drive env with goal_behavior=3, step until agent reaches goal, verify trial_ended flag fires, verify trial_complete count increments. + +### Milestone 2 — Python wiring + map swap (Day 3) + +- Refactor `_reinit_envs_with_new_maps` to take `set_terminals=True` param (default True for backwards compat). +- In `Drive.step`, read trial_ended array, call reinit-without-terminals at trial boundary. +- Increment current_trial, handle episode-budget exhaustion. +- Set terminals at episode boundary; integrate with map swap. +- Track per-trial metrics. + +### Milestone 3 — `_compute_delta_metrics` rewrite (Day 4) + +- Generalize from per-scenario to per-trial. +- Emit `trial_N_score`, `trial_N_collision_rate`, etc. +- Compute `ada_delta_score` as last_trial_score - first_trial_score. +- Add `mean_trial_score`, `n_trials_completed_per_episode`. + +### Milestone 4 — Eval mirror (Day 5) + +- Rewrite `HumanReplayEvaluator.rollout` outer loop. +- Per-(rollout, agent, trial) success array. +- Mirror trial_cache_reset for control runs. +- Test on a saved checkpoint. + +### Milestone 5 — Integration test + verification (Day 6) + +- End-to-end: train one epoch with new mode, check loss not NaN, scores look sensible. +- Compare GAE flow with `terminals=1` only at episode boundary vs at every trial boundary; verify expected behavior. +- Log `transformer_position` over an episode to verify cache lifecycle. + +### Milestone 6 — A/B run + ship (Day 7-10) + +- Train 3 seeds k=2 (now 4 trials × 100 ticks per trial = 400-tick episode budget) gb=3 — adaptive treatment. +- Train 3 seeds k=2 — control (trial_cache_reset=True). +- Compare ada_delta_score curves. + +## Risks / open questions + +1. **Cost of map-swap-per-trial**: `_reinit_envs_with_new_maps` is the slow call (vec_close + 540× env_init). Per-trial cost ~5-10s. With 4 trials/episode and many parallel envs, this could 4× training wall-clock. + - Mitigation: could pre-load N maps and just switch index, but env doesn't support that today. Adds engineering. + +2. **K/V cache size**: with episode_budget=400 and horizon=400, cache exactly fits. With variable trial counts, the cache might overrun if trials run long. We need to cap episode at horizon to avoid wraparound. + +3. **Per-trial Log fixed-size arrays**: MAX_TRIALS=8 is arbitrary. If a user sets max_trials_per_episode=16, breaks. Needs runtime validation. + +4. **Per-env trial counters when num_envs=540**: 540 separate counters. Cheap in C, just need to expose them right. + +5. **`map_rand_per_scenario` flag**: existing knob is "broken as ICL probe" because of the terminals=1 write. Our new mode bypasses this. Should we also fix the existing flag, or remove it / mark deprecated? + +## Smallest-possible-test before committing to full impl + +(Removed — Day 0 used `map_rand_per_scenario`, which we're deleting. Going straight to the full implementation.) diff --git a/pufferlib/config/default.ini b/pufferlib/config/default.ini index 810de828c0..248eb8d3b3 100644 --- a/pufferlib/config/default.ini +++ b/pufferlib/config/default.ini @@ -49,7 +49,7 @@ minibatch_size = 8192 # Accumulate gradients above this size max_minibatch_size = 32768 -bptt_horizon = 64 +horizon = 64 compile = False compile_mode = max-autotune-no-cudagraphs compile_fullgraph = True @@ -81,7 +81,7 @@ downsample = 10 ; mean = 1e8 ; scale = time -; [sweep.train.bptt_horizon] +; [sweep.train.horizon] ; distribution = uniform_pow2 ; min = 16 ; max = 64 diff --git a/pufferlib/config/ocean/adaptive.ini b/pufferlib/config/ocean/adaptive.ini index 3e4eaf60c4..fc347130b8 100644 --- a/pufferlib/config/ocean/adaptive.ini +++ b/pufferlib/config/ocean/adaptive.ini @@ -2,22 +2,35 @@ package = ocean env_name = puffer_adaptive_drive policy_name = Drive -rnn_name = Recurrent +policy_architecture = Transformer +; Adaptive runs always use the transformer wrapper. +rnn_name = Transformer [vec] -num_workers = 16 -num_envs = 16 +num_workers = 32 +num_envs = 32 batch_size = 1 ; backend = Serial [policy] -input_size = 64 +input_size = 128 hidden_size = 256 [rnn] input_size = 256 hidden_size = 256 +[transformer] +input_size = 256 +hidden_size = 256 +num_layers = 2 +; Number of transformer layers +num_heads = 4 +; Number of attention heads (must divide hidden_size) +; Transformer uses `horizon` from [train] section for attention span +dropout = 0.0 +; Dropout (keep at 0 for RL stability initially) + [env] num_agents = 1024 num_ego_agents = 512 @@ -25,35 +38,99 @@ num_ego_agents = 512 action_type = discrete ; Options: classic, jerk dynamics_model = classic -; Number of consecutive scenarios per episode (adaptive-specific) -k_scenarios = 2 reward_vehicle_collision = -0.5 -reward_offroad_collision = -0.2 -reward_ade = 0.0 +reward_offroad_collision = -0.5 dt = 0.1 reward_goal = 1.0 reward_goal_post_respawn = 0.25 +; in case of reward conditioning, we scale the goal_weight by this number for post respawn +; Lane alignment reward (GIGAFLOW) - set to 0 to disable +reward_lane_align = 0.0 +; Velocity alignment coefficient for lane reward +reward_vel_align = 1.0 ; Meters around goal to be considered "reached" goal_radius = 2.0 -; What to do when goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" +; Max target speed in m/s for the agent to maintain towards the goal +goal_speed = 100.0 +; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" goal_behavior = 0 +; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals. +; Large numbers will select a goal point further away from the agent's current position. +goal_target_distance = 30.0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove collision_behavior = 0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove offroad_behavior = 0 -; Number of steps before reset +; Number of steps in each scenario (constrained by base data) scenario_length = 91 -; Resample frequency = k_scenarios * scenario_length (adaptive-specific) -resample_frequency = 182 -num_maps = 1000 -; Which step of the trajectory to initialize the agents at upon reset +k_scenarios = 2 +termination_mode = 1 +; 0 - terminate at episode_length, 1 - terminate after all agents have been reset +map_dir = "resources/drive/binaries/training" +num_maps = 10000 +; Determines which step of the trajectory to initialize the agents at upon reset init_steps = 0 -; Options: "control_vehicles", "control_agents", "control_tracks_to_predict" +; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only" control_mode = "control_vehicles" ; Options: "created_all_valid", "create_only_controlled" init_mode = "create_all_valid" ; train with co players -co_player_enabled = 1 +co_player_enabled = True +; When True, co-player inference happens on the main GPU process (batched +; across workers) instead of single-thread CPU per worker. ~5-10x speedup +; on env stepping; required for k>=3 to be tractable on adaptive runs. +external_co_player_actions = False +; Re-init C envs with fresh map_ids at every within-episode scenario +; boundary. BROKEN as an ICL probe — the reinit sets terminals[:]=1 +; which wipes the ego cache and truncates GAE at the boundary. +map_rand_per_scenario = False +; When True (only meaningful for k_scenarios > 1, with co_player_enabled): at +; every scenario boundary, the partner conditioning vector is re-sampled from +; configured ranges. The partner POLICY weights are unchanged, but its +; effective behavior shifts per scenario. The ego policy must infer partner +; type from s_0 observations and use it in s_1. Independent of +; map_rand_per_scenario; set ONE of them, not both. +condition_rand_per_scenario = False +; When True, partner's entropy_weight_ub is annealed up over training in 4 +; stages: 0.05 → 0.20 → 0.50 → 1.0 of the user-passed entropy_ub. Each +; stage advances every 30 episodes per worker (≈30 epochs at nw=32 nv=32). +; Other conditioning dims (collision/offroad/discount) sample at their +; passed ranges throughout. Use with a wide-conditioning partner whose +; trained range covers up to the final entropy_ub. +entropy_curriculum_enabled = False +; When resuming a curriculum run from checkpoint, set this to the original +; run's last episode count so the curriculum picks up at the right stage +; instead of restarting from stage 0. Each per-worker counter starts at +; this value. No-op when entropy_curriculum_enabled = False. +entropy_curriculum_episodes_start = 0 +; When True, ego K/V cache is reset at SOME within-episode scenario +; boundaries (controlled by curriculum stage) so the policy first learns +; to drive single scenarios, then short cross-scenario context, before +; full K_max-scenario context. Implementation: at boundaries to cut, we +; set truncations[ego_ids]=1 + terminals[ego_ids]=1, which makes +; pufferl drop the cache via done_mask=t+d during eval and blocks +; cross-boundary attention via create_episode_mask during training. +; Stage schedule: stage 0 = k_eff=1, stage 1 = k_eff=2, stage 2 = K_max. +; K_max=4 with k_eff∈{1,2,4} produces uniform splits; for other K_max, +; only k_eff=1 and k_max produce uniform stages. +k_eff_curriculum_enabled = False +; Episodes per stage in the k_eff curriculum. +k_eff_curriculum_episodes_per_stage = 30 +; When True, the partner's per-env conditioning vector is appended to +; the END of every ego's observation as `oracle` slots. The C-side obs +; format is unchanged (writes into a private buffer); a wider Python +; buffer holds the C output plus the appended oracle dims. Pufferl's +; partner-obs path strips the trailing oracle dims via the env's +; `_c_obs_dim` attribute, so the partner policy still sees its original +; obs width. Reward, GAE γ, and PPO entropy coefficient unchanged. +; Requires co_player_policy.conditioning.type != "none". +ego_is_oracle = False + +; If True, zero rewards in scenarios 0..k-2; only the last scenario +; produces reward signal. Forces credit assignment to flow back from s_{k-1} +; rewards to actions in earlier scenarios — used to test whether reward +; shape is the bottleneck for cross-scenario adaptation. +reward_only_last_scenario = False [env.conditioning] @@ -71,17 +148,25 @@ discount_weight_lb = 0.80 discount_weight_ub = 0.98 [env.co_player_policy] -enabled = True policy_name = Drive -rnn_name = Recurrent -policy_path = "experiments/puffer_drive_ewdjljwd.pt" -input_size = 64 +; Options: "Recurrent", "Transformer" +architecture = Recurrent +policy_path = "pufferlib/resources/drive/policies/varied_discount.pt" +input_size = 128 hidden_size = 256 [env.co_player_policy.rnn] input_size = 256 hidden_size = 256 +[env.co_player_policy.transformer] +input_size = 256 +hidden_size = 256 +num_layers = 2 +num_heads = 4 +horizon = 91 +dropout = 0.0 + [env.co_player_policy.conditioning] ; Options: "none", "reward", "entropy", "discount", "all" type = "all" @@ -96,68 +181,77 @@ entropy_weight_ub = 0.001 discount_weight_lb = 0.80 discount_weight_ub = 0.98 - [train] +seed=42 +compile = True +compile_mode = default +precision = bfloat16 total_timesteps = 2_000_000_000 -# learning_rate = 0.02 -# gamma = 0.985 anneal_lr = True -; Needs to be: num_agents * num_workers * BPTT horizon +; Needs to be: num_agents * num_workers * horizon batch_size = auto -; minibatch_size = 745472 -; minibatch_multiplier = 512 -; max_minibatch_size = 745472 -minibatch_size = 372736 -minibatch_multiplier = 256 -max_minibatch_size = 372736 -; BPTT horizon (overridden by pufferl.py for adaptive agents to k_scenarios * scenario_length) -bptt_horizon = 32 +minibatch_size = 36400 +; 200 * 182 (must be divisible by horizon = k_scenarios * scenario_length) +max_minibatch_size = 36400 +minibatch_multiplier = 400 +; Sequence length - overridden to k_scenarios * scenario_length for adaptive +horizon = 91 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_eps = 1e-8 clip_coef = 0.2 -ent_coef = 0.001 +ent_coef = 0.005 gae_lambda = 0.95 -gamma = 0.98 -learning_rate = 0.001 -max_grad_norm = 1 -prio_alpha = 0.8499999999999999 -prio_beta0 = 0.8499999999999999 +; gamma=0.995 → ~135-step credit half-life (was 0.98 → ~34-step). Bumped +; for k=2/201 (horizon=402) so s_1 rewards can actually flow back to s_0 +; actions; γ=0.98 over 402 steps gave ~0.0003 effective discount → no +; cross-scenario credit assignment, killing the in-context-learning +; signal. See notes/oracle_partner_conditioning_investigation.md. +gamma = 0.995 +learning_rate = 0.003 +; Reduced from 0.003 (transformers often need lower LR) +max_grad_norm = 1.0 +prio_alpha = 0.85 +prio_beta0 = 0.85 update_epochs = 1 -vf_clip_coef = 0.1999999999999999 -vf_coef = 2 +vf_clip_coef = 0.2 +vf_coef = 2.0 vtrace_c_clip = 1 vtrace_rho_clip = 1 -checkpoint_interval = 1000 +checkpoint_interval = 40 # Rendering options render = True -render_interval = 1000 +render_interval = 100 ; If True, show exactly what the agent sees in agent observation obs_only = True ; Show grid lines -show_grid = False +show_grid = True ; Draws lines from ego agent observed ORUs and road elements to show detection range show_lasers = False ; Display human xy logs in the background -show_human_logs = True -; Options: str to path (e.g., "resources/drive/binaries/map_001.bin"), None +show_human_logs = False +; If True, zoom in on a part of the map. Otherwise, show full map +zoom_in = True +; Options: List[str to path], str to path (e.g., "resources/drive/training/binaries/map_001.bin"), None render_map = none [eval] -eval_interval = 1000 +eval_interval = 40 +; Path to dataset used for evaluation. None = inherit env.map_dir from training. +map_dir = None +; Evaluation will run on the first num_maps maps in the map_dir directory +num_maps = 20 backend = PufferEnv -# WOSAC (Waymo Open Sim Agents Challenge) evaluation settings +; WOSAC (Waymo Open Sim Agents Challenge) evaluation settings ; If True, enables evaluation on realism metrics each time we save a checkpoint -wosac_realism_eval = True +wosac_realism_eval = False ; Number of policy rollouts per scene wosac_num_rollouts = 32 ; When to start the simulation wosac_init_steps = 10 -; Total number of WOSAC agents to evaluate -wosac_num_agents = 256 -; Control the tracks to predict -wosac_control_mode = "control_tracks_to_predict" -; Initialize from the tracks to predict +; Control everything valid at init in the scene +wosac_control_mode = "control_wosac" +; Create everything in valid at init the scene wosac_init_mode = "create_all_valid" ; Stop when reaching the goal wosac_goal_behavior = 2 @@ -168,23 +262,33 @@ wosac_sanity_check = False wosac_aggregate_results = True ; If True, enable human replay evaluation (pair policy-controlled agent with human replays) human_replay_eval = True -; Control only the self-driving car -human_replay_control_mode = "control_sdc_only" -; This equals the number of scenarios, since we control one agent in each -human_replay_num_agents = 64 +; Control mode for human replay (control_vehicles with max_controlled_agents=1 controls one agent) +human_replay_control_mode = "control_vehicles" +; Number of agents in human replay evaluation environment. +; Set equal to num_maps so every scene gets a controllable SDC each rollout +; (human_replay caps at one controllable agent per scene). +human_replay_num_agents = 100 +; Number of independent rollouts per eval cycle. With deterministic resets the +; scene set is fixed across rollouts, so num_rollouts samples policy action +; stochasticity. Reported metrics are mean across rollouts with `_std`. +human_replay_num_rollouts = 100 +; Number of maps to use for human replay evaluation +human_replay_num_maps = 100 +; Number of maps to render for human replay (subset of eval maps) +human_replay_render_num_maps = 1 -[sweep.env.reward_vehicle_collision] -distribution = uniform -min = -0.5 -max = 0.0 -mean = -0.05 +[sweep.train.learning_rate] +distribution = log_normal +min = 0.001 +mean = 0.003 +max = 0.005 scale = auto -[sweep.env.reward_offroad_collision] -distribution = uniform -min = -0.5 -max = 0.0 -mean = -0.05 +[sweep.train.ent_coef] +distribution = log_normal +min = 0.001 +mean = 0.005 +max = 0.03 scale = auto [sweep.env.goal_radius] @@ -194,16 +298,18 @@ max = 20.0 mean = 10.0 scale = auto -[sweep.env.reward_ade] -distribution = uniform -min = -0.1 -max = 0.0 -mean = -0.02 +[sweep.train.gae_lambda] +distribution = log_normal +min = 0.95 +mean = 0.98 +max = 0.999 scale = auto -[sweep.env.reward_goal_post_respawn] -distribution = uniform -min = 0.0 -max = 1.0 -mean = 0.5 -scale = auto +[controlled_exp.train.goal_speed] +values = [10, 20, 30, 3] + +[controlled_exp.train.ent_coef] +values = [0.001, 0.005, 0.01] + +[controlled_exp.train.seed] +values = [42, 55, 1] diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index b7bc0d021f..366ee25ac3 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -2,22 +2,33 @@ package = ocean env_name = puffer_drive policy_name = Drive -rnn_name = Recurrent +policy_architecture = Transformer [vec] -num_workers = 16 -num_envs = 16 +num_workers = 8 +num_envs = 8 batch_size = 2 ; backend = Serial [policy] -input_size = 64 +input_size = 128 hidden_size = 256 [rnn] input_size = 256 hidden_size = 256 +[transformer] +input_size = 256 +hidden_size = 256 +num_layers = 2 +; Number of transformer layers +num_heads = 4 +; Number of attention heads (must divide hidden_size) +; k_scenarios (2) * scenario_length (91) = maximum attention span +dropout = 0.0 +; Dropout (keep at 0 for RL stability initially) + [env] num_agents = 1024 num_ego_agents = 512 @@ -26,26 +37,36 @@ action_type = discrete ; Options: classic, jerk dynamics_model = classic reward_vehicle_collision = -0.5 -reward_offroad_collision = -0.2 -reward_ade = 0.0 +reward_offroad_collision = -0.5 dt = 0.1 reward_goal = 1.0 reward_goal_post_respawn = 0.25 # in case of reward conditioning, we scale the goal_weight by this number for post respawn +; Lane alignment reward (GIGAFLOW) - set to 0 to disable +reward_lane_align = 0.0 +; Velocity alignment coefficient for lane reward +reward_vel_align = 1.0 ; Meters around goal to be considered "reached" goal_radius = 2.0 -; What to do when goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" +; Max target speed in m/s for the agent to maintain towards the goal +goal_speed = 100.0 +; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" goal_behavior = 0 +; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals. +; Large numbers will select a goal point further away from the agent's current position. +goal_target_distance = 30.0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove collision_behavior = 0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove offroad_behavior = 0 -; Number of steps before reset +; Number of steps before scenario_length = 91 -resample_frequency = 182 +resample_frequency = 910 +termination_mode = 1 # 0 - terminate at episode_length, 1 - terminate after all agents have been reset +map_dir = "resources/drive/binaries/training" num_maps = 1000 -; Which step of the trajectory to initialize the agents at upon reset +; Determines which step of the trajectory to initialize the agents at upon reset init_steps = 0 -; Options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only" +; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only" control_mode = "control_vehicles" ; Options: "created_all_valid", "create_only_controlled" init_mode = "create_all_valid" @@ -67,10 +88,9 @@ discount_weight_lb = 0.80 discount_weight_ub = 0.98 [env.co_player_policy] -enabled = False policy_name = Drive rnn_name = Recurrent -policy_path = "resources/drive/policies/varied_discount.pt" +policy_path = "pufferlib/resources/drive/policies/varied_discount.pt" input_size = 64 hidden_size = 256 @@ -93,15 +113,17 @@ discount_weight_lb = 0.98 discount_weight_ub = 0.80 [train] +seed=42 total_timesteps = 2_000_000_000 -# learning_rate = 0.02 -# gamma = 0.985 anneal_lr = True -; Needs to be: num_agents * num_workers * BPTT horizon +; Needs to be: num_agents * num_workers * horizon batch_size = auto -minibatch_size = 32768 -max_minibatch_size = 32768 -bptt_horizon = 32 +minibatch_size = 32760 +; 360 * 91 +max_minibatch_size = 32760 +minibatch_multiplier = 360 +; Sequence length for training - matches scenario_length for full episode context +horizon = 91 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_eps = 1e-8 @@ -110,44 +132,49 @@ ent_coef = 0.005 gae_lambda = 0.95 gamma = 0.98 learning_rate = 0.003 -max_grad_norm = 1 -prio_alpha = 0.8499999999999999 -prio_beta0 = 0.8499999999999999 +max_grad_norm = 1.0 +prio_alpha = 0.85 +prio_beta0 = 0.85 update_epochs = 1 -vf_clip_coef = 0.1999999999999999 -vf_coef = 2 +vf_clip_coef = 0.2 +vf_coef = 2.0 vtrace_c_clip = 1 vtrace_rho_clip = 1 -checkpoint_interval = 1000 +checkpoint_interval = 100 +context_length = 32 # Rendering options render = True -render_interval = 1000 +render_interval = 100 ; If True, show exactly what the agent sees in agent observation obs_only = True ; Show grid lines -show_grid = False +show_grid = True ; Draws lines from ego agent observed ORUs and road elements to show detection range show_lasers = False ; Display human xy logs in the background -show_human_logs = True -; Options: str to path (e.g., "resources/drive/binaries/map_001.bin"), None +show_human_logs = False +; If True, zoom in on a part of the map. Otherwise, show full map +zoom_in = True +; Options: List[str to path], str to path (e.g., "resources/drive/training/binaries/map_001.bin"), None render_map = none [eval] -eval_interval = 1000 +eval_interval = 100 +; Path to dataset used for evaluation. None = inherit env.map_dir from training. +map_dir = None +; Evaluation will run on the first num_maps maps in the map_dir directory +num_maps = 20 backend = PufferEnv -# WOSAC (Waymo Open Sim Agents Challenge) evaluation settings +; WOSAC (Waymo Open Sim Agents Challenge) evaluation settings ; If True, enables evaluation on realism metrics each time we save a checkpoint wosac_realism_eval = False ; Number of policy rollouts per scene wosac_num_rollouts = 32 ; When to start the simulation wosac_init_steps = 10 -; Total number of WOSAC agents to evaluate -wosac_num_agents = 256 -; Control the tracks to predict -wosac_control_mode = "control_tracks_to_predict" -; Initialize from the tracks to predict +; Control everything valid at init in the scene +wosac_control_mode = "control_wosac" +; Create everything in valid at init the scene wosac_init_mode = "create_all_valid" ; Stop when reaching the goal wosac_goal_behavior = 2 @@ -157,11 +184,17 @@ wosac_sanity_check = False ; Only return aggregate results across all scenes wosac_aggregate_results = True ; If True, enable human replay evaluation (pair policy-controlled agent with human replays) -human_replay_eval = False -; Control only the self-driving car -human_replay_control_mode = "control_sdc_only" -; This equals the number of scenarios, since we control one agent in each -human_replay_num_agents = 64 +human_replay_eval = True +; Control mode for human replay (control_vehicles with max_controlled_agents=1 controls one agent) +human_replay_control_mode = "control_vehicles" +; Number of agents in human replay evaluation environment +human_replay_num_agents = 32 +; Number of rollouts for human replay evaluation +human_replay_num_rollouts = 100 +; Number of maps to use for human replay evaluation +human_replay_num_maps = 100 +; Number of maps to render for human replay (subset of eval maps) +human_replay_render_num_maps = 1 [sweep.train.learning_rate] distribution = log_normal @@ -174,10 +207,9 @@ scale = auto distribution = log_normal min = 0.001 mean = 0.005 -max = 0.01 +max = 0.03 scale = auto - [sweep.env.goal_radius] distribution = uniform min = 2.0 @@ -185,16 +217,18 @@ max = 20.0 mean = 10.0 scale = auto -[sweep.env.reward_ade] -distribution = uniform -min = -0.1 -max = 0.0 -mean = -0.02 +[sweep.train.gae_lambda] +distribution = log_normal +min = 0.95 +mean = 0.98 +max = 0.999 scale = auto -[sweep.env.reward_goal_post_respawn] -distribution = uniform -min = 0.0 -max = 1.0 -mean = 0.5 -scale = auto +[controlled_exp.train.goal_speed] +values = [10, 20, 30, 3] + +[controlled_exp.train.ent_coef] +values = [0.001, 0.005, 0.01] + +[controlled_exp.train.seed] +values = [42, 55, 1] diff --git a/pufferlib/models.py b/pufferlib/models.py index 0893a9db47..72121e458e 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -1,3 +1,5 @@ +import os + import numpy as np import torch @@ -7,6 +9,14 @@ import pufferlib.pytorch import pufferlib.spaces +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +import math + + +# Set PUFFER_TRANSFORMER_LEGACY_EVAL=1 to fall back to the pre-KV-cache path. +_USE_LEGACY_EVAL = os.environ.get("PUFFER_TRANSFORMER_LEGACY_EVAL", "0") == "1" + class Default(nn.Module): """Default PyTorch policy. Flattens obs and applies a linear layer. @@ -196,6 +206,488 @@ def forward(self, observations, state): return logits, values +class TransformerWrapper(nn.Module): # TransformerWrapper + def __init__( + self, + env, + policy, + input_size=128, + hidden_size=128, + num_layers=4, + num_heads=8, + horizon=512, + dropout=0.0, + use_checkpointing=False, + ): + """Wraps your policy with a Transformer for temporal modeling. + + Args: + env: Environment instance + policy: Your Drive policy (must have encode_observations and decode_actions) + input_size: Size of encoded observations (from policy.encode_observations) + hidden_size: Transformer hidden dimension + num_layers: Number of transformer layers + num_heads: Number of attention heads + horizon: Maximum sequence length to attend over + dropout: Dropout probability + use_checkpointing: Enable gradient checkpointing to save memory (slower training) + """ + super().__init__() + self.obs_shape = env.single_observation_space.shape + self.policy = policy + self.input_size = input_size + self.hidden_size = hidden_size + self.horizon = horizon + self.num_layers = num_layers + self.num_heads = num_heads + if hidden_size % num_heads != 0: + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})") + self.head_dim = hidden_size // num_heads + self.is_continuous = self.policy.is_continuous + self.use_checkpointing = use_checkpointing + # Per-slot attention masks for KV-cached streaming inference. Cached + # lazily per device to avoid recomputing the same mask each step. + self._streaming_mask_cache = {} + + # Project encoded observations to transformer dimension if needed + if input_size != hidden_size: + self.input_projection = nn.Linear(input_size, hidden_size) + else: + self.input_projection = nn.Identity() + + # Sinusoidal positional embedding (Vaswani et al.) — non-trainable. + # Switched from learnable PE so the transformer has temporal + # structure from initialization rather than having to learn it + # from gradients. Slot-tied: PE[i] is added when writing to + # cache slot i, identical for both forward (training) and + # forward_eval (rollout) paths via get_positional_embedding(). + pe = torch.zeros(horizon, hidden_size) + position = torch.arange(0, horizon, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, hidden_size, 2, dtype=torch.float) * (-math.log(10000.0) / hidden_size)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + # register_buffer keeps it on the module's device but excludes it + # from .parameters() (no gradient updates). + self.register_buffer("positional_embedding", pe.unsqueeze(0)) + + # Transformer encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_size, + nhead=num_heads, + dim_feedforward=hidden_size * 2, + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, # Pre-LN architecture (more stable) + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # create cache for memory context + for T in [1, 2, 4, 8, 16, 32, 64, 91, 182, 273, 364, 455]: + mask = self.create_causal_mask(T, "cpu") + self.register_buffer(f"_causal_mask_{T}", mask, persistent=False) + + # Cached masks for episode mask creation (reduces memory allocation) + self.register_buffer("_zero_mask", torch.zeros(1), persistent=False) + self.register_buffer("_neg_inf_mask", torch.full((1,), float("-inf")), persistent=False) + + # Layer norm for output + self.output_norm = nn.LayerNorm(hidden_size) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + """Initialize weights similar to GPT-2""" + for name, param in self.named_parameters(): + if "layer_norm" in name or "layernorm" in name or "output_norm" in name: + continue + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name and param.ndim >= 2: + nn.init.orthogonal_(param, 1.0) + + def create_causal_mask(self, seq_len, device): + """Create causal attention mask""" + mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1) + return mask + + def get_causal_mask(self, T, device): + """Get cached causal mask or create new one""" + buffer_name = f"_causal_mask_{T}" + if hasattr(self, buffer_name): + mask = getattr(self, buffer_name) + if mask.device != device: + # Move to device and cache + mask = mask.to(device) + setattr(self, buffer_name, mask) + return mask + mask = self.create_causal_mask(T, device) + self.register_buffer(buffer_name, mask, persistent=False) + return mask + + def get_positional_embedding(self, T, device): + """Get cached positional embedding for length T""" + cache_key = f"_pos_embed_{T}" + if not hasattr(self, cache_key) or getattr(self, cache_key).device != device: + pos_embed = self.positional_embedding[:, :T].to(device) + setattr(self, cache_key, pos_embed) + return getattr(self, cache_key) + + def create_episode_mask(self, terminals, seq_len): + """Episode mask which ensures that you arent attending over episode boundaries. + Optimized with cached mask buffers to reduce memory allocation.""" + B = terminals.shape[0] + device = terminals.device + + # Use cumsum for episode IDs + episode_ids = torch.nn.functional.pad(terminals[:, :-1], (1, 0)).cumsum(dim=1) + + # Avoid full (B, T, T) allocation - use sparse comparison + mask_allow = episode_ids.unsqueeze(2) == episode_ids.unsqueeze(1) + + # Use cached tensors moved to correct device + zero_mask = self._zero_mask.to(device) if self._zero_mask.device != device else self._zero_mask + neg_inf_mask = self._neg_inf_mask.to(device) if self._neg_inf_mask.device != device else self._neg_inf_mask + + return torch.where(mask_allow, zero_mask, neg_inf_mask) + + # ------------------------------------------------------------------ # + # Streaming inference with per-layer KV cache. + # + # The legacy `_forward_eval_legacy` below recomputes a full transformer + # forward over the entire horizon-length context buffer on every step, + # then throws away all but one output position. For B=512 co-players, + # horizon=91, single-thread CPU, this costs ~3 s per step and dominates + # training wallclock. + # + # The KV-cached path maintains per-layer (K, V) buffers in `state` and + # only computes Q/K/V for the new token, attending against the cache. + # Numerically equivalent to the legacy path (same rolling buffer + + # causal-row semantics, including the post-wrap "self-attention only" + # behavior at slot 0 after horizon steps). + # + # Implementation note: `slot` is kept as a 1-element long tensor (not + # `int(pos.item())`) so this method stays compile-friendly. On the GPU + # ego policy under torch.compile, `.item()` would cause a Dynamo graph + # break and a CUDA sync every call. + # ------------------------------------------------------------------ # + + def _slot_arange(self, device): + """Return a length-`horizon` arange tensor on `device`, cached per device.""" + key = (device.type, device.index if device.index is not None else -1) + cached = self._streaming_mask_cache.get(key) + if cached is not None: + return cached + arr = torch.arange(self.horizon, device=device) + self._streaming_mask_cache[key] = arr + return arr + + def _make_kv_cache(self, batch_size, device, dtype): + return [ + torch.zeros(batch_size, self.num_heads, self.horizon, self.head_dim, device=device, dtype=dtype) + for _ in range(self.num_layers) + ] + + def init_eval_state(self, batch_size, device, dtype=torch.float32): + """Allocate a fresh streaming-inference state dict for this policy.""" + return dict( + k_cache=self._make_kv_cache(batch_size, device, dtype), + v_cache=self._make_kv_cache(batch_size, device, dtype), + transformer_position=torch.zeros(1, dtype=torch.long, device=device), + ) + + def _prime_kv_cache(self, indices, state): + """Prime K/V cache for `indices` to match legacy 'zero hidden context'. + + The legacy reset only zeroed the rolling hidden buffer. Because that + buffer is summed with the slot-tied positional embedding inside the + transformer, the *effective* K/V at unwritten slots is the K/V you + get from running the transformer over an all-zero hidden sequence + (i.e. just the pos embeddings, with causal attention propagating + through layers). This priming fills our cache with exactly that + state, so subsequent forward_eval calls match the legacy bit-close. + """ + if isinstance(indices, slice) and indices == slice(None): + n_idx = state["k_cache"][0].shape[0] + elif torch.is_tensor(indices): + n_idx = int(indices.shape[0]) + else: + n_idx = len(indices) + if n_idx == 0: + return + + H, D = self.num_heads, self.head_dim + T = self.horizon + device = state["k_cache"][0].device + dtype = state["k_cache"][0].dtype + + pos_embed = self.get_positional_embedding(T, device).to(dtype) # (1, T, hidden) + layer_input = pos_embed.expand(n_idx, T, self.hidden_size).contiguous() + causal_mask = self.get_causal_mask(T, device) + + with torch.no_grad(): + for li, layer in enumerate(self.transformer.layers): + attn = layer.self_attn + x_norm = layer.norm1(layer_input) + qkv = F.linear(x_norm, attn.in_proj_weight, attn.in_proj_bias) + q, k, v = qkv.chunk(3, dim=-1) + q = q.view(n_idx, T, H, D).transpose(1, 2) + k = k.view(n_idx, T, H, D).transpose(1, 2) + v = v.view(n_idx, T, H, D).transpose(1, 2) + + # Defensive cast: under autocast or mixed-precision, k/v can + # come out in a different dtype than the cache; PyTorch refuses + # cross-dtype `index_put`. Match the cache's dtype. + state["k_cache"][li][indices] = k.to(state["k_cache"][li].dtype) + state["v_cache"][li][indices] = v.to(state["v_cache"][li].dtype) + + attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask, is_causal=False) + attn_out = attn_out.transpose(1, 2).reshape(n_idx, T, self.hidden_size) + attn_out = F.linear(attn_out, attn.out_proj.weight, attn.out_proj.bias) + x = layer_input + attn_out + x_norm2 = layer.norm2(x) + ffn_h = layer.activation(F.linear(x_norm2, layer.linear1.weight, layer.linear1.bias)) + ffn_out = F.linear(ffn_h, layer.linear2.weight, layer.linear2.bias) + layer_input = x + ffn_out + + def reset_eval_state(self, state, done_indices=None): + """Reset KV cache (and step counter) for done agents. + + - done_indices=None: full reset. K/V zeroed, step counter cleared. + (Equivalent to allocating a fresh state.) + - done_indices=tensor/array of agent indices: re-prime those rows + to mirror the legacy "zero hidden buffer" behavior. The shared + step counter is intentionally NOT reset in that case (matches + legacy, which only touched per-row context). + """ + k_cache = state.get("k_cache") + v_cache = state.get("v_cache") + if k_cache is None or v_cache is None: + return + if done_indices is None: + for c in k_cache: + c.zero_() + for c in v_cache: + c.zero_() + pos = state.get("transformer_position") + if pos is not None: + pos.zero_() + else: + idx = done_indices + if not torch.is_tensor(idx): + idx = torch.as_tensor(idx, device=k_cache[0].device, dtype=torch.long) + self._prime_kv_cache(idx, state) + + def forward_eval(self, observations, state): + if _USE_LEGACY_EVAL: + # Escape hatch for benchmarking / safety net: set + # PUFFER_TRANSFORMER_LEGACY_EVAL=1 in the environment to bypass + # the KV-cached path and use the original full-context forward. + return self._forward_eval_legacy(observations, state) + B = observations.shape[0] + device = observations.device + + hidden = self.policy.encode_observations(observations, state=state) + hidden = self.input_projection(hidden) + # hidden: (B, hidden_size) + + # Fetch or lazily allocate the KV cache. We re-allocate if shape + # changes (e.g. batch size differs across calls) or dtype mismatches + # the input (mixed-precision boundary). + k_cache = state.get("k_cache") + v_cache = state.get("v_cache") + need_alloc = ( + k_cache is None or v_cache is None or k_cache[0].shape[0] != B or k_cache[0].shape[2] != self.horizon + ) + if need_alloc: + k_cache = self._make_kv_cache(B, device, hidden.dtype) + v_cache = self._make_kv_cache(B, device, hidden.dtype) + pos = torch.zeros(1, dtype=torch.long, device=device) + else: + pos = state.get("transformer_position", torch.zeros(1, dtype=torch.long, device=device)) + if k_cache[0].dtype != hidden.dtype: + k_cache = [c.to(hidden.dtype) for c in k_cache] + v_cache = [c.to(hidden.dtype) for c in v_cache] + + slot_t = (pos % self.horizon).long() # (1,) long tensor + + # Add the slot's positional embedding (slot-tied, matching the + # legacy rolling-buffer scheme). + pos_embed = self.get_positional_embedding(self.horizon, device) # (1, horizon, hidden) + pos_embed_slot = pos_embed.index_select(1, slot_t).squeeze(1) # (1, hidden) + x = (hidden + pos_embed_slot).unsqueeze(1) # (B, 1, hidden) + + # Build (1, 1, 1, horizon) bool mask: True at slots [0, slot_t]. + slots_arange = self._slot_arange(device) + attn_mask = (slots_arange <= slot_t).view(1, 1, 1, self.horizon) + H = self.num_heads + D = self.head_dim + + for li, layer in enumerate(self.transformer.layers): + attn = layer.self_attn + + x_norm = layer.norm1(x) + qkv = F.linear(x_norm, attn.in_proj_weight, attn.in_proj_bias) + q, k, v = qkv.chunk(3, dim=-1) + q = q.view(B, 1, H, D).transpose(1, 2) # (B, H, 1, D) + k = k.view(B, 1, H, D).transpose(1, 2) # (B, H, 1, D) + v = v.view(B, 1, H, D).transpose(1, 2) # (B, H, 1, D) + + # index_copy_ on dim=2 writes one slot using a tensor index, which + # avoids the .item() sync that would force a Dynamo graph break. + k_cache[li].index_copy_(2, slot_t, k) + v_cache[li].index_copy_(2, slot_t, v) + + if state.get("_probe_attention", False): + # Manual softmax attention so we can stash the weights. SDPA's + # functional form doesn't return weights. Math is identical to + # the SDPA call below but we capture (B, H, 1, horizon) weights + # per layer per step into state["_attn_weights"]. + scale = 1.0 / math.sqrt(D) + logits = torch.matmul(q, k_cache[li].transpose(-2, -1)) * scale # (B, H, 1, horizon) + logits = logits.masked_fill(~attn_mask, float("-inf")) + weights = F.softmax(logits, dim=-1) # (B, H, 1, horizon) + attn_out = torch.matmul(weights, v_cache[li]) # (B, H, 1, D) + state.setdefault("_attn_weights", []).append( + {"layer": li, "slot": int(slot_t.item()), "weights": weights.detach().cpu()} + ) + else: + attn_out = F.scaled_dot_product_attention( + q, + k_cache[li], + v_cache[li], + attn_mask=attn_mask, + is_causal=False, + ) + attn_out = attn_out.transpose(1, 2).reshape(B, 1, self.hidden_size) + attn_out = F.linear(attn_out, attn.out_proj.weight, attn.out_proj.bias) + x = x + attn_out + + x_norm2 = layer.norm2(x) + ffn_h = layer.activation(F.linear(x_norm2, layer.linear1.weight, layer.linear1.bias)) + ffn_out = F.linear(ffn_h, layer.linear2.weight, layer.linear2.bias) + x = x + ffn_out + + x = self.output_norm(x) + hidden_out = x.squeeze(1) + + state["k_cache"] = k_cache + state["v_cache"] = v_cache + state["transformer_position"] = pos + 1 + state["hidden"] = hidden_out + + logits, values = self.policy.decode_actions(hidden_out) + return logits, values + + def _forward_eval_legacy(self, observations, state): + """Original full-context forward. Kept for equivalence testing only.""" + B = observations.shape[0] + device = observations.device + + hidden = self.policy.encode_observations(observations, state=state) + hidden = self.input_projection(hidden) + + if "transformer_context" not in state or state["transformer_context"] is None: + context = torch.zeros(B, self.horizon, self.hidden_size, device=device, dtype=hidden.dtype) + pos = torch.zeros(1, dtype=torch.long, device=device) + else: + context = state["transformer_context"] + pos = state.get("transformer_position", torch.zeros(1, dtype=torch.long, device=device)) + + if context.shape[-1] != self.hidden_size or context.shape[0] != B or context.shape[1] != self.horizon: + context = torch.zeros(B, self.horizon, self.hidden_size, device=device, dtype=hidden.dtype) + pos = torch.zeros(1, dtype=torch.long, device=device) + if context.dtype != hidden.dtype: + context = context.to(hidden.dtype) + + write_idx = (pos % self.horizon).long() + context[:, write_idx, :] = hidden.unsqueeze(1) + pos = pos + 1 + + pos_embed = self.get_positional_embedding(self.horizon, device) + context_with_pos = context + pos_embed + + causal_mask = self.get_causal_mask(self.horizon, device) + + output = self.transformer(context_with_pos, mask=causal_mask, is_causal=True) + output = self.output_norm(output) + + read_idx = ((pos - 1) % self.horizon).long() + hidden_out = output[:, read_idx, :].squeeze(1) + + state["transformer_context"] = context + state["transformer_position"] = pos + state["hidden"] = hidden_out + + logits, values = self.policy.decode_actions(hidden_out) + return logits, values + + def forward(self, observations, state): + x = observations + device = x.device + + if x.ndim == len(self.obs_shape) + 1: + B, T = x.shape[0], 1 + elif x.ndim == len(self.obs_shape) + 2: + B, T = x.shape[:2] + else: + raise ValueError(f"Invalid input tensor shape: {x.shape}") + + x_flat = x.view(B * T, *self.obs_shape) + hidden = self.policy.encode_observations(x_flat, state) + + hidden = hidden.view(B, T, self.input_size) + hidden = self.input_projection(hidden) + + # Remove dynamic truncation - use clamp instead of if + T_actual = min(T, self.horizon) # Python int, fine + if T_actual < T: + hidden = hidden[:, -T_actual:] + T = T_actual + + hidden = hidden + self.get_positional_embedding(T, device) + + use_episode_mask = "terminals" in state and state["terminals"] is not None + + if not use_episode_mask: + causal_mask = self.get_causal_mask(T, device) + if self.training and self.use_checkpointing: + hidden = checkpoint( + lambda h, m: self.transformer(h, mask=m, is_causal=True), hidden, causal_mask, use_reentrant=False + ) + else: + hidden = self.transformer(hidden, mask=causal_mask, is_causal=True) + else: + terminals = state["terminals"] + if terminals.shape[1] > T: + terminals = terminals[:, -T:] + causal_mask = self.get_causal_mask(T, device) + episode_mask = self.create_episode_mask(terminals, T) + attn_mask = causal_mask.unsqueeze(0) + episode_mask + attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) + if self.training and self.use_checkpointing: + hidden = checkpoint( + lambda h, m: self.transformer(h, mask=m, is_causal=False), hidden, attn_mask, use_reentrant=False + ) + else: + hidden = self.transformer(hidden, mask=attn_mask, is_causal=False) + + hidden = self.output_norm(hidden) + flat_hidden = hidden.contiguous().view(B * T, self.hidden_size) + + logits, values = self.policy.decode_actions(flat_hidden) + values = values.view(B, T) + + # Use Python int for context_len - no sync + context_len = min(T, self.horizon) + state["hidden"] = hidden + state["transformer_context"] = hidden[:, -context_len:].detach() + state["transformer_position"] = torch.full((B,), context_len - 1, dtype=torch.long, device=device) + + return logits, values + + class Convolutional(nn.Module): def __init__( self, diff --git a/pufferlib/ocean/benchmark/evaluator.py b/pufferlib/ocean/benchmark/evaluator.py index 383384e623..818a266277 100644 --- a/pufferlib/ocean/benchmark/evaluator.py +++ b/pufferlib/ocean/benchmark/evaluator.py @@ -625,53 +625,156 @@ class HumanReplayEvaluator: def __init__(self, config: Dict): self.config = config - self.sim_steps = 91 - self.config["env"]["init_steps"] + k_scenarios = self.config["env"].get("k_scenarios", 1) + scenario_length = self.config["env"].get("scenario_length", 91) + init_steps = self.config["env"].get("init_steps", 0) + self.sim_steps = scenario_length - init_steps def rollout(self, args, puffer_env, policy): """Roll out policy in env with human replays. Store statistics. - In human replay mode, only the SDC (self-driving car) is controlled by the policy - while all other agents replay their human trajectories. This tests how compatible - the policy is with (static) human partners. + In human replay mode, only the SDC is controlled by the policy while + all other agents replay their human trajectories. This tests how + compatible the policy is with static human partners. - Args: - args: Config dict with train settings (device, use_rnn, etc.) - puffer_env: PufferLib environment wrapper - policy: Trained policy to evaluate - - Returns: - dict: Aggregated metrics including: - - avg_collisions_per_agent: Average collisions per agent - - avg_offroad_per_agent: Average offroad events per agent + Runs `num_rollouts` independent rollouts (env is reset between each, + which resamples the map/agent slice). Reports per-key mean across + rollouts, plus `_std` so you can see variance once the per-batch + score saturates. """ import numpy as np import torch import pufferlib + num_rollouts = int(args.get("eval", {}).get("human_replay_num_rollouts", 1) or 1) num_agents = puffer_env.observation_space.shape[0] device = args["train"]["device"] + k_scenarios = args["env"].get("k_scenarios", 1) - obs, info = puffer_env.reset() - state = {} - if args["train"]["use_rnn"]: - state = dict( - lstm_h=torch.zeros(num_agents, policy.hidden_size, device=device), - lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device), - ) - - for time_idx in range(self.sim_steps): - # Step policy - with torch.no_grad(): - ob_tensor = torch.as_tensor(obs).to(device) - logits, value = policy.forward_eval(ob_tensor, state) - action, logprob, _ = pufferlib.pytorch.sample_logits(logits) - action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape) - - if isinstance(logits, torch.distributions.Normal): - action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high) + is_transformer = hasattr(policy, "horizon") and hasattr(policy, "transformer") + is_recurrent = hasattr(policy, "lstm") - obs, rewards, dones, truncs, info_list = puffer_env.step(action_np) - - if len(info_list) > 0: # Happens at the end of episode - results = info_list[0] - return results + def _fresh_state(): + if is_recurrent: + return dict( + lstm_h=torch.zeros(num_agents, policy.hidden_size, device=device), + lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device), + ) + if is_transformer: + return dict( + transformer_context=torch.zeros(num_agents, policy.horizon, policy.hidden_size, device=device), + transformer_position=torch.zeros(1, dtype=torch.long, device=device), + ) + return {} + + per_rollout_aggregates = [] + per_rollout_scenario = [] + per_rollout_delta = [] + + # Per-(rollout, scenario, agent) success tracking. An agent is marked + # "successful in scenario s" when it terminates with a positive reward + # at any step of s — this is the goal-reach signal in stop-on-goal eval + # (goal_behavior=2). Agents that collide / go offroad terminate with + # negative reward; agents that time out stay non-terminal. The 0.5 + # threshold is conservative: reward_goal default is 1.0 and the only + # other positive per-step reward is reward_lane_align (0.01-ish), so + # a single tick can't accumulate to 0.5 from lane reward alone. + goal_reward_threshold = float(args.get("eval", {}).get("recovery_goal_reward_threshold", 0.5)) + # CONTROL: when env var RECOVERY_CACHE_RESET_PER_SCENARIO=1, reset + # the policy's K/V cache (= "_fresh_state") at every scenario + # boundary. This kills any cross-scenario context the Transformer + # would have used, isolating "is the cache helping?" from "is the + # per-scenario obs alone enough?". Env var because pufferl's + # argparser doesn't auto-create new --eval.* flags. + cache_reset_per_scenario = os.environ.get("RECOVERY_CACHE_RESET_PER_SCENARIO", "0") == "1" + if cache_reset_per_scenario: + print("[recovery] CONTROL mode: resetting K/V cache at every scenario boundary", flush=True) + success_arr = np.zeros((num_rollouts, k_scenarios, num_agents), dtype=bool) + + for rollout_idx in range(num_rollouts): + obs, _ = puffer_env.reset() + state = _fresh_state() + collected_infos = [] + scenario_metrics = {} + delta_metrics = {} + + for scenario in range(k_scenarios): + if scenario > 0 and cache_reset_per_scenario: + state = _fresh_state() + for time_idx in range(self.sim_steps): + with torch.no_grad(): + ob_tensor = torch.as_tensor(obs).to(device) + logits, value = policy.forward_eval(ob_tensor, state) + action, logprob, _ = pufferlib.pytorch.sample_logits(logits) + action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape) + + if isinstance(logits, torch.distributions.Normal): + action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high) + + obs, rewards, dones, truncs, info_list = puffer_env.step(action_np) + + # Mark per-agent success this scenario: a +reward_goal spike + # at any tick == goal reached. In stop-on-goal mode the env + # does NOT set `dones` per agent (the agent just stops moving), + # so we can't gate on dones. The only step-level reward that + # crosses `goal_reward_threshold` is the goal reward itself + # (lane_align is ~0.01/step, so even integrated it can't + # reach 0.5 in one tick). We OR across the scenario so the + # success flag sticks even if subsequent ticks are 0. + rewards_arr = np.asarray(rewards).reshape(-1) + success_arr[rollout_idx, scenario] |= rewards_arr > goal_reward_threshold + + for info_dict in info_list: + if not isinstance(info_dict, dict): + continue + if "ada_delta_score" in info_dict: + delta_metrics = info_dict + elif any(k.startswith("scenario_") for k in info_dict.keys()): + scenario_metrics.update(info_dict) + elif "score" in info_dict: + collected_infos.append(info_dict) + + if collected_infos: + rollout_agg = { + k: float(np.mean([d.get(k, 0) for d in collected_infos])) for k in collected_infos[0].keys() + } + else: + rollout_agg = {} + per_rollout_aggregates.append(rollout_agg) + per_rollout_scenario.append(scenario_metrics) + per_rollout_delta.append(delta_metrics) + + # Mean + std across rollouts (std only meaningful for >1 rollout) + final = {} + for dicts in (per_rollout_aggregates, per_rollout_scenario, per_rollout_delta): + keys = {k for d in dicts for k in d.keys()} + for k in keys: + vals = [float(d[k]) for d in dicts if k in d] + if not vals: + continue + final[k] = float(np.mean(vals)) + if len(vals) > 1: + final[f"{k}_std"] = float(np.std(vals, ddof=0)) + + final["n_rollouts"] = num_rollouts + final["n_agents_per_rollout"] = num_agents + final["n_total_evals"] = num_rollouts * num_agents + + # ----- Raw per-(rollout, agent, scenario) success log ----- + # No fancy aggregation here. We dump the full success grid as a flat + # list of records, one per (rollout, agent), with success bools per + # scenario. Any conditional rate (e.g., P(succeed s_k | fail s_0)) is + # a one-liner over this data downstream. + # + # Schema: list of {"rollout": int, "agent": int, "s0": int, ..., + # "s_{k-1}": int} — one record per (rollout, agent) pair. + records = [] + for r in range(num_rollouts): + for a in range(num_agents): + rec = {"rollout": int(r), "agent": int(a)} + for s_idx in range(k_scenarios): + rec[f"s{s_idx}"] = int(success_arr[r, s_idx, a]) + records.append(rec) + final["per_agent_success_log"] = records + + return final diff --git a/pufferlib/ocean/drive/README.md b/pufferlib/ocean/drive/README.md deleted file mode 100644 index a37907b206..0000000000 --- a/pufferlib/ocean/drive/README.md +++ /dev/null @@ -1,108 +0,0 @@ -# PufferDrive - -This readme contains several important assumptions and definions about the `PufferDrive` environment. - -## Agent initialization and control - -### `init_mode` - -Determines which agents are **created** in the environment. - -| Option | Description | -| ------------------------ | ---------------------------------------------------------------------------- | -| `create_all_valid` | Create all entities valid at initialization (`traj_valid[init_steps] == 1`). | -| `create_only_controlled` | Create only those agents that are controlled by the policy. | - -### `control_mode` - -Determines which created agents are **controlled** by the policy. - -| Option | Description | -| ----------------------------------------- | ------------------------------------------------------------------------------------------------- | -| `control_vehicles` (default) | Control only valid **vehicles** (not experts, beyond `MIN_DISTANCE_TO_GOAL`, under `MAX_AGENTS`). | -| `control_agents` | Control all valid **agent types** (vehicles, cyclists, pedestrians). | -| `control_tracks_to_predict` *(WOMD only)* | Control agents listed in the `tracks_to_predict` metadata. | - - -## Termination conditions (`done`) - -Episodes are never truncated before reaching `episode_len`. The `goal_behavior` argument controls agent behavior after reaching a goal early: - -* **`goal_behavior=0` (default):** Agents respawn at their initial position after reaching their goal (last valid log position). -* **`goal_behavior=1`:** Agents receive new goals indefinitely after reaching each goal. -* **`goal_behavior=2`:** Agents stop after reaching their goal. - -## Logged performance metrics - -We record multiple performance metrics during training, aggregated over all *active agents* (alive and controlled). Key metrics include: - -- `score`: Goals reached cleanly (goal was achieved without collision or going off-road) -- `collision_rate`: Binary flag (0 or 1) if agent hit another vehicle. -- `offroad_rate`: Binary flag (0 or 1) if agent left road bounds. -- `completion_rate`: Whether the agent reached its goal in this episode (even if it collided or went off-road). - - -### Metric aggregation - -The `num_agents` parameter in `drive.ini` defines the total number of agents used to collect experience. -At runtime, **Puffer** uses `num_maps` to create enough environments to populate the buffer with `num_agents`, distributing them evenly across `num_envs`. - -Because agents are respawned immediately after reaching their goal, they remain active throughout the episode. - -At the end of each episode (i.e., when `timestep == TRAJECTORY_LENGTH`), metrics are logged once via: - -```C -if (env->timestep == TRAJECTORY_LENGTH) { - add_log(env); - c_reset(env); - return; -} -``` - -Metrics are normalized and aggregated in `vec_log` (`pufferlib/ocean/env_binding.h`). They are averaged over all active agents across all environments. For example, the aggregated collision rate is computed as: - -$$ -r^{agg}_{\text{collision}} = \frac{\mathbb{I}[\text{collided in episode}]}{N} -$$ - -where $N$ is the number of controlled agents. -This value represents the fraction of agents that collided at least once during the episode. So, cases **A** and **B** below would yield identical off-road and collision rates: - -![alt text](../../resources/drive/examples_a_b.png) - -Since these metrics do not capture *multiple* events per agent, we additionally log the **average number of collision and off-road events per episode**. This is computed as: - -$$ -c^{avg}_{\text{collision}} = \frac{\text{total number of collision events across all agents and environments}}{N} -$$ - -where $N$ is the total number of controlled agents. -For example, an `avg_collisions_per_agent` value of 4 indicates that, on average, each agent collides four times per episode. - -### Effect of respawning on metrics - -By default, agents are reset to their initial position when they reach their goal before the episode ends. Upon respawn, `respawn_timestep` is updated from `-1` to the current step index. - -This raises the question: **how does repeated respawning affect aggregated metrics?** - -To begin, note that the environment is a bit different before and after respawn. After an agent respawns, all other agents are "removed" from the environment. As a result, collisions with other agents cannot occur post-respawn. - -This effectively transforms the scenario into a single-agent environment, simplifying the task since the agent no longer needs to coordinate with others. - -![alt text](../../resources/drive/pre_and_post_respawn.png) - -#### `score` - -Consider an episode of 91 steps where an agent is initialized relatively close to the goal position and reaches its goal three times: - -1. **First attempt:** reaches the goal without collisions -2. **Second attempt:** reaches the goal without collisions -3. **Third attempt:** reaches the goal but goes off-road along the way - -![alt text](../../resources/drive/realistic_collision_event_post_respawn.png) - -The highlighted trajectory shows the first attempt. In this case, the recorded score is `0.0` — a single off-road event invalidates the score for the entire episode. This behavior is desired: the score metric is unforgiving. - -#### `offroad_rate` and `collision_rate` - -Same logic holds as above. diff --git a/pufferlib/ocean/drive/adaptive.py b/pufferlib/ocean/drive/adaptive.py index 184ce41577..7af77c62db 100644 --- a/pufferlib/ocean/drive/adaptive.py +++ b/pufferlib/ocean/drive/adaptive.py @@ -12,7 +12,12 @@ def __init__(self, **kwargs): kwargs["ini_file"] = "pufferlib/config/ocean/adaptive.ini" kwargs["adaptive_driving_agent"] = True + # Human replay mode: disable co-players, use human trajectories for other agents + human_replay_mode = kwargs.pop("human_replay_mode", False) + if human_replay_mode: + kwargs["co_player_enabled"] = False + kwargs["resample_frequency"] = self.k_scenarios * self.scenario_length self.episode_length = kwargs["resample_frequency"] - # print(f"resample frequency is ", kwargs["resample_frequency"], flush=True) + super().__init__(**kwargs) diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index 60c03814b7..5eaed37b0a 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -1,27 +1,29 @@ #define Env Drive #define MY_SHARED #define MY_PUT + +#include #include "binding.h" -static int my_put(Env* env, PyObject* args, PyObject* kwargs) { - PyObject* obs = PyDict_GetItemString(kwargs, "observations"); +static int my_put(Env *env, PyObject *args, PyObject *kwargs) { + PyObject *obs = PyDict_GetItemString(kwargs, "observations"); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array"); return 1; } - PyArrayObject* observations = (PyArrayObject*)obs; + PyArrayObject *observations = (PyArrayObject *)obs; if (!PyArray_ISCONTIGUOUS(observations)) { PyErr_SetString(PyExc_ValueError, "Observations must be contiguous"); return 1; } env->observations = PyArray_DATA(observations); - PyObject* act = PyDict_GetItemString(kwargs, "actions"); + PyObject *act = PyDict_GetItemString(kwargs, "actions"); if (!PyObject_TypeCheck(act, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array"); return 1; } - PyArrayObject* actions = (PyArrayObject*)act; + PyArrayObject *actions = (PyArrayObject *)act; if (!PyArray_ISCONTIGUOUS(actions)) { PyErr_SetString(PyExc_ValueError, "Actions must be contiguous"); return 1; @@ -32,12 +34,12 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) { return 1; } - PyObject* rew = PyDict_GetItemString(kwargs, "rewards"); + PyObject *rew = PyDict_GetItemString(kwargs, "rewards"); if (!PyObject_TypeCheck(rew, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array"); return 1; } - PyArrayObject* rewards = (PyArrayObject*)rew; + PyArrayObject *rewards = (PyArrayObject *)rew; if (!PyArray_ISCONTIGUOUS(rewards)) { PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous"); return 1; @@ -48,12 +50,12 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) { } env->rewards = PyArray_DATA(rewards); - PyObject* term = PyDict_GetItemString(kwargs, "terminals"); + PyObject *term = PyDict_GetItemString(kwargs, "terminals"); if (!PyObject_TypeCheck(term, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array"); return 1; } - PyArrayObject* terminals = (PyArrayObject*)term; + PyArrayObject *terminals = (PyArrayObject *)term; if (!PyArray_ISCONTIGUOUS(terminals)) { PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous"); return 1; @@ -66,23 +68,21 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) { return 0; } -static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) { int population_play = unpack(kwargs, "population_play"); - if (population_play){ - return my_shared_population_play(self, args, kwargs); - } - else{ - return my_shared_self_play( self, args, kwargs); + if (population_play) { + return my_shared_population_play(self, args, kwargs); + } else { + return my_shared_self_play(self, args, kwargs); } - } -static int my_init(Env* env, PyObject* args, PyObject* kwargs) { +static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->human_agent_idx = unpack(kwargs, "human_agent_idx"); env->ini_file = unpack_str(kwargs, "ini_file"); env_init_config conf = {0}; - if(ini_parse(env->ini_file, handler, &conf) < 0) { + if (ini_parse(env->ini_file, handler, &conf) < 0) { printf("Error while loading %s", env->ini_file); } if (kwargs && PyDict_GetItemString(kwargs, "scenario_length")) { @@ -95,15 +95,18 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { env->action_type = conf.action_type; env->dynamics_model = conf.dynamics_model; if (PyDict_GetItemString(kwargs, "dynamics_model")) { - char* dynamics_str = unpack_str(kwargs, "dynamics_model"); + char *dynamics_str = unpack_str(kwargs, "dynamics_model"); env->dynamics_model = (strcmp(dynamics_str, "jerk") == 0) ? JERK : CLASSIC; } env->reward_vehicle_collision = conf.reward_vehicle_collision; env->reward_offroad_collision = conf.reward_offroad_collision; env->reward_goal = conf.reward_goal; env->reward_goal_post_respawn = conf.reward_goal_post_respawn; - env->reward_ade = conf.reward_ade; + env->reward_lane_align = (float)unpack(kwargs, "reward_lane_align"); + env->reward_vel_align = (float)unpack(kwargs, "reward_vel_align"); env->scenario_length = conf.scenario_length; + + env->termination_mode = conf.termination_mode; env->collision_behavior = conf.collision_behavior; env->offroad_behavior = conf.offroad_behavior; env->max_controlled_agents = unpack(kwargs, "max_controlled_agents"); @@ -127,10 +130,10 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { if (env->population_play) { env->num_co_players = unpack(kwargs, "num_co_players"); - double* co_player_ids_d = unpack_float_array(kwargs, "co_player_ids", &env->num_co_players); + double *co_player_ids_d = unpack_float_array(kwargs, "co_player_ids", &env->num_co_players); if (co_player_ids_d != NULL && env->num_co_players > 0) { - env->co_player_ids = (int*)malloc(env->num_co_players * sizeof(int)); + env->co_player_ids = (int *)malloc(env->num_co_players * sizeof(int)); if (env->co_player_ids == NULL) { fprintf(stderr, "Error: Failed to allocate memory for co_player_ids\n"); free(co_player_ids_d); @@ -152,9 +155,9 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { // Handle ego agents - always as an array env->num_ego_agents = unpack(kwargs, "num_ego_agents"); if (env->num_ego_agents > 0) { - double* ego_agent_ids_d = unpack_float_array(kwargs, "ego_agent_ids", &env->num_ego_agents); + double *ego_agent_ids_d = unpack_float_array(kwargs, "ego_agent_ids", &env->num_ego_agents); if (ego_agent_ids_d != NULL) { - env->ego_agent_ids = (int*)malloc(env->num_ego_agents * sizeof(int)); + env->ego_agent_ids = (int *)malloc(env->num_ego_agents * sizeof(int)); for (int i = 0; i < env->num_ego_agents; i++) { env->ego_agent_ids[i] = (int)ego_agent_ids_d[i]; } @@ -173,16 +176,23 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { env->ego_agent_ids = NULL; } - env->init_mode = (int)unpack(kwargs, "init_mode"); env->control_mode = (int)unpack(kwargs, "control_mode"); + // Render mode: 0=RENDER_OFF, 1=RENDER_HEADLESS, 2=RENDER_WINDOW + env->render_mode = RENDER_OFF; // Default to off + if (kwargs && PyDict_GetItemString(kwargs, "render_mode")) { + env->render_mode = (int)unpack(kwargs, "render_mode"); + } env->goal_behavior = (int)unpack(kwargs, "goal_behavior"); + env->goal_target_distance = (float)unpack(kwargs, "goal_target_distance"); env->goal_radius = (float)unpack(kwargs, "goal_radius"); + env->goal_speed = (float)unpack(kwargs, "goal_speed"); + char *map_dir = unpack_str(kwargs, "map_dir"); int map_id = unpack(kwargs, "map_id"); int max_agents = unpack(kwargs, "max_agents"); int init_steps = unpack(kwargs, "init_steps"); - char map_file[100]; - sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); + char map_file[512]; + snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id); env->num_agents = max_agents; env->map_name = strdup(map_file); env->init_steps = init_steps; @@ -191,18 +201,21 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { return 0; } -static int my_log(PyObject* dict, Log* log) { +static int my_log(PyObject *dict, Log *log) { assign_to_dict(dict, "n", log->n); + assign_to_dict(dict, "score", log->score); assign_to_dict(dict, "offroad_rate", log->offroad_rate); - assign_to_dict(dict, "episode_length", log->episode_length); assign_to_dict(dict, "collision_rate", log->collision_rate); + assign_to_dict(dict, "episode_length", log->episode_length); assign_to_dict(dict, "episode_return", log->episode_return); assign_to_dict(dict, "dnf_rate", log->dnf_rate); - assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error); assign_to_dict(dict, "completion_rate", log->completion_rate); assign_to_dict(dict, "lane_alignment_rate", log->lane_alignment_rate); - assign_to_dict(dict, "score", log->score); - assign_to_dict(dict, "avg_offroad_per_agent", log->avg_offroad_per_agent); - assign_to_dict(dict, "avg_collisions_per_agent", log->avg_collisions_per_agent); + assign_to_dict(dict, "offroad_per_agent", log->offroad_per_agent); + assign_to_dict(dict, "collisions_per_agent", log->collisions_per_agent); + assign_to_dict(dict, "goals_sampled_this_episode", log->goals_sampled_this_episode); + assign_to_dict(dict, "goals_reached_this_episode", log->goals_reached_this_episode); + assign_to_dict(dict, "speed_at_goal", log->speed_at_goal); + // assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error); return 0; } diff --git a/pufferlib/ocean/drive/binding.h b/pufferlib/ocean/drive/binding.h index b5ec9ed65d..492dde2f14 100644 --- a/pufferlib/ocean/drive/binding.h +++ b/pufferlib/ocean/drive/binding.h @@ -1,57 +1,75 @@ #include "drive.h" #include "../env_binding.h" -static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *my_shared_self_play(PyObject *self, PyObject *args, PyObject *kwargs) { + char *map_dir = unpack_str(kwargs, "map_dir"); int num_agents = unpack(kwargs, "num_agents"); int num_maps = unpack(kwargs, "num_maps"); int init_mode = unpack(kwargs, "init_mode"); int control_mode = unpack(kwargs, "control_mode"); int init_steps = unpack(kwargs, "init_steps"); + int goal_behavior = unpack(kwargs, "goal_behavior"); + float goal_target_distance = unpack(kwargs, "goal_target_distance"); + int use_all_maps = unpack(kwargs, "use_all_maps"); int max_controlled_agents = unpack(kwargs, "max_controlled_agents"); - clock_gettime(CLOCK_REALTIME, &ts); - srand(ts.tv_nsec); + int map_seed = PyDict_GetItemString(kwargs, "map_seed") ? unpack(kwargs, "map_seed") : -1; + printf("Generating environments for %d agents using %s maps from %s, num maps %d \n", num_agents, + use_all_maps ? "all" : "random", map_dir, num_maps); + fflush(stdout); + // Use provided seed or fall back to time+pid + if (map_seed >= 0) { + srand((unsigned int)map_seed); + } else { + clock_gettime(CLOCK_REALTIME, &ts); + srand((unsigned int)(ts.tv_sec ^ ts.tv_nsec ^ getpid())); + } int total_agent_count = 0; int env_count = 0; - int max_envs = num_agents; + int max_envs = use_all_maps ? num_maps : num_agents; + int map_idx = 0; int maps_checked = 0; - PyObject* agent_offsets = PyList_New(max_envs+1); - PyObject* map_ids = PyList_New(max_envs); + PyObject *agent_offsets = PyList_New(max_envs + 1); + PyObject *map_ids = PyList_New(max_envs); // getting env count - while(total_agent_count < num_agents && env_count < max_envs){ - char map_file[100]; - int map_id = rand() % num_maps; - Drive* env = calloc(1, sizeof(Drive)); + while (use_all_maps ? map_idx < max_envs : total_agent_count < num_agents && env_count < max_envs) { + char map_file[512]; + int map_id = use_all_maps ? map_idx++ : rand() % num_maps; + Drive *env = calloc(1, sizeof(Drive)); env->init_mode = init_mode; + env->max_controlled_agents = max_controlled_agents; env->control_mode = control_mode; env->init_steps = init_steps; - env->max_controlled_agents = max_controlled_agents; - sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); + env->goal_behavior = goal_behavior; + env->goal_target_distance = goal_target_distance; + snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id); env->entities = load_map_binary(map_file, env); set_active_agents(env); // Skip map if it doesn't contain any controllable agents - if(env->active_agent_count == 0) { - maps_checked++; - - // Safeguard: if we've checked all available maps and found no active agents, raise an error - if(maps_checked >= num_maps) { - for(int j=0;jnum_entities;j++) { - free_entity(&env->entities[j]); + if (env->active_agent_count == 0) { + if (!use_all_maps) { + maps_checked++; + + // Safeguard: if we've checked all available maps and found no active agents, raise an error + if (maps_checked >= num_maps) { + for (int j = 0; j < env->num_entities; j++) { + free_entity(&env->entities[j]); + } + free(env->entities); + free(env->active_agent_indices); + free(env->static_agent_indices); + free(env->expert_static_agent_indices); + free(env); + Py_DECREF(agent_offsets); + Py_DECREF(map_ids); + char error_msg[256]; + sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps); + PyErr_SetString(PyExc_ValueError, error_msg); + return NULL; } - free(env->entities); - free(env->active_agent_indices); - free(env->static_agent_indices); - free(env->expert_static_agent_indices); - free(env); - Py_DECREF(agent_offsets); - Py_DECREF(map_ids); - char error_msg[256]; - sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps); - PyErr_SetString(PyExc_ValueError, error_msg); - return NULL; } - for(int j=0;jnum_entities;j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -60,17 +78,17 @@ static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* k free(env->expert_static_agent_indices); free(env); continue; - } + } // Store map_id - PyObject* map_id_obj = PyLong_FromLong(map_id); + PyObject *map_id_obj = PyLong_FromLong(map_id); PyList_SetItem(map_ids, env_count, map_id_obj); // Store agent offset - PyObject* offset = PyLong_FromLong(total_agent_count); + PyObject *offset = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, offset); total_agent_count += env->active_agent_count; env_count++; - for(int j=0;jnum_entities;j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -79,26 +97,26 @@ static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* k free(env->expert_static_agent_indices); free(env); } - //printf("Generated %d environments to cover %d agents (requested %d agents)\n", env_count, total_agent_count, num_agents); - if(total_agent_count >= num_agents){ + // printf("Generated %d environments to cover %d agents (requested %d agents)\n", env_count, total_agent_count, + // num_agents); + if (!use_all_maps && total_agent_count >= num_agents) { total_agent_count = num_agents; } - PyObject* final_total_agent_count = PyLong_FromLong(total_agent_count); + PyObject *final_total_agent_count = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, final_total_agent_count); - PyObject* final_env_count = PyLong_FromLong(env_count); + PyObject *final_env_count = PyLong_FromLong(env_count); // resize lists - PyObject* resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); - PyObject* resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); - PyObject* tuple = PyTuple_New(3); + PyObject *resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); + PyObject *resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); + PyObject *tuple = PyTuple_New(3); PyTuple_SetItem(tuple, 0, resized_agent_offsets); PyTuple_SetItem(tuple, 1, resized_map_ids); PyTuple_SetItem(tuple, 2, final_env_count); return tuple; } - -static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_size) { - PyObject* val = PyDict_GetItemString(kwargs, key); +static double *unpack_float_array(PyObject *kwargs, char *key, Py_ssize_t *out_size) { + PyObject *val = PyDict_GetItemString(kwargs, key); if (val == NULL) { char error_msg[100]; snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key); @@ -123,15 +141,14 @@ static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_s return NULL; } - double* array = (double*)malloc(size * sizeof(double)); + double *array = (double *)malloc(size * sizeof(double)); if (array == NULL) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate memory for float array"); return NULL; } - for (Py_ssize_t i = 0; i < size; i++) { - PyObject* item = PySequence_GetItem(val, i); + PyObject *item = PySequence_GetItem(val, i); if (item == NULL) { free(array); return NULL; @@ -168,19 +185,20 @@ static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_s return array; } - -static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *my_shared_population_play(PyObject *self, PyObject *args, PyObject *kwargs) { + char *map_dir = unpack_str(kwargs, "map_dir"); int num_agents = unpack(kwargs, "num_agents"); int num_maps = unpack(kwargs, "num_maps"); int num_ego_agents = unpack(kwargs, "num_ego_agents"); - int init_mode = unpack(kwargs, "init_mode"); + int init_mode = unpack(kwargs, "init_mode"); int population_play = unpack(kwargs, "population_play"); int control_mode = unpack(kwargs, "control_mode"); int init_steps = unpack(kwargs, "init_steps"); int max_controlled_agents = unpack(kwargs, "max_controlled_agents"); + int map_seed = PyDict_GetItemString(kwargs, "map_seed") ? unpack(kwargs, "map_seed") : -1; int max_scenes_per_process = 0; - PyObject* max_envs_obj = PyDict_GetItemString(kwargs, "max_scenes_per_process"); + PyObject *max_envs_obj = PyDict_GetItemString(kwargs, "max_scenes_per_process"); if (max_envs_obj && PyLong_Check(max_envs_obj)) { long v = PyLong_AsLong(max_envs_obj); if (v > 0 && v <= INT_MAX) { @@ -188,17 +206,22 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj } } - // Use current time for randomness - struct timespec ts; - clock_gettime(CLOCK_REALTIME, &ts); - srand(ts.tv_nsec); + // Use provided seed or fall back to time+pid + if (map_seed >= 0) { + srand((unsigned int)map_seed); + } else { + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + srand((unsigned int)(ts.tv_sec ^ ts.tv_nsec ^ getpid())); + } int num_coplayers = num_agents - num_ego_agents; - printf("Creating worlds for %d total agents (%d egos, %d co-players)\n", - num_agents, num_ego_agents, num_coplayers); + // Silenced: noisy during normal training. Re-enable for debug. + // printf("Creating worlds for %d total agents (%d egos, %d co-players)\n", num_agents, num_ego_agents, + // num_coplayers); // Create shuffled agent role array (0 = coplayer, 1 = ego) - int* agent_roles = malloc(num_agents * sizeof(int)); + int *agent_roles = malloc(num_agents * sizeof(int)); for (int i = 0; i < num_ego_agents; i++) { agent_roles[i] = 1; // ego } @@ -218,23 +241,45 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj int env_count = 0; int total_egos_assigned = 0; int total_coplayers_assigned = 0; + int agent_role_index = 0; // Track position in agent_roles array int max_envs = num_agents; if (max_scenes_per_process > 0 && max_scenes_per_process < max_envs) { max_envs = max_scenes_per_process; } - PyObject* agent_offsets = PyList_New(max_envs + 1); - PyObject* map_ids = PyList_New(max_envs); - PyObject* ego_agent_ids = PyList_New(max_envs); - PyObject* coplayer_ids = PyList_New(max_envs); + PyObject *agent_offsets = PyList_New(max_envs + 1); + PyObject *map_ids = PyList_New(max_envs); + PyObject *ego_agent_ids = PyList_New(max_envs); + PyObject *coplayer_ids = PyList_New(max_envs); + + int consecutive_skips = 0; // Safety counter for infinite loop detection + int max_consecutive_skips = num_maps * 3; // Allow trying each map multiple times // Create worlds by randomly sampling maps while (total_agent_count < num_agents && env_count < max_envs) { + // Safety check: if we've skipped too many times in a row, something is wrong + if (consecutive_skips > max_consecutive_skips) { + fprintf(stderr, + "[shared_population_play] ERROR: Too many consecutive skips (%d). " + "All maps may have 0 active agents. agent_role_index=%d, total_agent_count=%d\n", + consecutive_skips, agent_role_index, total_agent_count); + + Py_DECREF(agent_offsets); + Py_DECREF(map_ids); + Py_DECREF(ego_agent_ids); + Py_DECREF(coplayer_ids); + free(agent_roles); + PyErr_Format(PyExc_RuntimeError, + "shared_population_play: unable to find maps with active agents after %d attempts", + consecutive_skips); + return NULL; + } + char map_file[100]; int map_id = rand() % num_maps; - Drive* env = calloc(1, sizeof(Drive)); - sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); + Drive *env = calloc(1, sizeof(Drive)); + snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id); env->entities = load_map_binary(map_file, env); int remaining_capacity = num_agents - total_agent_count; @@ -250,15 +295,28 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj set_active_agents(env); + // CRITICAL FIX: Skip maps with 0 active agents + if (env->active_agent_count == 0) { + // Silenced: noisy during normal training. Re-enable for debug. + // printf("Skipping map %d (0 active agents)\n", map_id); + for (int j = 0; j < env->num_entities; j++) { + free_entity(&env->entities[j]); + } + free(env->entities); + free(env->active_agent_indices); + free(env->static_agent_indices); + free(env->expert_static_agent_indices); + free(env); + consecutive_skips++; + continue; + } + int next_total = total_agent_count + env->active_agent_count; if (next_total > num_agents) { int remaining = num_agents - total_agent_count; - fprintf(stderr, - "[shared_population_play] ERROR oversubscribed agents: requested=%d remaining=%d map=%d\n", - env->active_agent_count, - remaining, - map_id); - for(int j=0; jnum_entities; j++) { + fprintf(stderr, "[shared_population_play] ERROR oversubscribed agents: requested=%d remaining=%d map=%d\n", + env->active_agent_count, remaining, map_id); + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -272,25 +330,22 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj Py_DECREF(coplayer_ids); free(agent_roles); PyErr_Format(PyExc_RuntimeError, - "shared_population_play oversubscribed: total=%d target=%d map=%d active=%d", - next_total, - num_agents, - map_id, - env->active_agent_count); + "shared_population_play oversubscribed: total=%d target=%d map=%d active=%d", next_total, + num_agents, map_id, env->active_agent_count); return NULL; } // Store map_id - PyObject* map_id_obj = PyLong_FromLong(map_id); + PyObject *map_id_obj = PyLong_FromLong(map_id); PyList_SetItem(map_ids, env_count, map_id_obj); // Store agent offset - PyObject* offset = PyLong_FromLong(total_agent_count); + PyObject *offset = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, offset); // Create ego and coplayer lists for this world - PyObject* ego_list = PyList_New(0); - PyObject* coplayer_list = PyList_New(0); + PyObject *ego_list = PyList_New(0); + PyObject *coplayer_list = PyList_New(0); int world_egos = 0; int world_coplayers = 0; @@ -298,9 +353,9 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj // Assign agents from the shuffled roles for (int a = 0; a < env->active_agent_count; a++) { - PyObject* agent_id = PyLong_FromLong(total_agent_count); + PyObject *agent_id = PyLong_FromLong(total_agent_count); - if (agent_roles[total_agent_count] == 1) { + if (agent_roles[agent_role_index] == 1) { // This agent is an ego PyList_Append(ego_list, agent_id); world_egos++; @@ -314,22 +369,26 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj Py_DECREF(agent_id); total_agent_count++; + agent_role_index++; } // Enforce constraint: must have at least 1 ego per world (if egos remain) if (world_egos == 0 && remaining_egos > 0) { - fprintf(stderr, - "[shared_population_play] WARNING: World %d has no ego agents but %d egos remain. Skipping world.\n", - env_count, remaining_egos); + // Silenced: noisy during normal training. Re-enable for debug. + // fprintf( + // stderr, + // "[shared_population_play] WARNING: World %d has no ego agents but %d egos remain. Skipping world.\n", + // env_count, remaining_egos); // Rollback the agent assignments for this world total_agent_count -= env->active_agent_count; total_coplayers_assigned -= world_coplayers; + agent_role_index -= env->active_agent_count; Py_DECREF(ego_list); Py_DECREF(coplayer_list); - for(int j=0; jnum_entities; j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -337,18 +396,24 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj free(env->static_agent_indices); free(env->expert_static_agent_indices); free(env); + + consecutive_skips++; continue; // Try another map } + // Successfully created a world, reset skip counter + consecutive_skips = 0; + PyList_SetItem(ego_agent_ids, env_count, ego_list); PyList_SetItem(coplayer_ids, env_count, coplayer_list); - printf("World %d (map %d): %d agents (%d egos, %d co-players)\n", - env_count, map_id, env->active_agent_count, world_egos, world_coplayers); + // Silenced: noisy during normal training. Re-enable for debug. + // printf("World %d (map %d): %d agents (%d egos, %d co-players)\n", env_count, map_id, env->active_agent_count, + // world_egos, world_coplayers); env_count++; - for(int j=0; jnum_entities; j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -362,15 +427,15 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj total_agent_count = num_agents; } - PyObject* final_total_agent_count = PyLong_FromLong(total_agent_count); + PyObject *final_total_agent_count = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, final_total_agent_count); - PyObject* final_env_count = PyLong_FromLong(env_count); + PyObject *final_env_count = PyLong_FromLong(env_count); // Resize lists - PyObject* resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); - PyObject* resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); - PyObject* resized_ego_ids = PyList_GetSlice(ego_agent_ids, 0, env_count); - PyObject* resized_coplayer_ids = PyList_GetSlice(coplayer_ids, 0, env_count); + PyObject *resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); + PyObject *resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); + PyObject *resized_ego_ids = PyList_GetSlice(ego_agent_ids, 0, env_count); + PyObject *resized_coplayer_ids = PyList_GetSlice(coplayer_ids, 0, env_count); Py_DECREF(agent_offsets); Py_DECREF(map_ids); @@ -381,15 +446,16 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj free(agent_roles); // Create a tuple - PyObject* tuple = PyTuple_New(5); + PyObject *tuple = PyTuple_New(5); PyTuple_SetItem(tuple, 0, resized_agent_offsets); PyTuple_SetItem(tuple, 1, resized_map_ids); PyTuple_SetItem(tuple, 2, final_env_count); PyTuple_SetItem(tuple, 3, resized_ego_ids); PyTuple_SetItem(tuple, 4, resized_coplayer_ids); - printf("Total: %d agents across %d worlds (egos: %d, co-players: %d)\n", - total_agent_count, env_count, total_egos_assigned, total_coplayers_assigned); + // Silenced: noisy during normal training. Re-enable for debug. + // printf("Total: %d agents across %d worlds (egos: %d, co-players: %d)\n", total_agent_count, env_count, + // total_egos_assigned, total_coplayers_assigned); return tuple; } diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c deleted file mode 100644 index 5e193c0371..0000000000 --- a/pufferlib/ocean/drive/drive.c +++ /dev/null @@ -1,171 +0,0 @@ -#include "drive.h" -#include "drivenet.h" -#include -#include "../env_config.h" - -// Use this test if the network changes to ensure that the forward pass -// matches the torch implementation to the 3rd or ideally 4th decimal place -void test_drivenet() { - int num_obs = 1848; - int num_actions = 2; - int num_agents = 4; - - float* observations = calloc(num_agents*num_obs, sizeof(float)); - for (int i = 0; i < num_obs*num_agents; i++) { - observations[i] = i % 7; - } - - int* actions = calloc(num_agents*num_actions, sizeof(int)); - - //Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin"); - Weights* weights = load_weights("puffer_drive_weights.bin"); - DriveNet* net = init_drivenet(weights, num_agents, CLASSIC, false, false, false); - - forward(net, observations, actions); - for (int i = 0; i < num_agents*num_actions; i++) { - printf("idx: %d, action: %d, logits:", i, actions[i]); - for (int j = 0; j < num_actions; j++) { - printf(" %.6f", net->actor->output[i*num_actions + j]); - } - printf("\n"); - } - free_drivenet(net); - free(weights); -} - -void demo() { - // Read configuration from INI file - env_init_config conf = {0}; - const char* ini_file = "pufferlib/config/ocean/drive.ini"; - if(ini_parse(ini_file, handler, &conf) < 0) { - fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); - exit(1); - } - - Drive env = { - .human_agent_idx = 0, - .dynamics_model = conf.dynamics_model, - .reward_vehicle_collision = conf.reward_vehicle_collision, - .reward_offroad_collision = conf.reward_offroad_collision, - .reward_ade = conf.reward_ade, - .goal_radius = conf.goal_radius, - .dt = conf.dt, - .map_name = "resources/drive/binaries/map_000.bin", - .init_steps = conf.init_steps, - .collision_behavior = conf.collision_behavior, - .offroad_behavior = conf.offroad_behavior, - }; - allocate(&env); - c_reset(&env); - c_render(&env); - Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin"); - DriveNet* net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, false, false, false); - //Client* client = make_client(&env); - int accel_delta = 2; - int steer_delta = 4; - while (!WindowShouldClose()) { - // Handle camera controls - int (*actions)[2] = (int(*)[2])env.actions; - forward(net, env.observations, env.actions); - if (IsKeyDown(KEY_LEFT_SHIFT)) { - actions[env.human_agent_idx][0] = 3; - actions[env.human_agent_idx][1] = 6; - if(IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)){ - actions[env.human_agent_idx][0] += accel_delta; - // Cap acceleration to maximum of 6 - if(actions[env.human_agent_idx][0] > 6) { - actions[env.human_agent_idx][0] = 6; - } - } - if(IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)){ - actions[env.human_agent_idx][0] -= accel_delta; - // Cap acceleration to minimum of 0 - if(actions[env.human_agent_idx][0] < 0) { - actions[env.human_agent_idx][0] = 0; - } - } - if(IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)){ - actions[env.human_agent_idx][1] += steer_delta; - // Cap steering to minimum of 0 - if(actions[env.human_agent_idx][1] < 0) { - actions[env.human_agent_idx][1] = 0; - } - } - if(IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)){ - actions[env.human_agent_idx][1] -= steer_delta; - // Cap steering to maximum of 12 - if(actions[env.human_agent_idx][1] > 12) { - actions[env.human_agent_idx][1] = 12; - } - } - if(IsKeyPressed(KEY_TAB)){ - env.human_agent_idx = (env.human_agent_idx + 1) % env.active_agent_count; - } - } - c_step(&env); - c_render(&env); - } - - close_client(env.client); - free_allocated(&env); - free_drivenet(net); - free(weights); -} - -void performance_test() { - // Read configuration from INI file - env_init_config conf = {0}; - const char* ini_file = "pufferlib/config/ocean/drive.ini"; - if(ini_parse(ini_file, handler, &conf) < 0) { - fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); - exit(1); - } - - long test_time = 10; - Drive env = { - .human_agent_idx = 0, - .dynamics_model = conf.dynamics_model, - .reward_vehicle_collision = conf.reward_vehicle_collision, - .reward_offroad_collision = conf.reward_offroad_collision, - .reward_ade = conf.reward_ade, - .goal_radius = conf.goal_radius, - .dt = conf.dt, - .map_name = "resources/drive/binaries/map_000.bin", - .init_steps = conf.init_steps, - }; - clock_t start_time, end_time; - double cpu_time_used; - start_time = clock(); - allocate(&env); - c_reset(&env); - end_time = clock(); - cpu_time_used = ((double) (end_time - start_time)) / CLOCKS_PER_SEC; - printf("Init time: %f\n", cpu_time_used); - - long start = time(NULL); - int i = 0; - int (*actions)[2] = (int(*)[2])env.actions; - - while (time(NULL) - start < test_time) { - // Set random actions for all agents - for(int j = 0; j < env.active_agent_count; j++) { - int accel = rand() % 7; - int steer = rand() % 13; - actions[j][0] = accel; // -1, 0, or 1 - actions[j][1] = steer; // Random steering - } - - c_step(&env); - i++; - } - long end = time(NULL); - printf("SPS: %ld\n", (i*env.active_agent_count) / (end - start)); - free_allocated(&env); -} - -int main() { - //performance_test(); - demo(); - //test_drivenet(); - return 0; -} diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index cb112b9a7e..895604077f 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -6,12 +6,26 @@ #include #include #include +#include +#include +#include #include "raylib.h" #include "raymath.h" #include "rlgl.h" #include +#include #include "error.h" +// Render modes +#define RENDER_OFF 0 +#define RENDER_HEADLESS 1 +#define RENDER_WINDOW 2 + +// View modes for rendering +#define VIEW_MODE_SIM_STATE 0 // Full simulation state (top-down orthographic) +#define VIEW_MODE_BEV_AGENT_OBS 1 // Bird's eye view centered on agent +#define VIEW_MODE_AGENT_PERSP 2 // Agent perspective (3rd person) + // Entity Types #define NONE 0 #define VEHICLE 1 @@ -37,7 +51,7 @@ // Control modes #define CONTROL_VEHICLES 0 #define CONTROL_AGENTS 1 -#define CONTROL_TRACKS_TO_PREDICT 2 +#define CONTROL_WOSAC 2 #define CONTROL_SDC_ONLY 3 // Minimum distance to goal position @@ -60,11 +74,23 @@ #define OFFROAD_IDX 1 #define REACHED_GOAL_IDX 2 #define LANE_ALIGNED_IDX 3 -#define AVG_DISPLACEMENT_ERROR_IDX 4 +#define LANE_DIST_IDX 4 +#define LANE_ANGLE_IDX 5 +#define AVG_DISPLACEMENT_ERROR_IDX 6 + +// Lane alignment constants (from GIGAFLOW) +#define LANE_DISTANCE_NORMALIZATION 4.0f +#define LANE_SELECTION_DISTANCE_WEIGHT 0.7f +#define LANE_SELECTION_HEADING_WEIGHT 0.3f +#define LANE_SWITCH_THRESHOLD 0.5f +#define LANE_ALIGN_COS_THRESHOLD 0.965f // ~15 degrees +#define MAX_CHECKED_LANES 32 // Grid cell size #define GRID_CELL_SIZE 5.0f -#define MAX_ENTITIES_PER_CELL 30 // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) => For each entity type in gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2 +#define MAX_ENTITIES_PER_CELL \ + 30 // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) => For each entity type in + // gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2 // Max road segment observation entities #define MAX_ROAD_SEGMENT_OBSERVATIONS 200 @@ -86,34 +112,58 @@ #define STOP_AGENT 1 #define REMOVE_AGENT 2 -//GOAL BEHAVIOUR +// GOAL BEHAVIOUR #define GOAL_RESPAWN 0 #define GOAL_GENERATE_NEW 1 #define GOAL_STOP 2 +#define PARTNER_FEATURES 7 + +#define ROAD_FEATURES 7 +#define ROAD_FEATURES_ONEHOT 13 +#define PARTNER_FEATURES 7 + +// Ego features depend on dynamics model +// Classic: goal_x, goal_y, speed, width, length, collision, respawn, lane_dist, lane_angle +#define EGO_FEATURES_CLASSIC 9 +// Jerk: + steering_angle, a_long, a_lat +#define EGO_FEATURES_JERK 12 + // Jerk action space (for JERK dynamics model) static const float JERK_LONG[4] = {-15.0f, -4.0f, 0.0f, 4.0f}; static const float JERK_LAT[3] = {-4.0f, 0.0f, 4.0f}; // Classic action space (for CLASSIC dynamics model) static const float ACCELERATION_VALUES[7] = {-4.0000f, -2.6670f, -1.3330f, -0.0000f, 1.3330f, 2.6670f, 4.0000f}; -static const float STEERING_VALUES[13] = {-1.000f, -0.833f, -0.667f, -0.500f, -0.333f, -0.167f, 0.000f, 0.167f, 0.333f, 0.500f, 0.667f, 0.833f, 1.000f}; +static const float STEERING_VALUES[13] = {-1.000f, -0.833f, -0.667f, -0.500f, -0.333f, -0.167f, 0.000f, + 0.167f, 0.333f, 0.500f, 0.667f, 0.833f, 1.000f}; static const float offsets[4][2] = { - {-1, 1}, // top-left - {1, 1}, // top-right - {1, -1}, // bottom-right - {-1, -1} // bottom-left - }; + {-1, 1}, // top-left + {1, 1}, // top-right + {1, -1}, // bottom-right + {-1, -1} // bottom-left +}; static const int collision_offsets[25][2] = { - {-2, -2}, {-1, -2}, {0, -2}, {1, -2}, {2, -2}, // Top row - {-2, -1}, {-1, -1}, {0, -1}, {1, -1}, {2, -1}, // Second row - {-2, 0}, {-1, 0}, {0, 0}, {1, 0}, {2, 0}, // Middle row (including center) - {-2, 1}, {-1, 1}, {0, 1}, {1, 1}, {2, 1}, // Fourth row - {-2, 2}, {-1, 2}, {0, 2}, {1, 2}, {2, 2} // Bottom row + {-2, -2}, {-1, -2}, {0, -2}, {1, -2}, {2, -2}, // Top row + {-2, -1}, {-1, -1}, {0, -1}, {1, -1}, {2, -1}, // Second row + {-2, 0}, {-1, 0}, {0, 0}, {1, 0}, {2, 0}, // Middle row (including center) + {-2, 1}, {-1, 1}, {0, 1}, {1, 1}, {2, 1}, // Fourth row + {-2, 2}, {-1, 2}, {0, 2}, {1, 2}, {2, 2} // Bottom row }; +const Color STONE_GRAY = (Color){80, 80, 80, 255}; +const Color PUFF_RED = (Color){187, 0, 0, 255}; +const Color PUFF_CYAN = (Color){0, 187, 187, 255}; +const Color PUFF_WHITE = (Color){241, 241, 241, 241}; +const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255}; +const Color PUFF_BACKGROUND2 = (Color){18, 72, 72, 255}; +const Color LIGHTGREEN = (Color){152, 255, 152, 255}; +const Color LIGHTYELLOW = (Color){255, 255, 152, 255}; +const Color LIGHTBLUE = (Color){152, 200, 255, 255}; +const Color SOFT_YELLOW = (Color){245, 245, 220, 255}; + struct timespec ts; typedef struct Drive Drive; @@ -122,20 +172,22 @@ typedef struct Log Log; typedef struct Graph Graph; typedef struct AdjListNode AdjListNode; typedef struct Co_Player_Log Co_Player_Log; -typedef struct Adaptive_Agent_Log Adaptive_Agent_Log; struct Log { float episode_return; float episode_length; float score; + float goals_reached_this_episode; + float goals_sampled_this_episode; float offroad_rate; float collision_rate; - float num_goals_reached; float completion_rate; + float offroad_per_agent; + float collisions_per_agent; float dnf_rate; float n; float lane_alignment_rate; - float avg_displacement_error; + float speed_at_goal; float active_agent_count; float expert_static_agent_count; float static_agent_count; @@ -145,8 +197,6 @@ struct Log { float avg_goal_weight; float avg_entropy_weight; float avg_discount_weight; - float avg_offroad_per_agent; - float avg_collisions_per_agent; }; typedef struct Entity Entity; @@ -155,14 +205,14 @@ struct Entity { int type; int id; int array_size; - float* traj_x; - float* traj_y; - float* traj_z; - float* traj_vx; - float* traj_vy; - float* traj_vz; - float* traj_heading; - int* traj_valid; + float *traj_x; + float *traj_y; + float *traj_z; + float *traj_vx; + float *traj_vy; + float *traj_vz; + float *traj_heading; + int *traj_valid; float width; float length; float height; @@ -173,7 +223,8 @@ struct Entity { float init_goal_y; int mark_as_expert; int collision_state; - float metrics_array[5]; // metrics_array: [collision, offroad, reached_goal, lane_aligned, avg_displacement_error] + float metrics_array[7]; // metrics_array: [collision, offroad, reached_goal, lane_aligned, lane_dist, lane_angle, + // avg_disp_error] float x; float y; float z; @@ -184,13 +235,14 @@ struct Entity { float heading_x; float heading_y; int current_lane_idx; + int current_lane_geometry_idx; int valid; int respawn_timestep; int respawn_count; int collided_before_goal; - int sampled_new_goal; - int reached_goal_this_episode; - int num_goals_reached; + float goals_reached_this_episode; + float goals_sampled_this_episode; + int current_goal_reached; int active_agent; float cumulative_displacement; int displacement_sample_count; @@ -206,12 +258,12 @@ struct Entity { float steering_angle; float wheelbase; - //population play + // population play bool is_ego; bool is_co_player; }; -void free_entity(Entity* entity){ +void free_entity(Entity *entity) { // free trajectory arrays free(entity->traj_x); free(entity->traj_y); @@ -222,22 +274,6 @@ void free_entity(Entity* entity){ free(entity->traj_heading); free(entity->traj_valid); } -struct Co_Player_Log { - float co_player_episode_return; - float co_player_episode_length; - float co_player_perf; - float co_player_score; - float co_player_offroad_rate; - float co_player_collision_rate; - float co_player_clean_collision_rate; - float co_player_num_goals_reached; - float co_player_completion_rate; - float co_player_dnf_rate; - float co_player_lane_alignment_rate; - float co_player_avg_displacement_error; - float co_player_n; -}; - // Utility functions float compute_delta_percent(float first, float last) { @@ -247,51 +283,26 @@ float compute_delta_percent(float first, float last) { return (last - first) / first * 100.0f; } -float relative_distance(float a, float b){ +float relative_distance(float a, float b) { float distance = sqrtf(powf(a - b, 2)); return distance; } -float relative_distance_2d(float x1, float y1, float x2, float y2){ +float relative_distance_2d(float x1, float y1, float x2, float y2) { float dx = x2 - x1; float dy = y2 - y1; - float distance = sqrtf(dx*dx + dy*dy); + float distance = sqrtf(dx * dx + dy * dy); return distance; } float clip(float value, float min, float max) { - if (value < min) return min; - if (value > max) return max; + if (value < min) + return min; + if (value > max) + return max; return value; } -float compute_displacement_error(Entity* agent, int timestep) { - // Check if timestep is within valid range - if (timestep < 0 || timestep >= agent->array_size) { - return 0.0f; - } - - // Check if reference trajectory is valid at this timestep - if (!agent->traj_valid[timestep]) { - return 0.0f; - } - - // Get reference position at current timestep, skip invalid ones - float ref_x = agent->traj_x[timestep]; - float ref_y = agent->traj_y[timestep]; - - if (ref_x == INVALID_POSITION || ref_y == INVALID_POSITION) { - return 0.0f; - } - - // Compute deltas: Euclidean distance between actual and reference position - float dx = agent->x - ref_x; - float dy = agent->y - ref_y; - float displacement = sqrtf(dx*dx + dy*dy); - - return displacement; -} - typedef struct GridMapEntity GridMapEntity; struct GridMapEntity { int entity_idx; @@ -308,64 +319,66 @@ struct GridMap { int grid_rows; int cell_size_x; int cell_size_y; - int* cell_entities_count; // number of entities in each cell of the GridMap - GridMapEntity** cells; // list of gridEntities in each cell of the GridMap - + int *cell_entities_count; // number of entities in each cell of the GridMap + GridMapEntity **cells; // list of gridEntities in each cell of the GridMap // Extras/Optimizations int vision_range; - int* neighbor_cache_count; // number of entities in each cells neighbor cache - GridMapEntity** neighbor_cache_entities; // preallocated array to hold neighbor entities + int *neighbor_cache_count; // number of entities in each cells neighbor cache + GridMapEntity **neighbor_cache_entities; // preallocated array to hold neighbor entities }; struct Drive { - Client* client; - float* observations; - float* actions; - float* rewards; - unsigned char* terminals; + Client *client; + float *observations; + float *actions; + float *rewards; + unsigned char *terminals; Log log; - Log* logs; + Log *logs; int num_agents; int active_agent_count; - int* active_agent_indices; + int *active_agent_indices; int action_type; int human_agent_idx; - Entity* entities; - Graph* topology_graph; + Entity *entities; int num_entities; int num_actors; int num_objects; int num_roads; int static_agent_count; - int* static_agent_indices; + int *static_agent_indices; int expert_static_agent_count; - int* expert_static_agent_indices; + int *expert_static_agent_indices; int timestep; int init_steps; int dynamics_model; - GridMap* grid_map; - int* neighbor_offsets; + GridMap *grid_map; + int *neighbor_offsets; int scenario_length; + int termination_mode; float reward_vehicle_collision; float reward_offroad_collision; - float reward_ade; - char* map_name; + char *map_name; float world_mean_x; float world_mean_y; float dt; float reward_goal; float reward_goal_post_respawn; + float reward_lane_align; + float reward_vel_align; float goal_radius; + float goal_speed; int max_controlled_agents; int logs_capacity; int goal_behavior; - char* ini_file; - char* scenario_id; + float goal_target_distance; + char *ini_file; + char *scenario_id; int collision_behavior; int offroad_behavior; int sdc_track_index; int num_tracks_to_predict; - int* tracks_to_predict_indices; + int *tracks_to_predict_indices; int init_mode; int control_mode; @@ -377,60 +390,82 @@ struct Drive { float offroad_weight_ub; float goal_weight_lb; float goal_weight_ub; - float* collision_weights; - float* offroad_weights; - float* goal_weights; + float *collision_weights; + float *offroad_weights; + float *goal_weights; // Entropy conditioning bool use_ec; float entropy_weight_lb; float entropy_weight_ub; - float* entropy_weights; + float *entropy_weights; // Discount conditioning bool use_dc; float discount_weight_lb; float discount_weight_ub; - float* discount_weights; - //fixed population play - Co_Player_Log co_player_log; - Co_Player_Log* co_player_logs; + float *discount_weights; + // fixed population play + Log co_player_log; + Log *co_player_logs; int num_co_players; int num_ego_agents; - int* co_player_ids; - int* ego_agent_ids; + int *co_player_ids; + int *ego_agent_ids; bool population_play; - - + // Rendering + int render_mode; // RENDER_OFF, RENDER_HEADLESS, or RENDER_WINDOW + char video_basename[256]; // Full mp4 basename (without ".mp4") set by Python via vec_set_video_suffix. Defaults to + // "render". }; -void add_log(Drive* env) { - +void add_log(Drive *env) { for (int i = 0; i < env->active_agent_count; i++) { - Entity* e = &env->entities[env->active_agent_indices[i]]; + Entity *e = &env->entities[env->active_agent_indices[i]]; - if (e->is_ego) { - // ALWAYS update regular logs for all ego agents - if (e->reached_goal_this_episode) - env->log.completion_rate += 1.0f; + // Common metrics for all agents + env->log.goals_reached_this_episode += e->goals_reached_this_episode; + env->log.goals_sampled_this_episode += e->goals_sampled_this_episode; + if (e->is_ego) { + // EGO agent logging int offroad = env->logs[i].offroad_rate; env->log.offroad_rate += offroad; int collided = env->logs[i].collision_rate; env->log.collision_rate += collided; - int num_goals_reached = env->logs[i].num_goals_reached; - env->log.num_goals_reached += num_goals_reached; + float offroad_per_agent = env->logs[i].offroad_per_agent; + env->log.offroad_per_agent += offroad_per_agent; + float collisions_per_agent = env->logs[i].collisions_per_agent; + env->log.collisions_per_agent += collisions_per_agent; + + float frac_goal_reached = e->goals_reached_this_episode / e->goals_sampled_this_episode; + + // Calculate threshold based on goals sampled + float threshold = 0.99f; // Default threshold for 1 goal + if (e->goals_sampled_this_episode == 2.0f) { + threshold = 0.5f; // Require ≥50% completion for 2 goals + } else if (e->goals_sampled_this_episode < 5.0f) { + threshold = 0.8f; // Require ≥80% completion for 3-4 goals + } else { + threshold = 0.9f; // Require ≥90% completion for 5+ goals + } + + // Use the "before goal" flag in respawn AND stop modes so that a + // post-goal rear-end (which the SDC can't avoid once it has stopped) + // does not deny the score award. + int collision_occurred = (env->goal_behavior == GOAL_RESPAWN || env->goal_behavior == GOAL_STOP) + ? e->collided_before_goal + : env->logs[i].collision_rate; - if (e->reached_goal_this_episode && !e->collided_before_goal) { + if (frac_goal_reached > threshold && !collision_occurred) { env->log.score += 1.0f; } - if (!offroad && !collided && !e->reached_goal_this_episode) { + if (!offroad && !collided && frac_goal_reached < 1.0f) { env->log.dnf_rate += 1.0f; } int lane_aligned = env->logs[i].lane_alignment_rate; env->log.lane_alignment_rate += lane_aligned; - float displacement_error = env->logs[i].avg_displacement_error; - env->log.avg_displacement_error += displacement_error; + env->log.speed_at_goal += env->logs[i].speed_at_goal; env->log.episode_length += env->logs[i].episode_length; env->log.episode_return += env->logs[i].episode_return; @@ -438,81 +473,89 @@ void add_log(Drive* env) { env->log.expert_static_agent_count += env->expert_static_agent_count; env->log.static_agent_count += env->static_agent_count; env->log.n += 1.0f; - } - // Process co-player agents (separate if, not else-if!) + if (e->is_co_player && env->co_player_logs != NULL) { - if (e->reached_goal_this_episode) - env->co_player_log.co_player_completion_rate += 1.0f; - - int co_offroad = env->co_player_logs[i].co_player_offroad_rate; - env->co_player_log.co_player_offroad_rate += co_offroad; - int co_collided = env->co_player_logs[i].co_player_collision_rate; - env->co_player_log.co_player_collision_rate += co_collided; - int co_num_goals_reached = env->co_player_logs[i].co_player_num_goals_reached; - env->co_player_log.co_player_num_goals_reached += co_num_goals_reached; - - env->co_player_log.co_player_clean_collision_rate += - env->co_player_logs[i].co_player_clean_collision_rate; - - if (e->reached_goal_this_episode && !e->collided_before_goal) { - env->co_player_log.co_player_score += 1.0f; - env->co_player_log.co_player_perf += 1.0f; + int co_offroad = env->co_player_logs[i].offroad_rate; + env->co_player_log.offroad_rate += co_offroad; + int co_collided = env->co_player_logs[i].collision_rate; + env->co_player_log.collision_rate += co_collided; + float co_offroad_per_agent = env->co_player_logs[i].offroad_per_agent; + env->co_player_log.offroad_per_agent += co_offroad_per_agent; + float co_collisions_per_agent = env->co_player_logs[i].collisions_per_agent; + env->co_player_log.collisions_per_agent += co_collisions_per_agent; + + float co_frac_goal_reached = e->goals_reached_this_episode / e->goals_sampled_this_episode; + + // Calculate threshold for co-players + float co_threshold = 0.99f; + if (e->goals_sampled_this_episode == 2.0f) { + co_threshold = 0.5f; + } else if (e->goals_sampled_this_episode < 5.0f) { + co_threshold = 0.8f; + } else { + co_threshold = 0.9f; } - if (!co_offroad && !co_collided && !e->reached_goal_this_episode) { - env->co_player_log.co_player_dnf_rate += 1.0f; + // Same post-goal-collision exemption as the ego score check above. + int co_collision_occurred = (env->goal_behavior == GOAL_RESPAWN || env->goal_behavior == GOAL_STOP) + ? e->collided_before_goal + : env->co_player_logs[i].collision_rate; + + if (co_frac_goal_reached > co_threshold && !co_collision_occurred) { + env->co_player_log.score += 1.0f; + } + + if (!co_offroad && !co_collided && co_frac_goal_reached < 1.0f) { + env->co_player_log.dnf_rate += 1.0f; } - int co_lane_aligned = env->co_player_logs[i].co_player_lane_alignment_rate; - env->co_player_log.co_player_lane_alignment_rate += co_lane_aligned; - float co_displacement_error = env->co_player_logs[i].co_player_avg_displacement_error; - env->co_player_log.co_player_avg_displacement_error += co_displacement_error; - env->co_player_log.co_player_episode_return += - env->co_player_logs[i].co_player_episode_return; - env->co_player_log.co_player_episode_length += - env->co_player_logs[i].co_player_episode_length; + int co_lane_aligned = env->co_player_logs[i].lane_alignment_rate; + env->co_player_log.lane_alignment_rate += co_lane_aligned; + env->co_player_log.speed_at_goal += env->co_player_logs[i].speed_at_goal; + env->co_player_log.episode_return += env->co_player_logs[i].episode_return; + env->co_player_log.episode_length += env->co_player_logs[i].episode_length; - env->co_player_log.co_player_n += 1.0f; + env->co_player_log.n += 1.0f; } } } struct AdjListNode { int dest; - struct AdjListNode* next; + struct AdjListNode *next; }; struct Graph { int V; - struct AdjListNode** array; + struct AdjListNode **array; }; // Function to create a new adjacency list node -struct AdjListNode* newAdjListNode(int dest) { - struct AdjListNode* newNode = malloc(sizeof(struct AdjListNode)); +struct AdjListNode *newAdjListNode(int dest) { + struct AdjListNode *newNode = malloc(sizeof(struct AdjListNode)); newNode->dest = dest; newNode->next = NULL; return newNode; } // Function to create a graph of V vertices -struct Graph* createGraph(int V) { - struct Graph* graph = malloc(sizeof(struct Graph)); +struct Graph *createGraph(int V) { + struct Graph *graph = malloc(sizeof(struct Graph)); graph->V = V; - graph->array = calloc(V, sizeof(struct AdjListNode*)); + graph->array = calloc(V, sizeof(struct AdjListNode *)); return graph; } // Function to get next lanes from a given lane entity index // Returns the number of next lanes found, fills next_lanes array with entity indices -int getNextLanes(struct Graph* graph, int entity_idx, int* next_lanes, int max_lanes) { +int getNextLanes(struct Graph *graph, int entity_idx, int *next_lanes, int max_lanes) { if (!graph || entity_idx < 0 || entity_idx >= graph->V) { return 0; } int count = 0; - struct AdjListNode* node = graph->array[entity_idx]; + struct AdjListNode *node = graph->array[entity_idx]; while (node && count < max_lanes) { next_lanes[count] = node->dest; @@ -524,13 +567,14 @@ int getNextLanes(struct Graph* graph, int entity_idx, int* next_lanes, int max_l } // Function to free the topology graph -void freeTopologyGraph(struct Graph* graph) { - if (!graph) return; +void freeTopologyGraph(struct Graph *graph) { + if (!graph) + return; for (int i = 0; i < graph->V; i++) { - struct AdjListNode* node = graph->array[i]; + struct AdjListNode *node = graph->array[i]; while (node) { - struct AdjListNode* temp = node; + struct AdjListNode *temp = node; node = node->next; free(temp); } @@ -540,19 +584,24 @@ void freeTopologyGraph(struct Graph* graph) { free(graph); } - -Entity* load_map_binary(const char* filename, Drive* env) { - FILE* file = fopen(filename, "rb"); - if (!file) return NULL; - +Entity *load_map_binary(const char *filename, Drive *env) { + FILE *file = fopen(filename, "rb"); + if (!file) + return NULL; // Read sdc_track_index - fread(&env->sdc_track_index, sizeof(int), 1, file); + if (fread(&env->sdc_track_index, sizeof(int), 1, file) != 1) { + fclose(file); + return NULL; + } // Read tracks_to_predict - fread(&env->num_tracks_to_predict, sizeof(int), 1, file); + if (fread(&env->num_tracks_to_predict, sizeof(int), 1, file) != 1) { + fclose(file); + return NULL; + } if (env->num_tracks_to_predict > 0) { - env->tracks_to_predict_indices = (int*)malloc(env->num_tracks_to_predict * sizeof(int)); + env->tracks_to_predict_indices = (int *)malloc(env->num_tracks_to_predict * sizeof(int)); for (int i = 0; i < env->num_tracks_to_predict; i++) { fread(&env->tracks_to_predict_indices[i], sizeof(int), 1, file); @@ -564,25 +613,38 @@ Entity* load_map_binary(const char* filename, Drive* env) { fread(&env->num_objects, sizeof(int), 1, file); fread(&env->num_roads, sizeof(int), 1, file); env->num_entities = env->num_objects + env->num_roads; - Entity* entities = (Entity*)malloc(env->num_entities * sizeof(Entity)); + Entity *entities = (Entity *)malloc(env->num_entities * sizeof(Entity)); for (int i = 0; i < env->num_entities; i++) { - // Read base entity data - fread(&entities[i].scenario_id, sizeof(int), 1, file); - fread(&entities[i].type, sizeof(int), 1, file); - fread(&entities[i].id, sizeof(int), 1, file); - fread(&entities[i].array_size, sizeof(int), 1, file); + // Read base entity data + if (fread(&entities[i].scenario_id, sizeof(int), 1, file) != 1 || + fread(&entities[i].type, sizeof(int), 1, file) != 1 || fread(&entities[i].id, sizeof(int), 1, file) != 1 || + fread(&entities[i].array_size, sizeof(int), 1, file) != 1) { + // File truncated - adjust entity count and break + env->num_entities = i; + env->num_objects = (i < env->num_objects) ? i : env->num_objects; + env->num_roads = env->num_entities - env->num_objects; + break; + } + // Validate array_size is reasonable (max 1000 timesteps = 100s at 0.1s dt) + if (entities[i].array_size <= 0 || entities[i].array_size > 1000) { + env->num_entities = i; + env->num_objects = (i < env->num_objects) ? i : env->num_objects; + env->num_roads = env->num_entities - env->num_objects; + break; + } // Allocate arrays based on type int size = entities[i].array_size; - entities[i].traj_x = (float*)malloc(size * sizeof(float)); - entities[i].traj_y = (float*)malloc(size * sizeof(float)); - entities[i].traj_z = (float*)malloc(size * sizeof(float)); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type + entities[i].traj_x = (float *)malloc(size * sizeof(float)); + entities[i].traj_y = (float *)malloc(size * sizeof(float)); + entities[i].traj_z = (float *)malloc(size * sizeof(float)); + if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || + entities[i].type == CYCLIST) { // Object type // Allocate arrays for object-specific data - entities[i].traj_vx = (float*)malloc(size * sizeof(float)); - entities[i].traj_vy = (float*)malloc(size * sizeof(float)); - entities[i].traj_vz = (float*)malloc(size * sizeof(float)); - entities[i].traj_heading = (float*)malloc(size * sizeof(float)); - entities[i].traj_valid = (int*)malloc(size * sizeof(int)); + entities[i].traj_vx = (float *)malloc(size * sizeof(float)); + entities[i].traj_vy = (float *)malloc(size * sizeof(float)); + entities[i].traj_vz = (float *)malloc(size * sizeof(float)); + entities[i].traj_heading = (float *)malloc(size * sizeof(float)); + entities[i].traj_valid = (int *)malloc(size * sizeof(int)); } else { // Roads don't use these arrays entities[i].traj_vx = NULL; @@ -595,7 +657,8 @@ Entity* load_map_binary(const char* filename, Drive* env) { fread(entities[i].traj_x, sizeof(float), size, file); fread(entities[i].traj_y, sizeof(float), size, file); fread(entities[i].traj_z, sizeof(float), size, file); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type + if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || + entities[i].type == CYCLIST) { // Object type fread(entities[i].traj_vx, sizeof(float), size, file); fread(entities[i].traj_vy, sizeof(float), size, file); fread(entities[i].traj_vz, sizeof(float), size, file); @@ -616,32 +679,31 @@ Entity* load_map_binary(const char* filename, Drive* env) { return entities; } -void set_start_position(Drive* env){ - //InitWindow(800, 600, "GPU Drive"); - //BeginDrawing(); - for(int i = 0; i < env->num_entities; i++){ +void set_start_position(Drive *env) { + for (int i = 0; i < env->num_entities; i++) { int is_active = 0; - for(int j = 0; j < env->active_agent_count; j++){ - if(env->active_agent_indices[j] == i){ + for (int j = 0; j < env->active_agent_count; j++) { + if (env->active_agent_indices[j] == i) { is_active = 1; break; } } - Entity* e = &env->entities[i]; + Entity *e = &env->entities[i]; // Clamp init_steps to ensure we don't go out of bounds int step = env->init_steps; - if (step >= e->array_size) step = e->array_size - 1; - if (step < 0) step = 0; + if (step >= e->array_size) + step = e->array_size - 1; + if (step < 0) + step = 0; e->x = e->traj_x[step]; e->y = e->traj_y[step]; e->z = e->traj_z[step]; - - if(e->type > CYCLIST || e->type == 0){ + if (e->type > CYCLIST || e->type == 0) { continue; } - if(is_active == 0){ + if (is_active == 0) { e->vx = 0; e->vy = 0; e->vz = 0; @@ -656,15 +718,16 @@ void set_start_position(Drive* env){ e->heading_y = sinf(e->heading); e->valid = e->traj_valid[env->init_steps]; e->collision_state = 0; - e->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision - e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad - e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal - e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned - e->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; // avg displacement error - e->cumulative_displacement = 0.0f; - e->displacement_sample_count = 0; + e->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision + e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad + e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal + e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned + e->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; // far from lane + e->metrics_array[LANE_ANGLE_IDX] = 0.0f; // no alignment + e->current_lane_idx = -1; + e->current_lane_geometry_idx = -1; e->respawn_timestep = -1; - e->stopped = 0; + e->stopped = 0; e->removed = 0; e->respawn_count = 0; @@ -676,33 +739,35 @@ void set_start_position(Drive* env){ e->steering_angle = 0.0f; e->wheelbase = 0.6f * e->length; } - //EndDrawing(); } -int getGridIndex(Drive* env, float x1, float y1) { - if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x || env->grid_map->bottom_right_y >= env->grid_map->top_left_y) { - return -1; // Invalid grid coordinates +int getGridIndex(Drive *env, float x1, float y1) { + if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x || + env->grid_map->bottom_right_y >= env->grid_map->top_left_y) { + return -1; // Invalid grid coordinates } - float relativeX = x1 - env->grid_map->top_left_x; // Distance from left - float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom - int gridX = (int)(relativeX / GRID_CELL_SIZE); // Column index - int gridY = (int)(relativeY / GRID_CELL_SIZE); // Row index + float relativeX = x1 - env->grid_map->top_left_x; // Distance from left + float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom + int gridX = (int)(relativeX / GRID_CELL_SIZE); // Column index + int gridY = (int)(relativeY / GRID_CELL_SIZE); // Row index if (gridX < 0 || gridX >= env->grid_map->grid_cols || gridY < 0 || gridY >= env->grid_map->grid_rows) { - return -1; // Return -1 for out of bounds + return -1; // Return -1 for out of bounds } - int index = (gridY*env->grid_map->grid_cols) + gridX; + int index = (gridY * env->grid_map->grid_cols) + gridX; return index; } -void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry_idx, int* cell_entities_insert_index){ - if(grid_index == -1){ +void add_entity_to_grid(Drive *env, int grid_index, int entity_idx, int geometry_idx, int *cell_entities_insert_index) { + if (grid_index == -1) { return; } int count = cell_entities_insert_index[grid_index]; - if(count >= env->grid_map->cell_entities_count[grid_index]) { - printf("Error: Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): %d\n", grid_index, count, env->grid_map->cell_entities_count[grid_index]); + if (count >= env->grid_map->cell_entities_count[grid_index]) { + printf("Error: Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): " + "%d\n", + grid_index, count, env->grid_map->cell_entities_count[grid_index]); return; } @@ -711,72 +776,9 @@ void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry cell_entities_insert_index[grid_index] = count + 1; } - -void init_topology_graph(Drive* env){ - // Count ROAD_LANE entities - int road_lane_count = 0; - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type == ROAD_LANE){ - road_lane_count++; - } - } - - if(road_lane_count == 0){ - env->topology_graph = NULL; - return; - } - - // Create graph with all entities as vertices (we'll only use ROAD_LANE indices) - env->topology_graph = createGraph(env->num_entities); - - // Connect ROAD_LANE entities based on geometric connectivity - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type != ROAD_LANE) continue; - - Entity* lane_i = &env->entities[i]; - if(lane_i->array_size < 2) continue; // Need at least 2 points - - // Get end point of current lane - float end_x = lane_i->traj_x[lane_i->array_size - 1]; - float end_y = lane_i->traj_y[lane_i->array_size - 1]; - float end_vector_x = lane_i->traj_x[lane_i->array_size - 1] - lane_i->traj_x[lane_i->array_size - 2]; - float end_vector_y = lane_i->traj_y[lane_i->array_size - 1] - lane_i->traj_y[lane_i->array_size - 2]; - float end_heading = atan2f(end_vector_y, end_vector_x); - - // Find lanes that start near this lane's end - for(int j = 0; j < env->num_entities; j++){ - if(i == j || env->entities[j].type != ROAD_LANE) continue; - - Entity* lane_j = &env->entities[j]; - if(lane_j->array_size < 2) continue; - - // Get start point of potential next lane - float start_x = lane_j->traj_x[0]; - float start_y = lane_j->traj_y[0]; - float start_vector_x = lane_j->traj_x[1] - lane_j->traj_x[0]; - float start_vector_y = lane_j->traj_y[1] - lane_j->traj_y[0]; - float start_heading = atan2f(start_vector_y, start_vector_x); - - // Check if end of lane_i is close to start of lane_j - float distance = relative_distance_2d(end_x, end_y, start_x, start_y); - float heading_diff = fabsf(end_heading - start_heading); - - // Lane connectivity thresholds: - // - 0.01m distance: lanes must connect within 1cm (very strict for clean topology) - // - 0.1 (~5.7 degrees) heading difference: allow slight curves - if(distance < 0.01f && heading_diff < 0.1f){ - // Add directed edge from i to j (lane i connects to lane j) - struct AdjListNode* node = newAdjListNode(j); - node->next = env->topology_graph->array[i]; - env->topology_graph->array[i] = node; - } - } - } -} - -void init_grid_map(Drive* env){ +void init_grid_map(Drive *env) { // Allocate memory for the grid map structure - env->grid_map = (GridMap*)malloc(sizeof(GridMap)); + env->grid_map = (GridMap *)malloc(sizeof(GridMap)); // Find top left and bottom right points of the map float top_left_x; @@ -784,23 +786,29 @@ void init_grid_map(Drive* env){ float bottom_right_x; float bottom_right_y; int first_valid_point = 0; - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ + for (int i = 0; i < env->num_entities; i++) { + if (env->entities[i].type > 3 && env->entities[i].type < 7) { // Check all points in the trajectory for road elements - Entity* e = &env->entities[i]; - for(int j = 0; j < e->array_size; j++){ - if(e->traj_x[j] == INVALID_POSITION) continue; - if(e->traj_y[j] == INVALID_POSITION) continue; - if(!first_valid_point) { + Entity *e = &env->entities[i]; + for (int j = 0; j < e->array_size; j++) { + if (e->traj_x[j] == INVALID_POSITION) + continue; + if (e->traj_y[j] == INVALID_POSITION) + continue; + if (!first_valid_point) { top_left_x = bottom_right_x = e->traj_x[j]; top_left_y = bottom_right_y = e->traj_y[j]; first_valid_point = true; continue; } - if(e->traj_x[j] < top_left_x) top_left_x = e->traj_x[j]; - if(e->traj_x[j] > bottom_right_x) bottom_right_x = e->traj_x[j]; - if(e->traj_y[j] > top_left_y) top_left_y = e->traj_y[j]; - if(e->traj_y[j] < bottom_right_y) bottom_right_y = e->traj_y[j]; + if (e->traj_x[j] < top_left_x) + top_left_x = e->traj_x[j]; + if (e->traj_x[j] > bottom_right_x) + bottom_right_x = e->traj_x[j]; + if (e->traj_y[j] > top_left_y) + top_left_y = e->traj_y[j]; + if (e->traj_y[j] < bottom_right_y) + bottom_right_y = e->traj_y[j]; } } } @@ -817,66 +825,71 @@ void init_grid_map(Drive* env){ float grid_height = top_left_y - bottom_right_y; env->grid_map->grid_cols = ceil(grid_width / GRID_CELL_SIZE); env->grid_map->grid_rows = ceil(grid_height / GRID_CELL_SIZE); - int grid_cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; - env->grid_map->cells = (GridMapEntity**)calloc(grid_cell_count, sizeof(GridMapEntity*)); - env->grid_map->cell_entities_count = (int*)calloc(grid_cell_count, sizeof(int)); + int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows; + env->grid_map->cells = (GridMapEntity **)calloc(grid_cell_count, sizeof(GridMapEntity *)); + env->grid_map->cell_entities_count = (int *)calloc(grid_cell_count, sizeof(int)); // Calculate number of entities in each grid cell - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ - for(int j = 0; j < env->entities[i].array_size - 1; j++){ - float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2; - float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; + for (int i = 0; i < env->num_entities; i++) { + if (env->entities[i].type > 3 && env->entities[i].type < 7) { + for (int j = 0; j < env->entities[i].array_size - 1; j++) { + float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j + 1]) / 2; + float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j + 1]) / 2; int grid_index = getGridIndex(env, x_center, y_center); - env->grid_map->cell_entities_count[grid_index]++; + if (grid_index != -1) { + env->grid_map->cell_entities_count[grid_index]++; + } } } } - int cell_entities_insert_index[grid_cell_count]; // Helper array for insertion index - memset(cell_entities_insert_index, 0, grid_cell_count * sizeof(int)); + // Use heap allocation instead of VLA to avoid stack overflow on large maps + int *cell_entities_insert_index = (int *)calloc(grid_cell_count, sizeof(int)); // Initialize grid cells - for(int grid_index = 0; grid_index < grid_cell_count; grid_index++){ - env->grid_map->cells[grid_index] = (GridMapEntity*)calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity)); + for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) { + env->grid_map->cells[grid_index] = + (GridMapEntity *)calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity)); } - for(int i = 0;inum_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ // NOTE: Only Road Edges, Lines, and Lanes in grid map - for(int j = 0; j < env->entities[i].array_size - 1; j++){ - float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2; - float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; + for (int i = 0; i < env->num_entities; i++) { + if (env->entities[i].type > 3 && + env->entities[i].type < 7) { // NOTE: Only Road Edges, Lines, and Lanes in grid map + for (int j = 0; j < env->entities[i].array_size - 1; j++) { + float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j + 1]) / 2; + float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j + 1]) / 2; int grid_index = getGridIndex(env, x_center, y_center); add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index); } } } + free(cell_entities_insert_index); } -void init_neighbor_offsets(Drive* env) { +void init_neighbor_offsets(Drive *env) { // Allocate memory for the offsets - env->neighbor_offsets = (int*)calloc(env->grid_map->vision_range*env->grid_map->vision_range*2, sizeof(int)); + env->neighbor_offsets = (int *)calloc(env->grid_map->vision_range * env->grid_map->vision_range * 2, sizeof(int)); // neighbor offsets in a spiral pattern int dx[] = {1, 0, -1, 0}; int dy[] = {0, 1, 0, -1}; - int x = 0; // Current x offset - int y = 0; // Current y offset - int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down) - int steps_to_take = 1; // Number of steps in current direction - int steps_taken = 0; // Steps taken in current direction + int x = 0; // Current x offset + int y = 0; // Current y offset + int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down) + int steps_to_take = 1; // Number of steps in current direction + int steps_taken = 0; // Steps taken in current direction int segments_completed = 0; // Count of direction segments completed - int total = 0; // Total offsets added - int max_offsets = env->grid_map->vision_range*env->grid_map->vision_range; + int total = 0; // Total offsets added + int max_offsets = env->grid_map->vision_range * env->grid_map->vision_range; // Start at center (0,0) int curr_idx = 0; - env->neighbor_offsets[curr_idx++] = 0; // x offset - env->neighbor_offsets[curr_idx++] = 0; // y offset + env->neighbor_offsets[curr_idx++] = 0; // x offset + env->neighbor_offsets[curr_idx++] = 0; // y offset total++; // Generate spiral pattern while (total < max_offsets) { @@ -884,16 +897,17 @@ void init_neighbor_offsets(Drive* env) { x += dx[dir]; y += dy[dir]; // Only add if within vision range bounds - if (abs(x) <= env->grid_map->vision_range/2 && abs(y) <= env->grid_map->vision_range/2) { + if (abs(x) <= env->grid_map->vision_range / 2 && abs(y) <= env->grid_map->vision_range / 2) { env->neighbor_offsets[curr_idx++] = x; env->neighbor_offsets[curr_idx++] = y; total++; } steps_taken++; // Check if we need to change direction - if(steps_taken != steps_to_take) continue; - steps_taken = 0; // Reset steps taken - dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down) + if (steps_taken != steps_to_take) + continue; + steps_taken = 0; // Reset steps taken + dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down) segments_completed++; // Increase step length every two direction changes if (segments_completed % 2 == 0) { @@ -902,65 +916,64 @@ void init_neighbor_offsets(Drive* env) { } } -void cache_neighbor_offsets(Drive* env){ +void cache_neighbor_offsets(Drive *env) { int count = 0; - int cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; - env->grid_map->neighbor_cache_entities = (GridMapEntity**)calloc(cell_count, sizeof(GridMapEntity*)); - env->grid_map->neighbor_cache_count = (int*)calloc(cell_count + 1, sizeof(int)); - for(int i = 0; i < cell_count; i++){ - int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates + int cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows; + env->grid_map->neighbor_cache_entities = (GridMapEntity **)calloc(cell_count, sizeof(GridMapEntity *)); + env->grid_map->neighbor_cache_count = (int *)calloc(cell_count + 1, sizeof(int)); + for (int i = 0; i < cell_count; i++) { + int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates int cell_y = i / env->grid_map->grid_cols; int current_cell_neighbor_count = 0; - for(int j = 0; j < env->grid_map->vision_range*env->grid_map->vision_range; j++){ - int x = cell_x + env->neighbor_offsets[j*2]; - int y = cell_y + env->neighbor_offsets[j*2+1]; - int grid_index = env->grid_map->grid_cols*y + x; - if(x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) continue; + for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) { + int x = cell_x + env->neighbor_offsets[j * 2]; + int y = cell_y + env->neighbor_offsets[j * 2 + 1]; + int grid_index = env->grid_map->grid_cols * y + x; + if (x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) + continue; int grid_count = env->grid_map->cell_entities_count[grid_index]; current_cell_neighbor_count += grid_count; } env->grid_map->neighbor_cache_count[i] = current_cell_neighbor_count; count += current_cell_neighbor_count; - if(current_cell_neighbor_count == 0) { + if (current_cell_neighbor_count == 0) { env->grid_map->neighbor_cache_entities[i] = NULL; continue; } - env->grid_map->neighbor_cache_entities[i] = (GridMapEntity*)calloc(current_cell_neighbor_count, sizeof(GridMapEntity)); + env->grid_map->neighbor_cache_entities[i] = + (GridMapEntity *)calloc(current_cell_neighbor_count, sizeof(GridMapEntity)); } env->grid_map->neighbor_cache_count[cell_count] = count; - for(int i = 0; i < cell_count; i ++){ - int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates + for (int i = 0; i < cell_count; i++) { + int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates int cell_y = i / env->grid_map->grid_cols; int base_index = 0; - for(int j = 0; j < env->grid_map->vision_range*env->grid_map->vision_range; j++){ - int x = cell_x + env->neighbor_offsets[j*2]; - int y = cell_y + env->neighbor_offsets[j*2+1]; - int grid_index = env->grid_map->grid_cols*y + x; - if(x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) continue; + for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) { + int x = cell_x + env->neighbor_offsets[j * 2]; + int y = cell_y + env->neighbor_offsets[j * 2 + 1]; + int grid_index = env->grid_map->grid_cols * y + x; + if (x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) + continue; int grid_count = env->grid_map->cell_entities_count[grid_index]; // Skip if no entities or source is NULL - if(grid_count == 0 || env->grid_map->cells[grid_index] == NULL) { + if (grid_count == 0 || env->grid_map->cells[grid_index] == NULL) { continue; } int src_idx = grid_index; int dst_idx = base_index; // Copy grid_count pairs (entity_idx, geometry_idx) at once - memcpy(&env->grid_map->neighbor_cache_entities[i][dst_idx], - env->grid_map->cells[src_idx], - grid_count * sizeof(GridMapEntity)); - // for(int k = 0; k < grid_count; k++){ - // env->grid_map->neighbor_cache_entities[i][dst_idx + k] = env->grid_map->cells[src_idx][k]; - // } + memcpy(&env->grid_map->neighbor_cache_entities[i][dst_idx], env->grid_map->cells[src_idx], + grid_count * sizeof(GridMapEntity)); base_index += grid_count; } } } -int get_neighbor_cache_entities(Drive* env, int cell_idx, GridMapEntity* entities, int max_entities) { - GridMap* grid_map = env->grid_map; +int get_neighbor_cache_entities(Drive *env, int cell_idx, GridMapEntity *entities, int max_entities) { + GridMap *grid_map = env->grid_map; if (cell_idx < 0 || cell_idx >= (grid_map->grid_cols * grid_map->grid_rows)) { return 0; // Invalid cell index } @@ -974,14 +987,15 @@ int get_neighbor_cache_entities(Drive* env, int cell_idx, GridMapEntity* entitie return count; } -void set_means(Drive* env) { +void set_means(Drive *env) { float mean_x = 0.0f; float mean_y = 0.0f; int64_t point_count = 0; // Compute single mean for all entities (vehicles and roads) for (int i = 0; i < env->num_entities; i++) { - if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) { + if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || + env->entities[i].type == CYCLIST) { for (int j = 0; j < env->entities[i].array_size; j++) { // Assume a validity flag exists (e.g., valid[j]); adjust if not available if (env->entities[i].traj_valid[j]) { // Add validity check if applicable @@ -1001,9 +1015,11 @@ void set_means(Drive* env) { env->world_mean_x = mean_x; env->world_mean_y = mean_y; for (int i = 0; i < env->num_entities; i++) { - if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST || env->entities[i].type >= 4) { + if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || + env->entities[i].type == CYCLIST || env->entities[i].type >= 4) { for (int j = 0; j < env->entities[i].array_size; j++) { - if(env->entities[i].traj_x[j] == INVALID_POSITION) continue; + if (env->entities[i].traj_x[j] == INVALID_POSITION) + continue; env->entities[i].traj_x[j] -= mean_x; env->entities[i].traj_y[j] -= mean_y; } @@ -1011,11 +1027,10 @@ void set_means(Drive* env) { env->entities[i].goal_position_y -= mean_y; } } - } -void move_expert(Drive* env, float* actions, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; +void move_expert(Drive *env, float *actions, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; int t = env->timestep; if (t < 0 || t >= agent->array_size) { agent->x = INVALID_POSITION; @@ -1058,7 +1073,8 @@ bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) float cross = dx1 * dy2 - dy1 * dx2; // If lines are parallel - if (cross == 0) return false; + if (cross == 0) + return false; // Calculate relative vectors between start points float dx3 = p1[0] - q1[0]; @@ -1072,10 +1088,12 @@ bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) return (s >= 0 && s <= 1 && t >= 0 && t <= 1); } -int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int max_size, const int (*local_offsets)[2], int offset_size) { +int checkNeighbors(Drive *env, float x, float y, GridMapEntity *entity_list, int max_size, + const int (*local_offsets)[2], int offset_size) { // Get the grid index for the given position (x, y) int index = getGridIndex(env, x, y); - if (index == -1) return 0; // Return 0 size if position invalid + if (index == -1) + return 0; // Return 0 size if position invalid // Calculate 2D grid coordinates int cellsX = env->grid_map->grid_cols; int gridX = index % cellsX; @@ -1086,7 +1104,8 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int int nx = gridX + local_offsets[i][0]; int ny = gridY + local_offsets[i][1]; // Ensure the neighbor is within grid bounds - if(nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) continue; + if (nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) + continue; int neighborIndex = ny * env->grid_map->grid_cols + nx; int count = env->grid_map->cell_entities_count[neighborIndex]; // Add entities from this cell to the list @@ -1101,7 +1120,7 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int return entity_list_count; } -int check_aabb_collision(Entity* car1, Entity* car2) { +int check_aabb_collision(Entity *car1, Entity *car2) { // Get car corners in world space float cos1 = car1->heading_x; float sin1 = car1->heading_y; @@ -1119,79 +1138,83 @@ int check_aabb_collision(Entity* car1, Entity* car2) { {car1->x + (half_len1 * cos1 - half_width1 * sin1), car1->y + (half_len1 * sin1 + half_width1 * cos1)}, {car1->x + (half_len1 * cos1 + half_width1 * sin1), car1->y + (half_len1 * sin1 - half_width1 * cos1)}, {car1->x + (-half_len1 * cos1 - half_width1 * sin1), car1->y + (-half_len1 * sin1 + half_width1 * cos1)}, - {car1->x + (-half_len1 * cos1 + half_width1 * sin1), car1->y + (-half_len1 * sin1 - half_width1 * cos1)} - }; + {car1->x + (-half_len1 * cos1 + half_width1 * sin1), car1->y + (-half_len1 * sin1 - half_width1 * cos1)}}; // Calculate car2's corners in world space float car2_corners[4][2] = { {car2->x + (half_len2 * cos2 - half_width2 * sin2), car2->y + (half_len2 * sin2 + half_width2 * cos2)}, {car2->x + (half_len2 * cos2 + half_width2 * sin2), car2->y + (half_len2 * sin2 - half_width2 * cos2)}, {car2->x + (-half_len2 * cos2 - half_width2 * sin2), car2->y + (-half_len2 * sin2 + half_width2 * cos2)}, - {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)} - }; + {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)}}; // Get the axes to check (normalized vectors perpendicular to each edge) float axes[4][2] = { - {cos1, sin1}, // Car1's length axis - {-sin1, cos1}, // Car1's width axis - {cos2, sin2}, // Car2's length axis - {-sin2, cos2} // Car2's width axis + {cos1, sin1}, // Car1's length axis + {-sin1, cos1}, // Car1's width axis + {cos2, sin2}, // Car2's length axis + {-sin2, cos2} // Car2's width axis }; // Check each axis - for(int i = 0; i < 4; i++) { + for (int i = 0; i < 4; i++) { float min1 = INFINITY, max1 = -INFINITY; float min2 = INFINITY, max2 = -INFINITY; // Project car1's corners onto the axis - for(int j = 0; j < 4; j++) { + for (int j = 0; j < 4; j++) { float proj = car1_corners[j][0] * axes[i][0] + car1_corners[j][1] * axes[i][1]; min1 = fminf(min1, proj); max1 = fmaxf(max1, proj); } // Project car2's corners onto the axis - for(int j = 0; j < 4; j++) { + for (int j = 0; j < 4; j++) { float proj = car2_corners[j][0] * axes[i][0] + car2_corners[j][1] * axes[i][1]; min2 = fminf(min2, proj); max2 = fmaxf(max2, proj); } // If there's a gap on this axis, the boxes don't intersect - if(max1 < min2 || min1 > max2) { - return 0; // No collision + if (max1 < min2 || min1 > max2) { + return 0; // No collision } } // If we get here, there's no separating axis, so the boxes intersect - return 1; // Collision + return 1; // Collision } -int collision_check(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; +int collision_check(Drive *env, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; - if(agent->x == INVALID_POSITION ) return -1; + if (agent->x == INVALID_POSITION) + return -1; int car_collided_with_index = -1; - if (agent->respawn_timestep != -1) return car_collided_with_index; // Skip respawning entities + if (agent->respawn_timestep != -1) + return car_collided_with_index; // Skip respawning entities - for(int i = 0; i < MAX_AGENTS; i++){ + for (int i = 0; i < MAX_AGENTS; i++) { int index = -1; - if(i < env->active_agent_count){ + if (i < env->active_agent_count) { index = env->active_agent_indices[i]; - } else if (i < env->num_actors){ + } else if (i < env->num_actors) { index = env->static_agent_indices[i - env->active_agent_count]; } - if(index == -1) continue; - if(index == agent_idx) continue; - Entity* entity = &env->entities[index]; - if (entity->respawn_timestep != -1) continue; // Skip respawning entities + if (index == -1) + continue; + if (index == agent_idx) + continue; + Entity *entity = &env->entities[index]; + if (entity->respawn_timestep != -1) + continue; // Skip respawning entities float x1 = entity->x; float y1 = entity->y; - float dist = ((x1 - agent->x)*(x1 - agent->x) + (y1 - agent->y)*(y1 - agent->y)); - if(dist > 225.0f) continue; - if(check_aabb_collision(agent, entity)) { + float dist = ((x1 - agent->x) * (x1 - agent->x) + (y1 - agent->y) * (y1 - agent->y)); + if (dist > 225.0f) + continue; + if (check_aabb_collision(agent, entity)) { car_collided_with_index = index; break; } @@ -1200,13 +1223,127 @@ int collision_check(Drive* env, int agent_idx) { return car_collided_with_index; } -int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) { +// ============================================================================ +// Lane Alignment Helper Functions (from GIGAFLOW / PufferDrive 3.0) +// ============================================================================ + +// Normalize heading to [-π, π] +float normalize_heading(float h) { + while (h > M_PI) + h -= 2.0f * M_PI; + while (h < -M_PI) + h += 2.0f * M_PI; + return h; +} + +// Compute signed heading difference between agent and lane, normalized to [-π, π] +float compute_heading_diff(float agent_heading, float lane_heading) { + float diff = normalize_heading(agent_heading - lane_heading); + return diff; +} + +// Find closest segment on lane polyline, return signed lateral distance +// Positive = right of lane center, Negative = left of lane center +float find_closest_segment_on_lane(Entity *lane, float px, float py, int *segment_idx) { + if (!lane || lane->array_size < 2) { + *segment_idx = 0; + return LANE_DISTANCE_NORMALIZATION; + } + + float min_dist = LANE_DISTANCE_NORMALIZATION * 10.0f; + int best_idx = 0; + float best_signed_dist = 0.0f; + + for (int i = 0; i < lane->array_size - 1; i++) { + float x1 = lane->traj_x[i]; + float y1 = lane->traj_y[i]; + float x2 = lane->traj_x[i + 1]; + float y2 = lane->traj_y[i + 1]; + + float dx = x2 - x1; + float dy = y2 - y1; + float seg_len_sq = dx * dx + dy * dy; + + if (seg_len_sq < 1e-6f) + continue; + + // Project point onto segment + float t = ((px - x1) * dx + (py - y1) * dy) / seg_len_sq; + t = fmaxf(0.0f, fminf(1.0f, t)); + + float closest_x = x1 + t * dx; + float closest_y = y1 + t * dy; + + float dist = sqrtf((px - closest_x) * (px - closest_x) + (py - closest_y) * (py - closest_y)); + + if (dist < min_dist) { + min_dist = dist; + best_idx = i; + + // Compute signed distance (cross product gives sign) + // Positive if point is to the right of the lane direction + float cross = dx * (py - y1) - dy * (px - x1); + best_signed_dist = (cross >= 0) ? dist : -dist; + } + } + + *segment_idx = best_idx; + return best_signed_dist; +} + +// Compute lane heading using weighted average of neighboring segments +float compute_multi_segment_alignment(Entity *lane, int segment_idx) { + if (!lane || lane->array_size < 2) + return 0.0f; + + // Clamp to valid range + if (segment_idx < 0) + segment_idx = 0; + if (segment_idx >= lane->array_size - 1) + segment_idx = lane->array_size - 2; + + float sum_heading = 0.0f; + float sum_weight = 0.0f; + + // Consider current segment and neighbors + for (int offset = -1; offset <= 1; offset++) { + int idx = segment_idx + offset; + if (idx < 0 || idx >= lane->array_size - 1) + continue; + + float dx = lane->traj_x[idx + 1] - lane->traj_x[idx]; + float dy = lane->traj_y[idx + 1] - lane->traj_y[idx]; + float heading = atan2f(dy, dx); + + // Weight by distance from main segment (center segment has highest weight) + float weight = (offset == 0) ? 2.0f : 1.0f; + + // Handle angle wrapping for averaging + if (sum_weight > 0) { + float diff = heading - (sum_heading / sum_weight); + if (diff > M_PI) + heading -= 2.0f * M_PI; + else if (diff < -M_PI) + heading += 2.0f * M_PI; + } + + sum_heading += weight * heading; + sum_weight += weight; + } + + return (sum_weight > 0) ? normalize_heading(sum_heading / sum_weight) : 0.0f; +} + +int check_lane_aligned(Entity *car, Entity *lane, int geometry_idx) { // Validate lane geometry length - if (!lane || lane->array_size < 2) return 0; + if (!lane || lane->array_size < 2) + return 0; // Clamp geometry index to valid segment range [0, array_size-2] - if (geometry_idx < 0) geometry_idx = 0; - if (geometry_idx >= lane->array_size - 1) geometry_idx = lane->array_size - 2; + if (geometry_idx < 0) + geometry_idx = 0; + if (geometry_idx >= lane->array_size - 1) + geometry_idx = lane->array_size - 2; // Compute local lane segment heading float heading_x1, heading_y1; @@ -1227,26 +1364,31 @@ int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) { float heading = (heading_1 + heading_2) / 2.0f; // Normalize to [-pi, pi] - if (heading > M_PI) heading -= 2.0f * M_PI; - if (heading < -M_PI) heading += 2.0f * M_PI; + if (heading > M_PI) + heading -= 2.0f * M_PI; + if (heading < -M_PI) + heading += 2.0f * M_PI; // Compute heading difference float car_heading = car->heading; // radians float heading_diff = fabsf(car_heading - heading); - if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff; + if (heading_diff > M_PI) + heading_diff = 2.0f * M_PI - heading_diff; // within 15 degrees return (heading_diff < (M_PI / 12.0f)) ? 1 : 0; } -void reset_agent_metrics(Drive* env, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; - agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision - agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad - agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned - agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; +void reset_agent_metrics(Drive *env, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; + agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision + agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad + agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned + agent->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; // far from lane + agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // no alignment agent->collision_state = 0; + agent->current_lane_geometry_idx = -1; } float point_to_segment_distance_2d(float px, float py, float x1, float y1, float x2, float y2) { @@ -1262,8 +1404,10 @@ float point_to_segment_distance_2d(float px, float py, float x1, float y1, float float t = ((px - x1) * dx + (py - y1) * dy) / (dx * dx + dy * dy); // Clamp t to the segment - if (t < 0) t = 0; - else if (t > 1) t = 1; + if (t < 0) + t = 0; + else if (t > 1) + t = 1; // Find the closest point on the segment float closestX = x1 + t * dx; @@ -1273,28 +1417,17 @@ float point_to_segment_distance_2d(float px, float py, float x1, float y1, float return sqrtf((px - closestX) * (px - closestX) + (py - closestY) * (py - closestY)); } -void compute_agent_metrics(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; +void compute_agent_metrics(Drive *env, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; reset_agent_metrics(env, agent_idx); - if(agent->x == INVALID_POSITION ) return; // invalid agent position - - // Compute displacement error - float displacement_error = compute_displacement_error(agent, env->timestep); - - if (displacement_error > 0.0f) { // Only count valid displacements - agent->cumulative_displacement += displacement_error; - agent->displacement_sample_count++; - - // Compute running average - agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = - agent->cumulative_displacement / agent->displacement_sample_count; - } + if (agent->x == INVALID_POSITION) + return; // invalid agent position int collided = 0; - float half_length = agent->length/2.0f; - float half_width = agent->width/2.0f; + float half_length = agent->length / 2.0f; + float half_width = agent->width / 2.0f; float cos_heading = cosf(agent->heading); float sin_heading = sinf(agent->heading); float min_distance = (float)INT16_MAX; @@ -1304,20 +1437,30 @@ void compute_agent_metrics(Drive* env, int agent_idx) { float corners[4][2]; for (int i = 0; i < 4; i++) { - corners[i][0] = agent->x + (offsets[i][0]*half_length*cos_heading - offsets[i][1]*half_width*sin_heading); - corners[i][1] = agent->y + (offsets[i][0]*half_length*sin_heading + offsets[i][1]*half_width*cos_heading); + corners[i][0] = + agent->x + (offsets[i][0] * half_length * cos_heading - offsets[i][1] * half_width * sin_heading); + corners[i][1] = + agent->y + (offsets[i][0] * half_length * sin_heading + offsets[i][1] * half_width * cos_heading); } - GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; // Array big enough for all neighboring cells - int list_size = checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL*25, collision_offsets, 25); - for (int i = 0; i < list_size ; i++) { - if(entity_list[i].entity_idx == -1) continue; - if(entity_list[i].entity_idx == agent_idx) continue; - Entity* entity; + GridMapEntity entity_list[MAX_ENTITIES_PER_CELL * 25]; // Array big enough for all neighboring cells + int list_size = + checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL * 25, collision_offsets, 25); + + // Track checked lanes to avoid duplicate processing (reset before loop) + int checked_lanes[MAX_CHECKED_LANES]; + int num_checked = 0; + + for (int i = 0; i < list_size; i++) { + if (entity_list[i].entity_idx == -1) + continue; + if (entity_list[i].entity_idx == agent_idx) + continue; + Entity *entity; entity = &env->entities[entity_list[i].entity_idx]; // Check for offroad collision with road edges - if(entity->type == ROAD_EDGE) { + if (entity->type == ROAD_EDGE) { int geometry_idx = entity_list[i].geometry_idx; float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]}; float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]}; @@ -1330,90 +1473,158 @@ void compute_agent_metrics(Drive* env, int agent_idx) { } } - if (collided == OFFROAD) break; + if (collided == OFFROAD) + break; - // Find closest point on the road centerline to the agent - if(entity->type == ROAD_LANE) { + // Find closest lane using GIGAFLOW-style scoring (distance + heading) + if (entity->type == ROAD_LANE) { int entity_idx = entity_list[i].entity_idx; - int geometry_idx = entity_list[i].geometry_idx; - float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]}; - float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]}; + // Skip if already checked this lane + int already_checked = 0; + for (int c = 0; c < num_checked; c++) { + if (checked_lanes[c] == entity_idx) { + already_checked = 1; + break; + } + } + if (already_checked) + continue; + if (num_checked < MAX_CHECKED_LANES) + checked_lanes[num_checked++] = entity_idx; + + // Find closest segment on this lane (returns signed distance) + int segment_idx; + float signed_dist = find_closest_segment_on_lane(entity, agent->x, agent->y, &segment_idx); + float abs_dist = fabsf(signed_dist); + if (abs_dist > LANE_DISTANCE_NORMALIZATION) + continue; + + // Compute lane heading using multi-segment alignment + float lane_heading = compute_multi_segment_alignment(entity, segment_idx); + + // Compute heading alignment penalty (0.0 = perfect, 1.0 = opposite) + float heading_diff = fabsf(compute_heading_diff(agent->heading, lane_heading)); + float heading_penalty = heading_diff / M_PI; - float dist = point_to_segment_distance_2d(agent->x, agent->y, start[0], start[1], end[0], end[1]); - float heading_diff = fabsf(atan2f(end[1]-start[1], end[0]-start[0]) - agent->heading); + // Normalize distance for scoring + float distance_penalty = abs_dist / LANE_DISTANCE_NORMALIZATION; - // Normalize heading difference to [0, pi] - if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff; + // Combined score using defined weights + float score = + LANE_SELECTION_DISTANCE_WEIGHT * distance_penalty + LANE_SELECTION_HEADING_WEIGHT * heading_penalty; - // Penalize if heading differs by more than 30 degrees - if (heading_diff > (M_PI / 6.0f)) dist += 3.0f; + // Hysteresis: penalize switching away from current lane + if (agent->current_lane_idx != entity_idx && agent->current_lane_idx != -1) { + score += LANE_SWITCH_THRESHOLD; + } - if (dist < min_distance) { - min_distance = dist; + // Track best candidate + if (score < min_distance) { // Using min_distance to store best score + min_distance = score; closest_lane_entity_idx = entity_idx; - closest_lane_geometry_idx = geometry_idx; + closest_lane_geometry_idx = segment_idx; } } } - // check if aligned with closest lane and set current lane - // 4.0m threshold: agents more than 4 meters from any lane are considered off-road - if (min_distance > 4.0f || closest_lane_entity_idx == -1) { - agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; - agent->current_lane_idx = -1; - } else { + // Update lane metrics using GIGAFLOW Frenet coordinates + if (closest_lane_entity_idx != -1) { + Entity *lane = &env->entities[closest_lane_entity_idx]; + + // Recompute signed distance and heading for best lane + int segment_idx; + float signed_dist = find_closest_segment_on_lane(lane, agent->x, agent->y, &segment_idx); + float lane_heading = compute_multi_segment_alignment(lane, segment_idx); + float theta_f = compute_heading_diff(agent->heading, lane_heading); + + // Store lane metrics agent->current_lane_idx = closest_lane_entity_idx; + agent->current_lane_geometry_idx = closest_lane_geometry_idx; + agent->metrics_array[LANE_DIST_IDX] = signed_dist; + agent->metrics_array[LANE_ANGLE_IDX] = cosf(theta_f); - int lane_aligned = check_lane_aligned(agent, &env->entities[closest_lane_entity_idx], closest_lane_geometry_idx); + // Binary lane aligned flag (within ~15 degrees) + int lane_aligned = (fabsf(agent->metrics_array[LANE_ANGLE_IDX]) > LANE_ALIGN_COS_THRESHOLD) ? 1 : 0; agent->metrics_array[LANE_ALIGNED_IDX] = lane_aligned; + } else { + // Not on any lane + agent->current_lane_idx = -1; + agent->current_lane_geometry_idx = -1; + agent->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; + agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; + agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; } // Check for vehicle collisions int car_collided_with_index = collision_check(env, agent_idx); - if (car_collided_with_index != -1) collided = VEHICLE_COLLISION; + if (car_collided_with_index != -1) + collided = VEHICLE_COLLISION; agent->collision_state = collided; + if (collided == VEHICLE_COLLISION) { + if (env->collision_behavior == STOP_AGENT && !agent->stopped) { + agent->stopped = 1; + agent->vx = agent->vy = 0.0f; + } else if (env->collision_behavior == REMOVE_AGENT && !agent->removed) { + Entity *agent_collided = &env->entities[car_collided_with_index]; + agent->removed = 1; + agent_collided->removed = 1; + agent->x = agent->y = -10000.0f; + agent_collided->x = agent_collided->y = -10000.0f; + } + } + if (collided == OFFROAD) { + agent->metrics_array[OFFROAD_IDX] = 1.0f; + if (env->offroad_behavior == STOP_AGENT && !agent->stopped) { + agent->stopped = 1; + agent->vx = agent->vy = 0.0f; + } else if (env->offroad_behavior == REMOVE_AGENT && !agent->removed) { + agent->removed = 1; + agent->x = agent->y = -10000.0f; + } + } + return; } -bool should_control_agent(Drive* env, int agent_idx){ - +bool should_control_agent(Drive *env, int agent_idx) { // Check if we have room for more agents or are already at capacity if (env->active_agent_count >= env->num_agents) { return false; } - Entity* entity = &env->entities[agent_idx]; + Entity *entity = &env->entities[agent_idx]; - // Shrink agent size for collision checking - entity->width *= 0.7f; // TODO: Move this somewhere else + // TODO: Move this elsewhere or remove + entity->width *= 0.7f; entity->length *= 0.7f; if (env->control_mode == CONTROL_SDC_ONLY) { - return (agent_idx == env->sdc_track_index); + return agent_idx == env->sdc_track_index; } - // Special mode: control only agents in prediction track list - if (env->control_mode == CONTROL_TRACKS_TO_PREDICT) { - for (int j = 0; j < env->num_tracks_to_predict; j++) { - if (env->tracks_to_predict_indices[j] == agent_idx) { - return true; - } - } - return false; - } + bool is_vehicle = (entity->type == VEHICLE); + bool is_ped_or_bike = (entity->type == PEDESTRIAN || entity->type == CYCLIST); + bool type_is_valid = false; + + switch (env->control_mode) { + case CONTROL_WOSAC: + // Valid types only, ignore expert flag and goal distance + return (is_vehicle || is_ped_or_bike); + + case CONTROL_VEHICLES: + type_is_valid = is_vehicle; + break; - // Standard mode: check type, distance to goal, and expert status - bool type_is_controllable = false; - if (env->control_mode == CONTROL_VEHICLES) { - type_is_controllable = (entity->type == VEHICLE); - } else { // CONTROL_AGENTS mode - type_is_controllable = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST); + default: + type_is_valid = (is_vehicle || is_ped_or_bike); + break; } - if (!type_is_controllable || entity->mark_as_expert) { + // Filter invalid types or experts + if (!type_is_valid || entity->mark_as_expert) { return false; } @@ -1431,26 +1642,24 @@ bool should_control_agent(Drive* env, int agent_idx){ return distance_to_goal >= MIN_DISTANCE_TO_GOAL; } -void set_active_agents(Drive* env){ - +void set_active_agents(Drive *env) { // Initialize - env->active_agent_count = 0; // Policy-controlled agents - env->static_agent_count = 0; // Non-moving background agents + env->active_agent_count = 0; // Policy-controlled agents + env->static_agent_count = 0; // Non-moving background agents env->expert_static_agent_count = 0; // Expert replay agents (non-controlled) - env->num_actors = 0; // Total agents created + env->num_actors = 0; // Total agents created int active_agent_indices[MAX_AGENTS]; int static_agent_indices[MAX_AGENTS]; int expert_static_agent_indices[MAX_AGENTS]; - if(env->num_agents == 0){ + if (env->num_agents == 0) { env->num_agents = MAX_AGENTS; } - // Iterate through entities to find agents to create and/or control - for(int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++){ + for (int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++) { - Entity* entity = &env->entities[i]; + Entity *entity = &env->entities[i]; // Skip if not valid at initialization if (entity->traj_valid[env->init_steps] != 1) { @@ -1460,14 +1669,15 @@ void set_active_agents(Drive* env){ // Determine if entity should be created bool should_create = false; if (env->init_mode == INIT_ALL_VALID) { - should_create = true; // All valid entities + should_create = true; // All valid entities } else if (env->control_mode == CONTROL_VEHICLES) { should_create = (entity->type == VEHICLE); - } else { // Control all agents + } else { // Control all agents should_create = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST); } - if (!should_create) continue; + if (!should_create) + continue; env->num_actors++; @@ -1475,12 +1685,13 @@ void set_active_agents(Drive* env){ bool is_controlled = false; is_controlled = should_control_agent(env, i); - if (is_controlled && env->active_agent_count >= env->max_controlled_agents && env->max_controlled_agents != -1) { + if (is_controlled && env->active_agent_count >= env->max_controlled_agents && + env->max_controlled_agents != -1) { is_controlled = false; entity->mark_as_expert = 1; } - if(is_controlled){ + if (is_controlled) { active_agent_indices[env->active_agent_count] = i; env->active_agent_count++; env->entities[i].active_agent = 1; @@ -1488,7 +1699,7 @@ void set_active_agents(Drive* env){ static_agent_indices[env->static_agent_count] = i; env->static_agent_count++; env->entities[i].active_agent = 0; - if(env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) { + if (env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) { expert_static_agent_indices[env->expert_static_agent_count] = i; env->expert_static_agent_count++; env->entities[i].mark_as_expert = 1; @@ -1497,24 +1708,24 @@ void set_active_agents(Drive* env){ } // Set up initial active agents - env->active_agent_indices = (int*)malloc(env->active_agent_count * sizeof(int)); - env->static_agent_indices = (int*)malloc(env->static_agent_count * sizeof(int)); - env->expert_static_agent_indices = (int*)malloc(env->expert_static_agent_count * sizeof(int)); - for(int i=0;iactive_agent_count;i++){ + env->active_agent_indices = (int *)malloc(env->active_agent_count * sizeof(int)); + env->static_agent_indices = (int *)malloc(env->static_agent_count * sizeof(int)); + env->expert_static_agent_indices = (int *)malloc(env->expert_static_agent_count * sizeof(int)); + for (int i = 0; i < env->active_agent_count; i++) { env->active_agent_indices[i] = active_agent_indices[i]; } - for(int i=0;istatic_agent_count;i++){ + for (int i = 0; i < env->static_agent_count; i++) { env->static_agent_indices[i] = static_agent_indices[i]; } - for(int i=0;iexpert_static_agent_count;i++){ + for (int i = 0; i < env->expert_static_agent_count; i++) { env->expert_static_agent_indices[i] = expert_static_agent_indices[i]; } return; } -void remove_bad_trajectories(Drive* env){ +void remove_bad_trajectories(Drive *env) { - if (env->control_mode != CONTROL_TRACKS_TO_PREDICT) { + if (env->control_mode != CONTROL_WOSAC) { return; // Leave all trajectories in WOSAC control mode } @@ -1526,22 +1737,23 @@ void remove_bad_trajectories(Drive* env){ collided_with_indices[i] = -1; } // move experts through trajectories to check for collisions and remove as illegal agents - for(int t = 0; t < env->scenario_length; t++){ - for(int i = 0; i < env->active_agent_count; i++){ + for (int t = 0; t < env->scenario_length; t++) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; move_expert(env, env->actions, agent_idx); } - for(int i = 0; i < env->expert_static_agent_count; i++){ + for (int i = 0; i < env->expert_static_agent_count; i++) { int expert_idx = env->expert_static_agent_indices[i]; - if(env->entities[expert_idx].x == INVALID_POSITION) continue; + if (env->entities[expert_idx].x == INVALID_POSITION) + continue; move_expert(env, env->actions, expert_idx); } // check collisions - for(int i = 0; i < env->active_agent_count; i++){ + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; env->entities[agent_idx].collision_state = 0; int collided_with_index = collision_check(env, agent_idx); - if((collided_with_index >= 0) && collided_agents[i] == 0){ + if ((collided_with_index >= 0) && collided_agents[i] == 0) { collided_agents[i] = 1; collided_with_indices[i] = collided_with_index; } @@ -1549,11 +1761,13 @@ void remove_bad_trajectories(Drive* env){ env->timestep++; } - for(int i = 0; i< env->active_agent_count; i++){ - if(collided_with_indices[i] == -1) continue; - for(int j = 0; j < env->static_agent_count; j++){ + for (int i = 0; i < env->active_agent_count; i++) { + if (collided_with_indices[i] == -1) + continue; + for (int j = 0; j < env->static_agent_count; j++) { int static_agent_idx = env->static_agent_indices[j]; - if(static_agent_idx != collided_with_indices[i]) continue; + if (static_agent_idx != collided_with_indices[i]) + continue; env->entities[static_agent_idx].traj_x[0] = INVALID_POSITION; env->entities[static_agent_idx].traj_y[0] = INVALID_POSITION; } @@ -1561,15 +1775,15 @@ void remove_bad_trajectories(Drive* env){ env->timestep = 0; } -void init_goal_positions(Drive* env){ - for(int x = 0;xactive_agent_count; x++){ +void init_goal_positions(Drive *env) { + for (int x = 0; x < env->active_agent_count; x++) { int agent_idx = env->active_agent_indices[x]; env->entities[agent_idx].init_goal_x = env->entities[agent_idx].goal_position_x; env->entities[agent_idx].init_goal_y = env->entities[agent_idx].goal_position_y; } } -void assign_ego_and_coplayer_roles(Drive* env) { +void assign_ego_and_coplayer_roles(Drive *env) { if (!env->population_play || env->num_ego_agents == 0) { for (int i = 0; i < env->num_entities; i++) { if (!env->entities[i].mark_as_expert) { @@ -1605,17 +1819,13 @@ void assign_ego_and_coplayer_roles(Drive* env) { } } - - - -void init(Drive* env){ +void init(Drive *env) { env->human_agent_idx = 0; env->timestep = 0; env->entities = load_map_binary(env->map_name, env); set_means(env); init_grid_map(env); - if (env->goal_behavior==GOAL_GENERATE_NEW) init_topology_graph(env); - env->grid_map->vision_range = 21; + env->grid_map->vision_range = 21; // TODO: Why is this hardcoded? init_neighbor_offsets(env); cache_neighbor_offsets(env); env->logs_capacity = 0; @@ -1625,34 +1835,32 @@ void init(Drive* env){ set_start_position(env); init_goal_positions(env); assign_ego_and_coplayer_roles(env); - env->logs = (Log*)calloc(env->active_agent_count, sizeof(Log)); + env->logs = (Log *)calloc(env->active_agent_count, sizeof(Log)); // Always allocate weight arrays for consistency - env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float)); + env->collision_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->offroad_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->goal_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->entropy_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->discount_weights = (float *)calloc(env->active_agent_count, sizeof(float)); if (env->population_play) { - if (env->co_player_logs) { free(env->co_player_logs); env->co_player_logs = NULL; } if (env->active_agent_count > 0) { - env->co_player_logs = (Co_Player_Log*)calloc(env->active_agent_count, sizeof(Co_Player_Log)); + env->co_player_logs = (Log *)calloc(env->active_agent_count, sizeof(Log)); } else { env->co_player_logs = NULL; } - memset(&env->co_player_log, 0, sizeof(Co_Player_Log)); + memset(&env->co_player_log, 0, sizeof(Log)); } if (env->population_play) { - if (env->co_player_logs) { free(env->co_player_logs); env->co_player_logs = NULL; @@ -1661,38 +1869,64 @@ void init(Drive* env){ // Always allocate for all active agents, not just co-players // because we index by x which goes from 0 to active_agent_count-1 if (env->active_agent_count > 0) { - env->co_player_logs = (Co_Player_Log*)calloc(env->active_agent_count, sizeof(Co_Player_Log)); + env->co_player_logs = (Log *)calloc(env->active_agent_count, sizeof(Log)); } else { env->co_player_logs = NULL; } - memset(&env->co_player_log, 0, sizeof(Co_Player_Log)); + memset(&env->co_player_log, 0, sizeof(Log)); + } +} + +void close_client(Client *client); + +// Render-mode helper: stash env->client across vec_close + re-vectorize so the +// raylib window + ffmpeg pipe survive a map swap. Needed because raylib's +// CloseWindow → InitWindow cycle segfaults on LoadModel under our xvfb-headless +// setup (stale GL state). We keep a single global slot — the renderer is +// always single-env, single-window. Multi-env render would need an array. +static Client *g_donated_client = NULL; + +void c_donate_client(Drive *env) { + if (env->client != NULL) { + g_donated_client = env->client; + env->client = NULL; // c_close now no-ops the client teardown. } } +void c_adopt_client(Drive *env) { + if (g_donated_client != NULL) { + env->client = g_donated_client; + g_donated_client = NULL; + } +} -void c_close(Drive* env){ +void c_close(Drive *env) { + if (env->client != NULL) { + close_client(env->client); + env->client = NULL; + } if (env->population_play && env->co_player_logs != NULL) { free(env->co_player_logs); free(env->co_player_ids); free(env->ego_agent_ids); } - for(int i = 0; i < env->num_entities; i++){ + for (int i = 0; i < env->num_entities; i++) { free_entity(&env->entities[i]); } free(env->entities); free(env->active_agent_indices); free(env->logs); // GridMap cleanup - int grid_cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; - for(int grid_index = 0; grid_index < grid_cell_count; grid_index++){ + int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows; + for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) { free(env->grid_map->cells[grid_index]); } free(env->grid_map->cells); free(env->grid_map->cell_entities_count); free(env->neighbor_offsets); - for(int i = 0; i < grid_cell_count; i++){ + for (int i = 0; i < grid_cell_count; i++) { free(env->grid_map->neighbor_cache_entities[i]); } free(env->grid_map->neighbor_cache_entities); @@ -1700,34 +1934,30 @@ void c_close(Drive* env){ free(env->grid_map); free(env->static_agent_indices); free(env->expert_static_agent_indices); - freeTopologyGraph(env->topology_graph); - // free(env->map_name); free(env->ini_file); - } -void allocate(Drive* env){ +void allocate(Drive *env) { init(env); - int base_ego_dim = (env->dynamics_model == JERK) ? 10 : 7; + int base_ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC; int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0); int ego_dim = base_ego_dim + conditioning_dims; // Always allocate weight arrays for consistency - env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - - int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS; - env->observations = (float*)calloc(env->active_agent_count*max_obs, sizeof(float)); - env->actions = (float*)calloc(env->active_agent_count*2, sizeof(float)); - env->rewards = (float*)calloc(env->active_agent_count, sizeof(float)); - env->terminals= (unsigned char*)calloc(env->active_agent_count, sizeof(unsigned char)); - + env->collision_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->offroad_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->goal_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->entropy_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->discount_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + + int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS; + env->observations = (float *)calloc(env->active_agent_count * max_obs, sizeof(float)); + env->actions = (float *)calloc(env->active_agent_count * 2, sizeof(float)); + env->rewards = (float *)calloc(env->active_agent_count, sizeof(float)); + env->terminals = (unsigned char *)calloc(env->active_agent_count, sizeof(unsigned char)); } -void free_allocated(Drive* env){ +void free_allocated(Drive *env) { free(env->observations); free(env->actions); free(env->rewards); @@ -1745,24 +1975,21 @@ void free_allocated(Drive* env){ float clipSpeed(float speed) { const float maxSpeed = MAX_SPEED; - if (speed > maxSpeed) return maxSpeed; - if (speed < -maxSpeed) return -maxSpeed; + if (speed > maxSpeed) + return maxSpeed; + if (speed < -maxSpeed) + return -maxSpeed; return speed; } -float normalize_heading(float heading){ - if(heading > M_PI) heading -= 2*M_PI; - if(heading < -M_PI) heading += 2*M_PI; - return heading; -} +// normalize_heading is defined earlier in this file (around line 1222) -float normalize_value(float value, float min, float max){ - return (value - min) / (max - min); -} +float normalize_value(float value, float min, float max) { return (value - min) / (max - min); } -void move_dynamics(Drive* env, int action_idx, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; - if (agent->removed) return; +void move_dynamics(Drive *env, int action_idx, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; + if (agent->removed) + return; if (agent->stopped) { agent->vx = 0.0f; @@ -1776,7 +2003,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ float steering = 0.0f; if (env->action_type == 1) { // continuous - float (*action_array_f)[2] = (float(*)[2])env->actions; + float (*action_array_f)[2] = (float (*)[2])env->actions; acceleration = action_array_f[action_idx][0]; steering = action_array_f[action_idx][1]; @@ -1784,8 +2011,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ steering *= STEERING_VALUES[12]; } else { // discrete // Interpret action as a single integer: a = accel_idx * num_steer + steer_idx - int* action_array = (int*)env->actions; - int num_accel = sizeof(ACCELERATION_VALUES) / sizeof(ACCELERATION_VALUES[0]); + int *action_array = (int *)env->actions; int num_steer = sizeof(STEERING_VALUES) / sizeof(STEERING_VALUES[0]); int action_val = action_array[action_idx]; int acceleration_index = action_val / num_steer; @@ -1801,27 +2027,28 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ float vx = agent->vx; float vy = agent->vy; - // Calculate current speed - float speed = sqrtf(vx*vx + vy*vy); + // Calculate current speed (signed based on direction relative to heading) + float speed_magnitude = sqrtf(vx * vx + vy * vy); + float v_dot_heading = vx * agent->heading_x + vy * agent->heading_y; + float signed_speed = copysignf(speed_magnitude, v_dot_heading); // Update speed with acceleration - speed = speed + acceleration*env->dt; - speed = clipSpeed(speed); - + signed_speed = signed_speed + acceleration * env->dt; + signed_speed = clipSpeed(signed_speed); // Compute yaw rate - float beta = tanh(.5*tanf(steering)); + float beta = tanh(.5 * tanf(steering)); // New heading - float yaw_rate = (speed*cosf(beta)*tanf(steering)) / agent->length; + float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / agent->length; // New velocity - float new_vx = speed*cosf(heading + beta); - float new_vy = speed*sinf(heading + beta); + float new_vx = signed_speed * cosf(heading + beta); + float new_vy = signed_speed * sinf(heading + beta); // Update position - x = x + (new_vx*env->dt); - y = y + (new_vy*env->dt); - heading = heading + yaw_rate*env->dt; + x = x + (new_vx * env->dt); + y = y + (new_vy * env->dt); + heading = heading + yaw_rate * env->dt; // Apply updates to the agent's state agent->x = x; @@ -1836,23 +2063,26 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ // Extract action components float a_long, a_lat; if (env->action_type == 1) { // continuous - float (*action_array_f)[2] = (float(*)[2])env->actions; + float (*action_array_f)[2] = (float (*)[2])env->actions; // Asymmetric scaling for longitudinal jerk to match discrete action space // Discrete: JERK_LONG = [-15, -4, 0, 4] (more braking than acceleration) - float a_long_action = action_array_f[action_idx][0]; // [-1, 1] + float a_long_action = action_array_f[action_idx][0]; // [-1, 1] if (a_long_action < 0) { - a_long = a_long_action * (-JERK_LONG[0]); // Negative: [-1, 0] → [-15, 0] (braking) + a_long = a_long_action * (-JERK_LONG[0]); // Negative: [-1, 0] → [-15, 0] (braking) } else { - a_long = a_long_action * JERK_LONG[3]; // Positive: [0, 1] → [0, 4] (acceleration) + a_long = a_long_action * JERK_LONG[3]; // Positive: [0, 1] → [0, 4] (acceleration) } // Symmetric scaling for lateral jerk a_lat = action_array_f[action_idx][1] * JERK_LAT[2]; } else { // discrete - int (*action_array)[2] = (int(*)[2])env->actions; - int a_long_idx = action_array[action_idx][0]; - int a_lat_idx = action_array[action_idx][1]; + // Interpret action as a single integer: a = long_idx * num_lat + lat_idx + int *action_array = (int *)env->actions; + int num_lat = sizeof(JERK_LAT) / sizeof(JERK_LAT[0]); + int action_val = action_array[action_idx]; + int a_long_idx = action_val / num_lat; + int a_lat_idx = action_val % num_lat; a_long = JERK_LONG[a_long_idx]; a_lat = JERK_LAT[a_lat_idx]; } @@ -1876,7 +2106,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ // Calculate new velocity float v_dot_heading = agent->vx * agent->heading_x + agent->vy * agent->heading_y; - float signed_v = copysignf(sqrtf(agent->vx*agent->vx + agent->vy*agent->vy), v_dot_heading); + float signed_v = copysignf(sqrtf(agent->vx * agent->vx + agent->vy * agent->vy), v_dot_heading); float v_new = signed_v + 0.5f * (a_long_new + agent->a_long) * env->dt; // Make it easy to stop with 0 vel @@ -1931,10 +2161,23 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ return; } -void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* id_out) { - for(int i = 0; i < env->active_agent_count; i++){ +static inline int get_track_id_or_placeholder(Drive *env, int agent_idx) { + if (env->tracks_to_predict_indices == NULL || env->num_tracks_to_predict == 0) { + return -1; + } + for (int k = 0; k < env->num_tracks_to_predict; k++) { + if (env->tracks_to_predict_indices[k] == agent_idx) { + return env->tracks_to_predict_indices[k]; + } + } + return -1; +} + +void c_get_global_agent_state(Drive *env, float *x_out, float *y_out, float *z_out, float *heading_out, int *id_out, + float *length_out, float *width_out) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; - Entity* agent = &env->entities[agent_idx]; + Entity *agent = &env->entities[agent_idx]; // For WOSAC, we need the original world coordinates, so we add the world means back x_out[i] = agent->x + env->world_mean_x; @@ -1945,14 +2188,15 @@ void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_o } } -void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* valid_out, int* id_out, int* scenario_id_out) { - for(int i = 0; i < env->active_agent_count; i++){ +void c_get_global_ground_truth_trajectories(Drive *env, float *x_out, float *y_out, float *z_out, float *heading_out, + int *valid_out, int *id_out, int *scenario_id_out) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; - Entity* agent = &env->entities[agent_idx]; + Entity *agent = &env->entities[agent_idx]; id_out[i] = env->tracks_to_predict_indices[i]; scenario_id_out[i] = agent->scenario_id; - for(int t = env->init_steps; t < agent->array_size; t++){ + for (int t = env->init_steps; t < agent->array_size; t++) { int out_idx = i * (agent->array_size - env->init_steps) + (t - env->init_steps); // Add world means back to get original world coordinates x_out[out_idx] = agent->traj_x[t] + env->world_mean_x; @@ -1964,48 +2208,87 @@ void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_o } } -void compute_observations(Drive* env) { - int base_ego_dim = (env->dynamics_model == JERK) ? 10 : 7; +void c_get_road_edge_counts(Drive *env, int *num_polylines_out, int *total_points_out) { + int count = 0, points = 0; + for (int i = env->num_objects; i < env->num_entities; i++) { + if (env->entities[i].type == ROAD_EDGE) { + count++; + points += env->entities[i].array_size; + } + } + *num_polylines_out = count; + *total_points_out = points; +} + +void c_get_road_edge_polylines(Drive *env, float *x_out, float *y_out, int *lengths_out, int *scenario_ids_out) { + int poly_idx = 0, pt_idx = 0; + for (int i = env->num_objects; i < env->num_entities; i++) { + Entity *e = &env->entities[i]; + if (e->type == ROAD_EDGE) { + lengths_out[poly_idx] = e->array_size; + scenario_ids_out[poly_idx] = e->scenario_id; + for (int j = 0; j < e->array_size; j++) { + x_out[pt_idx] = e->traj_x[j] + env->world_mean_x; + y_out[pt_idx] = e->traj_y[j] + env->world_mean_y; + pt_idx++; + } + poly_idx++; + } + } +} + +void compute_observations(Drive *env) { + int base_ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC; int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0); int ego_dim = base_ego_dim + conditioning_dims; - int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS; + int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS; - memset(env->observations, 0, max_obs*env->active_agent_count*sizeof(float)); - float (*observations)[max_obs] = (float(*)[max_obs])env->observations; + memset(env->observations, 0, max_obs * env->active_agent_count * sizeof(float)); + float (*observations)[max_obs] = (float (*)[max_obs])env->observations; - for(int i = 0; i < env->active_agent_count; i++) { - float* obs = &observations[i][0]; - Entity* ego_entity = &env->entities[env->active_agent_indices[i]]; - if(ego_entity->type > 3) break; + for (int i = 0; i < env->active_agent_count; i++) { + float *obs = &observations[i][0]; + Entity *ego_entity = &env->entities[env->active_agent_indices[i]]; + if (ego_entity->type > 3) + break; float cos_heading = ego_entity->heading_x; float sin_heading = ego_entity->heading_y; - float ego_speed = sqrtf(ego_entity->vx*ego_entity->vx + ego_entity->vy*ego_entity->vy); + float speed_magnitude = sqrtf(ego_entity->vx * ego_entity->vx + ego_entity->vy * ego_entity->vy); + float v_dot_heading = ego_entity->vx * ego_entity->heading_x + ego_entity->vy * ego_entity->heading_y; + float signed_speed = copysignf(speed_magnitude, v_dot_heading); // Set goal distances float goal_x = ego_entity->goal_position_x - ego_entity->x; float goal_y = ego_entity->goal_position_y - ego_entity->y; // Rotate to ego vehicle's frame - float rel_goal_x = goal_x*cos_heading + goal_y*sin_heading; - float rel_goal_y = -goal_x*sin_heading + goal_y*cos_heading; + float rel_goal_x = goal_x * cos_heading + goal_y * sin_heading; + float rel_goal_y = -goal_x * sin_heading + goal_y * cos_heading; - obs[0] = rel_goal_x* 0.005f; - obs[1] = rel_goal_y* 0.005f; - obs[2] = ego_speed / MAX_SPEED; + obs[0] = rel_goal_x * 0.005f; + obs[1] = rel_goal_y * 0.005f; + obs[2] = signed_speed / MAX_SPEED; obs[3] = ego_entity->width / MAX_VEH_WIDTH; obs[4] = ego_entity->length / MAX_VEH_LEN; obs[5] = (ego_entity->collision_state > 0) ? 1.0f : 0.0f; obs[6] = (ego_entity->respawn_timestep != -1) ? 1 : 0; + // Lane alignment observations (GIGAFLOW Frenet coordinates) + float lane_center_dist = ego_entity->metrics_array[LANE_DIST_IDX] / LANE_DISTANCE_NORMALIZATION; + lane_center_dist = fmaxf(-1.0f, fminf(1.0f, lane_center_dist)); // Clamp to [-1, 1] + obs[7] = lane_center_dist; + obs[8] = ego_entity->metrics_array[LANE_ANGLE_IDX]; // cos(theta_f), already in [-1, 1] + if (env->dynamics_model == JERK) { - obs[7] = ego_entity->steering_angle / M_PI; + obs[9] = ego_entity->steering_angle / M_PI; // Asymmetric normalization for a_long to match action space - obs[8] = (ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3]; - obs[9] = ego_entity->a_lat / JERK_LAT[2]; + obs[10] = + (ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3]; + obs[11] = ego_entity->a_lat / JERK_LAT[2]; } - int obs_idx = (env->dynamics_model == JERK) ? 10 : 7; + int obs_idx = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC; // Add conditioning weights to observations if (env->use_rc) { obs[obs_idx++] = env->collision_weights[i]; @@ -2021,86 +2304,97 @@ void compute_observations(Drive* env) { // Relative Pos of other cars int cars_seen = 0; - for(int j = 0; j < MAX_AGENTS; j++) { + for (int j = 0; j < MAX_AGENTS; j++) { int index = -1; - if(j < env->active_agent_count){ + if (j < env->active_agent_count) { index = env->active_agent_indices[j]; - } else if (j < env->num_actors){ + } else if (j < env->num_actors) { index = env->static_agent_indices[j - env->active_agent_count]; } - if(index == -1) continue; - if(env->entities[index].type > 3) break; - if(index == env->active_agent_indices[i]) continue; // Skip self, but don't increment obs_idx - Entity* other_entity = &env->entities[index]; - if(ego_entity->respawn_timestep != -1) continue; - if(other_entity->respawn_timestep != -1) continue; + if (index == -1) + continue; + if (env->entities[index].type > 3) + break; + if (index == env->active_agent_indices[i]) + continue; // Skip self, but don't increment obs_idx + Entity *other_entity = &env->entities[index]; + if (ego_entity->respawn_timestep != -1) + continue; + if (other_entity->respawn_timestep != -1) + continue; // Store original relative positions float dx = other_entity->x - ego_entity->x; float dy = other_entity->y - ego_entity->y; - float dist = (dx*dx + dy*dy); - if(dist > 2500.0f) continue; + float dist = (dx * dx + dy * dy); + if (dist > 2500.0f) + continue; // Rotate to ego vehicle's frame - float rel_x = dx*cos_heading + dy*sin_heading; - float rel_y = -dx*sin_heading + dy*cos_heading; + float rel_x = dx * cos_heading + dy * sin_heading; + float rel_y = -dx * sin_heading + dy * cos_heading; // Store observations with correct indexing obs[obs_idx] = rel_x * 0.02f; - // Add conditioning weights to observations + // Add conditioning weights to observations obs[obs_idx + 1] = rel_y * 0.02f; obs[obs_idx + 2] = other_entity->width / MAX_VEH_WIDTH; obs[obs_idx + 3] = other_entity->length / MAX_VEH_LEN; // relative heading - float rel_heading_x = other_entity->heading_x * ego_entity->heading_x + - other_entity->heading_y * ego_entity->heading_y; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b) - float rel_heading_y = other_entity->heading_y * ego_entity->heading_x - - other_entity->heading_x * ego_entity->heading_y; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b) + float rel_heading_x = + other_entity->heading_x * ego_entity->heading_x + + other_entity->heading_y * ego_entity->heading_y; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b) + float rel_heading_y = + other_entity->heading_y * ego_entity->heading_x - + other_entity->heading_x * ego_entity->heading_y; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b) obs[obs_idx + 4] = rel_heading_x; obs[obs_idx + 5] = rel_heading_y; - // obs[obs_idx + 4] = cosf(rel_heading) / MAX_ORIENTATION_RAD; - // obs[obs_idx + 5] = sinf(rel_heading) / MAX_ORIENTATION_RAD; - // // relative speed - float other_speed = sqrtf(other_entity->vx*other_entity->vx + other_entity->vy*other_entity->vy); - obs[obs_idx + 6] = other_speed / MAX_SPEED; + + // relative speed + float other_speed_magnitude = + sqrtf(other_entity->vx * other_entity->vx + other_entity->vy * other_entity->vy); + float other_v_dot_heading = + other_entity->vx * other_entity->heading_x + other_entity->vy * other_entity->heading_y; + float other_signed_speed = copysignf(other_speed_magnitude, other_v_dot_heading); + obs[obs_idx + 6] = other_signed_speed / MAX_SPEED; cars_seen++; - obs_idx += 7; // Move to next observation slot + obs_idx += 7; // Move to next observation slot } int remaining_partner_obs = (MAX_AGENTS - 1 - cars_seen) * 7; memset(&obs[obs_idx], 0, remaining_partner_obs * sizeof(float)); obs_idx += remaining_partner_obs; // map observations - GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; + GridMapEntity entity_list[MAX_ENTITIES_PER_CELL * 25]; int grid_idx = getGridIndex(env, ego_entity->x, ego_entity->y); int list_size = get_neighbor_cache_entities(env, grid_idx, entity_list, MAX_ROAD_SEGMENT_OBSERVATIONS); - for(int k = 0; k < list_size; k++) { + for (int k = 0; k < list_size; k++) { int entity_idx = entity_list[k].entity_idx; int geometry_idx = entity_list[k].geometry_idx; // Validate entity_idx before accessing - if(entity_idx < 0 || entity_idx >= env->num_entities) { - printf("ERROR: Invalid entity_idx %d (max: %d)\n", entity_idx, env->num_entities-1); + if (entity_idx < 0 || entity_idx >= env->num_entities) { + printf("ERROR: Invalid entity_idx %d (max: %d)\n", entity_idx, env->num_entities - 1); continue; } - Entity* entity = &env->entities[entity_idx]; + Entity *entity = &env->entities[entity_idx]; // Validate geometry_idx before accessing - if(geometry_idx < 0 || geometry_idx >= entity->array_size) { - printf("ERROR: Invalid geometry_idx %d for entity %d (max: %d)\n", - geometry_idx, entity_idx, entity->array_size-1); + if (geometry_idx < 0 || geometry_idx >= entity->array_size) { + printf("ERROR: Invalid geometry_idx %d for entity %d (max: %d)\n", geometry_idx, entity_idx, + entity->array_size - 1); continue; } float start_x = entity->traj_x[geometry_idx]; float start_y = entity->traj_y[geometry_idx]; - float end_x = entity->traj_x[geometry_idx+1]; - float end_y = entity->traj_y[geometry_idx+1]; + float end_x = entity->traj_x[geometry_idx + 1]; + float end_y = entity->traj_y[geometry_idx + 1]; float mid_x = (start_x + end_x) / 2.0f; float mid_y = (start_y + end_y) / 2.0f; float rel_x = mid_x - ego_entity->x; float rel_y = mid_y - ego_entity->y; - float x_obs = rel_x*cos_heading + rel_y*sin_heading; - float y_obs = -rel_x*sin_heading + rel_y*cos_heading; + float x_obs = rel_x * cos_heading + rel_y * sin_heading; + float y_obs = -rel_x * sin_heading + rel_y * cos_heading; float length = relative_distance_2d(mid_x, mid_y, end_x, end_y); float width = 0.1; // Calculate angle from ego to midpoint (vector from ego to midpoint) @@ -2108,14 +2402,14 @@ void compute_observations(Drive* env) { float dy = end_y - mid_y; float dx_norm = dx; float dy_norm = dy; - float hypot = sqrtf(dx*dx + dy*dy); - if(hypot > 0) { + float hypot = sqrtf(dx * dx + dy * dy); + if (hypot > 0) { dx_norm /= hypot; dy_norm /= hypot; } // Compute sin and cos of relative angle directly without atan2f - float cos_angle = dx_norm*cos_heading + dy_norm*sin_heading; - float sin_angle = -dx_norm*sin_heading + dy_norm*cos_heading; + float cos_angle = dx_norm * cos_heading + dy_norm * sin_heading; + float sin_angle = -dx_norm * sin_heading + dy_norm * cos_heading; obs[obs_idx] = x_obs * 0.02f; obs[obs_idx + 1] = y_obs * 0.02f; obs[obs_idx + 2] = length / MAX_ROAD_SEGMENT_LENGTH; @@ -2131,170 +2425,103 @@ void compute_observations(Drive* env) { } } -static int find_forward_projection_on_lane(Entity* lane, Entity* agent, int* out_segment_idx, float* out_fraction) { - int best_idx = -1; - float best_dist_sq = 1e30f; - - for (int i = 1; i < lane->array_size; i++) { - float x0 = lane->traj_x[i - 1]; - float y0 = lane->traj_y[i - 1]; - float x1 = lane->traj_x[i]; - float y1 = lane->traj_y[i]; - float dx = x1 - x0; - float dy = y1 - y0; - float seg_len_sq = dx * dx + dy * dy; - if (seg_len_sq < 1e-6f) continue; - - float to_agent_x = agent->x - x0; - float to_agent_y = agent->y - y0; - float t = (to_agent_x * dx + to_agent_y * dy) / seg_len_sq; - if (t < 0.0f) t = 0.0f; - else if (t > 1.0f) t = 1.0f; - - float proj_x = x0 + t * dx; - float proj_y = y0 + t * dy; - - float rel_x = proj_x - agent->x; - float rel_y = proj_y - agent->y; - float forward = rel_x * agent->heading_x + rel_y * agent->heading_y; - if (forward < 0.0f) continue; - - float dist_sq = rel_x * rel_x + rel_y * rel_y; - if (dist_sq < best_dist_sq) { - best_dist_sq = dist_sq; - best_idx = i; - *out_fraction = t; - } - } - - if (best_idx != -1) { - *out_segment_idx = best_idx; - return 1; - } +void sample_new_goal(Drive *env, int agent_idx) { + // Samples a new goal position based on the existing road lane points + Entity *agent = &env->entities[agent_idx]; + float best_x = agent->x; + float best_y = agent->y; + float best_distance_error = 1e30f; - return 0; -} + // Sample points from all road lanes + for (int i = env->num_objects; i < env->num_entities; i++) { + if (env->entities[i].type != ROAD_LANE) + continue; -void compute_new_goal(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; - int current_lane = agent->current_lane_idx; + Entity *lane = &env->entities[i]; - if (current_lane == -1) return; // No current lane + // Check every point in the lane + for (int j = 0; j < lane->array_size; j++) { + float point_x = lane->traj_x[j]; + float point_y = lane->traj_y[j]; - // Target distance: 40m ahead along the lane topology from agent's current position - float target_distance = 40.0f; - int current_entity = current_lane; - Entity* lane = &env->entities[current_entity]; + // Calculate vector from agent to point + float to_point_x = point_x - agent->x; + float to_point_y = point_y - agent->y; - int initial_segment_idx = 1; - float initial_fraction = 0.0f; - if (!find_forward_projection_on_lane(lane, agent, &initial_segment_idx, &initial_fraction)) { - int forward_idx = -1; - for (int i = 0; i < lane->array_size; i++) { - float to_point_x = lane->traj_x[i] - agent->x; - float to_point_y = lane->traj_y[i] - agent->y; + // Check if point is ahead of agent float dot = to_point_x * agent->heading_x + to_point_y * agent->heading_y; - if (dot > 0.0f) { - forward_idx = i; - break; - } - } - - if (forward_idx == -1) { - agent->goal_position_x = lane->traj_x[lane->array_size - 1]; - agent->goal_position_y = lane->traj_y[lane->array_size - 1]; - agent->sampled_new_goal = 0; - return; - } - - initial_segment_idx = forward_idx; - if (initial_segment_idx == 0) initial_segment_idx = 1; - initial_fraction = 0.0f; - } - - float remaining_distance = target_distance; - int first_lane = 1; - - // Traverse the topology graph starting from the vehicle's position forward - while (current_entity != -1) { - lane = &env->entities[current_entity]; - - int start_idx = first_lane ? initial_segment_idx : 1; - // Ensure start_idx is at least 1 to avoid accessing traj_x[i-1] with i=0 - if (start_idx < 1) start_idx = 1; - first_lane = 0; + if (dot <= 0.0f) + continue; - for (int i = start_idx; i < lane->array_size; i++) { - float prev_x = lane->traj_x[i - 1]; - float prev_y = lane->traj_y[i - 1]; - float next_x = lane->traj_x[i]; - float next_y = lane->traj_y[i]; - float seg_dx = next_x - prev_x; - float seg_dy = next_y - prev_y; - float segment_length = relative_distance_2d(prev_x, prev_y, next_x, next_y); + // Calculate distance to point + float distance = sqrtf(to_point_x * to_point_x + to_point_y * to_point_y); - if (remaining_distance <= segment_length) { - agent->goal_position_x = next_x; - agent->goal_position_y = next_y; - agent->sampled_new_goal = 0; - return; + // Find point closest to target distance + float distance_error = fabsf(distance - env->goal_target_distance); + if (distance_error < best_distance_error) { + best_distance_error = distance_error; + best_x = point_x; + best_y = point_y; } - - remaining_distance -= segment_length; - } - - int connected_lanes[5]; - int num_connected = getNextLanes(env->topology_graph, current_entity, connected_lanes, 5); - - if (num_connected == 0) { - agent->goal_position_x = lane->traj_x[lane->array_size - 1]; - agent->goal_position_y = lane->traj_y[lane->array_size - 1]; - agent->sampled_new_goal = 0; - return; // No further lanes to traverse } + } - int random_idx = agent_idx % num_connected; - current_entity = connected_lanes[random_idx]; + // If no valid goal found, use another agent's initial goal + if (best_distance_error >= 1e30f && env->active_agent_count > 1) { + int other_idx = env->active_agent_indices[(agent_idx + 1) % env->active_agent_count]; + best_x = env->entities[other_idx].init_goal_x; + best_y = env->entities[other_idx].init_goal_y; } + + agent->goal_position_x = best_x; + agent->goal_position_y = best_y; + agent->goals_sampled_this_episode += 1; } -void c_reset(Drive* env){ +void c_reset(Drive *env) { env->timestep = env->init_steps; set_start_position(env); - // Initialize all conditioning weights even when no conditioning (lb=ub) - for(int i = 0; i < env->active_agent_count; i++) { - env->collision_weights[i] = ((float)rand() / RAND_MAX) * (env->collision_weight_ub - env->collision_weight_lb) + env->collision_weight_lb; - env->offroad_weights[i] = ((float)rand() / RAND_MAX) * (env->offroad_weight_ub - env->offroad_weight_lb) + env->offroad_weight_lb; - env->goal_weights[i] = ((float)rand() / RAND_MAX) * (env->goal_weight_ub - env->goal_weight_lb) + env->goal_weight_lb; - env->entropy_weights[i] = ((float)rand() / RAND_MAX) * (env->entropy_weight_ub - env->entropy_weight_lb) + env->entropy_weight_lb; - env->discount_weights[i] = ((float)rand() / RAND_MAX) * (env->discount_weight_ub - env->discount_weight_lb) + env->discount_weight_lb; - } - - for(int x = 0;xactive_agent_count; x++){ + for (int i = 0; i < env->active_agent_count; i++) { + env->collision_weights[i] = ((float)rand() / RAND_MAX) * (env->collision_weight_ub - env->collision_weight_lb) + + env->collision_weight_lb; + env->offroad_weights[i] = + ((float)rand() / RAND_MAX) * (env->offroad_weight_ub - env->offroad_weight_lb) + env->offroad_weight_lb; + env->goal_weights[i] = + ((float)rand() / RAND_MAX) * (env->goal_weight_ub - env->goal_weight_lb) + env->goal_weight_lb; + env->entropy_weights[i] = + ((float)rand() / RAND_MAX) * (env->entropy_weight_ub - env->entropy_weight_lb) + env->entropy_weight_lb; + env->discount_weights[i] = + ((float)rand() / RAND_MAX) * (env->discount_weight_ub - env->discount_weight_lb) + env->discount_weight_lb; + } + + for (int x = 0; x < env->active_agent_count; x++) { env->logs[x] = (Log){0}; int agent_idx = env->active_agent_indices[x]; env->entities[agent_idx].respawn_timestep = -1; env->entities[agent_idx].respawn_count = 0; env->entities[agent_idx].collided_before_goal = 0; - env->entities[agent_idx].reached_goal_this_episode = 0; + env->entities[agent_idx].goals_reached_this_episode = 0.0f; + // Initialize to 1 because there is one goal in the data file + env->entities[agent_idx].goals_sampled_this_episode = 1.0f; + env->entities[agent_idx].current_goal_reached = 0; env->entities[agent_idx].metrics_array[COLLISION_IDX] = 0.0f; env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; - env->entities[agent_idx].cumulative_displacement = 0.0f; - env->entities[agent_idx].displacement_sample_count = 0; - env->entities[agent_idx].stopped = 0; + env->entities[agent_idx].metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; + env->entities[agent_idx].metrics_array[LANE_ANGLE_IDX] = 0.0f; + env->entities[agent_idx].current_lane_idx = -1; + env->entities[agent_idx].current_lane_geometry_idx = -1; + env->entities[agent_idx].stopped = 0; env->entities[agent_idx].removed = 0; - if (env->goal_behavior==GOAL_GENERATE_NEW) { + if (env->goal_behavior == GOAL_GENERATE_NEW) { env->entities[agent_idx].goal_position_x = env->entities[agent_idx].init_goal_x; env->entities[agent_idx].goal_position_y = env->entities[agent_idx].init_goal_y; - env->entities[agent_idx].sampled_new_goal = 0; } - if (env->population_play){ - env->co_player_logs[x] = (Co_Player_Log){0}; + if (env->population_play) { + env->co_player_logs[x] = (Log){0}; } compute_agent_metrics(env, agent_idx); @@ -2302,7 +2529,7 @@ void c_reset(Drive* env){ compute_observations(env); } -void respawn_agent(Drive* env, int agent_idx){ +void respawn_agent(Drive *env, int agent_idx) { env->entities[agent_idx].x = env->entities[agent_idx].traj_x[0]; env->entities[agent_idx].y = env->entities[agent_idx].traj_y[0]; env->entities[agent_idx].heading = env->entities[agent_idx].traj_heading[0]; @@ -2314,10 +2541,13 @@ void respawn_agent(Drive* env, int agent_idx){ env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; - env->entities[agent_idx].cumulative_displacement = 0.0f; - env->entities[agent_idx].displacement_sample_count = 0; + env->entities[agent_idx].metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; + env->entities[agent_idx].metrics_array[LANE_ANGLE_IDX] = 0.0f; + env->entities[agent_idx].current_lane_idx = -1; + env->entities[agent_idx].current_lane_geometry_idx = -1; + env->entities[agent_idx].respawn_timestep = env->timestep; + env->entities[agent_idx].collided_before_goal = 0; env->entities[agent_idx].stopped = 0; env->entities[agent_idx].removed = 0; env->entities[agent_idx].a_long = 0.0f; @@ -2327,43 +2557,55 @@ void respawn_agent(Drive* env, int agent_idx){ env->entities[agent_idx].steering_angle = 0.0f; } -void c_step(Drive* env){ +void c_step(Drive *env) { memset(env->rewards, 0, env->active_agent_count * sizeof(float)); memset(env->terminals, 0, env->active_agent_count * sizeof(unsigned char)); env->timestep++; - if(env->timestep == env->scenario_length){ + + int originals_remaining = 0; + for (int i = 0; i < env->active_agent_count; i++) { + int agent_idx = env->active_agent_indices[i]; + // Keep flag true if there is at least one agent that has not been respawned yet + if (env->entities[agent_idx].respawn_count == 0) { + originals_remaining = 1; + break; + } + } + + if (env->timestep == env->scenario_length || (!originals_remaining && env->termination_mode == 1)) { add_log(env); c_reset(env); - return; } // Move static experts for (int i = 0; i < env->expert_static_agent_count; i++) { int expert_idx = env->expert_static_agent_indices[i]; - if(env->entities[expert_idx].x == INVALID_POSITION) continue; + if (env->entities[expert_idx].x == INVALID_POSITION) + continue; move_expert(env, env->actions, expert_idx); } // Process actions for all active agents - for(int i = 0; i < env->active_agent_count; i++){ + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; env->entities[agent_idx].collision_state = 0; + move_dynamics(env, i, agent_idx); // Update logs based on agent type - use i directly as log index - if(env->entities[agent_idx].is_ego){ + if (env->entities[agent_idx].is_ego) { env->logs[i].score = 0.0f; env->logs[i].episode_length += 1; - } else if(env->entities[agent_idx].is_co_player){ - env->co_player_logs[i].co_player_score = 0.0f; - env->co_player_logs[i].co_player_episode_length += 1; + } else if (env->entities[agent_idx].is_co_player) { + env->co_player_logs[i].score = 0.0f; + env->co_player_logs[i].episode_length += 1; } } // Compute metrics and rewards - for(int i = 0; i < env->active_agent_count; i++){ + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; env->entities[agent_idx].collision_state = 0; @@ -2372,143 +2614,155 @@ void c_step(Drive* env){ int collision_state = env->entities[agent_idx].collision_state; int is_ego = env->entities[agent_idx].is_ego; int is_co_player = env->entities[agent_idx].is_co_player; - - // Handle collisions - SAME REWARD for both ego and co-players - if(collision_state > 0){ - if(collision_state == VEHICLE_COLLISION){ + int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; + + // Handle collisions - SAME REWARD for both ego and co-players. + // Skip for agents already stopped (e.g. after reaching the goal in + // GOAL_STOP mode): the policy isn't driving anymore, so a rear-end + // shouldn't count against its collision/offroad metrics. + if (collision_state > 0 && !env->entities[agent_idx].stopped) { + if (collision_state == VEHICLE_COLLISION) { env->rewards[i] = env->collision_weights[i]; - if(is_ego){ + if (is_ego) { env->logs[i].episode_return += env->collision_weights[i]; env->logs[i].collision_rate = 1.0f; - env->logs[i].avg_collisions_per_agent += 1.0f; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += env->collision_weights[i]; - env->co_player_logs[i].co_player_collision_rate = 1.0f; + env->logs[i].collisions_per_agent += 1.0f; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += env->collision_weights[i]; + env->co_player_logs[i].collision_rate = 1.0f; + env->co_player_logs[i].collisions_per_agent += 1.0f; } - } else if(collision_state == OFFROAD){ + } else if (collision_state == OFFROAD) { env->rewards[i] = env->offroad_weights[i]; - if(is_ego){ + if (is_ego) { env->logs[i].episode_return += env->offroad_weights[i]; env->logs[i].offroad_rate = 1.0f; - env->logs[i].avg_offroad_per_agent += 1.0f; // ADD THIS - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += env->offroad_weights[i]; - env->co_player_logs[i].co_player_offroad_rate = 1.0f; - env->logs[i].avg_offroad_per_agent += 1.0f; - } + env->logs[i].offroad_per_agent += 1.0f; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += env->offroad_weights[i]; + env->co_player_logs[i].offroad_rate = 1.0f; + env->co_player_logs[i].offroad_per_agent += 1.0f; + } } - if(!env->entities[agent_idx].reached_goal_this_episode){ + if (!reached_goal) { env->entities[agent_idx].collided_before_goal = 1; } } - // Handle goal reward - SAME REWARD for both ego and co-players - float distance_to_goal = relative_distance_2d( - env->entities[agent_idx].x, - env->entities[agent_idx].y, - env->entities[agent_idx].goal_position_x, - env->entities[agent_idx].goal_position_y + // Handle goal reward - NEW INCOMING LOGIC with speed check + float distance_to_goal = + relative_distance_2d(env->entities[agent_idx].x, env->entities[agent_idx].y, + env->entities[agent_idx].goal_position_x, env->entities[agent_idx].goal_position_y); + + float current_speed = sqrtf(env->entities[agent_idx].vx * env->entities[agent_idx].vx + + env->entities[agent_idx].vy * env->entities[agent_idx].vy); - ); + // Reward agent if it is within X meters of goal and speed is below threshold + bool within_distance = distance_to_goal < env->goal_radius; + bool within_speed = current_speed <= env->goal_speed; - if(distance_to_goal < env->goal_radius){ - if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1){ + if (within_distance && within_speed && !env->entities[agent_idx].current_goal_reached) { + if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1) { float scaled_post_respawn_reward = env->reward_goal_post_respawn * env->goal_weights[i]; env->rewards[i] += scaled_post_respawn_reward; - if(is_ego){ + + if (is_ego) { env->logs[i].episode_return += scaled_post_respawn_reward; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += scaled_post_respawn_reward; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += scaled_post_respawn_reward; } - } else if (env->goal_behavior == GOAL_GENERATE_NEW) { + env->entities[agent_idx].current_goal_reached = 1; + } else if (env->goal_behavior == GOAL_GENERATE_NEW && (!env->entities[agent_idx].current_goal_reached)) { env->rewards[i] += env->goal_weights[i]; - env->entities[agent_idx].sampled_new_goal = 1; - if(is_ego){ + + if (is_ego) { env->logs[i].episode_return += env->goal_weights[i]; - env->logs[i].num_goals_reached += 1; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += env->goal_weights[i]; - env->co_player_logs[i].co_player_num_goals_reached += 1; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += env->goal_weights[i]; } + + sample_new_goal(env, agent_idx); + env->entities[agent_idx].current_goal_reached = 0; + env->entities[agent_idx].goals_reached_this_episode += 1.0f; } else { // Zero out the velocity so that the agent stops at the goal env->rewards[i] = env->goal_weights[i]; - if(is_ego){ + + if (is_ego) { env->logs[i].episode_return = env->goal_weights[i]; - env->logs[i].num_goals_reached = 1; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return = env->goal_weights[i]; - env->co_player_logs[i].co_player_num_goals_reached = 1; + } else if (is_co_player) { + env->co_player_logs[i].episode_return = env->goal_weights[i]; } + env->entities[agent_idx].stopped = 1; - env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f; + env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f; + env->entities[agent_idx].goals_reached_this_episode += 1.0f; } - env->entities[agent_idx].reached_goal_this_episode = 1; env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 1.0f; - } - if(env->entities[agent_idx].sampled_new_goal && env->goal_behavior == GOAL_GENERATE_NEW){ - compute_new_goal(env, agent_idx); + if (is_ego) { + env->logs[i].speed_at_goal = current_speed; + } else if (is_co_player) { + env->co_player_logs[i].speed_at_goal = current_speed; + } } - int lane_aligned = env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX]; - if(is_ego){ - env->logs[i].lane_alignment_rate = lane_aligned; - } else if(is_co_player){ - env->co_player_logs[i].co_player_lane_alignment_rate = lane_aligned; - } + // Lane alignment reward (GIGAFLOW formula) + // Only apply if reward_lane_align > 0 (disabled by default) + if (env->reward_lane_align > 0.0f) { + float cos_theta = env->entities[agent_idx].metrics_array[LANE_ANGLE_IDX]; + float theta_f = acosf(fminf(fmaxf(cos_theta, -1.0f), 1.0f)); // Get |θ_f| from cos + + // GIGAFLOW Rl-align: min(cos,0) + vel_align*min(cos*v,0) + 0.0025*(1-|θ|/(π/2)) + float against_lane_penalty = fminf(cos_theta, 0.0f); // Negative when >90° off + float vel_aligned_penalty = env->reward_vel_align * fminf(cos_theta * current_speed, 0.0f); + float alignment_bonus = 0.0025f * (1.0f - theta_f / (M_PI / 2.0f)); - float current_ade = env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX]; - if(current_ade > 0.0f && env->reward_ade != 0.0f){ - float ade_reward = env->reward_ade * current_ade; - env->rewards[i] += ade_reward; + float lane_align_reward = + env->reward_lane_align * env->dt * (against_lane_penalty + vel_aligned_penalty + alignment_bonus); - if(is_ego){ - env->logs[i].episode_return += ade_reward; - env->logs[i].avg_displacement_error = current_ade; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += ade_reward; - env->co_player_logs[i].co_player_avg_displacement_error = current_ade; + env->rewards[i] += lane_align_reward; + + if (is_ego) { + env->logs[i].episode_return += lane_align_reward; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += lane_align_reward; } } + + int lane_aligned = env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX]; + if (is_ego) { + env->logs[i].lane_alignment_rate = lane_aligned; + } else if (is_co_player) { + env->co_player_logs[i].lane_alignment_rate = lane_aligned; + } } - if (env->goal_behavior==GOAL_RESPAWN) { - for(int i = 0; i < env->active_agent_count; i++){ + if (env->goal_behavior == GOAL_RESPAWN) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; - if(reached_goal){ + if (reached_goal) { respawn_agent(env, agent_idx); env->entities[agent_idx].respawn_count++; } } - } - else if (env->goal_behavior==GOAL_STOP) { - for(int i = 0; i < env->active_agent_count; i++){ + } else if (env->goal_behavior == GOAL_STOP) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; - if(reached_goal){ + if (reached_goal) { env->entities[agent_idx].stopped = 1; - env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f; + env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f; } } } - compute_observations(env); } - -const Color STONE_GRAY = (Color){80, 80, 80, 255}; -const Color PUFF_RED = (Color){187, 0, 0, 255}; -const Color PUFF_CYAN = (Color){0, 187, 187, 255}; -const Color PUFF_WHITE = (Color){241, 241, 241, 241}; -const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255}; -const Color PUFF_BACKGROUND2 = (Color){18, 72, 72, 255}; -const Color LIGHTGREEN = (Color){152, 255, 152, 255}; - typedef struct Client Client; struct Client { float width; @@ -2518,18 +2772,118 @@ struct Client { float camera_zoom; Camera3D camera; Model cars[6]; - int car_assignments[MAX_AGENTS]; // To keep car model assignments consistent per vehicle + Model cyclist; + Model pedestrian; + ModelAnimation *cycle_anim; + int car_assignments[MAX_AGENTS]; // To keep car model assignments consistent per vehicle Vector3 default_camera_position; Vector3 default_camera_target; + // Video recording state (for headless rendering) + int recorder_pipefd[2]; // Pipe to ffmpeg process + pid_t recorder_pid; // PID of ffmpeg process + // Original map dimensions (for consistent rendering across scenarios) + float original_map_width; + float original_map_height; }; -Client* make_client(Drive* env){ - Client* client = (Client*)calloc(1, sizeof(Client)); - client->width = 1280; - client->height = 704; - SetConfigFlags(FLAG_MSAA_4X_HINT); - InitWindow(client->width, client->height, "PufferLib Ray GPU Drive"); - SetTargetFPS(30); +// Stop the running ffmpeg recorder (if any) and wait for it to finish writing. +static void stop_video_recorder(Client *client) { + if (client->recorder_pipefd[1] >= 0) { + close(client->recorder_pipefd[1]); + client->recorder_pipefd[1] = -1; + } + if (client->recorder_pipefd[0] >= 0) { + close(client->recorder_pipefd[0]); + client->recorder_pipefd[0] = -1; + } + if (client->recorder_pid > 0) { + int status; + waitpid(client->recorder_pid, &status, 0); + client->recorder_pid = 0; + } +} + +// Start (or restart) the ffmpeg recorder. Stops any in-flight recorder first +// so callers can switch output files mid-life without recreating raylib state. +static void start_video_recorder(Client *client, const char *basename) { + stop_video_recorder(client); + + if (pipe(client->recorder_pipefd) == -1) { + fprintf(stderr, "Failed to create pipe for video recording\n"); + client->recorder_pipefd[0] = -1; + client->recorder_pipefd[1] = -1; + return; + } + + char size_str[64]; + snprintf(size_str, sizeof(size_str), "%dx%d", (int)client->width, (int)client->height); + + char filename[320]; + snprintf(filename, sizeof(filename), "%s.mp4", basename && basename[0] ? basename : "render"); + + client->recorder_pid = fork(); + if (client->recorder_pid == -1) { + fprintf(stderr, "Failed to fork ffmpeg process\n"); + close(client->recorder_pipefd[0]); + close(client->recorder_pipefd[1]); + client->recorder_pipefd[0] = -1; + client->recorder_pipefd[1] = -1; + client->recorder_pid = 0; + return; + } + + if (client->recorder_pid == 0) { + // Child: run ffmpeg + close(client->recorder_pipefd[1]); + dup2(client->recorder_pipefd[0], STDIN_FILENO); + close(client->recorder_pipefd[0]); + for (int fd = 3; fd < 256; fd++) { + close(fd); + } + execlp("ffmpeg", "ffmpeg", "-y", "-f", "rawvideo", "-pix_fmt", "rgba", "-s", size_str, "-r", "30", "-i", "-", + "-c:v", "libx264", "-threads", "4", "-pix_fmt", "yuv420p", "-preset", "ultrafast", "-crf", "23", + "-loglevel", "error", filename, NULL); + fprintf(stderr, "Failed to exec ffmpeg\n"); + _exit(1); + } + + close(client->recorder_pipefd[0]); // parent: keep write end only + client->recorder_pipefd[0] = -1; + fprintf(stderr, "[drive] ffmpeg forked: pid=%d file=%s size=%s\n", client->recorder_pid, filename, size_str); +} + +Client *make_client(Drive *env) { + Client *client = (Client *)calloc(1, sizeof(Client)); + client->recorder_pid = 0; + client->recorder_pipefd[0] = -1; + client->recorder_pipefd[1] = -1; + + if (env->render_mode == RENDER_WINDOW) { + // Interactive window mode + client->width = 1280; + client->height = 704; + SetConfigFlags(FLAG_MSAA_4X_HINT); + InitWindow(client->width, client->height, "PufferDrive"); + SetTargetFPS(30); + } else if (env->render_mode == RENDER_HEADLESS) { + // Headless rendering mode - hidden window with ffmpeg pipe + float map_width = env->grid_map->bottom_right_x - env->grid_map->top_left_x; + float map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y; + float scale = 6.0f; + client->width = (int)roundf(map_width * scale / 2.0f) * 2; + client->height = (int)roundf(map_height * scale / 2.0f) * 2; + client->original_map_width = map_width; + client->original_map_height = map_height; + + SetConfigFlags(FLAG_WINDOW_HIDDEN); + SetConfigFlags(FLAG_MSAA_4X_HINT); + InitWindow(client->width, client->height, "PufferDrive Headless"); + SetTargetFPS(6000); + + start_video_recorder(client, env->video_basename); + } + + // Load textures and models (for both window and headless modes) client->puffers = LoadTexture("resources/puffers_128.png"); client->cars[0] = LoadModel("resources/drive/RedCar.glb"); client->cars[1] = LoadModel("resources/drive/WhiteCar.glb"); @@ -2537,26 +2891,21 @@ Client* make_client(Drive* env){ client->cars[3] = LoadModel("resources/drive/YellowCar.glb"); client->cars[4] = LoadModel("resources/drive/GreenCar.glb"); client->cars[5] = LoadModel("resources/drive/GreyCar.glb"); + client->cyclist = LoadModel("resources/drive/cyclist.glb"); + client->pedestrian = LoadModel("resources/drive/pedestrian.glb"); + int animCountCyc = 0; + client->cycle_anim = LoadModelAnimations("resources/drive/cyclist.glb", &animCountCyc); for (int i = 0; i < MAX_AGENTS; i++) { client->car_assignments[i] = (rand() % 4) + 1; } - // Get initial target position from first active agent - Vector3 target_pos = { - 0, - 0, // Y is up - 1 // Z is depth - }; - // Set up camera to look at target from above and behind - client->default_camera_position = (Vector3){ - 0, // Same X as target - 120.0f, // 20 units above target - 175.0f // 20 units behind target - }; + // Set up camera defaults + Vector3 target_pos = {0, 0, 1}; + client->default_camera_position = (Vector3){0, 120.0f, 175.0f}; client->default_camera_target = target_pos; client->camera.position = client->default_camera_position; client->camera.target = client->default_camera_target; - client->camera.up = (Vector3){ 0.0f, -1.0f, 0.0f }; // Y is up + client->camera.up = (Vector3){0.0f, -1.0f, 0.0f}; client->camera.fovy = 45.0f; client->camera.projection = CAMERA_PERSPECTIVE; client->camera_zoom = 1.0f; @@ -2564,7 +2913,7 @@ Client* make_client(Drive* env){ } // Camera control functions -void handle_camera_controls(Client* client) { +void handle_camera_controls(Client *client) { static Vector2 prev_mouse_pos = {0}; static bool is_dragging = false; float camera_move_speed = 0.5f; @@ -2581,10 +2930,8 @@ void handle_camera_controls(Client* client) { if (is_dragging) { Vector2 current_mouse_pos = GetMousePosition(); - Vector2 delta = { - (current_mouse_pos.x - prev_mouse_pos.x) * camera_move_speed, - -(current_mouse_pos.y - prev_mouse_pos.y) * camera_move_speed - }; + Vector2 delta = {(current_mouse_pos.x - prev_mouse_pos.x) * camera_move_speed, + -(current_mouse_pos.y - prev_mouse_pos.y) * camera_move_speed}; // Update camera position (only X and Y) client->camera.position.x += delta.x; @@ -2602,11 +2949,9 @@ void handle_camera_controls(Client* client) { if (wheel != 0) { float zoom_factor = 1.0f - (wheel * 0.1f); // Calculate the current direction vector from target to position - Vector3 direction = { - client->camera.position.x - client->camera.target.x, - client->camera.position.y - client->camera.target.y, - client->camera.position.z - client->camera.target.z - }; + Vector3 direction = {client->camera.position.x - client->camera.target.x, + client->camera.position.y - client->camera.target.y, + client->camera.position.z - client->camera.target.z}; // Scale the direction vector by the zoom factor direction.x *= zoom_factor; @@ -2620,28 +2965,27 @@ void handle_camera_controls(Client* client) { } } -void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int lasers){ +void draw_agent_obs(Drive *env, int agent_index, int mode, int obs_only, int lasers) { // Diamond dimensions - float diamond_height = 3.0f; // Total height of diamond - float diamond_width = 1.5f; // Width of diamond - float diamond_z = 8.0f; // Base Z position + float diamond_height = 3.0f; // Total height of diamond + float diamond_width = 1.5f; // Width of diamond + float diamond_z = 8.0f; // Base Z position // Define diamond points - Vector3 top_point = (Vector3){0.0f, 0.0f, diamond_z + diamond_height/2}; // Top point - Vector3 bottom_point = (Vector3){0.0f, 0.0f, diamond_z - diamond_height/2}; // Bottom point - Vector3 front_point = (Vector3){0.0f, diamond_width/2, diamond_z}; // Front point - Vector3 back_point = (Vector3){0.0f, -diamond_width/2, diamond_z}; // Back point - Vector3 left_point = (Vector3){-diamond_width/2, 0.0f, diamond_z}; // Left point - Vector3 right_point = (Vector3){diamond_width/2, 0.0f, diamond_z}; // Right point + Vector3 top_point = (Vector3){0.0f, 0.0f, diamond_z + diamond_height / 2}; // Top point + Vector3 bottom_point = (Vector3){0.0f, 0.0f, diamond_z - diamond_height / 2}; // Bottom point + Vector3 front_point = (Vector3){0.0f, diamond_width / 2, diamond_z}; // Front point + Vector3 back_point = (Vector3){0.0f, -diamond_width / 2, diamond_z}; // Back point + Vector3 left_point = (Vector3){-diamond_width / 2, 0.0f, diamond_z}; // Left point + Vector3 right_point = (Vector3){diamond_width / 2, 0.0f, diamond_z}; // Right point // Draw the diamond faces // Top pyramid - - if(mode ==0){ - DrawTriangle3D(top_point, front_point, right_point, PUFF_CYAN); // Front-right face - DrawTriangle3D(top_point, right_point, back_point, PUFF_CYAN); // Back-right face - DrawTriangle3D(top_point, back_point, left_point, PUFF_CYAN); // Back-left face - DrawTriangle3D(top_point, left_point, front_point, PUFF_CYAN); // Front-left face + if (mode == 0) { + DrawTriangle3D(top_point, front_point, right_point, PUFF_CYAN); // Front-right face + DrawTriangle3D(top_point, right_point, back_point, PUFF_CYAN); // Back-right face + DrawTriangle3D(top_point, back_point, left_point, PUFF_CYAN); // Back-left face + DrawTriangle3D(top_point, left_point, front_point, PUFF_CYAN); // Front-left face // Bottom pyramid DrawTriangle3D(bottom_point, right_point, front_point, PUFF_CYAN); // Front-right face @@ -2649,14 +2993,16 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las DrawTriangle3D(bottom_point, left_point, back_point, PUFF_CYAN); // Back-left face DrawTriangle3D(bottom_point, front_point, left_point, PUFF_CYAN); // Front-left face } - if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ + if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) { return; } - int ego_dim = (env->dynamics_model == JERK) ? 10 : 7; - int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS; - float (*observations)[max_obs] = (float(*)[max_obs])env->observations; - float* agent_obs = &observations[agent_index][0]; + int base_ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC; + int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0); + int ego_dim = base_ego_dim + conditioning_dims; + int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS; + float (*observations)[max_obs] = (float (*)[max_obs])env->observations; + float *agent_obs = &observations[agent_index][0]; // self int active_idx = env->active_agent_indices[agent_index]; float heading_self_x = env->entities[active_idx].heading_x; @@ -2666,82 +3012,64 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las // draw goal float goal_x = agent_obs[0] * 200; float goal_y = agent_obs[1] * 200; - if(mode == 0 ){ + if (mode == 0) { DrawSphere((Vector3){goal_x, goal_y, 1}, 0.5f, LIGHTGREEN); - DrawCircle3D((Vector3){goal_x, goal_y, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + DrawCircle3D((Vector3){goal_x, goal_y, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, + Fade(LIGHTGREEN, 0.3f)); } - if (mode == 1){ - float goal_x_world = px + (goal_x * heading_self_x - goal_y*heading_self_y); - float goal_y_world = py + (goal_x * heading_self_y + goal_y*heading_self_x); + if (mode == 1) { + float goal_x_world = px + (goal_x * heading_self_x - goal_y * heading_self_y); + float goal_y_world = py + (goal_x * heading_self_y + goal_y * heading_self_x); DrawSphere((Vector3){goal_x_world, goal_y_world, 1}, 0.5f, LIGHTGREEN); - DrawCircle3D((Vector3){goal_x_world, goal_y_world, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + DrawCircle3D((Vector3){goal_x_world, goal_y_world, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, + Fade(LIGHTGREEN, 0.3f)); } // First draw other agent observations - int obs_idx = ego_dim; // Start after ego obs - for(int j = 0; j < MAX_AGENTS - 1; j++) { - if(agent_obs[obs_idx] == 0 || agent_obs[obs_idx + 1] == 0) { - obs_idx += 7; // Move to next agent observation + int obs_idx = ego_dim; // Start after ego obs + for (int j = 0; j < MAX_AGENTS - 1; j++) { + if (agent_obs[obs_idx] == 0 || agent_obs[obs_idx + 1] == 0) { + obs_idx += 7; // Move to next agent observation continue; } // Draw position of other agents float x = agent_obs[obs_idx] * 50; float y = agent_obs[obs_idx + 1] * 50; - if(lasers && mode == 0){ - DrawLine3D( - (Vector3){0, 0, 0}, - (Vector3){x, y, 1}, - ORANGE - ); - } - - float partner_x = px + (x*heading_self_x - y*heading_self_y); - float partner_y = py + (x*heading_self_y + y*heading_self_x); - if(lasers && mode ==1){ - DrawLine3D( - (Vector3){px, py, 1}, - (Vector3){partner_x,partner_y,1}, - ORANGE - ); - } - - float half_width = 0.5*agent_obs[obs_idx + 2]*MAX_VEH_WIDTH; - float half_len = 0.5*agent_obs[obs_idx + 3]*MAX_VEH_LEN; + if (lasers && mode == 0) { + DrawLine3D((Vector3){0, 0, 0}, (Vector3){x, y, 1}, ORANGE); + } + + float partner_x = px + (x * heading_self_x - y * heading_self_y); + float partner_y = py + (x * heading_self_y + y * heading_self_x); + if (lasers && mode == 1) { + DrawLine3D((Vector3){px, py, 1}, (Vector3){partner_x, partner_y, 1}, ORANGE); + } + + float half_width = 0.5 * agent_obs[obs_idx + 2] * MAX_VEH_WIDTH; + float half_len = 0.5 * agent_obs[obs_idx + 3] * MAX_VEH_LEN; float theta_x = agent_obs[obs_idx + 4]; float theta_y = agent_obs[obs_idx + 5]; float partner_angle = atan2f(theta_y, theta_x); float cos_heading = cosf(partner_angle); float sin_heading = sinf(partner_angle); Vector3 corners[4] = { - (Vector3){ - x + (half_len * cos_heading - half_width * sin_heading), - y + (half_len * sin_heading + half_width * cos_heading), - 1 - }, - (Vector3){ - x + (half_len * cos_heading + half_width * sin_heading), - y + (half_len * sin_heading - half_width * cos_heading), - 1 - }, - (Vector3){ - x + (-half_len * cos_heading + half_width * sin_heading), - y + (-half_len * sin_heading - half_width * cos_heading), - 1 - }, - (Vector3){ - x + (-half_len * cos_heading - half_width * sin_heading), - y + (-half_len * sin_heading + half_width * cos_heading), - 1 - }, + (Vector3){x + (half_len * cos_heading - half_width * sin_heading), + y + (half_len * sin_heading + half_width * cos_heading), 1}, + (Vector3){x + (half_len * cos_heading + half_width * sin_heading), + y + (half_len * sin_heading - half_width * cos_heading), 1}, + (Vector3){x + (-half_len * cos_heading + half_width * sin_heading), + y + (-half_len * sin_heading - half_width * cos_heading), 1}, + (Vector3){x + (-half_len * cos_heading - half_width * sin_heading), + y + (-half_len * sin_heading + half_width * cos_heading), 1}, }; - if(mode ==0){ + if (mode == 0) { for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], ORANGE); + DrawLine3D(corners[j], corners[(j + 1) % 4], ORANGE); } } - if(mode ==1){ + if (mode == 1) { Vector3 world_corners[4]; for (int j = 0; j < 4; j++) { float lx = corners[j].x; @@ -2752,90 +3080,74 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las world_corners[j].z = 1; } for (int j = 0; j < 4; j++) { - DrawLine3D(world_corners[j], world_corners[(j+1)%4], ORANGE); + DrawLine3D(world_corners[j], world_corners[(j + 1) % 4], ORANGE); } } // draw an arrow above the car pointing in the direction that the partner is going - float arrow_length = 7.5f; - float arrow_x = x + arrow_length*cosf(partner_angle); - float arrow_y = y + arrow_length*sinf(partner_angle); + float arrow_length = 2.5f; + float arrow_x = x + arrow_length * cosf(partner_angle); + float arrow_y = y + arrow_length * sinf(partner_angle); float arrow_x_world; float arrow_y_world; - if(mode ==0){ - DrawLine3D((Vector3){x, y, 1}, (Vector3){arrow_x, arrow_y, 1}, PUFF_WHITE); + if (mode == 0) { + DrawLine3D((Vector3){x, y, 0.0}, (Vector3){arrow_x, arrow_y, 0.0}, PUFF_WHITE); } - if(mode == 1){ - arrow_x_world = px + (arrow_x * heading_self_x - arrow_y*heading_self_y); - arrow_y_world = py + (arrow_x * heading_self_y + arrow_y*heading_self_x); + if (mode == 1) { + arrow_x_world = px + (arrow_x * heading_self_x - arrow_y * heading_self_y); + arrow_y_world = py + (arrow_x * heading_self_y + arrow_y * heading_self_x); DrawLine3D((Vector3){partner_x, partner_y, 1}, (Vector3){arrow_x_world, arrow_y_world, 1}, PUFF_WHITE); } // Calculate perpendicular offsets for arrow head - float arrow_size = 2.0f; // Size of the arrow head + float arrow_size = 0.3f; // Size of the arrow head float dx = arrow_x - x; float dy = arrow_y - y; - float length = sqrtf(dx*dx + dy*dy); + float length = sqrtf(dx * dx + dy * dy); if (length > 0) { // Normalize direction vector dx /= length; dy /= length; // Calculate perpendicular vector - float perp_x = -dy * arrow_size; float perp_y = dx * arrow_size; - float arrow_x_end1 = arrow_x - dx*arrow_size + perp_x; - float arrow_y_end1 = arrow_y - dy*arrow_size + perp_y; - float arrow_x_end2 = arrow_x - dx*arrow_size - perp_x; - float arrow_y_end2 = arrow_y - dy*arrow_size - perp_y; + float arrow_x_end1 = arrow_x - dx * arrow_size + perp_x; + float arrow_y_end1 = arrow_y - dy * arrow_size + perp_y; + float arrow_x_end2 = arrow_x - dx * arrow_size - perp_x; + float arrow_y_end2 = arrow_y - dy * arrow_size - perp_y; // Draw the two lines forming the arrow head - if(mode ==0){ - DrawLine3D( - (Vector3){arrow_x, arrow_y, 1}, - (Vector3){arrow_x_end1, arrow_y_end1, 1}, - PUFF_WHITE - ); - DrawLine3D( - (Vector3){arrow_x, arrow_y, 1}, - (Vector3){arrow_x_end2, arrow_y_end2, 1}, - PUFF_WHITE - ); + if (mode == 0) { + DrawLine3D((Vector3){arrow_x, arrow_y, 0.0}, (Vector3){arrow_x_end1, arrow_y_end1, 0.0}, PUFF_WHITE); + DrawLine3D((Vector3){arrow_x, arrow_y, 0.0}, (Vector3){arrow_x_end2, arrow_y_end2, 0.0}, PUFF_WHITE); } - if(mode==1){ - float arrow_x_end1_world = px + (arrow_x_end1 * heading_self_x - arrow_y_end1*heading_self_y); - float arrow_y_end1_world = py + (arrow_x_end1 * heading_self_y + arrow_y_end1*heading_self_x); - float arrow_x_end2_world = px + (arrow_x_end2 * heading_self_x - arrow_y_end2*heading_self_y); - float arrow_y_end2_world = py + (arrow_x_end2 * heading_self_y + arrow_y_end2*heading_self_x); - DrawLine3D( - (Vector3){arrow_x_world, arrow_y_world, 1}, - (Vector3){arrow_x_end1_world, arrow_y_end1_world, 1}, - PUFF_WHITE - ); - DrawLine3D( - (Vector3){arrow_x_world, arrow_y_world, 1}, - (Vector3){arrow_x_end2_world, arrow_y_end2_world, 1}, - PUFF_WHITE - ); - + if (mode == 1) { + float arrow_x_end1_world = px + (arrow_x_end1 * heading_self_x - arrow_y_end1 * heading_self_y); + float arrow_y_end1_world = py + (arrow_x_end1 * heading_self_y + arrow_y_end1 * heading_self_x); + float arrow_x_end2_world = px + (arrow_x_end2 * heading_self_x - arrow_y_end2 * heading_self_y); + float arrow_y_end2_world = py + (arrow_x_end2 * heading_self_y + arrow_y_end2 * heading_self_x); + DrawLine3D((Vector3){arrow_x_world, arrow_y_world, 0.0}, + (Vector3){arrow_x_end1_world, arrow_y_end1_world, 0.0}, PUFF_WHITE); + DrawLine3D((Vector3){arrow_x_world, arrow_y_world, 0.0}, + (Vector3){arrow_x_end2_world, arrow_y_end2_world, 0.0}, PUFF_WHITE); } } - obs_idx += 7; // Move to next agent observation (7 values per agent) + obs_idx += PARTNER_FEATURES; // Move to next agent observation (7 values per agent) } // Then draw map observations - int map_start_idx = 7 + 7*(MAX_AGENTS - 1); // Start after agent observations - for(int k = 0; k < MAX_ROAD_SEGMENT_OBSERVATIONS; k++) { // Loop through potential map entities - int entity_idx = map_start_idx + k*7; - if(agent_obs[entity_idx] == 0 && agent_obs[entity_idx + 1] == 0){ + int map_start_idx = ego_dim + PARTNER_FEATURES * (MAX_AGENTS - 1); // Start after agent observations + for (int k = 0; k < MAX_ROAD_SEGMENT_OBSERVATIONS; k++) { // Loop through potential map entities + int entity_idx = map_start_idx + k * 7; + if (agent_obs[entity_idx] == 0 && agent_obs[entity_idx + 1] == 0) { continue; } - Color lineColor = BLUE; // Default color + Color lineColor = BLUE; // Default color int entity_type = (int)agent_obs[entity_idx + 6]; // Choose color based on entity type - if(entity_type+4 != ROAD_EDGE){ + if (entity_type + 4 != ROAD_EDGE) { continue; } lineColor = PUFF_CYAN; @@ -2848,88 +3160,60 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las float segment_length = agent_obs[entity_idx + 2] * MAX_ROAD_SEGMENT_LENGTH; // Calculate endpoint using the relative angle directly // Calculate endpoint directly - float x_start = x_middle - segment_length*cosf(rel_angle); - float y_start = y_middle - segment_length*sinf(rel_angle); - float x_end = x_middle + segment_length*cosf(rel_angle); - float y_end = y_middle + segment_length*sinf(rel_angle); - - - if(lasers && mode ==0){ - DrawLine3D((Vector3){0,0,0}, (Vector3){x_middle, y_middle, 1}, lineColor); - } - - if(mode ==1){ - float x_middle_world = px + (x_middle*heading_self_x - y_middle*heading_self_y); - float y_middle_world = py + (x_middle*heading_self_y + y_middle*heading_self_x); - float x_start_world = px + (x_start*heading_self_x - y_start*heading_self_y); - float y_start_world = py + (x_start*heading_self_y + y_start*heading_self_x); - float x_end_world = px + (x_end*heading_self_x - y_end*heading_self_y); - float y_end_world = py + (x_end*heading_self_y + y_end*heading_self_x); + float x_start = x_middle - segment_length * cosf(rel_angle); + float y_start = y_middle - segment_length * sinf(rel_angle); + float x_end = x_middle + segment_length * cosf(rel_angle); + float y_end = y_middle + segment_length * sinf(rel_angle); + + if (lasers && mode == 0) { + DrawLine3D((Vector3){0, 0, 0}, (Vector3){x_middle, y_middle, 1}, lineColor); + } + + if (mode == 1) { + float x_middle_world = px + (x_middle * heading_self_x - y_middle * heading_self_y); + float y_middle_world = py + (x_middle * heading_self_y + y_middle * heading_self_x); + float x_start_world = px + (x_start * heading_self_x - y_start * heading_self_y); + float y_start_world = py + (x_start * heading_self_y + y_start * heading_self_x); + float x_end_world = px + (x_end * heading_self_x - y_end * heading_self_y); + float y_end_world = py + (x_end * heading_self_y + y_end * heading_self_x); DrawCube((Vector3){x_middle_world, y_middle_world, 1}, 0.5f, 0.5f, 0.5f, lineColor); DrawLine3D((Vector3){x_start_world, y_start_world, 1}, (Vector3){x_end_world, y_end_world, 1}, BLUE); - if(lasers) DrawLine3D((Vector3){px,py,1}, (Vector3){x_middle_world, y_middle_world, 1}, lineColor); + if (lasers) + DrawLine3D((Vector3){px, py, 1}, (Vector3){x_middle_world, y_middle_world, 1}, lineColor); } - if(mode ==0){ + if (mode == 0) { DrawCube((Vector3){x_middle, y_middle, 1}, 0.5f, 0.5f, 0.5f, lineColor); DrawLine3D((Vector3){x_start, y_start, 1}, (Vector3){x_end, y_end, 1}, BLUE); } } } -void draw_road_edge(Drive* env, float start_x, float start_y, float end_x, float end_y){ - Color CURB_TOP = (Color){220, 220, 220, 255}; // Top surface - lightest - Color CURB_SIDE = (Color){180, 180, 180, 255}; // Side faces - medium +void draw_road_edge(Drive *env, float start_x, float start_y, float end_x, float end_y) { + Color CURB_TOP = (Color){220, 220, 220, 255}; // Top surface - lightest + Color CURB_SIDE = (Color){180, 180, 180, 255}; // Side faces - medium Color CURB_BOTTOM = (Color){160, 160, 160, 255}; - // Calculate curb dimensions - float curb_height = 0.5f; // Height of the curb - float curb_width = 0.3f; // Width/thickness of the curb - float road_z = 0.2f; // Ensure z-level for roads is below agents + // Calculate curb dimensions + float curb_height = 0.5f; // Height of the curb + float curb_width = 0.3f; // Width/thickness of the curb + float road_z = 0.0f; // Ensure z-level for roads is below agents // Calculate direction vector between start and end - Vector3 direction = { - end_x - start_x, - end_y - start_y, - 0.0f - }; + Vector3 direction = {end_x - start_x, end_y - start_y, 0.0f}; // Calculate length of the segment float length = sqrtf(direction.x * direction.x + direction.y * direction.y); // Normalize direction vector - Vector3 normalized_dir = { - direction.x / length, - direction.y / length, - 0.0f - }; + Vector3 normalized_dir = {direction.x / length, direction.y / length, 0.0f}; // Calculate perpendicular vector for width - Vector3 perpendicular = { - -normalized_dir.y, - normalized_dir.x, - 0.0f - }; + Vector3 perpendicular = {-normalized_dir.y, normalized_dir.x, 0.0f}; // Calculate the four bottom corners of the curb - Vector3 b1 = { - start_x - perpendicular.x * curb_width/2, - start_y - perpendicular.y * curb_width/2, - road_z - }; - Vector3 b2 = { - start_x + perpendicular.x * curb_width/2, - start_y + perpendicular.y * curb_width/2, - road_z - }; - Vector3 b3 = { - end_x + perpendicular.x * curb_width/2, - end_y + perpendicular.y * curb_width/2, - road_z - }; - Vector3 b4 = { - end_x - perpendicular.x * curb_width/2, - end_y - perpendicular.y * curb_width/2, - road_z - }; + Vector3 b1 = {start_x - perpendicular.x * curb_width / 2, start_y - perpendicular.y * curb_width / 2, road_z}; + Vector3 b2 = {start_x + perpendicular.x * curb_width / 2, start_y + perpendicular.y * curb_width / 2, road_z}; + Vector3 b3 = {end_x + perpendicular.x * curb_width / 2, end_y + perpendicular.y * curb_width / 2, road_z}; + Vector3 b4 = {end_x - perpendicular.x * curb_width / 2, end_y - perpendicular.y * curb_width / 2, road_z}; // Draw the curb faces // Bottom face @@ -2955,56 +3239,58 @@ void draw_road_edge(Drive* env, float start_x, float start_y, float end_x, float DrawTriangle3D(t4, t1, b1, CURB_SIDE); } -void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, int show_grid){ - // Draw a grid to help with orientation - // DrawGrid(20, 1.0f); - DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->top_left_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); - DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->top_left_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); - DrawLine3D((Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); - DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, PUFF_CYAN); - for(int i = 0; i < env->num_entities; i++) { +void draw_scene(Drive *env, Client *client, int mode, int obs_only, int lasers, int show_grid) { + + if (show_grid) { + float grid_start_x = env->grid_map->top_left_x; + float grid_start_y = env->grid_map->bottom_right_y; + for (int i = 0; i < env->grid_map->grid_cols; i++) { + for (int j = 0; j < env->grid_map->grid_rows; j++) { + float x = grid_start_x + i * GRID_CELL_SIZE; + float y = grid_start_y + j * GRID_CELL_SIZE; + DrawCubeWires((Vector3){x + GRID_CELL_SIZE / 2, y + GRID_CELL_SIZE / 2, 0.0f}, GRID_CELL_SIZE, + GRID_CELL_SIZE, 0.1f, Fade(PUFF_BACKGROUND2, 0.3f)); + } + } + } + + // Draw a grid to help with orientation + for (int i = 0; i < env->num_entities; i++) { // Draw objects - if(env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) { + if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || + env->entities[i].type == CYCLIST) { // Check if this vehicle is an active agent bool is_active_agent = false; bool is_static_agent = false; int agent_index = -1; - for(int j = 0; j < env->active_agent_count; j++) { - if(env->active_agent_indices[j] == i) { + for (int j = 0; j < env->active_agent_count; j++) { + if (env->active_agent_indices[j] == i) { is_active_agent = true; agent_index = j; break; } } - for(int j = 0; j < env->static_agent_count; j++) { - if(env->static_agent_indices[j] == i) { + for (int j = 0; j < env->static_agent_count; j++) { + if (env->static_agent_indices[j] == i) { is_static_agent = true; break; } } // HIDE CARS ON RESPAWN - IMPORTANT TO KNOW VISUAL SETTING - if((!is_active_agent && !is_static_agent) || env->entities[i].respawn_timestep != -1){ + if ((!is_active_agent && !is_static_agent) || env->entities[i].respawn_timestep != -1) { continue; } Vector3 position; float heading; - position = (Vector3){ - env->entities[i].x, - env->entities[i].y, - 1 - }; + position = (Vector3){env->entities[i].x, env->entities[i].y, 1.1}; heading = env->entities[i].heading; // Create size vector - Vector3 size = { - env->entities[i].length, - env->entities[i].width, - env->entities[i].height - }; + Vector3 size = {env->entities[i].length, env->entities[i].width, env->entities[i].height}; bool is_expert = (!is_active_agent) && (env->entities[i].mark_as_expert == 1); // Save current transform - if(mode==1){ + if (mode == 1) { float cos_heading = env->entities[i].heading_x; float sin_heading = env->entities[i].heading_y; @@ -3014,220 +3300,184 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, // Calculate the four corners of the collision box Vector3 corners[4] = { - (Vector3){ - position.x + (half_len * cos_heading - half_width * sin_heading), - position.y + (half_len * sin_heading + half_width * cos_heading), - position.z - }, - - - (Vector3){ - position.x + (half_len * cos_heading + half_width * sin_heading), - position.y + (half_len * sin_heading - half_width * cos_heading), - position.z - }, - (Vector3){ - position.x + (-half_len * cos_heading + half_width * sin_heading), - position.y + (-half_len * sin_heading - half_width * cos_heading), - position.z - }, - (Vector3){ - position.x + (-half_len * cos_heading - half_width * sin_heading), - position.y + (-half_len * sin_heading + half_width * cos_heading), - position.z - }, - + (Vector3){position.x + (half_len * cos_heading - half_width * sin_heading), + position.y + (half_len * sin_heading + half_width * cos_heading), position.z}, + (Vector3){position.x + (half_len * cos_heading + half_width * sin_heading), + position.y + (half_len * sin_heading - half_width * cos_heading), position.z}, + (Vector3){position.x + (-half_len * cos_heading + half_width * sin_heading), + position.y + (-half_len * sin_heading - half_width * cos_heading), position.z}, + (Vector3){position.x + (-half_len * cos_heading - half_width * sin_heading), + position.y + (-half_len * sin_heading + half_width * cos_heading), position.z}, }; - if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { + if (agent_index == env->human_agent_idx && + !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { draw_agent_obs(env, agent_index, mode, obs_only, lasers); } - if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ + + if ((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx) { continue; } // --- Draw the car --- - - Vector3 carPos = { position.x, position.y, position.z }; - Color car_color = GRAY; // default for static - if (is_expert) car_color = GOLD; // expert replay - if (is_active_agent) car_color = BLUE; // policy-controlled - if (is_active_agent && env->entities[i].collision_state > 0) car_color = RED; + Color car_color = GRAY; // default for static + if (is_expert) + car_color = GOLD; // expert replay + if (is_active_agent) { + // Distinguish ego (magenta) from co-player (blue) so renders + // make it obvious which policy controls which car. Without + // this both look identical and behavioral debugging is + // ambiguous. + if (env->entities[i].is_ego) + car_color = MAGENTA; + else + car_color = BLUE; + } + if (is_active_agent && env->entities[i].collision_state > 0) + car_color = RED; rlSetLineWidth(3.0f); for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], car_color); + DrawLine3D(corners[j], corners[(j + 1) % 4], car_color); } // --- Draw a heading arrow pointing forward --- Vector3 arrowStart = position; - Vector3 arrowEnd = { - position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car - position.y + sin_heading * half_len * 1.5f, - position.z - }; + Vector3 arrowEnd = {position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car + position.y + sin_heading * half_len * 1.5f, position.z}; DrawLine3D(arrowStart, arrowEnd, car_color); - DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip + DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip - } - else { + } else { // Agent view rlPushMatrix(); // Translate to position, rotate around Y axis, then draw rlTranslatef(position.x, position.y, position.z); - rlRotatef(heading*RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees - // Determine color based on status - Color object_color = PUFF_BACKGROUND2; // fill color unused for model tint - Color outline_color = PUFF_CYAN; // not used for model tint - Model car_model = client->cars[5]; - if(is_active_agent){ - car_model = client->cars[client->car_assignments[i %64]]; - } - if(agent_index == env->human_agent_idx){ - object_color = PUFF_CYAN; - outline_color = PUFF_WHITE; - } - if(is_active_agent && env->entities[i].collision_state > 0) { - car_model = client->cars[0]; // Collided agent + rlRotatef(heading * RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees + + // Select car model (skip index 0) + Model car_model = client->cars[(i % 5) + 1]; // Cycles through indices 1-5 + + if (agent_index == env->human_agent_idx) { + car_model = client->cars[0]; // Ego agent always uses red car + } else if (is_active_agent) { + + car_model = client->cars[(i % 5) + 1]; + + if (env->entities[i].collision_state > 0) { + car_model = client->cars[0]; // Collided agents use red + } } - // Draw obs for human selected agent - if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { + // Draw obs for selected agent index + if (agent_index == env->human_agent_idx && + (!env->entities[agent_index].metrics_array[REACHED_GOAL_IDX] || + env->goal_behavior == GOAL_GENERATE_NEW || env->goal_behavior == GOAL_STOP)) { draw_agent_obs(env, agent_index, mode, obs_only, lasers); } + // Draw cube for cars static and active // Calculate scale factors based on desired size and model dimensions - BoundingBox bounds = GetModelBoundingBox(car_model); - Vector3 model_size = { - bounds.max.x - bounds.min.x, - bounds.max.y - bounds.min.y, - bounds.max.z - bounds.min.z - }; - Vector3 scale = { - size.x / model_size.x, - size.y / model_size.y, - size.z / model_size.z - }; - if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ - rlPopMatrix(); - continue; + Vector3 model_size = {bounds.max.x - bounds.min.x, bounds.max.y - bounds.min.y, + bounds.max.z - bounds.min.z}; + Vector3 scale = {size.x / model_size.x, size.y / model_size.y, size.z / model_size.z}; + // if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ + // rlPopMatrix(); + // continue; + // } + if (env->entities[i].type == CYCLIST) { + scale = (Vector3){0.01, 0.01, 0.01}; + car_model = client->cyclist; + } + if (env->entities[i].type == PEDESTRIAN) { + scale = (Vector3){2, 2, 2}; + car_model = client->pedestrian; } - DrawModelEx(car_model, (Vector3){0, 0, 0}, (Vector3){1, 0, 0}, 90.0f, scale, WHITE); { - float cos_heading = env->entities[i].heading_x; - float sin_heading = env->entities[i].heading_y; float half_len = env->entities[i].length * 0.5f; float half_width = env->entities[i].width * 0.5f; Vector3 corners[4] = { - (Vector3){ 0 + ( half_len * cos_heading - half_width * sin_heading), 0 + ( half_len * sin_heading + half_width * cos_heading), 0 }, - (Vector3){ 0 + ( half_len * cos_heading + half_width * sin_heading), 0 + ( half_len * sin_heading - half_width * cos_heading), 0 }, - (Vector3){ 0 + (-half_len * cos_heading + half_width * sin_heading), 0 + (-half_len * sin_heading - half_width * cos_heading), 0 }, - (Vector3){ 0 + (-half_len * cos_heading - half_width * sin_heading), 0 + (-half_len * sin_heading + half_width * cos_heading), 0 }, + (Vector3){half_len, -half_width, 0}, // Front-left + (Vector3){half_len, half_width, 0}, // Front-right + (Vector3){-half_len, half_width, 0}, // Back-right + (Vector3){-half_len, -half_width, 0}, // Back-left }; - Color wire_color = GRAY; // static - if (!is_active_agent && env->entities[i].mark_as_expert == 1) wire_color = GOLD; // expert replay - if (is_active_agent) wire_color = BLUE; // policy - if (is_active_agent && env->entities[i].collision_state > 0) wire_color = RED; + Color wire_color = GRAY; // static + if (!is_active_agent && env->entities[i].mark_as_expert == 1) + wire_color = GOLD; // expert replay + if (is_active_agent) + wire_color = BLUE; // policy + if (is_active_agent && env->entities[i].collision_state > 0) + wire_color = RED; rlSetLineWidth(2.0f); for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], wire_color); + DrawLine3D(corners[j], corners[(j + 1) % 4], wire_color); } } rlPopMatrix(); } // FPV Camera Control - if(IsKeyDown(KEY_SPACE) && env->human_agent_idx== agent_index){ - if(env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]){ - env->human_agent_idx = rand() % env->active_agent_count; - } - Vector3 camera_position = (Vector3){ - position.x - (25.0f * cosf(heading)), - position.y - (25.0f * sinf(heading)), - position.z + 15 - }; + if (IsKeyDown(KEY_SPACE) && env->human_agent_idx == agent_index) { + Vector3 camera_position = (Vector3){position.x - (25.0f * cosf(heading)), + position.y - (25.0f * sinf(heading)), position.z + 15}; - Vector3 camera_target = (Vector3){ - position.x + 40.0f * cosf(heading), - position.y + 40.0f * sinf(heading), - position.z - 5.0f - }; + Vector3 camera_target = (Vector3){position.x + 40.0f * cosf(heading), + position.y + 40.0f * sinf(heading), position.z - 5.0f}; client->camera.position = camera_position; client->camera.target = camera_target; client->camera.up = (Vector3){0, 0, 1}; } - if(IsKeyReleased(KEY_SPACE)){ + if (IsKeyReleased(KEY_SPACE)) { client->camera.position = client->default_camera_position; client->camera.target = client->default_camera_target; client->camera.up = (Vector3){0, 0, 1}; } // Draw goal position for active agents - - if(!is_active_agent || env->entities[i].valid == 0) { + if (!is_active_agent || env->entities[i].valid == 0) { continue; } - if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ - DrawSphere((Vector3){ - env->entities[i].goal_position_x, - env->entities[i].goal_position_y, - 1 - }, 0.5f, DARKGREEN); - - DrawCircle3D((Vector3){ - env->entities[i].goal_position_x, - env->entities[i].goal_position_y, - 0.1f - }, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) { + DrawSphere((Vector3){env->entities[i].goal_position_x, env->entities[i].goal_position_y, 1}, 0.5f, + DARKGREEN); + + DrawCircle3D((Vector3){env->entities[i].goal_position_x, env->entities[i].goal_position_y, 0.1f}, + env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.9f)); } } // Draw road elements - if(env->entities[i].type <=3 && env->entities[i].type >= 7){ + if (env->entities[i].type <= 3 && env->entities[i].type >= 7) { continue; } - for(int j = 0; j < env->entities[i].array_size - 1; j++) { - Vector3 start = { - env->entities[i].traj_x[j], - env->entities[i].traj_y[j], - 1 - }; - Vector3 end = { - env->entities[i].traj_x[j + 1], - env->entities[i].traj_y[j + 1], - 1 - }; + for (int j = 0; j < env->entities[i].array_size - 1; j++) { + Vector3 start = {env->entities[i].traj_x[j], env->entities[i].traj_y[j], 1}; + Vector3 end = {env->entities[i].traj_x[j + 1], env->entities[i].traj_y[j + 1], 1}; Color lineColor = GRAY; - if (env->entities[i].type == ROAD_LANE) lineColor = GRAY; - else if (env->entities[i].type == ROAD_LINE) lineColor = BLUE; - else if (env->entities[i].type == ROAD_EDGE) lineColor = WHITE; - else if (env->entities[i].type == DRIVEWAY) lineColor = RED; - if(env->entities[i].type != ROAD_EDGE){ - continue; - } - if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ - draw_road_edge(env, start.x, start.y, end.x, end.y); + if (env->entities[i].type == ROAD_LANE) + lineColor = Fade(SOFT_YELLOW, 0.25f); + else if (env->entities[i].type == ROAD_LINE) + lineColor = WHITE; + else if (env->entities[i].type == ROAD_EDGE) + lineColor = WHITE; + else if (env->entities[i].type == DRIVEWAY) + lineColor = RED; + + if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) { + if (env->entities[i].type == ROAD_EDGE) { + draw_road_edge(env, start.x, start.y, end.x, end.y); + } else if (env->entities[i].type == ROAD_LANE || env->entities[i].type == ROAD_LINE) { + // Draw road lanes and lines as purple lines + rlSetLineWidth(2.0f); + DrawLine3D(start, end, lineColor); + } } } } - if(show_grid) { - // Draw grid cells using the stored bounds - float grid_start_x = env->grid_map->top_left_x; - float grid_start_y = env->grid_map->bottom_right_y; - for(int i = 0; i < env->grid_map->grid_cols; i++) { - for(int j = 0; j < env->grid_map->grid_rows; j++) { - float x = grid_start_x + i*GRID_CELL_SIZE; - float y = grid_start_y + j*GRID_CELL_SIZE; - DrawCubeWires( - (Vector3){x + GRID_CELL_SIZE/2, y + GRID_CELL_SIZE/2, 1}, - GRID_CELL_SIZE, GRID_CELL_SIZE, 0.1f, PUFF_BACKGROUND2); - } - } - } EndMode3D(); // Draw track indices for the tracks to predict - if (mode == 1 && env->control_mode == CONTROL_TRACKS_TO_PREDICT) { - float map_width = env->grid_map->bottom_right_x - env->grid_map->top_left_x; + if (mode == 1 && env->control_mode == CONTROL_WOSAC) { float map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y; float pixels_per_world_unit = client->height / map_height; @@ -3242,151 +3492,257 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, float raw_x = -env->entities[agent_idx].x * pixels_per_world_unit; float raw_y = env->entities[agent_idx].y * pixels_per_world_unit; - int screen_x = (int)raw_x + client->width/2 + 20; - int screen_y = (int)raw_y + client->height/2 - 25; + int screen_x = (int)raw_x + client->width / 2 + 20; + int screen_y = (int)raw_y + client->height / 2 - 25; - if (screen_x >= 0 && screen_x <= client->width && - screen_y >= 0 && screen_y <= client->height) { + if (screen_x >= 0 && screen_x <= client->width && screen_y >= 0 && screen_y <= client->height) { char text[32]; snprintf(text, sizeof(text), "%d", womd_track_idx); int text_width = MeasureText(text, 20); - DrawText(text, screen_x - text_width/2, screen_y, 20, PUFF_WHITE); + DrawText(text, screen_x - text_width / 2, screen_y, 20, PUFF_WHITE); } } } } -void saveTopDownImage(Drive* env, Client* client, const char *filename, RenderTexture2D target, int map_height, int obs, int lasers, int trajectories, int frame_count, float* path, int log_trajectories, int show_grid){ - // Top-down orthographic camera - Camera3D camera = {0}; - camera.position = (Vector3){ 0.0f, 0.0f, 500.0f }; // above the scene - camera.target = (Vector3){ 0.0f, 0.0f, 0.0f }; // look at origin - camera.up = (Vector3){ 0.0f, -1.0f, 0.0f }; - camera.fovy = map_height; - camera.projection = CAMERA_ORTHOGRAPHIC; - Color road = (Color){35, 35, 37, 255}; +// Headless rendering helper: write frame to ffmpeg pipe +static void write_frame_to_pipe(Client *client) { + if (client->recorder_pipefd[1] < 0) + return; - BeginTextureMode(target); - ClearBackground(road); - BeginMode3D(camera); - rlEnableDepthTest(); + int w = (int)client->width; + int h = (int)client->height; + unsigned char *screen_data = rlReadScreenPixels(w, h); + if (screen_data) { + write(client->recorder_pipefd[1], screen_data, w * h * 4); + RL_FREE(screen_data); + } +} + +// Render with view_mode and draw_traces parameters (for Python-based rendering) +void c_render_with_mode(Drive *env, int view_mode, int draw_traces, int current_scenario, int k_scenarios) { + if (env->client == NULL) { + env->client = make_client(env); + } + if (env->client == NULL) + return; // make_client may fail in headless mode - // Draw log trajectories FIRST (in background at lower Z-level) - if(log_trajectories){ - for(int i=0; iactive_agent_count;i++){ + Client *client = env->client; + Color road = (Color){35, 35, 37, 255}; + + if (env->render_mode == RENDER_HEADLESS) { + // Headless rendering mode. + // Use the CURRENT env's grid bounds for camera fovy. The original code + // used client->original_map_height (captured once at make_client) for + // "consistent resolution across scenarios", but with map_rand_per_scenario + // the new map can have entirely different bounds. Locking to scenario 0's + // dims meant new maps got cropped or off-centered, even though the env + // state was correct (ego scores stay ~0.99 in s_1; the agents are fine, + // just rendered outside the camera view). Read live so the camera always + // frames the actual current map. + float live_map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y; + float render_map_height = live_map_height > 0.0f ? live_map_height : client->original_map_height; + + Camera3D camera = {0}; + + if (view_mode == VIEW_MODE_SIM_STATE) { + // Top-down orthographic view of full simulation state + // Using original_map_height keeps pixels-per-world-unit consistent + camera.position = (Vector3){0.0f, 0.0f, 400.0f}; + camera.target = (Vector3){0.0f, 0.0f, 0.0f}; + camera.up = (Vector3){0.0f, -1.0f, 0.0f}; + camera.projection = CAMERA_ORTHOGRAPHIC; + camera.fovy = render_map_height; + + BeginDrawing(); + ClearBackground(road); + BeginMode3D(camera); + + if (draw_traces) { + // Draw trajectory traces for all active agents + for (int i = 0; i < env->active_agent_count; i++) { int idx = env->active_agent_indices[i]; - for(int j=0; jentities[idx].array_size;j++){ - float x = env->entities[idx].traj_x[j]; - float y = env->entities[idx].traj_y[j]; - float valid = env->entities[idx].traj_valid[j]; - if(!valid) continue; - DrawSphere((Vector3){x,y,0.5f}, 0.3f, Fade(LIGHTGREEN, 0.6f)); + int t_end = env->scenario_length; + if (t_end > env->entities[idx].array_size) { + t_end = env->entities[idx].array_size; + } + for (int t = env->init_steps; t < t_end; t++) { + if (env->entities[idx].traj_valid[t]) { + DrawPoint3D((Vector3){env->entities[idx].traj_x[t], env->entities[idx].traj_y[t], 0.5f}, + LIGHTBLUE); + } } } } + draw_scene(env, client, 1, 0, 0, 0); + EndMode3D(); + + } else if (view_mode == VIEW_MODE_BEV_AGENT_OBS) { + // Bird's eye view centered on human-controlled agent + int agent_idx = env->active_agent_indices[env->human_agent_idx]; + Entity *agent = &env->entities[agent_idx]; + + camera.position = (Vector3){agent->x, agent->y, 400.0f}; + camera.target = (Vector3){agent->x, agent->y, 0.0f}; + camera.up = (Vector3){0.0f, -1.0f, 0.0f}; + camera.projection = CAMERA_ORTHOGRAPHIC; + camera.fovy = env->grid_map->vision_range * GRID_CELL_SIZE * 2.0f; + + BeginDrawing(); + ClearBackground(road); + BeginMode3D(camera); + draw_scene(env, client, 1, 1, 0, 0); + EndMode3D(); - // Draw current path trajectories SECOND (slightly higher than log trajectories) - if(trajectories){ - for(int i=0; iactive_agent_indices[env->human_agent_idx]; + Entity *agent = &env->entities[agent_idx]; - // Draw main scene LAST (on top) - draw_scene(env, client, 1, obs, lasers, show_grid); + camera.position = + (Vector3){agent->x - (25.0f * cosf(agent->heading)), agent->y - (25.0f * sinf(agent->heading)), 15.0f}; + camera.target = + (Vector3){agent->x + 40.0f * cosf(agent->heading), agent->y + 40.0f * sinf(agent->heading), 1.0f}; + camera.up = (Vector3){0.0f, 0.0f, 1.0f}; + camera.fovy = 60.0f; + camera.projection = CAMERA_PERSPECTIVE; - EndMode3D(); - EndTextureMode(); - - // save to file - Image img = LoadImageFromTexture(target.texture); - ImageFlipVertical(&img); - ExportImage(img, filename); - UnloadImage(img); -} - -void saveAgentViewImage(Drive* env, Client* client, const char *filename, RenderTexture2D target, int map_height, int obs_only, int lasers, int show_grid) { - // Agent perspective camera following the human agent - int agent_idx = env->active_agent_indices[env->human_agent_idx]; - Entity* agent = &env->entities[agent_idx]; - - Camera3D camera = {0}; - // Position camera behind and above the agent - camera.position = (Vector3){ - agent->x - (25.0f * cosf(agent->heading)), - agent->y - (25.0f * sinf(agent->heading)), - 15.0f - }; - camera.target = (Vector3){ - agent->x + 40.0f * cosf(agent->heading), - agent->y + 40.0f * sinf(agent->heading), - 1.0f - }; - camera.up = (Vector3){ 0.0f, 0.0f, 1.0f }; - camera.fovy = 45.0f; - camera.projection = CAMERA_PERSPECTIVE; + BeginDrawing(); + ClearBackground(road); + BeginMode3D(camera); + draw_scene(env, client, 0, 0, 0, 1); + EndMode3D(); + } - Color road = (Color){35, 35, 37, 255}; + // Draw scenario counter overlay (2D text on top of 3D scene) + if (k_scenarios > 1) { + char scenario_text[64]; + snprintf(scenario_text, sizeof(scenario_text), "Scenario %d / %d", current_scenario + 1, k_scenarios); + DrawText(scenario_text, 40, 40, 120, WHITE); + } + + EndDrawing(); + + // Write frame to ffmpeg pipe + write_frame_to_pipe(client); - BeginTextureMode(target); + } else { + // Interactive window mode - use default rendering + BeginDrawing(); ClearBackground(road); - BeginMode3D(camera); - rlEnableDepthTest(); - draw_scene(env, client, 0, obs_only, lasers, show_grid); // mode=0 for agent view + BeginMode3D(client->camera); + handle_camera_controls(env->client); + draw_scene(env, client, 0, 0, 0, 0); EndMode3D(); - EndTextureMode(); - - // Save to file - Image img = LoadImageFromTexture(target.texture); - ImageFlipVertical(&img); - ExportImage(img, filename); - UnloadImage(img); + EndDrawing(); + } } -void c_render(Drive* env) { +// Original c_render for backward compatibility (interactive window mode) +void c_render(Drive *env) { if (env->client == NULL) { env->client = make_client(env); } - Client* client = env->client; + Client *client = env->client; BeginDrawing(); Color road = (Color){35, 35, 37, 255}; ClearBackground(road); BeginMode3D(client->camera); handle_camera_controls(env->client); draw_scene(env, client, 0, 0, 0, 0); + EndMode3D(); + + if (IsKeyPressed(KEY_TAB)) { + env->human_agent_idx = (env->human_agent_idx + 1) % env->active_agent_count; + } + // Draw debug info - DrawText(TextFormat("Camera Position: (%.2f, %.2f, %.2f)", - client->camera.position.x, - client->camera.position.y, - client->camera.position.z), 10, 10, 20, PUFF_WHITE); - DrawText(TextFormat("Camera Target: (%.2f, %.2f, %.2f)", - client->camera.target.x, - client->camera.target.y, - client->camera.target.z), 10, 30, 20, PUFF_WHITE); + DrawText(TextFormat("Camera Position: (%.2f, %.2f, %.2f)", client->camera.position.x, client->camera.position.y, + client->camera.position.z), + 10, 10, 20, PUFF_WHITE); + DrawText(TextFormat("Camera Target: (%.2f, %.2f, %.2f)", client->camera.target.x, client->camera.target.y, + client->camera.target.z), + 10, 30, 20, PUFF_WHITE); DrawText(TextFormat("Timestep: %d", env->timestep), 10, 50, 20, PUFF_WHITE); - // acceleration & steering + int human_idx = env->active_agent_indices[env->human_agent_idx]; DrawText(TextFormat("Controlling Agent: %d", env->human_agent_idx), 10, 70, 20, PUFF_WHITE); DrawText(TextFormat("Agent Index: %d", human_idx), 10, 90, 20, PUFF_WHITE); + + // Display current action values - yellow when controlling, white otherwise + Color action_color = IsKeyDown(KEY_LEFT_SHIFT) ? YELLOW : PUFF_WHITE; + + if (env->action_type == 0) { // discrete + int *action_array = (int *)env->actions; + int action_val = action_array[env->human_agent_idx]; + + if (env->dynamics_model == CLASSIC) { + int num_steer = 13; + int accel_idx = action_val / num_steer; + int steer_idx = action_val % num_steer; + float accel_value = ACCELERATION_VALUES[accel_idx]; + float steer_value = STEERING_VALUES[steer_idx]; + + DrawText(TextFormat("Acceleration: %.2f m/s^2", accel_value), 10, 110, 20, action_color); + DrawText(TextFormat("Steering: %.3f", steer_value), 10, 130, 20, action_color); + } else if (env->dynamics_model == JERK) { + int num_lat = 3; + int jerk_long_idx = action_val / num_lat; + int jerk_lat_idx = action_val % num_lat; + float jerk_long_value = JERK_LONG[jerk_long_idx]; + float jerk_lat_value = JERK_LAT[jerk_lat_idx]; + + DrawText(TextFormat("Longitudinal Jerk: %.2f m/s^3", jerk_long_value), 10, 110, 20, action_color); + DrawText(TextFormat("Lateral Jerk: %.2f m/s^3", jerk_lat_value), 10, 130, 20, action_color); + } + } else { // continuous + float (*action_array_f)[2] = (float (*)[2])env->actions; + DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, action_color); + DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, action_color); + } + + // Show key press status + int status_y = 150; + if (IsKeyDown(KEY_LEFT_SHIFT)) { + DrawText("[shift pressed]", 10, status_y, 20, YELLOW); + status_y += 20; + } + if (IsKeyDown(KEY_SPACE)) { + DrawText("[space pressed]", 10, status_y, 20, YELLOW); + status_y += 20; + } + if (IsKeyDown(KEY_LEFT_CONTROL)) { + DrawText("[ctrl pressed]", 10, status_y, 20, YELLOW); + status_y += 20; + } + // Controls help - DrawText("Controls: W/S - Accelerate/Brake, A/D - Steer, 1-4 - Switch Agent", - 10, client->height - 30, 20, PUFF_WHITE); - // acceleration & steering - if (env->action_type == 1) { // continuous (float) - float (*action_array_f)[2] = (float(*)[2])env->actions; - DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE); - DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE); - } else { // discrete (int) - int (*action_array)[2] = (int(*)[2])env->actions; - DrawText(TextFormat("Acceleration: %d", action_array[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE); - DrawText(TextFormat("Steering: %d", action_array[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE); - } - DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, 150, 20, PUFF_WHITE); - DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, 170, 20, PUFF_WHITE); + DrawText("Controls: SHIFT + W/S - Accelerate/Brake, SHIFT + A/D - Steer, TAB - Switch Agent", 10, + client->height - 30, 20, PUFF_WHITE); + + DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, status_y, 20, PUFF_WHITE); + DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, status_y + 20, 20, PUFF_WHITE); EndDrawing(); } -void close_client(Client* client){ +// Set the full mp4 basename (without ".mp4") for headless rendering. +// If a recorder is already running for this env, it's stopped and a new one +// is started so multiple basenames can be used in a single env lifetime. +void set_video_suffix(Drive *env, const char *basename) { + if (basename) { + strncpy(env->video_basename, basename, sizeof(env->video_basename) - 1); + env->video_basename[sizeof(env->video_basename) - 1] = '\0'; + } else { + env->video_basename[0] = '\0'; + } + + if (env->client != NULL && env->render_mode == RENDER_HEADLESS) { + start_video_recorder(env->client, env->video_basename); + } +} + +void close_client(Client *client) { + stop_video_recorder(client); for (int i = 0; i < 6; i++) { UnloadModel(client->cars[i]); } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index c3c6be27af..357d37649c 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -3,9 +3,20 @@ import json import struct import os +from enum import IntEnum import pufferlib from pufferlib.ocean.drive import binding import torch +from multiprocessing import Pool, cpu_count +from tqdm import tqdm + + +class RenderView(IntEnum): + """View modes for rendering.""" + + FULL_SIM_STATE = 0 # Top-down orthographic view of full simulation + BEV_AGENT_OBS = 1 # Bird's eye view centered on agent observation + AGENT_PERSPECTIVE = 2 # Third-person chase camera following agent class Drive(pufferlib.PufferEnv): @@ -20,13 +31,18 @@ def __init__( reward_offroad_collision=-0.1, reward_goal=1.0, reward_goal_post_respawn=0.5, - reward_ade=0.0, + reward_lane_align=0.0, # GIGAFLOW lane alignment reward (0 = disabled) + reward_vel_align=1.0, # Velocity alignment coefficient for lane reward goal_behavior=0, + goal_target_distance=10.0, goal_radius=2.0, + goal_speed=20.0, collision_behavior=0, offroad_behavior=0, dt=0.1, scenario_length=None, + episode_length=None, + termination_mode=None, resample_frequency=91, num_maps=100, num_agents=512, @@ -45,25 +61,62 @@ def __init__( co_player_enabled=False, num_ego_agents=512, co_player_policy={}, + map_dir="resources/drive/binaries/training", + use_all_maps=False, + report_all_scenarios=False, + map_seed=None, + external_co_player_actions=False, + worker_idx=0, + co_player_conditioning_shm=None, + map_rand_per_scenario=False, + condition_rand_per_scenario=False, + entropy_curriculum_enabled=False, + entropy_curriculum_episodes_start=0, + k_eff_curriculum_enabled=False, + k_eff_curriculum_episodes_per_stage=30, + ego_is_oracle=False, + reward_only_last_scenario=False, ): # env self.dt = dt + # Convert render_mode string to integer constant + if render_mode is None or render_mode == 0: + self._render_mode_int = binding.RENDER_OFF + elif render_mode == 1 or render_mode == "headless": + self._render_mode_int = binding.RENDER_HEADLESS + elif render_mode == 2 or render_mode == "window" or render_mode == "human": + self._render_mode_int = binding.RENDER_WINDOW + else: + self._render_mode_int = binding.RENDER_OFF self.render_mode = render_mode + self.report_all_scenarios = report_all_scenarios self.num_maps = num_maps self.report_interval = report_interval self.reward_vehicle_collision = reward_vehicle_collision self.reward_offroad_collision = reward_offroad_collision self.reward_goal = reward_goal self.reward_goal_post_respawn = reward_goal_post_respawn + self.reward_lane_align = reward_lane_align + self.reward_vel_align = reward_vel_align self.goal_radius = goal_radius + self.goal_speed = goal_speed self.goal_behavior = goal_behavior + self.goal_target_distance = goal_target_distance self.collision_behavior = collision_behavior self.offroad_behavior = offroad_behavior - self.reward_ade = reward_ade self.human_agent_idx = human_agent_idx self.scenario_length = scenario_length + self.termination_mode = termination_mode self.resample_frequency = resample_frequency self.ini_file = ini_file + self.use_all_maps = use_all_maps + self.map_seed = map_seed + + if episode_length != None: + self.scenario_length = episode_length + # Only set episode_length if not already set (adaptive.py sets it before calling super()) + if not hasattr(self, "episode_length"): + self.episode_length = self.scenario_length # Adaptive driving agent setup self.adaptive_driving_agent = int(adaptive_driving_agent) @@ -119,22 +172,100 @@ def __init__( self.dynamics_model = dynamics_model # Observation space calculation - base_ego_dim = 10 if self.dynamics_model == "jerk" else 7 + self.ego_features = {"classic": binding.EGO_FEATURES_CLASSIC, "jerk": binding.EGO_FEATURES_JERK}.get( + dynamics_model + ) + + self.ego_features += conditioning_dims + + # Extract observation shapes from constants + # These need to be defined in C, since they determine the shape of the arrays + self.max_road_objects = binding.MAX_ROAD_SEGMENT_OBSERVATIONS + self.max_partner_objects = binding.MAX_AGENTS - 1 + self.partner_features = binding.PARTNER_FEATURES + self.road_features = binding.ROAD_FEATURES - partner_features = 7 - road_features = 7 - max_partner_objects = 63 - max_road_objects = 200 self.num_obs = ( - base_ego_dim + conditioning_dims + max_partner_objects * partner_features + max_road_objects * road_features + self.ego_features + + self.max_partner_objects * self.partner_features + + self.max_road_objects * self.road_features ) - self.single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(self.num_obs,), dtype=np.float32) # Co-player policy setup self.population_play = co_player_enabled self.num_agents = num_agents self.num_ego_agents = num_ego_agents if self.population_play else num_agents + # When True, co-player actions are filled into self.actions[co_player_ids] + # by the *main* process (centralized GPU inference). Worker's step() + # then skips the local CPU forward in get_co_player_actions(). + self.external_co_player_actions = bool(external_co_player_actions) + # When True (and adaptive_drive with k>1), at every scenario boundary + # we re-init the C envs with FRESH map_ids (and freshly sampled co- + # player conditioning). The agents are spawned on a brand-new map + # while the EGO POLICY's K/V cache (held in main / pufferl) is NOT + # touched — so past-scenario context is the only stable signal that + # carries across scenarios. This is the experimental setup that + # actually exercises in-context adaptation. + self.map_rand_per_scenario = bool(map_rand_per_scenario) + # When True (only meaningful for k_scenarios > 1, with co_player_enabled): + # at every scenario boundary within an episode, partner conditioning is + # re-sampled from the configured ranges. The same partner POLICY weights + # are used, but the partner's effective behavior changes per scenario + # because conditioning shifts. This gives the ego policy a meaningful + # latent variable (partner type) to encode in its K/V cache without + # introducing the agent-identity misalignment that map_rand causes. + # Independent of map_rand_per_scenario. + self.condition_rand_per_scenario = bool(condition_rand_per_scenario) + # When True, partner's entropy_weight_ub is annealed up over training: + # the user-passed co_player_entropy_weight_ub is treated as the FINAL + # value, and a 4-stage schedule scales it 0.05 → 0.20 → 0.50 → 1.0 of + # the final, advancing every 30 episodes per worker. The other + # conditioning dims (collision/offroad/discount) sample at full range + # throughout. Reason: we observed ada_delta peaking early in training + # then drifting toward 0 as scores saturate; the curriculum keeps the + # task in an informative-difficulty regime for longer. + self.entropy_curriculum_enabled = bool(entropy_curriculum_enabled) + # When resuming from a checkpoint of a curriculum run, the per-env + # episode counter isn't part of the model state — pass the original + # run's ending episode count here so the curriculum picks up at the + # right stage instead of restarting from stage 0. Each worker still + # advances its own counter from this starting value. + self._entropy_curriculum_episodes_seen = int(entropy_curriculum_episodes_start) + self._pending_entropy_log = None + self._entropy_curriculum_final_ub = None # set lazily once we know co_player_entropy_weight_ub + # When True, the ego's K/V cache is reset at SOME within-episode + # scenario boundaries based on the curriculum stage. K_max is the + # configured `k_scenarios`; the curriculum has 3 stages of + # `k_eff_curriculum_episodes_per_stage` episodes each: + # Stage 0: k_eff=1 (reset at every within-episode boundary) + # Stage 1: k_eff=2 (reset at boundaries where current_scenario%2==0) + # Stage 2: k_eff=k_max (no within-episode resets) + # Implementation: at the boundary that should reset, we set + # truncations[ego_ids]=1 and terminals[ego_ids]=1 for the current + # step. pufferl picks up done_mask=t+d to reset transformer_position + # at eval time; create_episode_mask uses terminals to block + # cross-boundary attention during training. K_max=4 with stages + # k_eff∈{1,2,4} gives clean splits (boundaries 1,2,3 → reset {all}, + # {middle only}, {none}). For other K_max, only k_eff=1 and k_max + # produce uniform stages. + self.k_eff_curriculum_enabled = bool(k_eff_curriculum_enabled) + self.k_eff_curriculum_episodes_per_stage = int(k_eff_curriculum_episodes_per_stage) + self._k_eff_curriculum_episodes_seen = 0 + self._pending_k_eff_log = None + # When True, _reinit_envs_with_new_maps() donates env[0]->client to a + # C-side global before vec_close and re-attaches it to the new env[0] + # afterwards. This keeps the raylib window + ffmpeg pipe alive across + # the swap (raylib's CloseWindow → InitWindow cycle segfaults under + # xvfb), so a single mp4 captures all k scenarios with the maps + # rotating mid-stream. Set True only on the single-env render driver. + self._render_keep_client_on_swap = False + self.worker_idx = int(worker_idx) + # SHM view (numpy) of the per-worker conditioning slice. Env writes + # sampled conditioning here so the main process can read it before + # running the centralized co-player forward. None when conditioning + # is disabled or when running the per-worker CPU path. + self.co_player_conditioning_shm = co_player_conditioning_shm # Co-player conditioning setup self.co_player_conditioning = co_player_policy.get("conditioning") @@ -156,15 +287,64 @@ def __init__( self.co_player_discount_weight_lb = self.co_player_conditioning.get("discount_weight_lb", 0.98) self.co_player_discount_weight_ub = self.co_player_conditioning.get("discount_weight_ub", 0.98) + # ----- Ego oracle (NEW, isolated machinery) ----- + # When True, the ego's obs gets the partner's per-env conditioning + # vector appended at the END (after road_obs). Implementation: + # 1. Allocate a private `_c_observations` buffer the C side writes + # into (sized to the C's expected obs_dim, no oracle slots). + # 2. The pufferl-facing `self.observations` buffer is sized + # bigger (`+ oracle_dims`); each step we copy the C buffer + # into the first part and write `_oracle_obs_per_env[env]` + # into the trailing oracle slots for every ego row. + # No changes to [env.conditioning], pufferl, or reward — the + # oracle slots are pure obs signal that only the policy reads. + self.ego_is_oracle = bool(ego_is_oracle) + self.reward_only_last_scenario = bool(reward_only_last_scenario) + if self.reward_only_last_scenario and not self.adaptive_driving_agent: + raise ValueError("reward_only_last_scenario=True requires adaptive_driving_agent=True (k_scenarios > 1).") + if self.ego_is_oracle: + # Determine the partner's conditioning dim count (== oracle width). + ct = self.co_player_condition_type + if ct is None or ct == "none": + raise ValueError( + "ego_is_oracle=True requires co-player conditioning to be " + "enabled (co_player_policy.conditioning.type != 'none')." + ) + self._oracle_dims = ( + (3 if self.co_player_reward_conditioned else 0) + + (1 if self.co_player_entropy_conditioned else 0) + + (1 if self.co_player_discount_conditioned else 0) + ) + if self._oracle_dims == 0: + raise ValueError("ego_is_oracle=True but partner conditioning resolved to 0 dims.") + # Grow obs space so pufferl allocates a buffer wide enough to + # hold the appended oracle slots (placed AFTER road_obs, at + # offset = num_obs - oracle_dims). C still writes the smaller + # part into its own private `_c_observations` buffer. + self.num_obs += self._oracle_dims + self.single_observation_space = gymnasium.spaces.Box( + low=-1, high=1, shape=(self.num_obs,), dtype=np.float32 + ) + # Per-env oracle vector. Filled at every _set_co_player_conditioning + # call from a copy of `self.env_conditioning`. We don't allocate + # `_c_observations` or `_ego_env_indices` here — they need + # `num_envs`/`num_agents` which aren't known until later. + self._oracle_obs_per_env = None + self._c_observations = None + self._ego_env_indices = None + else: + self._oracle_dims = 0 + self.init_steps = init_steps self.init_mode_str = init_mode self.control_mode_str = control_mode + self.map_dir = map_dir if self.control_mode_str == "control_vehicles": self.control_mode = 0 elif self.control_mode_str == "control_agents": self.control_mode = 1 - elif self.control_mode_str == "control_tracks_to_predict": + elif self.control_mode_str == "control_wosac": self.control_mode = 2 elif self.control_mode_str == "control_sdc_only": self.control_mode = 3 @@ -188,7 +368,8 @@ def __init__( # Multi discrete (assume independence) # self.single_action_space = gymnasium.spaces.MultiDiscrete([7, 13]) elif dynamics_model == "jerk": - self.single_action_space = gymnasium.spaces.MultiDiscrete([4, 3]) + # Joint action space (assume dependence) - 4 longitudinal × 3 lateral = 12 + self.single_action_space = gymnasium.spaces.MultiDiscrete([4 * 3]) else: raise ValueError(f"dynamics_model must be 'classic' or 'jerk'. Got: {dynamics_model}") elif action_type == "continuous": @@ -198,18 +379,18 @@ def __init__( self._action_type_flag = 0 if action_type == "discrete" else 1 - # Check if resources directory exists - binary_path = "resources/drive/binaries/map_000.bin" + # Check if resources directory exists (check map_001 since some datasets start at 001) + binary_path = f"{map_dir}/map_001.bin" if not os.path.exists(binary_path): raise FileNotFoundError( - f"Required directory {binary_path} not found. Please ensure the Drive maps are downloaded and installed correctly per docs." + f"Required file {binary_path} not found. Please ensure the Drive maps are downloaded and installed correctly per docs." ) # Check maps availability - available_maps = len([name for name in os.listdir("resources/drive/binaries") if name.endswith(".bin")]) + available_maps = len([name for name in os.listdir(map_dir) if name.endswith(".bin")]) if num_maps > available_maps: raise ValueError( - f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). Please reduce num_maps or add more maps to resources/drive/binaries." + f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). Please reduce num_maps or add more maps to {map_dir}." ) if self.population_play: if self.num_ego_agents > num_agents: @@ -226,8 +407,18 @@ def __init__( if self.population_play: self.co_player_policy_name = co_player_policy.get("policy_name") self.co_player_rnn_name = co_player_policy.get("rnn_name") - self.co_player_policy = co_player_policy.get("co_player_policy_func") - self._set_co_player_state() + if self.external_co_player_actions: + # Main owns the policy + state on GPU; worker only needs the + # action slots (co_player_ids) to be filled via shared memory + # before vec_step. Skip the per-worker CPU model entirely. + self.co_player_policy = None + self.co_player_device = None + else: + self.co_player_policy = co_player_policy.get("co_player_policy_func") + # Co-player runs in forked subprocess - must stay on CPU + # (CUDA doesn't work with fork) + self.co_player_device = torch.device("cpu") + self._set_co_player_state() super().__init__(buf=buf) if self.population_play: @@ -238,12 +429,27 @@ def __init__( else: self.co_player_actions = np.zeros(co_player_atn_space.shape, dtype=np.int32) + # Allocate the private C-only obs buffer + ego→env index map (oracle + # path only). C writes into `_c_observations` (no oracle slots); + # we copy + append into `self.observations` (which has oracle slots) + # every step/reset. `_oracle_obs_per_env` was filled by + # `_set_co_player_conditioning` during the prior `_set_env_variables` + # call (which always fires here because oracle requires co-player + # conditioning to be on). + if self.ego_is_oracle: + self._c_obs_dim = self.num_obs - self._oracle_dims + self._c_observations = np.zeros((self.num_agents, self._c_obs_dim), dtype=np.float32) + self._rebuild_ego_env_indices() + env_ids = [] for i in range(self.num_envs): cur = self.agent_offsets[i] nxt = self.agent_offsets[i + 1] + # Oracle: hand C its own private obs slice (smaller, no oracle + # slots). Otherwise C uses the pufferl-provided buffer directly. + obs_slice_for_c = self._c_observations[cur:nxt] if self.ego_is_oracle else self.observations[cur:nxt] env_id = binding.env_init( - self.observations[cur:nxt], + obs_slice_for_c, self.actions[cur:nxt], self.rewards[cur:nxt], self.terminals[cur:nxt], @@ -256,13 +462,17 @@ def __init__( reward_offroad_collision=reward_offroad_collision, reward_goal=reward_goal, reward_goal_post_respawn=reward_goal_post_respawn, - reward_ade=reward_ade, + reward_lane_align=self.reward_lane_align, + reward_vel_align=self.reward_vel_align, goal_radius=goal_radius, + goal_speed=goal_speed, goal_behavior=self.goal_behavior, + goal_target_distance=self.goal_target_distance, collision_behavior=self.collision_behavior, offroad_behavior=self.offroad_behavior, dt=dt, - scenario_length=(int(scenario_length) if scenario_length is not None else None), + scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None), + termination_mode=(int(self.termination_mode) if self.termination_mode is not None else 0), max_controlled_agents=self.max_controlled_agents, map_id=self.map_ids[i], max_agents=nxt - cur, @@ -288,6 +498,8 @@ def __init__( discount_weight_ub=self.discount_weight_ub, init_mode=self.init_mode, control_mode=self.control_mode, + map_dir=map_dir, + render_mode=self._render_mode_int, ) env_ids.append(env_id) @@ -295,15 +507,30 @@ def __init__( def reset(self, seed=0): binding.vec_reset(self.c_envs, seed) + # Oracle: copy C obs into pufferl buffer + write oracle slots. + self._refresh_ego_oracle_obs() info = [] if self.population_play: info.append(self.ego_ids) - self._reset_co_player_state() + if self.external_co_player_actions: + # Pass the real co_player_ids so main does not have to + # guess by complement (which would include padding slots + # and pollute the shared KV cache). + info.append({"_external_co_player_ids": self.co_player_ids}) + # Tell main to drop the K/V cache, mirroring OFF's + # `_reset_co_player_state()` here. Initial conditioning + # was written to SHM in `_set_env_variables` and stays + # valid across the episode (matching OFF, which doesn't + # re-sample at reset either). + info.append({"_external_reset_co_cache": True}) + else: + self._reset_co_player_state() self.tick = 0 return self.observations, info def _set_env_variables(self): my_shared_tuple = binding.shared( + map_dir=self.map_dir, num_agents=self.num_agents, num_maps=self.num_maps, init_mode=self.init_mode, @@ -313,6 +540,9 @@ def _set_env_variables(self): goal_behavior=self.goal_behavior, population_play=self.population_play, num_ego_agents=self.num_ego_agents, + goal_target_distance=self.goal_target_distance, + use_all_maps=self.use_all_maps, + map_seed=self.map_seed if self.map_seed is not None else -1, ) if self.population_play: @@ -377,7 +607,11 @@ def _set_env_variables(self): self.agent_offsets, self.map_ids, self.num_envs = my_shared_tuple self.ego_ids = [i for i in range(self.agent_offsets[-1])] if len(self.ego_ids) != self.num_agents: - raise ValueError("mismatch between number of ego agents and number of agents") + print( + f"Warning: requested {self.num_agents} agents but maps contain {len(self.ego_ids)} valid agents. Adjusting.", + flush=True, + ) + self.num_agents = len(self.ego_ids) self.local_co_player_ids = [[] for i in range(self.num_envs)] self.local_ego_ids = [[0] for i in range(self.num_envs)] @@ -388,30 +622,62 @@ def get_co_player_actions(self): if self.co_player_condition_type != "none": co_player_obs = self._add_co_player_conditioning(co_player_obs) - co_player_obs = torch.as_tensor(co_player_obs) + # Convert directly to device for GPU acceleration + co_player_obs = torch.as_tensor(co_player_obs, device=self.co_player_device) + import sys + + sys.stdout.flush() # Prevent multiprocessing deadlock logits, value = self.co_player_policy.forward_eval(co_player_obs, self.state) - co_player_action, logprob, _ = pufferlib.pytorch.sample_logits(logits) + # Handle multi-discrete actions (logits is a tuple) vs single discrete (logits is tensor) + if isinstance(logits, tuple): + co_player_action = torch.cat([l.argmax(dim=-1, keepdim=True) for l in logits], dim=-1) + else: + co_player_action = logits.argmax(dim=-1) + # Only this transfer is necessary co_player_action = co_player_action.cpu().numpy().reshape(self.co_player_actions.shape) return co_player_action def _set_co_player_state(self): with torch.no_grad(): - self.state = dict( - lstm_h=torch.zeros(self.num_co_players, self.co_player_policy.hidden_size), - lstm_c=torch.zeros(self.num_co_players, self.co_player_policy.hidden_size), - ) + # Detect if co-player uses Transformer (has horizon) or LSTM + self.co_player_is_transformer = hasattr(self.co_player_policy, "horizon") + + if self.co_player_is_transformer: + # Transformer co-player uses streaming KV cache (see + # TransformerWrapper.forward_eval). The cache is allocated + # lazily inside forward_eval the first time it sees a state + # without "k_cache". We initialize the position counter here + # so reset_eval_state has something to zero on full reset. + self.state = dict( + transformer_position=torch.zeros(1, dtype=torch.long, device=self.co_player_device), + ) + else: + self.state = dict( + lstm_h=torch.zeros( + self.num_co_players, self.co_player_policy.hidden_size, device=self.co_player_device + ), + lstm_c=torch.zeros( + self.num_co_players, self.co_player_policy.hidden_size, device=self.co_player_device + ), + ) def _reset_co_player_state(self, done_indices=None): - """Reset LSTM state for co-players whose episodes ended""" + """Reset LSTM/Transformer state for co-players whose episodes ended""" with torch.no_grad(): if done_indices is None: # Reset all self._set_co_player_state() else: # Reset only specific co-players - device = self.state["lstm_h"].device - self.state["lstm_h"][done_indices] = 0 - self.state["lstm_c"][done_indices] = 0 + if self.co_player_is_transformer: + # Re-prime the KV cache for the done rows so that subsequent + # forward_eval calls behave as if those rows had a fresh + # zero hidden buffer (matches the original semantics of + # `state["transformer_context"][done_indices] = 0`). + self.co_player_policy.reset_eval_state(self.state, done_indices=done_indices) + else: + self.state["lstm_h"][done_indices] = 0 + self.state["lstm_c"][done_indices] = 0 def _add_co_player_conditioning(self, observations): """Add pre-sampled conditioning variables to co-player observations""" @@ -426,10 +692,96 @@ def _add_co_player_conditioning(self, observations): if observations.shape[0] != self.total_co_players: raise ValueError(f"Expected {self.total_co_players} observations, got {observations.shape[0]}") - return np.concatenate([observations[:, :7], self.cached_conditioning_array, observations[:, 7:]], axis=1) + # Use dynamic base_ego_dim based on dynamics model + base_ego_dim = binding.EGO_FEATURES_JERK if self.dynamics_model == "jerk" else binding.EGO_FEATURES_CLASSIC + return np.concatenate( + [observations[:, :base_ego_dim], self.cached_conditioning_array, observations[:, base_ego_dim:]], axis=1 + ) + + def _rebuild_ego_env_indices(self): + """Recompute self._ego_env_indices: for the k-th ego in self.ego_ids + order, the env index it belongs to. Called at init and after every + _set_env_variables (since map re-roll may change per-env ego counts). + Oracle path only.""" + if not self.ego_is_oracle: + return + ego_env_ids = [] + if self.population_play: + for env_idx, env_egos in enumerate(self.local_ego_ids): + ego_env_ids.extend([env_idx] * len(env_egos)) + else: + for env_idx in range(self.num_envs): + cur = int(self.agent_offsets[env_idx]) + nxt = int(self.agent_offsets[env_idx + 1]) + ego_env_ids.extend([env_idx] * (nxt - cur)) + self._ego_env_indices = np.asarray(ego_env_ids, dtype=np.int64) + + def _refresh_ego_oracle_obs(self): + """Copy C-side obs into the pufferl-facing buffer and write the + per-env partner-conditioning vector into the trailing oracle + slots for every ego row. No-op when oracle is off.""" + if not self.ego_is_oracle: + return + c_dim = self._c_obs_dim + # Copy C output into the leading c_obs_dim columns. (Cannot slice + # the assignment to a single np.copyto because the buffers were + # allocated separately; numpy fast path is fine.) + self.observations[:, :c_dim] = self._c_observations + # Append partner conditioning to ego rows only. Non-ego rows keep + # whatever was there (zeros from allocation; pufferl filters out + # non-ego rows downstream anyway). + if len(self.ego_ids) > 0: + self.observations[self.ego_ids, c_dim:] = self._oracle_obs_per_env[self._ego_env_indices] + + def _current_k_eff(self): + """Effective k for the ego's K/V cache horizon at the current + curriculum stage. Returns k_scenarios (i.e. K_max) when the + curriculum is disabled or has finished. Stages last + `k_eff_curriculum_episodes_per_stage` episodes each: + stage 0 → k_eff = 1 + stage 1 → k_eff = 2 + stage 2+ → k_eff = K_max + """ + if not self.k_eff_curriculum_enabled: + return self.k_scenarios + n = self._k_eff_curriculum_episodes_seen + s = self.k_eff_curriculum_episodes_per_stage + if n < s: + return 1 + elif n < 2 * s: + return 2 + else: + return self.k_scenarios + + def _k_eff_should_reset_at_current_boundary(self): + """True when the just-crossed scenario boundary should cut the ego + K/V cache under the current curriculum stage. Caller must already + have ensured this is a within-episode boundary (current_scenario != 0 + after the modulo increment).""" + k_eff = self._current_k_eff() + return self.current_scenario % k_eff == 0 def _set_co_player_conditioning(self): """Sample and store conditioning values for each environment and update all caches""" + # Entropy curriculum: scale entropy_ub based on episodes seen so far. + # Schedule: stage 0 = 0.05*final, stage 1 = 0.20*final, stage 2 = 0.50*final, + # stage 3 = 1.00*final. Each stage = 30 episodes per worker (≈30 epochs + # given ~1 episode/epoch with our nw=32 nv=32 setup). + if self.entropy_curriculum_enabled and self.co_player_entropy_conditioned: + if self._entropy_curriculum_final_ub is None: + self._entropy_curriculum_final_ub = self.co_player_entropy_weight_ub + n = self._entropy_curriculum_episodes_seen + if n < 30: + ratio = 0.05 + elif n < 60: + ratio = 0.20 + elif n < 90: + ratio = 0.50 + else: + ratio = 1.00 + self.co_player_entropy_weight_ub = ratio * self._entropy_curriculum_final_ub + self._entropy_curriculum_episodes_seen += 1 + # Update co-player counts and indices self.num_co_players_per_env = np.array([len(ids) for ids in self.local_co_player_ids], dtype=np.int32) self.total_co_players = self.num_co_players_per_env.sum() @@ -477,6 +829,133 @@ def _set_co_player_conditioning(self): else: self.cached_conditioning_array = np.empty((0, len(conditioning_dims)), dtype=np.float32) + # Oracle: sync per-env oracle vector with the freshly sampled + # partner conditioning. num_envs can change across + # _reinit_envs_with_new_maps (some maps yield no valid agents and + # get dropped C-side), so just take a fresh copy with the current + # shape rather than a fixed-size in-place write. Width invariant + # is checked by the assertion below. + if self.ego_is_oracle: + assert self.env_conditioning.shape[1] == self._oracle_dims, ( + f"oracle width mismatch: env_conditioning has " + f"{self.env_conditioning.shape[1]} dims, oracle expects {self._oracle_dims}" + ) + self._oracle_obs_per_env = self.env_conditioning.copy() + + # Stash sampled entropy stats for wandb. Index of the entropy column + # within env_conditioning depends on which dims are active above: + # reward(3) -> [collision, offroad, goal] then entropy then discount + if self.co_player_entropy_conditioned and self.env_conditioning.shape[1] > 0: + entropy_col = 3 if self.co_player_reward_conditioned else 0 + sampled = self.env_conditioning[:, entropy_col] + self._pending_entropy_log = { + "co_player/entropy_weight_ub": float(self.co_player_entropy_weight_ub), + "co_player/entropy_sampled_mean": float(sampled.mean()), + "co_player/entropy_sampled_min": float(sampled.min()), + "co_player/entropy_sampled_max": float(sampled.max()), + } + if self.entropy_curriculum_enabled: + self._pending_entropy_log["co_player/entropy_curriculum_episodes"] = int( + self._entropy_curriculum_episodes_seen + ) + + # Mirror the freshly-sampled conditioning into the shared-memory + # buffer so the main process (centralized co-player on GPU) sees the + # latest values before its next forward pass. SHM rows beyond + # `total_co_players` are left at whatever value they previously held; + # main only reads rows for active co-players. + if ( + self.co_player_conditioning_shm is not None + and self.cached_conditioning_array.shape[1] > 0 + and self.total_co_players > 0 + ): + shm = self.co_player_conditioning_shm + n = min(self.total_co_players, shm.shape[0]) + shm[:n, :] = self.cached_conditioning_array[:n, :] + + def _reinit_envs_with_new_maps(self): + """Close + recreate C envs with fresh map_ids. + + Called at episode resample boundary and (when + `map_rand_per_scenario=True`) at scenario boundaries. Note: this + sets `self.terminals[:] = 1`, which pufferl uses to wipe the ego + K/V cache — so `map_rand_per_scenario=True` is currently broken + as an ICL probe (cache + GAE both truncate at the boundary). + """ + if self._render_keep_client_on_swap: + binding.vec_donate_client(self.c_envs) + binding.vec_close(self.c_envs) + self._set_env_variables() + env_ids = [] + seed = np.random.randint(0, 2**32 - 1) + for i in range(self.num_envs): + cur = self.agent_offsets[i] + nxt = self.agent_offsets[i + 1] + obs_slice_for_c = self._c_observations[cur:nxt] if self.ego_is_oracle else self.observations[cur:nxt] + env_id = binding.env_init( + obs_slice_for_c, + self.actions[cur:nxt], + self.rewards[cur:nxt], + self.terminals[cur:nxt], + self.truncations[cur:nxt], + seed, + action_type=self._action_type_flag, + human_agent_idx=self.human_agent_idx, + dynamics_model=self.dynamics_model, + reward_vehicle_collision=self.reward_vehicle_collision, + reward_offroad_collision=self.reward_offroad_collision, + goal_radius=self.goal_radius, + goal_behavior=self.goal_behavior, + collision_behavior=self.collision_behavior, + offroad_behavior=self.offroad_behavior, + reward_goal=self.reward_goal, + reward_goal_post_respawn=self.reward_goal_post_respawn, + reward_lane_align=self.reward_lane_align, + reward_vel_align=self.reward_vel_align, + goal_speed=self.goal_speed, + goal_target_distance=self.goal_target_distance, + dt=self.dt, + scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None), + max_controlled_agents=self.max_controlled_agents, + map_id=self.map_ids[i], + use_rc=self.reward_conditioned, + use_ec=self.entropy_conditioned, + use_dc=self.discount_conditioned, + collision_weight_lb=self.collision_weight_lb, + collision_weight_ub=self.collision_weight_ub, + offroad_weight_lb=self.offroad_weight_lb, + offroad_weight_ub=self.offroad_weight_ub, + goal_weight_lb=self.goal_weight_lb, + goal_weight_ub=self.goal_weight_ub, + entropy_weight_lb=self.entropy_weight_lb, + entropy_weight_ub=self.entropy_weight_ub, + discount_weight_lb=self.discount_weight_lb, + discount_weight_ub=self.discount_weight_ub, + max_agents=nxt - cur, + ini_file=self.ini_file, + population_play=self.population_play, + num_co_players=len(self.local_co_player_ids[i]), + co_player_ids=self.local_co_player_ids[i], + ego_agent_ids=self.local_ego_ids[i], + num_ego_agents=len(self.local_ego_ids[i]), + init_steps=self.init_steps, + init_mode=self.init_mode, + control_mode=self.control_mode, + map_dir=self.map_dir, + render_mode=self._render_mode_int, + ) + env_ids.append(env_id) + self.c_envs = binding.vectorize(*env_ids) + if self._render_keep_client_on_swap: + binding.vec_adopt_client(self.c_envs) + + binding.vec_reset(self.c_envs, seed) + # Oracle: per-env ego counts may have shifted with the new map IDs; + # rebuild the ego→env index map and refresh obs. + self._rebuild_ego_env_indices() + self._refresh_ego_oracle_obs() + self.terminals[:] = 1 + def _aggregate_scenario_metrics(self, scenario_infos): """Aggregate metrics from all infos collected during a scenario.""" if not scenario_infos: @@ -526,10 +1005,6 @@ def _compute_delta_metrics(self): delta_key = f"ada_delta_{metric}" delta_metrics[delta_key] = last_metrics[metric] - first_metrics[metric] - # Add a count of how many agents this represents - if "n" in last_metrics: - delta_metrics["ada_agent_count"] = last_metrics["n"] - return delta_metrics def step(self, actions): @@ -537,29 +1012,44 @@ def step(self, actions): self.actions[self.ego_ids] = actions - if self.population_play: + if self.population_play and not self.external_co_player_actions: co_player_actions = self.get_co_player_actions() self.actions[self.co_player_ids] = co_player_actions + # When external_co_player_actions=True, the main process has already + # written co-player actions into self.actions[co_player_ids] via the + # shared-memory action buffer; nothing to do here. binding.vec_step(self.c_envs) + if self.reward_only_last_scenario and self.current_scenario != self.k_scenarios - 1: + self.rewards[:] = 0 + # Oracle: copy C obs into pufferl buffer + write oracle slots. + self._refresh_ego_oracle_obs() self.tick += 1 info = [] if self.tick % self.report_interval == 0: - log = binding.vec_log(self.c_envs) + log = binding.vec_log(self.c_envs, self.num_agents) if log: if self.adaptive_driving_agent: self.current_scenario_infos.append(log) - - # Only append to info if we're in the 0th scenario - if self.current_scenario == 0: + # For training: only report 0-shot (scenario 0) metrics + # For evaluation: report all scenarios when report_all_scenarios=True + if self.current_scenario == 0 or self.report_all_scenarios: info.append(log) - print("0th scenario metrics are ", log, flush=True) else: # Non-adaptive mode: always append info.append(log) - print("Regular metrics are ", log, flush=True) + + # Surface the entropy bound + sampled distribution that + # _set_co_player_conditioning stashed at the most recent reset. + # Drained on emit so each fresh sampling gets logged exactly once. + if self._pending_entropy_log is not None: + info.append(self._pending_entropy_log) + self._pending_entropy_log = None + if self._pending_k_eff_log is not None: + info.append(self._pending_k_eff_log) + self._pending_k_eff_log = None if self.tick % self.scenario_length == 0: if self.adaptive_driving_agent and self.current_scenario_infos: @@ -567,11 +1057,16 @@ def step(self, actions): scenario_log["scenario_id"] = self.current_scenario self.scenario_metrics.append(scenario_log) + # Log metrics for all scenarios with scenario-specific prefixes + prefixed_log = { + f"scenario_{self.current_scenario}_{k}": v for k, v in scenario_log.items() if k != "scenario_id" + } + info.append(prefixed_log) + if self.current_scenario == self.k_scenarios - 1: delta_metrics = self._compute_delta_metrics() if delta_metrics: info.append(delta_metrics) - print("delta metrics are ", delta_metrics, flush=True) self.scenario_metrics = [] @@ -579,6 +1074,60 @@ def step(self, actions): self.current_scenario = (self.current_scenario + 1) % self.k_scenarios + # Reset coplayer LSTM/Transformer state at scenario boundary so + # the partner behaves consistently. The OFF (per-worker) path + # does not re-sample conditioning here — it sticks with the + # values written to SHM at init/resample — so neither do we. + if self.population_play: + if self.external_co_player_actions: + # Signal main to drop the K/V cache so scenario N+1 + # starts fresh, mirroring the per-worker OFF path's + # `_reset_co_player_state()` here. + info.append({"_external_reset_co_cache": True}) + else: + self._reset_co_player_state() + + # MAP ROTATION per scenario: re-init the C envs with new map_ids + # while leaving the EGO POLICY's K/V cache (held in main) alone. + # This forces the policy to actually USE its past-scenario context + # because the current scene is genuinely new. + if ( + self.adaptive_driving_agent + and self.map_rand_per_scenario + and self.current_scenario != 0 # we just incremented above; 0 means we already wrapped to next episode + ): + self._reinit_envs_with_new_maps() + # PARTNER CONDITIONING ROTATION per scenario: resample the partner's + # conditioning vector. The partner POLICY is unchanged but its + # effective behavior shifts (e.g. high-entropy stochastic vs + # low-entropy deterministic) — the ego must encode partner type + # from s_0 observations. Skipped when map_rand fired above + # (reinit already re-samples conditioning via _set_env_variables). + elif ( + self.adaptive_driving_agent + and self.condition_rand_per_scenario + and self.population_play + and self.current_scenario != 0 + and self.co_player_condition_type is not None + and self.co_player_condition_type != "none" + ): + self._set_co_player_conditioning() + + # k_eff curriculum: at within-episode scenario boundaries, decide + # whether to cut the ego's K/V cache. Setting truncations=1 (and + # terminals=1) at this step makes pufferl drop the cache via + # done_mask=t+d, and during training, create_episode_mask blocks + # cross-boundary attention. current_scenario != 0 excludes the + # episode boundary itself (which is reset by the normal episode- + # done logic). Reset rule: cut if current_scenario % k_eff == 0. + if ( + self.adaptive_driving_agent + and self.current_scenario != 0 + and self._k_eff_should_reset_at_current_boundary() + ): + self.truncations[self.ego_ids] = 1 + self.terminals[self.ego_ids] = 1 + if self.tick > 0 and self.resample_frequency > 0 and self.tick % self.resample_frequency == 0: self.tick = 0 will_resample = 1 @@ -588,73 +1137,26 @@ def step(self, actions): delta_metrics = self._compute_delta_metrics() if delta_metrics: info.append(delta_metrics) - print("delta metrics 2, are ", delta_metrics, flush=True) self.scenario_metrics = [] self.current_scenario_infos = [] self.current_scenario = 0 - binding.vec_close(self.c_envs) - self._set_env_variables() - env_ids = [] - seed = np.random.randint(0, 2**32 - 1) - for i in range(self.num_envs): - cur = self.agent_offsets[i] - nxt = self.agent_offsets[i + 1] - env_id = binding.env_init( - self.observations[cur:nxt], - self.actions[cur:nxt], - self.rewards[cur:nxt], - self.terminals[cur:nxt], - self.truncations[cur:nxt], - seed, - action_type=self._action_type_flag, - human_agent_idx=self.human_agent_idx, - dynamics_model=self.dynamics_model, - reward_vehicle_collision=self.reward_vehicle_collision, - reward_offroad_collision=self.reward_offroad_collision, - reward_goal=self.reward_goal, - reward_goal_post_respawn=self.reward_goal_post_respawn, - reward_ade=self.reward_ade, - goal_radius=self.goal_radius, - goal_behavior=self.goal_behavior, - collision_behavior=self.collision_behavior, - offroad_behavior=self.offroad_behavior, - dt=self.dt, - scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None), - max_controlled_agents=self.max_controlled_agents, - map_id=self.map_ids[i], - use_rc=self.reward_conditioned, - use_ec=self.entropy_conditioned, - use_dc=self.discount_conditioned, - collision_weight_lb=self.collision_weight_lb, - collision_weight_ub=self.collision_weight_ub, - offroad_weight_lb=self.offroad_weight_lb, - offroad_weight_ub=self.offroad_weight_ub, - goal_weight_lb=self.goal_weight_lb, - goal_weight_ub=self.goal_weight_ub, - entropy_weight_lb=self.entropy_weight_lb, - entropy_weight_ub=self.entropy_weight_ub, - discount_weight_lb=self.discount_weight_lb, - discount_weight_ub=self.discount_weight_ub, - max_agents=nxt - cur, - ini_file=self.ini_file, - population_play=self.population_play, - num_co_players=len(self.local_co_player_ids[i]), - co_player_ids=self.local_co_player_ids[i], - ego_agent_ids=self.local_ego_ids[i], - num_ego_agents=len(self.local_ego_ids[i]), - init_steps=self.init_steps, - init_mode=self.init_mode, - control_mode=self.control_mode, - ) - env_ids.append(env_id) - self.c_envs = binding.vectorize(*env_ids) - - binding.vec_reset(self.c_envs, seed) - self.terminals[:] = 1 + # Advance k_eff curriculum once per real episode and stash + # the new stage's k_eff for wandb. Done before reinit so the + # log reflects the stage that the next episode will run at. + if self.k_eff_curriculum_enabled: + self._k_eff_curriculum_episodes_seen += 1 + self._pending_k_eff_log = { + "ego_curriculum/k_eff": int(self._current_k_eff()), + "ego_curriculum/episodes_seen": int(self._k_eff_curriculum_episodes_seen), + } + + self._reinit_envs_with_new_maps() if self.population_play: info.append(self.ego_ids) + if self.external_co_player_actions: + info.append({"_external_co_player_ids": self.co_player_ids}) return (self.observations, self.rewards, self.terminals, self.truncations, info) @@ -662,7 +1164,7 @@ def get_global_agent_state(self): """Get current global state of all active agents. Returns: - dict with keys 'x', 'y', 'z', 'heading', 'id' containing numpy arrays + dict with keys 'x', 'y', 'z', 'heading', 'id', 'length', 'width' containing numpy arrays of shape (num_active_agents,) """ num_agents = self.num_agents @@ -673,10 +1175,19 @@ def get_global_agent_state(self): "z": np.zeros(num_agents, dtype=np.float32), "heading": np.zeros(num_agents, dtype=np.float32), "id": np.zeros(num_agents, dtype=np.int32), + "length": np.zeros(num_agents, dtype=np.float32), + "width": np.zeros(num_agents, dtype=np.float32), } binding.vec_get_global_agent_state( - self.c_envs, states["x"], states["y"], states["z"], states["heading"], states["id"] + self.c_envs, + states["x"], + states["y"], + states["z"], + states["heading"], + states["id"], + states["length"], + states["width"], ) return states @@ -715,8 +1226,56 @@ def get_ground_truth_trajectories(self): return trajectories - def render(self): - binding.vec_render(self.c_envs, 0) + def get_road_edge_polylines(self): + """Get road edge polylines for all scenarios. + + Returns: + dict with keys 'x', 'y', 'lengths', 'scenario_id' containing numpy arrays. + x, y are flattened point coordinates; lengths indicates points per polyline. + """ + num_polylines, total_points = binding.vec_get_road_edge_counts(self.c_envs) + + polylines = { + "x": np.zeros(total_points, dtype=np.float32), + "y": np.zeros(total_points, dtype=np.float32), + "lengths": np.zeros(num_polylines, dtype=np.int32), + "scenario_id": np.zeros(num_polylines, dtype=np.int32), + } + + binding.vec_get_road_edge_polylines( + self.c_envs, + polylines["x"], + polylines["y"], + polylines["lengths"], + polylines["scenario_id"], + ) + + return polylines + + def render(self, view_mode: int = 0, draw_traces: bool = True, env_id: int = 0): + """Render the environment. + + Args: + view_mode: View mode for rendering: + 0 = VIEW_MODE_SIM_STATE (top-down orthographic) + 1 = VIEW_MODE_BEV_AGENT_OBS (bird's eye view centered on agent) + 2 = VIEW_MODE_AGENT_PERSP (third-person chase camera) + draw_traces: Whether to draw trajectory traces + env_id: Which environment to render (default 0) + """ + binding.vec_render(self.c_envs, int(view_mode), draw_traces, env_id, self.current_scenario, self.k_scenarios) + + def set_video_suffix(self, suffix: str, env_id: int = 0): + """Set the suffix appended to the mp4 filename for headless rendering. + + Must be called before the first render() call of a rollout. + E.g. set_video_suffix("_bev", env_id=0) -> {scenario_id}_bev.mp4 + + Args: + suffix: Suffix string to append to video filename + env_id: Which environment to set suffix for (default 0) + """ + binding.vec_set_video_suffix(self.c_envs, env_id, suffix) def close(self): binding.vec_close(self.c_envs) @@ -727,7 +1286,13 @@ def calculate_area(p1, p2, p3): return 0.5 * abs((p1["x"] - p3["x"]) * (p2["y"] - p1["y"]) - (p1["x"] - p2["x"]) * (p3["y"] - p1["y"])) -def simplify_polyline(geometry, polyline_reduction_threshold): +def dist(a, b): + dx = a["x"] - b["x"] + dy = a["y"] - b["y"] + return dx * dx + dy * dy + + +def simplify_polyline(geometry, polyline_reduction_threshold, max_segment_length): """Simplify the given polyline using a method inspired by Visvalingham-Whyatt, optimized for Python.""" num_points = len(geometry) if num_points < 3: @@ -756,8 +1321,7 @@ def simplify_polyline(geometry, polyline_reduction_threshold): point2 = geometry[k_1] point3 = geometry[k_2] area = calculate_area(point1, point2, point3) - - if area < polyline_reduction_threshold: + if area < polyline_reduction_threshold and dist(point1, point3) <= max_segment_length: skip[k_1] = True skip_changed = True k = k_2 @@ -767,9 +1331,28 @@ def simplify_polyline(geometry, polyline_reduction_threshold): return [geometry[i] for i in range(num_points) if not skip[i]] -def save_map_binary(map_data, output_file, unique_map_id): - trajectory_length = 91 - """Saves map data in a binary format readable by C""" +def _to_int32(v, default=0): + """Wrap an arbitrary integer into the signed int32 range using two's-complement + semantics so struct.pack('i', ...) cannot overflow. nuPlan IDs and some type + fields can exceed 2^31-1; this preserves the low 32 bits the way C would.""" + try: + v = int(v) + except (TypeError, ValueError): + return default + v &= 0xFFFFFFFF + if v >= 0x80000000: + v -= 0x100000000 + return v + + +def save_map_binary(map_data, output_file, unique_map_id, trajectory_length=91): + """Saves map data in a binary format readable by C. + + `trajectory_length` is how many frames per object/road to write. The + C reader is parametric on the per-binary `array_size` header, so any + value works. Default 91 matches the legacy WOMD/short-window setup; + nuplan scenes go up to 201 frames so pass `trajectory_length=201` + to capture the full data.""" with open(output_file, "wb") as f: # Get metadata metadata = map_data.get("metadata", {}) @@ -777,27 +1360,25 @@ def save_map_binary(map_data, output_file, unique_map_id): tracks_to_predict = metadata.get("tracks_to_predict", []) # Write sdc_track_index - f.write(struct.pack("i", sdc_track_index)) + f.write(struct.pack("i", _to_int32(sdc_track_index, -1))) # Write tracks_to_predict info (indices only) - f.write(struct.pack("i", len(tracks_to_predict))) + f.write(struct.pack("i", _to_int32(len(tracks_to_predict)))) for track in tracks_to_predict: track_index = track.get("track_index", -1) - f.write(struct.pack("i", track_index)) + f.write(struct.pack("i", _to_int32(track_index, -1))) # Count total entities - print(len(map_data.get("objects", []))) - print(len(map_data.get("roads", []))) num_objects = len(map_data.get("objects", [])) num_roads = len(map_data.get("roads", [])) # num_entities = num_objects + num_roads - f.write(struct.pack("i", num_objects)) - f.write(struct.pack("i", num_roads)) + f.write(struct.pack("i", _to_int32(num_objects))) + f.write(struct.pack("i", _to_int32(num_roads))) # f.write(struct.pack('i', num_entities)) # Write objects for obj in map_data.get("objects", []): # Write unique map id - f.write(struct.pack("i", unique_map_id)) + f.write(struct.pack("i", _to_int32(unique_map_id))) # Write base entity data obj_type = obj.get("type", 1) @@ -807,9 +1388,10 @@ def save_map_binary(map_data, output_file, unique_map_id): obj_type = 2 elif obj_type == "cyclist": obj_type = 3 - f.write(struct.pack("i", obj_type)) # type - f.write(struct.pack("i", obj.get("id", 0))) # id - f.write(struct.pack("i", trajectory_length)) # array_size + f.write(struct.pack("i", _to_int32(obj_type))) # type + obj_id = obj.get("id", 0) + f.write(struct.pack("i", _to_int32(obj_id))) # id + f.write(struct.pack("i", _to_int32(trajectory_length))) # array_size # Write position arrays positions = obj.get("position", []) for i in range(trajectory_length): @@ -842,7 +1424,7 @@ def save_map_binary(map_data, output_file, unique_map_id): f.write( struct.pack( f"{trajectory_length}i", - *[int(valids[i]) if i < len(valids) else 0 for i in range(trajectory_length)], + *[_to_int32(valids[i]) if i < len(valids) else 0 for i in range(trajectory_length)], ) ) @@ -854,11 +1436,11 @@ def save_map_binary(map_data, output_file, unique_map_id): f.write(struct.pack("f", float(goal_pos.get("x", 0.0)))) # Get x value f.write(struct.pack("f", float(goal_pos.get("y", 0.0)))) # Get y value f.write(struct.pack("f", float(goal_pos.get("z", 0.0)))) # Get z value - f.write(struct.pack("i", obj.get("mark_as_expert", 0))) + f.write(struct.pack("i", _to_int32(obj.get("mark_as_expert", 0)))) # Write roads for idx, road in enumerate(map_data.get("roads", [])): - f.write(struct.pack("i", unique_map_id)) + f.write(struct.pack("i", _to_int32(unique_map_id))) geometry = road.get("geometry", []) road_type = road.get("map_element_id", 0) @@ -869,7 +1451,7 @@ def save_map_binary(map_data, output_file, unique_map_id): road_type = 15 # breakpoint() if len(geometry) > 10 and road_type <= 16: - geometry = simplify_polyline(geometry, 0.1) + geometry = simplify_polyline(geometry, 0.1, 250) size = len(geometry) # breakpoint() if road_type >= 0 and road_type <= 3: @@ -887,9 +1469,10 @@ def save_map_binary(map_data, output_file, unique_map_id): elif road_type == 20: road_type = 10 # Write base entity data - f.write(struct.pack("i", road_type)) # type - f.write(struct.pack("i", road.get("id", 0))) # id - f.write(struct.pack("i", size)) # array_size + f.write(struct.pack("i", _to_int32(road_type))) # type + road_id = road.get("id", 0) + f.write(struct.pack("i", _to_int32(road_id))) # id + f.write(struct.pack("i", _to_int32(size))) # array_size # Write position arrays for coord in ["x", "y", "z"]: @@ -904,44 +1487,89 @@ def save_map_binary(map_data, output_file, unique_map_id): f.write(struct.pack("f", float(goal_pos.get("x", 0.0)))) # Get x value f.write(struct.pack("f", float(goal_pos.get("y", 0.0)))) # Get y value f.write(struct.pack("f", float(goal_pos.get("z", 0.0)))) # Get z value - f.write(struct.pack("i", road.get("mark_as_expert", 0))) + f.write(struct.pack("i", _to_int32(road.get("mark_as_expert", 0)))) -def load_map(map_name, unique_map_id, binary_output=None): +def load_map(map_name, unique_map_id, binary_output=None, trajectory_length=91): """Loads a JSON map and optionally saves it as binary""" with open(map_name, "r") as f: map_data = json.load(f) if binary_output: - save_map_binary(map_data, binary_output, unique_map_id) + save_map_binary(map_data, binary_output, unique_map_id, trajectory_length=trajectory_length) + + +def _process_single_map(args): + """Worker function to process a single map file""" + i, map_path, binary_path, trajectory_length = args + try: + load_map(str(map_path), i, str(binary_path), trajectory_length=trajectory_length) + return (i, map_path.name, True, None) + except Exception as e: + return (i, map_path.name, False, str(e)) + + +def process_all_maps( + data_folder="data/processed/training", + max_maps=50_000, + num_workers=None, + shuffle=False, + trajectory_length=91, + output_subdir=None, +): + """Process all maps and save them as binaries using multiprocessing + + Args: + data_folder: Path to the folder containing JSON map files + max_maps: Maximum number of maps to process + num_workers: Number of parallel workers (defaults to cpu_count()) + shuffle: If True, shuffle the JSON files before assigning map IDs. + This ensures that when using num_maps < total, you get + a random mix of all source maps instead of alphabetically first ones. + """ + from pathlib import Path + import random + if num_workers is None: + num_workers = cpu_count() -def process_all_maps(): - """Process all maps and save them as binaries""" - from pathlib import Path + # Path to the training data + data_dir = Path(data_folder) + dataset_name = output_subdir if output_subdir is not None else data_dir.name # Create the binaries directory if it doesn't exist - binary_dir = Path("resources/drive/binaries") + binary_dir = Path(f"resources/drive/binaries/{dataset_name}") binary_dir.mkdir(parents=True, exist_ok=True) - # Path to the training data - data_dir = Path("data/processed/training") - # Get all JSON files in the training directory json_files = sorted(data_dir.glob("*.json")) - print(f"Found {len(json_files)} JSON files") + if shuffle: + json_files = list(json_files) + random.shuffle(json_files) - # Process each JSON file - for i, map_path in enumerate(json_files[:10000]): - binary_file = f"map_{i:03d}.bin" # Use zero-padded numbers for consistent sorting + # Prepare arguments for parallel processing + tasks = [] + for i, map_path in enumerate(json_files[:max_maps]): + binary_file = f"map_{i:03d}.bin" binary_path = binary_dir / binary_file + tasks.append((i, map_path, binary_path, trajectory_length)) + + # Process maps in parallel with progress bar + with Pool(num_workers) as pool: + results = list( + tqdm(pool.imap(_process_single_map, tasks), total=len(tasks), desc="Processing maps", unit="map") + ) + + # Collect statistics + successful = sum(1 for _, _, success, _ in results if success) + failed = sum(1 for _, _, success, _ in results if not success) - print(f"Processing {map_path.name} -> {binary_file}") - # try: - load_map(str(map_path), i, str(binary_path)) - # except Exception as e: - # print(f"Error processing {map_path.name}: {e}") + if failed > 0: + print(f"\nFailed {failed}/{len(results)} files:") + for i, name, success, error in results: + if not success: + print(f" {name}: {error}") def test_performance(timeout=10, atn_cache=1024, num_agents=1024): @@ -976,4 +1604,10 @@ def test_performance(timeout=10, atn_cache=1024, num_agents=1024): if __name__ == "__main__": # test_performance() - process_all_maps() + # Process the train dataset + # process_all_maps(data_folder="/data/processed/training") + process_all_maps(data_folder="/workspace/ADA/data/nuplan-gpudrive/nuplan") + # Process the validation/test dataset + # process_all_maps(data_folder="data/processed/validation") + # # Process the validation_interactive dataset + # process_all_maps(data_folder="data/processed/validation_interactive") diff --git a/pufferlib/ocean/drive/drivenet.h b/pufferlib/ocean/drive/drivenet.h deleted file mode 100644 index 56dbda5212..0000000000 --- a/pufferlib/ocean/drive/drivenet.h +++ /dev/null @@ -1,256 +0,0 @@ -#include -#include "drive.h" -#include "puffernet.h" -#include -#include -#include -#include -#include -#include - -typedef struct DriveNet DriveNet; -struct DriveNet { - int num_agents; - int conditioning_dims; - int ego_dim; - float* obs_self; - float* obs_partner; - float* obs_road; - float* partner_linear_output; - float* road_linear_output; - float* partner_layernorm_output; - float* road_layernorm_output; - float* partner_linear_output_two; - float* road_linear_output_two; - Linear* ego_encoder; - Linear* road_encoder; - Linear* partner_encoder; - LayerNorm* ego_layernorm; - LayerNorm* road_layernorm; - LayerNorm* partner_layernorm; - Linear* ego_encoder_two; - Linear* road_encoder_two; - Linear* partner_encoder_two; - MaxDim1* partner_max; - MaxDim1* road_max; - CatDim1* cat1; - CatDim1* cat2; - GELU* gelu; - Linear* shared_embedding; - ReLU* relu; - LSTM* lstm; - Linear* actor; - Linear* value_fn; - Multidiscrete* multidiscrete; -}; - -DriveNet* init_drivenet(Weights* weights, int num_agents, int dynamics_model, bool use_rc, bool use_ec, bool use_dc) { - DriveNet* net = calloc(1, sizeof(DriveNet)); - int hidden_size = 256; - int input_size = 64; - - int base_ego_dim = (dynamics_model == JERK) ? 10 : 7; - net->conditioning_dims = (use_rc ? 3 : 0) + (use_ec ? 1 : 0) + (use_dc ? 1 : 0); - net->ego_dim = base_ego_dim + net->conditioning_dims; - - // Determine action space size based on dynamics model - int action_size, logit_sizes[2]; - int action_dim; - if (dynamics_model == CLASSIC) { - action_size = 7 * 13; // Joint action space - logit_sizes[0] = 7 * 13; - action_dim = 1; - } else { // JERK - action_size = 7; // 4 + 3 - logit_sizes[0] = 4; - logit_sizes[1] = 3; - action_dim = 2; - } - - net->num_agents = num_agents; - - net->obs_self = calloc(num_agents*net->ego_dim, sizeof(float)); - net->obs_partner = calloc(num_agents*63*7, sizeof(float)); // 63 objects, 7 features - net->obs_road = calloc(num_agents*200*13, sizeof(float)); // 200 objects, 13 features - - net->partner_linear_output = calloc(num_agents*63*input_size, sizeof(float)); - net->road_linear_output = calloc(num_agents*200*input_size, sizeof(float)); - net->partner_linear_output_two = calloc(num_agents*63*input_size, sizeof(float)); - net->road_linear_output_two = calloc(num_agents*200*input_size, sizeof(float)); - net->partner_layernorm_output = calloc(num_agents*63*input_size, sizeof(float)); - net->road_layernorm_output = calloc(num_agents*200*input_size, sizeof(float)); - net->ego_encoder = make_linear(weights, num_agents, net->ego_dim, input_size); - net->ego_layernorm = make_layernorm(weights, num_agents, input_size); - net->ego_encoder_two = make_linear(weights, num_agents, input_size, input_size); - net->road_encoder = make_linear(weights, num_agents, 13, input_size); - net->road_layernorm = make_layernorm(weights, num_agents, input_size); - net->road_encoder_two = make_linear(weights, num_agents, input_size, input_size); - net->partner_encoder = make_linear(weights, num_agents, 7, input_size); - net->partner_layernorm = make_layernorm(weights, num_agents, input_size); - net->partner_encoder_two = make_linear(weights, num_agents, input_size, input_size); - net->partner_max = make_max_dim1(num_agents, 63, input_size); - net->road_max = make_max_dim1(num_agents, 200, input_size); - net->cat1 = make_cat_dim1(num_agents, input_size, input_size); - net->cat2 = make_cat_dim1(num_agents, input_size + input_size, input_size); - net->gelu = make_gelu(num_agents, 3*input_size); - net->shared_embedding = make_linear(weights, num_agents, input_size*3, hidden_size); - net->relu = make_relu(num_agents, hidden_size); - net->actor = make_linear(weights, num_agents, hidden_size, action_size); - net->value_fn = make_linear(weights, num_agents, hidden_size, 1); - net->lstm = make_lstm(weights, num_agents, hidden_size, 256); - memset(net->lstm->state_h, 0, num_agents*256*sizeof(float)); - memset(net->lstm->state_c, 0, num_agents*256*sizeof(float)); - net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, action_dim); - return net; -} - -void free_drivenet(DriveNet* net) { - free(net->obs_self); - free(net->obs_partner); - free(net->obs_road); - free(net->partner_linear_output); - free(net->road_linear_output); - free(net->partner_linear_output_two); - free(net->road_linear_output_two); - free(net->partner_layernorm_output); - free(net->road_layernorm_output); - free(net->ego_encoder); - free(net->road_encoder); - free(net->partner_encoder); - free(net->ego_layernorm); - free(net->road_layernorm); - free(net->partner_layernorm); - free(net->ego_encoder_two); - free(net->road_encoder_two); - free(net->partner_encoder_two); - free(net->partner_max); - free(net->road_max); - free(net->cat1); - free(net->cat2); - free(net->gelu); - free(net->shared_embedding); - free(net->relu); - free(net->multidiscrete); - free(net->actor); - free(net->value_fn); - free(net->lstm); - free(net); -} - -void forward(DriveNet* net, float* observations, int* actions) { - int ego_dim = net->ego_dim; - - // Clear previous observations - memset(net->obs_self, 0, net->num_agents * ego_dim * sizeof(float)); - memset(net->obs_partner, 0, net->num_agents * 63 * 7 * sizeof(float)); - memset(net->obs_road, 0, net->num_agents * 200 * 13 * sizeof(float)); - - // Reshape observations into 2D boards and additional features - float* obs_self = net->obs_self; - float (*obs_partner)[63][7] = (float (*)[63][7])net->obs_partner; - float (*obs_road)[200][13] = (float (*)[200][13])net->obs_road; - - for (int b = 0; b < net->num_agents; b++) { - int b_offset = b * (ego_dim + 63*7 + 200*7); // offset for each batch - int partner_offset = b_offset + ego_dim; - int road_offset = b_offset + ego_dim + 63*7; - // Process self observation - for(int i = 0; i < ego_dim; i++) { - obs_self[b*ego_dim + i] = observations[b_offset + i]; - } - - // Process partner observation - for(int i = 0; i < 63; i++) { - for(int j = 0; j < 7; j++) { - net->obs_partner[b*63*7 + i*7 + j] = observations[partner_offset + i*7 + j]; - } - } - - // Process road observation - for(int i = 0; i < 200; i++) { - for(int j = 0; j < 7; j++) { - net->obs_road[b*200*13 + i*13 + j] = observations[road_offset + i*7 + j]; - } - for(int j = 0; j < 7; j++) { - if(j == observations[road_offset+i*7 + 6]) { - net->obs_road[b*200*13 + i*13 + 6 + j] = 1.0f; - } else { - net->obs_road[b*200*13 + i*13 + 6 + j] = 0.0f; - } - } - } - } - - // Forward pass through the network - linear(net->ego_encoder, net->obs_self); - layernorm(net->ego_layernorm, net->ego_encoder->output); - linear(net->ego_encoder_two, net->ego_layernorm->output); - for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 63; obj++) { - // Get the 7 features for this object - float* obj_features = &net->obs_partner[b*63*7 + obj*7]; - // Apply linear layer to this object - _linear(obj_features, net->partner_encoder->weights, net->partner_encoder->bias, - &net->partner_linear_output[b*63*64 + obj*64], 1, 7, 64); - } - } - - for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 63; obj++) { - float* after_first = &net->partner_linear_output[b*63*64 + obj*64]; - _layernorm(after_first, net->partner_layernorm->weights, net->partner_layernorm->bias, - &net->partner_layernorm_output[b*63*64 + obj*64], 1, 64); - } - } - for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 63; obj++) { - // Get the 7 features for this object - float* obj_features = &net->partner_layernorm_output[b*63*64 + obj*64]; - // Apply linear layer to this object - _linear(obj_features, net->partner_encoder_two->weights, net->partner_encoder_two->bias, - &net->partner_linear_output_two[b*63*64 + obj*64], 1, 64, 64); - - } - } - - // Process road objects: apply linear to each object individually - for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 200; obj++) { - // Get the 13 features for this object - float* obj_features = &net->obs_road[b*200*13 + obj*13]; - // Apply linear layer to this object - _linear(obj_features, net->road_encoder->weights, net->road_encoder->bias, - &net->road_linear_output[b*200*64 + obj*64], 1, 13, 64); - } - } - - // Apply layer norm and second linear to each road object - for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 200; obj++) { - float* after_first = &net->road_linear_output[b*200*64 + obj*64]; - _layernorm(after_first, net->road_layernorm->weights, net->road_layernorm->bias, - &net->road_layernorm_output[b*200*64 + obj*64], 1, 64); - } - } - for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 200; obj++) { - float* after_first = &net->road_layernorm_output[b*200*64 + obj*64]; - _linear(after_first, net->road_encoder_two->weights, net->road_encoder_two->bias, - &net->road_linear_output_two[b*200*64 + obj*64], 1, 64, 64); - } - } - - max_dim1(net->partner_max, net->partner_linear_output_two); - max_dim1(net->road_max, net->road_linear_output_two); - cat_dim1(net->cat1, net->ego_encoder_two->output, net->road_max->output); - cat_dim1(net->cat2, net->cat1->output, net->partner_max->output); - gelu(net->gelu, net->cat2->output); - linear(net->shared_embedding, net->gelu->output); - relu(net->relu, net->shared_embedding->output); - lstm(net->lstm, net->relu->output); - linear(net->actor, net->lstm->state_h); - linear(net->value_fn, net->lstm->state_h); - - // Get action by taking argmax of actor output - softmax_multidiscrete(net->multidiscrete, net->actor->output, actions); -} diff --git a/pufferlib/ocean/drive/error.h b/pufferlib/ocean/drive/error.h index b1eb78e7ed..77ae171bb5 100644 --- a/pufferlib/ocean/drive/error.h +++ b/pufferlib/ocean/drive/error.h @@ -18,21 +18,29 @@ typedef enum { ERROR_UNKNOWN } ErrorType; -const char* error_type_to_string(ErrorType type) { +const char *error_type_to_string(ErrorType type) { switch (type) { - case ERROR_NONE: return "No Error"; - case ERROR_NULL_POINTER: return "Null Pointer"; - case ERROR_INVALID_ARGUMENT: return "Invalid Argument"; - case ERROR_OUT_OF_BOUNDS: return "Out of Bounds"; - case ERROR_MEMORY_ALLOCATION: return "Memory Allocation Failed"; - case ERROR_FILE_NOT_FOUND: return "File Not Found"; - case ERROR_INITIALIZATION_FAILED: return "Initialization Failed"; - default: return "Unknown Error"; + case ERROR_NONE: + return "No Error"; + case ERROR_NULL_POINTER: + return "Null Pointer"; + case ERROR_INVALID_ARGUMENT: + return "Invalid Argument"; + case ERROR_OUT_OF_BOUNDS: + return "Out of Bounds"; + case ERROR_MEMORY_ALLOCATION: + return "Memory Allocation Failed"; + case ERROR_FILE_NOT_FOUND: + return "File Not Found"; + case ERROR_INITIALIZATION_FAILED: + return "Initialization Failed"; + default: + return "Unknown Error"; } } // Enhanced error function with custom message support -void raise_error_with_message(ErrorType type, const char* format, ...) { +void raise_error_with_message(ErrorType type, const char *format, ...) { printf("Error occurred: %s", error_type_to_string(type)); if (format != NULL) { @@ -47,35 +55,28 @@ void raise_error_with_message(ErrorType type, const char* format, ...) { } // Simple error function (backward compatibility) -void raise_error(ErrorType type) { - raise_error_with_message(type, NULL); -} +void raise_error(ErrorType type) { raise_error_with_message(type, NULL); } // Convenience macros for common error patterns -#define RAISE_FILE_ERROR(path) \ - raise_error_with_message(ERROR_FILE_NOT_FOUND, "at path: %s", path) +#define RAISE_FILE_ERROR(path) raise_error_with_message(ERROR_FILE_NOT_FOUND, "at path: %s", path) -#define RAISE_BOUNDS_ERROR() \ - raise_error(ERROR_OUT_OF_BOUNDS) +#define RAISE_BOUNDS_ERROR() raise_error(ERROR_OUT_OF_BOUNDS) -#define RAISE_BOUNDS_ERROR_WITH_BOUNDS(index, min, max) \ +#define RAISE_BOUNDS_ERROR_WITH_BOUNDS(index, min, max) \ raise_error_with_message(ERROR_OUT_OF_BOUNDS, "index %d exceeds minimum of %d and maximum %d", index, min, max) -#define RAISE_NULL_ERROR() \ - raise_error(ERROR_NULL_POINTER) +#define RAISE_NULL_ERROR() raise_error(ERROR_NULL_POINTER) -#define RAISE_NULL_ERROR_WITH_NAME(var_name) \ +#define RAISE_NULL_ERROR_WITH_NAME(var_name) \ raise_error_with_message(ERROR_NULL_POINTER, "variable '%s' is null", var_name) -#define RAISE_MEMORY_ERROR() \ - raise_error(ERROR_MEMORY_ALLOCATION) +#define RAISE_MEMORY_ERROR() raise_error(ERROR_MEMORY_ALLOCATION) -#define RAISE_MEMORY_ERROR_WITH_SIZE(size) \ +#define RAISE_MEMORY_ERROR_WITH_SIZE(size) \ raise_error_with_message(ERROR_MEMORY_ALLOCATION, "failed to allocate %zu bytes", size) -#define RAISE_INVALID_ARG_ERROR() \ - raise_error(ERROR_INVALID_ARGUMENT) +#define RAISE_INVALID_ARG_ERROR() raise_error(ERROR_INVALID_ARGUMENT) -#define RAISE_INVALID_ARG_ERROR_WITH_ARG(arg_name, value) \ +#define RAISE_INVALID_ARG_ERROR_WITH_ARG(arg_name, value) \ raise_error_with_message(ERROR_INVALID_ARGUMENT, "invalid value for '%s': %d", arg_name, value) #endif diff --git a/pufferlib/ocean/drive/rollout.py b/pufferlib/ocean/drive/rollout.py new file mode 100644 index 0000000000..aa8d4babf4 --- /dev/null +++ b/pufferlib/ocean/drive/rollout.py @@ -0,0 +1,161 @@ +"""Shared rollout loop for Drive evaluation and rendering. + +Single source of truth for the forward-sample-step-break cycle. Used by: + - Training renders (periodic evaluation with video logging) + - Offline batch rendering + - Safe evaluation time rendering + +This module enables rendering with ANY policy architecture (LSTM, Transformer, etc.) +by keeping policy inference in Python/PyTorch while using C bindings for environment +simulation and graphics rendering. +""" + +from dataclasses import dataclass +from enum import IntEnum +from typing import Optional + +import numpy as np +import torch + +import pufferlib.pytorch + + +class RenderView(IntEnum): + """View modes for rendering.""" + + FULL_SIM_STATE = 0 # Top-down orthographic view of full simulation + BEV_AGENT_OBS = 1 # Bird's eye view centered on agent observation + AGENT_PERSPECTIVE = 2 # Third-person chase camera following agent + + +@dataclass +class RenderContext: + """Enables rendering inside rollout_loop. + + Attributes: + view_mode: RenderView enum value passed to driver.render(). + env_id: which sub-env in the vecenv to record from (default 0). + draw_traces: whether to draw trajectory traces. + video_basename: full mp4 basename (without ".mp4"). Set once before + the first render via driver.set_video_suffix. Caller is + responsible for making this unique across renders. + """ + + view_mode: RenderView + env_id: int = 0 + draw_traces: bool = True + video_basename: str = "render" + + +def rollout_loop( + policy, + env, + device, + use_rnn: bool, + max_steps: Optional[int] = None, + render_ctx: Optional[RenderContext] = None, +): + """Run a single policy rollout in a Drive vecenv. + + This function handles policy inference in Python/PyTorch, making it work + with ANY model architecture (LSTM, Transformer, etc.). + + Args: + policy: the policy to run. Caller is responsible for calling .eval(). + env: a PufferEnv-compatible vecenv wrapping one or more Drive sub-envs. + device: torch device for observation / state tensors. + use_rnn: whether to allocate and carry LSTM hidden state. + max_steps: loop iteration cap. Defaults to env.driver_env.scenario_length. + render_ctx: if set, render the specified env/view every step before + sampling actions. Filename suffix is applied via set_video_suffix. + + Returns: + The last info returned by env.step(). + """ + driver = env.driver_env + + # Handle population play mode - only ego agents are controlled by the policy + population_play = getattr(driver, "population_play", False) + if population_play: + num_ego_agents = driver.num_ego_agents + ego_ids = driver.ego_ids + print(f"[rollout] Population play mode: {num_ego_agents} ego agents, {driver.num_co_players} co-players") + else: + num_ego_agents = env.observation_space.shape[0] + ego_ids = None + + # Set full video basename before the first render call + if render_ctx is not None: + driver.set_video_suffix(render_ctx.video_basename, env_id=render_ctx.env_id) + + obs, _ = env.reset() + + # Initialize recurrent state based on policy type + # Note: state is for EGO agents only, not co-players + state = {} + if use_rnn: + # Check if this is a Transformer or LSTM policy + is_transformer = hasattr(policy, "transformer") or hasattr(policy, "horizon") + + if is_transformer: + # Transformer handles its own state initialization in forward_eval + # when transformer_context is missing, so we just pass empty state + state = {} + else: + # LSTM policy - initialize h and c states for ego agents only + if hasattr(policy, "hidden_size"): + hidden_size = policy.hidden_size + else: + hidden_size = 128 # default + state = dict( + lstm_h=torch.zeros(num_ego_agents, hidden_size, device=device), + lstm_c=torch.zeros(num_ego_agents, hidden_size, device=device), + ) + + if max_steps is None: + max_steps = getattr(driver, "scenario_length", 91) + + info = [] + for step in range(max_steps): + if step % 30 == 0: + print(f"[Python Render] Step {step}/{max_steps}", flush=True) + # Render BEFORE the step so each frame shows the state the policy was + # conditioned on. + if render_ctx is not None: + driver.render( + view_mode=render_ctx.view_mode, + draw_traces=render_ctx.draw_traces, + env_id=render_ctx.env_id, + ) + + with torch.no_grad(): + # In population play, only pass ego observations to the policy + if population_play: + ego_obs = obs[ego_ids] + ob_t = torch.as_tensor(ego_obs).to(device) + else: + ob_t = torch.as_tensor(obs).to(device) + + logits, _ = policy.forward_eval(ob_t, state) + action, _, _ = pufferlib.pytorch.sample_logits(logits) + + # Reshape actions to match expected format + if population_play: + # In population play, policy outputs actions for ego agents only + # env.action_space.shape is (total_agents, action_dim), we need (num_ego_agents, action_dim) + action_dim = env.action_space.shape[1] if len(env.action_space.shape) > 1 else 1 + action_np = action.cpu().numpy().reshape(num_ego_agents, action_dim) + else: + action_np = action.cpu().numpy().reshape(env.action_space.shape) + + # Clip continuous actions to the valid range + if isinstance(logits, torch.distributions.Normal): + action_np = np.clip(action_np, env.action_space.low, env.action_space.high) + + obs, _, _, truncs, info = env.step(action_np) + + # Break when episode ends (truncs.all() is set when the env auto-resets) + if truncs.all(): + break + + return info diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c deleted file mode 100644 index 172ddd1a8f..0000000000 --- a/pufferlib/ocean/drive/visualize.c +++ /dev/null @@ -1,556 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "rlgl.h" -#include -#include -#include -#include -#include "error.h" -#include "drivenet.h" -#include "libgen.h" -#include "../env_config.h" -#define TRAJECTORY_LENGTH_DEFAULT 91 - -typedef struct { - int pipefd[2]; - pid_t pid; -} VideoRecorder; - -bool OpenVideo(VideoRecorder *recorder, const char *output_filename, int width, int height) { - if (pipe(recorder->pipefd) == -1) { - fprintf(stderr, "Failed to create pipe\n"); - return false; - } - - recorder->pid = fork(); - if (recorder->pid == -1) { - fprintf(stderr, "Failed to fork\n"); - return false; - } - - char size_str[64]; - snprintf(size_str, sizeof(size_str), "%dx%d", width, height); - - if (recorder->pid == 0) { // Child process: run ffmpeg - close(recorder->pipefd[1]); - dup2(recorder->pipefd[0], STDIN_FILENO); - close(recorder->pipefd[0]); - // Close all other file descriptors to prevent leaks - for (int fd = 3; fd < 256; fd++) { - close(fd); - } - execlp("ffmpeg", "ffmpeg", - "-y", - "-f", "rawvideo", - "-pix_fmt", "rgba", - "-s", size_str, - "-r", "30", - "-i", "-", - "-c:v", "libx264", - "-pix_fmt", "yuv420p", - "-preset", "ultrafast", - "-crf", "23", - "-loglevel", "error", - output_filename, - NULL); - TraceLog(LOG_ERROR, "Failed to launch ffmpeg"); - return false; - } - - close(recorder->pipefd[0]); // Close read end in parent - return true; -} - -void WriteFrame(VideoRecorder *recorder, int width, int height) { - unsigned char *screen_data = rlReadScreenPixels(width, height); - write(recorder->pipefd[1], screen_data, width * height * 4 * sizeof(*screen_data)); - RL_FREE(screen_data); -} - -void CloseVideo(VideoRecorder *recorder) { - close(recorder->pipefd[1]); - waitpid(recorder->pid, NULL, 0); -} - -void renderTopDownView(Drive* env, Client* client, int map_height, int obs, int lasers, int trajectories, int frame_count, float* path, int log_trajectories, int show_grid, int img_width, int img_height) { - - BeginDrawing(); - - // Top-down orthographic camera - Camera3D camera = {0}; - camera.position = (Vector3){ 0.0f, 0.0f, 500.0f }; // above the scene - camera.target = (Vector3){ 0.0f, 0.0f, 0.0f }; // look at origin - camera.up = (Vector3){ 0.0f, -1.0f, 0.0f }; - camera.fovy = map_height; - camera.projection = CAMERA_ORTHOGRAPHIC; - - client->width = img_width; - client->height = img_height; - - Color road = (Color){35, 35, 37, 255}; - ClearBackground(road); - BeginMode3D(camera); - rlEnableDepthTest(); - - // Draw human replay trajectories if enabled - if(log_trajectories){ - for(int i=0; iactive_agent_count; i++){ - int idx = env->active_agent_indices[i]; - Vector3 prev_point = {0}; - bool has_prev = false; - - for(int j = 0; j < env->entities[idx].array_size; j++){ - float x = env->entities[idx].traj_x[j]; - float y = env->entities[idx].traj_y[j]; - float valid = env->entities[idx].traj_valid[j]; - - if(!valid) { - has_prev = false; - continue; - } - - Vector3 curr_point = {x, y, 0.5f}; - - if(has_prev) { - DrawLine3D(prev_point, curr_point, Fade(LIGHTGREEN, 0.6f)); - } - - prev_point = curr_point; - has_prev = true; - } - } - } - - // Draw agent trajs - if(trajectories){ - for(int i=0; iactive_agent_indices[env->human_agent_idx]; - Entity* agent = &env->entities[agent_idx]; - - BeginDrawing(); - - Camera3D camera = {0}; - // Position camera behind and above the agent - camera.position = (Vector3){ - agent->x - (25.0f * cosf(agent->heading)), - agent->y - (25.0f * sinf(agent->heading)), - 15.0f - }; - camera.target = (Vector3){ - agent->x + 40.0f * cosf(agent->heading), - agent->y + 40.0f * sinf(agent->heading), - 1.0f - }; - camera.up = (Vector3){ 0.0f, 0.0f, 1.0f }; - camera.fovy = 45.0f; - camera.projection = CAMERA_PERSPECTIVE; - - Color road = (Color){35, 35, 37, 255}; - - ClearBackground(road); - BeginMode3D(camera); - rlEnableDepthTest(); - draw_scene(env, client, 0, obs_only, lasers, show_grid); // mode=0 for agent view - EndMode3D(); - EndDrawing(); -} - -static int run_cmd(const char *cmd) { - int rc = system(cmd); - if (rc != 0) { - fprintf(stderr, "[ffmpeg] command failed (%d): %s\n", rc, cmd); - } - return rc; -} - -// Make a high-quality GIF from numbered PNG frames like frame_000.png -static int make_gif_from_frames(const char *pattern, int fps, - const char *palette_path, - const char *out_gif) { - char cmd[1024]; - - // 1) Generate palette (no quotes needed for simple filter) - // NOTE: if your frames start at 000, you don't need -start_number. - snprintf(cmd, sizeof(cmd), - "ffmpeg -y -framerate %d -i %s -vf palettegen %s", - fps, pattern, palette_path); - if (run_cmd(cmd) != 0) return -1; - - // 2) Use palette to encode the GIF - snprintf(cmd, sizeof(cmd), - "ffmpeg -y -framerate %d -i %s -i %s -lavfi paletteuse -loop 0 %s", - fps, pattern, palette_path, out_gif); - if (run_cmd(cmd) != 0) return -1; - - return 0; -} - -int eval_gif(const char* map_name, const char* policy_name, int show_grid, int obs_only, int lasers, int log_trajectories, int frame_skip, float goal_radius, int init_steps, int use_rc, int use_ec, int use_dc, int max_controlled_agents, const char* view_mode, const char* output_topdown, const char* output_agent, int num_maps, int scenario_length_override, int init_mode, int control_mode, int goal_behavior) { - - // Parse configuration from INI file - env_init_config conf = {0}; // Initialize to zero - const char* ini_file = "pufferlib/config/ocean/drive.ini"; - if(ini_parse(ini_file, handler, &conf) < 0) { - fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); - return -1; - } - - char map_buffer[100]; - if (map_name == NULL) { - srand(time(NULL)); - int random_map = rand() % num_maps; - sprintf(map_buffer, "resources/drive/binaries/map_%03d.bin", random_map); // random map file - map_name = map_buffer; - } - - if (frame_skip <= 0) { - frame_skip = 1; // Default: render every frame - } - - // Check if map file exists - FILE* map_file = fopen(map_name, "rb"); - if (map_file == NULL) { - RAISE_FILE_ERROR(map_name); - } - fclose(map_file); - - FILE* policy_file = fopen(policy_name, "rb"); - if (policy_file == NULL) { - RAISE_FILE_ERROR(policy_name); - } - fclose(policy_file); - - Drive env = { - .dynamics_model = conf.dynamics_model, - .reward_vehicle_collision = conf.reward_vehicle_collision, - .reward_offroad_collision = conf.reward_offroad_collision, - .reward_ade = conf.reward_ade, - .goal_radius = goal_radius, - .dt = conf.dt, - .map_name = (char*)map_name, - .init_steps = init_steps, - .max_controlled_agents = max_controlled_agents, - .collision_behavior = conf.collision_behavior, - .offroad_behavior = conf.offroad_behavior, - .goal_behavior = goal_behavior, - .init_mode = init_mode, - .control_mode = control_mode, - .use_rc = use_rc, - .use_ec = use_ec, - .use_dc = use_dc, - // Conditioning weight bounds (defaults from drive.py) - .collision_weight_lb = -0.0f, - .collision_weight_ub = -0.0f, - .offroad_weight_lb = -0.0f, - .offroad_weight_ub = -0.0f, - .goal_weight_lb = 1.0f, - .goal_weight_ub = 1.0f, - .entropy_weight_lb = 0.001f, - .entropy_weight_ub = 0.001f, - .discount_weight_lb = 0.98f, - .discount_weight_ub = 0.98f, - }; - - env.scenario_length = (scenario_length_override > 0) ? scenario_length_override : - (conf.scenario_length > 0) ? conf.scenario_length : TRAJECTORY_LENGTH_DEFAULT; - allocate(&env); - - // Set which vehicle to focus on for obs mode - env.human_agent_idx = 0; - - c_reset(&env); - // Make client for rendering - Client* client = (Client*)calloc(1, sizeof(Client)); - env.client = client; - - SetConfigFlags(FLAG_WINDOW_HIDDEN); - - SetTargetFPS(6000); - - float map_width = env.grid_map->bottom_right_x - env.grid_map->top_left_x; - float map_height = env.grid_map->top_left_y - env.grid_map->bottom_right_y; - - printf("Map size: %.1fx%.1f\n", map_width, map_height); - float scale = 6.0f; // Can be used to increase the video quality - - // Calculate video width and height; round to nearest even number - int img_width = (int)roundf(map_width * scale / 2.0f) * 2; - int img_height = (int)roundf(map_height * scale / 2.0f) * 2; - InitWindow(img_width, img_height, "Puffer Drive"); - SetConfigFlags(FLAG_MSAA_4X_HINT); - - Weights* weights = load_weights(policy_name); - printf("Active agents in map: %d\n", env.active_agent_count); - DriveNet* net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, use_rc, use_ec, use_dc); - - int frame_count = env.scenario_length > 0 ? env.scenario_length : TRAJECTORY_LENGTH_DEFAULT; - int log_trajectory = log_trajectories; - char filename_topdown[256]; - char filename_agent[256]; - - if (output_topdown != NULL && output_agent != NULL) { - strcpy(filename_topdown, output_topdown); - strcpy(filename_agent, output_agent); - } else { - char policy_base[256]; - strcpy(policy_base, policy_name); - *strrchr(policy_base, '.') = '\0'; - - char map[256]; - strcpy(map, basename((char*)map_name)); - *strrchr(map, '.') = '\0'; - - // Create video directory if it doesn't exist - char video_dir[256]; - sprintf(video_dir, "%s/video", policy_base); - char mkdir_cmd[512]; - snprintf(mkdir_cmd, sizeof(mkdir_cmd), "mkdir -p \"%s\"", video_dir); - system(mkdir_cmd); - - sprintf(filename_topdown, "%s/video/%s_topdown.mp4", policy_base, map); - sprintf(filename_agent, "%s/video/%s_agent.mp4", policy_base, map); - } - - bool render_topdown = (strcmp(view_mode, "both") == 0 || strcmp(view_mode, "topdown") == 0); - bool render_agent = (strcmp(view_mode, "both") == 0 || strcmp(view_mode, "agent") == 0); - - printf("Rendering: %s\n", view_mode); - - int rendered_frames = 0; - double startTime = GetTime(); - - VideoRecorder topdown_recorder, agent_recorder; - - if (render_topdown) { - if (!OpenVideo(&topdown_recorder, filename_topdown, img_width, img_height)) { - CloseWindow(); - return -1; - } - } - - if (render_agent) { - if (!OpenVideo(&agent_recorder, filename_agent, img_width, img_height)) { - if (render_topdown) CloseVideo(&topdown_recorder); - CloseWindow(); - return -1; - } - } - - if (render_topdown) { - printf("Recording topdown view...\n"); - for(int i = 0; i < frame_count; i++) { - if (i % frame_skip == 0) { - renderTopDownView(&env, client, map_height, 0, 0, 0, frame_count, NULL, log_trajectories, show_grid, img_width, img_height); - WriteFrame(&topdown_recorder, img_width, img_height); - rendered_frames++; - } - int (*actions)[2] = (int(*)[2])env.actions; - forward(net, env.observations, (int*)env.actions); - c_step(&env); - } - - } - - if (render_agent) { - c_reset(&env); - printf("Recording agent view...\n"); - for(int i = 0; i < frame_count; i++) { - if (i % frame_skip == 0) { - renderAgentView(&env, client, map_height, obs_only, lasers, show_grid); - WriteFrame(&agent_recorder, img_width, img_height); - rendered_frames++; - } - int (*actions)[2] = (int(*)[2])env.actions; - forward(net, env.observations, (int*)env.actions); - c_step(&env); - } - } - - double endTime = GetTime(); - double elapsedTime = endTime - startTime; - double writeFPS = (elapsedTime > 0) ? rendered_frames / elapsedTime : 0; - - printf("Wrote %d frames in %.2f seconds (%.2f FPS) to %s \n", - rendered_frames, elapsedTime, writeFPS, filename_topdown); - - if (render_topdown) { - CloseVideo(&topdown_recorder); - } - if (render_agent) { - CloseVideo(&agent_recorder); - } - CloseWindow(); - - // Clean up resources - free(client); - free_allocated(&env); - free_drivenet(net); - free(weights); - return 0; -} - -int main(int argc, char* argv[]) { - int show_grid = 0; - int obs_only = 0; - int lasers = 0; - int log_trajectories = 1; - int frame_skip = 1; - float goal_radius = 2.0f; - int init_steps = 0; - const char* map_name = NULL; - const char* policy_name = "resources/drive/puffer_drive_weights.bin"; - int max_controlled_agents = -1; - int num_maps = 1; - int scenario_length_cli = -1; - int use_rc = 0; - int use_ec = 0; - int use_dc = 0; - int init_mode = 0; - int control_mode = 0; - int goal_behavior = 0; - - const char* view_mode = "both"; // "both", "topdown", "agent" - const char* output_topdown = NULL; - const char* output_agent = NULL; - - // Parse command line arguments - for (int i = 1; i < argc; i++) { - if (strcmp(argv[i], "--show-grid") == 0) { - show_grid = 1; - } else if (strcmp(argv[i], "--obs-only") == 0) { - obs_only = 1; - } else if (strcmp(argv[i], "--lasers") == 0) { - lasers = 1; - } else if (strcmp(argv[i], "--log-trajectories") == 0) { - log_trajectories = 1; - } else if (strcmp(argv[i], "--frame-skip") == 0) { - if (i + 1 < argc) { - frame_skip = atoi(argv[i + 1]); - i++; // Skip the next argument since we consumed it - if (frame_skip <= 0) { - frame_skip = 1; // Ensure valid value - } - } - } else if (strcmp(argv[i], "--goal-radius") == 0) { - if (i + 1 < argc) { - goal_radius = atof(argv[i + 1]); - i++; - if (goal_radius <= 0) { - goal_radius = 2.0f; // Ensure valid value - } - } - } else if (strcmp(argv[i], "--map-name") == 0) { - // Check if there's a next argument for the map path - if (i + 1 < argc) { - map_name = argv[i + 1]; - i++; // Skip the next argument since we used it as map path - } else { - fprintf(stderr, "Error: --map-name option requires a map file path\n"); - return 1; - } - } else if (strcmp(argv[i], "--policy-name") == 0) { - if (i + 1 < argc) { - policy_name = argv[i + 1]; - i++; - } else { - fprintf(stderr, "Error: --policy-name option requires a policy file path\n"); - return 1; - } - } else if (strcmp(argv[i], "--view") == 0) { - if (i + 1 < argc) { - view_mode = argv[i + 1]; - i++; - if (strcmp(view_mode, "both") != 0 && - strcmp(view_mode, "topdown") != 0 && - strcmp(view_mode, "agent") != 0) { - fprintf(stderr, "Error: --view must be 'both', 'topdown', or 'agent'\n"); - return 1; - } - } else { - fprintf(stderr, "Error: --view option requires a value (both/topdown/agent)\n"); - return 1; - } - } else if (strcmp(argv[i], "--output-topdown") == 0) { - if (i + 1 < argc) { - output_topdown = argv[i + 1]; - i++; - } - } else if (strcmp(argv[i], "--output-agent") == 0) { - if (i + 1 < argc) { - output_agent = argv[i + 1]; - i++; - } - } else if (strcmp(argv[i], "--init-steps") == 0) { - if (i + 1 < argc) { - init_steps = atoi(argv[i + 1]); - i++; - if (init_steps < 0) { - init_steps = 0; - } - } - } else if (strcmp(argv[i], "--init-mode") == 0) { - if (i + 1 < argc) { - init_mode = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--control-mode") == 0) { - if (i + 1 < argc) { - control_mode = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--max-controlled-agents") == 0) { - if (i + 1 < argc) { - max_controlled_agents = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--num-maps") == 0) { - if (i + 1 < argc) { - num_maps = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--scenario-length") == 0) { - if (i + 1 < argc) { - scenario_length_cli = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--use-rc") == 0) { - if (i + 1 < argc) { - use_rc = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--use-ec") == 0) { - if (i + 1 < argc) { - use_ec = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--use-dc") == 0) { - if (i + 1 < argc) { - use_dc = atoi(argv[i + 1]); - } - } else if (strcmp(argv[i], "--goal-behavior") == 0) { - if (i + 1 < argc) { - goal_behavior = atoi(argv[i + 1]); - i++; - } - } - } - - eval_gif(map_name, policy_name, show_grid, obs_only, lasers, log_trajectories, frame_skip, goal_radius, init_steps, use_rc, use_ec, use_dc, max_controlled_agents, view_mode, output_topdown, output_agent, num_maps, scenario_length_cli, init_mode, control_mode, goal_behavior); - return 0; -} diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 6f7fa3c7df..343cc96584 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -3,42 +3,36 @@ #include // Forward declarations for env-specific functions supplied by user -static int my_log(PyObject* dict, Log* log); -static int my_init(Env* env, PyObject* args, PyObject* kwargs); +static int my_log(PyObject *dict, Log *log); +static int my_init(Env *env, PyObject *args, PyObject *kwargs); -static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs); +static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs); #ifndef MY_SHARED -static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { - return NULL; -} +static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) { return NULL; } #endif -static PyObject* my_get(PyObject* dict, Env* env); +static PyObject *my_get(PyObject *dict, Env *env); #ifndef MY_GET -static PyObject* my_get(PyObject* dict, Env* env) { - return NULL; -} +static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } #endif -static int my_put(Env* env, PyObject* args, PyObject* kwargs); +static int my_put(Env *env, PyObject *args, PyObject *kwargs); #ifndef MY_PUT -static int my_put(Env* env, PyObject* args, PyObject* kwargs) { - return 0; -} +static int my_put(Env *env, PyObject *args, PyObject *kwargs) { return 0; } #endif #ifndef MY_METHODS #define MY_METHODS {NULL, NULL, 0, NULL} #endif -static Env* unpack_env(PyObject* args) { - PyObject* handle_obj = PyTuple_GetItem(args, 0); +static Env *unpack_env(PyObject *args) { + PyObject *handle_obj = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "env_handle must be an integer"); return NULL; } - Env* env = (Env*)PyLong_AsVoidPtr(handle_obj); + Env *env = (Env *)PyLong_AsVoidPtr(handle_obj); if (!env) { PyErr_SetString(PyExc_ValueError, "Invalid env handle"); return NULL; @@ -48,36 +42,36 @@ static Env* unpack_env(PyObject* args) { } // Python function to initialize the environment -static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *env_init(PyObject *self, PyObject *args, PyObject *kwargs) { if (PyTuple_Size(args) != 6) { PyErr_SetString(PyExc_TypeError, "Environment requires 5 arguments"); return NULL; } - Env* env = (Env*)calloc(1, sizeof(Env)); + Env *env = (Env *)calloc(1, sizeof(Env)); if (!env) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate environment"); return NULL; } - PyObject* obs = PyTuple_GetItem(args, 0); + PyObject *obs = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array"); return NULL; } - PyArrayObject* observations = (PyArrayObject*)obs; + PyArrayObject *observations = (PyArrayObject *)obs; if (!PyArray_ISCONTIGUOUS(observations)) { PyErr_SetString(PyExc_ValueError, "Observations must be contiguous"); return NULL; } env->observations = PyArray_DATA(observations); - PyObject* act = PyTuple_GetItem(args, 1); + PyObject *act = PyTuple_GetItem(args, 1); if (!PyObject_TypeCheck(act, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array"); return NULL; } - PyArrayObject* actions = (PyArrayObject*)act; + PyArrayObject *actions = (PyArrayObject *)act; if (!PyArray_ISCONTIGUOUS(actions)) { PyErr_SetString(PyExc_ValueError, "Actions must be contiguous"); return NULL; @@ -88,12 +82,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* rew = PyTuple_GetItem(args, 2); + PyObject *rew = PyTuple_GetItem(args, 2); if (!PyObject_TypeCheck(rew, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array"); return NULL; } - PyArrayObject* rewards = (PyArrayObject*)rew; + PyArrayObject *rewards = (PyArrayObject *)rew; if (!PyArray_ISCONTIGUOUS(rewards)) { PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous"); return NULL; @@ -104,12 +98,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } env->rewards = PyArray_DATA(rewards); - PyObject* term = PyTuple_GetItem(args, 3); + PyObject *term = PyTuple_GetItem(args, 3); if (!PyObject_TypeCheck(term, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array"); return NULL; } - PyArrayObject* terminals = (PyArrayObject*)term; + PyArrayObject *terminals = (PyArrayObject *)term; if (!PyArray_ISCONTIGUOUS(terminals)) { PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous"); return NULL; @@ -120,12 +114,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } env->terminals = PyArray_DATA(terminals); - PyObject* trunc = PyTuple_GetItem(args, 4); + PyObject *trunc = PyTuple_GetItem(args, 4); if (!PyObject_TypeCheck(trunc, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Truncations must be a NumPy array"); return NULL; } - PyArrayObject* truncations = (PyArrayObject*)trunc; + PyArrayObject *truncations = (PyArrayObject *)trunc; if (!PyArray_ISCONTIGUOUS(truncations)) { PyErr_SetString(PyExc_ValueError, "Truncations must be contiguous"); return NULL; @@ -136,8 +130,7 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } // env->truncations = PyArray_DATA(truncations); - - PyObject* seed_arg = PyTuple_GetItem(args, 5); + PyObject *seed_arg = PyTuple_GetItem(args, 5); if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "seed must be an integer"); return NULL; @@ -151,11 +144,11 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { if (kwargs == NULL) { kwargs = PyDict_New(); } else { - Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it + Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it } // Add the seed to kwargs - PyObject* py_seed = PyLong_FromLong(seed); + PyObject *py_seed = PyLong_FromLong(seed); if (PyDict_SetItemString(kwargs, "seed", py_seed) < 0) { PyErr_SetString(PyExc_RuntimeError, "Failed to set seed in kwargs"); Py_DECREF(py_seed); @@ -164,7 +157,7 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } Py_DECREF(py_seed); - PyObject* empty_args = PyTuple_New(0); + PyObject *empty_args = PyTuple_New(0); my_init(env, empty_args, kwargs); Py_DECREF(kwargs); if (PyErr_Occurred()) { @@ -175,14 +168,14 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } // Python function to reset the environment -static PyObject* env_reset(PyObject* self, PyObject* args) { +static PyObject *env_reset(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 2) { PyErr_SetString(PyExc_TypeError, "env_reset requires 2 arguments"); return NULL; } - Env* env = unpack_env(args); - if (!env){ + Env *env = unpack_env(args); + if (!env) { return NULL; } c_reset(env); @@ -190,15 +183,15 @@ static PyObject* env_reset(PyObject* self, PyObject* args) { } // Python function to step the environment -static PyObject* env_step(PyObject* self, PyObject* args) { +static PyObject *env_step(PyObject *self, PyObject *args) { int num_args = PyTuple_Size(args); if (num_args != 1) { PyErr_SetString(PyExc_TypeError, "vec_render requires 1 argument"); return NULL; } - Env* env = unpack_env(args); - if (!env){ + Env *env = unpack_env(args); + if (!env) { return NULL; } c_step(env); @@ -206,9 +199,9 @@ static PyObject* env_step(PyObject* self, PyObject* args) { } // Python function to step the environment -static PyObject* env_render(PyObject* self, PyObject* args) { - Env* env = unpack_env(args); - if (!env){ +static PyObject *env_render(PyObject *self, PyObject *args) { + Env *env = unpack_env(args); + if (!env) { return NULL; } c_render(env); @@ -216,9 +209,9 @@ static PyObject* env_render(PyObject* self, PyObject* args) { } // Python function to close the environment -static PyObject* env_close(PyObject* self, PyObject* args) { - Env* env = unpack_env(args); - if (!env){ +static PyObject *env_close(PyObject *self, PyObject *args) { + Env *env = unpack_env(args); + if (!env) { return NULL; } c_close(env); @@ -226,12 +219,12 @@ static PyObject* env_close(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* env_get(PyObject* self, PyObject* args) { - Env* env = unpack_env(args); - if (!env){ +static PyObject *env_get(PyObject *self, PyObject *args) { + Env *env = unpack_env(args); + if (!env) { return NULL; } - PyObject* dict = PyDict_New(); + PyObject *dict = PyDict_New(); my_get(dict, env); if (PyErr_Occurred()) { return NULL; @@ -239,19 +232,19 @@ static PyObject* env_get(PyObject* self, PyObject* args) { return dict; } -static PyObject* env_put(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *env_put(PyObject *self, PyObject *args, PyObject *kwargs) { int num_args = PyTuple_Size(args); if (num_args != 1) { PyErr_SetString(PyExc_TypeError, "env_put requires 1 positional argument"); return NULL; } - Env* env = unpack_env(args); - if (!env){ + Env *env = unpack_env(args); + if (!env) { return NULL; } - PyObject* empty_args = PyTuple_New(0); + PyObject *empty_args = PyTuple_New(0); my_put(env, empty_args, kwargs); if (PyErr_Occurred()) { return NULL; @@ -261,18 +254,18 @@ static PyObject* env_put(PyObject* self, PyObject* args, PyObject* kwargs) { } typedef struct { - Env** envs; + Env **envs; int num_envs; } VecEnv; -static VecEnv* unpack_vecenv(PyObject* args) { - PyObject* handle_obj = PyTuple_GetItem(args, 0); +static VecEnv *unpack_vecenv(PyObject *args) { + PyObject *handle_obj = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "env_handle must be an integer"); return NULL; } - VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(handle_obj); + VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(handle_obj); if (!vec) { PyErr_SetString(PyExc_ValueError, "Missing or invalid vec env handle"); return NULL; @@ -286,18 +279,18 @@ static VecEnv* unpack_vecenv(PyObject* args) { return vec; } -static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *vec_init(PyObject *self, PyObject *args, PyObject *kwargs) { if (PyTuple_Size(args) != 7) { PyErr_SetString(PyExc_TypeError, "vec_init requires 6 arguments"); return NULL; } - VecEnv* vec = (VecEnv*)calloc(1, sizeof(VecEnv)); + VecEnv *vec = (VecEnv *)calloc(1, sizeof(VecEnv)); if (!vec) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; } - PyObject* num_envs_arg = PyTuple_GetItem(args, 5); + PyObject *num_envs_arg = PyTuple_GetItem(args, 5); if (!PyObject_TypeCheck(num_envs_arg, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "num_envs must be an integer"); return NULL; @@ -308,25 +301,25 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } vec->num_envs = num_envs; - vec->envs = (Env**)calloc(num_envs, sizeof(Env*)); + vec->envs = (Env **)calloc(num_envs, sizeof(Env *)); if (!vec->envs) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; } - PyObject* seed_obj = PyTuple_GetItem(args, 6); + PyObject *seed_obj = PyTuple_GetItem(args, 6); if (!PyObject_TypeCheck(seed_obj, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "seed must be an integer"); return NULL; } int seed = PyLong_AsLong(seed_obj); - PyObject* obs = PyTuple_GetItem(args, 0); + PyObject *obs = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array"); return NULL; } - PyArrayObject* observations = (PyArrayObject*)obs; + PyArrayObject *observations = (PyArrayObject *)obs; if (!PyArray_ISCONTIGUOUS(observations)) { PyErr_SetString(PyExc_ValueError, "Observations must be contiguous"); return NULL; @@ -336,12 +329,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* act = PyTuple_GetItem(args, 1); + PyObject *act = PyTuple_GetItem(args, 1); if (!PyObject_TypeCheck(act, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array"); return NULL; } - PyArrayObject* actions = (PyArrayObject*)act; + PyArrayObject *actions = (PyArrayObject *)act; if (!PyArray_ISCONTIGUOUS(actions)) { PyErr_SetString(PyExc_ValueError, "Actions must be contiguous"); return NULL; @@ -351,12 +344,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* rew = PyTuple_GetItem(args, 2); + PyObject *rew = PyTuple_GetItem(args, 2); if (!PyObject_TypeCheck(rew, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array"); return NULL; } - PyArrayObject* rewards = (PyArrayObject*)rew; + PyArrayObject *rewards = (PyArrayObject *)rew; if (!PyArray_ISCONTIGUOUS(rewards)) { PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous"); return NULL; @@ -366,12 +359,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* term = PyTuple_GetItem(args, 3); + PyObject *term = PyTuple_GetItem(args, 3); if (!PyObject_TypeCheck(term, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array"); return NULL; } - PyArrayObject* terminals = (PyArrayObject*)term; + PyArrayObject *terminals = (PyArrayObject *)term; if (!PyArray_ISCONTIGUOUS(terminals)) { PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous"); return NULL; @@ -381,12 +374,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* trunc = PyTuple_GetItem(args, 4); + PyObject *trunc = PyTuple_GetItem(args, 4); if (!PyObject_TypeCheck(trunc, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Truncations must be a NumPy array"); return NULL; } - PyArrayObject* truncations = (PyArrayObject*)trunc; + PyArrayObject *truncations = (PyArrayObject *)trunc; if (!PyArray_ISCONTIGUOUS(truncations)) { PyErr_SetString(PyExc_ValueError, "Truncations must be contiguous"); return NULL; @@ -400,11 +393,11 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { if (kwargs == NULL) { kwargs = PyDict_New(); } else { - Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it + Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it } for (int i = 0; i < num_envs; i++) { - Env* env = (Env*)calloc(1, sizeof(Env)); + Env *env = (Env *)calloc(1, sizeof(Env)); if (!env) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate environment"); Py_DECREF(kwargs); @@ -415,18 +408,18 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { // // Make sure the log is initialized to 0 memset(&env->log, 0, sizeof(Log)); - env->observations = (void*)((char*)PyArray_DATA(observations) + i*PyArray_STRIDE(observations, 0)); - env->actions = (void*)((char*)PyArray_DATA(actions) + i*PyArray_STRIDE(actions, 0)); - env->rewards = (void*)((char*)PyArray_DATA(rewards) + i*PyArray_STRIDE(rewards, 0)); - env->terminals = (void*)((char*)PyArray_DATA(terminals) + i*PyArray_STRIDE(terminals, 0)); + env->observations = (void *)((char *)PyArray_DATA(observations) + i * PyArray_STRIDE(observations, 0)); + env->actions = (void *)((char *)PyArray_DATA(actions) + i * PyArray_STRIDE(actions, 0)); + env->rewards = (void *)((char *)PyArray_DATA(rewards) + i * PyArray_STRIDE(rewards, 0)); + env->terminals = (void *)((char *)PyArray_DATA(terminals) + i * PyArray_STRIDE(terminals, 0)); // env->truncations = (void*)((char*)PyArray_DATA(truncations) + i*PyArray_STRIDE(truncations, 0)); // Assumes each process has the same number of environments - int env_seed = i + seed*vec->num_envs; + int env_seed = i + seed * vec->num_envs; srand(env_seed); // Add the seed to kwargs for this environment - PyObject* py_seed = PyLong_FromLong(env_seed); + PyObject *py_seed = PyLong_FromLong(env_seed); if (PyDict_SetItemString(kwargs, "seed", py_seed) < 0) { PyErr_SetString(PyExc_RuntimeError, "Failed to set seed in kwargs"); Py_DECREF(py_seed); @@ -435,7 +428,7 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { } Py_DECREF(py_seed); - PyObject* empty_args = PyTuple_New(0); + PyObject *empty_args = PyTuple_New(0); my_init(env, empty_args, kwargs); if (PyErr_Occurred()) { return NULL; @@ -446,22 +439,21 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return PyLong_FromVoidPtr(vec); } - // Python function to close the environment -static PyObject* vectorize(PyObject* self, PyObject* args) { +static PyObject *vectorize(PyObject *self, PyObject *args) { int num_envs = PyTuple_Size(args); if (num_envs == 0) { PyErr_SetString(PyExc_TypeError, "make_vec requires at least 1 env id"); return NULL; } - VecEnv* vec = (VecEnv*)calloc(1, sizeof(VecEnv)); + VecEnv *vec = (VecEnv *)calloc(1, sizeof(VecEnv)); if (!vec) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; } - vec->envs = (Env**)calloc(num_envs, sizeof(Env*)); + vec->envs = (Env **)calloc(num_envs, sizeof(Env *)); if (!vec->envs) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; @@ -469,29 +461,30 @@ static PyObject* vectorize(PyObject* self, PyObject* args) { vec->num_envs = num_envs; for (int i = 0; i < num_envs; i++) { - PyObject* handle_obj = PyTuple_GetItem(args, i); + PyObject *handle_obj = PyTuple_GetItem(args, i); if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) { - PyErr_SetString(PyExc_TypeError, "Env ids must be integers. Pass them as separate args with *env_ids, not as a list."); + PyErr_SetString(PyExc_TypeError, + "Env ids must be integers. Pass them as separate args with *env_ids, not as a list."); return NULL; } - vec->envs[i] = (Env*)PyLong_AsVoidPtr(handle_obj); + vec->envs[i] = (Env *)PyLong_AsVoidPtr(handle_obj); } return PyLong_FromVoidPtr(vec); } -static PyObject* vec_reset(PyObject* self, PyObject* args) { +static PyObject *vec_reset(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 2) { PyErr_SetString(PyExc_TypeError, "vec_reset requires 2 arguments"); return NULL; } - VecEnv* vec = unpack_vecenv(args); + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } - PyObject* seed_arg = PyTuple_GetItem(args, 1); + PyObject *seed_arg = PyTuple_GetItem(args, 1); if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "seed must be an integer"); return NULL; @@ -500,20 +493,20 @@ static PyObject* vec_reset(PyObject* self, PyObject* args) { for (int i = 0; i < vec->num_envs; i++) { // Assumes each process has the same number of environments - srand(i + seed*vec->num_envs); + srand(i + seed * vec->num_envs); c_reset(vec->envs[i]); } Py_RETURN_NONE; } -static PyObject* vec_step(PyObject* self, PyObject* arg) { +static PyObject *vec_step(PyObject *self, PyObject *arg) { int num_args = PyTuple_Size(arg); if (num_args != 1) { PyErr_SetString(PyExc_TypeError, "vec_step requires 1 argument"); return NULL; } - VecEnv* vec = unpack_vecenv(arg); + VecEnv *vec = unpack_vecenv(arg); if (!vec) { return NULL; } @@ -524,37 +517,72 @@ static PyObject* vec_step(PyObject* self, PyObject* arg) { Py_RETURN_NONE; } -static PyObject* vec_render(PyObject* self, PyObject* args) { +static PyObject *vec_render(PyObject *self, PyObject *args) { int num_args = PyTuple_Size(args); - if (num_args != 2) { - PyErr_SetString(PyExc_TypeError, "vec_render requires 2 arguments"); + if (num_args != 6) { + PyErr_SetString(PyExc_TypeError, "vec_render requires 6 arguments: (vec_env, view_mode, draw_traces, env_id, " + "current_scenario, k_scenarios)"); return NULL; } - VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0)); + VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0)); if (!vec) { PyErr_SetString(PyExc_ValueError, "Invalid vec_env handle"); return NULL; } - PyObject* env_id_arg = PyTuple_GetItem(args, 1); - if (!PyObject_TypeCheck(env_id_arg, &PyLong_Type)) { - PyErr_SetString(PyExc_TypeError, "env_id must be an integer"); + int view_mode = (int)PyLong_AsLong(PyTuple_GetItem(args, 1)); + int draw_traces = PyObject_IsTrue(PyTuple_GetItem(args, 2)); + int env_id = (int)PyLong_AsLong(PyTuple_GetItem(args, 3)); + int current_scenario = (int)PyLong_AsLong(PyTuple_GetItem(args, 4)); + int k_scenarios = (int)PyLong_AsLong(PyTuple_GetItem(args, 5)); + + if (env_id < 0 || env_id >= vec->num_envs) { + PyErr_SetString(PyExc_ValueError, "env_id out of range"); return NULL; } - int env_id = PyLong_AsLong(env_id_arg); - c_render(vec->envs[env_id]); + c_render_with_mode(vec->envs[env_id], view_mode, draw_traces, current_scenario, k_scenarios); Py_RETURN_NONE; } -static int assign_to_dict(PyObject* dict, char* key, float value) { - PyObject* v = PyFloat_FromDouble(value); +static PyObject *vec_set_video_suffix(PyObject *self, PyObject *args) { + int num_args = PyTuple_Size(args); + if (num_args != 3) { + PyErr_SetString(PyExc_TypeError, "vec_set_video_suffix requires 3 arguments: (vec_env, env_id, suffix)"); + return NULL; + } + + VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0)); + if (!vec) { + PyErr_SetString(PyExc_ValueError, "Invalid vec_env handle"); + return NULL; + } + + int env_id = (int)PyLong_AsLong(PyTuple_GetItem(args, 1)); + if (env_id < 0 || env_id >= vec->num_envs) { + PyErr_SetString(PyExc_ValueError, "env_id out of range"); + return NULL; + } + + PyObject *suffix_obj = PyTuple_GetItem(args, 2); + const char *suffix = PyUnicode_AsUTF8(suffix_obj); + if (!suffix) { + PyErr_SetString(PyExc_TypeError, "suffix must be a string"); + return NULL; + } + + set_video_suffix(vec->envs[env_id], suffix); + Py_RETURN_NONE; +} + +static int assign_to_dict(PyObject *dict, char *key, float value) { + PyObject *v = PyFloat_FromDouble(value); if (v == NULL) { PyErr_SetString(PyExc_TypeError, "Failed to convert log value"); return 1; } - if(PyDict_SetItemString(dict, key, v) < 0) { + if (PyDict_SetItemString(dict, key, v) < 0) { PyErr_SetString(PyExc_TypeError, "Failed to set log value"); return 1; } @@ -562,99 +590,125 @@ static int assign_to_dict(PyObject* dict, char* key, float value) { return 0; } -static PyObject* vec_log(PyObject* self, PyObject* args) { - VecEnv* vec = unpack_vecenv(args); +static PyObject *vec_log(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 2) { + PyErr_SetString(PyExc_TypeError, "vec_log requires 2 arguments"); + return NULL; + } + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } + PyObject *num_agents_arg = PyTuple_GetItem(args, 1); + float num_agents = (float)PyLong_AsLong(num_agents_arg); + // Iterates over logs one float at a time. Will break // horribly if Log has non-float data. Log aggregate = {0}; int num_keys = sizeof(Log) / sizeof(float); - // Adaptive agent logging variables - float ada_delta_completion_rate = 0.0f; - float ada_delta_score = 0.0f; - float ada_delta_perf = 0.0f; - float ada_delta_collision_rate = 0.0f; - float ada_delta_offroad_rate = 0.0f; - float ada_delta_num_goals_reached = 0.0f; - float ada_delta_dnf_rate = 0.0f; - float ada_delta_lane_alignment_rate = 0.0f; - float ada_delta_avg_displacement_error = 0.0f; - float ada_delta_episode_return = 0.0f; - int ada_agent_count = 0; - int has_co_players = 0; // Flag to check if any env has co-players - Co_Player_Log co_player_aggregate = {0}; - int num_co_player_keys = sizeof(Co_Player_Log) / sizeof(float); + int has_co_players = 0; // Flag to check if any env has co-players + Log co_player_aggregate = {0}; // Now using Log struct instead of Co_Player_Log for (int i = 0; i < vec->num_envs; i++) { - Env* env = vec->envs[i]; - + Env *env = vec->envs[i]; for (int j = 0; j < num_keys; j++) { - ((float*)&aggregate)[j] += ((float*)&env->log)[j]; - ((float*)&env->log)[j] = 0.0f; + ((float *)&aggregate)[j] += ((float *)&env->log)[j]; } if (env->population_play && env->num_co_players > 0 && env->co_player_ids != NULL) { has_co_players = 1; - - // Aggregate co-player logs - for (int j = 0; j < num_co_player_keys; j++) { - ((float*)&co_player_aggregate)[j] += ((float*)&env->co_player_log)[j]; - ((float*)&env->co_player_log)[j] = 0.0f; // Reset after aggregating + // Aggregate co-player logs (now same structure as ego logs) + for (int j = 0; j < num_keys; j++) { + ((float *)&co_player_aggregate)[j] += ((float *)&env->co_player_log)[j]; } } + } + + PyObject *dict = PyDict_New(); + // Check if we have enough data from EITHER ego agents OR total (ego + co-players) + float total_n = aggregate.n + (has_co_players ? co_player_aggregate.n : 0.0f); + if (total_n < num_agents) { + return dict; // Not enough data yet } - PyObject* dict = PyDict_New(); + // Got enough data. Reset logs and return metrics + for (int i = 0; i < vec->num_envs; i++) { + Env *env = vec->envs[i]; + for (int j = 0; j < num_keys; j++) { + ((float *)&env->log)[j] = 0.0f; + } - // Average regular logs - if (aggregate.n > 0.0f) { - float n = aggregate.n; - for (int i = 0; i < num_keys; i++) { - ((float*)&aggregate)[i] /= n; + if (env->population_play && env->num_co_players > 0 && env->co_player_ids != NULL) { + for (int j = 0; j < num_keys; j++) { + ((float *)&env->co_player_log)[j] = 0.0f; + } + } + } + + float n = aggregate.n; + // Average across EGO agents only + if (n > 0) { + // Compute completion_rate from raw totals BEFORE averaging + float total_goals_reached = aggregate.goals_reached_this_episode; + float total_goals_sampled = aggregate.goals_sampled_this_episode; + if (total_goals_sampled > 0) { + aggregate.completion_rate = total_goals_reached / total_goals_sampled; + } else { + aggregate.completion_rate = 0.0f; } - // User populates dict - my_log(dict, &aggregate); - assign_to_dict(dict, "n", n); + for (int i = 0; i < num_keys; i++) { + ((float *)&aggregate)[i] /= n; + } } - if (has_co_players && co_player_aggregate.co_player_n > 0.0f) { - float co_player_n = co_player_aggregate.co_player_n; - // Only divide non-zero values to avoid corruption - for (int i = 0; i < num_co_player_keys; i++) { - if (((float*)&co_player_aggregate)[i] != 0.0f) { - ((float*)&co_player_aggregate)[i] /= co_player_n; - } + // User populates dict + my_log(dict, &aggregate); + assign_to_dict(dict, "n", n); + + // Handle co-player metrics + if (has_co_players && co_player_aggregate.n > 0.0f) { + float co_player_n = co_player_aggregate.n; + + // Compute co-player completion rate from raw totals BEFORE averaging + float co_total_goals_reached = co_player_aggregate.goals_reached_this_episode; + float co_total_goals_sampled = co_player_aggregate.goals_sampled_this_episode; + if (co_total_goals_sampled > 0) { + co_player_aggregate.completion_rate = co_total_goals_reached / co_total_goals_sampled; + } else { + co_player_aggregate.completion_rate = 0.0f; } - // Add co-player metrics directly - assign_to_dict(dict, "ego_co_player_ratio", aggregate.n / co_player_n); - assign_to_dict(dict, "co_player_completion_rate", co_player_aggregate.co_player_completion_rate); - assign_to_dict(dict, "co_player_collision_rate", co_player_aggregate.co_player_collision_rate); - assign_to_dict(dict, "co_player_offroad_rate", co_player_aggregate.co_player_offroad_rate); - assign_to_dict(dict, "co_player_clean_collision_rate", co_player_aggregate.co_player_clean_collision_rate); - assign_to_dict(dict, "co_player_num_goals_reached", co_player_aggregate.co_player_num_goals_reached); - assign_to_dict(dict, "co_player_score", co_player_aggregate.co_player_score); - assign_to_dict(dict, "co_player_perf", co_player_aggregate.co_player_perf); - assign_to_dict(dict, "co_player_dnf_rate", co_player_aggregate.co_player_dnf_rate); - assign_to_dict(dict, "co_player_episode_length", co_player_aggregate.co_player_episode_length); - assign_to_dict(dict, "co_player_episode_return", co_player_aggregate.co_player_episode_return); - assign_to_dict(dict, "co_player_lane_alignment_rate", co_player_aggregate.co_player_lane_alignment_rate); - assign_to_dict(dict, "co_player_avg_displacement_error", co_player_aggregate.co_player_avg_displacement_error); + // Average co-player metrics across CO-PLAYER agents only + for (int i = 0; i < num_keys; i++) { + ((float *)&co_player_aggregate)[i] /= co_player_n; + } + + // Add co-player metrics to dict with co_player_ prefix + assign_to_dict(dict, "ego_co_player_ratio", n / co_player_n); + assign_to_dict(dict, "co_player_completion_rate", co_player_aggregate.completion_rate); + assign_to_dict(dict, "co_player_collision_rate", co_player_aggregate.collision_rate); + assign_to_dict(dict, "co_player_collisions_per_agent", co_player_aggregate.collisions_per_agent); + assign_to_dict(dict, "co_player_offroad_rate", co_player_aggregate.offroad_rate); + assign_to_dict(dict, "co_player_offroad_per_agent", co_player_aggregate.offroad_per_agent); + assign_to_dict(dict, "co_player_score", co_player_aggregate.score); + assign_to_dict(dict, "co_player_dnf_rate", co_player_aggregate.dnf_rate); + assign_to_dict(dict, "co_player_episode_length", co_player_aggregate.episode_length); + assign_to_dict(dict, "co_player_episode_return", co_player_aggregate.episode_return); + assign_to_dict(dict, "co_player_lane_alignment_rate", co_player_aggregate.lane_alignment_rate); + assign_to_dict(dict, "co_player_speed_at_goal", co_player_aggregate.speed_at_goal); + assign_to_dict(dict, "co_player_goals_reached_this_episode", co_player_aggregate.goals_reached_this_episode); + assign_to_dict(dict, "co_player_goals_sampled_this_episode", co_player_aggregate.goals_sampled_this_episode); assign_to_dict(dict, "co_player_n", co_player_n); } - return dict; } - -static PyObject* vec_close(PyObject* self, PyObject* args) { - VecEnv* vec = unpack_vecenv(args); +static PyObject *vec_close(PyObject *self, PyObject *args) { + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } @@ -668,93 +722,118 @@ static PyObject* vec_close(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* get_global_agent_state(PyObject* self, PyObject* args) { - if (PyTuple_Size(args) != 5) { - PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 5 arguments"); +// Render-mode helpers: stash env[0]->client into a global before vec_close so +// raylib + ffmpeg pipe survive the map swap, then re-attach to env[0] of the +// freshly built vec. Single-slot global: render envs are single-env. +static PyObject *vec_donate_client(PyObject *self, PyObject *args) { + VecEnv *vec = unpack_vecenv(args); + if (!vec || vec->num_envs == 0) { + Py_RETURN_NONE; + } + c_donate_client(vec->envs[0]); + Py_RETURN_NONE; +} + +static PyObject *vec_adopt_client(PyObject *self, PyObject *args) { + VecEnv *vec = unpack_vecenv(args); + if (!vec || vec->num_envs == 0) { + Py_RETURN_NONE; + } + c_adopt_client(vec->envs[0]); + Py_RETURN_NONE; +} + +static PyObject *get_global_agent_state(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 7) { + PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 7 arguments"); return NULL; } - Env* env = unpack_env(args); + Env *env = unpack_env(args); if (!env) { return NULL; } - Drive* drive = (Drive*)env; // Cast to Drive* + Drive *drive = (Drive *)env; // Cast to Drive* // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* id_arr = PyTuple_GetItem(args, 5); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || - !PyArray_Check(id_arr)) { + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *id_arr = PyTuple_GetItem(args, 5); + PyObject *length_arr = PyTuple_GetItem(args, 6); + PyObject *width_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || !PyArray_Check(width_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - float* x_data = (float*)PyArray_DATA((PyArrayObject*)x_arr); - float* y_data = (float*)PyArray_DATA((PyArrayObject*)y_arr); - float* z_data = (float*)PyArray_DATA((PyArrayObject*)z_arr); - float* heading_data = (float*)PyArray_DATA((PyArrayObject*)heading_arr); - int* id_data = (int*)PyArray_DATA((PyArrayObject*)id_arr); + float *x_data = (float *)PyArray_DATA((PyArrayObject *)x_arr); + float *y_data = (float *)PyArray_DATA((PyArrayObject *)y_arr); + float *z_data = (float *)PyArray_DATA((PyArrayObject *)z_arr); + float *heading_data = (float *)PyArray_DATA((PyArrayObject *)heading_arr); + int *id_data = (int *)PyArray_DATA((PyArrayObject *)id_arr); + float *length_data = (float *)PyArray_DATA((PyArrayObject *)length_arr); + float *width_data = (float *)PyArray_DATA((PyArrayObject *)width_arr); - c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data); + c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data, length_data, width_data); Py_RETURN_NONE; } -static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { - if (PyTuple_Size(args) != 6) { - PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 6 arguments"); +static PyObject *vec_get_global_agent_state(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 8) { + PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 8 arguments"); return NULL; } - VecEnv* vec = unpack_vecenv(args); + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* id_arr = PyTuple_GetItem(args, 5); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || - !PyArray_Check(id_arr)) { + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *id_arr = PyTuple_GetItem(args, 5); + PyObject *length_arr = PyTuple_GetItem(args, 6); + PyObject *width_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || !PyArray_Check(width_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - PyArrayObject* x_array = (PyArrayObject*)x_arr; - PyArrayObject* y_array = (PyArrayObject*)y_arr; - PyArrayObject* z_array = (PyArrayObject*)z_arr; - PyArrayObject* heading_array = (PyArrayObject*)heading_arr; - PyArrayObject* id_array = (PyArrayObject*)id_arr; + PyArrayObject *x_array = (PyArrayObject *)x_arr; + PyArrayObject *y_array = (PyArrayObject *)y_arr; + PyArrayObject *z_array = (PyArrayObject *)z_arr; + PyArrayObject *heading_array = (PyArrayObject *)heading_arr; + PyArrayObject *id_array = (PyArrayObject *)id_arr; + PyArrayObject *length_array = (PyArrayObject *)length_arr; + PyArrayObject *width_array = (PyArrayObject *)width_arr; // Get base pointers to the arrays - float* x_base = (float*)PyArray_DATA(x_array); - float* y_base = (float*)PyArray_DATA(y_array); - float* z_base = (float*)PyArray_DATA(z_array); - float* heading_base = (float*)PyArray_DATA(heading_array); - int* id_base = (int*)PyArray_DATA(id_array); + float *x_base = (float *)PyArray_DATA(x_array); + float *y_base = (float *)PyArray_DATA(y_array); + float *z_base = (float *)PyArray_DATA(z_array); + float *heading_base = (float *)PyArray_DATA(heading_array); + int *id_base = (int *)PyArray_DATA(id_array); + float *length_base = (float *)PyArray_DATA(length_array); + float *width_base = (float *)PyArray_DATA(width_array); // Iterate through environments and write to correct offsets int offset = 0; for (int i = 0; i < vec->num_envs; i++) { - Drive* drive = (Drive*)vec->envs[i]; + Drive *drive = (Drive *)vec->envs[i]; // Write to the arrays at the current offset - c_get_global_agent_state(drive, - &x_base[offset], - &y_base[offset], - &z_base[offset], - &heading_base[offset], - &id_base[offset]); + c_get_global_agent_state(drive, &x_base[offset], &y_base[offset], &z_base[offset], &heading_base[offset], + &id_base[offset], &length_base[offset], &width_base[offset]); // Move offset forward by the number of agents in this environment offset += drive->active_agent_count; @@ -763,111 +842,105 @@ static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* get_ground_truth_trajectories(PyObject* self, PyObject* args) { +static PyObject *get_ground_truth_trajectories(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 7) { PyErr_SetString(PyExc_TypeError, "get_ground_truth_trajectories requires 7 arguments"); return NULL; } - Env* env = unpack_env(args); + Env *env = unpack_env(args); if (!env) { return NULL; } - Drive* drive = (Drive*)env; + Drive *drive = (Drive *)env; // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* valid_arr = PyTuple_GetItem(args, 5); - PyObject* id_arr = PyTuple_GetItem(args, 6); - PyObject* scenario_id_arr = PyTuple_GetItem(args, 7); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *valid_arr = PyTuple_GetItem(args, 5); + PyObject *id_arr = PyTuple_GetItem(args, 6); + PyObject *scenario_id_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || !PyArray_Check(valid_arr) || !PyArray_Check(id_arr) || !PyArray_Check(scenario_id_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - float* x_data = (float*)PyArray_DATA((PyArrayObject*)x_arr); - float* y_data = (float*)PyArray_DATA((PyArrayObject*)y_arr); - float* z_data = (float*)PyArray_DATA((PyArrayObject*)z_arr); - float* heading_data = (float*)PyArray_DATA((PyArrayObject*)heading_arr); - int* valid_data = (int*)PyArray_DATA((PyArrayObject*)valid_arr); - int* id_data = (int*)PyArray_DATA((PyArrayObject*)id_arr); - int* scenario_id_data = (int*)PyArray_DATA((PyArrayObject*)scenario_id_arr); + float *x_data = (float *)PyArray_DATA((PyArrayObject *)x_arr); + float *y_data = (float *)PyArray_DATA((PyArrayObject *)y_arr); + float *z_data = (float *)PyArray_DATA((PyArrayObject *)z_arr); + float *heading_data = (float *)PyArray_DATA((PyArrayObject *)heading_arr); + int *valid_data = (int *)PyArray_DATA((PyArrayObject *)valid_arr); + int *id_data = (int *)PyArray_DATA((PyArrayObject *)id_arr); + int *scenario_id_data = (int *)PyArray_DATA((PyArrayObject *)scenario_id_arr); - c_get_global_ground_truth_trajectories(drive, x_data, y_data, z_data, heading_data, valid_data, id_data, scenario_id_data); + c_get_global_ground_truth_trajectories(drive, x_data, y_data, z_data, heading_data, valid_data, id_data, + scenario_id_data); Py_RETURN_NONE; } -static PyObject* vec_get_global_ground_truth_trajectories(PyObject* self, PyObject* args) { +static PyObject *vec_get_global_ground_truth_trajectories(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 8) { PyErr_SetString(PyExc_TypeError, "vec_get_global_ground_truth_trajectories requires 8 arguments"); return NULL; } - VecEnv* vec = unpack_vecenv(args); + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* valid_arr = PyTuple_GetItem(args, 5); - PyObject* id_arr = PyTuple_GetItem(args, 6); - PyObject* scenario_id_arr = PyTuple_GetItem(args, 7); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *valid_arr = PyTuple_GetItem(args, 5); + PyObject *id_arr = PyTuple_GetItem(args, 6); + PyObject *scenario_id_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || !PyArray_Check(valid_arr) || !PyArray_Check(id_arr) || !PyArray_Check(scenario_id_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - PyArrayObject* x_array = (PyArrayObject*)x_arr; - PyArrayObject* y_array = (PyArrayObject*)y_arr; - PyArrayObject* z_array = (PyArrayObject*)z_arr; - PyArrayObject* heading_array = (PyArrayObject*)heading_arr; - PyArrayObject* valid_array = (PyArrayObject*)valid_arr; - PyArrayObject* id_array = (PyArrayObject*)id_arr; - PyArrayObject* scenario_id_array = (PyArrayObject*)scenario_id_arr; + PyArrayObject *x_array = (PyArrayObject *)x_arr; + PyArrayObject *y_array = (PyArrayObject *)y_arr; + PyArrayObject *z_array = (PyArrayObject *)z_arr; + PyArrayObject *heading_array = (PyArrayObject *)heading_arr; + PyArrayObject *valid_array = (PyArrayObject *)valid_arr; + PyArrayObject *id_array = (PyArrayObject *)id_arr; + PyArrayObject *scenario_id_array = (PyArrayObject *)scenario_id_arr; // Get base pointers to the arrays - float* x_base = (float*)PyArray_DATA(x_array); - float* y_base = (float*)PyArray_DATA(y_array); - float* z_base = (float*)PyArray_DATA(z_array); - float* heading_base = (float*)PyArray_DATA(heading_array); - int* valid_base = (int*)PyArray_DATA(valid_array); - int* id_base = (int*)PyArray_DATA(id_array); - int* scenario_id_base = (int*)PyArray_DATA(scenario_id_array); + float *x_base = (float *)PyArray_DATA(x_array); + float *y_base = (float *)PyArray_DATA(y_array); + float *z_base = (float *)PyArray_DATA(z_array); + float *heading_base = (float *)PyArray_DATA(heading_array); + int *valid_base = (int *)PyArray_DATA(valid_array); + int *id_base = (int *)PyArray_DATA(id_array); + int *scenario_id_base = (int *)PyArray_DATA(scenario_id_array); // Get number of timesteps from array shape - npy_intp* x_shape = PyArray_DIMS(x_array); - int num_timesteps = x_shape[1]; // Second dimension for 2D arrays + npy_intp *x_shape = PyArray_DIMS(x_array); + int num_timesteps = x_shape[1]; // Second dimension for 2D arrays // Iterate through environments and write to correct offsets - int agent_offset = 0; // Offset for 1D arrays (id, scenario_id) - int traj_offset = 0; // Offset for 2D arrays (x, y, z, heading, valid) + int agent_offset = 0; // Offset for 1D arrays (id, scenario_id) + int traj_offset = 0; // Offset for 2D arrays (x, y, z, heading, valid) for (int i = 0; i < vec->num_envs; i++) { - Drive* drive = (Drive*)vec->envs[i]; + Drive *drive = (Drive *)vec->envs[i]; - c_get_global_ground_truth_trajectories(drive, - &x_base[traj_offset], - &y_base[traj_offset], - &z_base[traj_offset], - &heading_base[traj_offset], - &valid_base[traj_offset], - &id_base[agent_offset], - &scenario_id_base[agent_offset]); + c_get_global_ground_truth_trajectories(drive, &x_base[traj_offset], &y_base[traj_offset], &z_base[traj_offset], + &heading_base[traj_offset], &valid_base[traj_offset], + &id_base[agent_offset], &scenario_id_base[agent_offset]); // Move offsets forward agent_offset += drive->active_agent_count; @@ -876,8 +949,64 @@ static PyObject* vec_get_global_ground_truth_trajectories(PyObject* self, PyObje Py_RETURN_NONE; } -static double unpack(PyObject* kwargs, char* key) { - PyObject* val = PyDict_GetItemString(kwargs, key); + +static PyObject *vec_get_road_edge_counts(PyObject *self, PyObject *args) { + VecEnv *vec = unpack_vecenv(args); + if (!vec) + return NULL; + + int total_polylines = 0, total_points = 0; + for (int i = 0; i < vec->num_envs; i++) { + Drive *drive = (Drive *)vec->envs[i]; + int np, tp; + c_get_road_edge_counts(drive, &np, &tp); + total_polylines += np; + total_points += tp; + } + return Py_BuildValue("(ii)", total_polylines, total_points); +} + +static PyObject *vec_get_road_edge_polylines(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 5) { + PyErr_SetString(PyExc_TypeError, "vec_get_road_edge_polylines requires 5 arguments"); + return NULL; + } + + VecEnv *vec = unpack_vecenv(args); + if (!vec) + return NULL; + + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *lengths_arr = PyTuple_GetItem(args, 3); + PyObject *scenario_ids_arr = PyTuple_GetItem(args, 4); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(lengths_arr) || + !PyArray_Check(scenario_ids_arr)) { + PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); + return NULL; + } + + float *x_base = (float *)PyArray_DATA((PyArrayObject *)x_arr); + float *y_base = (float *)PyArray_DATA((PyArrayObject *)y_arr); + int *lengths_base = (int *)PyArray_DATA((PyArrayObject *)lengths_arr); + int *scenario_ids_base = (int *)PyArray_DATA((PyArrayObject *)scenario_ids_arr); + + int poly_offset = 0, pt_offset = 0; + for (int i = 0; i < vec->num_envs; i++) { + Drive *drive = (Drive *)vec->envs[i]; + int np, tp; + c_get_road_edge_counts(drive, &np, &tp); + c_get_road_edge_polylines(drive, &x_base[pt_offset], &y_base[pt_offset], &lengths_base[poly_offset], + &scenario_ids_base[poly_offset]); + poly_offset += np; + pt_offset += tp; + } + Py_RETURN_NONE; +} + +static double unpack(PyObject *kwargs, char *key) { + PyObject *val = PyDict_GetItemString(kwargs, key); if (val == NULL) { char error_msg[100]; snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key); @@ -904,8 +1033,8 @@ static double unpack(PyObject* kwargs, char* key) { return 1; } -static char* unpack_str(PyObject* kwargs, char* key) { - PyObject* val = PyDict_GetItemString(kwargs, key); +static char *unpack_str(PyObject *kwargs, char *key) { + PyObject *val = PyDict_GetItemString(kwargs, key); if (val == NULL) { char error_msg[100]; snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key); @@ -918,12 +1047,12 @@ static char* unpack_str(PyObject* kwargs, char* key) { PyErr_SetString(PyExc_TypeError, error_msg); return NULL; } - const char* str_val = PyUnicode_AsUTF8(val); + const char *str_val = PyUnicode_AsUTF8(val); if (str_val == NULL) { // PyUnicode_AsUTF8 sets an error on failure return NULL; } - char* ret = strdup(str_val); + char *ret = strdup(str_val); if (ret == NULL) { PyErr_SetString(PyExc_MemoryError, "strdup failed in unpack_str"); } @@ -932,7 +1061,8 @@ static char* unpack_str(PyObject* kwargs, char* key) { // Method table static PyMethodDef methods[] = { - {"env_init", (PyCFunction)env_init, METH_VARARGS | METH_KEYWORDS, "Init environment with observation, action, reward, terminal, truncation arrays"}, + {"env_init", (PyCFunction)env_init, METH_VARARGS | METH_KEYWORDS, + "Init environment with observation, action, reward, terminal, truncation arrays"}, {"env_reset", env_reset, METH_VARARGS, "Reset the environment"}, {"env_step", env_step, METH_VARARGS, "Step the environment"}, {"env_render", env_render, METH_VARARGS, "Render the environment"}, @@ -945,26 +1075,56 @@ static PyMethodDef methods[] = { {"vec_step", vec_step, METH_VARARGS, "Step the vector of environments"}, {"vec_log", vec_log, METH_VARARGS, "Log the vector of environments"}, {"vec_render", vec_render, METH_VARARGS, "Render the vector of environments"}, + {"vec_set_video_suffix", vec_set_video_suffix, METH_VARARGS, "Set video filename suffix for headless rendering"}, {"vec_close", vec_close, METH_VARARGS, "Close the vector of environments"}, + {"vec_donate_client", vec_donate_client, METH_VARARGS, + "Stash env[0]->client into a global so it survives a subsequent vec_close (render only)"}, + {"vec_adopt_client", vec_adopt_client, METH_VARARGS, + "Re-attach the previously donated client to env[0] of the new vec (render only)"}, {"shared", (PyCFunction)my_shared, METH_VARARGS | METH_KEYWORDS, "Shared state"}, {"get_global_agent_state", get_global_agent_state, METH_VARARGS, "Get global agent state"}, {"vec_get_global_agent_state", vec_get_global_agent_state, METH_VARARGS, "Get agent state from vectorized env"}, {"get_ground_truth_trajectories", get_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories"}, - {"vec_get_global_ground_truth_trajectories", vec_get_global_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories from vectorized env"}, + {"vec_get_global_ground_truth_trajectories", vec_get_global_ground_truth_trajectories, METH_VARARGS, + "Get ground truth trajectories from vectorized env"}, + {"vec_get_road_edge_counts", vec_get_road_edge_counts, METH_VARARGS, + "Get road edge polyline counts from vectorized env"}, + {"vec_get_road_edge_polylines", vec_get_road_edge_polylines, METH_VARARGS, + "Get road edge polylines from vectorized env"}, MY_METHODS, - {NULL, NULL, 0, NULL} -}; + {NULL, NULL, 0, NULL}}; // Module definition -static PyModuleDef module = { - PyModuleDef_HEAD_INIT, - "binding", - NULL, - -1, - methods -}; +static PyModuleDef module = {PyModuleDef_HEAD_INIT, "binding", NULL, -1, methods}; PyMODINIT_FUNC PyInit_binding(void) { import_array(); - return PyModule_Create(&module); + PyObject *m = PyModule_Create(&module); // Changed variable name from 'module' to 'm' + + if (m == NULL) { + return NULL; + } + + // Make constants accessible from Python + PyModule_AddIntConstant(m, "MAX_ROAD_SEGMENT_OBSERVATIONS", MAX_ROAD_SEGMENT_OBSERVATIONS); + PyModule_AddIntConstant(m, "MAX_AGENTS", MAX_AGENTS); + PyModule_AddIntConstant(m, "TRAJECTORY_LENGTH", TRAJECTORY_LENGTH); + PyModule_AddIntConstant(m, "MAX_ENTITIES_PER_CELL", MAX_ENTITIES_PER_CELL); + + PyModule_AddIntConstant(m, "ROAD_FEATURES", ROAD_FEATURES); + PyModule_AddIntConstant(m, "PARTNER_FEATURES", PARTNER_FEATURES); + PyModule_AddIntConstant(m, "EGO_FEATURES_CLASSIC", EGO_FEATURES_CLASSIC); + PyModule_AddIntConstant(m, "EGO_FEATURES_JERK", EGO_FEATURES_JERK); + + // Render mode constants + PyModule_AddIntConstant(m, "RENDER_OFF", RENDER_OFF); + PyModule_AddIntConstant(m, "RENDER_HEADLESS", RENDER_HEADLESS); + PyModule_AddIntConstant(m, "RENDER_WINDOW", RENDER_WINDOW); + + // View mode constants + PyModule_AddIntConstant(m, "VIEW_MODE_SIM_STATE", VIEW_MODE_SIM_STATE); + PyModule_AddIntConstant(m, "VIEW_MODE_BEV_AGENT_OBS", VIEW_MODE_BEV_AGENT_OBS); + PyModule_AddIntConstant(m, "VIEW_MODE_AGENT_PERSP", VIEW_MODE_AGENT_PERSP); + + return m; } diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h index e8400c51c9..35f2806806 100644 --- a/pufferlib/ocean/env_config.h +++ b/pufferlib/ocean/env_config.h @@ -6,9 +6,22 @@ #include #include +typedef struct { + char *type; + float reward_offroad_weight_lb; + float reward_offroad_weight_ub; + float reward_collision_weight_lb; + float reward_collision_weight_ub; + float reward_goal_weight_lb; + float reward_goal_weight_ub; + float entropy_weight_lb; + float entropy_weight_ub; + float discount_weight_lb; + float discount_weight_ub; +} conditioning_config; + // Config struct for parsing INI files - contains all environment configuration -typedef struct -{ +typedef struct { int action_type; int dynamics_model; float reward_vehicle_collision; @@ -16,49 +29,56 @@ typedef struct float reward_goal; float reward_goal_post_respawn; float reward_vehicle_collision_post_respawn; - float reward_ade; float goal_radius; + float goal_speed; int collision_behavior; int offroad_behavior; int spawn_immunity_timer; float dt; int goal_behavior; + float goal_target_distance; int scenario_length; + int k_scenarios; + int termination_mode; int init_steps; int init_mode; int control_mode; + char map_dir[256]; + conditioning_config *conditioning; + // Population play settings + int co_player_enabled; + int num_ego_agents; + char co_player_policy_path[256]; + conditioning_config *co_player_conditioning; } env_init_config; // INI file parser handler - parses all environment configuration from drive.ini -static int handler( - void* config, - const char* section, - const char* name, - const char* value -) { - env_init_config* env_config = (env_init_config*)config; - #define MATCH(s, n) strcmp(section, s) == 0 && strcmp(name, n) == 0 +static int handler(void *config, const char *section, const char *name, const char *value) { + env_init_config *env_config = (env_init_config *)config; +#define MATCH(s, n) strcmp(section, s) == 0 && strcmp(name, n) == 0 if (MATCH("env", "action_type")) { - if (strcmp(value, "\"discrete\"") == 0 ||strcmp(value, "discrete") == 0) { - env_config->action_type = 0; // DISCRETE + if (strcmp(value, "\"discrete\"") == 0 || strcmp(value, "discrete") == 0) { + env_config->action_type = 0; // DISCRETE } else if (strcmp(value, "\"continuous\"") == 0 || strcmp(value, "continuous") == 0) { - env_config->action_type = 1; // CONTINUOUS + env_config->action_type = 1; // CONTINUOUS } else { printf("Warning: Unknown action_type value '%s', defaulting to DISCRETE\n", value); - env_config->action_type = 0; // Default to DISCRETE + env_config->action_type = 0; // Default to DISCRETE } } else if (MATCH("env", "dynamics_model")) { if (strcmp(value, "\"classic\"") == 0 || strcmp(value, "classic") == 0) { - env_config->dynamics_model = 0; // CLASSIC + env_config->dynamics_model = 0; // CLASSIC } else if (strcmp(value, "\"jerk\"") == 0 || strcmp(value, "jerk") == 0) { - env_config->dynamics_model = 1; // JERK + env_config->dynamics_model = 1; // JERK } else { printf("Warning: Unknown dynamics_model value '%s', defaulting to JERK\n", value); - env_config->dynamics_model = 1; // Default to JERK + env_config->dynamics_model = 1; // Default to JERK } } else if (MATCH("env", "goal_behavior")) { env_config->goal_behavior = atoi(value); + } else if (MATCH("env", "goal_target_distance")) { + env_config->goal_target_distance = atof(value); } else if (MATCH("env", "reward_vehicle_collision")) { env_config->reward_vehicle_collision = atof(value); } else if (MATCH("env", "reward_offroad_collision")) { @@ -69,13 +89,13 @@ static int handler( env_config->reward_goal_post_respawn = atof(value); } else if (MATCH("env", "reward_vehicle_collision_post_respawn")) { env_config->reward_vehicle_collision_post_respawn = atof(value); - } else if (MATCH("env", "reward_ade")) { - env_config->reward_ade = atof(value); } else if (MATCH("env", "goal_radius")) { env_config->goal_radius = atof(value); - } else if(MATCH("env", "collision_behavior")){ + } else if (MATCH("env", "goal_speed")) { + env_config->goal_speed = atof(value); + } else if (MATCH("env", "collision_behavior")) { env_config->collision_behavior = atoi(value); - } else if(MATCH("env", "offroad_behavior")){ + } else if (MATCH("env", "offroad_behavior")) { env_config->offroad_behavior = atoi(value); } else if (MATCH("env", "spawn_immunity_timer")) { env_config->spawn_immunity_timer = atoi(value); @@ -83,15 +103,171 @@ static int handler( env_config->dt = atof(value); } else if (MATCH("env", "scenario_length")) { env_config->scenario_length = atoi(value); + } else if (MATCH("env", "k_scenarios")) { + env_config->k_scenarios = atoi(value); + } else if (MATCH("env", "termination_mode")) { + env_config->termination_mode = atoi(value); } else if (MATCH("env", "init_steps")) { env_config->init_steps = atoi(value); } else if (MATCH("env", "init_mode")) { env_config->init_mode = atoi(value); } else if (MATCH("env", "control_mode")) { env_config->control_mode = atoi(value); + } else if (MATCH("env", "map_dir")) { + if (sscanf(value, "\"%255[^\"]\"", env_config->map_dir) != 1) { + strncpy(env_config->map_dir, value, sizeof(env_config->map_dir) - 1); + env_config->map_dir[sizeof(env_config->map_dir) - 1] = '\0'; + } + // printf("Parsed map_dir: '%s'\n", env_config->map_dir); + } else if (MATCH("env.conditioning", "type")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + // Remove quotes if present + if (value[0] == '"') { + size_t len = strlen(value) - 2; // -2 for both quotes + env_config->conditioning->type = (char *)malloc(len + 1); + strncpy(env_config->conditioning->type, value + 1, len); + env_config->conditioning->type[len] = '\0'; + } else { + env_config->conditioning->type = strdup(value); + } + } else if (MATCH("env.conditioning", "collision_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_collision_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "collision_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_collision_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "offroad_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_offroad_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "offroad_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_offroad_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "goal_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_goal_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "goal_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_goal_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "entropy_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->entropy_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "entropy_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->entropy_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "discount_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->discount_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "discount_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->discount_weight_ub = atof(value); + } + // Population play settings + else if (MATCH("env", "co_player_enabled")) { + if (strcmp(value, "True") == 0 || strcmp(value, "true") == 0 || strcmp(value, "1") == 0) { + env_config->co_player_enabled = 1; + } else { + env_config->co_player_enabled = 0; + } + } else if (MATCH("env", "num_ego_agents")) { + env_config->num_ego_agents = atoi(value); + } + // Co-player policy settings + else if (MATCH("env.co_player_policy", "policy_path")) { + if (sscanf(value, "\"%255[^\"]\"", env_config->co_player_policy_path) != 1) { + strncpy(env_config->co_player_policy_path, value, sizeof(env_config->co_player_policy_path) - 1); + env_config->co_player_policy_path[sizeof(env_config->co_player_policy_path) - 1] = '\0'; + } + } else if (MATCH("env.co_player_policy.conditioning", "type")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + if (value[0] == '"') { + size_t len = strlen(value) - 2; + env_config->co_player_conditioning->type = (char *)malloc(len + 1); + strncpy(env_config->co_player_conditioning->type, value + 1, len); + env_config->co_player_conditioning->type[len] = '\0'; + } else { + env_config->co_player_conditioning->type = strdup(value); + } + } else if (MATCH("env.co_player_policy.conditioning", "collision_weight_lb")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->reward_collision_weight_lb = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "collision_weight_ub")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->reward_collision_weight_ub = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "offroad_weight_lb")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->reward_offroad_weight_lb = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "offroad_weight_ub")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->reward_offroad_weight_ub = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "goal_weight_lb")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->reward_goal_weight_lb = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "goal_weight_ub")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->reward_goal_weight_ub = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "entropy_weight_lb")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->entropy_weight_lb = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "entropy_weight_ub")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->entropy_weight_ub = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "discount_weight_lb")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->discount_weight_lb = atof(value); + } else if (MATCH("env.co_player_policy.conditioning", "discount_weight_ub")) { + if (env_config->co_player_conditioning == NULL) { + env_config->co_player_conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->co_player_conditioning->discount_weight_ub = atof(value); + } + + else { + return 0; // Unknown section/name, indicate failure to handle } - #undef MATCH +#undef MATCH return 1; } diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 7aa3daa388..8a2a4c0b39 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -10,12 +10,19 @@ Recurrent = pufferlib.models.LSTMWrapper +Transformer = pufferlib.models.TransformerWrapper class Drive(nn.Module): def __init__(self, env, input_size=128, hidden_size=128, **kwargs): super().__init__() self.hidden_size = hidden_size + self.observation_size = env.single_observation_space.shape[0] + self.max_partner_objects = env.max_partner_objects + self.partner_features = env.partner_features + self.max_road_objects = env.max_road_objects + self.road_features = env.road_features + self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories # Conditioning setup self.use_rc = env.reward_conditioned @@ -24,25 +31,42 @@ def __init__(self, env, input_size=128, hidden_size=128, **kwargs): self.conditioning_dims = (3 if self.use_rc else 0) + (1 if self.use_ec else 0) + (1 if self.use_dc else 0) # Determine ego dimension from environment's dynamics model - base_ego_dim = 10 if env.dynamics_model == "jerk" else 7 - self.ego_dim = base_ego_dim + self.conditioning_dims + # Use binding constants for ego dimensions (includes lane features) + from pufferlib.ocean.drive import binding + base_ego_dim = binding.EGO_FEATURES_JERK if env.dynamics_model == "jerk" else binding.EGO_FEATURES_CLASSIC + self.ego_dim = base_ego_dim + self.conditioning_dims + cond_flags = [] + if self.use_rc: + cond_flags.append("reward(3)") + if self.use_ec: + cond_flags.append("entropy(1)") + if self.use_dc: + cond_flags.append("discount(1)") + cond_str = "+".join(cond_flags) if cond_flags else "none" + print( + f"[Drive policy] dynamics={env.dynamics_model} ego_dim={self.ego_dim} " + f"(base={base_ego_dim}+cond={self.conditioning_dims}: {cond_str}) " + f"obs_dim={self.observation_size} max_partners={self.max_partner_objects} " + f"max_road={self.max_road_objects} hidden={hidden_size}", + flush=True, + ) self.ego_encoder = nn.Sequential( pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)), nn.LayerNorm(input_size), # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), ) - max_road_objects = 13 + self.road_encoder = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(max_road_objects, input_size)), + pufferlib.pytorch.layer_init(nn.Linear(self.road_features_after_onehot, input_size)), nn.LayerNorm(input_size), # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), ) - max_partner_objects = 7 + self.partner_encoder = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(max_partner_objects, input_size)), + pufferlib.pytorch.layer_init(nn.Linear(self.partner_features, input_size)), nn.LayerNorm(input_size), # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), @@ -72,17 +96,18 @@ def forward_train(self, x, state=None): def encode_observations(self, observations, state=None): ego_dim = self.ego_dim - partner_dim = 63 * 7 - road_dim = 200 * 7 + partner_dim = self.max_partner_objects * self.partner_features + road_dim = self.max_road_objects * self.road_features ego_obs = observations[:, :ego_dim] partner_obs = observations[:, ego_dim : ego_dim + partner_dim] road_obs = observations[:, ego_dim + partner_dim : ego_dim + partner_dim + road_dim] - partner_objects = partner_obs.view(-1, 63, 7) - road_objects = road_obs.view(-1, 200, 7) - road_continuous = road_objects[:, :, :6] # First 6 features - road_categorical = road_objects[:, :, 6] - road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, 200, 7] + partner_objects = partner_obs.view(-1, self.max_partner_objects, self.partner_features) + + road_objects = road_obs.view(-1, self.max_road_objects, self.road_features) + road_continuous = road_objects[:, :, : self.road_features - 1] + road_categorical = road_objects[:, :, self.road_features - 1] + road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, ROAD_MAX_OBJECTS, 7] road_objects = torch.cat([road_continuous, road_onehot], dim=2) ego_features = self.ego_encoder(ego_obs) partner_features, _ = self.partner_encoder(partner_objects).max(dim=1) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 68b2e1ef2a..08e680cd13 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -20,6 +20,7 @@ import configparser from threading import Thread from collections import defaultdict, deque +from pathlib import Path import numpy as np import psutil @@ -34,6 +35,7 @@ import pufferlib.vector import pufferlib.pytorch import pufferlib.utils +import pufferlib.utils try: from pufferlib import _C @@ -74,7 +76,17 @@ def __init__(self, config, vecenv, policy, logger=None): # Vecenv info self.adaptive_driving_agent = getattr(vecenv.driver_env, "env_name", None) == "adaptive_drive" if self.adaptive_driving_agent: - config["bptt_horizon"] = vecenv.driver_env.episode_length + if config.get("policy_architecture", "Recurrent") == "Recurrent": + config["bptt_horizon"] = vecenv.driver_env.episode_length + if config.get("policy_architecture", "Recurrent") == "Transformer": + config["context_length"] = self.context_length = vecenv.driver_env.episode_length + config["bptt_horizon"] = ( + vecenv.driver_env.episode_length + ) ## this is used downstream so you need to define it too + else: + if config.get("policy_architecture", "Recurrent") == "Transformer": + self.context_length = config["context_length"] + config["bptt_horizon"] = config["context_length"] vecenv.async_reset(seed) obs_space = vecenv.single_observation_space @@ -84,7 +96,10 @@ def __init__(self, config, vecenv, policy, logger=None): if self.population_play: total_ego_agents = vecenv.num_ego_agents agents_for_calc = total_ego_agents - batch_size = vecenv.driver_env.num_ego_agents * config["bptt_horizon"] * vecenv.num_workers + if config.get("policy_architecture", "Recurrent") == "Recurrent": + batch_size = vecenv.driver_env.num_ego_agents * config["bptt_horizon"] * vecenv.num_workers + if config.get("policy_architecture", "Recurrent") == "Transformer": + batch_size = vecenv.driver_env.num_ego_agents * config["context_length"] * vecenv.num_workers config["batch_size"] = batch_size ## this is dynamic and based on ego agents else: agents_for_calc = total_agents @@ -93,17 +108,43 @@ def __init__(self, config, vecenv, policy, logger=None): self.total_agents = total_agents # Experience - if config["batch_size"] == "auto" and config["bptt_horizon"] == "auto": - raise pufferlib.APIUsageError("Must specify batch_size or bptt_horizon") + if ( + config["batch_size"] == "auto" + and config.get("bptt_horizon", "auto") == "auto" + and config.get("context_length", "auto") == "auto" + ): + raise pufferlib.APIUsageError("Must specify batch_size, bptt_horizon, or context_length") elif config["batch_size"] == "auto": - config["batch_size"] = agents_for_calc * config["bptt_horizon"] - elif config["bptt_horizon"] == "auto": + if config.get("policy_architecture", "Recurrent") == "Recurrent": + config["batch_size"] = agents_for_calc * config["bptt_horizon"] + elif config.get("policy_architecture", "Recurrent") == "Transformer": + config["batch_size"] = agents_for_calc * config["context_length"] + elif ( + config.get("bptt_horizon", "auto") == "auto" + and config.get("policy_architecture", "Recurrent") == "Recurrent" + ): config["bptt_horizon"] = config["batch_size"] // agents_for_calc + elif ( + config.get("context_length", "auto") == "auto" + and config.get("policy_architecture", "Recurrent") == "Transformer" + ): + config["context_length"] = config["batch_size"] // agents_for_calc batch_size = config["batch_size"] - horizon = config["bptt_horizon"] + + # Set horizon based on model type + if config.get("policy_architecture", "Recurrent") == "Recurrent": + horizon = config["bptt_horizon"] + elif config.get("policy_architecture", "Recurrent") == "Transformer": + horizon = config["context_length"] + else: + horizon = config.get("bptt_horizon", config.get("context_length", 1)) + + config["bptt_horizon"] = horizon # For backward compatibility + segments = batch_size // horizon self.segments = segments + self.horizon = horizon if not self.population_play: if total_agents > segments: raise pufferlib.APIUsageError(f"Total agents {total_agents} <= segments {segments}") @@ -140,11 +181,8 @@ def __init__(self, config, vecenv, policy, logger=None): self.render = config["render"] self.render_interval = config["render_interval"] - if self.render: - ensure_drive_binary() - # LSTM - if config["use_rnn"]: + if config.get("rnn_name", "Recurrent") == "Recurrent": h = policy.hidden_size if self.population_play: n = vecenv.ego_agents_per_batch # Use ego agents per batch @@ -156,6 +194,47 @@ def __init__(self, config, vecenv, policy, logger=None): self.lstm_h = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)} self.lstm_c = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)} + # TRANSFORMER + if config.get("rnn_name", "Recurrent") == "Transformer": + h = policy.hidden_size + + if self.population_play: + n = vecenv.ego_agents_per_batch # Use ego agents per batch + num_chunks = total_ego_agents // n + # Initialize transformer context buffers + self.transformer_context = {i * n: torch.zeros(n, 0, h, device=device) for i in range(num_chunks)} + self.transformer_position = { + i * n: torch.zeros(n, dtype=torch.long, device=device) for i in range(num_chunks) + } + else: + n = vecenv.agents_per_batch + num_chunks = total_agents // n + # Initialize transformer context buffers + self.transformer_context = {i * n: torch.zeros(n, 0, h, device=device) for i in range(num_chunks)} + self.transformer_position = { + i * n: torch.zeros(n, dtype=torch.long, device=device) for i in range(num_chunks) + } + # K/V cache persistence for the streaming forward_eval path. + # The model lazy-allocates k_cache and v_cache (list of per-layer + # tensors) on first call when state.get("k_cache") is None. We + # persist them here so the next rollout step finds the cache + # already populated with past timesteps' projections — without + # this, every step would lazy-allocate fresh empty caches and + # the policy would attend only to the current step (silent bug + # discovered 2026-05-02; broke all in-context-learning runs + # prior to that). None initially → first call allocates. + self.transformer_k_cache = {i * n: None for i in range(num_chunks)} + self.transformer_v_cache = {i * n: None for i in range(num_chunks)} + + # Regression detector for the rnn_name plumbing bug — fires once. + print( + f"[VERIFY rnn_name] config.get('rnn_name')={config.get('rnn_name')!r}, " + f"policy_architecture={config.get('policy_architecture')!r}, " + f"has_lstm_h={hasattr(self, 'lstm_h')}, " + f"has_transformer_k_cache={hasattr(self, 'transformer_k_cache')}", + flush=True, + ) + # Minibatching & gradient accumulation if self.adaptive_driving_agent: minibatch_size = config["minibatch_multiplier"] * horizon @@ -179,7 +258,7 @@ def __init__(self, config, vecenv, policy, logger=None): self.minibatch_segments = self.minibatch_size // horizon if self.minibatch_segments * horizon != self.minibatch_size: raise pufferlib.APIUsageError( - f"minibatch_size {self.minibatch_size} must be divisible by bptt_horizon {horizon}" + f"minibatch_size {self.minibatch_size} must be divisible by horizon {horizon}" ) # Torch compile @@ -187,7 +266,8 @@ def __init__(self, config, vecenv, policy, logger=None): self.policy = policy if config["compile"]: self.policy = torch.compile(policy, mode=config["compile_mode"]) - self.policy.forward_eval = torch.compile(policy, mode=config["compile_mode"]) + if hasattr(policy, "forward_eval"): + self.policy.forward_eval = torch.compile(policy.forward_eval, mode=config["compile_mode"]) pufferlib.pytorch.sample_logits = torch.compile( pufferlib.pytorch.sample_logits, mode=config["compile_mode"] ) @@ -217,26 +297,95 @@ def __init__(self, config, vecenv, policy, logger=None): raise ValueError(f"Unknown optimizer: {config['optimizer']}") self.optimizer = optimizer + + # ---- Resume optimizer / epoch / global_step from trainer_state.pt ---- + # When --load-model-path points at a checkpoint that has a sibling + # trainer_state.pt (which the trainer writes alongside every model + # checkpoint), restore optimizer momentum + counters so the resumed + # run continues mid-cosine instead of warm-restarting at peak LR with + # cold Adam moments. + resume_epoch = 0 + resume_global_step = 0 + load_path = config.get("load_model_path") + if load_path: + state_path = os.path.join(os.path.dirname(load_path), "trainer_state.pt") + if os.path.exists(state_path): + try: + # weights_only=False: trainer_state.pt contains optimizer + # state (with class refs), not just tensors. + saved = torch.load(state_path, map_location=config["device"], weights_only=False) + optimizer.load_state_dict(saved["optimizer_state_dict"]) + resume_epoch = int(saved.get("update", 0)) + resume_global_step = int(saved.get("global_step", 0)) + print( + f"[trainer-state] Resumed optimizer state from {state_path}\n" + f"[trainer-state] epoch={resume_epoch} global_step={resume_global_step}", + flush=True, + ) + except Exception as e: + print(f"[trainer-state] WARNING: could not load {state_path}: {e}", flush=True) + # Logging self.logger = logger if logger is None: self.logger = NoLogger(config) if self.population_play: + # Under external_co_player_actions, driver_env.co_player_policy is + # None (worker doesn't load it); the GPU-bound copy lives on the + # vecenv as co_player_policy_func. + export_co_player = getattr(vecenv, "co_player_policy_func", None) or vecenv.driver_env.co_player_policy co_player_path = f"resources/drive/{config['env']}_co_player.bin" export_args = {"env_name": config["env"], "path": co_player_path, **config} export( args=export_args, env_name=config["env"], vecenv=vecenv, - policy=vecenv.driver_env.co_player_policy, + policy=export_co_player, path=co_player_path, silent=True, ) - # Learning rate scheduler + # ---- Centralized GPU co-player inference (when enabled) ------------ + self.external_co_player = bool( + self.population_play + and getattr(vecenv, "co_player_policy_func", None) is not None + and config.get("env_config", {}).get("external_co_player_actions", False) + ) + if self.external_co_player: + co_policy = vecenv.co_player_policy_func.to(config["device"]) + co_policy.eval() + self.co_player_policy = co_policy + self.co_player_conditioning_dims = getattr(vecenv, "co_player_conditioning_dims", 0) + # One state dict per worker (each worker holds its own slice of + # co_players; the per-worker batch size is num_co_players_per_env). + num_co_per_worker = vecenv.driver_env.num_co_players + num_workers = vecenv.num_workers + self._co_player_num_per_worker = num_co_per_worker + # Per-worker state dicts. Start each as an empty dict so that + # `forward_eval` lazily allocates the K/V cache on first call with + # the correct (obs-derived) dtype — avoiding a cache dtype that + # mismatches the layer-output dtype during reset_eval_state's + # cache-prime path. + self.co_player_state = {w: {} for w in range(num_workers)} + print( + f"[external co-player] Loaded co-player on {device}; " + f"per-worker batch={num_co_per_worker}, conditioning_dims={self.co_player_conditioning_dims}, " + f"num_workers={num_workers}", + flush=True, + ) + + # Learning rate scheduler — if resuming, advance to the saved epoch + # position so cosine annealing continues smoothly. epochs = config["total_timesteps"] // config["batch_size"] - self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + last_epoch_arg = -1 + if resume_epoch > 0: + # CosineAnnealingLR requires `initial_lr` in each param_group when + # last_epoch != -1; old optimizer states sometimes lack it. + for group in optimizer.param_groups: + group.setdefault("initial_lr", config["learning_rate"]) + last_epoch_arg = resume_epoch - 1 # next .step() lands on resume_epoch + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, last_epoch=last_epoch_arg) self.total_epochs = epochs # Automatic mixed precision @@ -250,9 +399,9 @@ def __init__(self, config, vecenv, policy, logger=None): # Initializations self.config = config self.vecenv = vecenv - self.epoch = 0 - self.global_step = 0 - self.last_log_step = 0 + self.epoch = resume_epoch + self.global_step = resume_global_step + self.last_log_step = resume_global_step self.last_log_time = time.time() self.start_time = time.time() self.utilization = Utilization() @@ -260,6 +409,7 @@ def __init__(self, config, vecenv, policy, logger=None): self.stats = defaultdict(list) self.last_stats = defaultdict(list) self.losses = {} + # Dashboard self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad) self.print_dashboard(clear=True) @@ -275,6 +425,119 @@ def sps(self): return (self.global_step - self.last_log_step) / (time.time() - self.last_log_time) + def _fill_external_co_player_actions(self, full_obs, info, env_id, dones, truncs): + """Centralized co-player inference on GPU. + + Workers receive co_player actions via the shared `actions` SHM buffer + (filled here) instead of running per-worker CPU forward passes. + + Args: + full_obs: numpy obs from vecenv.recv(), shape + (num_agents_per_recv_batch, *obs_shape). + info: the raw info list from vecenv.recv(). Must contain a dict + with key "_external_co_player_ids" giving the actual list of + co-player agent indices (worker-local). Computing this from + complement-of-ego_ids is wrong: many slots are padding or + otherwise inactive, and forwarding garbage obs through them + pollutes the shared KV cache. + env_id: numpy array of agent indices for the current recv batch. + dones, truncs: numpy bool per-agent done/trunc flags. + """ + import numpy as np + import torch + + device = self.config["device"] + agents_per_worker = self.vecenv.agents_per_worker + # batch_size > 1 packs multiple workers into one recv; we handle the + # batch_size=1 case (the production setup) here. Generalizing to + # batch_size>1 is straightforward (loop over workers in the batch). + batch_size = self.vecenv.batch_size + if batch_size != 1: + raise NotImplementedError( + f"external_co_player_actions currently only supports batch_size=1; got batch_size={batch_size}." + ) + + # Map env_id back to a worker index so we know which co_player_state + # to use and which row of vecenv.actions to write to. + # For population_play, recv() returns ego-only agent ids + # (vecenv.ego_agent_ids), so divide by the per-worker ego count, not + # the full agent count. + ego_agents_per_worker = getattr(self.vecenv, "ego_agents_per_worker", agents_per_worker) + worker_id = int(env_id[0]) // ego_agents_per_worker + + # Pull the actual co_player_ids from info (the env knows them). + # Also check for the scenario-boundary cache reset signal. + co_ids = None + reset_cache = False + for item in info: + if isinstance(item, dict): + if "_external_co_player_ids" in item: + co_ids = list(item["_external_co_player_ids"]) + if item.get("_external_reset_co_cache"): + reset_cache = True + if co_ids is None: + raise RuntimeError( + "external_co_player_actions=True but the env did not " + "publish '_external_co_player_ids' in info. Is drive.py up to date?" + ) + if not co_ids: + return # no co-players in this env this step + + # Drop the cache at scenario boundaries — mirrors the per-worker + # OFF path's _reset_co_player_state() which fully reinits state. + # Replacing the dict makes forward_eval lazy-allocate fresh K/V on + # the next call, matching legacy behavior bit-for-bit at scenario + # boundaries. + if reset_cache: + self.co_player_state[worker_id] = {} + + # Slice the co-player observations. When the ego is in oracle mode + # (drive.py `ego_is_oracle=True`) the env's obs is wider than the + # partner policy expects — partner conditioning is appended to ego + # rows only. Strip trailing oracle dims by slicing columns to the + # env's `_c_obs_dim` (the C-side obs width). Defaults to None when + # oracle is off → take the full width as before. + co_obs_width = getattr(self.vecenv.driver_env, "_c_obs_dim", None) + if co_obs_width is None: + co_obs_np = full_obs[co_ids] + else: + co_obs_np = full_obs[co_ids, :co_obs_width] + co_obs = torch.as_tensor(co_obs_np, device=device) + if self.co_player_conditioning_dims > 0: + # Pull this worker's conditioning slice from the SHM buffer the + # env wrote at scenario boundaries (or at env init). Insert right + # after the base ego_features — matches drive.py's + # `_add_co_player_conditioning` exactly. + cond_shm = self.vecenv.co_player_conditioning # (num_workers, max_co, cdim) + cond_np = cond_shm[worker_id, : len(co_ids), :] # only the rows we'll use + cond = torch.as_tensor(cond_np, device=device, dtype=co_obs.dtype) + from pufferlib.ocean.drive import binding as _b + + base_ego_dim = ( + _b.EGO_FEATURES_JERK if self.vecenv.driver_env.dynamics_model == "jerk" else _b.EGO_FEATURES_CLASSIC + ) + co_obs = torch.cat([co_obs[:, :base_ego_dim], cond, co_obs[:, base_ego_dim:]], dim=1) + + # NOTE: the OFF (per-worker) path only resets cache at scenario + # boundary or reset(), never for individual done agents. So we + # don't reset per-done here either — it would diverge from OFF. + + with torch.no_grad(): + logits, _ = self.co_player_policy.forward_eval(co_obs, self.co_player_state[worker_id]) + + # Match the per-worker code path: argmax for discrete actions. + if isinstance(logits, tuple): + co_action = torch.cat([l.argmax(dim=-1, keepdim=True) for l in logits], dim=-1) + else: + co_action = logits.argmax(dim=-1) + co_action_np = co_action.cpu().numpy().reshape(len(co_ids), -1) + + # Write directly to the worker's slot in the shared-memory action + # buffer. The worker's env.step() will call vec_step using these + # actions because it has external_co_player_actions=True. + co_action_view = self.vecenv.actions[worker_id] # shape (agents_per_worker, *atn_shape) + co_action_view[co_ids] = co_action_np.reshape((len(co_ids),) + co_action_view.shape[1:]) + def evaluate(self): profile = self.profile epoch = self.epoch @@ -284,36 +547,59 @@ def evaluate(self): config = self.config device = config["device"] - if config["use_rnn"]: + # Reset hidden states for both RNN and Transformer + if config.get("rnn_name", "Recurrent") == "Recurrent": for k in self.lstm_h: self.lstm_h[k] = torch.zeros(self.lstm_h[k].shape, device=device) self.lstm_c[k] = torch.zeros(self.lstm_c[k].shape, device=device) + if config.get("rnn_name", "Recurrent") == "Transformer": + h = self.policy.hidden_size + for k in self.transformer_context: + n = self.transformer_context[k].shape[0] + # Pre-allocate full buffer instead of empty + self.transformer_context[k] = torch.zeros(n, self.horizon, h, device=device) + self.transformer_position[k] = torch.zeros(1, dtype=torch.long, device=device) + # Drop K/V cache so the model lazy-allocates fresh on + # the first forward_eval call of this rollout. + self.transformer_k_cache[k] = None + self.transformer_v_cache[k] = None + self.full_rows = 0 while self.full_rows < self.segments: profile("env", epoch) + # print(".", end="", flush=True) # Workaround: visible I/O prevents multiprocessing deadlock o, r, d, t, info, env_id, mask = self.vecenv.recv() # print(f"o shape is {o.shape}", flush = True) if self.population_play: batch_size = self.vecenv.batch_size - ego_ids = info[-1] + # Filter info to get only the ego_ids lists (not the metrics dicts) + ego_ids_per_env = [item for item in info if isinstance(item, list)] + + if self.external_co_player: + # Run co-player forward on GPU before the ego-only slicing + # below (we need the FULL obs array to extract co-player obs). + self._fill_external_co_player_actions(o, info, env_id, d, t) if batch_size > 1: total_agents = len(o) num_agents_per_env = total_agents // batch_size - original_shape = o.shape - - o = o.reshape(batch_size, num_agents_per_env, *original_shape[1:]) - r = r.reshape(batch_size, num_agents_per_env) - d = d.reshape(batch_size, num_agents_per_env) - t = t.reshape(batch_size, num_agents_per_env) - - o = o[:, ego_ids].reshape(batch_size * len(ego_ids), *original_shape[1:]) - r = r[:, ego_ids].flatten() - d = d[:, ego_ids].flatten() - t = t[:, ego_ids].flatten() + # Create flat ego_ids by adding batch offset + flat_ego_ids = [] + for env_idx in range(batch_size): + ego_ids = ego_ids_per_env[env_idx] + offset = env_idx * num_agents_per_env + flat_ego_ids.extend([int(idx) + offset for idx in ego_ids]) + + # Simply index with the flat ego_ids + o = o[flat_ego_ids] + r = r[flat_ego_ids] + d = d[flat_ego_ids] + t = t[flat_ego_ids] else: + ego_ids = ego_ids_per_env[0] # Single environment + ego_ids = [int(idx) for idx in ego_ids] # Convert to int o = o[ego_ids] r = r[ego_ids] d = d[ego_ids] @@ -326,9 +612,9 @@ def evaluate(self): profile("eval_copy", epoch) o = torch.as_tensor(o) - o_device = o.to(device) # , non_blocking=True) - r = torch.as_tensor(r).to(device) # , non_blocking=True) - d = torch.as_tensor(d).to(device) # , non_blocking=True) + o_device = o.to(device, non_blocking=True) + r = torch.as_tensor(r, device=device) + d = torch.as_tensor(d, device=device) profile("eval_forward", epoch) with torch.no_grad(), self.amp_context: @@ -338,29 +624,84 @@ def evaluate(self): env_id=env_id, mask=mask, ) - - if config["use_rnn"]: - state["lstm_h"] = self.lstm_h[env_id.start] - state["lstm_c"] = self.lstm_c[env_id.start] - + # Get appropriate batch key for state lookup + if self.population_play: + batch_size = self.vecenv.ego_agents_per_batch + else: + batch_size = self.vecenv.agents_per_batch + state_key = (env_id.start // batch_size) * batch_size + + if config.get("rnn_name", "Recurrent") == "Recurrent": + state["lstm_h"] = self.lstm_h[state_key] + state["lstm_c"] = self.lstm_c[state_key] + + if config.get("rnn_name", "Recurrent") == "Transformer": + state["transformer_context"] = self.transformer_context[state_key] + state["transformer_position"] = self.transformer_position[state_key] + # K/V cache for streaming attention. None on the first + # call → model lazy-allocates (and resets pos to 0). + # Subsequent calls reuse the populated cache, which is + # the whole point: each step appends one new K/V slot + # and the policy attends over the full accumulated past. + state["k_cache"] = self.transformer_k_cache[state_key] + state["v_cache"] = self.transformer_v_cache[state_key] + # Note: terminals not needed for eval since we're doing single-step inference + + # print(".", end="", flush=True) # Prevents multiprocessing deadlock logits, value = self.policy.forward_eval(o_device, state) action, logprob, _ = pufferlib.pytorch.sample_logits(logits) r = torch.clamp(r, -1, 1) profile("eval_copy", epoch) with torch.no_grad(): - if config["use_rnn"]: - # Use the same lstm_key calculation + # Update hidden states after forward pass + if config.get("rnn_name", "Recurrent") == "Recurrent": if self.population_play: batch_size = self.vecenv.ego_agents_per_batch else: batch_size = self.vecenv.agents_per_batch lstm_key = (env_id.start // batch_size) * batch_size - self.lstm_h[lstm_key] = state["lstm_h"] self.lstm_c[lstm_key] = state["lstm_c"] + if config.get("rnn_name", "Recurrent") == "Transformer": + if self.population_play: + batch_size = self.vecenv.ego_agents_per_batch + else: + batch_size = self.vecenv.agents_per_batch + + transformer_key = (env_id.start // batch_size) * batch_size + self.transformer_context[transformer_key] = state["transformer_context"] + self.transformer_position[transformer_key] = state["transformer_position"] + # Persist the K/V cache the model just wrote/updated so + # the next forward_eval call sees the accumulated past. + # state.get(...) is defensive: model may not have set + # these if it took the legacy path. + self.transformer_k_cache[transformer_key] = state.get("k_cache") + self.transformer_v_cache[transformer_key] = state.get("v_cache") + + # Episode-boundary reset. pos is a shared (1,) scalar + # across the chunk; cache rows are per-agent. Filter + # done indices against the cache's batch dim, not the + # pos buffer's (1,) shape. + if done_mask.any(): + done_indices = torch.where(torch.from_numpy(done_mask))[0] + if len(done_indices) > 0: + batch_start_in_group = env_id.start % batch_size + global_indices = batch_start_in_group + done_indices + kc = self.transformer_k_cache[transformer_key] + vc = self.transformer_v_cache[transformer_key] + cache_batch_dim = kc[0].shape[0] if kc is not None else 0 + valid_mask = global_indices < cache_batch_dim + valid_indices = global_indices[valid_mask] + if len(valid_indices) > 0: + self.transformer_position[transformer_key][:] = 0 + if kc is not None and vc is not None: + for c in kc: + c[valid_indices] = 0 + for c in vc: + c[valid_indices] = 0 # Fast path for fully vectorized envs l = self.ep_lengths[env_id.start].item() batch_rows = slice(self.ep_indices[env_id.start].item(), 1 + self.ep_indices[env_id.stop - 1].item()) @@ -378,7 +719,13 @@ def evaluate(self): # Note: We are not yet handling masks in this version self.ep_lengths[env_id] += 1 - if l + 1 >= config["bptt_horizon"]: + # Use appropriate horizon based on model type + horizon = ( + config.get("context_length") + if config.get("policy_architecture", "Recurrent") == "Transformer" + else config["bptt_horizon"] + ) + if l + 1 >= horizon: num_full = env_id.stop - env_id.start self.ep_indices[env_id] = self.free_idx + torch.arange(num_full, device=config["device"]).int() self.ep_lengths[env_id] = 0 @@ -400,12 +747,17 @@ def evaluate(self): self.stats[k].append(v) profile("env", epoch) - self.vecenv.send(action) profile("eval_misc", epoch) self.free_idx = self.total_agents - self.ep_indices = torch.arange(self.total_agents, device=device, dtype=torch.int32) + + if self.population_play: + total_agents = self.vecenv.num_ego_agents + else: + total_agents = self.total_agents + + self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32) self.ep_lengths.zero_() profile.end() return self.stats @@ -438,9 +790,9 @@ def train(self): hasattr(self.vecenv.driver_env, "dynamics_model") and self.vecenv.driver_env.dynamics_model == "jerk" ): - disc_idx = 10 # base ego obs + disc_idx = 12 # EGO_FEATURES_JERK (was 10 before lane features) else: - disc_idx = 7 + disc_idx = 9 # EGO_FEATURES_CLASSIC (was 7 before lane features) if self.vecenv.driver_env.reward_conditioned: disc_idx += 3 @@ -468,7 +820,15 @@ def train(self): prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) idx = torch.multinomial(prio_probs, self.minibatch_segments) mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta - mb_obs = self.observations[idx] + # When cpu_offload=True, self.observations lives on CPU but `idx` + # is on the training device (GPU). PyTorch refuses cross-device + # fancy indexing, so move the index to CPU for the gather, then + # ship the resulting minibatch to the device. Buffer was allocated + # with pin_memory=True (see __init__) so the H2D copy is fast. + if config["cpu_offload"]: + mb_obs = self.observations[idx.cpu()].to(device, non_blocking=True) + else: + mb_obs = self.observations[idx] mb_actions = self.actions[idx] mb_logprobs = self.logprobs[idx] mb_rewards = self.rewards[idx] @@ -480,17 +840,45 @@ def train(self): mb_advantages = advantages[idx] profile("train_forward", epoch) - if not config["use_rnn"]: + + # Handle observation reshaping based on model type + if ( + not config.get("rnn_name", "Recurrent") == "Recurrent" + and not config.get("rnn_name", "Recurrent") == "Transformer" + ): + # Flatten for non-recurrent models mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) state = dict( action=mb_actions, - lstm_h=None, - lstm_c=None, ) + # Add appropriate state based on model type + if config.get("rnn_name", "Recurrent") == "Recurrent": + state["lstm_h"] = None + state["lstm_c"] = None + elif config.get("rnn_name", "Recurrent") == "Transformer": + state["transformer_context"] = None + state["transformer_position"] = None + state["terminals"] = mb_terminals # For episode boundary masking + logits, newvalue = self.policy(mb_obs, state) - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + + # Handle action sampling based on observation shape + if ( + config.get("rnn_name", "Recurrent") == "Recurrent" + or config.get("rnn_name", "Recurrent") == "Transformer" + ): + # Add this right before calling sample_logits + if isinstance(logits, tuple): + logits = logits[0] + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + else: + # Need to flatten actions for non-recurrent models + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits( + logits, + action=mb_actions.reshape(-1, *mb_actions.shape[2:]) if len(mb_actions.shape) > 2 else mb_actions, + ) profile("train_misc", epoch) newlogprob = newlogprob.reshape(mb_logprobs.shape) @@ -508,6 +896,8 @@ def train(self): mb_gammas = gammas[idx] else: mb_gammas = torch.full((len(idx),), config["gamma"], device=device, dtype=torch.float32) + + # Recompute advantages with new ratios adv = compute_puff_advantage( mb_values, mb_rewards, @@ -541,9 +931,9 @@ def train(self): hasattr(self.vecenv.driver_env, "dynamics_model") and self.vecenv.driver_env.dynamics_model == "jerk" ): - ent_idx = 10 # base ego obs + ent_idx = 12 # EGO_FEATURES_JERK (was 10 before lane features) else: - ent_idx = 7 + ent_idx = 9 # EGO_FEATURES_CLASSIC (was 7 before lane features) if self.vecenv.driver_env.reward_conditioned: ent_idx += 3 @@ -607,30 +997,15 @@ def train(self): self.msg = f"Checkpoint saved at update {self.epoch}" if self.render and self.epoch % self.render_interval == 0: - model_dir = os.path.join(self.config["data_dir"], f"{self.config['env']}_{self.logger.run_id}") - model_files = glob.glob(os.path.join(model_dir, "model_*.pt")) - - if model_files: - # Take the latest checkpoint - latest_cpt = max(model_files, key=os.path.getctime) - bin_path = f"{model_dir}.bin" - - # Export to .bin for rendering with raylib - try: - export_args = {"env_name": self.config["env"], "load_model_path": latest_cpt, **self.config} - - export( - args=export_args, - env_name=self.config["env"], - vecenv=self.vecenv, - policy=self.uncompiled_policy, - path=bin_path, - silent=True, - ) - pufferlib.utils.render_videos(self.config, self.vecenv, self.logger, self.global_step, bin_path) - - except Exception as e: - print(f"Failed to export model weights: {e}") + torch.cuda.empty_cache() + pufferlib.utils.render_videos( + config=self.config, + policy=self.uncompiled_policy, + logger=self.logger, + epoch=self.epoch, + global_step=self.global_step, + device=self.config["device"], + ) if self.config["eval"]["wosac_realism_eval"] and ( self.epoch % self.config["eval"]["eval_interval"] == 0 or done_training @@ -641,6 +1016,16 @@ def train(self): self.epoch % self.config["eval"]["eval_interval"] == 0 or done_training ): pufferlib.utils.run_human_replay_eval_in_subprocess(self.config, self.logger, self.global_step) + torch.cuda.empty_cache() + pufferlib.utils.render_videos( + config=self.config, + policy=self.uncompiled_policy, + logger=self.logger, + epoch=self.epoch, + global_step=self.global_step, + device=self.config["device"], + human_replay=True, + ) def mean_and_log(self): config = self.config @@ -716,6 +1101,12 @@ def save_checkpoint(self): state_path = os.path.join(path, "trainer_state.pt") torch.save(state, state_path + ".tmp") os.rename(state_path + ".tmp", state_path) + + # Sidecar metadata: every render/eval can recover the right + # conditioning, dataset, and architecture from /info.json + # without the user having to re-pass them on the CLI. + write_run_info(path, self.config, run_id) + return model_path def print_dashboard(self, clear=False, idx=[0], c1="[cyan]", c2="[white]", b1="[bright_cyan]", b2="[bright_white]"): @@ -1052,6 +1443,7 @@ def __init__(self, args, load_id=None, resume="allow"): save_code=False, resume=resume, config=args, + name=args.get("wandb_name"), tags=[args["tag"]] if args["tag"] is not None else [], ) self.wandb = wandb @@ -1103,11 +1495,26 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None): policy = model.to(local_rank) if args["neptune"]: - logger = NeptuneLogger(args) + logger = NeptuneLogger(args, load_id=args.get("load_id")) elif args["wandb"]: - logger = WandbLogger(args) - - train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {})) + # Pass load_id so the wandb logger resumes the existing run instead + # of creating a fresh one. WandbLogger uses resume="allow", so wandb + # picks up where the original run left off (history, name, tags). + logger = WandbLogger(args, load_id=args.get("load_id")) + + train_config = dict( + **args["train"], + env=env_name, + eval=args.get("eval", {}), + env_config=args.get("env", {}), + policy_architecture=args.get("policy_architecture", "Recurrent"), + # rnn_name lives at args top level — must be explicitly propagated, + # else config.get("rnn_name") defaults to "Recurrent" and the + # Transformer init/rollout branches in PuffeRL never fire. + rnn_name=args.get("rnn_name", args.get("policy_architecture", "Recurrent")), + load_model_path=args.get("load_model_path"), + load_id=args.get("load_id"), + ) pufferl = PuffeRL(train_config, vecenv, policy, logger) all_logs = [] @@ -1149,15 +1556,24 @@ def eval(env_name, args=None, vecenv=None, policy=None): wosac_enabled = args["eval"]["wosac_realism_eval"] human_replay_enabled = args["eval"]["human_replay_eval"] + # Honor eval.map_dir only when explicitly set; otherwise inherit the + # training env.map_dir so eval doesn't silently switch datasets. + eval_map_dir = args["eval"].get("map_dir") + if eval_map_dir in (None, "", "None"): + eval_map_dir = args["env"].get("map_dir") + args["env"]["map_dir"] = eval_map_dir + args["eval"]["map_dir"] = eval_map_dir + args["env"]["num_maps"] = args["eval"]["num_maps"] + args["env"]["use_all_maps"] = True + dataset_name = args["env"]["map_dir"].split("/")[-1] if wosac_enabled: - print(f"Running WOSAC realism evaluation. \n") + print(f"Running WOSAC realism evaluation with {dataset_name} dataset. \n") from pufferlib.ocean.benchmark.evaluator import WOSACEvaluator backend = args["eval"]["backend"] assert backend == "PufferEnv" or not wosac_enabled, "WOSAC evaluation only supports PufferEnv backend." args["vec"] = dict(backend=backend, num_envs=1) - args["env"]["num_agents"] = args["eval"]["wosac_num_agents"] args["env"]["init_mode"] = args["eval"]["wosac_init_mode"] args["env"]["control_mode"] = args["eval"]["wosac_control_mode"] args["env"]["init_steps"] = args["eval"]["wosac_init_steps"] @@ -1172,6 +1588,10 @@ def eval(env_name, args=None, vecenv=None, policy=None): # Collect ground truth trajectories from the dataset gt_trajectories = evaluator.collect_ground_truth_trajectories(vecenv) + print(f"Number of scenarios: {len(np.unique(gt_trajectories['scenario_id']))}") + print(f"Number of controlled agents: {gt_trajectories['x'].shape[0]}") + print(f"Number of evaluated agents: {np.sum(gt_trajectories['id'] >= 0)}") + # Roll out trained policy in the simulator simulated_trajectories = evaluator.collect_simulated_trajectories(args, vecenv, policy) @@ -1179,32 +1599,52 @@ def eval(env_name, args=None, vecenv=None, policy=None): evaluator._quick_sanity_check(gt_trajectories, simulated_trajectories) # Analyze and compute metrics + agent_state = vecenv.driver_env.get_global_agent_state() + road_edge_polylines = vecenv.driver_env.get_road_edge_polylines() results = evaluator.compute_metrics( - gt_trajectories, simulated_trajectories, args["eval"]["wosac_aggregate_results"] + gt_trajectories, + simulated_trajectories, + agent_state, + road_edge_polylines, + args["eval"]["wosac_aggregate_results"], ) if args["eval"]["wosac_aggregate_results"]: import json - print("WOSAC_METRICS_START") + print("\nWOSAC_METRICS_START") print(json.dumps(results)) print("WOSAC_METRICS_END") return results elif human_replay_enabled: - print("Running human replay evaluation.\n") + print(f"Running human replay evaluation with {dataset_name} dataset.\n") from pufferlib.ocean.benchmark.evaluator import HumanReplayEvaluator backend = args["eval"].get("backend", "PufferEnv") args["vec"] = dict(backend=backend, num_envs=1) - args["env"]["num_agents"] = args["eval"]["human_replay_num_agents"] args["env"]["control_mode"] = args["eval"]["human_replay_control_mode"] - args["env"]["scenario_length"] = 91 # Standard scenario length + # episode_length is NOT hardcoded here — inherits scenario_length + # from training config (WOMD=91, nuPlan=201). + # Human replay: only 1 ego is policy-controlled, others follow logged trajectories + args["env"]["co_player_enabled"] = False + args["env"]["max_controlled_agents"] = 1 + # `human_replay_mode` is only accepted by AdaptiveDrivingAgent + if "adaptive" in env_name: + args["env"]["human_replay_mode"] = True + if args["eval"].get("human_replay_num_agents") is not None: + args["env"]["num_agents"] = args["eval"]["human_replay_num_agents"] + if args["eval"].get("human_replay_num_maps") is not None: + args["env"]["num_maps"] = args["eval"]["human_replay_num_maps"] + if args["eval"].get("map_dir") not in (None, "", "None"): + args["env"]["map_dir"] = args["eval"]["map_dir"] vecenv = vecenv or load_env(env_name, args) policy = policy or load_policy(args, vecenv, env_name) + print(f"Effective number of scenarios used: {len(vecenv.driver_env.agent_offsets) - 1}") + evaluator = HumanReplayEvaluator(args) # Run rollouts with human replays @@ -1231,10 +1671,6 @@ def eval(env_name, args=None, vecenv=None, policy=None): num_agents = vecenv.observation_space.shape[0] device = args["train"]["device"] - # Rebuild visualize binary if saving frames (for C-based rendering) - if args["save_frames"] > 0: - ensure_drive_binary() - state = {} if args["train"]["use_rnn"]: state = dict( @@ -1312,6 +1748,118 @@ def sweep(args=None, env_name=None): args["train"]["total_timesteps"] = total_timesteps +def controlled_exp(env_name, args=None): + """Run experiments with all combinations of specified parameter values.""" + import itertools + from copy import deepcopy + + args = args or load_config(env_name) + if not args["wandb"] and not args["neptune"]: + raise pufferlib.APIUsageError("Targeted experiments require either wandb or neptune") + + # Check if controlled_exp config exists + if "controlled_exp" not in args: + raise pufferlib.APIUsageError("No [controlled_exp.*] sections found in config") + + # Extract parameters from controlled_exp namespace + params = {} + for section, section_config in args["controlled_exp"].items(): + if isinstance(section_config, dict): + for param, param_config in section_config.items(): + if isinstance(param_config, dict) and "values" in param_config: + params[f"{section}.{param}"] = param_config["values"] + + if not params: + raise pufferlib.APIUsageError("No parameters with 'values' lists found in [controlled_exp.*] sections") + + # Generate all combinations + keys = list(params.keys()) + combinations = list(itertools.product(*[params[k] for k in keys])) + + print(f"Running a total of {len(combinations)} experiments with parameters: {keys}") + + # Run each combination + for i, combo in enumerate(combinations, 1): + exp_args = deepcopy(args) + + # Set parameters + for key, value in zip(keys, combo): + section, param = key.split(".") + exp_args[section][param] = value + + print(f"\nExperiment {i}/{len(combinations)}: {dict(zip(keys, combo))}") + + # Train + train(env_name, args=exp_args) + + print(f"\n✓ Completed all {len(combinations)} experiments") + + +def sanity(env_name, args=None): + args = args or load_config(env_name) + base_dir = Path(__file__).resolve().parent / "resources" / "drive" / "sanity" + json_dir = base_dir / "sanity_jsons" + binary_dir = base_dir / "sanity_binaries" + + available_maps = {p.stem: p for p in json_dir.glob("*.json")} + selected = args.get("sanity_maps") + if isinstance(selected, str): + selected = [selected] + + if selected: + missing = [name for name in selected if name not in available_maps] + if missing: + raise pufferlib.APIUsageError(f"Unknown sanity maps: {', '.join(sorted(missing))}") + chosen = [(name, available_maps[name]) for name in selected] + else: + chosen = sorted(available_maps.items()) + + if not chosen: + raise pufferlib.APIUsageError(f"No sanity maps found in {json_dir}") + + from pufferlib.ocean.drive.drive import load_map + + binary_dir.mkdir(parents=True, exist_ok=True) + binaries = [] + for idx, (name, json_path) in enumerate(chosen): + output_path = binary_dir / f"{name}.bin" + load_map(str(json_path), idx, str(output_path)) + binaries.append((name, output_path)) + + runs = [] + for name, binary in binaries: + map_zero = binary_dir / "map_000.bin" + shutil.copy2(binary, map_zero) + + run_args = { + **args, + "env": {**args["env"], "num_maps": 1, "map_dir": str(binary_dir)}, + "train": {**args["train"], "render_map": str(map_zero)}, + } + if run_args.get("wandb"): + run_args["wandb_name"] = name + + print(f"Running sanity map '{name}' from {binary.name}") + run_logs = train(env_name=env_name, args=run_args) + runs.append({"map": name, "logs": run_logs}) + + print("Sanity checklist:") + for entry in runs: + name = entry["map"] + logs = entry.get("logs") or [] + final = logs[-1] if logs else {} + score = final.get("environment/score") + if score is None: + status = "unknown (no score)" + elif score >= 0.95: + status = "✅ Solved" + else: + status = "❌ unsolved" + print(f" - {name}: {status} (score={score})") + + return runs + + def profile(args=None, env_name=None, vecenv=None, policy=None): args = load_config() vecenv = vecenv or load_env(env_name, args) @@ -1352,31 +1900,63 @@ def export(args=None, env_name=None, vecenv=None, policy=None, path=None, silent print(f"Saved {len(weights)} weights to {path}") -def ensure_drive_binary(): - """Delete existing visualize binary and rebuild it. This ensures the - binary is always up-to-date with the latest code changes. +def write_run_info(run_dir, config, run_id): + """Persist the bits of training config that render/eval need to replay. + + We only record fields that change observation/architecture shape or + dataset identity — the things you can't recover from the .pt alone. + Existing checkpoints don't have this file; readers must treat it as + optional. """ - if os.path.exists("./visualize"): - print("Removing existing visualize binary...") - try: - os.remove("./visualize") - except FileNotFoundError: - print("Binary not found") - print("Building visualize binary...") + import json + + env_cfg = config.get("env_config", {}) + info = { + "run_id": run_id, + "env_name": config.get("env"), + "policy_architecture": config.get("policy_architecture"), + "rnn_name": config.get("rnn_name"), + "env": { + "map_dir": env_cfg.get("map_dir"), + "num_maps": env_cfg.get("num_maps"), + "num_agents": env_cfg.get("num_agents"), + "num_ego_agents": env_cfg.get("num_ego_agents"), + "k_scenarios": env_cfg.get("k_scenarios", 1), + "scenario_length": env_cfg.get("scenario_length", 91), + "dynamics_model": env_cfg.get("dynamics_model", "classic"), + "co_player_enabled": bool(env_cfg.get("co_player_enabled")), + "conditioning": env_cfg.get("conditioning", {}), + "co_player_policy": env_cfg.get("co_player_policy", {}), + }, + } + info_path = os.path.join(run_dir, "info.json") try: - result = subprocess.run( - ["bash", "scripts/build_ocean.sh", "visualize", "local"], capture_output=True, text=True, timeout=300 - ) - - if result.returncode == 0: - print("Successfully built visualize binary") - else: - print(f"Build failed: {result.stderr}") - raise RuntimeError("Failed to build visualize binary for rendering") - except subprocess.TimeoutExpired: - raise RuntimeError("Build timed out") + with open(info_path + ".tmp", "w") as f: + json.dump(info, f, indent=2, default=str) + os.rename(info_path + ".tmp", info_path) except Exception as e: - raise RuntimeError(f"Build error: {e}") + print(f"[info.json] failed to write: {e}") + + +def load_run_info(model_path): + """Look up the run_dir/info.json for a given checkpoint path. Returns {} if missing.""" + import json + + candidates = [] + parent = os.path.dirname(model_path) or "." + candidates.append(os.path.join(parent, "info.json")) + # Allow `experiments/puffer_drive_.pt` (flat copy) → look in sibling run dir. + base, ext = os.path.splitext(model_path) + if ext == ".pt" and os.path.basename(base).startswith("puffer_"): + candidates.append(os.path.join(base, "info.json")) + for path in candidates: + if os.path.exists(path): + try: + with open(path) as f: + return json.load(f) + except Exception as e: + print(f"[info.json] failed to read {path}: {e}") + return {} def autotune(args=None, env_name=None, vecenv=None, policy=None): @@ -1402,40 +1982,57 @@ def load_policy(args, vecenv, env_name=""): env_module = importlib.import_module(module_name) device = args["train"]["device"] - policy_cls = getattr(env_module.torch, args["policy_name"]) - policy = policy_cls(vecenv.driver_env, **args["policy"]) - - rnn_name = args["rnn_name"] - if rnn_name is not None: - rnn_cls = getattr(env_module.torch, args["rnn_name"]) - policy = rnn_cls(vecenv.driver_env, policy, **args["rnn"]) - - policy = policy.to(device) load_id = args["load_id"] - if load_id is not None: + load_path = args.get("load_model_path") + state_dict = None + rnn_name = args.get("policy_architecture", "Recurrent") + + if load_path is not None: + state_dict = torch.load(load_path, map_location=device) + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + elif load_id is not None: if args["neptune"]: path = NeptuneLogger(args, load_id, mode="read-only").download() elif args["wandb"]: path = WandbLogger(args, load_id).download() else: raise pufferlib.APIUsageError("No run id provided for eval") - state_dict = torch.load(path, map_location=device) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} - policy.load_state_dict(state_dict) - load_path = args["load_model_path"] - if load_path == "latest": - load_path = max(glob.glob(f"experiments/{env_name}*.pt"), key=os.path.getctime) + # Auto-detect architecture from state_dict keys + if state_dict is not None: + if "positional_embedding" in state_dict: + rnn_name = "Transformer" + elif "lstm.weight_ih_l0" in state_dict: + rnn_name = "Recurrent" - if load_path is not None: - state_dict = torch.load(load_path, map_location=device) - state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + policy_cls = getattr(env_module.torch, args["policy_name"]) + policy = policy_cls(vecenv.driver_env, **args["policy"]) + + # Handle both RNN and Transformer wrappers via rnn_name + if rnn_name == "Transformer": + # Load transformer wrapper + transformer_cls = getattr(env_module.torch, rnn_name) + # For adaptive_driving_agent, use episode_length as horizon (k_scenarios * scenario_length) + # Otherwise, use config horizon with fallback to episode_length + is_adaptive = getattr(vecenv.driver_env, "env_name", None) == "adaptive_drive" + if is_adaptive: + args["transformer"]["horizon"] = vecenv.driver_env.episode_length + else: + args["transformer"]["horizon"] = args["train"].get("horizon", vecenv.driver_env.episode_length) + policy = transformer_cls(vecenv.driver_env, policy, **args["transformer"]) + elif rnn_name is not None: + # Load RNN wrapper (Recurrent) + rnn_cls = getattr(env_module.torch, rnn_name) + policy = rnn_cls(vecenv.driver_env, policy, **args["rnn"]) + + policy = policy.to(device) + + # Load the state dict if we have one + if state_dict is not None: policy.load_state_dict(state_dict) - # state_path = os.path.join(*load_path.split('/')[:-1], 'state.pt') - # optim_state = torch.load(state_path)['optimizer_state_dict'] - # pufferl.optimizer.load_state_dict(optim_state) return policy @@ -1466,7 +2063,7 @@ def load_config(env_name): parser.add_argument("--neptune-project", type=str, default="ablations") parser.add_argument("--local-rank", type=int, default=0, help="Used by torchrun for DDP") parser.add_argument("--tag", type=str, default=None, help="Tag for experiment") - + parser.add_argument("--sanity-maps", nargs="*", default=None, help="Optional list of sanity map base names to run") args = parser.parse_known_args()[0] # Load defaults and config @@ -1517,9 +2114,7 @@ def puffer_type(value): def main(): - err = ( - "Usage: puffer [train, eval, sweep, autotune, profile, export] [env_name] [optional args]. --help for more info" - ) + err = "Usage: puffer [train, eval, sweep, controlled_exp, autotune, profile, export, sanity] [env_name] [optional args]. --help for more info" if len(sys.argv) < 3: raise pufferlib.APIUsageError(err) @@ -1531,12 +2126,16 @@ def main(): eval(env_name=env_name) elif mode == "sweep": sweep(env_name=env_name) + elif mode == "controlled_exp": + controlled_exp(env_name=env_name) elif mode == "autotune": autotune(env_name=env_name) elif mode == "profile": profile(env_name=env_name) elif mode == "export": export(env_name=env_name) + elif mode == "sanity": + sanity(env_name=env_name) else: raise pufferlib.APIUsageError(err) diff --git a/pufferlib/resources/drive/binaries/map_000.bin b/pufferlib/resources/drive/binaries/map_000.bin deleted file mode 100644 index 434b98c255..0000000000 Binary files a/pufferlib/resources/drive/binaries/map_000.bin and /dev/null differ diff --git a/pufferlib/resources/drive/binaries/training/map_000.bin b/pufferlib/resources/drive/binaries/training/map_000.bin new file mode 100644 index 0000000000..ef87c59af2 Binary files /dev/null and b/pufferlib/resources/drive/binaries/training/map_000.bin differ diff --git a/pufferlib/resources/drive/puffer_adaptive_drive_co_player.bin b/pufferlib/resources/drive/puffer_adaptive_drive_co_player.bin new file mode 100644 index 0000000000..c3840a5351 Binary files /dev/null and b/pufferlib/resources/drive/puffer_adaptive_drive_co_player.bin differ diff --git a/pufferlib/utils.py b/pufferlib/utils.py index a93c6b2945..7210f54d06 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -7,9 +7,11 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): - """ - Run human replay evaluation in a subprocess and log metrics to wandb. + """Run human replay evaluation in a subprocess and log metrics to wandb. + Routes through `pufferl eval --eval.human-replay-eval True` for both adaptive + and non-adaptive agents. The subprocess uses HumanReplayEvaluator, which + handles both architectures (LSTM, Transformer) and both agent types. """ try: run_id = logger.run_id @@ -22,8 +24,19 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): latest_cpt = max(model_files, key=os.path.getctime) - # Prepare evaluation command - eval_config = config["eval"] + env_config = config.get("env_config", {}) + eval_config = config.get("eval", {}) + conditioning = env_config.get("conditioning", {}) + conditioning_type = conditioning.get("type", "none") + # Resolve map_dir on the parent so the child can't silently fall back to the ini default + map_dir = eval_config.get("map_dir") or env_config.get("map_dir") + # Adaptive runs override k_scenarios and the resulting episode length; + # the child must use the same values or it will build a model whose + # positional_embedding shape doesn't match the trained checkpoint. + k_scenarios = env_config.get("k_scenarios", 1) + scenario_length = env_config.get("scenario_length", 91) + train_horizon = config.get("horizon", scenario_length * k_scenarios) + cmd = [ sys.executable, "-m", @@ -37,35 +50,77 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): "--eval.human-replay-eval", "True", "--eval.human-replay-num-agents", - str(eval_config["human_replay_num_agents"]), + str(eval_config.get("human_replay_num_agents", 64)), + "--eval.human-replay-num-maps", + str(eval_config.get("human_replay_num_maps", 100)), + "--eval.human-replay-num-rollouts", + str(eval_config.get("human_replay_num_rollouts", 100)), "--eval.human-replay-control-mode", - str(eval_config["human_replay_control_mode"]), + str(eval_config.get("human_replay_control_mode", "control_vehicles")), + *(["--eval.map-dir", str(map_dir)] if map_dir else []), + "--eval.num-maps", + str(eval_config.get("num_maps", 20)), + "--env.k-scenarios", + str(k_scenarios), + "--env.scenario-length", + str(scenario_length), + "--train.horizon", + str(train_horizon), + # Match training goal_behavior. The "stop is cleaner" claim was wrong + # in practice — sparse reward under gb=2 produces worse drivers; gb=0 + # respawn captures real efficiency adaptation in scen_1 vs scen_0. + "--env.goal-behavior", + "0", + "--env.conditioning.type", + conditioning_type, + "--env.conditioning.collision-weight-lb", + str(conditioning.get("collision_weight_lb", -3.0)), + "--env.conditioning.collision-weight-ub", + str(conditioning.get("collision_weight_ub", -3.0)), + "--env.conditioning.offroad-weight-lb", + str(conditioning.get("offroad_weight_lb", -1.0)), + "--env.conditioning.offroad-weight-ub", + str(conditioning.get("offroad_weight_ub", -1.0)), + "--env.conditioning.goal-weight-lb", + str(conditioning.get("goal_weight_lb", 1.0)), + "--env.conditioning.goal-weight-ub", + str(conditioning.get("goal_weight_ub", 1.0)), + "--env.conditioning.entropy-weight-lb", + str(conditioning.get("entropy_weight_lb", 0.001)), + "--env.conditioning.entropy-weight-ub", + str(conditioning.get("entropy_weight_ub", 0.001)), + "--env.conditioning.discount-weight-lb", + str(conditioning.get("discount_weight_lb", 0.98)), + "--env.conditioning.discount-weight-ub", + str(conditioning.get("discount_weight_ub", 0.98)), ] - # Run human replay evaluation in subprocess result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, cwd=os.getcwd()) - if result.returncode == 0: - # Extract JSON from stdout between markers - stdout = result.stdout - if "HUMAN_REPLAY_METRICS_START" in stdout and "HUMAN_REPLAY_METRICS_END" in stdout: - start = stdout.find("HUMAN_REPLAY_METRICS_START") + len("HUMAN_REPLAY_METRICS_START") - end = stdout.find("HUMAN_REPLAY_METRICS_END") - json_str = stdout[start:end].strip() - human_replay_metrics = json.loads(json_str) - - # Log to wandb if available - if hasattr(logger, "wandb") and logger.wandb: - logger.wandb.log( - { - "eval/human_replay_collision_rate": human_replay_metrics["collision_rate"], - "eval/human_replay_offroad_rate": human_replay_metrics["offroad_rate"], - "eval/human_replay_completion_rate": human_replay_metrics["completion_rate"], - }, - step=global_step, - ) - else: - print(f"Human replay evaluation failed with exit code {result.returncode}: {result.stderr}") + if result.returncode != 0: + print(f"Human replay evaluation failed (exit {result.returncode}): {result.stderr}") + return + + stdout = result.stdout + if "HUMAN_REPLAY_METRICS_START" not in stdout or "HUMAN_REPLAY_METRICS_END" not in stdout: + return + + start = stdout.find("HUMAN_REPLAY_METRICS_START") + len("HUMAN_REPLAY_METRICS_START") + end = stdout.find("HUMAN_REPLAY_METRICS_END") + metrics = json.loads(stdout[start:end].strip()) + + if not (hasattr(logger, "wandb") and logger.wandb): + return + + # Forward every metric the evaluator emitted under eval/human_replay_*. + # This includes `_std` (variance across rollouts), eval-scale metadata + # (n_rollouts, n_agents_per_rollout, n_total_evals), and the full + # ada_delta_*/scenario_* family without an explicit allow-list. + log_data = {} + for k, v in metrics.items(): + if isinstance(v, (int, float)): + log_data[f"eval/human_replay_{k}"] = v + logger.wandb.log(log_data, step=global_step) except subprocess.TimeoutExpired: print("Human replay evaluation timed out") @@ -74,18 +129,7 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): def run_wosac_eval_in_subprocess(config, logger, global_step): - """ - Run WOSAC evaluation in a subprocess and log metrics to wandb. - - Args: - config: Configuration dictionary containing data_dir, env, and wosac settings - logger: Logger object with run_id and optional wandb attribute - epoch: Current training epoch - global_step: Current global training step - - Returns: - None. Prints error messages if evaluation fails. - """ + """Run WOSAC realism evaluation in a subprocess and log metrics to wandb.""" try: run_id = logger.run_id model_dir = os.path.join(config["data_dir"], f"{config['env']}_{run_id}") @@ -97,8 +141,11 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): latest_cpt = max(model_files, key=os.path.getctime) - # Prepare evaluation command + env_config = config.get("env_config", {}) eval_config = config.get("eval", {}) + # Forward training map_dir so eval doesn't silently fall back to ini default + map_dir = eval_config.get("map_dir") or env_config.get("map_dir") + cmd = [ sys.executable, "-m", @@ -114,7 +161,7 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): "--eval.wosac-init-mode", str(eval_config.get("wosac_init_mode", "create_all_valid")), "--eval.wosac-control-mode", - str(eval_config.get("wosac_control_mode", "control_tracks_to_predict")), + str(eval_config.get("wosac_control_mode", "control_wosac")), "--eval.wosac-init-steps", str(eval_config.get("wosac_init_steps", 10)), "--eval.wosac-goal-behavior", @@ -125,167 +172,184 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): str(eval_config.get("wosac_sanity_check", False)), "--eval.wosac-aggregate-results", str(eval_config.get("wosac_aggregate_results", True)), + *(["--eval.map-dir", str(map_dir)] if map_dir else []), ] - # Run WOSAC evaluation in subprocess result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, cwd=os.getcwd()) - if result.returncode == 0: - # Extract JSON from stdout between markers - stdout = result.stdout - if "WOSAC_METRICS_START" in stdout and "WOSAC_METRICS_END" in stdout: - start = stdout.find("WOSAC_METRICS_START") + len("WOSAC_METRICS_START") - end = stdout.find("WOSAC_METRICS_END") - json_str = stdout[start:end].strip() - wosac_metrics = json.loads(json_str) - - # Log to wandb if available - if hasattr(logger, "wandb") and logger.wandb: - logger.wandb.log( - { - "eval/wosac_realism_meta_score": wosac_metrics["realism_meta_score"], - "eval/wosac_ade": wosac_metrics["ade"], - "eval/wosac_min_ade": wosac_metrics["min_ade"], - "eval/wosac_total_num_agents": wosac_metrics["total_num_agents"], - }, - step=global_step, - ) - else: - print(f"WOSAC evaluation failed with exit code {result.returncode}: {result.stderr}") + if result.returncode != 0: + print(f"WOSAC evaluation failed (exit {result.returncode}): {result.stderr}") + stderr_lower = result.stderr.lower() + if "out of memory" in stderr_lower: + print("GPU OOM during WOSAC eval; skipping.") + return + + stdout = result.stdout + if "WOSAC_METRICS_START" not in stdout or "WOSAC_METRICS_END" not in stdout: + return + + start = stdout.find("WOSAC_METRICS_START") + len("WOSAC_METRICS_START") + end = stdout.find("WOSAC_METRICS_END") + metrics = json.loads(stdout[start:end].strip()) + + if hasattr(logger, "wandb") and logger.wandb: + logger.wandb.log( + { + "eval/wosac_realism_meta_score": metrics["realism_meta_score"], + "eval/wosac_ade": metrics["ade"], + "eval/wosac_min_ade": metrics["min_ade"], + "eval/wosac_total_num_agents": metrics["total_num_agents"], + }, + step=global_step, + ) except subprocess.TimeoutExpired: - print("WOSAC evaluation timed out") + print("WOSAC evaluation timed out after 600 seconds") except Exception as e: - print(f"Failed to run WOSAC evaluation: {e}") + print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}") -def render_videos(config, vecenv, logger, global_step, bin_path): - """ - Generate and log training videos using C-based rendering. +_VIEW_NAMES = {0: "sim_state", 1: "bev", 2: "persp"} + + +def render_videos(config, policy, logger, epoch, global_step, device="cuda", human_replay=False): + """Generate and log videos via Python rollout (works with any policy architecture). + + Policy inference is in PyTorch; rendering goes through the C bindings + (`vec_render`, `vec_set_video_suffix`). Saves under + /_/renders/ with names that include the map_id, + mode, and view so multiple modes on the same map don't collide. Args: - config: Configuration dictionary containing data_dir, env, and render settings - vecenv: Vectorized environment with driver_env attribute - logger: Logger object with run_id and optional wandb attribute - epoch: Current training epoch - global_step: Current global training step - bin_path: Path to the exported .bin model weights file - - Returns: - None. Prints error messages if rendering fails. + config: PuffeRL flat train config. + policy: PyTorch policy (LSTM or Transformer wrapper). + logger: Logger with run_id and optional .wandb. + epoch: Current training epoch. + global_step: Current global step (for wandb step alignment). + device: Inference device. + human_replay: If True, render in human-replay mode (1 ego, others = logs). """ - if not os.path.exists(bin_path): - print(f"Binary weights file does not exist: {bin_path}") - return + import copy + import torch + from pufferlib.pufferl import load_env + from pufferlib.ocean.drive.rollout import RenderContext, RenderView, rollout_loop - run_id = logger.run_id - model_dir = os.path.join(config["data_dir"], f"{config['env']}_{run_id}") - - # Now call the C rendering function try: - # Create output directory for videos - video_output_dir = os.path.join(model_dir, "videos") + run_id = logger.run_id + env_name = config.get("env", "drive") + model_dir = os.path.join(config["data_dir"], f"{env_name}_{run_id}") + video_output_dir = os.path.join(model_dir, "renders") os.makedirs(video_output_dir, exist_ok=True) - # Copy the binary weights to the expected location - expected_weights_path = "resources/drive/puffer_drive_weights.bin" - os.makedirs(os.path.dirname(expected_weights_path), exist_ok=True) - shutil.copy2(bin_path, expected_weights_path) - - # TODO: Fix memory leaks so that this is not needed - # Suppress AddressSanitizer exit code (temp) - env = os.environ.copy() - env["ASAN_OPTIONS"] = "exitcode=0" - - cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"] - - # Add render configurations - if config["show_grid"]: - cmd.append("--show-grid") - if config["obs_only"]: - cmd.append("--obs-only") - if config["show_lasers"]: - cmd.append("--lasers") - if config["show_human_logs"]: - cmd.append("--log-trajectories") - if vecenv.driver_env.goal_radius is not None: - cmd.extend(["--goal-radius", str(vecenv.driver_env.goal_radius)]) - if vecenv.driver_env.init_steps > 0: - cmd.extend(["--init-steps", str(vecenv.driver_env.init_steps)]) - if config["render_map"] is not None: - map_path = config["render_map"] - if os.path.exists(map_path): - cmd.extend(["--map-name", map_path]) - if vecenv.driver_env.init_mode is not None: - cmd.extend(["--init-mode", str(vecenv.driver_env.init_mode)]) - if vecenv.driver_env.control_mode is not None: - cmd.extend(["--control-mode", str(vecenv.driver_env.control_mode)]) - - if hasattr(vecenv.driver_env, "reward_conditioned"): - cmd.extend(["--use-rc", "1" if vecenv.driver_env.reward_conditioned else "0"]) - if hasattr(vecenv.driver_env, "entropy_conditioned"): - cmd.extend(["--use-ec", "1" if vecenv.driver_env.entropy_conditioned else "0"]) - if hasattr(vecenv.driver_env, "discount_conditioned"): - cmd.extend(["--use-dc", "1" if vecenv.driver_env.discount_conditioned else "0"]) - - # Specify output paths for videos - cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"]) - cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"]) - - # Add environment configuration - env_cfg = getattr(vecenv, "driver_env", None) - if env_cfg is not None: - n_policy = getattr(env_cfg, "max_controlled_agents", -1) + view_modes = config.get("render_view_modes", [RenderView.FULL_SIM_STATE]) + if isinstance(view_modes, int): + view_modes = [view_modes] + + env_kwargs = copy.deepcopy(config.get("env_config", {})) + env_kwargs["render_mode"] = 1 # RENDER_HEADLESS + # Render env runs alongside training and has to fit in the same VRAM / + # RAM budget — override the training num_agents (often 1024+) down to a + # render-sized footprint so we don't OOM on first render call. + env_kwargs["num_agents"] = min(env_kwargs.get("num_agents", 64), 64) + if env_kwargs.get("num_ego_agents") is not None: + env_kwargs["num_ego_agents"] = min(env_kwargs["num_ego_agents"], 32) + env_kwargs["num_maps"] = min(env_kwargs.get("num_maps", 500), 500) + # Force per-env inline co-player inference for the render env. Training + # runs with external_co_player_actions=True expect PufferL's centralized + # GPU inference path to write co-player actions into shared memory each + # step. The render env is a single Serial env with no central path + # attached — without this override the SHM is never written, co-players + # freeze, and training-style renders show stuck partners. + env_kwargs["external_co_player_actions"] = False + + if human_replay: + env_kwargs["co_player_enabled"] = False + env_kwargs["max_controlled_agents"] = 1 + # Inherit goal_behavior from training config (no override) so the + # render matches whatever regime the policy actually trained under. + if "adaptive" in env_name: + env_kwargs["human_replay_mode"] = True + + # Force Serial backend for render: raylib's GLFW needs DISPLAY, which + # xvfb-run sets in the parent process. A Multiprocessing worker would + # spawn without DISPLAY and segfault on InitWindow. + render_args = { + "env": env_kwargs, + "vec": {"num_envs": 1, "backend": "Serial"}, + "package": config.get("package", "ocean"), + } + + use_rnn = config.get("use_rnn", False) + episode_length = env_kwargs.get("scenario_length", 91) + k_scenarios = env_kwargs.get("k_scenarios", 1) + if k_scenarios > 1: + episode_length = k_scenarios * episode_length + + mode = "human_replay" if human_replay else ("coplayer" if env_kwargs.get("co_player_enabled") else "baseline") + videos_to_log_world = [] + videos_to_log_agent = [] + + for view_mode in view_modes: + render_env = load_env(env_name, render_args) try: - n_policy = int(n_policy) - except (TypeError, ValueError): - n_policy = -1 - if n_policy > 0: - cmd += ["--num-policy-controlled-agents", str(n_policy)] - if getattr(env_cfg, "num_maps", False): - cmd.extend(["--num-maps", str(env_cfg.num_maps)]) - if getattr(env_cfg, "scenario_length", None): - cmd.extend(["--scenario-length", str(env_cfg.scenario_length)]) - - # Call C code that runs eval_gif() in subprocess - result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=120, env=env) - - vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists( - "resources/drive/output_agent.mp4" - ) - - if result.returncode == 0 or (result.returncode == 1 and vids_exist): - # Move both generated videos to the model directory - videos = [ - ("resources/drive/output_topdown.mp4", f"step_{global_step:09d}_topdown.mp4"), - ("resources/drive/output_agent.mp4", f"step_{global_step:09d}_agent.mp4"), - ] - - for source_vid, target_filename in videos: - if os.path.exists(source_vid): - target_gif = os.path.join(video_output_dir, target_filename) - shutil.move(source_vid, target_gif) - - # Log to wandb if available - if hasattr(logger, "wandb") and logger.wandb: - import wandb - - view_type = "world_state" if "topdown" in target_filename else "agent_view" - logger.wandb.log( - {f"render/{view_type}": wandb.Video(target_gif, format="mp4")}, - step=global_step, - ) + driver = render_env.driver_env + map_ids = getattr(driver, "map_ids", None) + map_id = int(map_ids[0]) if map_ids is not None and len(map_ids) > 0 else 0 + view = _VIEW_NAMES.get(int(view_mode), "view") + basename = f"epoch_{epoch:06d}_{mode}_k{k_scenarios}_map{map_id:03d}_{view}" + + # Tell the env to keep raylib + ffmpeg alive across map swaps so + # the in-step _reinit_envs_with_new_maps() at scenario boundaries + # doesn't kill the render. Single mp4 captures all k scenarios + # with the maps rotating mid-stream. + if getattr(driver, "map_rand_per_scenario", False): + driver._render_keep_client_on_swap = True + + policy.eval() + rollout_loop( + policy=policy, + env=render_env, + device=device, + use_rnn=use_rnn, + max_steps=episode_length, + render_ctx=RenderContext( + view_mode=view_mode, + env_id=0, + draw_traces=True, + video_basename=basename, + ), + ) + finally: + render_env.close() + + src = f"{basename}.mp4" + if not os.path.exists(src): + print(f"render: expected {src} not produced") + continue + + target_path = os.path.join(video_output_dir, src) + shutil.move(src, target_path) + + if hasattr(logger, "wandb") and logger.wandb: + import wandb + + if view == "sim_state": + videos_to_log_world.append(wandb.Video(target_path, format="mp4")) else: - print(f"Video generation completed but {source_vid} not found") - else: - print(f"C rendering failed with exit code {result.returncode}: {result.stdout}") + videos_to_log_agent.append(wandb.Video(target_path, format="mp4")) + + if hasattr(logger, "wandb") and logger.wandb and (videos_to_log_world or videos_to_log_agent): + payload = {} + world_key = "eval/human_replay_world_view" if human_replay else "render/world_state" + agent_key = "eval/human_replay_agent_view" if human_replay else "render/agent_view" + if videos_to_log_world: + payload[world_key] = videos_to_log_world + if videos_to_log_agent: + payload[agent_key] = videos_to_log_agent + logger.wandb.log(payload, step=global_step) - except subprocess.TimeoutExpired: - print("C rendering timed out") except Exception as e: - print(f"Failed to generate GIF: {e}") + print(f"Failed to render videos: {e}") + import traceback - finally: - # Clean up bin weights file - if os.path.exists(expected_weights_path): - os.remove(expected_weights_path) + traceback.print_exc() diff --git a/pufferlib/vector.py b/pufferlib/vector.py index 24ac492405..deee21d9f9 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -104,6 +104,13 @@ def __init__(self, env_creators, env_args, env_kwargs, num_envs, buf=None, seed= self.initialized = False self.flag = RESET + # Handle population play mode (ego agents only controlled by policy) + self.population_play = getattr(self.driver_env, "population_play", False) + if self.population_play: + ego_agents_per_batch = self.driver_env.num_ego_agents * num_envs + self.num_ego_agents = ego_agents_per_batch + self.ego_action_space = pufferlib.spaces.joint_space(self.single_action_space, ego_agents_per_batch) + def _avg_infos(self): infos = {} for e in self.infos: @@ -334,6 +341,9 @@ def __init__( self.agents_per_batch = driver_env.num_agents * batch_size agents_per_worker = driver_env.num_agents * envs_per_worker + # Persisted on the vecenv so PuffeRL can map per-recv `env_id` back + # to a worker index for centralized co-player inference. + self.agents_per_worker = agents_per_worker obs_space = driver_env.single_observation_space obs_shape = obs_space.shape self.obs_shape = obs_shape @@ -398,6 +408,49 @@ def __init__( self.atn_batch_shape = (self.workers_per_batch, agents_per_worker, *atn_shape) self.actions = np.ndarray((*shape, *atn_shape), dtype=atn_dtype, buffer=self.shm["actions"]) + # ---- Centralized GPU co-player conditioning SHM ---- + # When `external_co_player_actions` is set in env_kwargs AND the + # co-player has non-zero conditioning dims, we allocate a per-worker + # buffer so the env (which still samples conditioning at scenario + # boundaries) can deposit values for the main process to read before + # each forward pass. Sized to the worst-case `co_players_per_worker`. + env_k0 = env_kwargs[0] if env_kwargs else {} + external_coplayer_flag = env_k0.get("external_co_player_actions", False) and env_k0.get( + "co_player_enabled", False + ) + co_player_conditioning_dim = 0 + if external_coplayer_flag: + cond = env_k0.get("co_player_policy", {}).get("conditioning", {}) or {} + ctype = cond.get("type", "none") + co_player_conditioning_dim = ( + (3 if ctype in ("reward", "all") else 0) + + (1 if ctype in ("entropy", "all") else 0) + + (1 if ctype in ("discount", "all") else 0) + ) + if self.population_play and co_player_conditioning_dim > 0: + co_players_per_worker = agents_per_worker - ego_agents_per_worker + self.shm["co_player_conditioning"] = RawArray( + "f", num_workers * co_players_per_worker * co_player_conditioning_dim + ) + self.co_player_conditioning = np.ndarray( + (num_workers, co_players_per_worker, co_player_conditioning_dim), + dtype=np.float32, + buffer=self.shm["co_player_conditioning"], + ) + # CRITICAL: pufferlib.vector.make() builds env_kwargs as + # `[env_kwargs] * num_envs`, which is a list of N references + # to the SAME dict. Mutating env_kwargs[i] modifies all + # entries. We have to replace each slot with a per-env + # copy before adding worker_idx / SHM-slice entries. + for i in range(len(env_kwargs)): + w_idx = i // envs_per_worker + env_kwargs[i] = { + **env_kwargs[i], + "worker_idx": w_idx, + "co_player_conditioning_shm": self.co_player_conditioning[w_idx], + } + self._co_player_conditioning_dim = co_player_conditioning_dim + self.buf = dict( observations=np.ndarray((*shape, *obs_shape), dtype=obs_dtype, buffer=self.shm["observations"]), rewards=np.ndarray(shape, dtype=np.float32, buffer=self.shm["rewards"]), @@ -831,6 +884,13 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer # TODO: First step action space check env_k = env_kwargs[0] + # When external_co_player_actions is set, the *main* process owns the + # co-player policy on GPU and writes actions into the shared-memory + # action buffer at co_player slots before vec_step. Workers don't load + # the model and don't need single-thread CPU mode. We still build the + # policy here (to ship it to main via vecenv.co_player_policy_func), + # but on GPU and not stuffed into env_kwargs for workers. + external_coplayer = env_k.get("external_co_player_actions", False) and env_k.get("co_player_enabled", False) if env_k.get("co_player_enabled", False): import torch import os @@ -838,19 +898,18 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer import gymnasium from pufferlib.ocean.torch import Drive import pufferlib.models + from pufferlib.ocean.drive import binding - dynamics_model = env_k.get("dynamics_model", "jerk") - # Observation space calculation - if dynamics_model == "classic": - ego_features = 7 - elif dynamics_model == "jerk": - ego_features = 10 + dynamics_model = env_k.get("dynamics_model", "classic") + action_type = env_k.get("action_type", "discrete") co_player_policy = env_k["co_player_policy"] input_size = co_player_policy.get("input_size", 256) hidden_size = co_player_policy.get("hidden_size", 256) co_player_rnn = co_player_policy.get("rnn", None) + co_player_architecture = co_player_policy.get("architecture", "Recurrent") + co_player_transformer = co_player_policy.get("transformer", {}) # Get conditioning type from env_k co_player_conditioning = co_player_policy.get("conditioning") @@ -858,27 +917,79 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer reward_conditioned = condition_type in ("reward", "all") entropy_conditioned = condition_type in ("entropy", "all") discount_conditioned = condition_type in ("discount", "all") - # Calculate conditioning dimensions + + if action_type == "discrete": + if dynamics_model == "classic": + # Joint action space (assume dependence) + single_action_space = gymnasium.spaces.MultiDiscrete([7 * 13]) + # Multi discrete (assume independence) + # self.single_action_space = gymnasium.spaces.MultiDiscrete([7, 13]) + elif dynamics_model == "jerk": + # Joint action space (assume dependence) - 4 longitudinal × 3 lateral = 12 + single_action_space = gymnasium.spaces.MultiDiscrete([4 * 3]) + else: + raise ValueError(f"dynamics_model must be 'classic' or 'jerk'. Got: {dynamics_model}") + elif action_type == "continuous": + single_action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + else: + raise ValueError(f"action_space must be 'discrete' or 'continuous'. Got: {action_type}") + + # # Observation space calculation + ego_features = {"classic": binding.EGO_FEATURES_CLASSIC, "jerk": binding.EGO_FEATURES_JERK}.get(dynamics_model) + conditioning_dims = ( (3 if reward_conditioned else 0) + (1 if entropy_conditioned else 0) + (1 if discount_conditioned else 0) ) - # Base observations + conditioning observations - num_obs = ego_features + conditioning_dims + 63 * 7 + 200 * 7 - temp_env = SimpleNamespace( - single_action_space=gymnasium.spaces.MultiDiscrete([7 * 13]), - single_observation_space=gymnasium.spaces.Box(low=-1, high=1, shape=(num_obs,), dtype=np.float32), + ego_features += conditioning_dims + + # # Extract observation shapes from constants + # # These need to be defined in C, since they determine the shape of the arrays + # max_road_objects = 200 + # max_partner_objects = 63 + # partner_features = 7 + # road_features = 7 + + # Extract observation shapes from constants + # These need to be defined in C, since they determine the shape of the arrays + max_road_objects = binding.MAX_ROAD_SEGMENT_OBSERVATIONS + max_partner_objects = binding.MAX_AGENTS - 1 + partner_features = binding.PARTNER_FEATURES + road_features = binding.ROAD_FEATURES + + num_obs = ego_features + max_partner_objects * partner_features + max_road_objects * road_features + + single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(num_obs,), dtype=np.float32) + + co_player_env = SimpleNamespace( + single_action_space=single_action_space, + single_observation_space=single_observation_space, reward_conditioned=reward_conditioned, entropy_conditioned=entropy_conditioned, discount_conditioned=discount_conditioned, dynamics_model=dynamics_model, ## keep these the same I think, multiple dynamics models could get weird + max_partner_objects=max_partner_objects, + partner_features=partner_features, + max_road_objects=max_road_objects, + road_features=road_features, ) - base_policy = Drive(temp_env, input_size=input_size, hidden_size=hidden_size) + base_policy = Drive(co_player_env, input_size=input_size, hidden_size=hidden_size) - if co_player_rnn: + if co_player_architecture == "Transformer": + policy = pufferlib.models.TransformerWrapper( + co_player_env, + base_policy, + input_size=co_player_transformer.get("input_size", 256), + hidden_size=co_player_transformer.get("hidden_size", 256), + num_layers=co_player_transformer.get("num_layers", 2), + num_heads=co_player_transformer.get("num_heads", 4), + horizon=co_player_transformer.get("horizon", 91), + dropout=co_player_transformer.get("dropout", 0.0), + ) + elif co_player_rnn: policy = pufferlib.models.LSTMWrapper( - temp_env, + co_player_env, base_policy, input_size=co_player_rnn.get("input_size"), hidden_size=co_player_rnn.get("hidden_size"), @@ -891,37 +1002,68 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location="cpu") + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) policy.load_state_dict(state_dict, strict=True) - policy.eval() + if external_coplayer: + # Main owns the co-player on GPU. Don't pin to CPU; don't pass to + # workers. We hand the (still-on-CPU) policy to the caller via the + # vecenv attribute below — caller will move it to its own device. + policy.eval() + else: + policy = policy.to("cpu") # Ensure all buffers are on CPU for forked subprocesses + policy.eval() print( - f"Co player policy loaded with {conditioning_dims} conditioning dims (condition_type={condition_type})", + f"Co player policy loaded with {conditioning_dims} conditioning dims " + f"(condition_type={condition_type}, external={external_coplayer})", flush=True, ) - # Store policy and conditioning info in env_k - env_k["co_player_policy"]["co_player_policy_func"] = policy - torch.set_num_threads( - 1 - ) # NOTE this is the only way I could get co-player policies to work inside environment evaluation - torch.set_num_interop_threads(1) - import os + if not external_coplayer: + # Per-worker CPU path (legacy): hand the policy to env_kwargs so + # each forked worker sees it via env_k["co_player_policy"][...]. + env_k["co_player_policy"]["co_player_policy_func"] = policy - os.environ["OMP_NUM_THREADS"] = "1" - os.environ["MKL_NUM_THREADS"] = "1" - os.environ["NUMEXPR_NUM_THREADS"] = "1" + # NOTE: Setting threads to 1 is required for co-player policies to work + # inside environment evaluation. Higher values cause deadlock. + # set_num_interop_threads can only be called once per process; on + # subsequent vector-construction calls (e.g. multiple renders) it + # raises RuntimeError. Guard so that's idempotent. + try: + torch.set_num_threads(1) + except RuntimeError: + pass + try: + torch.set_num_interop_threads(1) + except RuntimeError: + pass - # Disable MKL if available - try: - torch.backends.mkl.enabled = False - except: - pass + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + os.environ["NUMEXPR_NUM_THREADS"] = "1" - for i in range(len(env_kwargs)): - env_kwargs[i]["co_player_policy"]["co_player_policy_func"] = policy + # Disable MKL if available + try: + torch.backends.mkl.enabled = False + except: + pass - return backend(env_creators, env_args, env_kwargs, num_envs, **kwargs) + for i in range(len(env_kwargs)): + env_kwargs[i]["co_player_policy"]["co_player_policy_func"] = policy + else: + # External path: workers should NOT carry the policy. Ensure their + # env_kwargs don't accidentally hold a stale reference. + for i in range(len(env_kwargs)): + env_kwargs[i]["co_player_policy"]["co_player_policy_func"] = None + + vecenv = backend(env_creators, env_args, env_kwargs, num_envs, **kwargs) + if env_k.get("co_player_enabled", False) and external_coplayer: + # Stash the GPU-bound co-player policy + the per-worker conditioning + # dimension on the vecenv so PuffeRL can pick them up in __init__. + vecenv.co_player_policy_func = policy + vecenv.co_player_conditioning_dims = conditioning_dims + vecenv.co_player_condition_type = condition_type + return vecenv def make_seeds(seed, num_envs): diff --git a/pyproject.toml b/pyproject.toml index 4bcc818849..72a1b34949 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,8 +117,10 @@ metta = [ 'hydra-core', 'duckdb', 'raylib>=5.5.0', - 'metta-common @ git+https://github.com/metta-ai/metta.git@main#subdirectory=common', - 'metta-mettagrid @ git+https://github.com/metta-ai/metta.git@main#subdirectory=mettagrid', + # 'metta-common @ git+ssh://git@github.com/metta-ai/metta.git@main#subdirectory=common', + # 'metta-mettagrid @ git+ssh://git@github.com/metta-ai/metta.git@main#subdirectory=mettagrid', + 'mettagrid' + ] microrts = [ diff --git a/render.py b/render.py new file mode 100644 index 0000000000..4742f0c1f2 --- /dev/null +++ b/render.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python +"""Unified Python rendering CLI for PufferDrive. + +Replaces the old C `./visualize` binary, `render_test.py`, and the rendering +side of `evaluate_human_logs.py`. Works with both LSTM and Transformer policies +on both nuPlan and WOMD datasets. + +Modes +----- + Baseline (no co-players): + python render.py --model-path X.pt + With a frozen co-player population: + python render.py --model-path adaptive.pt --co-player-path coplayer.pt + Human replay (one ego, others follow logged trajectories): + python render.py --model-path X.pt --human-replay + +Architecture is auto-detected from the checkpoint state-dict, but can be +overridden with --policy-architecture. +""" + +import argparse +import copy +import glob +import os +import shutil +import sys + +import torch + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from pufferlib.pufferl import load_config, load_env, load_policy +from pufferlib.ocean.drive.rollout import RenderContext, RenderView, rollout_loop + + +VIEW_MODE_BY_NAME = { + "sim_state": RenderView.FULL_SIM_STATE, + "bev": RenderView.BEV_AGENT_OBS, + "persp": RenderView.AGENT_PERSPECTIVE, +} + +VIEW_NAME = { + RenderView.FULL_SIM_STATE: "sim_state", + RenderView.BEV_AGENT_OBS: "bev", + RenderView.AGENT_PERSPECTIVE: "persp", +} + + +def model_id_from_path(path): + """Short, readable id derived from the checkpoint filename or its run dir.""" + fname = os.path.splitext(os.path.basename(path))[0] + # If the file lives inside a run directory like puffer_drive_/, + # prefer that (handles intermediate model_*.pt files). + parent = os.path.basename(os.path.dirname(path) or "") + candidate = parent if parent.startswith("puffer_") else fname + return candidate.replace("puffer_adaptive_drive_", "").replace("puffer_drive_", "") + + +def run_dir_for(model_path): + """Locate the experiment directory that owns this checkpoint.""" + parent = os.path.dirname(model_path) + parent_name = os.path.basename(parent or "") + if parent_name.startswith("puffer_"): + return parent + # File is `/puffer_..._.pt`: matching run dir sits next to it. + fname = os.path.splitext(os.path.basename(model_path))[0] + if fname.startswith("puffer_"): + candidate = os.path.join(parent, fname) + if os.path.isdir(candidate): + return candidate + return None + + +def default_output_dir(model_path): + """Default to /renders so artifacts live with the experiment.""" + run_dir = run_dir_for(model_path) + if run_dir: + return os.path.join(run_dir, "renders") + return os.path.join("./render_output", model_id_from_path(model_path)) + + +def detect_architecture(model_path, device="cpu"): + state_dict = torch.load(model_path, map_location=device) + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + if "positional_embedding" in state_dict: + return "Transformer" + if "lstm.weight_ih_l0" in state_dict: + return "Recurrent" + return None + + +def build_config(args): + """Build the env/vec/policy config dict for one render.""" + if args.adaptive or args.k_scenarios > 1 or args.co_player_path is not None: + env_name = "puffer_adaptive_drive" + else: + env_name = "puffer_drive" + + saved_argv = sys.argv + sys.argv = [sys.argv[0]] + try: + config = load_config(env_name) + finally: + sys.argv = saved_argv + + arch = args.policy_architecture or detect_architecture(args.model_path) or "Recurrent" + config["policy_architecture"] = arch + config["rnn_name"] = arch + config["use_rnn"] = True + config["load_model_path"] = args.model_path + config["vec"] = {"backend": "Serial", "num_envs": 1} + + config["env"]["render_mode"] = 1 # RENDER_HEADLESS + config["env"]["map_dir"] = args.map_dir + config["env"]["num_maps"] = args.num_maps + config["env"]["num_agents"] = args.num_agents + config["env"]["num_ego_agents"] = args.num_ego_agents + config["env"]["k_scenarios"] = args.k_scenarios + config["env"]["scenario_length"] = args.scenario_length + + if args.human_replay: + if env_name == "puffer_adaptive_drive": + config["env"]["human_replay_mode"] = True + config["env"]["co_player_enabled"] = False + config["env"]["max_controlled_agents"] = 1 + elif args.co_player_path is not None: + config["env"]["co_player_enabled"] = True + cpp = config["env"].setdefault("co_player_policy", {}) + cpp["policy_path"] = args.co_player_path + cpp["architecture"] = args.co_player_architecture or detect_architecture(args.co_player_path) or "Recurrent" + cpp_cond = cpp.setdefault("conditioning", {}) + cpp_cond["type"] = args.co_player_conditioning_type + cpp_cond["collision_weight_lb"] = args.co_player_collision_weight_lb + cpp_cond["collision_weight_ub"] = args.co_player_collision_weight_ub + cpp_cond["offroad_weight_lb"] = args.co_player_offroad_weight_lb + cpp_cond["offroad_weight_ub"] = args.co_player_offroad_weight_ub + cpp_cond["goal_weight_lb"] = args.co_player_goal_weight_lb + cpp_cond["goal_weight_ub"] = args.co_player_goal_weight_ub + cpp_cond["entropy_weight_lb"] = args.co_player_entropy_weight_lb + cpp_cond["entropy_weight_ub"] = args.co_player_entropy_weight_ub + cpp_cond["discount_weight_lb"] = args.co_player_discount_weight_lb + cpp_cond["discount_weight_ub"] = args.co_player_discount_weight_ub + else: + config["env"]["co_player_enabled"] = False + + cond = config["env"].setdefault("conditioning", {}) + cond["type"] = args.conditioning_type + cond["collision_weight_lb"] = args.collision_weight_lb + cond["collision_weight_ub"] = args.collision_weight_ub + cond["offroad_weight_lb"] = args.offroad_weight_lb + cond["offroad_weight_ub"] = args.offroad_weight_ub + cond["goal_weight_lb"] = args.goal_weight_lb + cond["goal_weight_ub"] = args.goal_weight_ub + cond["entropy_weight_lb"] = args.entropy_weight_lb + cond["entropy_weight_ub"] = args.entropy_weight_ub + cond["discount_weight_lb"] = args.discount_weight_lb + cond["discount_weight_ub"] = args.discount_weight_ub + + config["train"]["device"] = args.device + return env_name, config + + +def mode_tag(args): + if args.human_replay: + return "human_replay" + if args.co_player_path is not None: + return "coplayer" + return "baseline" + + +def render_one(env_name, base_config, view_modes, render_idx, seed, args): + """One render = (one map seed × all view modes), one rollout per view.""" + print(f"\n[Render {render_idx + 1}/{args.num_renders}] map_seed={seed}") + cfg = copy.deepcopy(base_config) + cfg["env"]["map_seed"] = seed + + vecenv = load_env(env_name, cfg) + try: + policy = load_policy(cfg, vecenv, env_name) + policy.eval() + + # Pull the actual map id loaded — beats relying on the seed only + map_ids = getattr(vecenv.driver_env, "map_ids", None) + map_id = int(map_ids[0]) if map_ids is not None and len(map_ids) > 0 else seed + + model_id = model_id_from_path(args.model_path) + if args.co_player_path is not None and not args.human_replay: + mode = f"vs_{model_id_from_path(args.co_player_path)}" + else: + mode = mode_tag(args) + coplayer_part = "" + + max_steps = args.max_steps if args.max_steps is not None else (args.k_scenarios * args.scenario_length) + os.makedirs(args.output_dir, exist_ok=True) + saved = [] + + for view_mode in view_modes: + view = VIEW_NAME[view_mode] + basename = f"{model_id}_{mode}{coplayer_part}_k{args.k_scenarios}_map{map_id:03d}_{view}" + vecenv.reset(seed=seed) + rollout_loop( + policy=policy, + env=vecenv, + device=args.device, + use_rnn=True, + max_steps=max_steps, + render_ctx=RenderContext( + view_mode=view_mode, + env_id=0, + draw_traces=True, + video_basename=basename, + ), + ) + src = f"{basename}.mp4" + if os.path.exists(src): + target = os.path.join(args.output_dir, src) + shutil.move(src, target) + print(f" saved {target}") + saved.append(target) + else: + print(f" WARNING: expected {src} not produced") + finally: + vecenv.close() + + return saved + + +def main(): + p = argparse.ArgumentParser(description="Unified Python rendering for PufferDrive") + p.add_argument("--model-path", required=True, help="Trained ego policy checkpoint (.pt)") + p.add_argument( + "--co-player-path", default=None, help="Frozen co-player policy (.pt). Omit for baseline / human-replay." + ) + p.add_argument("--human-replay", action="store_true", help="Render in human-replay mode (one ego, others = log)") + p.add_argument("--adaptive", action="store_true", help="Force puffer_adaptive_drive env even when k_scenarios=1") + + p.add_argument( + "--policy-architecture", + choices=["Recurrent", "Transformer"], + default=None, + help="Override ego architecture (auto-detected from checkpoint if omitted)", + ) + p.add_argument( + "--co-player-architecture", + choices=["Recurrent", "Transformer"], + default=None, + help="Override co-player architecture", + ) + + p.add_argument( + "--map-dir", + default="resources/drive/binaries/training", + help="Map binary directory (e.g. resources/drive/binaries/nuplan)", + ) + p.add_argument("--num-maps", type=int, default=None, help="Map pool size (default: max(100, num_renders))") + p.add_argument("--num-renders", type=int, default=1, help="Number of independent renders (different map seeds)") + p.add_argument("--start-seed", type=int, default=1) + p.add_argument( + "--seed-stride", + type=int, + default=1009, + help="Stride between consecutive map seeds. Spaced apart so adjacent renders pick different maps.", + ) + p.add_argument("--num-agents", type=int, default=64) + p.add_argument("--num-ego-agents", type=int, default=32) + + p.add_argument("--k-scenarios", type=int, default=2, help="Number of scenarios per episode (adaptive)") + p.add_argument("--scenario-length", type=int, default=91) + p.add_argument( + "--max-steps", type=int, default=None, help="Steps per render (default: k_scenarios * scenario_length)" + ) + + p.add_argument("--view-mode", choices=["sim_state", "bev", "persp", "all"], default="sim_state") + p.add_argument("--output-dir", default=None, help="Where to write mp4s (default: /renders)") + p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + + p.add_argument("--conditioning-type", choices=["none", "reward", "entropy", "discount", "all"], default="none") + p.add_argument("--collision-weight-lb", type=float, default=-3.0) + p.add_argument("--collision-weight-ub", type=float, default=-3.0) + p.add_argument("--offroad-weight-lb", type=float, default=-1.0) + p.add_argument("--offroad-weight-ub", type=float, default=-1.0) + p.add_argument("--goal-weight-lb", type=float, default=1.0) + p.add_argument("--goal-weight-ub", type=float, default=1.0) + p.add_argument("--entropy-weight-lb", type=float, default=0.001) + p.add_argument("--entropy-weight-ub", type=float, default=0.001) + p.add_argument("--discount-weight-lb", type=float, default=0.98) + p.add_argument("--discount-weight-ub", type=float, default=0.98) + + p.add_argument( + "--co-player-conditioning-type", choices=["none", "reward", "entropy", "discount", "all"], default="all" + ) + p.add_argument("--co-player-collision-weight-lb", type=float, default=-1.0) + p.add_argument("--co-player-collision-weight-ub", type=float, default=0.0) + p.add_argument("--co-player-offroad-weight-lb", type=float, default=-0.4) + p.add_argument("--co-player-offroad-weight-ub", type=float, default=0.0) + p.add_argument("--co-player-goal-weight-lb", type=float, default=0.0) + p.add_argument("--co-player-goal-weight-ub", type=float, default=1.0) + p.add_argument("--co-player-entropy-weight-lb", type=float, default=0.0) + p.add_argument("--co-player-entropy-weight-ub", type=float, default=0.1) + p.add_argument("--co-player-discount-weight-lb", type=float, default=0.8) + p.add_argument("--co-player-discount-weight-ub", type=float, default=1.0) + + args = p.parse_args() + + if args.num_maps is None: + args.num_maps = max(100, args.num_renders) + if args.output_dir is None: + args.output_dir = default_output_dir(args.model_path) + + if args.view_mode == "all": + view_modes = list(VIEW_MODE_BY_NAME.values()) + else: + view_modes = [VIEW_MODE_BY_NAME[args.view_mode]] + + env_name, config = build_config(args) + + print(f"Env: {env_name}") + print(f"Architecture: {config['policy_architecture']}") + print(f"Map dir: {args.map_dir}, num_maps: {args.num_maps}, k_scenarios: {args.k_scenarios}") + print(f"Mode: {'human_replay' if args.human_replay else ('coplayer' if args.co_player_path else 'baseline')}") + print(f"Views: {[vm.name for vm in view_modes]}") + + saved = [] + for i in range(args.num_renders): + seed = args.start_seed + i * args.seed_stride + saved.extend(render_one(env_name, config, view_modes, i, seed, args)) + + print(f"\nDone. {len(saved)} videos in {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/ablations/all_coplayers.sh b/scripts/ablations/all_coplayers.sh new file mode 100755 index 0000000000..49b5caf1d2 --- /dev/null +++ b/scripts/ablations/all_coplayers.sh @@ -0,0 +1,105 @@ +#!/bin/bash +#SBATCH --job-name=coplayer_ablation +#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out +#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err +#SBATCH --mem=128GB +#SBATCH --time=24:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --account=torch_pr_355_tandon_advanced +#SBATCH --cpus-per-task=48 +#SBATCH --gres=gpu:1 +#SBATCH --array=0-7 + +# Train every co-player ablation variant in one slurm array. +# +# sbatch scripts/ablations/all_coplayers.sh +# +# 8 array tasks = 2 datasets × 2 archs × 2 conditioning types: +# idx | dataset | architecture | conditioning +# ----+---------+--------------+------------- +# 0 | womd | Recurrent | none +# 1 | womd | Recurrent | all +# 2 | womd | Transformer | none +# 3 | womd | Transformer | all +# 4 | nuplan | Recurrent | none +# 5 | nuplan | Recurrent | all +# 6 | nuplan | Transformer | none +# 7 | nuplan | Transformer | all +# +# Resulting checkpoints land at experiments/puffer_drive_.pt. +# Once they've trained, plug them into scripts/adaptive/*.sh as ZIPPED_RUNS +# entries. + +# Decode array index → (dataset, arch, cond_type) +RUNS=( + "womd Recurrent none" + "womd Recurrent all" + "womd Transformer none" + "womd Transformer all" + "nuplan Recurrent none" + "nuplan Recurrent all" + "nuplan Transformer none" + "nuplan Transformer all" +) + +read -r DATASET ARCH COND <<< "${RUNS[$SLURM_ARRAY_TASK_ID]}" + +# Dataset → map_dir + num_maps +case "$DATASET" in + nuplan) + MAP_DIR="resources/drive/binaries/nuplan" + NUM_MAPS=5000 + ;; + womd) + MAP_DIR="resources/drive/binaries/training" + NUM_MAPS=10000 + ;; + *) echo "unknown dataset $DATASET" >&2; exit 1 ;; +esac + +# Conditioning → puffer flags +# `all` sweeps entropy 0→0.1 and discount 0.8→1.0 (the same range used in +# scripts/coplayers/* with the cell at idx 0). +case "$COND" in + none) + COND_ARGS="--env.conditioning.type none" + ;; + all) + COND_ARGS="--env.conditioning.type all \ + --env.conditioning.entropy-weight-lb 0 \ + --env.conditioning.entropy-weight-ub 0.1 \ + --env.conditioning.discount-weight-lb 0.8 \ + --env.conditioning.discount-weight-ub 1.0" + ;; + *) echo "unknown conditioning $COND" >&2; exit 1 ;; +esac + +TAG="coplayer_${DATASET}_${ARCH,,}_cond-${COND}" + +singularity exec --nv \ + --overlay "$OVERLAY_FILE:ro" \ + "$SINGULARITY_IMAGE" \ + bash -c " + set -e + + source ~/.bashrc + cd /scratch/mmk9418/projects/Adaptive_Driving_Agent + source .venv/bin/activate + + nice -n 19 python scripts/gpu_heartbeat.py & + HEARTBEAT_PID=\$! + + puffer train puffer_drive --wandb \ + --wandb-project ada_coplayer_ablation \ + --tag $TAG \ + --policy-architecture $ARCH \ + --rnn-name $ARCH \ + --env.map-dir $MAP_DIR \ + --env.num-maps $NUM_MAPS \ + --eval.map-dir $MAP_DIR \ + --train.checkpoint-interval 50 \ + $COND_ARGS + + kill \$HEARTBEAT_PID + " diff --git a/scripts/ablations/human_align_ablation.sh b/scripts/ablations/human_align_ablation.sh new file mode 100755 index 0000000000..e2aa962f58 --- /dev/null +++ b/scripts/ablations/human_align_ablation.sh @@ -0,0 +1,74 @@ +#!/bin/bash +#SBATCH --job-name=human_align_ablation +#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out +#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err +#SBATCH --mem=128GB +#SBATCH --time=24:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --account=torch_pr_355_tandon_advanced +#SBATCH --cpus-per-task=48 +#SBATCH --gres=gpu:1 +#SBATCH --array=0-5 + +# Human behavior alignment ablation: 6 experiments +# collision_weight_lb × offroad_weight_lb grid +# +# | Exp | collision_weight_lb | offroad_weight_lb | +# |-----|---------------------|-------------------| +# | 0 | -3 | -2.0 | +# | 1 | -3 | -1.0 | +# | 2 | -3 | -0.5 | +# | 3 | -2 | -2.0 | +# | 4 | -2 | -1.0 | +# | 5 | -2 | -0.5 | + +# Define the grid +COLLISION_WEIGHTS=(-3 -3 -3 -2 -2 -2) +OFFROAD_WEIGHTS=(-2.0 -1.0 -0.5 -2.0 -1.0 -0.5) + +# Get values for this array task +COLLISION_LB=${COLLISION_WEIGHTS[$SLURM_ARRAY_TASK_ID]} +OFFROAD_LB=${OFFROAD_WEIGHTS[$SLURM_ARRAY_TASK_ID]} + +# Fixed parameters +DISCOUNT_UB=0.995 +SEED=42 +NUPLAN_NUM_MAPS=4999 + +echo "Running experiment $SLURM_ARRAY_TASK_ID: collision_weight_lb=$COLLISION_LB, offroad_weight_lb=$OFFROAD_LB" + +singularity exec --nv \ + --overlay "$OVERLAY_FILE:ro" \ + "$SINGULARITY_IMAGE" \ + bash -c " + set -e + + source ~/.bashrc + cd /scratch/mmk9418/projects/Adaptive_Driving_Agent + source .venv/bin/activate + + nice -n 19 python scripts/gpu_heartbeat.py & + HEARTBEAT_PID=\$! + + puffer train puffer_drive \ + --wandb --wandb-project human-align-ablation \ + --tag ablation_collision${COLLISION_LB}_offroad${OFFROAD_LB} \ + --env.map-dir resources/drive/binaries/nuplan \ + --env.num-maps $NUPLAN_NUM_MAPS \ + --env.conditioning.type all \ + --env.conditioning.collision-weight-lb $COLLISION_LB \ + --env.conditioning.collision-weight-ub 0 \ + --env.conditioning.offroad-weight-lb $OFFROAD_LB \ + --env.conditioning.offroad-weight-ub 0 \ + --env.conditioning.discount-weight-lb 0.8 \ + --env.conditioning.discount-weight-ub $DISCOUNT_UB \ + --policy-architecture Transformer \ + --rnn-name Transformer \ + --train.context-length 91 \ + --train.horizon 91 \ + --train.seed $SEED \ + --eval.map-dir resources/drive/binaries/nuplan + + kill \$HEARTBEAT_PID + " diff --git a/scripts/ablations/human_align_local.sh b/scripts/ablations/human_align_local.sh new file mode 100755 index 0000000000..72af9454d8 --- /dev/null +++ b/scripts/ablations/human_align_local.sh @@ -0,0 +1,57 @@ +#!/bin/bash +set -e + +# Local 8-GPU launcher for human-align ablation. +# One tmux window per experiment, each pinned to a single GPU. + +COLLISION_WEIGHTS=(-3 -3 -2 -2) +OFFROAD_WEIGHTS=(-2.0 -0.5 -2.0 -0.5) + +DISCOUNT_UB=0.995 +SEED=42 +NUPLAN_NUM_MAPS=4999 + +SESSION=human_align_redo +tmux new-session -d -s "$SESSION" -n exp0 + +for i in "${!COLLISION_WEIGHTS[@]}"; do + C=${COLLISION_WEIGHTS[$i]} + O=${OFFROAD_WEIGHTS[$i]} + L=${REWARD_LANE[$i]} + WIN="exp${i}" + + if [ "$i" -ne 0 ]; then + tmux new-window -t "$SESSION" -n "$WIN" + fi + + CMD="cd /workspace/ADA && \ +source .venv/bin/activate && \ +export CUDA_VISIBLE_DEVICES=$((i+4)) && \ +echo 'Running exp${i}: collision_weight_lb=$C, offroad_weight_lb=$O on GPU $i' && \ +xvfb-run -a puffer train puffer_drive \ + --wandb --wandb-project human-align-ablation \ + --tag human_ablation_lane_rewards_apr26_redo \ + --env.map-dir resources/drive/binaries/nuplan \ + --env.num-maps $NUPLAN_NUM_MAPS \ + --env.conditioning.type all \ + --env.conditioning.collision-weight-lb $C \ + --env.conditioning.collision-weight-ub 0 \ + --env.conditioning.offroad-weight-lb $O \ + --env.conditioning.offroad-weight-ub 0 \ + --env.conditioning.discount-weight-lb 0.8 \ + --env.conditioning.discount-weight-ub $DISCOUNT_UB \ + --env.reward-lane-align 0.01 \ + --env.reward-vel-align 1.0 \ + --policy-architecture Transformer \ + --train.context-length 91 \ + --train.horizon 91 \ + --train.seed $SEED \ + --eval.map-dir resources/drive/binaries/nuplan" + + tmux send-keys -t "$SESSION:$WIN" "$CMD" C-m +done + +echo "Launched ${#COLLISION_WEIGHTS[@]} runs in tmux session '$SESSION'." +echo "Attach: tmux attach -t $SESSION" +echo "Switch windows: Ctrl-b 0..7 (or Ctrl-b n / Ctrl-b p)" +echo "Detach: Ctrl-b d" diff --git a/scripts/adaptive/nuplan_recurrent.sh b/scripts/adaptive/nuplan_recurrent.sh new file mode 100755 index 0000000000..93136910f6 --- /dev/null +++ b/scripts/adaptive/nuplan_recurrent.sh @@ -0,0 +1,81 @@ +#!/bin/bash +#SBATCH --job-name=adaptive_nuplan_rnn +#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out +#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err +#SBATCH --mem=128GB +#SBATCH --time=24:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --account=torch_pr_355_tandon_advanced +#SBATCH --cpus-per-task=48 +#SBATCH --gres=gpu:1 +#SBATCH --array=0-15 + +# Train adaptive agents on NuPlan with Recurrent architecture +# Uses pre-trained NuPlan Recurrent co-players with varied conditioning +# +# PREREQUISITE: Train co-players first with scripts/coplayers/nuplan_recurrent.sh +# Then update ZIPPED_RUNS with the trained policy paths from wandb + +# Co-player policies trained with scripts/coplayers/nuplan_recurrent.sh +# Each entry: "policy_path entropy_weight_ub discount_weight_lb" +ZIPPED_RUNS=( + "experiments/puffer_drive_mwiatx5g.pt 0.5 0.8" + "experiments/puffer_drive_old90mw2.pt 0.1 0.8" + "TODO_FAILED 0.01 0.8" + "TODO_FAILED 0 0.8" + + "experiments/puffer_drive_hous32qj.pt 0.5 0.6" + "experiments/puffer_drive_dhxoorxc.pt 0.1 0.6" + "experiments/puffer_drive_0h5radmw.pt 0.01 0.6" + "experiments/puffer_drive_elet1zj8.pt 0 0.6" + + "experiments/puffer_drive_a57umusk.pt 0.5 0.4" + "experiments/puffer_drive_zozl26ek.pt 0.1 0.4" + "experiments/puffer_drive_jco8adma.pt 0.01 0.4" + "experiments/puffer_drive_5hhwfmmt.pt 0 0.4" + + "experiments/puffer_drive_a56dgt1x.pt 0.5 0.2" + "experiments/puffer_drive_dqzqt7qx.pt 0.1 0.2" + "experiments/puffer_drive_qetoozcn.pt 0.01 0.2" + "experiments/puffer_drive_q30rjcbu.pt 0 0.2" +) + +read -r COPLAYER_PATH ENTROPY_UB DISCOUNT_LB <<< "${ZIPPED_RUNS[$SLURM_ARRAY_TASK_ID]}" + +# Fixed values +CONDITION_TYPE="all" +DISCOUNT_UB=1 +ENTROPY_LB=0 +NUPLAN_NUM_MAPS=5000 + +singularity exec --nv \ + --overlay "$OVERLAY_FILE:ro" \ + "$SINGULARITY_IMAGE" \ + bash -c " + set -e + + source ~/.bashrc + cd /scratch/mmk9418/projects/Adaptive_Driving_Agent + source .venv/bin/activate + + nice -n 19 python scripts/gpu_heartbeat.py & + HEARTBEAT_PID=\$! + + puffer train puffer_adaptive_drive --wandb --tag adaptive_nuplan_recurrent_k2 \ + --env.map-dir resources/drive/binaries/nuplan \ + --env.num-maps $NUPLAN_NUM_MAPS \ + --env.conditioning.type none \ + --env.co-player-enabled 1 \ + --env.co-player-policy.policy-path $COPLAYER_PATH \ + --env.co-player-policy.conditioning.type $CONDITION_TYPE \ + --env.co-player-policy.conditioning.discount-weight-lb $DISCOUNT_LB \ + --env.co-player-policy.conditioning.discount-weight-ub $DISCOUNT_UB \ + --env.co-player-policy.conditioning.entropy-weight-lb $ENTROPY_LB \ + --env.co-player-policy.conditioning.entropy-weight-ub $ENTROPY_UB \ + --policy-architecture Recurrent \ + --rnn-name Recurrent \ + --eval.map-dir resources/drive/binaries/nuplan + + kill \$HEARTBEAT_PID + " diff --git a/scripts/adaptive/nuplan_transformer.sh b/scripts/adaptive/nuplan_transformer.sh new file mode 100755 index 0000000000..e9c105bc03 --- /dev/null +++ b/scripts/adaptive/nuplan_transformer.sh @@ -0,0 +1,78 @@ +#!/bin/bash +#SBATCH --job-name=adaptive_nuplan_tfm +#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out +#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err +#SBATCH --mem=128GB +#SBATCH --time=48:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --account=torch_pr_355_tandon_advanced +#SBATCH --cpus-per-task=48 +#SBATCH --gres=gpu:1 +#SBATCH --array=0-8 + +# Train adaptive agents on NuPlan with Transformer architecture +# Uses pre-trained NuPlan Transformer co-players with varied conditioning +# +# PREREQUISITE: Train co-players first with scripts/coplayers/nuplan_transformer.sh +# Then update ZIPPED_RUNS with the trained policy paths from wandb + +# Co-player policies trained with scripts/coplayers/nuplan_transformer.sh +# Each entry: "policy_path entropy_weight_ub discount_weight_lb" +ZIPPED_RUNS=( +"experiments/puffer_drive_joqbmi4s.pt 0.01 0.2" +"experiments/puffer_drive_0h81rtfi.pt 0 0.4" +# "experiments/puffer_drive_medzmgum.pt 0.1 0.2" +# "experiments/puffer_drive_f4a8yoi9.pt 0.5 0.2" +"experiments/puffer_drive_b1yx43w2.pt 0.01 0.4" +# "experiments/puffer_drive_lv4x8hlt.pt 0.1 0.4" +# "experiments/puffer_drive_j51yz49e.pt 0.5 0.4" +"experiments/puffer_drive_iry1wanp.pt 0 0.6" +"experiments/puffer_drive_kx8bhu3v.pt 0.01 0.6" +"experiments/puffer_drive_js4mf85k.pt 0.1 0.6" +# "experiments/puffer_drive_52u5onve.pt 0.5 0.6" +"experiments/puffer_drive_nu7lkmx4.pt 0 0.8" +"experiments/puffer_drive_f8dpkpbq.pt 0.01 0.8" +"experiments/puffer_drive_zxsxu6z7.pt 0.1 0.8" +# "experiments/puffer_drive_mfahi5bc.pt 0.5 0.8" +) + +read -r COPLAYER_PATH ENTROPY_UB DISCOUNT_LB <<< "${ZIPPED_RUNS[$SLURM_ARRAY_TASK_ID]}" + +# Fixed values +CONDITION_TYPE="all" +DISCOUNT_UB=1 +ENTROPY_LB=0 +NUPLAN_NUM_MAPS=5000 + +singularity exec --nv \ + --overlay "$OVERLAY_FILE:ro" \ + "$SINGULARITY_IMAGE" \ + bash -c " + set -e + + source ~/.bashrc + cd /scratch/mmk9418/projects/Adaptive_Driving_Agent + source .venv/bin/activate + + nice -n 19 python scripts/gpu_heartbeat.py & + HEARTBEAT_PID=\$! + + puffer train puffer_adaptive_drive --wandb --tag adaptive_nuplan_transformer_4apr \ + --env.map-dir resources/drive/binaries/nuplan \ + --env.num-maps $NUPLAN_NUM_MAPS \ + --env.conditioning.type none \ + --env.co-player-enabled 1 \ + --env.co-player-policy.policy-path $COPLAYER_PATH \ + --env.co-player-policy.architecture Transformer \ + --env.co-player-policy.conditioning.type $CONDITION_TYPE \ + --env.co-player-policy.conditioning.discount-weight-lb $DISCOUNT_LB \ + --env.co-player-policy.conditioning.discount-weight-ub $DISCOUNT_UB \ + --env.co-player-policy.conditioning.entropy-weight-lb $ENTROPY_LB \ + --env.co-player-policy.conditioning.entropy-weight-ub $ENTROPY_UB \ + --policy-architecture Transformer \ + --rnn-name Transformer \ + --eval.map-dir resources/drive/binaries/nuplan + + kill \$HEARTBEAT_PID + " diff --git a/scripts/adaptive/nuplan_transformer_local.sh b/scripts/adaptive/nuplan_transformer_local.sh new file mode 100644 index 0000000000..892af6113b --- /dev/null +++ b/scripts/adaptive/nuplan_transformer_local.sh @@ -0,0 +1,111 @@ +#!/bin/bash +set -e + +# Local 4-GPU adaptive launcher (tmux pattern, mirrors human_align_local.sh). +# Trains 4 adaptive ego agents, one against each of the new local co-player +# checkpoints (collision_lb=-2, offroad_lb=-2, lane_reward=0.01). +# +# Each ego run uses a co-player conditioning band that matches what THAT +# co-player saw during its own training (so the policy is queried inside the +# distribution it learned). +# +# DEFAULT GPU ASSIGNMENT: 0,1,2,3. Override with: GPUS="4 5 6 7" bash