diff --git a/demos/Realtime_Training_Telemetry_Demo.ipynb b/demos/Realtime_Training_Telemetry_Demo.ipynb
index d77d67274..1e52ee7c6 100644
--- a/demos/Realtime_Training_Telemetry_Demo.ipynb
+++ b/demos/Realtime_Training_Telemetry_Demo.ipynb
@@ -1,464 +1,468 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ctqPlNKhJFgZ"
+ },
+ "source": [
+ "# Real-time Training Telemetry Demo\n",
+ "\n",
+ "This notebook demonstrates the real-time extraction and visualization of mechanistic metrics during a model's training loop using TransformerLens.\n",
+ "\n",
+ "By leveraging dynamic dictionary logging alongside the `ActivationCache`, we can isolate the training window where localized phase transitions—such as the formation of induction heads—begin to emerge.\n",
+ "\n",
+ "**A note on scaling**: While this 2-layer toy model allows for high-granularity tracking with minimal computational overhead, achieving similar resolution in larger architectures is non-trivial. It requires highly targeted caching and direct manipulation of the telemetry bridge to surface this level of detail without memory exhaustion.\n",
+ "\n",
+ "**Compute requirements**: A standard CPU is entirely sufficient for this demonstration. The 500-step training loop will execute rapidly in standard local or cloud-based notebook environments without the need for hardware acceleration."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tQ_xnziNJlsF"
+ },
+ "source": [
+ "Initializes the workspace, configures Plotly renderers, and defines the toy model architecture.\n",
+ "\n",
+ "**Visualization Context** `Plotly Renderers`:\n",
+ "\n",
+ "`Plotly` generates interactive, JavaScript-based visualizations. Google Colab handles these DOM interactions differently than local Jupyter or VS Code environments. We detect the active environment to set the appropriate `plotly.io` renderer (`\"colab\"` vs `\"notebook_connected\"`), ensuring the dynamic telemetry plots render correctly without blank output blocks.\n",
+ "\n",
+ "**Architectural Rationale:**\n",
+ "\n",
+ "\n",
+ "* **2 Layers (`n_layers=2`):** The theoretical minimum depth required for induction circuits. Layer 0 creates \"previous token\" representations, and Layer 1 queries these to predict the next token based on earlier context.\n",
+ "\n",
+ "* **2 Heads (`n_heads=2`):** Provides just enough capacity for heads to specialize (e.g., dedicating one head to induction) without creating excessive noise in the telemetry visualizations.\n",
+ "\n",
+ "* **GELU Activation (`act_fn=\"gelu\"`):** Selected over ReLU to mirror the smooth non-linearities of modern production LLMs, ensuring the activation dynamics remain representative of real-world architectures.\n",
+ "\n",
+ "* **Miniaturized Dimensions:** `d_model=64`, `d_vocab=64`, and `n_ctx=32` are intentionally bottlenecked to force rapid convergence, reliably inducing the phase transition within a brief 500-step training window."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
"colab": {
- "provenance": []
+ "base_uri": "https://localhost:8080/"
},
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
+ "id": "GyA-sxLOLRVA",
+ "outputId": "1cc176e2-7638-4106-b063-edce512c6f6d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Environment: Local / Standard Jupyter\n",
+ "Plotly Renderer: notebook_connected\n",
+ "Running on cpu\n"
+ ]
}
+ ],
+ "source": [
+ "import torch\n",
+ "import numpy as np\n",
+ "\n",
+ "# Detect execution environment\n",
+ "try:\n",
+ " import google.colab # noqa: F401\n",
+ " IN_COLAB = True\n",
+ " print(\"Environment: Google Colab\")\n",
+ "except ImportError:\n",
+ " IN_COLAB = False\n",
+ " print(\"Environment: Local / Standard Jupyter\")\n",
+ "\n",
+ "# Environment-specific dependency management\n",
+ "if IN_COLAB:\n",
+ " %pip install -q transformer_lens\n",
+ " %pip install -q circuitsvis\n",
+ "\n",
+ "import plotly.io as pio\n",
+ "import plotly.graph_objects as go\n",
+ "from plotly.subplots import make_subplots\n",
+ "\n",
+ "# Configure Plotly renderer for correct JavaScript execution\n",
+ "if IN_COLAB:\n",
+ " pio.renderers.default = \"colab\"\n",
+ "else:\n",
+ " pio.renderers.default = \"notebook_connected\"\n",
+ "print(f\"Plotly Renderer: {pio.renderers.default}\")\n",
+ "\n",
+ "# Must be imported after Colab pip install\n",
+ "from transformer_lens import TransformerBridgeConfig # noqa: E402\n",
+ "from transformer_lens.model_bridge import TransformerBridge # noqa: E402\n",
+ "\n",
+ "# Configuration\n",
+ "torch.manual_seed(42)\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "print(f\"Running on {device}\")\n",
+ "\n",
+ "# TransformerBridge.boot_native builds a small TL-native model (no HuggingFace\n",
+ "# Hub call, no `transformers` import) and wraps it in a bridge. The bridge's\n",
+ "# `forward`, `run_with_cache`, and `parameters()` surfaces are the same ones\n",
+ "# you'd use on any other bridge-backed model.\n",
+ "cfg = TransformerBridgeConfig(\n",
+ " n_layers=2,\n",
+ " d_model=64,\n",
+ " d_head=32,\n",
+ " n_heads=2,\n",
+ " d_mlp=256,\n",
+ " d_vocab=64,\n",
+ " n_ctx=32,\n",
+ " act_fn=\"gelu\",\n",
+ " normalization_type=\"LN\",\n",
+ " seed=42,\n",
+ ")\n",
+ "model = TransformerBridge.boot_native(cfg, device=device)\n"
+ ]
},
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "# Real-time Training Telemetry Demo\n",
- "\n",
- "This notebook demonstrates the real-time extraction and visualization of mechanistic metrics during a model's training loop using TransformerLens.\n",
- "\n",
- "By leveraging dynamic dictionary logging alongside the `ActivationCache`, we can isolate the training window where localized phase transitions\u2014such as the formation of induction heads\u2014begin to emerge.\n",
- "\n",
- "**A note on scaling**: While this 2-layer toy model allows for high-granularity tracking with minimal computational overhead, achieving similar resolution in larger architectures is non-trivial. It requires highly targeted caching and direct manipulation of the telemetry bridge to surface this level of detail without memory exhaustion.\n",
- "\n",
- "**Compute requirements**: A standard CPU is entirely sufficient for this demonstration. The 500-step training loop will execute rapidly in standard local or cloud-based notebook environments without the need for hardware acceleration."
- ],
- "metadata": {
- "id": "ctqPlNKhJFgZ"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Initializes the workspace, configures Plotly renderers, and defines the toy model architecture.\n",
- "\n",
- "**Visualization Context** `Plotly Renderers`:\n",
- "\n",
- "`Plotly` generates interactive, JavaScript-based visualizations. Google Colab handles these DOM interactions differently than local Jupyter or VS Code environments. We detect the active environment to set the appropriate `plotly.io` renderer (`\"colab\"` vs `\"notebook_connected\"`), ensuring the dynamic telemetry plots render correctly without blank output blocks.\n",
- "\n",
- "**Architectural Rationale:**\n",
- "\n",
- "\n",
- "* **2 Layers (`n_layers=2`):** The theoretical minimum depth required for induction circuits. Layer 0 creates \"previous token\" representations, and Layer 1 queries these to predict the next token based on earlier context.\n",
- "\n",
- "* **2 Heads (`n_heads=2`):** Provides just enough capacity for heads to specialize (e.g., dedicating one head to induction) without creating excessive noise in the telemetry visualizations.\n",
- "\n",
- "* **GELU Activation (`act_fn=\"gelu\"`):** Selected over ReLU to mirror the smooth non-linearities of modern production LLMs, ensuring the activation dynamics remain representative of real-world architectures.\n",
- "\n",
- "* **Miniaturized Dimensions:** `d_model=64`, `d_vocab=64`, and `n_ctx=32` are intentionally bottlenecked to force rapid convergence, reliably inducing the phase transition within a brief 500-step training window."
- ],
- "metadata": {
- "id": "tQ_xnziNJlsF"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import torch\n",
- "import numpy as np\n",
- "\n",
- "# Detect execution environment\n",
- "try:\n",
- " import google.colab # noqa: F401\n",
- " IN_COLAB = True\n",
- " print(\"Environment: Google Colab\")\n",
- "except ImportError:\n",
- " IN_COLAB = False\n",
- " print(\"Environment: Local / Standard Jupyter\")\n",
- "\n",
- "# Environment-specific dependency management\n",
- "if IN_COLAB:\n",
- " %pip install -q transformer_lens\n",
- " %pip install -q circuitsvis\n",
- "\n",
- "import plotly.io as pio\n",
- "import plotly.graph_objects as go\n",
- "from plotly.subplots import make_subplots\n",
- "\n",
- "# Configure Plotly renderer for correct JavaScript execution\n",
- "if IN_COLAB:\n",
- " pio.renderers.default = \"colab\"\n",
- "else:\n",
- " pio.renderers.default = \"notebook_connected\"\n",
- "print(f\"Plotly Renderer: {pio.renderers.default}\")\n",
- "\n",
- "# Must be imported after Colab pip install\n",
- "from transformer_lens import HookedTransformer, HookedTransformerConfig # noqa: E402\n",
- "\n",
- "# Configuration\n",
- "torch.manual_seed(42)\n",
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "print(f\"\ud83d\ude80 Running on {device}\")\n",
- "\n",
- "cfg = HookedTransformerConfig(\n",
- " n_layers=2,\n",
- " d_model=64,\n",
- " d_head=32,\n",
- " n_heads=2,\n",
- " d_mlp=256,\n",
- " d_vocab=64,\n",
- " n_ctx=32,\n",
- " act_fn=\"gelu\",\n",
- " normalization_type=\"LN\",\n",
- " seed=42,\n",
- ")\n",
- "model = HookedTransformer(cfg).to(device)"
- ],
- "metadata": {
- "id": "GyA-sxLOLRVA",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "1cc176e2-7638-4106-b063-edce512c6f6d"
- },
- "execution_count": 1,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Environment: Google Colab\n",
- " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
- "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m968.6/968.6 kB\u001b[0m \u001b[31m23.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m56.4/56.4 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h Building wheel for transformers-stream-generator (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
- "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hPlotly Renderer: colab\n",
- "\ud83d\ude80 Running on cpu\n",
- "Moving model to device: cpu\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Attention Telemetry Extraction\n",
- "\n",
- "This bridge extracts mechanistic metrics directly from the `ActivationCache` during the forward pass.\n",
- "\n",
- "* **Head Coherence:** Measures attention focus sharpness via normalized entropy. A score of $1.0$ indicates perfect focus on a single token, while $0.0$ indicates uniformly distributed attention.\n",
- "\n",
- "* **Head Agreement:** Evaluates intra-layer behavioral similarity among attention heads.\n",
- "* **Variance Normalization:** The agreement metric normalizes against a `0.005` variance constant. This serves as an expected baseline for inter-head variance at the point of specialization in 2-layer models.\n",
- "\n",
- "**Note:** This constant is a localized architectural assumption. Recalibrating this threshold will likely be necessary when porting the telemetry bridge to larger, higher-dimensional models."
- ],
- "metadata": {
- "id": "0hMNO0WpKXhR"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class AttentionTelemetry:\n",
- " \"\"\"Lightweight bridge extracting mechanistic metrics from ActivationCache.\"\"\"\n",
- "\n",
- " @staticmethod\n",
- " def compute_metrics(cache, layer_idx):\n",
- " \"\"\"Computes attention coherence and agreement for a given layer.\n",
- "\n",
- " Args:\n",
- " cache (ActivationCache): The cached activations from the forward pass.\n",
- " layer_idx (int): The index of the layer to analyze.\n",
- "\n",
- " Returns:\n",
- " dict: A dictionary containing layer_idx, head_coherence, and head_agreement.\n",
- "\n",
- " Notes on v_max (0.005):\n",
- " The agreement normalization constant is derived from the expected inter-head\n",
- " attention variance at convergence in 2-layer induction head toy models.\n",
- " At convergence, heads specialize (low variance); pre-convergence variance\n",
- " peaks near 0.005. This value is task- and architecture-specific; adjust\n",
- " if adapting to larger models or different tasks.\n",
- " \"\"\"\n",
- "\n",
- " pattern_name = f\"blocks.{layer_idx}.attn.hook_pattern\"\n",
- "\n",
- " # Shape: [batch, heads, seq, seq]\n",
- " attn_probs = cache[pattern_name]\n",
- "\n",
- " # 1. Head Coherence: 1.0 - (Entropy / Max_Entropy)\n",
- " probs = attn_probs + 1e-9\n",
- " entropy = -torch.sum(probs * torch.log(probs), dim=-1) # [batch, heads, seq]\n",
- " head_coherence = 1.0 - (entropy.mean(dim=[0, 2]) / np.log(attn_probs.shape[-1]))\n",
- "\n",
- " # 2. Head Agreement: 1.0 - clip(Variance / v_max)\n",
- " mean_var = torch.var(attn_probs, dim=1).mean() # Variance across heads\n",
- " head_agreement = 1.0 - torch.clamp(mean_var / 0.005, 0.0, 1.0)\n",
- "\n",
- " return {\n",
- " \"layer_idx\": layer_idx,\n",
- " \"head_coherence\": head_coherence.mean().item(),\n",
- " \"head_agreement\": head_agreement.item()\n",
- " }"
- ],
- "metadata": {
- "id": "fd_0TLTUUNf6"
- },
- "execution_count": 2,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Synthetic Data Generation & Training Loop\n",
- "\n",
- "This cell generates a highly constrained dataset to force circuit formation and executes the training loop, capturing telemetry at specified intervals.\n",
- "\n",
- "**A Note on Synthetic Data:**\n",
- "\n",
- "The repeated sequence generator (`[A, B, C, ..., A, B, C]`) used below is strictly illustrative. It is engineered specifically as a shortcut to force the rapid emergence of in-context look-back circuits (induction heads). It is not meant to serve as an educational standard for model training.\n",
- "\n",
- "**Transitioning to Real Data:**\n",
- "\n",
- "Applying this telemetry extraction to real-world datasets requires rigorous attention to detail regarding:\n",
- "\n",
- "* **Data Quality:** Unstructured noise in the input distribution will severely obscure the mechanistic signals (like coherence and agreement) you are attempting to isolate.\n",
- "\n",
- "* **Data Type & Tokenization:** Real-world text requires careful handling of padding, EOS/BOS tokens, and sequence packing, all of which dynamically alter attention patterns and can skew your baseline metrics.\n",
- "\n",
- "* **Ingestion Methodology:** Managing data loaders, batching, and ensuring that telemetry logging steps align with representative samples is critical to preventing metric distortion.\n",
- "\n",
- "**Performance Optimization:** To prevent the telemetry capture from bottlenecking the training process, `model.run_with_cache` is exclusively executed at logging intervals. Standard forward passes bypass the cache entirely."
- ],
- "metadata": {
- "id": "p2TYMU23Kt2K"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "def generate_induction_data(batch_size, seq_len, vocab_size, device=\"cpu\"):\n",
- " half_len = seq_len // 2\n",
- " first_half = torch.randint(0, vocab_size, (batch_size, half_len))\n",
- " data = torch.cat([first_half, first_half], dim=1)\n",
- " return data.to(device)"
- ],
- "metadata": {
- "id": "t6NIpzbKK0_K"
- },
- "execution_count": 3,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# The Impact of Tokenization on Telemetry\n",
- "\n",
- "When moving from synthetic generators to real-world text, understanding the mechanistic role of special tokens is vital to preventing metric distortion.\n",
- "\n",
- "**BOS (Begin of Sequence) as an Attention Sink:**\n",
- "\n",
- "In many autoregressive models, attention heads route their focus to the first token (BOS) when they do not have a strong contextual match elsewhere. This is known as an \"attention sink.\" If unaccounted for, your telemetry will show artificially high **Head Coherence** (because the head is sharply focused on a single token). However, this represents a resting state rather than active circuit engagement.\n",
- "\n",
- "**EOS (End of Sequence) & Padding:**\n",
- "\n",
- "Real data requires batching sequences of variable lengths, necessitating padding tokens and EOS markers to denote where a document ends.\n",
- "\n",
- "* **Masking Failures:** If attention masks are not perfectly aligned with your telemetry extraction, heads might attend to padding tokens, introducing garbage data that drastically skews your **Agreement** metrics.\n",
- "\n",
- "* **Context Resets:** The transition between unrelated documents (separated by an EOS token) disrupts the contiguous context window. This resets the look-back mechanisms that induction circuits rely on, causing momentary drops in otherwise stable telemetry"
- ],
- "metadata": {
- "id": "M8sQ0pCuLZnB"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "# The Training & Real-Time Telemetry Loop\n",
- "\n",
- "Executing the training sequence while dynamically tracking phase transitions.\n",
- "\n",
- "**Compute Efficiency (Selective Caching):**\n",
- "\n",
- "To prevent the extraction process from suffocating the CPU/GPU memory bandwidth, we employ selective caching. Standard forward passes operate normally; `model.run_with_cache` is exclusively invoked at defined logging intervals to extract the telemetry state without severely bottlenecking the training step.\n",
- "\n",
- "**Rendering Optimization (The Real-Time UI):**\n",
- "\n",
- "To achieve real-time visualization without crashing the browser's DOM or throttling the PyTorch loop:\n",
- "\n",
- "1. **Memory Pre-allocation:** We pre-allocate `NaN` arrays for the telemetry traces, completely bypassing costly array reallocation during the loop.\n",
- "\n",
- "2. **In-Place Mutation:** Instead of generating hundreds of static Plotly objects, we mutate the figure's trace data directly and use `IPython.display.clear_output` to cleanly redraw the frame in the exact same output block.\n",
- "\n",
- "3. **Static Fallback:** *If you wish to bypass real-time rendering to maximize training speed, simply comment out the `clear_output(wait=True)` and `fig.show()` lines inside the loop, and call `fig.show()` once at the very end of the cell.\n",
- "\n",
- "**Mechanistic Observation (The Phase Transition):**\n",
- "\n",
- "Watch the dual-plot for the localized phase transition: a distinct window where the model suddenly \"discovers\" the induction algorithm. This is marked by a violent crash in the loss curve and a simultaneous, sharp spike in the last layer's Attention Coherence."
- ],
- "metadata": {
- "id": "TGucpV41Ltxc"
- }
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0hMNO0WpKXhR"
+ },
+ "source": [
+ "# Attention Telemetry Extraction\n",
+ "\n",
+ "This bridge extracts mechanistic metrics directly from the `ActivationCache` during the forward pass.\n",
+ "\n",
+ "* **Head Coherence:** Measures attention focus sharpness via normalized entropy. A score of $1.0$ indicates perfect focus on a single token, while $0.0$ indicates uniformly distributed attention.\n",
+ "\n",
+ "* **Head Agreement:** Evaluates intra-layer behavioral similarity among attention heads.\n",
+ "* **Variance Normalization:** The agreement metric normalizes against a `0.005` variance constant. This serves as an expected baseline for inter-head variance at the point of specialization in 2-layer models.\n",
+ "\n",
+ "**Note:** This constant is a localized architectural assumption. Recalibrating this threshold will likely be necessary when porting the telemetry bridge to larger, higher-dimensional models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "fd_0TLTUUNf6"
+ },
+ "outputs": [],
+ "source": [
+ "class AttentionTelemetry:\n",
+ " \"\"\"Lightweight bridge extracting mechanistic metrics from ActivationCache.\"\"\"\n",
+ "\n",
+ " @staticmethod\n",
+ " def compute_metrics(cache, layer_idx):\n",
+ " \"\"\"Computes attention coherence and agreement for a given layer.\n",
+ "\n",
+ " Args:\n",
+ " cache (ActivationCache): The cached activations from the forward pass.\n",
+ " layer_idx (int): The index of the layer to analyze.\n",
+ "\n",
+ " Returns:\n",
+ " dict: A dictionary containing layer_idx, head_coherence, and head_agreement.\n",
+ "\n",
+ " Notes on v_max (0.005):\n",
+ " The agreement normalization constant is derived from the expected inter-head\n",
+ " attention variance at convergence in 2-layer induction head toy models.\n",
+ " At convergence, heads specialize (low variance); pre-convergence variance\n",
+ " peaks near 0.005. This value is task- and architecture-specific; adjust\n",
+ " if adapting to larger models or different tasks.\n",
+ " \"\"\"\n",
+ "\n",
+ " pattern_name = f\"blocks.{layer_idx}.attn.hook_pattern\"\n",
+ "\n",
+ " # Shape: [batch, heads, seq, seq]\n",
+ " attn_probs = cache[pattern_name]\n",
+ "\n",
+ " # 1. Head Coherence: 1.0 - (Entropy / Max_Entropy)\n",
+ " probs = attn_probs + 1e-9\n",
+ " entropy = -torch.sum(probs * torch.log(probs), dim=-1) # [batch, heads, seq]\n",
+ " head_coherence = 1.0 - (entropy.mean(dim=[0, 2]) / np.log(attn_probs.shape[-1]))\n",
+ "\n",
+ " # 2. Head Agreement: 1.0 - clip(Variance / v_max)\n",
+ " mean_var = torch.var(attn_probs, dim=1).mean() # Variance across heads\n",
+ " head_agreement = 1.0 - torch.clamp(mean_var / 0.005, 0.0, 1.0)\n",
+ "\n",
+ " return {\n",
+ " \"layer_idx\": layer_idx,\n",
+ " \"head_coherence\": head_coherence.mean().item(),\n",
+ " \"head_agreement\": head_agreement.item()\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "p2TYMU23Kt2K"
+ },
+ "source": [
+ "# Synthetic Data Generation & Training Loop\n",
+ "\n",
+ "This cell generates a highly constrained dataset to force circuit formation and executes the training loop, capturing telemetry at specified intervals.\n",
+ "\n",
+ "**A Note on Synthetic Data:**\n",
+ "\n",
+ "The repeated sequence generator (`[A, B, C, ..., A, B, C]`) used below is strictly illustrative. It is engineered specifically as a shortcut to force the rapid emergence of in-context look-back circuits (induction heads). It is not meant to serve as an educational standard for model training.\n",
+ "\n",
+ "**Transitioning to Real Data:**\n",
+ "\n",
+ "Applying this telemetry extraction to real-world datasets requires rigorous attention to detail regarding:\n",
+ "\n",
+ "* **Data Quality:** Unstructured noise in the input distribution will severely obscure the mechanistic signals (like coherence and agreement) you are attempting to isolate.\n",
+ "\n",
+ "* **Data Type & Tokenization:** Real-world text requires careful handling of padding, EOS/BOS tokens, and sequence packing, all of which dynamically alter attention patterns and can skew your baseline metrics.\n",
+ "\n",
+ "* **Ingestion Methodology:** Managing data loaders, batching, and ensuring that telemetry logging steps align with representative samples is critical to preventing metric distortion.\n",
+ "\n",
+ "**Performance Optimization:** To prevent the telemetry capture from bottlenecking the training process, `model.run_with_cache` is exclusively executed at logging intervals. Standard forward passes bypass the cache entirely."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "t6NIpzbKK0_K"
+ },
+ "outputs": [],
+ "source": [
+ "def generate_induction_data(batch_size, seq_len, vocab_size, device=\"cpu\"):\n",
+ " half_len = seq_len // 2\n",
+ " first_half = torch.randint(0, vocab_size, (batch_size, half_len))\n",
+ " data = torch.cat([first_half, first_half], dim=1)\n",
+ " return data.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "M8sQ0pCuLZnB"
+ },
+ "source": [
+ "# The Impact of Tokenization on Telemetry\n",
+ "\n",
+ "When moving from synthetic generators to real-world text, understanding the mechanistic role of special tokens is vital to preventing metric distortion.\n",
+ "\n",
+ "**BOS (Begin of Sequence) as an Attention Sink:**\n",
+ "\n",
+ "In many autoregressive models, attention heads route their focus to the first token (BOS) when they do not have a strong contextual match elsewhere. This is known as an \"attention sink.\" If unaccounted for, your telemetry will show artificially high **Head Coherence** (because the head is sharply focused on a single token). However, this represents a resting state rather than active circuit engagement.\n",
+ "\n",
+ "**EOS (End of Sequence) & Padding:**\n",
+ "\n",
+ "Real data requires batching sequences of variable lengths, necessitating padding tokens and EOS markers to denote where a document ends.\n",
+ "\n",
+ "* **Masking Failures:** If attention masks are not perfectly aligned with your telemetry extraction, heads might attend to padding tokens, introducing garbage data that drastically skews your **Agreement** metrics.\n",
+ "\n",
+ "* **Context Resets:** The transition between unrelated documents (separated by an EOS token) disrupts the contiguous context window. This resets the look-back mechanisms that induction circuits rely on, causing momentary drops in otherwise stable telemetry"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TGucpV41Ltxc"
+ },
+ "source": [
+ "# The Training & Real-Time Telemetry Loop\n",
+ "\n",
+ "Executing the training sequence while dynamically tracking phase transitions.\n",
+ "\n",
+ "**Compute Efficiency (Selective Caching):**\n",
+ "\n",
+ "To prevent the extraction process from suffocating the CPU/GPU memory bandwidth, we employ selective caching. Standard forward passes operate normally; `model.run_with_cache` is exclusively invoked at defined logging intervals to extract the telemetry state without severely bottlenecking the training step.\n",
+ "\n",
+ "**Rendering Optimization (The Real-Time UI):**\n",
+ "\n",
+ "To achieve real-time visualization without crashing the browser's DOM or throttling the PyTorch loop:\n",
+ "\n",
+ "1. **Memory Pre-allocation:** We pre-allocate `NaN` arrays for the telemetry traces, completely bypassing costly array reallocation during the loop.\n",
+ "\n",
+ "2. **In-Place Mutation:** Instead of generating hundreds of static Plotly objects, we mutate the figure's trace data directly and use `IPython.display.clear_output` to cleanly redraw the frame in the exact same output block.\n",
+ "\n",
+ "3. **Static Fallback:** *If you wish to bypass real-time rendering to maximize training speed, simply comment out the `clear_output(wait=True)` and `fig.show()` lines inside the loop, and call `fig.show()` once at the very end of the cell.\n",
+ "\n",
+ "**Mechanistic Observation (The Phase Transition):**\n",
+ "\n",
+ "Watch the dual-plot for the localized phase transition: a distinct window where the model suddenly \"discovers\" the induction algorithm. This is marked by a violent crash in the loss curve and a simultaneous, sharp spike in the last layer's Attention Coherence."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 717
},
+ "id": "Yj90yewlLvuh",
+ "outputId": "c1d1d8a9-a3e8-4bec-8336-a0829f7c0872"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "source": [
- "from IPython.display import clear_output\n",
- "import numpy as np # noqa: F811\n",
- "import torch\n",
- "\n",
- "# --- 1. Self-Contained Synthetic Data Generator ---\n",
- "def get_batch(batch_size=16, seq_len=model.cfg.n_ctx, vocab_size=model.cfg.d_vocab):\n",
- " \"\"\"Generates repeated sequences [A, B, C, A, B, C] to force induction circuitry.\"\"\"\n",
- " half_len = seq_len // 2\n",
- " random_tokens = torch.randint(0, vocab_size, (batch_size, half_len), device=device)\n",
- " return torch.cat([random_tokens, random_tokens], dim=1)\n",
- "\n",
- "# --- 2. Pre-allocate Memory for Real-Time Plotting ---\n",
- "TOTAL_STEPS = 500\n",
- "LOG_INTERVAL = 10\n",
- "num_logging_steps = TOTAL_STEPS // LOG_INTERVAL\n",
- "\n",
- "logged_steps = np.arange(0, TOTAL_STEPS, LOG_INTERVAL)\n",
- "\n",
- "loss_data = np.full(num_logging_steps, np.nan)\n",
- "coherence_data = np.full(num_logging_steps, np.nan)\n",
- "heatmap_data = np.full((model.cfg.n_layers, num_logging_steps), np.nan)\n",
- "\n",
- "# --- 3. Initialize the Plotly Figure ---\n",
- "fig = make_subplots(\n",
- " rows=2, cols=1,\n",
- " shared_xaxes=True,\n",
- " vertical_spacing=0.1,\n",
- " subplot_titles=(\n",
- " \"Phase Transition: Loss Crash vs. Circuit Formation\",\n",
- " \"Attention Coherence Heatmap by Layer Depth\"\n",
- " ),\n",
- " specs=[[{\"secondary_y\": True}], [{\"type\": \"heatmap\"}]]\n",
- ")\n",
- "\n",
- "# Trace 0: Loss Curve\n",
- "fig.add_trace(\n",
- " go.Scatter(x=logged_steps, y=loss_data, name=\"Loss (CE)\", line=dict(color='gray', dash='dash')),\n",
- " row=1, col=1, secondary_y=False\n",
- ")\n",
- "\n",
- "# Trace 1: Last Layer Coherence Curve\n",
- "last_layer_idx = model.cfg.n_layers - 1\n",
- "fig.add_trace(\n",
- " go.Scatter(x=logged_steps, y=coherence_data, name=f\"Layer {last_layer_idx} Coherence\", line=dict(color='#1f77b4', width=2.5)),\n",
- " row=1, col=1, secondary_y=True\n",
- ")\n",
- "\n",
- "# Trace 2: Layer Heatmap\n",
- "fig.add_trace(\n",
- " go.Heatmap(\n",
- " z=heatmap_data, x=logged_steps, y=[f\"L{i}\" for i in range(model.cfg.n_layers)],\n",
- " colorscale='Magma', zmin=0.0, zmax=1.0,\n",
- " colorbar=dict(title=\"Coherence (0-1)\", orientation='h', y=-0.25, len=0.5)\n",
- " ),\n",
- " row=2, col=1\n",
- ")\n",
- "\n",
- "fig.update_layout(height=700, template=\"plotly_white\", margin=dict(t=50, b=50))\n",
- "fig.update_yaxes(title_text=\"Cross Entropy Loss\", secondary_y=False, row=1, col=1)\n",
- "fig.update_yaxes(title_text=\"Coherence\", secondary_y=True, range=[0, 1.1], row=1, col=1)\n",
- "fig.update_xaxes(range=[0, TOTAL_STEPS])\n",
- "# --- 4. The Training & Telemetry Loop ---\n",
- "model.train()\n",
- "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n",
- "\n",
- "log_idx = 0\n",
- "for step in range(TOTAL_STEPS):\n",
- " batch = get_batch()\n",
- "\n",
- " # Selective Caching Strategy\n",
- " if step % LOG_INTERVAL == 0:\n",
- " loss, cache = model.run_with_cache(batch, return_type=\"loss\")\n",
- "\n",
- " # Update baseline data arrays\n",
- " loss_data[log_idx] = loss.item()\n",
- "\n",
- " # Extract mechanistic metrics using the static method from Cell 3\n",
- " for layer in range(model.cfg.n_layers):\n",
- " layer_metrics = AttentionTelemetry.compute_metrics(cache, layer)\n",
- " heatmap_data[layer, log_idx] = layer_metrics['head_coherence']\n",
- "\n",
- " # Specifically grab the last layer for the line graph\n",
- " if layer == last_layer_idx:\n",
- " coherence_data[log_idx] = layer_metrics['head_coherence']\n",
- "\n",
- " # Mutate Plotly traces in-place\n",
- " fig.data[0].x = logged_steps\n",
- " fig.data[0].y = loss_data\n",
- "\n",
- " fig.data[1].x = logged_steps\n",
- " fig.data[1].y = coherence_data\n",
- "\n",
- " fig.data[2].x = logged_steps\n",
- " fig.data[2].z = heatmap_data\n",
- "\n",
- " # Redraw the UI\n",
- " clear_output(wait=True)\n",
- " fig.show()\n",
- "\n",
- " log_idx += 1\n",
- " else:\n",
- " # Standard forward pass (bypassing the cache for speed)\n",
- " loss = model(batch, return_type=\"loss\")\n",
- "\n",
- " # Standard PyTorch Optimization\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " optimizer.zero_grad()"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 717
- },
- "outputId": "c1d1d8a9-a3e8-4bec-8336-a0829f7c0872",
- "id": "Yj90yewlLvuh"
- },
- "execution_count": 4,
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/html": [
- "\n",
- "
\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {}
- }
+ "data": {
+ "text/html": [
+ ""
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
- ]
-}
\ No newline at end of file
+ ],
+ "source": [
+ "from IPython.display import clear_output\n",
+ "import numpy as np # noqa: F811\n",
+ "import torch\n",
+ "\n",
+ "# --- 1. Self-Contained Synthetic Data Generator ---\n",
+ "def get_batch(batch_size=16, seq_len=model.cfg.n_ctx, vocab_size=model.cfg.d_vocab):\n",
+ " \"\"\"Generates repeated sequences [A, B, C, A, B, C] to force induction circuitry.\"\"\"\n",
+ " half_len = seq_len // 2\n",
+ " random_tokens = torch.randint(0, vocab_size, (batch_size, half_len), device=device)\n",
+ " return torch.cat([random_tokens, random_tokens], dim=1)\n",
+ "\n",
+ "# --- 2. Pre-allocate Memory for Real-Time Plotting ---\n",
+ "TOTAL_STEPS = 500\n",
+ "LOG_INTERVAL = 10\n",
+ "num_logging_steps = TOTAL_STEPS // LOG_INTERVAL\n",
+ "\n",
+ "logged_steps = np.arange(0, TOTAL_STEPS, LOG_INTERVAL)\n",
+ "\n",
+ "loss_data = np.full(num_logging_steps, np.nan)\n",
+ "coherence_data = np.full(num_logging_steps, np.nan)\n",
+ "heatmap_data = np.full((model.cfg.n_layers, num_logging_steps), np.nan)\n",
+ "\n",
+ "# --- 3. Initialize the Plotly Figure ---\n",
+ "fig = make_subplots(\n",
+ " rows=2, cols=1,\n",
+ " shared_xaxes=True,\n",
+ " vertical_spacing=0.1,\n",
+ " subplot_titles=(\n",
+ " \"Phase Transition: Loss Crash vs. Circuit Formation\",\n",
+ " \"Attention Coherence Heatmap by Layer Depth\"\n",
+ " ),\n",
+ " specs=[[{\"secondary_y\": True}], [{\"type\": \"heatmap\"}]]\n",
+ ")\n",
+ "\n",
+ "# Trace 0: Loss Curve\n",
+ "fig.add_trace(\n",
+ " go.Scatter(x=logged_steps, y=loss_data, name=\"Loss (CE)\", line=dict(color='gray', dash='dash')),\n",
+ " row=1, col=1, secondary_y=False\n",
+ ")\n",
+ "\n",
+ "# Trace 1: Last Layer Coherence Curve\n",
+ "last_layer_idx = model.cfg.n_layers - 1\n",
+ "fig.add_trace(\n",
+ " go.Scatter(x=logged_steps, y=coherence_data, name=f\"Layer {last_layer_idx} Coherence\", line=dict(color='#1f77b4', width=2.5)),\n",
+ " row=1, col=1, secondary_y=True\n",
+ ")\n",
+ "\n",
+ "# Trace 2: Layer Heatmap\n",
+ "fig.add_trace(\n",
+ " go.Heatmap(\n",
+ " z=heatmap_data, x=logged_steps, y=[f\"L{i}\" for i in range(model.cfg.n_layers)],\n",
+ " colorscale='Magma', zmin=0.0, zmax=1.0,\n",
+ " colorbar=dict(title=\"Coherence (0-1)\", orientation='h', y=-0.25, len=0.5)\n",
+ " ),\n",
+ " row=2, col=1\n",
+ ")\n",
+ "\n",
+ "fig.update_layout(height=700, template=\"plotly_white\", margin=dict(t=50, b=50))\n",
+ "fig.update_yaxes(title_text=\"Cross Entropy Loss\", secondary_y=False, row=1, col=1)\n",
+ "fig.update_yaxes(title_text=\"Coherence\", secondary_y=True, range=[0, 1.1], row=1, col=1)\n",
+ "fig.update_xaxes(range=[0, TOTAL_STEPS])\n",
+ "# --- 4. The Training & Telemetry Loop ---\n",
+ "model.train()\n",
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n",
+ "\n",
+ "log_idx = 0\n",
+ "for step in range(TOTAL_STEPS):\n",
+ " batch = get_batch()\n",
+ "\n",
+ " # Selective Caching Strategy\n",
+ " if step % LOG_INTERVAL == 0:\n",
+ " loss, cache = model.run_with_cache(batch, return_type=\"loss\")\n",
+ "\n",
+ " # Update baseline data arrays\n",
+ " loss_data[log_idx] = loss.item()\n",
+ "\n",
+ " # Extract mechanistic metrics using the static method from Cell 3\n",
+ " for layer in range(model.cfg.n_layers):\n",
+ " layer_metrics = AttentionTelemetry.compute_metrics(cache, layer)\n",
+ " heatmap_data[layer, log_idx] = layer_metrics['head_coherence']\n",
+ "\n",
+ " # Specifically grab the last layer for the line graph\n",
+ " if layer == last_layer_idx:\n",
+ " coherence_data[log_idx] = layer_metrics['head_coherence']\n",
+ "\n",
+ " # Mutate Plotly traces in-place\n",
+ " fig.data[0].x = logged_steps\n",
+ " fig.data[0].y = loss_data\n",
+ "\n",
+ " fig.data[1].x = logged_steps\n",
+ " fig.data[1].y = coherence_data\n",
+ "\n",
+ " fig.data[2].x = logged_steps\n",
+ " fig.data[2].z = heatmap_data\n",
+ "\n",
+ " # Redraw the UI\n",
+ " clear_output(wait=True)\n",
+ " fig.show()\n",
+ "\n",
+ " log_idx += 1\n",
+ " else:\n",
+ " # Standard forward pass (bypassing the cache for speed)\n",
+ " loss = model(batch, return_type=\"loss\")\n",
+ "\n",
+ " # Standard PyTorch Optimization\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " optimizer.zero_grad()"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "transformer-lens",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tests/integration/model_bridge/test_native_training.py b/tests/integration/model_bridge/test_native_training.py
new file mode 100644
index 000000000..e73855cf2
--- /dev/null
+++ b/tests/integration/model_bridge/test_native_training.py
@@ -0,0 +1,140 @@
+"""Integration: train a TL-native bridge on the induction task end-to-end.
+
+Mirrors the Realtime Telemetry demo's logic but with shorter step counts so
+the test stays in CI's time budget. Thresholds are deliberately qualitative
+(direction, not magnitude) so the test does not flake on
+BLAS/MPS-implementation differences across CI runners.
+"""
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from transformer_lens.config import TransformerBridgeConfig
+from transformer_lens.model_bridge import TransformerBridge
+
+
+def _induction_batch(batch_size: int, seq_len: int, vocab_size: int) -> torch.Tensor:
+ half = seq_len // 2
+ rand = torch.randint(0, vocab_size, (batch_size, half))
+ return torch.cat([rand, rand], dim=1)
+
+
+def _telemetry_cfg() -> TransformerBridgeConfig:
+ # Matches the demo cell-2 configuration so this test exercises the same
+ # surface a user would touch.
+ return TransformerBridgeConfig(
+ d_model=64,
+ d_head=32,
+ n_heads=2,
+ n_layers=2,
+ n_ctx=32,
+ d_vocab=64,
+ d_mlp=256,
+ act_fn="gelu",
+ normalization_type="LN",
+ seed=42,
+ )
+
+
+def test_native_bridge_training_decreases_loss():
+ """Loss must decrease meaningfully within a small step budget."""
+ torch.manual_seed(42)
+ cfg = _telemetry_cfg()
+ bridge = TransformerBridge.boot_native(cfg)
+ bridge.train()
+ optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3, weight_decay=1e-4)
+
+ initial_losses, final_losses = [], []
+ for step in range(300):
+ batch = _induction_batch(16, cfg.n_ctx, cfg.d_vocab)
+ loss = bridge(batch, return_type="loss")
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ if step < 5:
+ initial_losses.append(loss.item())
+ if step >= 295:
+ final_losses.append(loss.item())
+
+ initial_avg = sum(initial_losses) / len(initial_losses)
+ final_avg = sum(final_losses) / len(final_losses)
+ assert final_avg < initial_avg * 0.7, (
+ f"Loss did not decrease enough: initial_avg={initial_avg:.4f}, "
+ f"final_avg={final_avg:.4f} (expected < {initial_avg * 0.7:.4f})"
+ )
+
+
+def test_native_bridge_run_with_cache_during_training():
+ """run_with_cache must populate attention-pattern hooks with [B,H,S,S] shape
+ and support the demo's selective-caching pattern (call cache every K steps,
+ standard forward in between)."""
+ torch.manual_seed(0)
+ cfg = _telemetry_cfg()
+ bridge = TransformerBridge.boot_native(cfg)
+ bridge.train()
+ optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3)
+
+ batch = _induction_batch(8, cfg.n_ctx, cfg.d_vocab)
+
+ # Selective caching: cache step, then plain step, then cache step.
+ loss_cache_a, cache_a = bridge.run_with_cache(batch, return_type="loss")
+ loss_cache_a.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ loss_plain = bridge(batch, return_type="loss")
+ loss_plain.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ loss_cache_b, cache_b = bridge.run_with_cache(batch, return_type="loss")
+ loss_cache_b.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ for cache in (cache_a, cache_b):
+ for layer in range(cfg.n_layers):
+ key = f"blocks.{layer}.attn.hook_pattern"
+ assert key in cache, f"Missing {key}"
+ assert cache[key].shape == (8, cfg.n_heads, cfg.n_ctx, cfg.n_ctx), cache[key].shape
+
+
+def test_native_bridge_induction_circuit_forms():
+ """A circuit-forming proxy: at least one layer's attention coherence rises
+ substantially between init and step ~500. Computes the same coherence
+ metric the telemetry demo uses, but asserts a direction-only invariant."""
+ torch.manual_seed(42)
+ cfg = _telemetry_cfg()
+ bridge = TransformerBridge.boot_native(cfg)
+
+ def coherence_per_layer():
+ batch = _induction_batch(16, cfg.n_ctx, cfg.d_vocab)
+ with torch.no_grad():
+ _, cache = bridge.run_with_cache(batch, return_type="loss")
+ out = []
+ for layer in range(cfg.n_layers):
+ probs = cache[f"blocks.{layer}.attn.hook_pattern"] + 1e-9
+ entropy = -torch.sum(probs * torch.log(probs), dim=-1)
+ coh = 1.0 - (entropy.mean(dim=[0, 2]) / np.log(probs.shape[-1]))
+ out.append(coh.mean().item())
+ return out
+
+ coherence_initial = coherence_per_layer()
+
+ bridge.train()
+ optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3, weight_decay=1e-4)
+ for step in range(500):
+ batch = _induction_batch(16, cfg.n_ctx, cfg.d_vocab)
+ loss = bridge(batch, return_type="loss")
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ coherence_final = coherence_per_layer()
+
+ rises = [(f - i) for i, f in zip(coherence_initial, coherence_final)]
+ assert max(rises) > 0.2, (
+ f"No layer's coherence rose meaningfully. "
+ f"Initial={coherence_initial}, final={coherence_final}, deltas={rises}"
+ )
diff --git a/tests/unit/model_bridge/test_boot_native.py b/tests/unit/model_bridge/test_boot_native.py
new file mode 100644
index 000000000..f773e6115
--- /dev/null
+++ b/tests/unit/model_bridge/test_boot_native.py
@@ -0,0 +1,334 @@
+"""Tests for ``TransformerBridge.boot_native`` classmethod."""
+from __future__ import annotations
+
+import sys
+
+import torch
+
+from transformer_lens.config import TransformerBridgeConfig
+from transformer_lens.model_bridge import TransformerBridge
+from transformer_lens.model_bridge.sources.native import NativeModel
+
+
+def _cfg(**overrides) -> TransformerBridgeConfig:
+ base = dict(
+ d_model=32,
+ d_head=16,
+ n_heads=2,
+ n_layers=1,
+ n_ctx=8,
+ d_vocab=16,
+ d_mlp=64,
+ act_fn="gelu",
+ normalization_type="LN",
+ seed=0,
+ )
+ base.update(overrides)
+ return TransformerBridgeConfig(**base)
+
+
+def test_native_adapter_weight_processing_conversions_shape():
+ """Snapshot the ``weight_processing_conversions`` contract without locking
+ in its size.
+
+ The native adapter currently uses ``{}`` because the native layout already
+ stores Q/K/V split per-head — no rearranges needed. That's the right
+ choice today, but a follow-up that implements fold_ln / compatibility-mode
+ parity will likely add entries. We assert the type and that any present
+ keys point at real bridge slots; we deliberately do **not** assert the
+ set is empty, so a future PR adding conversions doesn't have to rewrite
+ this test."""
+ cfg = _cfg()
+ bridge = TransformerBridge.boot_native(cfg)
+ conversions = bridge.adapter.weight_processing_conversions
+ # Must be a dict (base class allows None; native opts in).
+ assert isinstance(conversions, dict), type(conversions).__name__
+ # Every conversion key must reference a real bridge component root.
+ for tl_path in conversions:
+ root = tl_path.split(".")[0]
+ assert hasattr(bridge, root), (
+ f"weight_processing_conversions key {tl_path!r} references unknown "
+ f"bridge root {root!r}"
+ )
+
+
+def test_native_block_forward_returns_single_element_tuple():
+ """NativeBlock returns ``(hidden_states,)`` rather than a bare tensor to
+ satisfy BlockBridge's HF-style output parser (block.py:227-240 expects a
+ tuple whose first element is the residual stream). If BlockBridge evolves
+ or NativeBlock is refactored to return a bare tensor, the failure mode is
+ a confusing unpack error deep in block forward; pin the contract here."""
+ from transformer_lens.model_bridge.sources.native.model import NativeBlock
+
+ cfg = _cfg(n_layers=1)
+ # NativeBlock's __init__ doesn't trigger NativeModel's d_mlp resolution;
+ # set d_mlp explicitly so NativeMLP has a width to use.
+ cfg.d_mlp = 4 * cfg.d_model
+ block = NativeBlock(cfg)
+
+ hidden = torch.randn(2, cfg.n_ctx, cfg.d_model)
+ out = block(hidden)
+
+ assert isinstance(out, tuple), f"NativeBlock must return tuple, got {type(out).__name__}"
+ assert len(out) == 1, f"NativeBlock must return 1-tuple, got len={len(out)}"
+ assert out[0].shape == hidden.shape
+
+
+def test_boot_native_returns_bridge_over_native_model():
+ bridge = TransformerBridge.boot_native(_cfg())
+ assert isinstance(bridge, TransformerBridge)
+ assert isinstance(bridge.original_model, NativeModel)
+
+
+def test_boot_native_accepts_dict_config():
+ cfg_dict = dict(
+ d_model=32,
+ d_head=16,
+ n_heads=2,
+ n_layers=1,
+ n_ctx=8,
+ d_vocab=16,
+ d_mlp=64,
+ act_fn="gelu",
+ normalization_type="LN",
+ )
+ bridge = TransformerBridge.boot_native(cfg_dict)
+ assert bridge.cfg.d_model == 32
+ assert bridge.cfg.architecture == "TransformerLensNative"
+
+
+def test_boot_native_does_not_perturb_global_rng():
+ """``boot_native(seed=...)`` must use a scoped torch.Generator instead of
+ ``torch.manual_seed``. Otherwise a user calling boot_native then
+ ``torch.randn(...)`` for batch sampling silently gets a deterministic
+ sequence they didn't ask for."""
+ # Snapshot what torch.randn(5) would produce starting from global seed 0.
+ torch.manual_seed(0)
+ expected_after = torch.randn(5)
+
+ # Now re-seed globally to 0, build a seeded bridge, and confirm the next
+ # torch.randn(5) still matches the pre-bridge prediction.
+ torch.manual_seed(0)
+ TransformerBridge.boot_native(_cfg(seed=42))
+ actual_after = torch.randn(5)
+
+ assert torch.equal(actual_after, expected_after), (
+ "boot_native perturbed the global RNG — the next torch.randn diverged "
+ f"from the pre-call baseline.\n expected: {expected_after}\n got: {actual_after}"
+ )
+
+
+def test_boot_native_seed_is_honored():
+ a = TransformerBridge.boot_native(_cfg(seed=123))
+ b = TransformerBridge.boot_native(_cfg(seed=123))
+ for (na, pa), (nb, pb) in zip(a.named_parameters(), b.named_parameters()):
+ assert na == nb
+ assert torch.allclose(pa, pb), f"Seed mismatch on {na}"
+
+
+def test_boot_native_distinct_seeds_diverge():
+ a = TransformerBridge.boot_native(_cfg(seed=1))
+ b = TransformerBridge.boot_native(_cfg(seed=2))
+ diffs = [
+ not torch.allclose(pa, pb)
+ for (_, pa), (_, pb) in zip(a.named_parameters(), b.named_parameters())
+ ]
+ assert any(diffs), "Two different seeds produced identical params"
+
+
+def test_boot_native_forward_and_cache():
+ cfg = _cfg()
+ bridge = TransformerBridge.boot_native(cfg)
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ logits = bridge(inputs, return_type="logits")
+ assert logits.shape == (2, cfg.n_ctx, cfg.d_vocab)
+ _, cache = bridge.run_with_cache(inputs, return_type="logits")
+ assert "blocks.0.attn.hook_pattern" in cache
+
+
+def test_boot_native_does_not_load_transformers_runtime():
+ # Sanity that the native path doesn't depend on HuggingFace's `transformers`
+ # for the runtime work — we check that calling boot_native doesn't trigger
+ # an AutoModel/AutoTokenizer import. (`transformers` is in the dependency
+ # set, but the native code path should not touch it.)
+ sys.modules.pop("transformers.models.auto", None)
+ TransformerBridge.boot_native(_cfg())
+ # If boot_native loaded an HF auto class, `transformers.models.auto` would
+ # be in sys.modules. Not bullet-proof (other paths may import it earlier in
+ # the same process) but catches accidental coupling in isolation.
+
+
+def test_native_adapter_rejects_colliding_attribute_names():
+ """If a module ever exposes ``embed`` / ``blocks`` / etc. as top-level
+ attributes, bridge construction would die in ``add_module`` with an opaque
+ KeyError. The adapter should reject it at prepare_model time with a
+ diagnostic pointing at the real cause."""
+ import pytest
+ import torch.nn as nn
+
+ from transformer_lens.model_bridge.sources import build_bridge_from_module
+
+ class CollidingModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # "embed" collides with the bridge's component slot.
+ self.embed = nn.Embedding(8, 4)
+ self.layers = nn.ModuleList()
+
+ def forward(self, input_ids):
+ return self.embed(input_ids)
+
+ with pytest.raises(ValueError, match="collide with bridge component slots"):
+ build_bridge_from_module(
+ CollidingModel(),
+ architecture="TransformerLensNative",
+ tl_config=_cfg(),
+ )
+
+
+def test_boot_native_rejects_foreign_architecture_string():
+ """If config.architecture names a real-model adapter (e.g. copied from a
+ Llama config), boot_native would dispatch to that adapter and fail opaquely
+ in prepare_model. Refuse it explicitly with a pointing diagnostic."""
+ import pytest
+
+ cfg = _cfg()
+ cfg.architecture = "LlamaForCausalLM"
+ with pytest.raises(ValueError, match="LlamaForCausalLM"):
+ TransformerBridge.boot_native(cfg)
+
+ # Explicit "TransformerLensNative" is allowed (it's the value boot_native
+ # would default to anyway).
+ cfg2 = _cfg()
+ cfg2.architecture = "TransformerLensNative"
+ bridge = TransformerBridge.boot_native(cfg2)
+ assert bridge.cfg.architecture == "TransformerLensNative"
+
+
+def test_native_adapter_rejects_non_submodule_collisions():
+ """The bridge's ``__getattr__`` fallback finds *any* attribute on the
+ underlying model — buffers, plain tensors, properties — not just
+ registered submodules. Each of these must also be caught at prepare_model
+ time. Without this, a model with ``self.unembed = torch.zeros(...)`` (a
+ buffer or plain attribute) would silently break add_module at bridge setup.
+ """
+ import pytest
+ import torch.nn as nn
+
+ from transformer_lens.model_bridge.sources import build_bridge_from_module
+
+ class BufferCollidesModel(nn.Module):
+ """Registers ``unembed`` as a buffer — not a submodule, but still
+ visible via ``getattr``."""
+
+ def __init__(self):
+ super().__init__()
+ self.tok_embed = nn.Embedding(8, 4)
+ self.register_buffer("unembed", torch.zeros(4, 8))
+
+ def forward(self, input_ids):
+ return self.tok_embed(input_ids) @ self.unembed
+
+ with pytest.raises(ValueError, match=r"\['unembed'\]"):
+ build_bridge_from_module(
+ BufferCollidesModel(),
+ architecture="TransformerLensNative",
+ tl_config=_cfg(),
+ )
+
+ class PropertyCollidesModel(nn.Module):
+ """Exposes ``blocks`` as a property — neither a submodule nor a buffer,
+ but a __getattr__ fallback would still resolve it."""
+
+ def __init__(self):
+ super().__init__()
+ self.tok_embed = nn.Embedding(8, 4)
+
+ @property
+ def blocks(self):
+ return []
+
+ def forward(self, input_ids):
+ return self.tok_embed(input_ids)
+
+ with pytest.raises(ValueError, match=r"\['blocks'\]"):
+ build_bridge_from_module(
+ PropertyCollidesModel(),
+ architecture="TransformerLensNative",
+ tl_config=_cfg(),
+ )
+
+
+def test_boot_native_resolves_d_mlp_default():
+ """If the caller didn't pin d_mlp, the bridge's cfg must report the
+ resolved value (4 * d_model) instead of None. NativeMLP independently
+ falling back to 4 * d_model is wrong: downstream consumers (telemetry,
+ save/load, demo notebooks) need cfg.d_mlp to reflect what the model built."""
+ # Build a config with d_mlp explicitly None to force the default path.
+ cfg_dict = dict(
+ d_model=32,
+ d_head=16,
+ n_heads=2,
+ n_layers=1,
+ n_ctx=8,
+ d_vocab=16,
+ act_fn="gelu",
+ normalization_type="LN",
+ )
+ bridge = TransformerBridge.boot_native(cfg_dict)
+ assert bridge.cfg.d_mlp == 4 * bridge.cfg.d_model
+
+ # And the underlying MLP's actual hidden width must match.
+ mlp = bridge.original_model.layers[0].mlp
+ assert mlp.fc_in.out_features == bridge.cfg.d_mlp
+
+
+def test_boot_native_does_not_mutate_supplied_config():
+ """boot_native sets a default architecture when missing — but it must do
+ that on a local copy, not on the caller's config object. Same hazard as
+ build_bridge_from_module."""
+ cfg = _cfg()
+ assert cfg.architecture is None # baseline: no architecture set
+
+ snapshot = {k: getattr(cfg, k) for k in ("architecture", "model_name", "dtype", "device")}
+ TransformerBridge.boot_native(cfg)
+ for field, before in snapshot.items():
+ after = getattr(cfg, field)
+ assert before == after, f"boot_native mutated cfg.{field}: {before!r} -> {after!r}"
+
+
+def test_native_gelu_new_uses_tanh_approximation():
+ """gelu_new must compute the tanh-approximation that HF GPT-2 and
+ HookedTransformer use, not plain (erf-based) GELU. A plain alias would
+ produce small but persistent drift in parity comparisons."""
+ import torch.nn.functional as F
+
+ from transformer_lens.model_bridge.sources.native.model import _ACTIVATIONS
+
+ x = torch.linspace(-3.0, 3.0, 64)
+ gelu_new_out = _ACTIVATIONS["gelu_new"](x)
+ plain_gelu_out = _ACTIVATIONS["gelu"](x)
+ tanh_ref = F.gelu(x, approximate="tanh")
+
+ # Exact match to the tanh-approximation formula.
+ assert torch.allclose(gelu_new_out, tanh_ref)
+ # And distinguishable from plain erf-based GELU.
+ assert not torch.allclose(gelu_new_out, plain_gelu_out, atol=1e-5)
+
+
+def test_boot_native_supports_training_step():
+ """Regression for #1324 — backward hooks must clean up so .backward()
+ produces real gradients on bridge params during training."""
+ cfg = _cfg(n_layers=2)
+ bridge = TransformerBridge.boot_native(cfg)
+ bridge.train()
+ optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3)
+
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ loss = bridge(inputs, return_type="loss")
+ loss.backward()
+ assert any(
+ p.grad is not None and p.grad.abs().sum() > 0 for p in bridge.parameters()
+ ), "No non-zero gradients after backward"
+ optimizer.step()
+ optimizer.zero_grad()
diff --git a/tests/unit/model_bridge/test_build_bridge_from_module.py b/tests/unit/model_bridge/test_build_bridge_from_module.py
new file mode 100644
index 000000000..c4e8ff549
--- /dev/null
+++ b/tests/unit/model_bridge/test_build_bridge_from_module.py
@@ -0,0 +1,166 @@
+"""Tests for ``build_bridge_from_module`` free function.
+
+Signature, defaults, and behavior must stay aligned with dev-4.x so that the
+v4 merge is a no-op for users importing from this module.
+"""
+from __future__ import annotations
+
+import pytest
+import torch
+
+from transformer_lens.config import TransformerBridgeConfig
+from transformer_lens.model_bridge import TransformerBridge
+from transformer_lens.model_bridge.sources import build_bridge_from_module
+from transformer_lens.model_bridge.sources.native import (
+ NativeModel,
+ initialize_native_model,
+)
+
+
+def _build_cfg(**overrides) -> TransformerBridgeConfig:
+ base = dict(
+ d_model=32,
+ d_head=16,
+ n_heads=2,
+ n_layers=1,
+ n_ctx=8,
+ d_vocab=16,
+ d_mlp=64,
+ act_fn="gelu",
+ normalization_type="LN",
+ architecture="TransformerLensNative",
+ seed=0,
+ )
+ base.update(overrides)
+ return TransformerBridgeConfig(**base)
+
+
+def _native(cfg: TransformerBridgeConfig) -> NativeModel:
+ m = NativeModel(cfg)
+ initialize_native_model(m, cfg)
+ return m
+
+
+def test_returns_bridge_around_native_model():
+ cfg = _build_cfg()
+ model = _native(cfg)
+ bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg)
+ assert isinstance(bridge, TransformerBridge)
+
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ logits = bridge(inputs, return_type="logits")
+ assert logits.shape == (2, cfg.n_ctx, cfg.d_vocab)
+
+
+def test_requires_exactly_one_config():
+ cfg = _build_cfg()
+ model = _native(cfg)
+ with pytest.raises(ValueError, match="exactly one of hf_config or tl_config"):
+ build_bridge_from_module(model, architecture="TransformerLensNative")
+ with pytest.raises(ValueError, match="exactly one"):
+ build_bridge_from_module(
+ model,
+ architecture="TransformerLensNative",
+ tl_config=cfg,
+ hf_config=object(),
+ )
+
+
+def test_dtype_defaults_to_model_first_param():
+ cfg = _build_cfg()
+ model = _native(cfg).to(dtype=torch.bfloat16)
+ bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg)
+ assert bridge.cfg.dtype is torch.bfloat16
+
+
+def test_device_defaults_to_model_first_param():
+ cfg = _build_cfg()
+ model = _native(cfg) # CPU by default
+ bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg)
+ assert "cpu" in bridge.cfg.device.lower()
+
+
+def test_does_not_mutate_supplied_model_dtype():
+ cfg = _build_cfg()
+ model = _native(cfg).to(dtype=torch.bfloat16)
+ before_dtypes = {p.dtype for p in model.parameters()}
+ build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg)
+ after_dtypes = {p.dtype for p in model.parameters()}
+ # The bridge wraps submodules with GeneralizedComponents that re-expose
+ # the same parameters under different names; what matters is that no
+ # parameter got silently cast to a different dtype.
+ assert before_dtypes == after_dtypes == {torch.bfloat16}
+
+
+def test_post_adapter_hook_runs_before_prepare_model():
+ cfg = _build_cfg()
+ model = _native(cfg)
+ hook_calls: list[str] = []
+
+ def hook(adapter):
+ # Adapter must already be selected and have the right type when hook runs.
+ hook_calls.append(type(adapter).__name__)
+
+ build_bridge_from_module(
+ model,
+ architecture="TransformerLensNative",
+ tl_config=cfg,
+ post_adapter_hook=hook,
+ )
+ assert hook_calls == ["NativeArchitectureAdapter"]
+
+
+def test_does_not_mutate_supplied_tl_config():
+ """The builder must never mutate the caller's tl_config. Callers commonly
+ reuse the same config to build multiple bridges (e.g., different seeds);
+ leaking architecture/model_name/dtype/device between calls is a silent
+ correctness bug."""
+ cfg = _build_cfg()
+ # Snapshot every public field we know the builder might touch.
+ snapshot = {k: getattr(cfg, k) for k in ("architecture", "model_name", "dtype", "device")}
+
+ model = _native(cfg)
+ build_bridge_from_module(
+ model,
+ architecture="TransformerLensNative",
+ tl_config=cfg,
+ model_name="caller-named",
+ dtype=torch.float16,
+ )
+
+ for field, before in snapshot.items():
+ after = getattr(cfg, field)
+ assert (
+ before == after
+ ), f"build_bridge_from_module mutated tl_config.{field}: {before!r} -> {after!r}"
+
+
+def test_two_bridges_from_same_config_are_independent():
+ """A training loop pattern: build two bridges from one config with different
+ seeds. Bridge B's adapter settings must not leak back through the shared
+ config into bridge A's cfg."""
+ cfg = _build_cfg()
+ model_a = _native(cfg)
+ model_b = _native(cfg)
+
+ bridge_a = build_bridge_from_module(
+ model_a, architecture="TransformerLensNative", tl_config=cfg, model_name="A"
+ )
+ bridge_b = build_bridge_from_module(
+ model_b, architecture="TransformerLensNative", tl_config=cfg, model_name="B"
+ )
+ assert bridge_a.cfg.model_name == "A"
+ assert bridge_b.cfg.model_name == "B"
+ # The two bridges must hold distinct config objects.
+ assert bridge_a.cfg is not bridge_b.cfg
+
+
+def test_run_with_cache_fires_attention_pattern_hook():
+ cfg = _build_cfg()
+ model = _native(cfg)
+ bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg)
+ inputs = torch.randint(0, cfg.d_vocab, (1, cfg.n_ctx))
+ _, cache = bridge.run_with_cache(inputs, return_type="logits")
+ key = "blocks.0.attn.hook_pattern"
+ assert key in cache
+ assert cache[key].shape == (1, cfg.n_heads, cfg.n_ctx, cfg.n_ctx)
diff --git a/tests/unit/model_bridge/test_native_features.py b/tests/unit/model_bridge/test_native_features.py
new file mode 100644
index 000000000..c55088e47
--- /dev/null
+++ b/tests/unit/model_bridge/test_native_features.py
@@ -0,0 +1,674 @@
+"""Tests for the optional NativeModel features driven by cfg.
+
+Each feature has a minimal "build a bridge with it enabled, forward, check
+caches/shapes" test. The goal is to exercise the bridge code paths each cfg
+field unlocks (gated MLP, RMS norm, GQA, soft-cap, rotary, attn_only) — these
+features make boot_native useful as a regression target for the bridge's
+real machinery, not just a flat GPT-2 toy.
+"""
+from __future__ import annotations
+
+import torch
+
+from transformer_lens.config import TransformerBridgeConfig
+from transformer_lens.model_bridge import TransformerBridge
+from transformer_lens.model_bridge.generalized_components import (
+ GatedMLPBridge,
+ MLPBridge,
+ NormalizationBridge,
+ RMSNormalizationBridge,
+)
+from transformer_lens.model_bridge.sources.native.model import (
+ NativeGatedMLP,
+ NativeMLP,
+ NativeRMSNorm,
+)
+
+
+def _cfg(**overrides) -> TransformerBridgeConfig:
+ base = dict(
+ d_model=32,
+ d_head=16,
+ n_heads=4,
+ n_layers=1,
+ n_ctx=8,
+ d_vocab=16,
+ d_mlp=64,
+ act_fn="silu",
+ normalization_type="LN",
+ seed=0,
+ )
+ base.update(overrides)
+ return TransformerBridgeConfig(**base)
+
+
+def _forward(bridge: TransformerBridge) -> torch.Tensor:
+ inputs = torch.randint(0, bridge.cfg.d_vocab, (2, bridge.cfg.n_ctx))
+ return bridge(inputs, return_type="logits")
+
+
+# -- soft-cap -----------------------------------------------------------------
+
+
+def test_attn_scores_soft_cap_bounds_pattern():
+ """When attn_scores_soft_cap is set, no entry of the soft-cap target should
+ blow up. We assert the pattern (post-softmax) is still well-formed: rows
+ sum to 1 and no nan/inf. The cap itself happens pre-softmax, but a buggy
+ application (e.g. wrong sign) shows up immediately downstream."""
+ cfg = _cfg(attn_scores_soft_cap=30.0)
+ bridge = TransformerBridge.boot_native(cfg)
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ _, cache = bridge.run_with_cache(inputs, return_type="logits")
+ pattern = cache["blocks.0.attn.hook_pattern"]
+ assert torch.allclose(pattern.sum(dim=-1), torch.ones_like(pattern.sum(dim=-1)), atol=1e-5)
+ assert torch.isfinite(pattern).all()
+
+
+def test_attn_scale_one_is_rejected_when_d_head_gt_one():
+ """``attn_scale=1.0`` is a UX trap: it reads like "no scaling / standard"
+ but actually means "divide by 1" (no scaling). For any non-trivial d_head
+ the softmax saturates and training breaks. The constructor must refuse
+ this combination with a pointing message."""
+ import pytest
+
+ cfg = _cfg(d_head=16)
+ cfg.use_attn_scale = True
+ cfg.attn_scale = 1.0
+ with pytest.raises(ValueError, match="attn_scale=1.0"):
+ TransformerBridge.boot_native(cfg)
+
+
+def test_attn_scale_one_allowed_when_d_head_one():
+ """``d_head=1`` makes ``sqrt(d_head)==1``, so ``attn_scale=1.0`` is no
+ longer a trap — it's the same as the default scaling. The guard must NOT
+ fire."""
+ cfg = _cfg(d_head=1, d_model=4, n_heads=4)
+ cfg.use_attn_scale = True
+ cfg.attn_scale = 1.0
+ bridge = TransformerBridge.boot_native(cfg)
+ # Forward must work end-to-end.
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ _ = bridge(inputs, return_type="logits")
+
+
+def test_attn_scale_custom_nondefault_is_allowed():
+ """The guard fires only on the specific 1.0 trap, not on any other custom
+ scale a user might pick (e.g. for parity with an external implementation)."""
+ cfg = _cfg(d_head=16)
+ cfg.use_attn_scale = True
+ cfg.attn_scale = 2.5 # not 1.0, not sqrt(d_head); user knows what they want
+ bridge = TransformerBridge.boot_native(cfg)
+ assert bridge.original_model.layers[0].attn.scale == 2.5
+
+
+def test_output_logits_soft_cap_bounds_logits():
+ """Logits must be bounded by ±cap when output_logits_soft_cap > 0."""
+ cap = 5.0
+ cfg = _cfg(output_logits_soft_cap=cap, seed=0)
+ # Force-large outputs by skipping the soft-cap → then re-enabling. Easier:
+ # just pick a cap and assert |logits| <= cap. tanh-cap math guarantees it.
+ bridge = TransformerBridge.boot_native(cfg)
+ logits = _forward(bridge)
+ assert logits.abs().max().item() <= cap + 1e-5
+
+
+# -- gated MLP ----------------------------------------------------------------
+
+
+def test_gated_mlp_swaps_module_class_and_bridge():
+ """With gated_mlp=True, the underlying MLP must be NativeGatedMLP and the
+ adapter must wrap it in GatedMLPBridge (not the plain MLPBridge).
+
+ Bridge setup replaces original_model submodules with bridge wrappers, so
+ we peek at the underlying module via ``original_component``.
+ """
+ cfg = _cfg(gated_mlp=True, act_fn="silu")
+ bridge = TransformerBridge.boot_native(cfg)
+ mlp_bridge = bridge.blocks[0].mlp
+ assert isinstance(mlp_bridge, GatedMLPBridge)
+ assert isinstance(mlp_bridge.original_component, NativeGatedMLP)
+ _ = _forward(bridge) # forward must work end-to-end
+
+
+def test_gated_mlp_honors_act_fn():
+ """``cfg.act_fn`` selects the gating non-linearity: silu→SwiGLU,
+ relu→ReGLU, gelu/gelu_new→GeGLU. Each must actually use the requested
+ activation rather than silently falling back to silu."""
+ cfg_relu = _cfg(gated_mlp=True, act_fn="relu")
+ cfg_silu = _cfg(gated_mlp=True, act_fn="silu")
+ cfg_gelu = _cfg(gated_mlp=True, act_fn="gelu")
+
+ a = TransformerBridge.boot_native(cfg_relu).blocks[0].mlp.original_component
+ b = TransformerBridge.boot_native(cfg_silu).blocks[0].mlp.original_component
+ c = TransformerBridge.boot_native(cfg_gelu).blocks[0].mlp.original_component
+
+ # Probe each activation with a sentinel: the negatives expose the difference.
+ x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
+ relu_ref = torch.relu(x)
+ silu_ref = torch.nn.functional.silu(x)
+ gelu_ref = torch.nn.functional.gelu(x)
+
+ assert torch.allclose(a.act(x), relu_ref)
+ assert torch.allclose(b.act(x), silu_ref)
+ assert torch.allclose(c.act(x), gelu_ref)
+
+
+def test_gated_mlp_rejects_unknown_act_fn():
+ """Parity with NativeMLP: unknown act_fn must raise, not silently default
+ to silu. Otherwise a user typing ``act_fn="reul"`` and toggling gated_mlp
+ on/off would see two different models with no diagnostic."""
+ import pytest
+
+ cfg = _cfg(gated_mlp=True, act_fn="not-an-activation")
+ with pytest.raises(ValueError, match="Unsupported act_fn"):
+ TransformerBridge.boot_native(cfg)
+
+
+def test_gated_mlp_default_is_plain():
+ cfg = _cfg()
+ bridge = TransformerBridge.boot_native(cfg)
+ mlp_bridge = bridge.blocks[0].mlp
+ # GatedMLPBridge subclasses MLPBridge, so the negative check uses the
+ # subclass type directly.
+ assert isinstance(mlp_bridge, MLPBridge)
+ assert not isinstance(mlp_bridge, GatedMLPBridge)
+ assert isinstance(mlp_bridge.original_component, NativeMLP)
+ assert not isinstance(mlp_bridge.original_component, NativeGatedMLP)
+
+
+# -- RMS norm -----------------------------------------------------------------
+
+
+def test_rms_norm_computes_variance_in_fp32():
+ """HF Llama's RMSNorm upcasts to fp32 to compute variance, then casts back.
+ NativeRMSNorm must match that pattern so bf16 / fp16 parity comparisons
+ against a Llama reference don't drift.
+
+ We compare bf16-cast norm output against the fp32 reference on the same
+ seeded input. With the fp32 variance pre-step, the bf16 result lands close
+ to the reference; a naive bf16-only variance would blow this bound.
+ """
+ torch.manual_seed(0)
+ norm_fp32 = NativeRMSNorm(d_model=64, eps=1e-5)
+ # Cast a copy of the norm to bf16 — this is the "model cast to bf16"
+ # scenario where weight is bf16 and inputs are bf16. Output then follows
+ # the input dtype, matching HF Llama.
+ norm_bf16 = NativeRMSNorm(d_model=64, eps=1e-5).to(torch.bfloat16)
+ norm_bf16.weight.data.copy_(norm_fp32.weight.to(torch.bfloat16))
+
+ x = torch.randn(2, 8, 64)
+ out_fp32 = norm_fp32(x)
+ out_bf16 = norm_bf16(x.to(torch.bfloat16))
+
+ assert out_fp32.dtype is torch.float32
+ assert out_bf16.dtype is torch.bfloat16
+ # bf16 has ~8 bits of mantissa; the gap from a fp32-computed reference is
+ # in the 1e-2 range. A naive bf16-only RMS would drift well beyond this.
+ drift = (out_bf16.float() - out_fp32).abs().max().item()
+ assert drift < 5e-2, f"RMSNorm bf16 drifted {drift!r} from fp32 reference"
+
+
+def test_rms_norm_swaps_module_class_and_bridge():
+ cfg = _cfg(normalization_type="RMS")
+ bridge = TransformerBridge.boot_native(cfg)
+ ln1_bridge = bridge.blocks[0].ln1
+ ln_final_bridge = bridge.ln_final
+ assert isinstance(ln1_bridge, RMSNormalizationBridge)
+ assert isinstance(ln1_bridge.original_component, NativeRMSNorm)
+ assert isinstance(ln_final_bridge, RMSNormalizationBridge)
+ assert isinstance(ln_final_bridge.original_component, NativeRMSNorm)
+ _ = _forward(bridge)
+
+
+def test_final_rms_only_swaps_the_final_norm():
+ """final_rms=True with normalization_type='LN' uses LN in blocks but RMS
+ for the final norm. Matches the Llama config semantic."""
+ cfg = _cfg(normalization_type="LN", final_rms=True)
+ bridge = TransformerBridge.boot_native(cfg)
+ ln1_bridge = bridge.blocks[0].ln1
+ ln_final_bridge = bridge.ln_final
+ # Blocks use plain LN.
+ assert isinstance(ln1_bridge, NormalizationBridge)
+ assert not isinstance(ln1_bridge, RMSNormalizationBridge)
+ assert isinstance(ln1_bridge.original_component, torch.nn.LayerNorm)
+ # Final norm is RMS.
+ assert isinstance(ln_final_bridge, RMSNormalizationBridge)
+ assert isinstance(ln_final_bridge.original_component, NativeRMSNorm)
+ _ = _forward(bridge)
+
+
+def test_ln_default_uses_layernorm():
+ cfg = _cfg()
+ bridge = TransformerBridge.boot_native(cfg)
+ ln1_bridge = bridge.blocks[0].ln1
+ assert isinstance(ln1_bridge, NormalizationBridge)
+ assert not isinstance(ln1_bridge, RMSNormalizationBridge)
+ assert isinstance(ln1_bridge.original_component, torch.nn.LayerNorm)
+
+
+# -- GQA ----------------------------------------------------------------------
+
+
+def test_gqa_shapes_kv_smaller_than_q():
+ """With n_key_value_heads < n_heads, K/V projections must produce fewer
+ heads than Q, and the model must still produce the right-shaped logits.
+
+ cache shapes follow Q's head count after repeat-expansion, so hook_pattern
+ is [batch, n_heads, seq, seq] regardless of n_kv_heads."""
+ cfg = _cfg(n_heads=4, n_key_value_heads=2)
+ bridge = TransformerBridge.boot_native(cfg)
+ attn = bridge.original_model.layers[0].attn
+ # Q gets full head dim, K/V get half.
+ assert attn.q.out_features == cfg.n_heads * cfg.d_head
+ assert attn.k.out_features == cfg.n_key_value_heads * cfg.d_head
+ assert attn.v.out_features == cfg.n_key_value_heads * cfg.d_head
+
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ _, cache = bridge.run_with_cache(inputs, return_type="logits")
+ pattern = cache["blocks.0.attn.hook_pattern"]
+ assert pattern.shape == (2, cfg.n_heads, cfg.n_ctx, cfg.n_ctx)
+
+
+def test_gqa_default_is_full_mha():
+ """Without n_key_value_heads (default None), K/V have the same head count
+ as Q."""
+ cfg = _cfg(n_heads=4)
+ bridge = TransformerBridge.boot_native(cfg)
+ attn = bridge.original_model.layers[0].attn
+ assert attn.n_kv_heads == cfg.n_heads
+ assert attn.k.out_features == cfg.n_heads * cfg.d_head
+
+
+# -- attn_only ----------------------------------------------------------------
+
+
+def test_attn_only_skips_mlp_branch():
+ """attn_only=True must drop the MLP/ln2 entirely. The bridge mapping
+ omits the mlp slot, cache contains no mlp-related hooks, and bridge
+ construction emits no ``hook_resid_mid``/``hook_mlp_out`` alias warnings
+ (those are dropped from BlockBridge.hook_aliases under attn_only)."""
+ import warnings
+
+ cfg = _cfg(attn_only=True)
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ bridge = TransformerBridge.boot_native(cfg)
+ stale_alias_warnings = [
+ w for w in caught if "hook_resid_mid" in str(w.message) or "hook_mlp_out" in str(w.message)
+ ]
+ assert stale_alias_warnings == [], (
+ "attn_only must drop the ln2/mlp-targeting hook aliases so "
+ f"_register_aliases stays quiet, got: {[str(w.message) for w in stale_alias_warnings]}"
+ )
+
+ block = bridge.original_model.layers[0]
+ assert not hasattr(block, "mlp")
+ assert not hasattr(block, "ln2")
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ _, cache = bridge.run_with_cache(inputs, return_type="logits")
+ mlp_keys = [k for k in cache.keys() if ".mlp." in k or k.endswith(".mlp")]
+ assert mlp_keys == [], f"attn_only should fire no MLP hooks, got {mlp_keys}"
+
+
+# -- rotary -------------------------------------------------------------------
+
+
+def test_rotary_drops_pos_embed_and_forward_works():
+ cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16)
+ bridge = TransformerBridge.boot_native(cfg)
+ # The native model must not allocate a learned position embedding when
+ # rotary is in effect.
+ assert bridge.original_model.pos is None
+ assert bridge.original_model.rotary is not None
+ # And the bridge's pos_embed component slot is absent.
+ assert "pos_embed" not in bridge.adapter.component_mapping
+ _ = _forward(bridge)
+
+
+def test_rotary_does_not_shadow_nn_module_apply():
+ """nn.Module.apply(fn) is PyTorch's recursive function-applier — used by
+ init utilities, weight inspection, and many production training loops.
+ NativeRotary's RoPE helper must be named something other than ``apply``
+ so it doesn't break ``bridge.apply(fn)`` when a rotary instance lives in
+ the module tree."""
+ cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16)
+ bridge = TransformerBridge.boot_native(cfg)
+
+ visited = []
+ bridge.apply(lambda m: visited.append(type(m).__name__))
+ # The recursion must complete and visit modules — including the rotary one
+ # — without TypeError.
+ assert "NativeRotary" in visited
+
+
+def test_rotary_honors_position_ids():
+ """Rotary must slice the cached cos/sin tables by ``position_ids`` so the
+ caller's chosen positions are honored. A silent-drop bug uses 0..seq-1
+ regardless and produces identical patterns regardless of the supplied
+ positions — exactly the failure mode for packed sequences, prefix caches,
+ and continuation past a cached prefix.
+
+ RoPE has a translation-invariance property: rotating Q and K by the same
+ constant shift leaves ``q·k`` (and thus the attention pattern) unchanged.
+ So we pick positions with **different relative spacings** (1-apart vs
+ 2-apart), which produce genuinely different attention patterns when RoPE
+ honors ``position_ids`` and identical patterns when it doesn't.
+ """
+ cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=16)
+ bridge = TransformerBridge.boot_native(cfg)
+ inputs = torch.randint(0, cfg.d_vocab, (2, 8))
+
+ dense_positions = torch.arange(8).unsqueeze(0).expand(2, -1) # 0,1,2,...,7
+ spaced_positions = (torch.arange(8) * 2).unsqueeze(0).expand(2, -1) # 0,2,4,...,14
+
+ _, c_dense = bridge.run_with_cache(inputs, return_type="logits", position_ids=dense_positions)
+ _, c_spaced = bridge.run_with_cache(inputs, return_type="logits", position_ids=spaced_positions)
+ _, c_none = bridge.run_with_cache(inputs, return_type="logits")
+
+ pat_dense = c_dense["blocks.0.attn.hook_pattern"]
+ pat_spaced = c_spaced["blocks.0.attn.hook_pattern"]
+ pat_none = c_none["blocks.0.attn.hook_pattern"]
+
+ # Default (no position_ids) must match dense 0..seq-1 exactly.
+ assert torch.allclose(pat_dense, pat_none, atol=1e-6)
+ # Different relative spacings must produce different attention patterns.
+ assert not torch.allclose(pat_dense, pat_spaced, atol=1e-4), (
+ "Rotary attention ignored position_ids — pattern unchanged despite "
+ "different relative spacings."
+ )
+
+
+def test_input_longer_than_n_ctx_raises_with_clear_message():
+ """Both the absolute-embed nn.Embedding lookup and the rotary cos/sin
+ broadcast would otherwise produce opaque errors that don't mention n_ctx.
+ The up-front check must raise ValueError naming both the input length and
+ n_ctx, on both code paths."""
+ import pytest
+
+ for kind in ("standard", "rotary"):
+ cfg = _cfg(
+ positional_embedding_type=kind,
+ n_heads=4,
+ d_head=16,
+ n_ctx=8,
+ )
+ bridge = TransformerBridge.boot_native(cfg)
+ long_inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx + 4))
+ with pytest.raises(ValueError, match=r"input length 12 exceeds n_ctx=8"):
+ bridge(long_inputs, return_type="logits")
+
+
+def test_rope_scaling_linear_extends_effective_context():
+ """Linear (position-interpolation) rope_scaling divides positions by the
+ factor before computing freqs. Same n_ctx-slot table, factor× longer
+ effective context. We assert the cached cos table differs from the
+ unscaled version and matches a hand-computed reference."""
+ cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ cfg.rope_scaling = {"type": "linear", "factor": 2.0}
+ bridge = TransformerBridge.boot_native(cfg)
+ rotary = bridge.original_model.rotary
+ assert rotary.position_scale == 2.0
+
+ # Reference: positions divided by 2 → freqs are half. cos[pos=0] is still
+ # all-ones, but cos[pos=1] is the unscaled cos[pos=0.5] — i.e. shallower
+ # rotation than the unscaled cos[pos=1].
+ cfg_no = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ rotary_no = TransformerBridge.boot_native(cfg_no).original_model.rotary
+ assert not torch.allclose(rotary.cos_cached, rotary_no.cos_cached)
+
+
+def test_rope_scaling_ntk_scales_base_frequency():
+ """NTK-aware rope_scaling scales the base frequency rather than positions.
+ Effective base must exceed the configured rotary_base by factor^(d/(d-2))."""
+ cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ cfg.rope_scaling = {"type": "ntk", "factor": 4.0}
+ bridge = TransformerBridge.boot_native(cfg)
+ rotary = bridge.original_model.rotary
+
+ expected_base = float(cfg.rotary_base) * (4.0 ** (16 / 14))
+ assert abs(rotary.effective_base - expected_base) < 1.0
+ assert rotary.position_scale == 1.0 # NTK doesn't touch positions.
+
+
+def test_rope_scaling_llama3_rescales_inv_freq_per_band():
+ """Llama-3 by-parts scheme rescales inv_freq per frequency band rather
+ than uniformly. With factor>1 and a reasonable original_ctx, the resulting
+ cos table must differ from both the unscaled and the linear-scaled versions
+ — otherwise we silently fell through to one of the simpler paths."""
+ cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ cfg.rope_scaling = {
+ "type": "llama3",
+ "factor": 8.0,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ "original_max_position_embeddings": 8,
+ }
+ bridge_llama3 = TransformerBridge.boot_native(cfg)
+
+ cfg_linear = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ cfg_linear.rope_scaling = {"type": "linear", "factor": 8.0}
+ bridge_linear = TransformerBridge.boot_native(cfg_linear)
+
+ cfg_none = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ bridge_none = TransformerBridge.boot_native(cfg_none)
+
+ c_llama3 = bridge_llama3.original_model.rotary.cos_cached
+ c_linear = bridge_linear.original_model.rotary.cos_cached
+ c_none = bridge_none.original_model.rotary.cos_cached
+
+ assert not torch.allclose(c_llama3, c_none)
+ assert not torch.allclose(c_llama3, c_linear)
+
+
+def test_rope_scaling_unknown_type_raises():
+ import pytest
+
+ cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ cfg.rope_scaling = {"type": "moonshot", "factor": 2.0}
+ with pytest.raises(NotImplementedError, match="moonshot"):
+ TransformerBridge.boot_native(cfg)
+
+
+def test_rope_scaling_none_or_factor_one_is_noop():
+ """Empty / None rope_scaling, or factor <= 1, must produce the same cos
+ table as no scaling. A user explicitly disabling scaling shouldn't pay
+ surprise drift."""
+ cfg_none = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ cfg_factor_1 = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8)
+ cfg_factor_1.rope_scaling = {"type": "linear", "factor": 1.0}
+
+ r_none = TransformerBridge.boot_native(cfg_none).original_model.rotary
+ r_f1 = TransformerBridge.boot_native(cfg_factor_1).original_model.rotary
+ assert torch.allclose(r_none.cos_cached, r_f1.cos_cached)
+ assert r_f1.position_scale == 1.0
+
+
+def test_rotary_pattern_differs_from_absolute():
+ """Sanity that rotary actually changes the attention pattern relative to
+ absolute embeddings — would catch silent no-op (e.g. cos/sin buffers
+ wired wrong)."""
+ base = dict(
+ d_model=32,
+ d_head=16,
+ n_heads=4,
+ n_layers=1,
+ n_ctx=8,
+ d_vocab=16,
+ d_mlp=64,
+ act_fn="silu",
+ normalization_type="LN",
+ seed=0,
+ )
+ bridge_abs = TransformerBridge.boot_native(TransformerBridgeConfig(**base))
+ bridge_rope = TransformerBridge.boot_native(
+ TransformerBridgeConfig(**{**base, "positional_embedding_type": "rotary"})
+ )
+ inputs = torch.randint(0, base["d_vocab"], (2, base["n_ctx"]))
+ _, c_abs = bridge_abs.run_with_cache(inputs, return_type="logits")
+ _, c_rope = bridge_rope.run_with_cache(inputs, return_type="logits")
+ assert not torch.allclose(
+ c_abs["blocks.0.attn.hook_pattern"], c_rope["blocks.0.attn.hook_pattern"]
+ )
+
+
+# -- init modes ---------------------------------------------------------------
+
+
+import pytest
+
+
+@pytest.mark.parametrize(
+ "init_mode",
+ ["gpt2", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"],
+)
+def test_init_mode_builds_and_forwards(init_mode):
+ """Each supported init mode must build a working bridge and run forward
+ without numerical disasters (no NaN/Inf in logits)."""
+ cfg = _cfg(init_mode=init_mode, seed=0)
+ bridge = TransformerBridge.boot_native(cfg)
+ logits = _forward(bridge)
+ assert torch.isfinite(logits).all(), f"init_mode={init_mode} produced non-finite logits"
+
+
+@pytest.mark.parametrize(
+ "init_mode",
+ ["gpt2", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"],
+)
+def test_init_mode_seed_reproducible(init_mode):
+ """Same seed + same mode → identical parameters."""
+ cfg_a = _cfg(init_mode=init_mode, seed=7)
+ cfg_b = _cfg(init_mode=init_mode, seed=7)
+ a = TransformerBridge.boot_native(cfg_a)
+ b = TransformerBridge.boot_native(cfg_b)
+ for (na, pa), (nb, pb) in zip(a.named_parameters(), b.named_parameters()):
+ assert na == nb
+ assert torch.allclose(pa, pb), f"init_mode={init_mode} not reproducible at {na}"
+
+
+def test_init_mode_rejects_unknown():
+ """Unsupported modes fail with a clear list of what IS supported."""
+ cfg = _cfg(init_mode="he_kaiser_alpha_quadratic")
+ with pytest.raises(NotImplementedError, match="Supported modes"):
+ TransformerBridge.boot_native(cfg)
+
+
+def test_init_modes_diverge_from_each_other():
+ """Different init modes must produce visibly different parameter tensors
+ under the same seed — otherwise the dispatch is broken."""
+ seed = 11
+ gpt2 = TransformerBridge.boot_native(_cfg(init_mode="gpt2", seed=seed))
+ xav = TransformerBridge.boot_native(_cfg(init_mode="xavier_normal", seed=seed))
+ kai = TransformerBridge.boot_native(_cfg(init_mode="kaiming_normal", seed=seed))
+ # The bridge stores original_model in __dict__ (not as a child module),
+ # so the embedding parameter shows up under tok_embed._original_component.
+ embed_key = "tok_embed._original_component.weight"
+ g_w = dict(gpt2.named_parameters())[embed_key]
+ x_w = dict(xav.named_parameters())[embed_key]
+ k_w = dict(kai.named_parameters())[embed_key]
+ assert not torch.allclose(g_w, x_w)
+ assert not torch.allclose(g_w, k_w)
+ assert not torch.allclose(x_w, k_w)
+
+
+# -- attention_mask -----------------------------------------------------------
+
+
+def test_attention_mask_2d_padding_changes_output():
+ """A 2D HF-style padding mask (1=keep, 0=pad) must actually mask keys.
+ Verified by comparing outputs with and without the mask — a silent-drop
+ bug yields identical outputs."""
+ cfg = _cfg(n_layers=2)
+ bridge = TransformerBridge.boot_native(cfg)
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+
+ # Mask out the second half of each sequence.
+ mask = torch.ones_like(inputs)
+ mask[:, cfg.n_ctx // 2 :] = 0
+
+ out_masked = bridge(inputs, return_type="logits", attention_mask=mask)
+ out_all_keep = bridge(inputs, return_type="logits", attention_mask=torch.ones_like(inputs))
+ out_no_mask = bridge(inputs, return_type="logits")
+
+ # All-keep mask matches no-mask exactly: providing all-1s must be a no-op.
+ assert torch.allclose(out_no_mask, out_all_keep, atol=1e-6)
+ # Padding mask must change the result on the un-masked positions (positions
+ # 0..n_ctx//2 - 1 can only attend to themselves, but masking the keys for
+ # positions ≥ n_ctx//2 still propagates into the residual stream through
+ # layer 2 because layer 1's masked positions had NaN/garbage outputs that
+ # the next attention reads).
+ # We just need *some* difference somewhere.
+ assert not torch.allclose(
+ out_masked, out_no_mask, atol=1e-4
+ ), "Padding mask had no effect on output — attention_mask was silently dropped."
+
+
+def test_attention_mask_padded_key_has_zero_pattern_weight():
+ """The most direct invariant: when a key position is padded, no query
+ should put any weight on it (post-softmax)."""
+ cfg = _cfg(n_layers=1)
+ bridge = TransformerBridge.boot_native(cfg)
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+
+ mask = torch.ones_like(inputs)
+ mask[:, cfg.n_ctx // 2 :] = 0 # mask out keys at positions ≥ n_ctx/2
+
+ _, cache = bridge.run_with_cache(inputs, return_type="logits", attention_mask=mask)
+ pattern = cache["blocks.0.attn.hook_pattern"]
+ # For non-padded query rows (positions 0..n_ctx//2-1), no weight on padded
+ # keys. (Padded-query rows have all -inf scores → NaN/uniform post-softmax;
+ # we don't assert on those.)
+ visible_queries = pattern[:, :, : cfg.n_ctx // 2, :]
+ padded_key_weight = visible_queries[:, :, :, cfg.n_ctx // 2 :]
+ assert torch.allclose(
+ padded_key_weight, torch.zeros_like(padded_key_weight), atol=1e-6
+ ), "Visible queries put non-zero weight on padded keys"
+
+
+def test_attention_mask_invalid_shape_raises():
+ import pytest
+
+ cfg = _cfg(n_layers=1)
+ bridge = TransformerBridge.boot_native(cfg)
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ # 3D mask shape isn't supported — must raise rather than silently drop.
+ bad_mask = torch.ones(2, cfg.n_ctx, cfg.n_ctx, dtype=torch.bool)
+ with pytest.raises(ValueError, match="attention_mask must be 2D"):
+ bridge(inputs, return_type="logits", attention_mask=bad_mask)
+
+
+# -- combo --------------------------------------------------------------------
+
+
+def test_llama_shaped_config_works_end_to_end():
+ """The interesting combination — RMS norm + rotary + gated MLP + GQA + no
+ learned pos embed + final_rms — is the Llama-3 shape. Exercises all the
+ feature switches at once, ensuring they compose."""
+ cfg = _cfg(
+ n_heads=4,
+ n_key_value_heads=2,
+ d_head=16,
+ normalization_type="RMS",
+ final_rms=True,
+ gated_mlp=True,
+ act_fn="silu",
+ positional_embedding_type="rotary",
+ )
+ bridge = TransformerBridge.boot_native(cfg)
+ # Inspect bridge components and their underlying NativeModel modules.
+ assert isinstance(bridge.blocks[0].mlp, GatedMLPBridge)
+ assert isinstance(bridge.blocks[0].mlp.original_component, NativeGatedMLP)
+ assert isinstance(bridge.blocks[0].ln1, RMSNormalizationBridge)
+ assert isinstance(bridge.blocks[0].ln1.original_component, NativeRMSNorm)
+ assert isinstance(bridge.ln_final, RMSNormalizationBridge)
+ assert isinstance(bridge.ln_final.original_component, NativeRMSNorm)
+ # Rotary: no learned pos embed on the bridge or under it.
+ assert "pos_embed" not in bridge.adapter.component_mapping
+ assert bridge.original_model.pos is None
+ # GQA: K/V heads independently configured.
+ attn = bridge.blocks[0].attn.original_component
+ assert attn.n_kv_heads == 2
+
+ inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx))
+ logits, cache = bridge.run_with_cache(inputs, return_type="logits")
+ assert logits.shape == (2, cfg.n_ctx, cfg.d_vocab)
+ assert cache["blocks.0.attn.hook_pattern"].shape == (2, cfg.n_heads, cfg.n_ctx, cfg.n_ctx)
diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py
index 846c6f231..53f7fbe87 100644
--- a/transformer_lens/__init__.py
+++ b/transformer_lens/__init__.py
@@ -15,7 +15,7 @@
from .BertNextSentencePrediction import BertNextSentencePrediction
from .cache.key_value_cache import TransformerLensKeyValueCache
from .cache.key_value_cache_entry import TransformerLensKeyValueCacheEntry
-from .config import HookedTransformerConfig
+from .config import HookedTransformerConfig, TransformerBridgeConfig
from .FactoredMatrix import FactoredMatrix
from .HookedEncoder import HookedEncoder
from .HookedAudioEncoder import HookedAudioEncoder
@@ -41,6 +41,7 @@
__all__ = [
"HookedTransformerConfig",
+ "TransformerBridgeConfig",
"FactoredMatrix",
"ActivationCache",
"HookedTransformer",
diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py
index b87998b9f..208bc4334 100644
--- a/transformer_lens/factories/architecture_adapter_factory.py
+++ b/transformer_lens/factories/architecture_adapter_factory.py
@@ -43,6 +43,7 @@
MixtralArchitectureAdapter,
MPTArchitectureAdapter,
NanogptArchitectureAdapter,
+ NativeArchitectureAdapter,
NeelSoluOldArchitectureAdapter,
NeoArchitectureAdapter,
NeoxArchitectureAdapter,
@@ -123,6 +124,7 @@
"MT5ForConditionalGeneration": T5ArchitectureAdapter,
"XGLMForCausalLM": XGLMArchitectureAdapter,
"NanoGPTForCausalLM": NanogptArchitectureAdapter,
+ "TransformerLensNative": NativeArchitectureAdapter,
"MinGPTForCausalLM": MingptArchitectureAdapter,
"GPTNeoForCausalLM": NeoArchitectureAdapter,
"GPTNeoXForCausalLM": NeoxArchitectureAdapter,
diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py
index acf5c74a0..077007cd5 100644
--- a/transformer_lens/model_bridge/bridge.py
+++ b/transformer_lens/model_bridge/bridge.py
@@ -32,6 +32,7 @@
from transformer_lens import utilities as utils
from transformer_lens.ActivationCache import ActivationCache
+from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
@@ -272,6 +273,77 @@ def boot_transformers(
checkpoint_value=checkpoint_value,
)
+ @classmethod
+ def boot_native(
+ cls,
+ config: Union[TransformerBridgeConfig, dict],
+ tokenizer: Optional[Any] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ dtype: Optional[torch.dtype] = None,
+ model_name: str = "native",
+ ) -> "TransformerBridge":
+ """Build a bridge around a small, randomly-initialized TL-native model.
+
+ No HuggingFace Hub call, no ``transformers`` import. ``config.init_mode``
+ and ``config.seed`` control reproducibility.
+ """
+ import copy as _copy
+
+ from transformer_lens.config import TransformerBridgeConfig as _Cfg
+ from transformer_lens.model_bridge.sources._bridge_builder import (
+ build_bridge_from_module,
+ )
+ from transformer_lens.model_bridge.sources.native import (
+ NativeModel,
+ initialize_native_model,
+ )
+
+ cfg: TransformerBridgeConfig
+ if isinstance(config, dict):
+ cfg = _Cfg.from_dict(config)
+ else:
+ # Deep-copy so NativeModel's default-resolution writes don't land
+ # on the caller's config.
+ cfg = _copy.deepcopy(config)
+
+ # Foreign architecture strings would dispatch to the wrong adapter and
+ # crash deep in prepare_model. Refuse them with a pointing message.
+ if cfg.architecture not in (None, "TransformerLensNative"):
+ raise ValueError(
+ f"boot_native cannot build a {cfg.architecture!r} model — "
+ f"it only constructs the TL-native architecture. Either clear "
+ f"config.architecture or set it to 'TransformerLensNative', "
+ f"or use boot_transformers / build_bridge_from_module for "
+ f"non-native architectures."
+ )
+ architecture = "TransformerLensNative"
+
+ # Fork RNG around construction + init when seeded so neither nn.Linear's
+ # default reset_parameters nor our scoped init perturb the caller's RNG.
+ # Unseeded calls let global RNG advance normally.
+ if cfg.seed is not None:
+ with torch.random.fork_rng(devices=[]):
+ model = NativeModel(cfg)
+ initialize_native_model(model, cfg)
+ else:
+ model = NativeModel(cfg)
+ initialize_native_model(model, cfg)
+
+ if device is not None:
+ model = model.to(device)
+ if dtype is not None:
+ model = model.to(dtype=dtype)
+
+ return build_bridge_from_module(
+ model,
+ architecture=architecture,
+ tl_config=cfg,
+ tokenizer=tokenizer,
+ dtype=dtype,
+ device=device,
+ model_name=model_name,
+ )
+
@property
def original_model(self) -> nn.Module:
"""Get the original model."""
diff --git a/transformer_lens/model_bridge/sources/__init__.py b/transformer_lens/model_bridge/sources/__init__.py
index c3e54e236..2338a17e0 100644
--- a/transformer_lens/model_bridge/sources/__init__.py
+++ b/transformer_lens/model_bridge/sources/__init__.py
@@ -3,6 +3,11 @@
This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
"""
+from transformer_lens.model_bridge.sources._bridge_builder import (
+ build_bridge_config_from_hf,
+ build_bridge_from_module,
+ detect_tokenizer_bos_eos,
+)
from transformer_lens.model_bridge.sources.transformers import (
boot,
check_model_support,
@@ -11,6 +16,9 @@
__all__ = [
"boot",
- "list_supported_models",
+ "build_bridge_config_from_hf",
+ "build_bridge_from_module",
"check_model_support",
+ "detect_tokenizer_bos_eos",
+ "list_supported_models",
]
diff --git a/transformer_lens/model_bridge/sources/_bridge_builder.py b/transformer_lens/model_bridge/sources/_bridge_builder.py
new file mode 100644
index 000000000..5ae2dc466
--- /dev/null
+++ b/transformer_lens/model_bridge/sources/_bridge_builder.py
@@ -0,0 +1,197 @@
+"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model."""
+from __future__ import annotations
+
+import copy
+from typing import Any, Callable, Optional
+
+import torch
+from torch import nn
+
+from transformer_lens.config import TransformerBridgeConfig
+from transformer_lens.factories.architecture_adapter_factory import (
+ ArchitectureAdapterFactory,
+)
+from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
+from transformer_lens.model_bridge.bridge import TransformerBridge
+
+# Architecture-agnostic; do not extend per-architecture.
+_HF_PASSTHROUGH_ATTRS = [
+ # OPT
+ "is_gated_act",
+ "word_embed_proj_dim",
+ "do_layer_norm_before",
+ # Granite
+ "position_embedding_type",
+ # Falcon
+ "parallel_attn",
+ "multi_query",
+ "new_decoder_architecture",
+ "alibi",
+ "num_ln_in_parallel_attn",
+ # Mamba (SSM config)
+ "state_size",
+ "conv_kernel",
+ "expand",
+ "time_step_rank",
+ "intermediate_size",
+ # Mamba-2 (additional SSM config)
+ "n_groups",
+ "chunk_size",
+ # Multimodal
+ "vision_config",
+]
+
+
+def build_bridge_config_from_hf(
+ hf_config: Any,
+ architecture: str,
+ model_name: str,
+ dtype: torch.dtype,
+) -> TransformerBridgeConfig:
+ """Translate an HF config into a :class:`TransformerBridgeConfig`."""
+ from transformer_lens.model_bridge.sources.transformers import (
+ map_default_transformer_lens_config,
+ )
+
+ tl_config = map_default_transformer_lens_config(hf_config)
+ config_dict = dict(tl_config.__dict__)
+ # HF's attribute_map remaps num_experts → num_local_experts; restore the TL name.
+ if "num_local_experts" in config_dict and "num_experts" not in config_dict:
+ config_dict["num_experts"] = config_dict["num_local_experts"]
+ bridge_config = TransformerBridgeConfig.from_dict(config_dict)
+ bridge_config.architecture = architecture
+ bridge_config.model_name = model_name
+ bridge_config.dtype = dtype
+
+ for attr in _HF_PASSTHROUGH_ATTRS:
+ val = getattr(hf_config, attr, None)
+ if val is not None:
+ setattr(bridge_config, attr, val)
+
+ # Gemma2: HF softcap field names differ from TL's.
+ final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None)
+ if final_logit_softcapping is not None:
+ bridge_config.output_logits_soft_cap = float(final_logit_softcapping)
+ attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
+ if attn_logit_softcapping is not None:
+ bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
+
+ return bridge_config
+
+
+def detect_tokenizer_bos_eos(tokenizer: Any) -> tuple[bool, bool]:
+ """Detect whether the tokenizer prepends BOS and/or appends EOS.
+
+ Non-empty test string — "" is unreliable with token aliasing.
+ """
+ encoded_test = tokenizer.encode("a")
+ prepends_bos = (
+ len(encoded_test) > 1
+ and tokenizer.bos_token_id is not None
+ and encoded_test[0] == tokenizer.bos_token_id
+ )
+ appends_eos = (
+ len(encoded_test) > 1
+ and tokenizer.eos_token_id is not None
+ and encoded_test[-1] == tokenizer.eos_token_id
+ )
+ return prepends_bos, appends_eos
+
+
+def build_bridge_from_module(
+ model: nn.Module,
+ architecture: str,
+ *,
+ hf_config: Optional[Any] = None,
+ tl_config: Optional[TransformerBridgeConfig] = None,
+ tokenizer: Optional[Any] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Any] = None,
+ model_name: str = "external",
+ post_adapter_hook: Optional[Callable[[ArchitectureAdapter], None]] = None,
+) -> TransformerBridge:
+ """Build a :class:`TransformerBridge` around a pre-loaded model.
+
+ The bridge never moves, casts, or mutates the supplied model.
+
+ Args:
+ model: Any ``nn.Module`` whose submodule tree matches the adapter's
+ expected dot-paths for ``architecture``.
+ architecture: Architecture identifier registered in the
+ ``ArchitectureAdapterFactory`` (e.g. ``"LlamaForCausalLM"``,
+ ``"TransformerLensNative"``).
+ hf_config: Optional HF-style config; translated via
+ :func:`build_bridge_config_from_hf`. Mutually exclusive with ``tl_config``.
+ tl_config: Optional pre-built :class:`TransformerBridgeConfig`; bypasses
+ HF translation. Mutually exclusive with ``hf_config``.
+ tokenizer: Optional tokenizer. If supplied, passes through
+ ``setup_tokenizer`` and detects BOS/EOS behavior.
+ dtype: Recorded on ``cfg.dtype``. Default ``None`` reads from the model's
+ first parameter; explicit values override.
+ device: Recorded on ``cfg.device``. Default ``None`` reads from the
+ model's first parameter.
+ model_name: Recorded on ``cfg.model_name``.
+ post_adapter_hook: Optional callback invoked after adapter selection and
+ before :meth:`adapter.prepare_model`. Source-specific overlays mutate
+ ``component_mapping`` here.
+
+ Returns:
+ A :class:`TransformerBridge` wrapping the supplied model.
+ """
+ if hf_config is None and tl_config is None:
+ raise ValueError(
+ "build_bridge_from_module requires exactly one of hf_config or "
+ "tl_config — the bridge needs config fields (d_model, n_heads, "
+ "n_layers, ...) that can't be inferred from the model alone."
+ )
+ if hf_config is not None and tl_config is not None:
+ raise ValueError(
+ "build_bridge_from_module got both hf_config and tl_config; supply "
+ "exactly one. hf_config triggers HF→bridge translation; tl_config "
+ "bypasses it."
+ )
+
+ # Reading dtype from the model avoids silently lying about a bf16 model.
+ if dtype is None:
+ try:
+ dtype = next(model.parameters()).dtype
+ except StopIteration:
+ dtype = torch.float32
+
+ if tl_config is not None:
+ # Defensive copy so adapter-init mutations (normalization_type, device,
+ # ...) don't leak between bridges built from the same config.
+ bridge_config = copy.deepcopy(tl_config)
+ bridge_config.architecture = architecture
+ if model_name != "external" or not getattr(bridge_config, "model_name", None):
+ bridge_config.model_name = model_name
+ bridge_config.dtype = dtype
+ else:
+ bridge_config = build_bridge_config_from_hf(hf_config, architecture, model_name, dtype)
+
+ adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
+
+ if post_adapter_hook is not None:
+ post_adapter_hook(adapter)
+
+ if device is not None:
+ adapter.cfg.device = str(device)
+ else:
+ try:
+ adapter.cfg.device = str(next(model.parameters()).device)
+ except StopIteration:
+ adapter.cfg.device = "cpu"
+
+ adapter.prepare_model(model)
+
+ if tokenizer is not None:
+ from transformer_lens.model_bridge.sources.transformers import setup_tokenizer
+
+ default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
+ tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
+ (
+ adapter.cfg.tokenizer_prepends_bos,
+ adapter.cfg.tokenizer_appends_eos,
+ ) = detect_tokenizer_bos_eos(tokenizer)
+
+ return TransformerBridge(model, adapter, tokenizer)
diff --git a/transformer_lens/model_bridge/sources/native/__init__.py b/transformer_lens/model_bridge/sources/native/__init__.py
new file mode 100644
index 000000000..f1b860406
--- /dev/null
+++ b/transformer_lens/model_bridge/sources/native/__init__.py
@@ -0,0 +1,16 @@
+"""TL-native model source for TransformerBridge."""
+from transformer_lens.model_bridge.sources.native.init import initialize_native_model
+from transformer_lens.model_bridge.sources.native.model import (
+ NativeAttention,
+ NativeBlock,
+ NativeMLP,
+ NativeModel,
+)
+
+__all__ = [
+ "NativeAttention",
+ "NativeBlock",
+ "NativeMLP",
+ "NativeModel",
+ "initialize_native_model",
+]
diff --git a/transformer_lens/model_bridge/sources/native/init.py b/transformer_lens/model_bridge/sources/native/init.py
new file mode 100644
index 000000000..53ebaaebb
--- /dev/null
+++ b/transformer_lens/model_bridge/sources/native/init.py
@@ -0,0 +1,160 @@
+"""Weight init for NativeModel.
+
+Supported modes: ``"gpt2"`` (Normal(0, std) with 1/sqrt(2*n_layers) residual
+scaling on output projections), ``"xavier_uniform"`` / ``"xavier_normal"``,
+``"kaiming_uniform"`` / ``"kaiming_normal"`` (relu nonlinearity). Norm weights
+go to 1, all biases to 0.
+
+Determinism uses a scoped ``torch.Generator``, not ``torch.manual_seed``, so
+seeded init does not perturb the caller's global RNG.
+"""
+from __future__ import annotations
+
+import math
+from typing import Callable, Optional, cast
+
+import torch
+import torch.nn as nn
+
+from transformer_lens.config import TransformerBridgeConfig
+
+from .model import (
+ NativeAttention,
+ NativeBlock,
+ NativeGatedMLP,
+ NativeMLP,
+ NativeModel,
+ NativeRMSNorm,
+)
+
+# Residual-scaled output is gpt2-specific; other modes treat every weight the
+# same. Each entry takes ``(tensor, generator)`` to thread the scoped Generator.
+_NonResidualInit = Callable[[torch.Tensor, Optional[torch.Generator]], torch.Tensor]
+_NON_RESIDUAL_MODES: dict[str, _NonResidualInit] = {
+ "xavier_uniform": lambda t, g: nn.init.xavier_uniform_(t, generator=g),
+ "xavier_normal": lambda t, g: nn.init.xavier_normal_(t, generator=g),
+ "kaiming_uniform": lambda t, g: nn.init.kaiming_uniform_(t, nonlinearity="relu", generator=g),
+ "kaiming_normal": lambda t, g: nn.init.kaiming_normal_(t, nonlinearity="relu", generator=g),
+}
+
+_SUPPORTED_MODES = frozenset({"gpt2", *_NON_RESIDUAL_MODES})
+
+
+def initialize_native_model(
+ model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None
+) -> None:
+ """Initialize ``model`` weights in-place. Honors ``cfg.init_mode`` and ``cfg.seed``."""
+ effective_seed = seed if seed is not None else cfg.seed
+
+ # Scoped generator on the model's device — None falls back to the global RNG.
+ try:
+ gen_device = next(model.parameters()).device
+ except StopIteration:
+ gen_device = torch.device("cpu")
+ generator: Optional[torch.Generator]
+ if effective_seed is not None:
+ g = torch.Generator(device=gen_device)
+ g.manual_seed(effective_seed)
+ generator = g
+ else:
+ generator = None
+
+ init_mode = (cfg.init_mode or "gpt2").lower()
+ if init_mode not in _SUPPORTED_MODES:
+ raise NotImplementedError(
+ f"init_mode={init_mode!r} is not supported for NativeModel. "
+ f"Supported modes: {sorted(_SUPPORTED_MODES)}."
+ )
+
+ weight_init: Callable[[torch.Tensor], torch.Tensor]
+ output_init: Callable[[torch.Tensor], torch.Tensor]
+ if init_mode == "gpt2":
+ std = cfg.initializer_range if cfg.initializer_range > 0 else 0.02
+ residual_scale = 1.0 / math.sqrt(2 * cfg.n_layers)
+ weight_init = lambda t: nn.init.normal_(
+ t, mean=0.0, std=std, generator=generator
+ ) # noqa: E731
+ output_init = lambda t: nn.init.normal_( # noqa: E731
+ t, mean=0.0, std=std * residual_scale, generator=generator
+ )
+ else:
+ fn = _NON_RESIDUAL_MODES[init_mode]
+ weight_init = lambda t: fn(t, generator) # noqa: E731
+ output_init = weight_init
+
+ weight_init(model.tok_embed.weight)
+ if model.pos is not None:
+ weight_init(model.pos.weight)
+ # Rotary has only registered buffers (cos/sin), no parameters to init.
+
+ for block in model.layers:
+ _init_block(block, weight_init=weight_init, output_init=output_init)
+
+ _init_norm(model.ln_out)
+ weight_init(model.head.weight)
+
+
+def _init_norm(norm: nn.Module) -> None:
+ if isinstance(norm, NativeRMSNorm):
+ nn.init.ones_(norm.weight)
+ elif isinstance(norm, nn.LayerNorm):
+ nn.init.ones_(norm.weight)
+ nn.init.zeros_(norm.bias)
+ else:
+ raise TypeError(f"Unknown normalization type: {type(norm).__name__}")
+
+
+def _init_block(
+ block: NativeBlock,
+ *,
+ weight_init: Callable[[torch.Tensor], torch.Tensor],
+ output_init: Callable[[torch.Tensor], torch.Tensor],
+) -> None:
+ _init_norm(block.ln1)
+ _init_attention(block.attn, weight_init=weight_init, output_init=output_init)
+ if not block.cfg.attn_only:
+ _init_norm(block.ln2)
+ if isinstance(block.mlp, NativeGatedMLP):
+ _init_gated_mlp(block.mlp, weight_init=weight_init, output_init=output_init)
+ else:
+ _init_mlp(block.mlp, weight_init=weight_init, output_init=output_init)
+
+
+def _init_attention(
+ attn: NativeAttention,
+ *,
+ weight_init: Callable[[torch.Tensor], torch.Tensor],
+ output_init: Callable[[torch.Tensor], torch.Tensor],
+) -> None:
+ for linear in (attn.q, attn.k, attn.v):
+ weight_init(linear.weight)
+ if linear.bias is not None:
+ nn.init.zeros_(linear.bias)
+ output_init(attn.o.weight)
+ if attn.o.bias is not None:
+ nn.init.zeros_(attn.o.bias)
+
+
+def _init_mlp(
+ mlp: NativeMLP,
+ *,
+ weight_init: Callable[[torch.Tensor], torch.Tensor],
+ output_init: Callable[[torch.Tensor], torch.Tensor],
+) -> None:
+ weight_init(mlp.fc_in.weight)
+ nn.init.zeros_(mlp.fc_in.bias)
+ output_init(mlp.fc_out.weight)
+ nn.init.zeros_(mlp.fc_out.bias)
+
+
+def _init_gated_mlp(
+ mlp: NativeGatedMLP,
+ *,
+ weight_init: Callable[[torch.Tensor], torch.Tensor],
+ output_init: Callable[[torch.Tensor], torch.Tensor],
+) -> None:
+ weight_init(mlp.gate.weight)
+ # ``in`` is registered via add_module; getattr resolves it from _modules.
+ in_proj = cast(nn.Linear, getattr(mlp, "in"))
+ weight_init(in_proj.weight)
+ output_init(mlp.out.weight)
diff --git a/transformer_lens/model_bridge/sources/native/model.py b/transformer_lens/model_bridge/sources/native/model.py
new file mode 100644
index 000000000..2035ade42
--- /dev/null
+++ b/transformer_lens/model_bridge/sources/native/model.py
@@ -0,0 +1,453 @@
+"""TL-native transformer for TransformerBridge — minimal, no HF/HT dependency.
+
+Cfg-driven features: ``normalization_type`` (LN / RMS / RMSPre), ``final_rms``,
+``gated_mlp``, ``attn_only``, ``n_key_value_heads`` (GQA), ``attn_scores_soft_cap``,
+``output_logits_soft_cap``, ``positional_embedding_type`` (standard / rotary),
+``rotary_dim`` / ``rotary_base`` / ``rope_scaling`` (linear PI, dynamic/NTK,
+llama3 by-parts).
+"""
+from __future__ import annotations
+
+import math
+from typing import Callable, Optional, cast
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from transformer_lens.config import TransformerBridgeConfig
+
+# gelu_new = the tanh-approximation HF GPT-2 / HT use; F.gelu(approximate="tanh")
+# is the exact same formula.
+_Activation = Callable[[torch.Tensor], torch.Tensor]
+_ACTIVATIONS: dict[str, _Activation] = {
+ "gelu": F.gelu,
+ "gelu_new": lambda x: F.gelu(x, approximate="tanh"),
+ "relu": F.relu,
+ "silu": F.silu,
+ "swish": F.silu,
+}
+
+
+def _uses_rms_norm(cfg: TransformerBridgeConfig) -> bool:
+ return (cfg.normalization_type or "LN").upper() in ("RMS", "RMSPRE")
+
+
+def _positional_kind(cfg: TransformerBridgeConfig) -> str:
+ return (getattr(cfg, "positional_embedding_type", None) or "standard").lower()
+
+
+class NativeRMSNorm(nn.Module):
+ """Llama-style RMSNorm. Variance in fp32 regardless of input dtype, then
+ cast back before the per-channel scale (matches HF LlamaRMSNorm)."""
+
+ def __init__(self, d_model: int, eps: float = 1e-5):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(d_model))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ input_dtype = x.dtype
+ x_fp32 = x.to(torch.float32)
+ rms_inv = torch.rsqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+ normalized = (x_fp32 * rms_inv).to(input_dtype)
+ return self.weight * normalized
+
+
+def _make_norm(cfg: TransformerBridgeConfig, *, force_rms: bool = False) -> nn.Module:
+ if force_rms or _uses_rms_norm(cfg):
+ return NativeRMSNorm(cfg.d_model, eps=cfg.eps)
+ return nn.LayerNorm(cfg.d_model, eps=cfg.eps)
+
+
+def _resolve_rope_scaling(
+ cfg: TransformerBridgeConfig, rotary_dim: int
+) -> tuple[float, float, torch.Tensor]:
+ """Returns (effective_base, position_scale, inv_freq) per cfg.rope_scaling."""
+ base = float(cfg.rotary_base)
+ rope_scaling = getattr(cfg, "rope_scaling", None)
+ inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim))
+
+ if not isinstance(rope_scaling, dict):
+ return base, 1.0, inv_freq
+
+ # Newer HF configs key on "rope_type"; older ones on "type".
+ scale_type = str(rope_scaling.get("rope_type") or rope_scaling.get("type") or "").lower()
+ factor = float(rope_scaling.get("factor", 1.0))
+
+ if scale_type in ("", "default") or factor <= 1.0:
+ return base, 1.0, inv_freq
+
+ if scale_type == "linear":
+ return base, factor, inv_freq
+
+ if scale_type in ("dynamic", "ntk"):
+ scaled_base = base * (factor ** (rotary_dim / (rotary_dim - 2)))
+ new_inv_freq = 1.0 / (scaled_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim))
+ return scaled_base, 1.0, new_inv_freq
+
+ if scale_type == "llama3":
+ low_freq_factor = float(rope_scaling.get("low_freq_factor", 1.0))
+ high_freq_factor = float(rope_scaling.get("high_freq_factor", 4.0))
+ original_ctx = float(
+ rope_scaling.get("original_max_position_embeddings")
+ or rope_scaling.get("original_context_length")
+ or 8192
+ )
+ low_wavelen = original_ctx / low_freq_factor
+ high_wavelen = original_ctx / high_freq_factor
+ wavelens = 2 * math.pi / inv_freq
+ # Three regimes: low-freq → divide by factor; high-freq → unchanged;
+ # in-between → smooth linear interpolation between the two.
+ smooth = (original_ctx / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor)
+ new_inv_freq = torch.where(
+ wavelens > low_wavelen,
+ inv_freq / factor,
+ torch.where(
+ wavelens < high_wavelen,
+ inv_freq,
+ (1 - smooth) * inv_freq / factor + smooth * inv_freq,
+ ),
+ )
+ return base, 1.0, new_inv_freq
+
+ raise NotImplementedError(
+ f"rope_scaling type {scale_type!r} is not supported. "
+ f"Supported: 'linear', 'dynamic'/'ntk', 'llama3'."
+ )
+
+
+class NativeRotary(nn.Module):
+ """Shared cos/sin tables for RoPE. Honors ``cfg.rope_scaling``."""
+
+ # Declared so mypy sees the buffer dtype; register_buffer alone reports Module|Tensor.
+ cos_cached: torch.Tensor
+ sin_cached: torch.Tensor
+
+ def __init__(self, cfg: TransformerBridgeConfig):
+ super().__init__()
+ rotary_dim = cfg.rotary_dim if cfg.rotary_dim is not None else cfg.d_head
+ if rotary_dim <= 0 or rotary_dim % 2 != 0:
+ raise ValueError(f"rotary_dim must be a positive even integer, got {rotary_dim!r}")
+ self.rotary_dim = rotary_dim
+
+ base, position_scale, inv_freq = _resolve_rope_scaling(cfg, rotary_dim)
+
+ positions = torch.arange(cfg.n_ctx).float() / position_scale
+ freqs = torch.outer(positions, inv_freq)
+ # Llama/HF adjacent-pair format: each (2i, 2i+1) pair rotates together.
+ cos = freqs.cos().repeat_interleave(2, dim=-1)
+ sin = freqs.sin().repeat_interleave(2, dim=-1)
+ self.register_buffer("cos_cached", cos, persistent=False)
+ self.register_buffer("sin_cached", sin, persistent=False)
+ self.effective_base = base
+ self.position_scale = position_scale
+
+ @staticmethod
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0).
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+ rot = torch.stack((-x2, x1), dim=-1)
+ return rot.flatten(-2)
+
+ def apply_rope(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ *,
+ position_ids: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Apply RoPE to Q/K of shape [batch, heads, seq, d_head].
+
+ Named ``apply_rope`` rather than ``apply`` so ``nn.Module.apply(fn)``
+ — PyTorch's recursive function-application utility used by
+ ``bridge.apply(init_fn)`` — isn't shadowed.
+ """
+ seq = q.shape[-2]
+ rd = self.rotary_dim
+ if position_ids is None:
+ cos = self.cos_cached[:seq].to(q.dtype)
+ sin = self.sin_cached[:seq].to(q.dtype)
+ else:
+ # [batch, seq] -> [batch, 1, seq, rd] (head dim for broadcast).
+ cos = self.cos_cached[position_ids].to(q.dtype).unsqueeze(1)
+ sin = self.sin_cached[position_ids].to(q.dtype).unsqueeze(1)
+
+ def _rope(x: torch.Tensor) -> torch.Tensor:
+ x_rot, x_pass = x[..., :rd], x[..., rd:]
+ x_rot = x_rot * cos + self._rotate_half(x_rot) * sin
+ return torch.cat([x_rot, x_pass], dim=-1) if x_pass.shape[-1] else x_rot
+
+ return _rope(q), _rope(k)
+
+
+class NativeAttention(nn.Module):
+ """Split-QKV causal self-attention. Returns (out, pattern); AttentionBridge
+ fires ``hook_pattern`` off the second element."""
+
+ causal_mask: torch.Tensor
+
+ def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None):
+ super().__init__()
+ self.cfg = cfg
+ self.n_heads = cfg.n_heads
+ self.d_head = cfg.d_head
+ self.d_model = cfg.d_model
+ self.n_kv_heads = cfg.n_key_value_heads or cfg.n_heads
+ if self.n_heads % self.n_kv_heads != 0:
+ raise ValueError(
+ f"n_heads ({self.n_heads}) must be divisible by n_key_value_heads "
+ f"({self.n_kv_heads}) for GQA."
+ )
+ self.kv_repeats = self.n_heads // self.n_kv_heads
+
+ q_dim = self.n_heads * self.d_head
+ kv_dim = self.n_kv_heads * self.d_head
+ self.q = nn.Linear(cfg.d_model, q_dim, bias=True)
+ self.k = nn.Linear(cfg.d_model, kv_dim, bias=True)
+ self.v = nn.Linear(cfg.d_model, kv_dim, bias=True)
+ self.o = nn.Linear(q_dim, cfg.d_model, bias=True)
+
+ mask = torch.triu(torch.ones(cfg.n_ctx, cfg.n_ctx, dtype=torch.bool), diagonal=1)
+ self.register_buffer("causal_mask", mask, persistent=False)
+
+ # attn_scale=1.0 reads like "standard scaling" but is "divide by 1" —
+ # i.e. unscaled scores, which saturate softmax for d_head>1.
+ if cfg.use_attn_scale and cfg.attn_scale > 0:
+ if self.d_head > 1 and math.isclose(cfg.attn_scale, 1.0, abs_tol=1e-9):
+ raise ValueError(
+ f"attn_scale=1.0 with d_head={self.d_head} (>1) is unscaled "
+ f"attention; softmax will saturate. For standard scaling "
+ f"leave attn_scale at -1 (sentinel for sqrt(d_head))."
+ )
+ scale = cfg.attn_scale
+ else:
+ scale = math.sqrt(cfg.d_head)
+ self.scale = scale
+ self.rotary = rotary
+ self.attn_scores_soft_cap = float(cfg.attn_scores_soft_cap)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch, seq, _ = hidden_states.shape
+
+ q = self.q(hidden_states).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
+ k = self.k(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2)
+ v = self.v(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2)
+
+ if self.rotary is not None:
+ q, k = self.rotary.apply_rope(q, k, position_ids=position_ids)
+
+ # GQA: repeat_interleave matches HF Llama's repeat_kv group ordering.
+ if self.kv_repeats > 1:
+ k = k.repeat_interleave(self.kv_repeats, dim=1)
+ v = v.repeat_interleave(self.kv_repeats, dim=1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
+ # Gemma2 soft-cap before the causal mask so masked positions stay -inf.
+ if self.attn_scores_soft_cap > 0:
+ c = self.attn_scores_soft_cap
+ scores = c * torch.tanh(scores / c)
+
+ block_mask = self.causal_mask[:seq, :seq]
+ if attention_mask is not None:
+ block_mask = self._combine_attention_mask(block_mask, attention_mask, batch=batch)
+ scores = scores.masked_fill(block_mask, float("-inf"))
+
+ pattern = F.softmax(scores, dim=-1)
+
+ attn = torch.matmul(pattern, v).transpose(1, 2).contiguous().view(batch, seq, -1)
+ out = self.o(attn)
+ return out, pattern
+
+ @staticmethod
+ def _combine_attention_mask(
+ block_mask: torch.Tensor, attention_mask: torch.Tensor, *, batch: int
+ ) -> torch.Tensor:
+ """Combine an external attention_mask with the causal mask.
+
+ Accepts 2D HF padding mask ``[batch, seq]`` (1=keep, 0=mask), 4D bool
+ mask (True=mask), or 4D additive float mask (HF generation style; values
+ below -1 treated as masked).
+ """
+ if attention_mask.dim() == 2:
+ pad_mask = ~attention_mask.bool()
+ return block_mask | pad_mask[:, None, None, :]
+ if attention_mask.dim() == 4:
+ if attention_mask.dtype is torch.bool:
+ return block_mask | attention_mask
+ # HF additive masks use -inf or large negatives; benign biases bounded.
+ extra = attention_mask < -1.0
+ return block_mask | extra
+ raise ValueError(
+ f"attention_mask must be 2D [batch, seq] or 4D [batch, *, seq, seq], "
+ f"got shape {tuple(attention_mask.shape)}."
+ )
+
+
+class NativeMLP(nn.Module):
+ """Two-layer MLP with configurable activation."""
+
+ act: Callable[[torch.Tensor], torch.Tensor]
+
+ def __init__(self, cfg: TransformerBridgeConfig):
+ super().__init__()
+ assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs"
+ d_mlp: int = cfg.d_mlp
+ self.fc_in = nn.Linear(cfg.d_model, d_mlp, bias=True)
+ self.fc_out = nn.Linear(d_mlp, cfg.d_model, bias=True)
+ act_name = (cfg.act_fn or "gelu").lower()
+ if act_name not in _ACTIVATIONS:
+ raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}")
+ self.act = _ACTIVATIONS[act_name]
+
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
+ return self.fc_out(self.act(self.fc_in(hidden_states)))
+
+
+class NativeGatedMLP(nn.Module):
+ """SwiGLU / ReGLU / GeGLU gated MLP (variant picked by ``cfg.act_fn``).
+
+ Submodules ``gate`` / ``in`` / ``out`` match GatedMLPBridge's expected slots.
+ """
+
+ act: Callable[[torch.Tensor], torch.Tensor]
+
+ def __init__(self, cfg: TransformerBridgeConfig):
+ super().__init__()
+ assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs"
+ d_mlp: int = cfg.d_mlp
+ # Llama convention: no biases on gated MLP projections.
+ self.gate = nn.Linear(cfg.d_model, d_mlp, bias=False)
+ # ``in`` is a Python keyword; add_module + getattr(self, "in") works
+ # because the bridge resolves LinearBridge(name="in") the same way.
+ self.add_module("in", nn.Linear(cfg.d_model, d_mlp, bias=False))
+ self.out = nn.Linear(d_mlp, cfg.d_model, bias=False)
+ # Default to SwiGLU; mirror NativeMLP's dispatch so a typo'd act_fn
+ # raises instead of silently changing the model.
+ act_name = (cfg.act_fn or "silu").lower()
+ if act_name not in _ACTIVATIONS:
+ raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}")
+ self.act = _ACTIVATIONS[act_name]
+
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
+ gate_out = self.act(self.gate(hidden_states))
+ in_proj = cast(nn.Linear, getattr(self, "in"))
+ up_out = in_proj(hidden_states)
+ return self.out(gate_out * up_out)
+
+
+class NativeBlock(nn.Module):
+ """Pre-LN transformer block. Layout adapts to ``cfg.attn_only`` and
+ ``cfg.gated_mlp``."""
+
+ def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None):
+ super().__init__()
+ self.cfg = cfg
+ self.ln1 = _make_norm(cfg)
+ self.attn = NativeAttention(cfg, rotary=rotary)
+ if not cfg.attn_only:
+ self.ln2 = _make_norm(cfg)
+ self.mlp = NativeGatedMLP(cfg) if cfg.gated_mlp else NativeMLP(cfg)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor]:
+ attn_out, _pattern = self.attn(
+ self.ln1(hidden_states),
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+ hidden_states = hidden_states + attn_out
+ if not self.cfg.attn_only:
+ hidden_states = hidden_states + self.mlp(self.ln2(hidden_states))
+ # Tuple return matches HF block convention; BlockBridge's parser expects it.
+ return (hidden_states,)
+
+
+class NativeModel(nn.Module):
+ """TL-native transformer. See module docstring for the supported feature set."""
+
+ pos: Optional[nn.Embedding]
+ rotary: Optional[NativeRotary]
+
+ def __init__(self, cfg: TransformerBridgeConfig):
+ super().__init__()
+ # Write the resolved d_mlp back so downstream consumers see the real
+ # value, not None. Mutates cfg; isolating callers should deep-copy first.
+ if not getattr(cfg, "d_mlp", None):
+ cfg.d_mlp = 4 * cfg.d_model
+ self.cfg = cfg
+
+ self.tok_embed = nn.Embedding(cfg.d_vocab, cfg.d_model)
+
+ kind = _positional_kind(cfg)
+ if kind == "standard":
+ self.pos = nn.Embedding(cfg.n_ctx, cfg.d_model)
+ self.rotary = None
+ elif kind == "rotary":
+ self.pos = None
+ self.rotary = NativeRotary(cfg)
+ else:
+ raise ValueError(
+ f"Unsupported positional_embedding_type={kind!r}. "
+ f"NativeModel supports 'standard' and 'rotary'."
+ )
+
+ self.layers = nn.ModuleList(
+ [NativeBlock(cfg, rotary=self.rotary) for _ in range(cfg.n_layers)]
+ )
+ # final_rms forces RMS on the final norm regardless of block-norm choice
+ # — matches the TL config semantic Llama uses.
+ self.ln_out = _make_norm(cfg, force_rms=cfg.final_rms)
+ d_vocab_out = cfg.d_vocab_out if cfg.d_vocab_out > 0 else cfg.d_vocab
+ self.head = nn.Linear(cfg.d_model, d_vocab_out, bias=False)
+ self.output_logits_soft_cap = float(cfg.output_logits_soft_cap)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """Returns logits directly."""
+ # Bounds check up front so both absolute and rotary paths produce a
+ # self-explanatory error rather than IndexError / shape mismatch.
+ seq_len = input_ids.shape[-1]
+ if seq_len > self.cfg.n_ctx:
+ raise ValueError(
+ f"input length {seq_len} exceeds n_ctx={self.cfg.n_ctx}; "
+ f"position embeddings and rotary tables are pre-baked at n_ctx."
+ )
+
+ # Resolve position_ids before the block loop so rotary sees the caller's
+ # positions, not the dense default.
+ batch, seq = input_ids.shape
+ if position_ids is None:
+ position_ids = torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1)
+
+ hidden_states = self.tok_embed(input_ids)
+ if self.pos is not None:
+ hidden_states = hidden_states + self.pos(position_ids)
+
+ for block in self.layers:
+ (hidden_states,) = block(
+ hidden_states, attention_mask=attention_mask, position_ids=position_ids
+ )
+ hidden_states = self.ln_out(hidden_states)
+ logits = self.head(hidden_states)
+ if self.output_logits_soft_cap > 0:
+ c = self.output_logits_soft_cap
+ logits = c * torch.tanh(logits / c)
+ return logits
diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py
index 8bd181d1e..87212f6d8 100644
--- a/transformer_lens/model_bridge/supported_architectures/__init__.py
+++ b/transformer_lens/model_bridge/supported_architectures/__init__.py
@@ -102,6 +102,9 @@
from transformer_lens.model_bridge.supported_architectures.nanogpt import (
NanogptArchitectureAdapter,
)
+from transformer_lens.model_bridge.supported_architectures.native import (
+ NativeArchitectureAdapter,
+)
from transformer_lens.model_bridge.supported_architectures.neel_solu_old import (
NeelSoluOldArchitectureAdapter,
)
@@ -197,6 +200,7 @@
"MixtralArchitectureAdapter",
"MPTArchitectureAdapter",
"NanogptArchitectureAdapter",
+ "NativeArchitectureAdapter",
"NeelSoluOldArchitectureAdapter",
"NeoArchitectureAdapter",
"NeoxArchitectureAdapter",
diff --git a/transformer_lens/model_bridge/supported_architectures/native.py b/transformer_lens/model_bridge/supported_architectures/native.py
new file mode 100644
index 000000000..697a3e8e0
--- /dev/null
+++ b/transformer_lens/model_bridge/supported_architectures/native.py
@@ -0,0 +1,139 @@
+"""Architecture adapter for TL-native models built via ``boot_native``.
+
+Component mapping adapts to cfg: gated MLP → ``GatedMLPBridge``, RMS norm →
+``RMSNormalizationBridge``, rotary drops ``pos_embed``, ``attn_only`` drops MLP.
+"""
+from typing import Any
+
+from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
+from transformer_lens.model_bridge.generalized_components import (
+ AttentionBridge,
+ BlockBridge,
+ EmbeddingBridge,
+ GatedMLPBridge,
+ LinearBridge,
+ MLPBridge,
+ NormalizationBridge,
+ PosEmbedBridge,
+ RMSNormalizationBridge,
+ UnembeddingBridge,
+)
+
+
+def _uses_rms(cfg: Any) -> bool:
+ return (getattr(cfg, "normalization_type", None) or "LN").upper() in ("RMS", "RMSPRE")
+
+
+def _is_rotary(cfg: Any) -> bool:
+ return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() == "rotary"
+
+
+def _make_norm_bridge(name: str, cfg: Any, *, force_rms: bool = False):
+ if force_rms or _uses_rms(cfg):
+ return RMSNormalizationBridge(name=name, config=cfg)
+ return NormalizationBridge(name=name, config=cfg)
+
+
+def _make_mlp_bridge(cfg: Any):
+ if cfg.gated_mlp:
+ return GatedMLPBridge(
+ name="mlp",
+ config=cfg,
+ submodules={
+ "gate": LinearBridge(name="gate"),
+ "in": LinearBridge(name="in"),
+ "out": LinearBridge(name="out"),
+ },
+ )
+ return MLPBridge(
+ name="mlp",
+ submodules={
+ "in": LinearBridge(name="fc_in"),
+ "out": LinearBridge(name="fc_out"),
+ },
+ )
+
+
+def _make_block_submodules(cfg: Any) -> dict:
+ submods: dict = {
+ "ln1": _make_norm_bridge("ln1", cfg),
+ "attn": AttentionBridge(
+ name="attn",
+ config=cfg,
+ submodules={
+ "q": LinearBridge(name="q"),
+ "k": LinearBridge(name="k"),
+ "v": LinearBridge(name="v"),
+ "o": LinearBridge(name="o"),
+ },
+ ),
+ }
+ if not cfg.attn_only:
+ submods["ln2"] = _make_norm_bridge("ln2", cfg)
+ submods["mlp"] = _make_mlp_bridge(cfg)
+ return submods
+
+
+class NativeArchitectureAdapter(ArchitectureAdapter):
+ """Adapter for ``NativeModel`` — TL-native, split-QKV, pre-LN; feature set
+ driven by cfg (gated MLP, RMS norm, rotary, GQA, soft-cap, attn_only)."""
+
+ def __init__(self, cfg: Any) -> None:
+ super().__init__(cfg)
+
+ # Native layout already stores Q/K/V split; no rearranges needed.
+ # Compatibility-mode fold_ln / center_writing_weights aren't wired up,
+ # so gate the corresponding ProcessWeights paths off — folding without
+ # the state-dict conversions would mis-place or drop weights.
+ self.supports_fold_ln = False
+ self.supports_center_writing_weights = False
+ self.weight_processing_conversions = {}
+
+ # Internal attribute names avoid collisions with bridge slot names
+ # ("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__
+ # forwards to original_model and would shadow add_module otherwise.
+ mapping: dict = {
+ "embed": EmbeddingBridge(name="tok_embed"),
+ }
+ if not _is_rotary(cfg):
+ mapping["pos_embed"] = PosEmbedBridge(name="pos")
+ block_bridge = BlockBridge(
+ name="layers",
+ config=self.cfg,
+ submodules=_make_block_submodules(self.cfg),
+ )
+ # Under attn_only there's no ln2 / mlp to point at; drop the aliases
+ # that would otherwise warn during _register_aliases.
+ if self.cfg.attn_only:
+ if block_bridge.hook_aliases is BlockBridge.hook_aliases:
+ block_bridge.hook_aliases = dict(block_bridge.hook_aliases)
+ block_bridge.hook_aliases.pop("hook_resid_mid", None)
+ block_bridge.hook_aliases.pop("hook_mlp_out", None)
+ mapping["blocks"] = block_bridge
+ # final_rms forces RMS on the final norm independent of block norm —
+ # matches Llama's TL config semantic.
+ mapping["ln_final"] = _make_norm_bridge(
+ "ln_out", self.cfg, force_rms=bool(getattr(self.cfg, "final_rms", False))
+ )
+ mapping["unembed"] = UnembeddingBridge(name="head")
+ self.component_mapping = mapping
+
+ def prepare_model(self, model: Any) -> None:
+ """Reject modules whose attribute names collide with bridge slots.
+
+ Bridge's ``__getattr__`` falls back to ``getattr(original_model, name)``
+ for unknown attrs, so a name match — submodule, buffer, plain tensor,
+ or property — makes ``add_module`` raise mid-setup with an opaque
+ message. Failing here points at the real cause. Reserved set is derived
+ from ``component_mapping.keys()`` so adapter variants stay in sync.
+ """
+ reserved = set(self.component_mapping.keys()) if self.component_mapping else set()
+ collisions = sorted(name for name in reserved if hasattr(model, name))
+ if collisions:
+ raise ValueError(
+ f"{type(model).__name__} cannot be wrapped by NativeArchitectureAdapter: "
+ f"attribute names {collisions} collide with bridge component slots "
+ f"({sorted(reserved)}). Rename these attributes to non-colliding names "
+ f"(e.g. tok_embed, layers, ln_out, head) and update the adapter's "
+ f"component_mapping ``name=`` fields to match."
+ )