diff --git a/.claude/commands/add-model-support.md b/.claude/commands/add-model-support.md new file mode 100644 index 000000000..8508ff64e --- /dev/null +++ b/.claude/commands/add-model-support.md @@ -0,0 +1,53 @@ +--- +description: Guided workflow for adding a new architecture adapter to TransformerBridge. +argument-hint: +--- + +Adding TransformerBridge support for HF model `$ARGUMENTS`. If empty, ask the user for the HF repo path first. + +Each step names the doc to read **when you reach that step** — don't load all up front. + +1. **Check registry state and decide whether to verify.** + + State: + - Architecture supported? Check `SUPPORTED_ARCHITECTURES` in [`architecture_adapter_factory.py`](../../transformer_lens/factories/architecture_adapter_factory.py). + - Model in registry? Check [`supported_models.json`](../../transformer_lens/tools/model_registry/data/supported_models.json); note `status` (0=unverified, 1=verified, 2=skipped, 3=failed). + + Branch: + + - **Supported AND `status==1`** → already verified. Ask the user the symptom (bug-report path, not add-support). Stop. + - **Supported, `status != 1`** → proceed to **Confirm before verification**. If `status==3`, read existing `note` for the prior failure mode. + - **Supported, not in registry** → add an entry per [§Adding the HF repo to the registry](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#adding-the-hf-repo-to-the-registry) with `status: 0` and null scores, then proceed. + - **Not supported** → skip to step 2. + + ### Confirm before verification + + Always ask the user first, even for small models: + + 1. Dry-run to project cost: + ``` + set -a; source .env; set +a + uv run python -m transformer_lens.tools.model_registry.verify_models --model "$ARGUMENTS" --dry-run + ``` + 2. Show: model ID, architecture class, estimated parameters, projected memory (GB), HF_TOKEN needed?, runtime (30 s–2 min sub-1B, 2–15 min 1B–7B, 15+ min 7B+/multimodal), what verification does (Phases 1–4; updates `supported_models.json` on success). + 3. Ask: "Run verification on this machine? (Y/N)" + + **Confirm** → `/verify-model $ARGUMENTS`. On pass, done. On fail, see [debugging_numerical_divergence.md](../../docs/source/content/debugging_numerical_divergence.md) (per-sibling adapter bug). + + **Reject** → `gh issue create --template verify-model.md` (fill from dry-run output). No `gh`? . Stop. + +2. **Analyze the HF model.** Read `config.json` and source — identify embedding, attention, MLP, normalization, output-head layouts. Read [§Config-attr propagation](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#config-attr-propagation) and decide which non-standard attrs (`final_logit_softcapping`, `sliding_window`, etc.) need surfacing on `self.cfg`. + +3. **Pick a starting adapter.** See [§Starter-adapter table](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#starter-adapter-table). Copy into [`supported_architectures/`](../../transformer_lens/model_bridge/supported_architectures/) as `.py`. **Tokenizer-policy flags are per-model** — see [§Tokenizer policy](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#tokenizer-policy). + +4. **Fill `self.component_mapping`.** Bridge-native hook names. Reference: [§Minimal contract](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#minimal-contract), [§Common gotchas](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#common-gotchas). + +5. **Register in all four sites** per [§Registration steps](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#registration-steps). Then run the invariant test: `uv run pytest tests/unit/tools/test_model_registry.py -k TestRegistrySyncedWithFactory`. + +6. **Add the HF repo entry** to [`data/supported_models.json`](../../transformer_lens/tools/model_registry/data/supported_models.json) per [§Adding the HF repo to the registry](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#adding-the-hf-repo-to-the-registry). Ask the user about adding canonical sibling variants from `CANONICAL_AUTHORS_BY_ARCH[]`. + +7. **Verify** end-to-end: `/verify-model $ARGUMENTS`. Read both `status` AND per-phase scores. `STATUS_VERIFIED` means hard gates passed (see [§Phase-score thresholds](../../transformer_lens/tools/model_registry/AGENTS.md#phase-score-thresholds)) — but P4's 50% bar is intentionally lenient. P4 well below 100% on a small parity-test model + `status==1` → suspect missing `preprocess_weights` fold or wrong `default_prepend_bos`; investigate before step 8. + +8. **Write tests** per [§Required tests](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#required-tests) (unit + integration). Copy the closest sibling. + +9. **`/task-complete`** — comment cleanup, `/format`, standard test tiers, loop until clean. diff --git a/.claude/commands/build-docs.md b/.claude/commands/build-docs.md new file mode 100644 index 000000000..ffa90f979 --- /dev/null +++ b/.claude/commands/build-docs.md @@ -0,0 +1,16 @@ +--- +description: Source .env then build the Sphinx docs. +--- + +Build the documentation locally: + +``` +set -a; source .env; set +a +uv run build-docs +``` + +Sourcing `.env` is required so `HF_TOKEN` is available — some doctests and notebook embeddings load gated models. Output goes to [docs/build/](../../docs/build/). + +For an interactive live-reloading preview instead, run `uv run docs-hot-reload`. + +Docs follow Google docstring style with reST extensions; see [docs/source/content/contributing.md](../../docs/source/content/contributing.md) for the style guide. diff --git a/.claude/commands/format.md b/.claude/commands/format.md new file mode 100644 index 000000000..fec0d2dda --- /dev/null +++ b/.claude/commands/format.md @@ -0,0 +1,14 @@ +--- +description: Type-check then format the working tree. +--- + +Run mypy first, then format. Mypy fixes (`isinstance`, `cast`, signatures) can introduce format drift — running format after means a single pass. + +``` +uv run mypy . +make format +``` + +`uv run mypy .` uses the config in [pyproject.toml](../../pyproject.toml). `make format` runs `pycln --all` (unused imports), `isort`, and `black` (line length 100). + +If mypy reports errors, fix the underlying typing issue — never add `# type: ignore`. Prefer `isinstance` / `typing.cast` ([AGENTS.md §10](../../AGENTS.md#10-hard-rules)). diff --git a/.claude/commands/task-complete.md b/.claude/commands/task-complete.md new file mode 100644 index 000000000..318a5bb91 --- /dev/null +++ b/.claude/commands/task-complete.md @@ -0,0 +1,46 @@ +--- +description: End-of-task gate. Clean up new comments, format, type-check, and run the standard test tiers (unit + docstring + acceptance + integration) — fixing issues along the way. +--- + +Run the end-of-task gate. Do not declare the task complete until every step below passes cleanly. + +### 1. Clean up new comments + +Review every comment and docstring **added or modified during this task** against the rules in [AGENTS.md §10](../../AGENTS.md#10-hard-rules): + +- Comments should be terse one-liners; docstrings are one-line where possible. +- Inline comments explain WHY, not WHAT — delete any that just restate the code. +- Multi-paragraph explanations belong in PR descriptions or design docs, not source. +- Remove any references to plan files, audit IDs, finding IDs, or "see plan section X" — those rot as the codebase evolves and belong only in the PR description. + +Use `git diff` against the merge-base to scope the review to genuinely new comments — do NOT rewrite unrelated comments elsewhere in the file. + +### 2. Type-check, then format + +Run mypy **before** format. Mypy fixes (`isinstance`, `typing.cast`, signature changes) can introduce format drift — running format after mypy means a single format pass. + +``` +uv run mypy . +make format +``` + +If mypy reports new errors, fix the underlying typing issue. Do not add `# type: ignore`. + +### 3. Run the standard test tiers + +``` +set -a; source .env; set +a +make test-pr +``` + +`make test-pr` runs unit + docstring + acceptance + integration — the tiers that gate PR review for almost every change. Notebook and benchmark suites are intentionally skipped (slow, gated models, CI runs them separately). If your change specifically touched a notebook or a benchmark, also run that file directly (`pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/.ipynb` or `make benchmark-test`). + +Investigate every failure. Do not dismiss any failure as "pre-existing" or "unrelated" — fix the underlying issue, even if it predates this task (see [AGENTS.md §10](../../AGENTS.md#10-hard-rules)). Do not add platform skips or `xfail` markers to dodge a failing test. + +### 4. Re-loop on failure + +If any step surfaces issues, fix them and restart from step 1 — fixes can reintroduce comment, format, type, or test drift. + +### 5. Report + +Report the actual final command output, not a summary. Reviewers re-run tests; agent self-reports are not evidence ([AGENTS.md §10](../../AGENTS.md#10-hard-rules)). diff --git a/.claude/commands/test-all.md b/.claude/commands/test-all.md new file mode 100644 index 000000000..583eb08a7 --- /dev/null +++ b/.claude/commands/test-all.md @@ -0,0 +1,16 @@ +--- +description: Run the full test suite (unit + integration + acceptance + benchmark + docstring + notebook). Slow. +--- + +Run every test tier in TransformerLens via the top-level `make test` target: + +``` +make test +``` + +This is slow — it runs unit, integration, acceptance, benchmark, docstring, and notebook tests sequentially. It hits HuggingFace Hub and loads multiple models. Before running, confirm: + +1. `.env` is sourced so `HF_TOKEN` is set (`set -a; source .env; set +a`). +2. No other heavy GPU/MPS jobs are running on this machine — model verification cannot run concurrently (see [AGENTS.md §10](../../AGENTS.md#10-hard-rules)). + +Report the actual command output. Investigate any failures rather than dismissing them. diff --git a/.claude/commands/test-unit.md b/.claude/commands/test-unit.md new file mode 100644 index 000000000..11274ff81 --- /dev/null +++ b/.claude/commands/test-unit.md @@ -0,0 +1,11 @@ +--- +description: Run the unit test suite. +--- + +Run the TransformerLens unit tests: + +``` +make unit-test +``` + +If any test fails, investigate the failure rather than dismissing it as "pre-existing" or unrelated — see [AGENTS.md §10](../../AGENTS.md#10-hard-rules). Report the actual command output, not a summary. diff --git a/.claude/commands/typecheck.md b/.claude/commands/typecheck.md new file mode 100644 index 000000000..68dd34921 --- /dev/null +++ b/.claude/commands/typecheck.md @@ -0,0 +1,11 @@ +--- +description: Run mypy across the project. +--- + +Run the type checker: + +``` +uv run mypy . +``` + +Config lives in `[tool.mypy]` of [pyproject.toml](../../pyproject.toml). If mypy reports errors, fix the underlying typing issue — do not add `# type: ignore`. Prefer `isinstance` assertions or `typing.cast` for narrowing. diff --git a/.claude/commands/verify-model.md b/.claude/commands/verify-model.md new file mode 100644 index 000000000..8eec7501a --- /dev/null +++ b/.claude/commands/verify-model.md @@ -0,0 +1,69 @@ +--- +description: Run verify_models.py against a single model (non-parallel). Always dry-run first. +argument-hint: +--- + +Verify model `$ARGUMENTS`. If empty, ask for an HF repo path (e.g. `gpt2`, `meta-llama/Llama-2-7b-hf`) or registry alias. + +## Always dry-run first + +Verification loads the full model and runs Phases 1–4 — 30 s to 30 min, needs memory to hold the model. **Never invoke the real run blindly.** + +``` +set -a; source .env; set +a +uv run python -m transformer_lens.tools.model_registry.verify_models --model "$ARGUMENTS" --dry-run +``` + +Capture: estimated parameter count, projected memory (GB), HF_TOKEN requirement, architecture class. + +| Model | Action | +|---|---| +| Cached small (`gpt2`, `attn-only-*`, `tiny-stories-1M`, `distilgpt2`, …) | Proceed; report dry-run in your response so user can intervene | +| ≥1B params, gated, or anything else | Present dry-run, ask before running | + +## Run the verification + +``` +set -a; source .env; set +a +uv run python -m transformer_lens.tools.model_registry.verify_models --model "$ARGUMENTS" +``` + +## Optional flags + +Full reference: [tools/model_registry/AGENTS.md §Flag reference](../../transformer_lens/tools/model_registry/AGENTS.md#flag-reference). + +- `--device cpu|cuda|mps` — override device selection +- `--dtype float32|bfloat16` — override dtype +- `--max-memory ` — skip if param estimate exceeds; e.g. `16` on a 24 GB GPU leaves headroom for activations +- `--phases 1 2 3` — restrict (P4 is slowest; restrict when debugging P1 forward parity) +- `--dry-run` — see above; always first +- `--no-hf-reference` / `--no-ht-reference` — skip HF / HT comparison (faster, lower confidence) +- `--reverify` — re-test `status==1` +- `--retry-failed` — re-test `status==3` (read existing `note` first) + +Batch flags (`--architectures`, `--per-arch`, `--limit`, `--resume`) don't apply to `--model ` — use [§Canonical invocations](../../transformer_lens/tools/model_registry/AGENTS.md#canonical-invocations). + +## Interpreting the output + +Hard thresholds (`_MIN_PHASE_SCORES` in `verify_models.py`): + +| Phase | Min score | Required tests | Below = | +|---|---|---|---| +| 1 | 100% | — | `STATUS_FAILED` | +| 2 | 75% | `logits_equivalence`, `loss_equivalence` | `STATUS_FAILED` | +| 3 | 75% | `logits_equivalence`, `loss_equivalence` | `STATUS_FAILED` | +| 4 | 50% | — | **Non-gating** — adds `"low text quality"` to `note`; never fails. | +| 7 | 75% | `multimodal_forward` | `STATUS_FAILED`. NULL = fail. | +| 8 | 75% | `audio_forward` | `STATUS_FAILED`. NULL = fail. | + +`STATUS_VERIFIED` means hard gates passed. `note` carries quality flags or failure details. + +**Adapter-author caveat:** P4's 50% bar is intentionally lenient (coherence, not correctness). P4 well below 100% on a small parity-test model can indicate a real bug the system doesn't gate on — most often a missing [`preprocess_weights` fold](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#when-to-override-preprocess_weights) or wrong [`default_prepend_bos`](../../transformer_lens/model_bridge/supported_architectures/AGENTS.md#tokenizer-policy). Investigate even on VERIFIED. + +Full reference: [§Phase-score thresholds](../../transformer_lens/tools/model_registry/AGENTS.md#phase-score-thresholds). + +## Hard rules + +**Use `verify_models`, never `main_benchmark`** — only `verify_models` writes `data/supported_models.json` ([tools/model_registry/AGENTS.md](../../transformer_lens/tools/model_registry/AGENTS.md)). + +One model at a time — concurrent loads OOM. Report actual per-phase scores; investigate failures per [AGENTS.md §10](../../AGENTS.md#10-hard-rules). diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 000000000..992d4abaf --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,6 @@ +{ + "permissions": { + "allow": [], + "deny": [] + } +} diff --git a/.cursor/rules/transformerlens.mdc b/.cursor/rules/transformerlens.mdc new file mode 100644 index 000000000..845d076da --- /dev/null +++ b/.cursor/rules/transformerlens.mdc @@ -0,0 +1,21 @@ +--- +description: TransformerLens project conventions for Cursor agents. +alwaysApply: true +--- + +Read `AGENTS.md` at the repo root before doing any work. It is the single source of truth for project conventions, quickstart commands, repo layout, hook-naming rules, the HookedTransformer ↔ TransformerBridge mirroring rule, PR conventions, and hard rules. + +Sub-folder `AGENTS.md` files apply when you're working in those directories — read them too: + +- `tests/AGENTS.md` — tier placement, conftest hierarchy, MPS rules +- `transformer_lens/model_bridge/supported_architectures/AGENTS.md` — adapter contract, starter-adapter table, 4-place registration +- `transformer_lens/tools/model_registry/AGENTS.md` — `verify_models` workflow, the `main_benchmark` trap + +Quick reminders that override common defaults: + +- Use `uv`, not `pip` or `poetry`. Install with `uv sync`; run commands with `uv run …` or `make` targets. +- This repo has two parallel systems (`HookedTransformer` legacy and `TransformerBridge` v3). Changes to HookedTransformer that have equivalents in TransformerBridge must be mirrored to TransformerBridge. +- Base PRs against `dev`, not `main`. Never name a branch `main` or `dev`. +- No pre-commit hook is installed. Run `make format` and `uv run mypy .` manually before push. +- Source `.env` (e.g. `set -a; source .env; set +a`) before any HuggingFace-Hub-hitting command. +- Never add `# type: ignore`, never dismiss failing tests as "pre-existing", never add platform skips to dodge CI, never claim drift is "fp noise" without empirical evidence. diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..19a8b9a2e --- /dev/null +++ b/.env.example @@ -0,0 +1,20 @@ +# TransformerLens environment variables. +# Copy to .env (gitignored) and fill in. Then source before HF-Hub-hitting commands: +# set -a; source .env; set +a + +# HuggingFace access token. Required for gated models: Llama, Mistral, Gemma, +# gated Qwen variants, etc. Create at https://huggingface.co/settings/tokens +HF_TOKEN= + +# Optional: retry HuggingFace Hub 429s. Only matters for ad-hoc non-pytest scripts +# (docs build, demo notebooks, one-off `python -m ...` invocations). The pytest +# suite already enables retries unconditionally via tests/conftest.py's +# _enable_hf_retry_for_tests fixture, and CI sets this env var in checks.yml. +# TRANSFORMERLENS_HF_RETRY=1 + +# Optional: allow MPS device on macOS (off by default to avoid hard-to-debug +# divergence from CUDA / CPU). CI's mps-checks job sets this. +# TRANSFORMERLENS_ALLOW_MPS=1 + +# Optional: silence the tokenizers parallelism warning when running under uv. +# TOKENIZERS_PARALLELISM=false diff --git a/.github/ISSUE_TEMPLATE/verify-model.md b/.github/ISSUE_TEMPLATE/verify-model.md new file mode 100644 index 000000000..3e8292a53 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/verify-model.md @@ -0,0 +1,51 @@ +--- +name: Verify model support +about: Track a request to run verify_models on a specific model that someone without appropriate hardware can't run themselves +title: "[Verify Model] org/model-id" +labels: verification-request + +--- + + + +## Model + +- HF repo: `https://huggingface.co/REPLACE_WITH_MODEL_ID` +- Architecture class (from `config.architectures[0]`): `REPLACE_WITH_HF_ARCH_CLASS` +- Estimated parameters: `REPLACE_WITH_PARAM_COUNT` +- Projected memory: `REPLACE_WITH_GB` GB (from `verify_models --dry-run`) +- Gated repo: `yes / no` (HF_TOKEN required: `yes / no`) + +## Registry state + +- Architecture adapter exists: `yes` (file: `transformer_lens/model_bridge/supported_architectures/REPLACE_WITH_ADAPTER.py`) +- Currently in `data/supported_models.json`: + - [ ] No — entry added in the PR linked from this issue + - [ ] Yes, status: `REPLACE_WITH_STATUS` (0=unverified, 2=skipped, 3=failed) + +## Motivation + + + + +## How to run + +```bash +set -a; source .env; set +a +uv run python -m transformer_lens.tools.model_registry.verify_models --model REPLACE_WITH_MODEL_ID +``` + +For the full workflow, see [Creating Architecture Adapters in contributing.md](../../docs/source/content/contributing.md#creating-architecture-adapters). + +## Result + + diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 000000000..bcb46e53e --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,26 @@ +# GitHub Copilot Instructions + +**Read [AGENTS.md](../AGENTS.md) at the repo root for the full set of project conventions, quickstart commands, repo layout, and hard rules — start with its TL;DR section.** This file inlines the highest-friction defaults Copilot most often gets wrong. + +## Top rules to remember + +1. **Use `uv`, not `pip` or `poetry`.** `uv sync` to install; `uv run ` or a `make` target to run anything. +2. **Mirror `HookedTransformer` → `TransformerBridge`** in the same PR when behaviour exists in both. The HT registry [`transformer_lens/supported_models.py`](../transformer_lens/supported_models.py) is HT-only — Bridge-only models go in the Bridge registry under [`transformer_lens/tools/model_registry/`](../transformer_lens/tools/model_registry/). +3. **Base PRs against `dev`**, not `main`. PRs to `main` are maintainer-only. + +## Common commands + +```bash +uv sync # install +make unit-test # fast tests +make format # pycln + isort + black +uv run mypy . # type check +uv run docs-hot-reload # live docs preview +``` + +## Copilot-specific anti-patterns + +- Don't add `# type: ignore`. Prefer `isinstance` / `typing.cast`. +- Don't dismiss failing tests as "pre-existing" — investigate every failure. + +The full set of hard rules (numerics, parallel benchmarks, plan-file references, etc.) lives in [AGENTS.md §10](../AGENTS.md#10-hard-rules). diff --git a/.gitignore b/.gitignore index 5b2752d66..570bd085c 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,8 @@ docs/source/generated .venv .env .adapter-workspace -.claude +.claude/* +!.claude/settings.json +!.claude/commands/ .adapter-progress.json transformer_lens/tools/model_registry/data/verification_checkpoint.json diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..10fbef1a2 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,245 @@ +# AGENTS.md + +Guidance for AI coding agents contributing to **TransformerLens**. + +This file is the single source of truth. Vendor-specific files ([CLAUDE.md](CLAUDE.md), [.github/copilot-instructions.md](.github/copilot-instructions.md), [.cursor/rules/transformerlens.mdc](.cursor/rules/transformerlens.mdc)) and sub-folder `AGENTS.md` files all defer here. + +> **Just want to use the library?** Start with the [README](README.md). **Just want to run the tests?** Skip to [§3 Quickstart](#3-quickstart). **Contributing?** Read on. + +## TL;DR + +1. **Use `uv`**, not `pip` or `poetry` (`uv sync`). +2. **Source `.env`** (`set -a; source .env; set +a`) before any HF-Hub command. +3. **Base PRs on `dev`**, not `main`. Never name a branch `main` or `dev`. +4. **Mirror HookedTransformer → TransformerBridge** when behaviour exists in both ([§2](#2-two-systems-live-in-this-repo)). +5. **`make format` + `uv run mypy .` before push** — no pre-commit hook. +6. **Never add `# type: ignore`** ([§10](#10-hard-rules)). +7. **Never dismiss a failing test as "pre-existing"** ([§10](#10-hard-rules)). + +Sub-folder rules: [tests/AGENTS.md](tests/AGENTS.md) · [supported_architectures/AGENTS.md](transformer_lens/model_bridge/supported_architectures/AGENTS.md) · [tools/model_registry/AGENTS.md](transformer_lens/tools/model_registry/AGENTS.md). + +--- + +## 1. What this repo is + +**TransformerLens** — mechanistic-interpretability library. Loads 9,000+ models across 50+ architecture families (see [supported_models.json](transformer_lens/tools/model_registry/data/supported_models.json)) and exposes internal activations through a hook system for caching, editing, and ablating intermediate state. Built on HuggingFace `transformers`. + +## 2. Two systems live in this repo + +| System | Status | Lives in | Numerics | Registry | +|---|---|---|---|---| +| **`TransformerBridge`** | v3 — default for new work | [transformer_lens/model_bridge/](transformer_lens/model_bridge/) | Raw HF weights by default; `bridge.enable_compatibility_mode()` for HT-equivalent | [transformer_lens/tools/model_registry/data/supported_models.json](transformer_lens/tools/model_registry/data/supported_models.json) | +| **`HookedTransformer`** | Legacy, maintenance mode, deprecated in 3.0 | [transformer_lens/HookedTransformer.py](transformer_lens/HookedTransformer.py) + [transformer_lens/components/](transformer_lens/components/) | Folds LayerNorm + centres weights → does NOT match HF | [transformer_lens/supported_models.py](transformer_lens/supported_models.py) (**HT-only**) | + +> ⚠ The **HookedTransformer acceptance suite is quarantined** ([test_hooked_transformer.py](tests/acceptance/test_hooked_transformer.py), [test_hooked_encoder.py](tests/acceptance/test_hooked_encoder.py), [test_hooked_encoder_decoder.py](tests/acceptance/test_hooked_encoder_decoder.py); see [QUARANTINES.md](tests/QUARANTINES.md)). HT changes land untested at the acceptance level — extra manual care required. + +Bridge architecture-adapter pattern: each HF architecture has one file in [supported_architectures/](transformer_lens/model_bridge/supported_architectures/) mapping HF module paths to canonical names. Bridge hooks are architecture-native (e.g. `blocks.{i}.hook_out`); HT-style aliases live in [bridge.py](transformer_lens/model_bridge/bridge.py). + +**Mirroring rule:** if you change `HookedTransformer` behaviour that has a `TransformerBridge` counterpart, update both in the same PR. [supported_models.py](transformer_lens/supported_models.py) is HT-only — Bridge-only models go in the Bridge registry data file. + +## 3. Quickstart + +```bash +# Bootstrap uv if missing +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install — uv only (not pip, not poetry) +uv sync +source .venv/bin/activate + +# First-time only: create .env from the template, then fill in HF_TOKEN +cp .env.example .env + +# Source HF token before any HF-Hub-hitting command +set -a; source .env; set +a + +# Tests +make unit-test # fast, no model loads +make integration-test # cross-component +make acceptance-test # end-to-end +make docstring-test # doctest + doctest-plus +make notebook-test # slow; subset run in CI +make test-pr # unit + docstring + acceptance + integration (PR-review surface) +make test # everything (long; includes benchmarks + notebooks) + +# Format + typecheck — no pre-commit hook; run manually before push +make format # pycln + isort + black +make check-format # CI-equivalent check +uv run mypy . + +# Docs +uv run docs-hot-reload # live preview +uv run build-docs # build to docs/build/ +``` + +Python: **>=3.10, <4.0**. CI tests 3.10, 3.11, 3.12. Format/type/docstring checks run on 3.12. + +**Windows:** use WSL2 — the bash invocations above (`source`, `set -a; source .env; set +a`, heredocs in slash commands) don't translate to PowerShell. + +## 4. Repo map + +| Path | What's there | +|---|---| +| [transformer_lens/](transformer_lens/) | Core package | +| [transformer_lens/HookedTransformer.py](transformer_lens/HookedTransformer.py) | Legacy `HookedTransformer` API | +| [transformer_lens/HookedEncoder.py](transformer_lens/HookedEncoder.py), [HookedEncoderDecoder.py](transformer_lens/HookedEncoderDecoder.py), [HookedAudioEncoder.py](transformer_lens/HookedAudioEncoder.py) | Encoder-only / seq2seq / audio variants | +| [transformer_lens/model_bridge/](transformer_lens/model_bridge/) | `TransformerBridge` system | +| [transformer_lens/model_bridge/supported_architectures/](transformer_lens/model_bridge/supported_architectures/) | One adapter file per HF architecture | +| [transformer_lens/model_bridge/generalized_components/](transformer_lens/model_bridge/generalized_components/) | Bridge-side reusable components | +| [transformer_lens/components/](transformer_lens/components/) | HT-side components (attention, MLP, LN, embed) | +| [transformer_lens/factories/](transformer_lens/factories/) | `architecture_adapter_factory.py`, `mlp_factory.py`, `activation_function_factory.py` | +| [transformer_lens/config/](transformer_lens/config/) | `HookedTransformerConfig` and `TransformerBridgeConfig` | +| [transformer_lens/utilities/](transformer_lens/utilities/) | Device management, weight processing, HF utilities | +| [transformer_lens/hook_points.py](transformer_lens/hook_points.py) | `HookPoint` class and `LensHandle` | +| [transformer_lens/supported_models.py](transformer_lens/supported_models.py) | **HT-only** registry (`OFFICIAL_MODEL_NAMES`, `MODEL_ALIASES`) | +| [transformer_lens/tools/model_registry/](transformer_lens/tools/model_registry/) | Bridge-side registry + `verify_models.py` benchmark suite | +| [transformer_lens/patching.py](transformer_lens/patching.py), [evals.py](transformer_lens/evals.py) | Activation patching, IOI, ROME, etc. | +| [tests/unit/](tests/unit/), [tests/integration/](tests/integration/), [tests/acceptance/](tests/acceptance/), [tests/benchmarks/](tests/benchmarks/), [tests/mps/](tests/mps/) | Test tiers | +| [demos/](demos/) | Jupyter notebooks; a subset runs in CI under `nbval` with sanitization from [demos/doc_sanitize.cfg](demos/doc_sanitize.cfg) | +| [docs/source/content/](docs/source/content/) | Sphinx markdown sources | +| [docs/source/content/adapter_development/](docs/source/content/adapter_development/) | Adapter-authoring guides — read these before adding a new architecture | +| [makefile](makefile) | Canonical test/format/docs targets | +| [pyproject.toml](pyproject.toml) | Deps, pytest / mypy / format / build config | +| [.github/workflows/checks.yml](.github/workflows/checks.yml) | CI gates | +| [.github/PULL_REQUEST_TEMPLATE.md](.github/PULL_REQUEST_TEMPLATE.md) | PR template + base-branch rules | + +## 5. Hook naming — HT vs Bridge + +- **HT canonical**: uniform across architectures — `hook_embed`, `blocks.{i}.hook_resid_pre`, `blocks.{i}.attn.hook_q`, `blocks.{i}.hook_resid_post`. +- **Bridge-native**: architecture-shaped — `blocks.{i}.hook_out`, `blocks.{i}.attn.q.hook_out`. HT aliases registered via `build_alias_to_canonical_map()` in [bridge.py](transformer_lens/model_bridge/bridge.py). + +Prefer Bridge-native names in new code. Raw-HF-forward drivers comparing against `boot_transformers` must match its load configuration (fp32, eager attention) and probe for optional features like `resid_mid` rather than assume. + +## 6. Adding a model + +Adapters are written **per architecture family**, not per individual model — adding `gpt2` registers all GPT-2 variants. Full workflow (starter-adapter table, 4-place registration, common gotchas, anti-patterns): **[supported_architectures/AGENTS.md](transformer_lens/model_bridge/supported_architectures/AGENTS.md)**. Verification flow: **[tools/model_registry/AGENTS.md](transformer_lens/tools/model_registry/AGENTS.md)**. Claude Code users: invoke `/add-model-support `. + +## 7. Prioritization + +When picking solutions to any problem, prioritize by **research impact**, not implementation ease. A correct, broadly-applicable feature is worth more than a one-off shortcut. + +## 8. PR conventions + +**Base branch**: +- `dev` — default for all PRs (new features, refactors, docs, most bug fixes). +- `main` — **only** for bug fixes against the currently-released version. PRs to `main` are made by maintainers; request permission before basing off `main`. + +**Branch naming**: +- Do NOT name your branch `main` or `dev` — these conflict with the canonical branches when maintainers periodically refresh PRs against upstream. +- Include the word `docs` in the branch name if your PR is primarily a docs change (this triggers the docs build job). +- New branches must track their own remote — `git push -u origin ` from your branch, not from `main`/`dev`. + +**Pre-push checks**: no pre-commit hook — `make format` + `uv run mypy .` manually. + +**PR template**: [.github/PULL_REQUEST_TEMPLATE.md](.github/PULL_REQUEST_TEMPLATE.md). No conventional-commits enforcement. + +**Changelog**: no per-PR file. Note user-facing changes in your PR description and commit messages — release essays in [news/](docs/source/content/news/) and GitHub Releases are drafted by a maintainer from those at release time. + +## 9. CI gates a PR must pass + +From [.github/workflows/checks.yml](.github/workflows/checks.yml): + +| Job | Runs | When it fails, run locally | +|---|---|---| +| `compatibility-checks` | `make unit-test` + `make acceptance-test` + `uv build` × py 3.10 / 3.11 / 3.12 | `make unit-test` / `make acceptance-test` / `uv build`. Repro Python-specific failures with `uv python install ` then `uv run --python pytest …` | +| `mps-checks` | macOS MPS unit + integration + smoke (PRs targeting `main` or pushes to `main` only) | `TRANSFORMERLENS_ALLOW_MPS=1 uv run pytest tests/mps` on a Mac with MPS | +| `format-check` | `make check-format` | `make format` (writes the fix) then commit | +| `type-check` | `uv run mypy .` | `uv run mypy .` — fix the error; never add `# type: ignore` | +| `docstring-test` | `make docstring-test` | `make docstring-test` — failures are doctest mismatches in `transformer_lens/` | +| `coverage-test` | Full suite + coverage artifact | `make test-pr` covers most of this; full reproduction with `make coverage-report-test` | +| `notebook-checks` | `nbval` over subset of `demos/*.ipynb`; `HF_TOKEN`-gated notebooks skip when secret absent | `uv run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/.ipynb` | +| `build-docs` | Sphinx build (push to `main`/`dev` or branch containing `docs`) | `uv run build-docs` (sources `.env` if needed for HF-gated doctest examples) | +| `deploy-docs` | GitHub Pages (push to `main` only) | Maintainer-only; rarely the contributor's fault — usually a `build-docs` artifact issue | + +In-progress PR runs cancel on new commit; tag/release runs do not. + +## 10. Hard rules + +**Test failures:** + +- Never dismiss a failing test as "pre-existing" — investigate every failure. +- Never fabricate test counts or claim passes without running. Reviewers re-run. +- Never add `@pytest.mark.skip` / `xfail` / platform-skip to make CI green. Debug the bug. +- If new tests surface pre-existing bugs, fix them — don't punt. + +**Code quality:** + +- Never add `# type: ignore` — use `isinstance` / `typing.cast`. +- Comments: terse one-liners; inline comments explain WHY, not WHAT. +- Never reference plan-file details (audit IDs, "see plan section X") in source. + +**Numerical work:** + +- Never claim drift is "fp noise" without empirical evidence — bugs and accumulated rounding look identical at noise scale. +- Never run model verification or benchmarks in parallel — a single CUDA/MPS device OOMs. + +**Environment:** + +- `uv` only — never `pip` or `poetry`. Run via `uv run ` or `make`. +- Source `.env` before any HF-Hub-hitting command. `HF_TOKEN` required for Llama, Mistral, Gemma, gated Qwen. +- No pre-commit hook — `make format` + `uv run mypy .` manually before push. + +## 11. "Done" checklist + +Before declaring a task complete: + +| Task type | Must do | +|---|---| +| Bug fix | Reproduce in a test, fix, confirm test passes, `make format`, `uv run mypy .` | +| New adapter | Adapter file + 4-place registration + integration test asserting HF logit parity + `verify_models` run (see [supported_architectures/AGENTS.md](transformer_lens/model_bridge/supported_architectures/AGENTS.md)) | +| Docs change | `uv run build-docs` succeeds; branch name contains `docs` so the docs job triggers in CI | +| Notebook change | `pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/.ipynb` passes locally | +| Anything else | `make format` + `uv run mypy .` + `make test-pr` clean before push | + +Claude Code: `/task-complete` automates the last row. See [§15 Workflow shortcuts](#15-workflow-shortcuts). + +## 12. Pointers for further reading + +- [docs/source/content/migrating_to_v3.md](docs/source/content/migrating_to_v3.md) — HT → Bridge migration recipes +- [docs/source/content/adapter_development/](docs/source/content/adapter_development/) — adapter authoring deep dive +- [docs/source/content/compatibility_mode.md](docs/source/content/compatibility_mode.md) — when to call `bridge.enable_compatibility_mode()`, what each flag does, four-quadrant test matrix +- [docs/source/content/debugging_numerical_divergence.md](docs/source/content/debugging_numerical_divergence.md) — bisection workflow for HT-vs-Bridge / Bridge-vs-HF logit drift +- [tests/QUARANTINES.md](tests/QUARANTINES.md) — inventory of every `skip` / `xfail` and when each can be un-skipped + +## 13. Local-only conventions + +Gitignored, first-class (entries checked in): + +- **`transformer_lens/scratch.py`** — one-off bisection scripts, ad-hoc imports. +- **`.adapter-workspace/`** — adapter WIP (notes, dumps, repros). + +## 14. Upstream dependency pins + +Load-bearing pins live in [pyproject.toml](pyproject.toml): + +| Pin | Where | Why it matters | +|---|---|---| +| `transformers>=5.4.0` | `[project] dependencies` | The Bridge adapter contract is written against HF module layouts; every minor HF release can break adapter component-mappings. Bumping is a real test pass. | +| `torch>=2.6` | `[project] dependencies` | Hook system relies on PyTorch's forward / backward hook semantics; major torch bumps occasionally change ordering. | +| `accelerate>=0.23.0` | `[project] dependencies` | Required for Llama-family loading. | +| `numpy>=1.24` / `>=1.26` | `[project] dependencies` (python-version-conditional) | Doctest float formatting can drift across NumPy versions. | +| `isort==5.8.0` | `[dependency-groups] dev` (exact) | Format check pins to exactly this version; a bump flips the formatting of every file. | + +**Bumping upstream pins:** + +1. Bump in `pyproject.toml`, `uv lock`. +2. `make test-pr` — expect real breakages, don't pin around them. +3. `uv run python -m transformer_lens.tools.model_registry.verify_models --architectures ` to catch numerical regressions. +4. Land in a focused PR — pin bumps are reviewed separately (wide blast radius). + +HF specifically: a `transformers` minor bump that adds or renames an architecture impacts `_HF_PASSTHROUGH_ATTRS` in [`_bridge_builder.py`](transformer_lens/model_bridge/sources/_bridge_builder.py) and `SUPPORTED_ARCHITECTURES` in [`architecture_adapter_factory.py`](transformer_lens/factories/architecture_adapter_factory.py). + +## 15. Workflow shortcuts + +Claude Code users have slash commands in [.claude/commands/](.claude/commands/) that wrap common workflows. Non-Claude agents (Cursor, Codex, Copilot, Aider, etc.) can run the manual equivalents: + +| Claude shortcut | Manual equivalent | Reference | +|---|---|---| +| `/test-unit` | `make unit-test` | [§3 Quickstart](#3-quickstart) | +| `/test-all` | `make test` (long) | [§3 Quickstart](#3-quickstart) | +| `/format` | `make format && uv run mypy .` | [§3 Quickstart](#3-quickstart) | +| `/typecheck` | `uv run mypy .` | [§3 Quickstart](#3-quickstart) | +| `/build-docs` | `set -a; source .env; set +a && uv run build-docs` | [§3 Quickstart](#3-quickstart) | +| `/verify-model ` | `set -a; source .env; set +a` then `uv run python -m transformer_lens.tools.model_registry.verify_models --model --dry-run` first, confirm, then drop `--dry-run` | [tools/model_registry/AGENTS.md](transformer_lens/tools/model_registry/AGENTS.md) | +| `/add-model-support ` | Follow the 4-way branch + adapter-authoring workflow | [supported_architectures/AGENTS.md](transformer_lens/model_bridge/supported_architectures/AGENTS.md) | +| `/task-complete` | `make format && uv run mypy . && make test-pr` (loop until clean) | [§11 Done checklist](#11-done-checklist) | diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..f075c399b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,34 @@ +# CLAUDE.md + +**Read [AGENTS.md](AGENTS.md) first** — single source of truth. Everything below is Claude-Code-specific. + +## Slash commands ([.claude/commands/](.claude/commands/)) + +- `/test-unit` — `make unit-test` +- `/test-all` — `make test` (long; benchmarks + notebooks too) +- `/format` — `make format` + `uv run mypy .` +- `/typecheck` — `uv run mypy .` +- `/build-docs` — sources `.env`, runs `uv run build-docs` +- `/verify-model ` — guided single-model `verify_models.py` run +- `/add-model-support ` — checklist-driven new-adapter workflow +- `/task-complete` — end-of-task gate: clean new comments, `/format`, `make test-pr` (unit + docstring + acceptance + integration), loop until clean. Skips notebook + benchmark tiers. + +## Settings + +- [.claude/settings.json](.claude/settings.json) is the checked-in scaffold (hooks, env); ships with no permission allowlist. +- **First-time permission setup**: every command prompts on first use. Either (1) run `/fewer-permission-prompts` to populate `.claude/settings.local.json` from your transcripts, or (2) approve with "Don't ask again." Both write to the gitignored local file. +- `.claude/settings.local.json`, `.claude/agents/`, `.claude/worktrees/` are gitignored — per-user. + +## Pointers + +- [AGENTS.md §10](AGENTS.md#10-hard-rules) — hard rules; load-bearing. +- [AGENTS.md §2](AGENTS.md#2-two-systems-live-in-this-repo) — HT → Bridge mirroring; most common PR-review pushback. +- [tests/QUARANTINES.md](tests/QUARANTINES.md) — check before debugging any failing test. The macOS-arm64 KV-cache skip is the most common time-sink. +- [debugging_numerical_divergence.md](docs/source/content/debugging_numerical_divergence.md) — Bridge-vs-HF logit drift bisection. +- [compatibility_mode.md](docs/source/content/compatibility_mode.md) — `bridge.enable_compatibility_mode()` contract; read before adding tests that use it. + +## Starter tasks + +- [`good first issue`](https://github.com/TransformerLensOrg/TransformerLens/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) · [`help wanted`](https://github.com/TransformerLensOrg/TransformerLens/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) +- [`verification-request`](https://github.com/TransformerLensOrg/TransformerLens/issues?q=is%3Aissue+is%3Aopen+label%3Averification-request) — models awaiting verification; pick one that fits your machine, run `/verify-model `. +- Backfilling per-adapter unit tests in [`tests/unit/model_bridge/supported_architectures/`](tests/unit/model_bridge/supported_architectures/) is high-leverage; copy a sibling adapter test. diff --git a/README.md b/README.md index 28c47cecb..d909863c6 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,9 @@ interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style lan goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. -TransformerLens lets you load in 50+ different open source language models, and exposes the internal -activations of the model to you. You can cache any internal activation in the model, and add in -functions to edit, remove or replace these activations as the model runs. +TransformerLens lets you load in 9,000+ open source language models across 50+ architecture families, +and exposes the internal activations of the model to you. You can cache any internal activation in the +model, and add in functions to edit, remove or replace these activations as the model runs. ## Quick Start @@ -52,7 +52,7 @@ logits, activations = bridge.run_with_cache("Hello World") > Gated models (Llama, Mistral, Gemma, ...) require `HF_TOKEN` in your environment. See [Environment Variables](https://TransformerLensOrg.github.io/TransformerLens/content/getting_started.html#environment-variables) for the full list. -`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. By default it preserves raw HuggingFace weights – logits and activations match HF, *not* legacy `HookedTransformer` (which folds LayerNorm and centers weights by default). Call `bridge.enable_compatibility_mode()` after booting for HookedTransformer-equivalent numerics. The legacy `HookedTransformer.from_pretrained` API is still available but deprecated — see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide. +`TransformerBridge` is the recommended 3.0 path and supports 9,000+ models across 50+ architecture families (see [`supported_models.json`](transformer_lens/tools/model_registry/data/supported_models.json) for the full inventory). By default it preserves raw HuggingFace weights – logits and activations match HF, *not* legacy `HookedTransformer` (which folds LayerNorm and centers weights by default). Call `bridge.enable_compatibility_mode()` after booting for HookedTransformer-equivalent numerics. The legacy `HookedTransformer.from_pretrained` API is still available but deprecated — see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide. ## Key Tutorials diff --git a/docs/source/content/compatibility_mode.md b/docs/source/content/compatibility_mode.md new file mode 100644 index 000000000..de6eb49ab --- /dev/null +++ b/docs/source/content/compatibility_mode.md @@ -0,0 +1,94 @@ +# TransformerBridge Compatibility Mode + +`TransformerBridge.boot_transformers(...)` returns a bridge whose **default numerics match HuggingFace** — raw weights, no folding, no centering. Calling `bridge.enable_compatibility_mode()` afterwards puts the bridge into **HookedTransformer-equivalent numerics** — weights folded, centered, and the legacy hook aliases registered. + +Most research code that was written against `HookedTransformer.from_pretrained(...)` assumes compatibility mode. Most new code that needs HF-faithful logits does not. + +> Source: [`transformer_lens/model_bridge/bridge.py:enable_compatibility_mode`](../../../transformer_lens/model_bridge/bridge.py). + +--- + +## When to enable it + +| Use case | Compatibility mode? | Why | +|---|---|---| +| Logit lens / direct logit attribution | **Yes** | These analyses reason in the post-fold-LN coordinate system; raw HF weights produce different (wrong) attributions. | +| Residual-stream norm analysis | **Yes** | Centered weights give the residual a meaningful zero. | +| Circuit analysis using HT-style hook names (`blocks.{i}.attn.hook_q`, `hook_resid_pre`, etc.) | **Yes** | Legacy aliases register only after compat mode. | +| Logit parity against HuggingFace | **No** | Folding changes weights; logits will not match HF. | +| Generation / inference vs HF baseline | **No** | Same reason. | +| Verifying a new adapter's forward pass | **No (initially)** | Use `enable_compatibility_mode(no_processing=True)` to get hook aliases without weight processing — isolates forward-pass bugs from weight-processing bugs. | + +## What each flag does + +```python +bridge.enable_compatibility_mode( + disable_warnings: bool = False, + no_processing: bool = False, + fold_ln: bool = True, + center_writing_weights: bool = True, + center_unembed: bool = True, + fold_value_biases: bool = True, + refactor_factored_attn_matrices: bool = False, +) +``` + +| Flag | Default | Effect | +|---|---|---| +| `no_processing` | `False` | If `True`, **overrides all other processing flags to False** — registers the legacy hook aliases only, leaves weights raw. The "I want HT hook names but HF numerics" mode. | +| `fold_ln` | `True` | Folds LayerNorm scale + bias into the subsequent linear weights so the LayerNorm modules become pure normalization. Changes weights; mathematically equivalent. | +| `center_writing_weights` | `True` | Subtracts the mean from each "writing" weight (`W_out` in attention, MLP-down). Makes residual contributions sum to zero per layer, which makes residual-stream norms interpretable. | +| `center_unembed` | `True` | Subtracts the mean from the unembedding matrix. Logits become mean-zero — affects logit-lens output but not argmax. | +| `fold_value_biases` | `True` | Folds attention value biases into the output bias. Same numerics, fewer parameters. | +| `refactor_factored_attn_matrices` | `False` | Refactors `W_Q @ W_K.T` and `W_V @ W_O` for analysis. Off by default because it's slow and only matters for specific factored-matrix research. | +| `disable_warnings` | `False` | Suppresses warnings emitted by legacy component aliases when accessed. | + +After processing, the bridge **also**: + +- Re-initializes the hook registry. +- Calls `_setup_hook_compatibility()` on every component (installs HT-style hook conversions like reshaping `hook_z` from `[batch, seq, d_model]` to `[batch, seq, n_heads, d_head]`). +- Registers HT-style hook aliases recursively across blocks. + +`compatibility_mode` is then `True` on the bridge and on every component, so subsequent operations behave as if the bridge were loaded by `HookedTransformer.from_pretrained()`. + +## Hook semantic parity + +After `enable_compatibility_mode()`, these HT hook names fire on the **pre-norm residual** (matching HookedTransformer semantics): + +- `blocks.{i}.attn.hook_q_input`, `hook_k_input`, `hook_v_input` +- `blocks.{i}.hook_attn_in` +- `blocks.{i}.hook_mlp_in` (gated on `cfg.use_hook_mlp_in`; toggle via `bridge.set_use_hook_mlp_in(True)`) + +**Carve-outs** ([issue #1317](https://github.com/TransformerLensOrg/TransformerLens/issues/1317)): + +- **Post-norm architectures** (OLMo 2, BERT-style) read the **post-attention residual** instead, because the norm semantically lives elsewhere in the block. +- **MLA blocks** (DeepSeek V2 / V3 / R1) do **not** expose the split-qkv aliases — MLA's compressed K/V doesn't have a clean split. + +An adapter author for a new post-norm or MLA-style architecture must handle these carve-outs in `setup_hook_compatibility`. The Gemma1/Gemma2 adapters are exemplars of when **not** to override `setup_hook_compatibility` — `GemmaTextScaledWordEmbedding` already scales internally, so any added `hook_conversion` would double-scale `embed.hook_out`. + +## The four-quadrant test matrix + +The integration conftest at [`tests/integration/model_bridge/conftest.py`](../../../tests/integration/model_bridge/conftest.py) provides four bridge variants for every test model: + +| Variant | `compatibility_mode` | `no_processing` | Tests… | +|---|---|---|---| +| `gpt2_bridge` | off | n/a | HF-faithful numerics | +| `gpt2_bridge_compat` | on | `False` | HT-equivalent numerics | +| `gpt2_bridge_compat_no_processing` | on | `True` | Hook aliases without weight processing — used to bisect numerical bugs | +| (HT side) `gpt2_hooked_processed`, `gpt2_hooked_unprocessed` | n/a | n/a | Reference HookedTransformer with/without weight processing | + +New integration tests should use the variant that matches the property they're testing. Tests of HF parity → `gpt2_bridge`. Tests of HT-API behaviour → `gpt2_bridge_compat`. Tests of hook semantics regardless of weights → `gpt2_bridge_compat_no_processing`. + +## Cost + +`enable_compatibility_mode()` mutates the bridge's weights in-place. It is: + +- **One-shot:** calling it twice re-runs the centering subtractions. Don't. +- **Not reversible** from within the bridge — re-boot for raw weights. +- **`_setup_hook_compatibility` is idempotent**; only `process_weights` mutates weights. + +## See also + +- [Creating Architecture Adapters in contributing.md](contributing.md#creating-architecture-adapters) — adapter contract and the four-place registration; adapter authors override `setup_hook_compatibility` only when an architecture has the post-norm / MLA carve-outs mentioned above. +- [Debugging Numerical Divergence](debugging_numerical_divergence.md) — uses `no_processing=True` as a key bisection tool. +- [Migrating to TransformerLens 3](migrating_to_v3.md) — when porting HT code, you almost always want `enable_compatibility_mode()`. diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index 87a28f6b0..e23986081 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -2,6 +2,8 @@ ```{warning} `HookedTransformer` is deprecated as of TransformerLens 3.0 and will be removed in the next major version. New code should use [`TransformerBridge`](migrating_to_v3.md) instead. Existing `HookedTransformer` code continues to work through the 3.x branch via a compatibility layer. See the [migration guide](migrating_to_v3.md) for conversion recipes. + +The HookedTransformer **acceptance test suite is currently quarantined** due to a CI test-pollution issue (see `tests/QUARANTINES.md` in the repo). Changes that touch HookedTransformer internals therefore land essentially untested at the acceptance level — extra manual care is required until the suite is re-enabled. ``` ## Setup @@ -22,6 +24,8 @@ As of TransformerLens 3.0, this project uses [UV](https://docs.astral.sh/uv/gett uv sync # activate the virtual environment source .venv/bin/activate +# first-time only: create .env from the template, then fill in HF_TOKEN for gated models +cp .env.example .env ``` Dependency groups are defined in `pyproject.toml` under `[dependency-groups]`. The project sets `default-groups = ["dev", "docs", "jupyter"]`, so `uv sync` installs all three out of the box — you do not need to pass `--group` flags for the standard contributor setup. @@ -33,6 +37,12 @@ You can also add individual groups with `uv sync --group `, or install wit Requires Python 3.10 or higher. +**Windows users:** the bash invocations above don't translate to PowerShell. Use WSL2. + +### Environment variables + +Source `.env` (`set -a; source .env; set +a`) before any HuggingFace-Hub-hitting command. `HF_TOKEN` is required for gated models (Llama, Mistral, Gemma, gated Qwen variants). See `.env.example` in the repo root for the full list. + ## Testing If adding a feature, please add unit tests for it. If you need a model, please use one of the ones @@ -43,11 +53,33 @@ quite slow (as we only have CPU actions) so the smaller models like `attn-only-1 ### Running the tests -- Unit tests only via `make unit-test` -- Acceptance tests only via `make acceptance-test` -- Docstring tests only via `make docstring-test` -- Notebook tests only via `make notebook-test` -- Run all test suites mentioned `make test` +- Standard PR-review surface (unit + docstring + acceptance + integration): `make test-pr` +- Just unit tests (fast feedback): `make unit-test` +- Integration tests: `make integration-test` +- Acceptance tests: `make acceptance-test` +- Docstring tests: `make docstring-test` +- Notebook tests: `make notebook-test` +- All test suites including benchmarks + notebooks (long): `make test` + +### Test tiers + +| Tier | Path | Loads models? | Hits HF Hub? | Scope | +|---|---|---|---|---| +| `unit` | `tests/unit/` | None / synthetic (rare exceptions) | No | Function or single module | +| `integration` | `tests/integration/` | 1–2 cached models, module-scoped | Yes | Cross-component | +| `acceptance` | `tests/acceptance/` | Full models, session-scoped | Yes | End-to-end behaviour | +| `benchmarks` | `tests/benchmarks/` | Varies; performance focus | Yes | Throughput / memory | +| `mps` | `tests/mps/` | TinyStories-1M, fp32 only | Yes | macOS-MPS smoke only | + +Rule of thumb: new tests that load a model should land in `integration/` by default. Tests that need real (large) weights to verify go through the model registry's `verify_models` workflow rather than running in pytest. + +The flaky-retry policy (`--reruns 2 --reruns-delay 5`) wraps every `make` target and exists to absorb HF Hub 429s. The root `tests/conftest.py` also enables `enable_hf_retry()` session-wide. + +### Quarantined tests + +Some tests carry persistent `skip` / `skipif` / `xfail` markers — for optional dependencies (LIT, bitsandbytes), hardware requirements (CUDA, MPS, multi-GPU), CI cost / network budget, or upstream platform bugs. The `tests/QUARANTINES.md` file in the repo inventories every one with an "un-skip when…" line. **Before debugging a failing test, check whether it's a known quarantine.** + +The HookedTransformer acceptance suite (`tests/acceptance/test_hooked_transformer.py`, `test_hooked_encoder.py`, `test_hooked_encoder_decoder.py`) is currently a whole-file quarantine — see the warning at the top of this page. ## Formatting @@ -56,9 +88,64 @@ actions. - Format all files via `make format` - Only check the formatting via `make check-format` +- Type-check via `uv run mypy .` Note that `black` line length is set to 100 in `pyproject.toml` (instead of the default 88). +**No pre-commit hook is installed.** Run `make format` and `uv run mypy .` manually before pushing — CI will fail otherwise. + +## Project conventions + +A few conventions have grown out of recurring failure modes. Following them tends to make reviews go faster. + +### Test failures + +Treat every failing test as a signal worth investigating, even if the failure looks unrelated to your change. "Pre-existing failure" is rarely true once you dig in, and the cost of finding out is small compared to shipping a regression. If a new test surfaces a pre-existing bug, it's worth fixing in the same PR rather than deferring. + +Avoid reaching for `@pytest.mark.skip`, `xfail`, or platform-gated `skipif` to make a red CI go green. If a skip is genuinely required — an optional dependency, a hardware requirement, an upstream platform bug — document it in `tests/QUARANTINES.md` so the next person debugging knows why it's quarantined. + +### Code quality + +We avoid `# type: ignore` comments throughout the codebase; if the type system disagrees with you, the usual escape hatches are `isinstance` narrowing or `typing.cast`. Keep comments terse — one-line docstrings, and inline comments that explain *why* code does what it does rather than restating *what* it does. + +### Numerical work + +When a forward pass disagrees with HuggingFace, it's tempting to call the difference "floating-point noise" and move on. The trouble is that genuine bugs and accumulated rounding error are indistinguishable at small magnitudes until you measure. A cheap check: rerun the comparison in `dtype=torch.float64` on both sides. If the diff stays the same magnitude, it's a bug; if it drops by roughly eight orders of magnitude, it was noise. The [Debugging Numerical Divergence](debugging_numerical_divergence.md) guide walks the rest of the bisection. + +Model verification and benchmark runs each load a full model; a single CUDA or MPS device generally lacks the memory to host two at once. Run them serially. + +### Environment and tooling + +Use `uv` rather than `pip` or `poetry` — commands run via `uv run ` or the appropriate `make` target. Source `.env` (`set -a; source .env; set +a`) before any command that hits the HuggingFace Hub. + +## Two systems live in this repo + +The library is mid-transition between two parallel paths: + +| System | Status | Lives in | Numerics | Registry | +|---|---|---|---|---| +| `TransformerBridge` | v3 — default for new work | `transformer_lens/model_bridge/` | Raw HF weights by default; `bridge.enable_compatibility_mode()` for HT-equivalent — see [Compatibility Mode](compatibility_mode.md) | `transformer_lens/tools/model_registry/data/supported_models.json` | +| `HookedTransformer` | Legacy, maintenance mode, deprecated in 3.0 | `transformer_lens/HookedTransformer.py` + `transformer_lens/components/` | Folds LayerNorm + centres weights → does NOT match HF | `transformer_lens/supported_models.py` (**HT-only**) | + +Because the two systems are parallel implementations of the same surface, behavioural changes on one side usually need a matching change on the other. If you change a feature in `HookedTransformer` that has a counterpart in `TransformerBridge` (or vice versa), update both in the same PR — drift between them has historically been a steady source of bugs. The registries are *not* parallel, though: `supported_models.py` is HookedTransformer-only, while Bridge-only models live in the Bridge registry data file under `transformer_lens/tools/model_registry/`. + +## PR conventions + +### Base branch + +- `dev` — default for all PRs (new features, refactors, docs, most bug fixes). +- `main` — **only** for bug fixes against the currently-released version. PRs to `main` are made by maintainers; request permission before basing off `main`. + +### Branch naming + +- Do NOT name your branch `main` or `dev` — these conflict with the canonical branches when maintainers periodically refresh PRs against upstream. +- Include the word `docs` in the branch name if your PR is primarily a docs change (this triggers the docs build job). +- New branches must track their own remote (`git push -u origin ` from your branch, not from `main`/`dev`). + +### Changelog + +There is no per-PR changelog file. **Note user-facing changes in your PR description and commit messages** — release essays in `docs/source/content/news/` and GitHub Releases are drafted by a maintainer by reviewing those at release time. Breaking changes and notable user-facing features should be called out explicitly so the rollup picks them up. + ## Documentation Please make sure to add thorough documentation for any features you add. You should do this directly @@ -170,7 +257,74 @@ Two guides walk through the process: - [Architecture Adapter Creation Guide](adapter_development/adapter-creation-guide.md) — start here. A step-by-step workflow for taking an HF model from unsupported to tested, registered adapter. - [HuggingFace Model Analysis Guide](adapter_development/hf-model-analysis-guide.md) — a reference for reading an HF model's `config.json` and source files to extract the attributes you'll set on `self.cfg`. -Adapters live in `transformer_lens/model_bridge/supported_architectures/.py` and are registered in two places: `supported_architectures/__init__.py` and `factories/architecture_adapter_factory.py`. Both steps are covered in the creation guide. If you want a starter file, copy [adapter-template.py](../_static/adapter-template.py) into `supported_architectures/` and rename it. +Adapters live in `transformer_lens/model_bridge/supported_architectures/.py` and need to be registered in **four** places. Each registration site has a different consequence if you skip it, which is why the next section's invariant test is worth running before you open the PR. + +1. **`transformer_lens/model_bridge/supported_architectures/__init__.py`** — import the adapter class and add it to `__all__`. The package fails to import at boot if this is missed. +2. **`transformer_lens/factories/architecture_adapter_factory.py`** — import the class and add it to the `SUPPORTED_ARCHITECTURES` dict. The dict key must match `config.architectures[0]` exactly; otherwise `boot_transformers` raises "unsupported architecture." +3. **`transformer_lens/tools/model_registry/__init__.py`** — add the architecture name to `HF_SUPPORTED_ARCHITECTURES` (the set) and `CANONICAL_AUTHORS_BY_ARCH` (the foundation-orgs map). Without this, the HF scraper won't discover canonical models for the new architecture. +4. **`transformer_lens/tools/model_registry/generate_report.py`** — add a one-line entry to `ARCHITECTURE_DESCRIPTIONS` so the generated coverage table covers the new architecture. + +If you want a starter file, copy [adapter-template.py](../_static/adapter-template.py) into `supported_architectures/` and rename it. + +After registering, verify the four-place wiring by running: + +```bash +uv run pytest tests/unit/tools/test_model_registry.py -k TestRegistrySyncedWithFactory +``` + +The `TestRegistrySyncedWithFactory` class bidirectionally asserts that `SUPPORTED_ARCHITECTURES`, `HF_SUPPORTED_ARCHITECTURES`, and `CANONICAL_AUTHORS_BY_ARCH` stay in sync — the failure message names exactly which set is missing your new key. + +### Required tests for a new adapter + +Two test layers, both required: + +1. **Unit adapter test** at `tests/unit/model_bridge/supported_architectures/test__adapter.py`. ~26 of these exist; copy the closest sibling. The pattern: a `_make_cfg()` factory, an `adapter` fixture, and one test per architecture-specific quirk. Unit adapter tests instantiate the adapter from a synthetic config and assert structural properties — they don't load weights and don't hit HF Hub. +2. **Integration parity test** at `tests/integration/model_bridge/test__adapter.py`. Loads a real cached HF model and asserts logit parity vs HuggingFace at fp32 + eager attention. + +### Common adapter gotchas + +- **HF raw config attributes are invisible to TL-side consumers unless explicitly propagated to `self.cfg`.** Walk the HF `config.json` and mirror any non-standard knobs (`final_logit_softcapping`, `attn_logit_softcapping`, `query_pre_attn_scalar`, `sliding_window`, `layer_types`, custom `eps_attr` names) onto `self.cfg` so weight processing and forward passes can see them. +- **Some config attrs need both surface-on-cfg AND fold-into-weight** via a `preprocess_weights()` override. The trigger: a numerical operation HF's forward applies natively must also be baked into the raw weights, or `bridge.enable_compatibility_mode()` (which calls `process_weights` on raw weights) produces wrong results. Concrete examples in-tree: Cohere `logit_scale` → `unembed.weight`; Gemma embedding scale (`√d_model`) → `embed.weight`. Skip the fold and Phase 3 / Phase 4 of `verify_models` will silently degrade. +- **Tokenizer policy is per-model, not per-architecture.** Sibling models in the same family routinely differ — the chat-instruct variant may prepend BOS where the base does not, padding side can flip, EOS handling can differ. It's worth re-checking `default_prepend_bos`, padding side, and EOS handling against the specific target rather than copying them from a starter adapter. `tokenizer_config.json` is not always reliable on its own — some architectures (Cohere is a notable example) declare `add_bos_token=False` but HF's `__call__` prepends BOS anyway. The most reliable check is to invoke the tokenizer directly: + + ```python + from transformers import AutoTokenizer + t = AutoTokenizer.from_pretrained("") + print(t("hello").input_ids) # what generation actually uses + print(t.bos_token_id) + ``` + + If `t("hello").input_ids[0] == t.bos_token_id`, set `cfg.default_prepend_bos = True`; otherwise leave the flag unset. +- **Hook names inside adapters are Bridge-native** (e.g., `blocks.{i}.hook_out`). HookedTransformer-style aliases (e.g., `blocks.{i}.hook_resid_post`) are registered elsewhere — in `transformer_lens/model_bridge/bridge.py` via `build_alias_to_canonical_map()`. Adapters declare canonical names only. +- **`ComponentMapping` types do not need `# type: ignore`.** If the type system disagrees, prefer `isinstance` narrowing or `typing.cast`; the project as a whole avoids `# type: ignore`. + +### Verifying a new model + +After your adapter is registered, the model registry's `verify_models` runner exercises it end-to-end: + +```bash +set -a; source .env; set +a +uv run python -m transformer_lens.tools.model_registry.verify_models --model +``` + +`verify_models` runs phases 1–4 (forward correctness vs HF, hook firing + gradients, weight processing, generation quality) and updates `data/supported_models.json` with the resulting status and per-phase scores. We recommend running `--dry-run` first to project memory and parameter count without loading the model, and verifying one model at a time — concurrent loads tend to OOM a single device. + +A note on entry points: `verify_models` is the script that writes the registry. `main_benchmark` runs the same underlying benchmarks but defaults to *not* writing the registry (it requires `--update-registry`, and even then it doesn't record Phase 7 / 8 scores or the resume checkpoint). If you want the registry updated after your run, use `verify_models`. + +It's worth reading the per-phase scores in addition to the final status — the verifier enforces hard pass/fail thresholds, and a model that just clears the bar tells you something different than one that breezes through. The current thresholds: + +| Phase | Min score | Required tests | Effect when below threshold | +|---|---|---|---| +| 1 | 100% | — | Verification fails | +| 2 | 75% | `logits_equivalence`, `loss_equivalence` | Verification fails | +| 3 | 75% | `logits_equivalence`, `loss_equivalence` | Verification fails | +| 4 | 50% | — | **Non-gating** — below 50% adds `"low text quality"` to the registry `note`; never fails verification. | +| 7 | 75% | `multimodal_forward` | Verification fails. A NULL score also fails. | +| 8 | 75% | `audio_forward` | Verification fails. A NULL score also fails. | + +Phase 4 is intentionally lenient — it's a coherence metric, not a correctness check. A sub-100% Phase-4 score on a small parity-test model can still indicate a real adapter bug that the gates don't catch (missing `preprocess_weights` fold, wrong `default_prepend_bos`, and so on); the model can pass verification overall and still be worth a manual look. + +If verification fails by `~1e-3` or more against the HF reference, the bisection workflow lives at [Debugging Numerical Divergence](debugging_numerical_divergence.md). ```{toctree} :hidden: diff --git a/docs/source/content/debugging_numerical_divergence.md b/docs/source/content/debugging_numerical_divergence.md new file mode 100644 index 000000000..01568a3ce --- /dev/null +++ b/docs/source/content/debugging_numerical_divergence.md @@ -0,0 +1,111 @@ +# Debugging Numerical Divergence + +When a Bridge adapter's integration test fails by `~1e-3` (or any larger delta) against the HuggingFace reference, the failure mode is almost always one of a small set of recurring bugs. This page walks the bisection workflow. + +> A note before you start: it's tempting to attribute small drift to "floating-point noise" and move on, but genuine bugs and accumulated rounding error are indistinguishable at small magnitudes until you measure. The [numerical-work conventions in contributing.md](contributing.md#numerical-work) describe the cheap fp64 check that disambiguates the two. + +--- + +## 0. Setup checklist + +- **fp32 + eager attention on both sides.** `dtype=torch.float32, attn_implementation="eager"`. `sdpa` / `flash_attention_2` mask bugs. +- **`enable_compatibility_mode(no_processing=True)`** for the first pass — isolates forward-pass bugs from weight-processing bugs. See [compatibility_mode.md](compatibility_mode.md). +- **Single-token first**, then 5–10, then longer. Most adapter bugs surface single-token. +- **Same seed / no dropout.** A stray `nn.Dropout(p=0.1)` in a generalized component silently de-correlates runs. + +## 1. Bisect by component + +Walk Bridge hooks vs HF `output_hidden_states=True` / `output_attentions=True` and find the first layer where they diverge: + +| Stage | Bridge hook | HF output | +|---|---|---| +| Embedding | `embed.hook_out` | `outputs.hidden_states[0]` | +| Block i pre-attn-norm | `blocks.{i}.ln1.hook_out` | (HF doesn't expose; compute from `hidden_states[i]`) | +| Block i Q / K / V | `blocks.{i}.attn.q.hook_out`, `.k.hook_out`, `.v.hook_out` | (HF doesn't expose; instrument `model.layers[i].self_attn` directly) | +| Block i attention output | `blocks.{i}.attn.hook_out` | `outputs.attentions[i]` (pattern), then hidden_state delta | +| Block i MLP output | `blocks.{i}.mlp.hook_out` | (HF doesn't expose; hidden_state delta) | +| Block i residual out | `blocks.{i}.hook_resid_post` | `outputs.hidden_states[i+1]` | +| Final norm | `ln_final.hook_out` | (HF inlines into lm_head) | +| Logits | `unembed.hook_out` | `outputs.logits` | + +The first hop where they disagree localizes the bug. + +## 2. Common root causes, in order of frequency + +| Symptom | Likely cause | Where to look | +|---|---|---| +| Logits off everywhere but Q/K/V close | RoPE base / scaling mismatch | Adapter's `RotaryEmbeddingBridge` setup; check `cfg.rotary_base`, `cfg.rope_scaling` | +| Attention output drifts; Q / K / V match | Wrong `n_key_value_heads`, wrong head reshape | `_qkvo_weight_conversions(n_kv_heads=...)`; GQA-aware split | +| First-layer outputs off; embeddings off | Embedding scaling missing (Gemma, T5) | `preprocess_weights()` override; `cfg.scale_embeddings` | +| Off by a constant scale in residual | Final-RMS-norm offset missing | `cfg.rmsnorm_uses_offset = True` + `ArithmeticTensorConversion(ADDITION, 1.0)` | +| Logits flat / saturated at extremes | Missing logit softcap | `cfg.output_logits_soft_cap` from HF's `final_logit_softcapping` | +| Attention pattern collapses to argmax | Missing attention-score softcap | `cfg.attn_scores_soft_cap` from HF's `attn_logit_softcapping` | +| Off by `eps` magnitudes in norm | Wrong RMSNorm eps attribute name | `cfg.eps_attr` (Llama uses `"variance_epsilon"`, most others use `"eps"`) | +| First MLP off; gate matches | Forgot gated-MLP wiring | `GatedMLPBridge` with `{gate, in, out}` submodules — not `MLPBridge` | +| Bias-related drift | Adapter assumes biases that don't exist (Llama / RMSNorm) | `ProcessWeights._safe_get_tensor` handles `None`; check the weight-processing conversions are bias-aware | +| Drift only in compatibility mode | Hook semantic carve-out missing for post-norm or MLA | See [compatibility_mode.md](compatibility_mode.md) §"Hook semantic parity" | + +## 3. Isolating weight-processing bugs + +If `no_processing=True` matches HF but `enable_compatibility_mode()` (default) drifts: + +- The bug is in weight processing, not the forward pass. +- Bisect by toggling individual flags: `fold_ln`, `center_writing_weights`, `center_unembed`, `fold_value_biases`. The first one that introduces drift is the culprit. +- See [compatibility_mode.md §"What each flag does"](compatibility_mode.md#what-each-flag-does). + +## 4. Comparing against `boot_transformers` for the same model + +If Bridge ≠ HF on a model that already passes `verify_models`, your adapter likely diverges from the canonical Bridge load configuration. Quick sanity check: + +```python +import torch +from transformer_lens.model_bridge.bridge import TransformerBridge +from transformers import AutoModelForCausalLM + +ref = TransformerBridge.boot_transformers(model_name, device="cpu", dtype=torch.float32) +hf = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, attn_implementation="eager" +) + +ids = torch.tensor([[hf.config.bos_token_id or 1]]) +ref_logits = ref(ids) +hf_logits = hf(ids).logits + +print((ref_logits - hf_logits).abs().max()) # should be < 1e-4 in fp32 +``` + +If `boot_transformers` itself disagrees with HF on the same model, the issue is upstream of your adapter (probably a `_HF_PASSTHROUGH_ATTRS` gap in `transformer_lens/model_bridge/sources/_bridge_builder.py`, or a non-standard HF config attribute that the adapter never propagated onto `self.cfg`). HF raw config attributes are invisible to TL-side consumers unless explicitly mirrored. Common attributes that need propagation: `final_logit_softcapping` (Gemma2/3), `attn_logit_softcapping` (Gemma2/3), `query_pre_attn_scalar` (Gemma2/3), `sliding_window` (Mistral, Qwen2, Gemma2), `layer_types` (hybrid models), and non-standard RMSNorm eps attribute names (Llama uses `variance_epsilon`). + +## 5. Bisecting `verify_models` phase failures + +`verify_models` reports phase-by-phase. Map the failing phase to the bisection focus: + +| Phase | What failed | Start here | +|---|---|---| +| 1 | Forward correctness vs HF | Steps 1–4 above; this is the standard parity workflow | +| 2 | Hook firing / gradient flow | The hook isn't registered, or it's firing on a tensor that's been replaced (in-place op). Grep adapter for in-place ops on hookable tensors. | +| 3 | Weight processing | Run with `no_processing=True` to isolate. Then bisect compat-mode flags per §3 above. | +| 4 | Text-generation quality | Usually tokenizer policy: `default_prepend_bos`, padding side, EOS handling, chat-template wiring. Tokenizer behaviour is per-model, not per-architecture — check the target's `tokenizer_config.json`, don't inherit from a sibling. Less often, a generation-loop divergence; rerun with `--no-ht-reference` to skip HT comparison. | +| 7 | Multimodal alignment | Vision encoder output drift or projection mismatch. Llava / Gemma3-multimodal only. | +| 8 | Audio | HuBERT only; check CTC head and audio-feature alignment. | + +## 6. What "fp noise" actually looks like + +Empirically, in this codebase: + +- **fp32, eager attention, single forward**: HF vs Bridge max-abs diff is typically `< 5e-5`. Anything ≥ `1e-4` is suspicious. +- **bf16, eager**: `< 1e-2` is the noise floor. +- **fp32, sdpa**: `< 5e-4` due to sdpa's internal reductions. Use eager for parity tests. + +If you suspect noise, the cheap proof is to run **fp64**: `dtype=torch.float64` on both sides. If the diff stays the same magnitude, it's a bug. If it drops by ~8 orders of magnitude, it was noise. See the [numerical-work conventions in contributing.md](contributing.md#numerical-work) for more context on why this check is worth running. + +## 7. Tooling + +- `make integration-test PYTEST_ADDOPTS="-k -s"` — focused run with stdout. +- `transformer_lens/scratch.py` (gitignored) — drop one-off bisection scripts here without polluting `git status`. +- `.adapter-workspace/` (gitignored) — sibling directory for WIP adapter notes / repros. +- `bridge.run_with_cache(ids)` — returns `(logits, cache)`; `cache["blocks.{i}.hook_resid_post"]` is the easiest path to per-layer diffs. + +--- + +If you exhaust this guide and still can't localize the bug, the failure pattern is worth adding to §2 above so the next contributor doesn't repeat the bisection. diff --git a/docs/source/content/hook_system.md b/docs/source/content/hook_system.md new file mode 100644 index 000000000..dbf17cf43 --- /dev/null +++ b/docs/source/content/hook_system.md @@ -0,0 +1,177 @@ +# The Hook System + +Hooks are the primary value proposition of TransformerLens. They let you intercept, cache, edit, or ablate intermediate activations as a model runs — without modifying the model code. This page covers the user-facing API; for the implementation, see [`transformer_lens/hook_points.py`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/hook_points.py). + +--- + +## What a hook is + +A `HookPoint` is an `nn.Identity` module placed at a named position inside a model (e.g., `blocks.0.attn.hook_q`). At runtime it passes its input through unchanged — but a registered hook function can read the tensor (caching), modify it (intervention), or replace it (ablation, patching). + +The hook system is purely PyTorch's `register_forward_hook` / `register_full_backward_hook` underneath, but with two additions that matter: + +1. **Named positions.** Every hook point has a string name like `blocks.{i}.hook_resid_post`, so you can address it without traversing the module tree. +2. **Context-scoped lifecycle.** Hooks added inside a `with model.hooks(...)` context auto-remove when the block exits — no manual cleanup required. + +--- + +## Three ways to use hooks + +### 1. `run_with_cache` — read everything + +The simplest workflow: run a forward pass and get back both the logits and a dict of every cached activation, keyed by hook name. + +```python +logits, cache = model.run_with_cache("Hello, world") +cache["blocks.0.attn.hook_q"] # Q tensor at layer 0 +cache["blocks.5.hook_resid_post"] # residual stream after block 5 +cache["ln_final.hook_normalized"] # post-final-norm activations +``` + +`cache` is an `ActivationCache` — a dict-like with conveniences (`cache.decompose_resid()`, `cache.apply_ln_to_stack(...)`, etc.). See [`transformer_lens/ActivationCache.py`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/ActivationCache.py). + +For TransformerBridge: + +```python +from transformer_lens.model_bridge import TransformerBridge +bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") +logits, cache = bridge.run_with_cache("Hello, world") +``` + +Pass `names_filter=` to cache only a subset (saves memory): + +```python +logits, cache = model.run_with_cache( + "Hello, world", + names_filter=lambda name: name.endswith("hook_resid_post"), +) +``` + +### 2. `run_with_hooks` — intervene during the forward pass + +Attach temporary hooks for one forward pass, then auto-remove them. + +```python +def zero_out_head_3(activation, hook): + activation[:, :, 3, :] = 0 # ablate attention head 3 + return activation + +logits = model.run_with_hooks( + "Hello, world", + fwd_hooks=[("blocks.5.attn.hook_z", zero_out_head_3)], +) +``` + +Each hook is a `(hook_name, hook_fn)` tuple. The hook function signature is **always**: + +```python +def hook_fn(tensor: torch.Tensor, *, hook: HookPoint) -> Optional[torch.Tensor]: + # Read or modify `tensor`. Return None to leave the activation unchanged, + # or return a tensor of the same shape to replace it. + ... +``` + +`hook` (the `HookPoint` instance) exposes `hook.name` so a single function can dispatch on which hook called it. + +For backward hooks (gradient interventions), use `bwd_hooks=[...]` with the same tuple shape. + +### 3. `add_hook` + `remove_all_hook_fns` — manual lifecycle + +When you need hooks that persist across multiple forward passes (e.g., during training), drop down to the underlying API: + +```python +hook_point = model.blocks[5].attn.hook_z +hook_point.add_hook(my_hook_fn, dir="fwd") # temporary +hook_point.add_hook(my_hook_fn, dir="fwd", is_permanent=True) # survives reset + +# later +model.remove_all_hook_fns() # removes temporary hooks +model.remove_all_hook_fns(including_permanent=True) # also removes permanent +``` + +`add_hook` returns nothing useful; lifecycle is owned by the `HookPoint`. The `is_permanent` flag is the only way to survive a `remove_all_hook_fns()` call. + +--- + +## Hook naming + +Stable strings; differ between HookedTransformer and TransformerBridge: + +| System | Style | Example | +|---|---|---| +| `HookedTransformer` (legacy) | Uniform across architectures | `blocks.5.attn.hook_q`, `blocks.5.hook_resid_post`, `hook_embed` | +| `TransformerBridge` (default) | Architecture-native | `blocks.5.attn.q.hook_out`, `blocks.5.hook_out`, `embed.hook_out` | +| `TransformerBridge` + compatibility mode | Bridge-native AND HT-style aliases | Above + `blocks.5.attn.hook_q` etc. | + +Full catalogue: [Main Demo](generated/demos/Main_Demo), [Exploratory Analysis Demo](generated/demos/Exploratory_Analysis_Demo). Architecture diagram: [TransformerLens_Diagram.svg](../_static/TransformerLens_Diagram.svg). + +Porting HT code to Bridge: `bridge.enable_compatibility_mode()` (see [Compatibility Mode](compatibility_mode.md)) registers HT aliases so existing names resolve. + +--- + +## Common patterns + +### Cache one activation, run a single forward pass + +```python +logits, cache = model.run_with_cache("text", names_filter="blocks.5.hook_resid_post") +resid_5 = cache["blocks.5.hook_resid_post"] +``` + +### Zero-ablate a head + +```python +def ablate(z, hook): + z[:, :, head_idx, :] = 0 + return z + +model.run_with_hooks("text", fwd_hooks=[(f"blocks.{layer}.attn.hook_z", ablate)]) +``` + +### Activation patching (swap an activation from one prompt into another) + +```python +_, clean_cache = model.run_with_cache(clean_prompt) +target = clean_cache["blocks.5.hook_resid_post"] + +def patch(resid, hook): + return target # replace corrupted's activation with clean's + +logits = model.run_with_hooks( + corrupted_prompt, + fwd_hooks=[("blocks.5.hook_resid_post", patch)], +) +``` + +[`transformer_lens/patching.py`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/patching.py) wraps this pattern for systematic sweeps across layers / positions. + +### Gradient intervention via backward hook + +```python +def scale_grad(grad, hook): + return grad * 0.1 + +model.run_with_hooks( + "text", + bwd_hooks=[("blocks.5.hook_resid_post", scale_grad)], +) +``` + +--- + +## Lifecycle gotchas + +- **Temporary hooks added outside `run_with_hooks` / `model.hooks(...)` do NOT auto-clean.** Call `model.remove_all_hook_fns()` or you'll leak hooks across runs. +- **Permanent hooks (`is_permanent=True`) survive `remove_all_hook_fns()`** — use `including_permanent=True` to clear them. +- **Hook functions that return a tensor replace the activation in-flight.** Returning `None` leaves it unchanged. In-place modification (`tensor[…] = …`) + `return tensor` is the common pattern. +- **Backward hooks see `(grad,)` tuples** at the PyTorch level — the wrapper in `hook_points.py` unwraps to a bare tensor for you. Your hook function still receives a bare tensor. +- **Hooks fire in registration order**, with `prepend=True` to register at the front. + +--- + +## See also + +- [Compatibility Mode](compatibility_mode.md) — when to enable HT-style hook aliases on a Bridge model. +- [Migrating to TransformerLens 3](migrating_to_v3.md) — porting HookedTransformer hook patterns to TransformerBridge. +- [Main Demo](generated/demos/Main_Demo) — end-to-end walkthrough using the hook system. +- [`transformer_lens/hook_points.py`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/hook_points.py), [`transformer_lens/ActivationCache.py`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/ActivationCache.py), [`transformer_lens/patching.py`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/patching.py) — source. diff --git a/docs/source/index.md b/docs/source/index.md index de5cc84eb..b10bd3a9d 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -54,6 +54,9 @@ content/migrating_to_v3 content/tutorials content/citation content/contributing +content/hook_system +content/compatibility_mode +content/debugging_numerical_divergence generated/demos/Main_Demo generated/demos/Exploratory_Analysis_Demo content/special_cases diff --git a/makefile b/makefile index de8c99738..2b5834e28 100644 --- a/makefile +++ b/makefile @@ -62,6 +62,12 @@ notebook-test: $(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb $(RERUN_ARGS) $(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb $(RERUN_ARGS) +test-pr: + $(MAKE) unit-test + $(MAKE) docstring-test + $(MAKE) acceptance-test + $(MAKE) integration-test + test: $(MAKE) unit-test $(MAKE) integration-test diff --git a/tests/AGENTS.md b/tests/AGENTS.md new file mode 100644 index 000000000..d233280ec --- /dev/null +++ b/tests/AGENTS.md @@ -0,0 +1,100 @@ +# Tests — AGENTS.md + +Read [the root AGENTS.md](../AGENTS.md) for project-wide rules. This file covers conventions specific to `tests/`. + +> **Just running tests?** Run `make test-pr` for the standard PR-review surface (unit + docstring + acceptance + integration). `make unit-test` for fast feedback. Full tier breakdown in [Test tiers](#test-tiers) below. + +--- + +## TL;DR + +- **Tier placement matters** — a unit test loading a model belongs in `integration/`. See [tier table](#test-tiers). +- **No mocking model loads or HF Hub.** Session-scoped fixtures amortize the cost. +- **Cached models for fast tests** (`gpt2`, `attn-only-{1,2,3,4}l`, `tiny-stories-1M`). Anything else → `@pytest.mark.slow`. +- **MPS is a carve-out** — `TRANSFORMERLENS_ALLOW_MPS=1` required; only [`tests/mps/`](mps/) runs there. +- **[AGENTS.md §10](../AGENTS.md#10-hard-rules) applies**: no `xfail`/`skipif` to dodge CI; no platform skips outside MPS. +- **Check [QUARANTINES.md](QUARANTINES.md) before debugging a failing test** — known quarantines have documented reasons. + +--- + +## Test tiers + +| Tier | Path | Run | Loads models? | Hits HF Hub? | Scope | Example | +|---|---|---|---|---|---|---| +| `unit` | [`tests/unit/`](unit/) | `make unit-test` | None / synthetic (rare exceptions) | No | Function or single module | [`tests/unit/test_key_value_cache_entry.py`](unit/test_key_value_cache_entry.py) | +| `integration` | [`tests/integration/`](integration/) | `make integration-test` | 1–2 cached models, module-scoped | Yes | Cross-component | [`tests/integration/test_generation_compatibility.py`](integration/test_generation_compatibility.py) | +| `acceptance` | [`tests/acceptance/`](acceptance/) | `make acceptance-test` | Full models (`gpt2`, `bloom-560m`), session-scoped | Yes | End-to-end behaviour | [`tests/acceptance/conftest.py`](acceptance/conftest.py) | +| `benchmarks` | [`tests/benchmarks/`](benchmarks/) | `make benchmark-test` | Varies; performance focus | Yes | Throughput / memory | [`tests/benchmarks/test_boot_memory.py`](benchmarks/test_boot_memory.py) | +| `mps` | [`tests/mps/`](mps/) | `pytest tests/mps -v` (needs `TRANSFORMERLENS_ALLOW_MPS=1`) | TinyStories-1M, fp32 only | Yes | macOS-MPS smoke only | [`tests/mps/test_mps_basic.py`](mps/test_mps_basic.py) | + +Common combinations: `make test-pr` (unit + docstring + acceptance + integration — the PR-review surface), `make test` (everything including benchmarks + notebooks). + +**Rule of thumb:** new tests that load a model should land in `integration/` by default. The `unit/` tier has a few legitimate model-loading exceptions (e.g. `test_bridge_vs_hooked_transformer_*.py` compares numerics across architectures, which is conceptually unit-scoped) — match that pattern only when the test really is testing isolated behaviour that happens to need a model. + +--- + +## Conftest hierarchy + +[`tests/conftest.py`](conftest.py) — root, provides: + +- `cleanup_memory` (function autouse), `cleanup_class_memory` — CUDA/MPS cache + GC +- `_enable_hf_retry_for_tests` (session autouse) — wraps HF `from_pretrained` with 429 retry +- Seeded RNG (numpy/torch/Python @ 42) +- `gpt2_tokenizer` (session) +- `gpt2_hooked_processed` (session) +- `temp_dir` + +Sub-folder conftests: + +| Path | Provides | +|---|---| +| [`tests/acceptance/conftest.py`](acceptance/conftest.py) | `gpt2_model`, `bloom_560m_hooked`, `bloom_560m_hf_model`, `bloom_560m_hf_tokenizer` (all session) | +| [`tests/acceptance/model_bridge/conftest.py`](acceptance/model_bridge/conftest.py) | Bridge variants of gpt2 with/without compat mode | +| [`tests/integration/model_bridge/conftest.py`](integration/model_bridge/conftest.py) | distilgpt2 + gpt2 Bridge variants × {compat, no-compat, no-processing} | + +Two cross-cutting rules: + +- All `transformer_lens` imports inside conftest fixtures live in fixture bodies, not at module top — jaxtyping's `pytest_configure` hook must install before the package is first imported. +- Session-scoped model fixtures (`gpt2_hooked_processed`, `gpt2_bridge`, …) are read-only — mutating them leaks across the entire test session. + +--- + +## Cached-model allowlist + +CI cache ([`checks.yml`](../.github/workflows/checks.yml)) covers: `gpt2`, `gpt2-xl`, `distilgpt2`, `pythia-70m`, `gpt-neo-125M`, `gemma-2-2b-it`, `bloom-560m`, `Qwen2-0.5B`, `bert-base-cased`, `NeelNanda/Attn_Only*`, `roneneldan/TinyStories-1M*`, `NeelNanda/SoLU*`, `redwood_attn_2l`, `tiny-random-llama-2`, `DialoGPT-medium`. + +Prefer `attn-only-{1,2,3,4}l` and `tiny-stories-1M` for fast tests — `gpt2` is slow on CI's CPU runners. Use `gpt2` only when you need GPT-2 numerics. Anything outside the cached set → `@pytest.mark.slow`. + +--- + +## The `slow` marker + +`pyproject.toml`: `"slow: marks tests as slow (deselect with '-m \"not slow\"')"`. Add when the test: + +- loads a non-cached model +- iterates exhaustively over many param combos +- takes >5 s per invocation + +Deselect with `pytest -m "not slow"`. Default `make` targets do NOT filter; the marker is for ad-hoc runs. + +--- + +## MPS rules + +- [`mps-checks`](../.github/workflows/checks.yml) sets `TRANSFORMERLENS_ALLOW_MPS=1` and runs `tests/unit`, `tests/integration`, `tests/mps` on `macos-latest`. +- `get_device()` returns `"cpu"` unless `TRANSFORMERLENS_ALLOW_MPS=1` — protects against silent MPS divergence. +- The workflow's long `--ignore=` list documents existing MPS divergence (MoE, optimizer compat, KV-cache layout); it's **not** a license to add new skips. +- [`tests/mps/test_mps_basic.py`](mps/test_mps_basic.py) is the template: float32 only (no bfloat16 on MPS), TinyStories-1M only (50 MB fits the runner), `torch.mps.empty_cache()` + `gc.collect()` between tests. +- MPS-only modules need: `pytestmark = pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")`. + +--- + +## Hard "don'ts" + +Plus [AGENTS.md §10](../AGENTS.md#10-hard-rules): + +- **No mocking model loads** — session-scoped fixtures are cheap enough. +- **No mocking the HF Hub** — tests hit the real hub with `enable_hf_retry()` handling 429s. +- **No platform `skipif` outside MPS** — no `skipif(sys.platform == 'win32')` or `skipif(not torch.cuda.is_available())` to dodge CI. +- **No `xfail` to dodge a failing test** — fix the bug, even if pre-existing. +- **No copying acceptance-tier tests as unit-test templates** — their model fixtures time out / OOM at the unit tier. diff --git a/tests/QUARANTINES.md b/tests/QUARANTINES.md new file mode 100644 index 000000000..5cf3e9fb2 --- /dev/null +++ b/tests/QUARANTINES.md @@ -0,0 +1,118 @@ +# Test Quarantines + +Inventory of every `skip` / `skipif` / `xfail` in [`tests/`](.). A test on this list with the matching reason is **not your bug** — don't debug blindly. A failure NOT on this list is real. + +Rule ([AGENTS.md §10](../AGENTS.md#10-hard-rules)): **never add `xfail` / `skipif` to dodge a failing CI.** New skips need a row here. + +--- + +## Permanent — optional dependency + +| Path | Marker | Trigger | +|---|---|---| +| [`unit/test_lit.py`](unit/test_lit.py) (×18) | `skipif(not LIT_AVAILABLE)` | `pip install lit-nlp` (`lit` group) | +| [`unit/components/test_attention.py`:48](unit/components/test_attention.py) | `skipif(not is_bitsandbytes_available())` | `uv sync --group quantization` | +| [`unit/test_weight_processing.py`:477](unit/test_weight_processing.py) | same | same | +| [`unit/factories/test_mlp_factory.py`:40](unit/factories/test_mlp_factory.py) | same | same | + +**Un-skip:** never. Install the optional group to run locally. + +--- + +## Permanent — hardware requirement + +| Path | Marker | Required | +|---|---|---| +| [`unit/test_next_sentence_prediction.py`:131](unit/test_next_sentence_prediction.py) | `skipif(not cuda)` | Any CUDA | +| [`unit/model_bridge/compatibility/test_next_sentence_prediction.py`:95](unit/model_bridge/compatibility/test_next_sentence_prediction.py) | `skipif(not cuda)` | Any CUDA | +| [`unit/components/test_attention.py`:83](unit/components/test_attention.py) | `skipif(not cuda)` (half/bfloat16) | Any CUDA | +| [`acceptance/test_hooked_encoder.py`:227](acceptance/test_hooked_encoder.py) | `skipif(not cuda)` | Any CUDA | +| [`acceptance/test_hooked_encoder_decoder.py`:421](acceptance/test_hooked_encoder_decoder.py) | `skipif(not cuda)` | Any CUDA | +| [`acceptance/test_multi_gpu.py`:91,105](acceptance/test_multi_gpu.py) | `skipif(device_count < 2)` | 2+ CUDA | +| [`acceptance/test_multi_gpu.py`:22](acceptance/test_multi_gpu.py) | `skipif(device_count < 4)` | 4+ CUDA | +| [`acceptance/model_bridge/test_multi_gpu_bridge.py`:257](acceptance/model_bridge/test_multi_gpu_bridge.py) | `skipif(device_count < 2)` | 2+ CUDA | +| [`mps/test_mps_basic.py`](mps/test_mps_basic.py) module-level | `skipif(not mps)` | Apple Silicon | + +**Un-skip:** never. CI provides each tier (CUDA via compatibility-checks → CPU-only in practice; MPS via `mps-checks`; multi-GPU local-only). See [tests/AGENTS.md §MPS rules](AGENTS.md#mps-rules) and the `--ignore=` list in [`checks.yml`](../.github/workflows/checks.yml). + +--- + +## Intentional — CI cost / network budget + +`skipif(os.getenv("CI"))` to avoid expensive HF fetches / large loads. + +| Path | Reason | +|---|---| +| [`unit/model_bridge/supported_architectures/test_gemma2_adapter.py`:49](unit/model_bridge/supported_architectures/test_gemma2_adapter.py) | "Network/disk fetch of tiny Gemma2" | +| [`integration/model_bridge/test_bridge_integration.py`:801](integration/model_bridge/test_bridge_integration.py) | "Skip Gemma2 in CI to avoid timeout" | +| [`acceptance/model_bridge/compatibility/test_hook_completeness.py`:156](acceptance/model_bridge/compatibility/test_hook_completeness.py) | "Gemma2 too large for CI" | + +**Un-skip:** locally with `HF_TOKEN` sourced. + +--- + +## Intentional — manual verification only + +| Path | Reason | +|---|---| +| [`integration/model_bridge/test_qwen3_moe_bridge.py`:155,166](integration/model_bridge/test_qwen3_moe_bridge.py) | "Requires real weights — run via `verify_models`" | + +**Un-skip:** `/verify-model Qwen/Qwen3-MoE-...` ([tools/model_registry/AGENTS.md](../transformer_lens/tools/model_registry/AGENTS.md)). + +--- + +## Upstream / platform bug + +| Path | Reason | Issue | +|---|---|---| +| [`unit/model_bridge/test_bridge_generate_no_tokenizer.py`:30,128](unit/model_bridge/test_bridge_generate_no_tokenizer.py) | `skipif(_MACOS_ARM64)` — KV-cache NaN | Upstream PyTorch/HF on M-series Macs | + +**Un-skip:** when upstream resolves. Don't bypass — produces NaN logits. + +--- + +## ⚠️ Technical debt — whole-file + +Entire test modules quarantined via module-level `pytestmark`. Significant coverage gap — priority to re-enable. + +| Path | Reason | +|---|---| +| [`acceptance/test_hooked_transformer.py`:19](acceptance/test_hooked_transformer.py) | "CI test pollution" | +| [`acceptance/test_hooked_encoder.py`:13](acceptance/test_hooked_encoder.py) | same | +| [`acceptance/test_hooked_encoder_decoder.py`:10](acceptance/test_hooked_encoder_decoder.py) | same | + +**Un-skip:** root-cause the test pollution (fixture-scope or import-ordering bug). Until then, these acceptance tiers are dark. + +--- + +## Technical debt — individual + +| Path | Marker | Covers | +|---|---|---| +| [`unit/factored_matrix/test_constructor.py`:54](unit/factored_matrix/test_constructor.py) | `skip` | FactoredMatrix constructor edge case | +| [`unit/model_bridge/test_architecture_adapter.py`:436](unit/model_bridge/test_architecture_adapter.py) | `skip` | Adapter behaviour | +| [`unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py`:138,142](unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py) | `skipif`/`xfail` | Bridge↔HT patching parity | +| [`unit/model_bridge/test_hook_alias_resolution.py`:89](unit/model_bridge/test_hook_alias_resolution.py) | `xfail(strict=True)` per-arch | Hook-alias gaps | +| [`unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py`:574,609,660,680,771](unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py) | `skipif` ×5 | Qwen3.5 quirks | +| [`unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py`:531](unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py) | `skipif` | Qwen3-Next quirks | +| [`integration/test_weight_processing_integration.py`:238](integration/test_weight_processing_integration.py) | `skip` | Weight-processing edge case | +| [`integration/test_tensor_extraction_consistency.py`:33](integration/test_tensor_extraction_consistency.py) | `skip` | Tensor extraction | +| [`integration/test_tokenization_methods.py`:53](integration/test_tokenization_methods.py) | `skipif` | Tokenization coverage | +| [`integration/test_hooked_encoder_properties.py`:71](integration/test_hooked_encoder_properties.py) | `xfail` | HookedEncoder properties | +| [`acceptance/model_bridge/compatibility/test_backward_hooks.py`:11](acceptance/model_bridge/compatibility/test_backward_hooks.py) | `skip` | Backward-hook compatibility | +| [`acceptance/test_hooked_transformer.py`:551,560](acceptance/test_hooked_transformer.py) | `skipif` ×2 (inside module-level skip) | `from_pretrained_no_processing` | + +**Un-skip:** debug the underlying issue and remove the marker. Each removal lands in a focused PR with a regression test. + +--- + +## Adding a new quarantine + +Read [AGENTS.md §10](../AGENTS.md#10-hard-rules) first — default answer is "fix the bug instead." + +If a quarantine is genuinely right: + +1. Pick the right marker — `skipif(condition)` for env gates; `skip(reason=)` for known-bad paths; `xfail(strict=True, reason=)` when you expect failure and want CI to alert if it passes. +2. Use a `reason=` descriptive enough to look up — not `"flaky"` or `"broken"`. +3. Add a row above with path, marker, "un-skip when" line. +4. Whole-module `pytestmark` skips go in the ⚠️ section for visibility. diff --git a/transformer_lens/benchmarks/AGENTS.md b/transformer_lens/benchmarks/AGENTS.md new file mode 100644 index 000000000..4f9159bb5 --- /dev/null +++ b/transformer_lens/benchmarks/AGENTS.md @@ -0,0 +1,19 @@ +# Benchmarks — AGENTS.md + +This is the benchmark library that `verify_models` calls into. Read [the root AGENTS.md](../../AGENTS.md) for project-wide rules. + +## ⚠ Not the registry-update path + +**Use [`verify_models`](../tools/model_registry/AGENTS.md), not [`main_benchmark.py`](main_benchmark.py), for registry updates.** `main_benchmark` runs the same benchmark math but defaults to NOT writing `data/supported_models.json` (needs `--update-registry`, and even with that flag misses Phase 7 / 8 scores and the resume checkpoint). + +If an agent is here because the user asked to "update the registry" or "verify a model," they took the wrong entry point — redirect to [tools/model_registry/AGENTS.md](../tools/model_registry/AGENTS.md). + +## What this directory IS for + +- The phase-by-phase benchmark implementations (`forward_pass.py`, `generation.py`, `hook_registration.py`, `weight_processing.py`, `multimodal.py`, `audio.py`, `text_quality.py`, `granular_weight_processing.py`, `component_outputs.py`, `backward_gradients.py`, `activation_cache.py`, `component_benchmark.py`, `hook_structure.py`). +- `main_benchmark.py` — exploratory benchmark runner for ad-hoc comparison. Useful for debugging a single model's phase scores without touching the registry. +- `utils.py` — shared helpers including `BenchmarkSeverity`. + +## When you ARE in the right place + +You're modifying a benchmark implementation (e.g., adding a new check to Phase 2 hook firing), debugging why a phase score is wrong, or adding a benchmark for a new architectural feature. In that case, follow project conventions in [AGENTS.md §10](../../AGENTS.md#10-hard-rules) and ensure the change is exercised by an existing or new `verify_models` phase. diff --git a/transformer_lens/config/AGENTS.md b/transformer_lens/config/AGENTS.md new file mode 100644 index 000000000..c88f0ef56 --- /dev/null +++ b/transformer_lens/config/AGENTS.md @@ -0,0 +1,44 @@ +# Config — AGENTS.md + +The config dataclasses that drive both `HookedTransformer` and `TransformerBridge`. Read [the root AGENTS.md](../../AGENTS.md) for project-wide rules. + +## File map + +| File | Class | Used by | +|---|---|---| +| [`transformer_lens_config.py`](transformer_lens_config.py) | `TransformerLensConfig` | Minimal base — only fields actually used by the system | +| [`transformer_bridge_config.py`](transformer_bridge_config.py) | `TransformerBridgeConfig(TransformerLensConfig)` | The Bridge config; what every Bridge adapter receives as `cfg` | +| [`hooked_transformer_config.py`](hooked_transformer_config.py) | `HookedTransformerConfig` | Legacy HT-only config (deprecated; see [AGENTS.md §2](../../AGENTS.md#2-two-systems-live-in-this-repo)) | + +## Adding a new HF-config attr to `TransformerBridgeConfig` — decision tree + +The `logit_scale` bug existed because the rules below weren't documented anywhere. **Pick exactly one of the four paths**; doing more than one creates silent-override risk. + +| Use case | Path | +|---|---| +| First-class TL field that adapters / hooks / weight processing read | **Declare as a dataclass parameter** on `TransformerBridgeConfig`. Set a sensible default. Update `map_default_transformer_lens_config` in [`sources/transformers.py`](../model_bridge/sources/transformers.py) to translate the HF-config attr name to your field name. | +| HF attr the adapter reads at runtime, no semantic translation needed | **Add to `_HF_PASSTHROUGH_ATTRS`** in BOTH [`sources/transformers.py:481`](../model_bridge/sources/transformers.py) AND [`sources/_bridge_builder.py:18`](../model_bridge/sources/_bridge_builder.py). The trap: adding to only one half-fixes. See [sources/AGENTS.md](../model_bridge/sources/AGENTS.md). | +| HF attr name differs from existing TL field | **Add an explicit handler** in `map_default_transformer_lens_config` (e.g. Gemma2's `final_logit_softcapping` → `output_logits_soft_cap`). Don't also add to PASSTHROUGH. | +| Just a derived view of an existing field | **Add a `@property`** on `TransformerBridgeConfig` (e.g. `head_dim` aliases `d_head`). | + +**Don't** declare the same attr both as a dataclass field AND a PASSTHROUGH entry — PASSTHROUGH writes happen AFTER `from_dict`, so the runtime value will silently overwrite whatever defaults / explicit handlers set. + +## Existing properties that may surprise you + +- **`head_dim`** is a read-only `@property` aliasing `d_head` (line 212). `setattr(cfg, "head_dim", val)` raises `AttributeError: property 'head_dim' has no setter`. Don't add it to PASSTHROUGH. +- **`n_heads`** has `= -1` as its placeholder in the constructor signature, deliberately. Comment in the source: *"Add n_heads to signature so it's not filtered out by from_dict"*. Don't "fix" this default. + +## Verifying a new field propagates + +Integration test pattern (the regression test that would have caught the `logit_scale` bug): + +```python +def test_cfg__matches_hf(bridge: TransformerBridge, hf_model: Any) -> None: + """Regression: must propagate from HF (not silently fall back).""" + bridge_val = getattr(bridge.cfg, "") + assert bridge_val == hf_model.config. +``` + +The `getattr` form sidesteps mypy's "undeclared field" complaint without `# type: ignore` (see [AGENTS.md §10](../../AGENTS.md#10-hard-rules)). + +Pick a test model whose `` value differs from the adapter's hardcoded fallback (if any) — otherwise the assertion passes by tautology. diff --git a/transformer_lens/model_bridge/sources/AGENTS.md b/transformer_lens/model_bridge/sources/AGENTS.md new file mode 100644 index 000000000..a75bbe210 --- /dev/null +++ b/transformer_lens/model_bridge/sources/AGENTS.md @@ -0,0 +1,47 @@ +# Sources — AGENTS.md + +Backend loaders that translate external models (HF, native) into a `TransformerBridge`. Read [the root AGENTS.md](../../../AGENTS.md) for project-wide rules. + +## File map + +| File | Role | +|---|---| +| [`_bridge_builder.py`](_bridge_builder.py) | Loader-agnostic helpers: `build_bridge_config_from_hf`, `build_bridge_from_module`, `detect_tokenizer_bos_eos` | +| [`transformers.py`](transformers.py) | HuggingFace backend: `boot` (entry point used by `TransformerBridge.boot_transformers`), `map_default_transformer_lens_config` (HF→TL config translation), `check_model_support`, `list_supported_models` | +| [`native/model.py`](native/model.py), [`native/init.py`](native/init.py) | TL-native models built from scratch (no HF), used by `TransformerBridge.boot_native` | + +## ⚠ The duplicate `_HF_PASSTHROUGH_ATTRS` trap + +There are **two** identical `_HF_PASSTHROUGH_ATTRS` lists: + +- `transformers.py:481` — inside `map_default_transformer_lens_config`, copies HF-config attrs onto `tl_config`. +- `_bridge_builder.py:18` — module-level, copies HF-config attrs onto `bridge_config` AFTER `TransformerBridgeConfig.from_dict` runs. + +Both fire sequentially on the same `hf_config` during `boot_transformers`. **When adding a passthrough attr, add it to BOTH lists** — adding to only one half-fixes the bug. A previous regen-agent run shipped a half-fix that the regression test caught only because we also added an assertion on the canonical bridge config. + +## Two-pass config-translation pipeline + +``` +hf_config + │ + ├─ map_default_transformer_lens_config() [transformers.py] + │ ├─ explicit handlers (d_model, n_heads, head_dim → d_head, etc.) + │ └─ _HF_PASSTHROUGH_ATTRS copy → tl_config + │ + ├─ TransformerBridgeConfig.from_dict() [filters to declared fields] + │ + └─ _HF_PASSTHROUGH_ATTRS copy → bridge_config [_bridge_builder.py] +``` + +**Implications:** + +- If an attr has an explicit handler in `map_default_transformer_lens_config` (like `head_dim` → `d_head`), it's translated to a declared `TransformerBridgeConfig` field — adding it to PASSTHROUGH is **wrong** (often raises `AttributeError` from a read-only property). +- If an attr is purely runtime / adapter-specific (like Cohere's `logit_scale`), it has no explicit handler and is dropped by `from_dict` — it MUST be in both PASSTHROUGH lists, or the adapter's `getattr(cfg, "", default)` silently falls back to its default forever. +- See [config/AGENTS.md](../../config/AGENTS.md) for the decision tree. + +## `boot_transformers` vs `boot_native` + +- `boot_transformers` is the HF path. Goes through `boot()` in `transformers.py` → `build_bridge_from_module` in `_bridge_builder.py` → adapter init. +- `boot_native` builds a TL-native transformer from a `TransformerBridgeConfig` directly (no HF dependency). Uses [`native.py`](../supported_architectures/native.py) as the single adapter; the model class is in [`native/model.py`](native/model.py). + +Almost every contributor change goes through the `boot_transformers` path. Touch `native/` only when adding a primitive (new norm type, new positional embedding variant) that the cfg-driven dispatch in `native/model.py` doesn't already cover. diff --git a/transformer_lens/model_bridge/supported_architectures/AGENTS.md b/transformer_lens/model_bridge/supported_architectures/AGENTS.md new file mode 100644 index 000000000..8d1bf047d --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/AGENTS.md @@ -0,0 +1,374 @@ +# Supported Architectures — AGENTS.md + +Read [the root AGENTS.md](../../../AGENTS.md) for project-wide rules. This file covers conventions specific to writing TransformerBridge architecture adapters. + +For the **verification** workflow after writing an adapter, see [tools/model_registry/AGENTS.md](../../tools/model_registry/AGENTS.md). Do not duplicate verification content here. + +--- + +## TL;DR + +- **One file per architecture family**, not per model. `llama.py` covers Llama 1 / 2 / 3 / 3.1 / 3.2; do NOT add `llama7b.py` or `llama2.py`. Family splits happen only when internal structure changes (`gemma1.py` / `gemma2.py` / `gemma3.py`, `qwen.py` / `qwen2.py` / `qwen3.py`, `mistral.py` / `mixtral.py`). +- **Four-place registration is mandatory.** See [Registration steps](#registration-steps) — missing any one causes silent runtime failure or stale generated docs. +- **`self.component_mapping: ComponentMapping`** is the load-bearing contract. Maps canonical TransformerLens paths → Bridge components wrapping HF module paths. +- **Hook names are Bridge-native** (`blocks.{i}.hook_out`, `blocks.{i}.attn.q.hook_out`). Don't invent HT-style aliases here — those are registered separately in [`bridge.py`](../bridge.py). +- **Copy from the closest existing adapter** before writing from scratch. See the [starter table](#starter-adapter-table). +- **Surface non-standard HF config attrs explicitly** — they are invisible to TL-side consumers unless propagated to `self.cfg`. See [Config-attr propagation](#config-attr-propagation). +- **Tokenizer policy is per-model, not per-architecture.** Never inherit `default_prepend_bos`, `default_padding_side`, EOS handling, or chat-template wiring from the starter adapter without checking the target's `tokenizer_config.json`. See [Tokenizer policy](#tokenizer-policy). + +--- + +## Minimal contract + +Every adapter inherits from `ArchitectureAdapter` ([`../architecture_adapter.py`](../architecture_adapter.py)) and declares in `__init__`: + +```python +class MyArchitectureAdapter(ArchitectureAdapter): + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + # Required config flags — set on self.cfg, not just default_config + self.cfg.normalization_type = "RMS" # "RMS" or "LN" + self.cfg.positional_embedding_type = "rotary" # "rotary" | "standard" | "relative_positional_bias" + self.cfg.final_rms = True + self.cfg.gated_mlp = True + self.cfg.attn_only = False + + # Architecture-conditional flags + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.cfg.n_key_value_heads = cfg.n_key_value_heads # GQA / MQA + + # Weight reshape rules where Bridge layout differs from HF + self.weight_processing_conversions = { + **self._qkvo_weight_conversions(), + } + + # The load-bearing mapping + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "attn": PositionEmbeddingsAttentionBridge(...), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "mlp": GatedMLPBridge(...), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } +``` + +Optional overrides on the class: + +- `default_cfg` — class-level defaults merged into the runtime config +- `preprocess_weights()` — pre-load weight transforms (e.g. Gemma embedding scaling, Cohere `logit_scale` fold). See [When to override `preprocess_weights`](#when-to-override-preprocess_weights). +- `applicable_phases` — subset of `[1, 2, 3, 4, 7, 8]` (default `[1, 2, 3, 4]`); see [tools/model_registry/AGENTS.md](../../tools/model_registry/AGENTS.md) for what each phase tests +- `supports_generation` — set `False` for encoder-only models (BERT, HuBERT) + +--- + +## Starter-adapter table + +| Target shape | Start from | Why | +|---|---|---| +| Decoder-only causal LM, RMSNorm + RoPE | [`llama.py`](llama.py) | Canonical modern shape; already handles GQA via `n_key_value_heads` | +| Decoder-only causal LM, LayerNorm + combined QKV | [`gpt2.py`](gpt2.py) | Older GPT-style; demonstrates QKV split/rearrange | +| GQA / MQA | [`llama.py`](llama.py) | Pass `n_kv_heads=` to `_qkvo_weight_conversions()` | +| RMSNorm with offset | [`gemma1.py`](gemma1.py) | Uses `rmsnorm_uses_offset = True` + `ArithmeticTensorConversion` | +| Mixture-of-experts | [`mixtral.py`](mixtral.py) | Uses `MoEBridge` with batched experts | +| Vision-language | [`llava.py`](llava.py) | Dual encoder + projection pathways | +| Encoder-decoder (T5-style) | [`t5.py`](t5.py) | Sets `supports_generation = False`, separate encoder/decoder block lists | +| State-space model | [`mamba.py`](mamba.py), [`mamba2.py`](mamba2.py) | Off the transformer path entirely | +| Encoder-only with CTC head | [`bert.py`](bert.py), [`hubert.py`](hubert.py) | `supports_generation = False` | + +--- + +## Config-attr propagation + +HF raw config attributes are invisible to TL-side consumers unless propagated to `self.cfg`. Walk the `config.json` for non-standard attributes; mirror anything the base machinery doesn't already handle. + +**Surface-on-cfg attributes** (set on `self.cfg.`; bridge reads from there): + +| HF attribute | Surface as | Used by | +|---|---|---| +| `final_logit_softcapping` | `self.cfg.final_logit_softcapping` | Gemma2/3 — final-layer logit clip | +| `attn_logit_softcapping` | `self.cfg.attn_logit_softcapping` | Gemma2/3 — attention-score clip | +| `query_pre_attn_scalar` | `self.cfg.query_pre_attn_scalar` | Gemma2/3 — query scaling override | +| `sliding_window` | `self.cfg.sliding_window` | Mistral, Qwen2, Gemma2 — local-attention layers | +| `layer_types` | `self.cfg.layer_types` | Hybrid models with per-layer attention type lists | +| Non-standard RMSNorm eps key | `self.cfg.eps_attr = ""` | Llama uses `"variance_epsilon"` instead of `"eps"` | + +**Weight-fold attributes** (need BOTH surface-on-cfg AND fold-into-weight via `preprocess_weights` — see [the next section](#when-to-override-preprocess_weights)): + +| HF attribute | Fold target | Used by | +|---|---|---| +| `logit_scale` | `unembed.weight` (multiply in fp32 then cast back) | Cohere — final logits scaled by `1/16` | +| `embedding_multiplier` / embed-scale flags | `embed.weight` (multiply) | Gemma — embeddings scaled by `√d_model` | +| Tied unembed with extra scale | `unembed.weight` | T5-family tied projection variants | + +Rule of thumb: if the model card or HF source mentions a numerical knob, assume it needs to land on `self.cfg`. If that knob changes weights or final outputs and HF's forward applies it natively, you ALSO need a `preprocess_weights` override or compatibility mode will diverge. + +**Passthrough gotcha:** if your attr is NOT a declared `TransformerBridgeConfig` field, `TransformerBridgeConfig.from_dict(hf_config)` silently filters it out — `getattr(cfg, "", None)` in your adapter will return `None` and your fallback default fires regardless of the model's actual value. To propagate, add the attr name to `_HF_PASSTHROUGH_ATTRS` in [`sources/transformers.py`](../sources/transformers.py) (and the duplicate list in [`sources/_bridge_builder.py`](../sources/_bridge_builder.py)). Verify with an integration-test assertion: `assert bridge.cfg. == hf_model.config.`. + +--- + +## When to override `preprocess_weights` + +Default is a no-op pass-through. Override when a numerical op HF applies natively in forward must also be baked into raw weights — otherwise `bridge.enable_compatibility_mode()` (which calls `process_weights` expecting the math to already be in the weights) diverges. + +> **Trigger:** if a config attr changes the forward-pass math AND isn't re-applied by compat-mode weight processing, fold it into the relevant weight in `preprocess_weights()`. + +Examples: + +- **Cohere** — `cfg.logit_scale` (default `0.0625`) folds into `unembed.weight`. HF forward multiplies logits by `logit_scale`; compat-mode doesn't. +- **Gemma1/2/3** — embedding scale (`√d_model`) folds into `embed.weight`. HF's `GemmaTextScaledWordEmbedding` scales on forward; compat-mode reads raw. + +Skeleton: + +```python +import torch + +def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Fold into before ProcessWeights runs. + + bridge.py clones unembed.weight before calling this, so the scale does not + leak into the tied embed.weight. + """ + scale: float = getattr(self.cfg, "") # set in __init__ + if scale != 1.0: # no-op when scale is identity + key = "unembed.weight" + if key in state_dict: + orig_dtype = state_dict[key].dtype + state_dict[key] = (state_dict[key].float() * scale).to(orig_dtype) + return state_dict +``` + +**Skip the override when:** the op lives inside an HF submodule the bridge delegates to (e.g. RoPE inside `CohereRotaryEmbedding`) — HF forward applies it on both paths, parity is free. + +**Verification:** a missing fold degrades Phase 4 (text-generation), and Phase 3 if severe enough to trip the strict gate. P4's 50% bar is intentionally lenient, so sub-100% P4 on a small parity-test model is worth investigating even with `STATUS_VERIFIED`. See [tools/model_registry/AGENTS.md §Phase-score thresholds](../../tools/model_registry/AGENTS.md#phase-score-thresholds). + +--- + +## Tokenizer policy + +Tokenizer behaviour is **per-model**, not per-architecture — siblings routinely differ (instruct vs base BOS, padding side, EOS, chat template). **Don't copy these from your starter adapter blindly.** + +| Flag | Default | Failure mode | +|---|---|---| +| `default_prepend_bos` | Framework default | Off-by-one logits; generation starts from wrong context | +| `default_padding_side` (`"left"` / `"right"`) | HF tokenizer default | Batched-uneven-length generation produces garbage | +| EOS handling (single or list) | `cfg.eos_token_id` | Generation runs past or stops short | +| Chat-template auto-apply | Off (user calls explicitly) | Instruct models produce base-style continuations | +| Tokenizer class mismatch | HF resolves | Subtle BPE/SentencePiece divergence (rare, real for forks) | + +**How to verify:** + +1. **Run the tokenizer.** `tokenizer_config.json` lies for some architectures — Cohere declares `add_bos_token=False` but `__call__` prepends BOS via `add_special_tokens=True`. Only reliable check is to invoke it: + + ```python + from transformers import AutoTokenizer + t = AutoTokenizer.from_pretrained("") + print(t("hello").input_ids, t.bos_token_id) + ``` + + First-token == BOS → set `cfg.default_prepend_bos = True`. Otherwise leave unset. + +2. **Cross-check `tokenizer_config.json`** for `padding_side`, `bos_token`, `eos_token`, `chat_template` (these tend to be honest; runtime mostly overrides BOS-prepending). +3. **Compare to closest sibling** in registry; differences are usually deliberate. +4. **When unsure, leave the framework default** — explicit wrongness hides under "I configured it." + +`default_prepend_bos` is the most common trap (Cohere config-vs-runtime mismatch, instruct/base divergence in Llama/Mistral/Gemma). + +--- + +## Registration steps + +After writing `myarch.py` with `MyArchitectureAdapter`, register in **all four** sites. Missing any one breaks something: + +1. **[`__init__.py`](__init__.py)** — import + add to `__all__`: + ```python + from transformer_lens.model_bridge.supported_architectures.myarch import ( + MyArchitectureAdapter, + ) + ``` + Missing → import error at boot. + +2. **[`../../factories/architecture_adapter_factory.py`](../../factories/architecture_adapter_factory.py)** — import + `SUPPORTED_ARCHITECTURES` entry: + ```python + from transformer_lens.model_bridge.supported_architectures import ( + ..., + MyArchitectureAdapter, + ) + + SUPPORTED_ARCHITECTURES = { + ..., + "MyArchForCausalLM": MyArchitectureAdapter, # key must match config.architectures[0] exactly + } + ``` + Missing → `boot_transformers` raises "unsupported architecture." + +3. **[`../../tools/model_registry/__init__.py`](../../tools/model_registry/__init__.py)** — two updates in this file: + - Add `"MyArchForCausalLM"` to `HF_SUPPORTED_ARCHITECTURES` + - Add `"MyArchForCausalLM": ["foundation-org-1", "foundation-org-2"]` to `CANONICAL_AUTHORS_BY_ARCH` + + Missing → HF scraper misses canonical models for this architecture (download-threshold bypass). + +4. **[`../../tools/model_registry/generate_report.py`](../../tools/model_registry/generate_report.py)** — one-line entry in `ARCHITECTURE_DESCRIPTIONS`: + ```python + "MyArchForCausalLM": "Short human-readable description.", + ``` + Missing → generated docs table omits the new architecture. + +**Verify the four-place wiring**: [`TestRegistrySyncedWithFactory`](../../../tests/unit/tools/test_model_registry.py) asserts four bidirectional invariants across `SUPPORTED_ARCHITECTURES`, `HF_SUPPORTED_ARCHITECTURES`, `CANONICAL_AUTHORS_BY_ARCH`. Run: + +```bash +uv run pytest tests/unit/tools/test_model_registry.py -k TestRegistrySyncedWithFactory +``` + +Failure message names the missing set. (`INTENTIONAL_EXCLUDES` in the test handles internal-only architectures and HF-emits-different-casing aliases; new adapters rarely belong there.) + +--- + +## Common gotchas + +| HF convention | Handled in | How | +|---|---|---| +| RoPE (rotary positional embeddings) | `llama.py`, `mistral.py`, `qwen2.py`+ | `RotaryEmbeddingBridge(name="model.rotary_emb")` + `cfg.positional_embedding_type = "rotary"` | +| GQA / MQA (`n_key_value_heads < n_heads`) | `llama.py`, `mistral.py`, `falcon.py`, `cohere.py` | Set `cfg.n_key_value_heads`; pass `n_kv_heads=` to `_qkvo_weight_conversions()` | +| RMSNorm with offset | `gemma1.py`, `gemma2.py`, `gemma3.py` | `cfg.rmsnorm_uses_offset = True` + `ArithmeticTensorConversion(ADDITION, 1.0)` | +| Custom RMSNorm eps attribute | `llama.py` | `cfg.eps_attr = "variance_epsilon"` (Llama uses this instead of `eps`) | +| Standard LayerNorm | `gpt2.py`, `bloom.py` | `cfg.normalization_type = "LN"` | +| Gated MLP (`gate_proj`, `up_proj`, `down_proj`) | `llama.py`, `mistral.py`, `gemma1.py`, `qwen2.py`+ | `GatedMLPBridge` with submodules `{gate, in, out}` | +| Combined QKV (`c_attn`) | `gpt2.py`, `bloom.py` | `QKVSplitRearrangeConversion` to split + rearrange | +| Split Q/K/V (standard) | `llama.py`, `mistral.py`, most modern | `self._qkvo_weight_conversions()` helper | +| MoE routing | `mixtral.py`, `deepseek_v3.py`, `qwen3_moe.py`, `granite_moe.py` | `MoEBridge` with `gate` + batched expert submodules | +| Missing biases (RMSNorm has no `b`; Llama has no attn/MLP biases) | `llama.py` (documented in docstring) | Weight processing handles `None` via `ProcessWeights._safe_get_tensor()` | +| KV cache layout | All (implicit) | Adapter delegates; HF module manages internally | + +--- + +## Adding the HF repo to the registry + +After registration, add the model ID to [`data/supported_models.json`](../../tools/model_registry/data/supported_models.json) so `verify_models` can resolve it. Entry shape: + +```json +{ + "architecture_id": "MyArchForCausalLM", + "model_id": "org/repo-name", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, "phase2_score": null, "phase3_score": null, + "phase4_score": null, "phase7_score": null, "phase8_score": null +} +``` + +**Hand-edits add new entries only.** `status`, `verified_date`, `note`, `phaseN_score` are written by `update_model_status()`; never set manually. See [tools/model_registry/AGENTS.md](../../tools/model_registry/AGENTS.md). + +Prompt the user about canonical sibling variants from `CANONICAL_AUTHORS_BY_ARCH[]` — e.g. adding `google/gemma-2-2b`, ask about `-2b-it`, `-9b`, `-9b-it`, `-27b`, `-27b-it`. The HF scraper picks these up eventually; explicit entries unblock `verify_models` immediately. + +Do NOT add the model to [`supported_models.py`](../../supported_models.py) — HookedTransformer-only. + +--- + +## Model source paths: `boot_transformers` vs `boot_native` + +Two parallel load paths, same adapter system: + +| Source | Entry | Use case | +|---|---|---| +| `boot_transformers` | `TransformerBridge.boot_transformers(model_name, ...)` | Default — wraps an HF model. Every adapter here supports it. | +| `boot_native` | `TransformerBridge.boot_native(cfg, init_mode=...)` | TL-native, no HF dependency. Used by [`Realtime_Training_Telemetry_Demo.ipynb`](../../../demos/Realtime_Training_Telemetry_Demo.ipynb) and cfg-driven small models. | + +Native models live in [`sources/native/model.py`](../sources/native/model.py); routed through [`native.py`](native.py) whose `component_mapping` adapts to `cfg` (rotary drops `pos_embed`, RMS norm → `RMSNormalizationBridge`, gated → `GatedMLPBridge`, `attn_only` drops MLP). Init policies in [`sources/native/init.py`](../sources/native/init.py): `gpt2` (default, Normal + 1/√(2·n_layers) residual scaling), `xavier_uniform/normal`, `kaiming_uniform/normal`. Scoped `torch.Generator` keeps seeded init from perturbing global RNG. + +Cfg-driven Native features: `normalization_type` (LN / RMS / RMSPre), `final_rms`, `gated_mlp`, `attn_only`, `n_key_value_heads`, `attn_scores_soft_cap`, `output_logits_soft_cap`, `positional_embedding_type`, `rotary_dim`, `rotary_base`, `rope_scaling` (linear PI, dynamic/NTK, llama3 by-parts). + +Sister backends under [`sources/`](../sources/) — `transformers/`, `vllm/`, `inspect/` — are internal. Almost all new work writes a `supported_architectures/.py` file targeting the HF backend. + +**Need both paths?** Almost never. Touch `native.py` only when adding a primitive (e.g. new norm type) the cfg-driven dispatch doesn't cover. + +--- + +## Required tests + +Two test layers per architecture, both required for review: + +### 1. Unit adapter test — `tests/unit/model_bridge/supported_architectures/test__adapter.py` + +26 of these exist; they all follow the same shape. Copy from a sibling that matches your architecture's quirks: + +- **Standard decoder LM** template: [`test_gemma1_adapter.py`](../../../tests/unit/model_bridge/supported_architectures/test_gemma1_adapter.py) or [`test_gpt2_adapter.py`](../../../tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py) +- **GQA / modern RMSNorm+RoPE** template: [`test_qwen3_adapter.py`](../../../tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py) +- **Multimodal** template: [`test_llava_adapter.py`](../../../tests/unit/model_bridge/supported_architectures/test_llava_adapter.py), [`test_gemma3_multimodal_adapter.py`](../../../tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py) + +The minimal shape: + +```python +def _make_cfg(d_model: int = 32) -> TransformerBridgeConfig: + return TransformerBridgeConfig( + d_model=d_model, d_head=d_model // 4, n_layers=2, n_ctx=128, + n_heads=4, d_vocab=256, d_mlp=64, architecture="MyArchForCausalLM", + ) + +@pytest.fixture(scope="module") +def adapter() -> MyArchitectureAdapter: + return MyArchitectureAdapter(_make_cfg()) + +class TestMyArchHookCompatibility: + def test_(self, adapter): ... +``` + +No weight load, no HF Hub access — synthetic cfg + structural assertions only. Runs in default `make unit-test`. + +Add one test per architecture quirk (softcaps, RMSNorm offsets, sliding window, custom `eps_attr`, MoE routing). Gemma1's "must NOT override `setup_hook_compatibility`" is a good one-quirk-one-test example. + +### 2. Integration parity test — `tests/integration/model_bridge/test__adapter.py` + +Loads a real cached HF model; asserts logit parity at fp32 + eager attention. Templates: [`test_deepseek_adapter.py`](../../../tests/integration/model_bridge/test_deepseek_adapter.py), [`test_falcon_adapter.py`](../../../tests/integration/model_bridge/test_falcon_adapter.py), [`test_mamba_adapter.py`](../../../tests/integration/model_bridge/test_mamba_adapter.py). + +Required test classes (one per concern; copy-rename from a sibling): + +| Class | Asserts | +|---|---| +| `TestBridgeCreation` | Boot succeeds, expected components present, `cfg.` set correctly | +| `TestForwardEquivalence` | Logit parity vs HF at fp32 + eager attention | +| `TestHFDelegation` | Bridge submodules hold live HF objects (e.g. `bridge.blocks[0].attn.q.original_component is hf_model.model.layers[0].self_attn.q_proj`) | +| `TestHookShapes` | Hook fires; output shape matches expectation (replace with `ParallelHooks` / `MoEHooks` etc. for architecture-specific block shapes) | +| `Test` | One class per architectural quirk: `LogitScale`, `RMSNormOffset`, `Softcap`, `TiedEmbedding`, etc. — each propagated config attr should have a test asserting `bridge.cfg. == hf_model.config.` | + +The boot fixture uses `dtype=` (Bridge's API), NOT `torch_dtype=` (HF's): + +```python +@pytest.fixture(scope="module") +def bridge(): + return TransformerBridge.boot_transformers( + "", device="cpu", dtype=torch.float32, attn_implementation="eager", + ) +``` + +If the reference model OOMs CI, gate with `@pytest.mark.skipif(bool(os.getenv("CI")), ...)` and add a row to [`QUARANTINES.md`](../../../tests/QUARANTINES.md) under "CI cost / network budget." + +### End-to-end registry verification + +```bash +set -a; source .env; set +a +uv run python -m transformer_lens.tools.model_registry.verify_models --architectures MyArchForCausalLM --per-arch 1 +``` + +Wrapped by [`/verify-model`](../../../.claude/commands/verify-model.md). See [tools/model_registry/AGENTS.md](../../tools/model_registry/AGENTS.md) for the contract and the [`verify_models` vs `main_benchmark` trap](../../tools/model_registry/AGENTS.md#tldr). For parity failures: [debugging_numerical_divergence.md](../../../docs/source/content/debugging_numerical_divergence.md). + +--- + +## Hard "don'ts" + +- **No single-model files** — one file per architecture family. +- **No `# type: ignore` on `ComponentMapping`** — `isinstance` / `cast` ([AGENTS.md §10](../../../AGENTS.md#10-hard-rules)). +- **No skipping any of the four registration sites** — each breaks something different. [Invariant test](../../../tests/unit/tools/test_model_registry.py) catches most. +- **No HT-style hook aliases inside adapters** — aliases live in [`bridge.py`](../bridge.py); adapters declare Bridge-native names only. +- **No inheriting tokenizer-policy flags from the starter** — see [Tokenizer policy](#tokenizer-policy). +- **No manual edits to existing `supported_models.json` entries' status/phase fields** — only `update_model_status()`. +- **No skipping the unit adapter test** — required for every architecture. diff --git a/transformer_lens/tools/model_registry/AGENTS.md b/transformer_lens/tools/model_registry/AGENTS.md new file mode 100644 index 000000000..e3ac11134 --- /dev/null +++ b/transformer_lens/tools/model_registry/AGENTS.md @@ -0,0 +1,153 @@ +# Model Registry — AGENTS.md + +Read [the root AGENTS.md](../../../AGENTS.md) for project-wide rules. This file covers only conventions specific to `transformer_lens/tools/model_registry/`. + +--- + +## TL;DR + +> **Use `verify_models`, not `main_benchmark`.** Only `verify_models` writes `data/supported_models.json`. `main_benchmark` runs the same math but defaults to NOT writing the registry (needs `--update-registry`, and even with it misses Phase 7/8 scores and the resume checkpoint). If you ran `main_benchmark`, the registry is stale. + +- **`update_model_status()` in `registry_io.py` is the only mutator of status/phase/note on existing entries.** Never set by hand. +- **Adding a new model-ID entry is allowed** (required before `verify_models --model ` can find it). See [Adding a new model entry](#adding-a-new-model-entry). +- **Never run in parallel** — single CUDA/MPS device OOMs ([AGENTS.md §10](../../../AGENTS.md#10-hard-rules)). +- **HF token required** for parameter estimation on gated models. Source `.env`: `set -a; source .env; set +a`. + +--- + +## Canonical invocations + +| Goal | Command | +|---|---| +| Verify one specific model + update registry | `uv run python -m transformer_lens.tools.model_registry.verify_models --model ` | +| Verify N models per architecture family | `uv run python -m transformer_lens.tools.model_registry.verify_models --architectures --per-arch ` | +| Verify N models across all architectures | `uv run python -m transformer_lens.tools.model_registry.verify_models --per-arch ` | +| Resume after Ctrl-C / crash | Re-run the same command with `--resume` (reads `data/verification_checkpoint.json`) | +| Re-verify already-verified models for an arch | `--reverify --architectures ` | +| See what would run without doing it | Add `--dry-run` | +| Restrict to specific phases | `--phases 1 2 3` | +| Override device / dtype / memory cap | `--device cuda --dtype float32 --max-memory 16` | + +`HFClassName` matches the strings in `HF_SUPPORTED_ARCHITECTURES` (see `__init__.py`) — e.g. `LlamaForCausalLM`, `GPT2LMHeadModel`, `Olmo2ForCausalLM`. + +### Flag reference + +| Flag | Meaning | +|---|---| +| `--model ` | Verify a single HF repo (must already exist as an entry in `supported_models.json`) | +| `--architectures ` | Restrict to one or more HF architecture classes | +| `--per-arch ` | Verify the top-N unverified models per architecture (default 10) | +| `--limit ` | Cap total models verified across all architectures | +| `--device ` | Override automatic device selection | +| `--dtype ` | Override automatic dtype selection | +| `--max-memory ` | Skip models whose parameter-count estimate exceeds this GB cap (default: tries every model that fits available device memory). Use this to avoid OOM on a small device — e.g. `--max-memory 16` on a 24 GB GPU leaves head-room for activations. | +| `--phases ` | Restrict to specific phases (default `1 2 3 4`; Phase 7/8 are auto-skipped for non-applicable architectures) | +| `--resume` | Read `data/verification_checkpoint.json` and skip models already tested in the in-flight run | +| `--reverify` | Re-test already-verified models (default skips status=1 entries) | +| `--retry-failed` | Re-test status=3 (failed) entries | +| `--dry-run` | Print what would be tested without running | +| `--no-hf-reference` / `--no-ht-reference` | Skip the HF / HT comparison passes (faster, lower confidence) | +| `--quiet` | Suppress per-model logging | + +--- + +## File roles + +| File | Role | +|---|---| +| `verify_models.py` | **Canonical CLI** for batch verification + registry updates | +| `registry_io.py` | I/O for `supported_models.json`; `update_model_status()` is the only writer | +| `verification.py` | `VerificationRecord` / `VerificationHistory` dataclasses (audit-trail schema) | +| `validate.py` | JSON-schema validation for registry files | +| `api.py` | Read-only programmatic access (`is_model_supported`, `get_architecture_models`, …) | +| `schemas.py` | Dataclasses for model entries, scan info, architecture stats | +| `exceptions.py` | Custom exception types | +| `alias_drift.py` | Detects when legacy `MODEL_ALIASES` and the registry have diverged | +| `discover_architectures.py` | Lightweight HF scan to enumerate architecture classes | +| `hf_scraper.py` | Full HF Hub scan; builds initial supported/unsupported model lists | +| `relevancy.py` | Filters models by download count, foundation-org provenance | +| `generate_report.py` | Renders human-readable status summaries; holds `ARCHITECTURE_DESCRIPTIONS` | + +`__init__.py` exports the canonical `HF_SUPPORTED_ARCHITECTURES` set and `CANONICAL_AUTHORS_BY_ARCH` map; agents adding a new HF architecture must update both. + +--- + +## Adding a new model entry + +To verify a model not yet in `data/supported_models.json`, hand-add the entry first. This is the **only** allowed hand-edit: + +```json +{ + "architecture_id": "MyArchForCausalLM", + "model_id": "org/repo-name", + "status": 0, + "verified_date": null, "metadata": null, "note": null, + "phase1_score": null, "phase2_score": null, "phase3_score": null, + "phase4_score": null, "phase7_score": null, "phase8_score": null +} +``` + +`verify_models --model org/repo-name` then populates status/score/note via `update_model_status()`. Never set those fields manually. + +--- + +## `data/verification_checkpoint.json` (gitignored) + +Resume state for long-running runs (tested/verified/failed/skipped IDs + timestamp): + +- Ctrl-C → SIGINT handler finishes current model, persists checkpoint, exits cleanly. +- `--resume` reads it, skips already-tested models. +- Deleted on successful full run; missing/corrupt → fresh run (safe). + +Never edit manually. + +--- + +## Phase reference + +`verify_models` runs the model through phases and writes per-phase scores back into the registry entry. Phases (some don't apply to every architecture — see `applicable_phases` on the adapter): + +| Phase | Checks | +|---|---| +| 1 | Core forward correctness vs HuggingFace logits | +| 2 | Hook firing + gradient flow | +| 3 | Weight processing (compatibility mode, fold/centre) | +| 4 | Text-generation quality | +| 7 | Multimodal (vision/text alignment) — only Llava / Gemma3-multimodal | +| 8 | Audio — only Hubert | + +### Phase-score thresholds + +`verify_models` enforces hard pass/fail at the thresholds in `_MIN_PHASE_SCORES` ([`verify_models.py:508`](verify_models.py)). Below threshold OR a required-test failure → `STATUS_FAILED`. The contract: + +| Phase | Min score | Required tests | Effect when below threshold or required tests fail | +|---|---|---|---| +| 1 | **100%** | — | `STATUS_FAILED` | +| 2 | 75% | `logits_equivalence`, `loss_equivalence` | `STATUS_FAILED` | +| 3 | 75% | `logits_equivalence`, `loss_equivalence` | `STATUS_FAILED` | +| 4 | 50% | — | **Non-gating.** Below 50% adds `"low text quality"` to the registry `note`; never causes `STATUS_FAILED`. | +| 7 | 75% | `multimodal_forward` | `STATUS_FAILED`. NULL score (processor unavailable) also fails. | +| 8 | 75% | `audio_forward` | `STATUS_FAILED`. NULL score also fails. | + +Phase 4 is intentionally lenient — source ([`verify_models.py:554`](verify_models.py)) calls it *"a quality metric, not a correctness check."* The 50% bar asks "is the text coherent at all?" not "is this adapter clean?" + +**For adapter authors:** a `STATUS_VERIFIED` entry with P4 well below 100% on a small parity-test model can still indicate a real bug the system doesn't gate on (e.g. missing `preprocess_weights` fold). Investigate manually even when VERIFIED. + +**Reading the result:** + +- `status==1` + `note="Full verification completed"` → all gates passed, no quality flag. Good. +- `status==1` + `note` mentions `"low text quality"` → P4 < 50%; investigate. +- `status==1` + P4 < 100% on a small model, no quality flag → potential weight-fold/tokenizer bug; investigate. +- `status==3` (FAILED) → `note` carries the failure reason; debug from there. + +P1/P3 failures: [supported_architectures/AGENTS.md §When to override preprocess_weights](../../model_bridge/supported_architectures/AGENTS.md#when-to-override-preprocess_weights), [debugging_numerical_divergence.md](../../../docs/source/content/debugging_numerical_divergence.md). P4 drift: [§Tokenizer policy](../../model_bridge/supported_architectures/AGENTS.md#tokenizer-policy) (logit-scale / embedding-scale folds typically degrade P4 without crossing the 50% gate). + +--- + +## Hard "don'ts" + +- **No `main_benchmark` for registry updates** — misses P7/P8, no checkpoint, no registry write without `--update-registry`. +- **No parallel `verify_models`** — device OOM ([AGENTS.md §10](../../../AGENTS.md#10-hard-rules)). +- **No manual edits to existing entries' `status`/`verified_date`/`note`/`phaseN_score`** — only `update_model_status()` writes those. (New entries OK — see [Adding a new model entry](#adding-a-new-model-entry).) +- **No deleting `data/verification_checkpoint.json` mid-run** — let SIGINT clean up. +- **No skipping `.env`** — gated-model verification needs `HF_TOKEN`.