From 1d00e5b273433b2fb5542b976119ed55becc391b Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 1 Jun 2026 15:54:51 -0500 Subject: [PATCH 1/2] Initial native source setup --- demos/Realtime_Training_Telemetry_Demo.ipynb | 914 +++++++++--------- .../model_bridge/test_native_training.py | 140 +++ tests/unit/model_bridge/test_boot_native.py | 313 ++++++ .../test_build_bridge_from_module.py | 166 ++++ .../unit/model_bridge/test_native_features.py | 379 ++++++++ transformer_lens/__init__.py | 3 +- .../factories/architecture_adapter_factory.py | 2 + transformer_lens/model_bridge/bridge.py | 86 ++ .../model_bridge/sources/__init__.py | 10 +- .../model_bridge/sources/_bridge_builder.py | 204 ++++ .../model_bridge/sources/native/__init__.py | 16 + .../model_bridge/sources/native/init.py | 152 +++ .../model_bridge/sources/native/model.py | 353 +++++++ .../supported_architectures/__init__.py | 4 + .../supported_architectures/native.py | 163 ++++ 15 files changed, 2448 insertions(+), 457 deletions(-) create mode 100644 tests/integration/model_bridge/test_native_training.py create mode 100644 tests/unit/model_bridge/test_boot_native.py create mode 100644 tests/unit/model_bridge/test_build_bridge_from_module.py create mode 100644 tests/unit/model_bridge/test_native_features.py create mode 100644 transformer_lens/model_bridge/sources/_bridge_builder.py create mode 100644 transformer_lens/model_bridge/sources/native/__init__.py create mode 100644 transformer_lens/model_bridge/sources/native/init.py create mode 100644 transformer_lens/model_bridge/sources/native/model.py create mode 100644 transformer_lens/model_bridge/supported_architectures/native.py diff --git a/demos/Realtime_Training_Telemetry_Demo.ipynb b/demos/Realtime_Training_Telemetry_Demo.ipynb index d77d67274..1e52ee7c6 100644 --- a/demos/Realtime_Training_Telemetry_Demo.ipynb +++ b/demos/Realtime_Training_Telemetry_Demo.ipynb @@ -1,464 +1,468 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ctqPlNKhJFgZ" + }, + "source": [ + "# Real-time Training Telemetry Demo\n", + "\n", + "This notebook demonstrates the real-time extraction and visualization of mechanistic metrics during a model's training loop using TransformerLens.\n", + "\n", + "By leveraging dynamic dictionary logging alongside the `ActivationCache`, we can isolate the training window where localized phase transitions—such as the formation of induction heads—begin to emerge.\n", + "\n", + "**A note on scaling**: While this 2-layer toy model allows for high-granularity tracking with minimal computational overhead, achieving similar resolution in larger architectures is non-trivial. It requires highly targeted caching and direct manipulation of the telemetry bridge to surface this level of detail without memory exhaustion.\n", + "\n", + "**Compute requirements**: A standard CPU is entirely sufficient for this demonstration. The 500-step training loop will execute rapidly in standard local or cloud-based notebook environments without the need for hardware acceleration." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tQ_xnziNJlsF" + }, + "source": [ + "Initializes the workspace, configures Plotly renderers, and defines the toy model architecture.\n", + "\n", + "**Visualization Context** `Plotly Renderers`:\n", + "\n", + "`Plotly` generates interactive, JavaScript-based visualizations. Google Colab handles these DOM interactions differently than local Jupyter or VS Code environments. We detect the active environment to set the appropriate `plotly.io` renderer (`\"colab\"` vs `\"notebook_connected\"`), ensuring the dynamic telemetry plots render correctly without blank output blocks.\n", + "\n", + "**Architectural Rationale:**\n", + "\n", + "\n", + "* **2 Layers (`n_layers=2`):** The theoretical minimum depth required for induction circuits. Layer 0 creates \"previous token\" representations, and Layer 1 queries these to predict the next token based on earlier context.\n", + "\n", + "* **2 Heads (`n_heads=2`):** Provides just enough capacity for heads to specialize (e.g., dedicating one head to induction) without creating excessive noise in the telemetry visualizations.\n", + "\n", + "* **GELU Activation (`act_fn=\"gelu\"`):** Selected over ReLU to mirror the smooth non-linearities of modern production LLMs, ensuring the activation dynamics remain representative of real-world architectures.\n", + "\n", + "* **Miniaturized Dimensions:** `d_model=64`, `d_vocab=64`, and `n_ctx=32` are intentionally bottlenecked to force rapid convergence, reliably inducing the phase transition within a brief 500-step training window." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { "colab": { - "provenance": [] + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" + "id": "GyA-sxLOLRVA", + "outputId": "1cc176e2-7638-4106-b063-edce512c6f6d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Environment: Local / Standard Jupyter\n", + "Plotly Renderer: notebook_connected\n", + "Running on cpu\n" + ] } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "# Detect execution environment\n", + "try:\n", + " import google.colab # noqa: F401\n", + " IN_COLAB = True\n", + " print(\"Environment: Google Colab\")\n", + "except ImportError:\n", + " IN_COLAB = False\n", + " print(\"Environment: Local / Standard Jupyter\")\n", + "\n", + "# Environment-specific dependency management\n", + "if IN_COLAB:\n", + " %pip install -q transformer_lens\n", + " %pip install -q circuitsvis\n", + "\n", + "import plotly.io as pio\n", + "import plotly.graph_objects as go\n", + "from plotly.subplots import make_subplots\n", + "\n", + "# Configure Plotly renderer for correct JavaScript execution\n", + "if IN_COLAB:\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"\n", + "print(f\"Plotly Renderer: {pio.renderers.default}\")\n", + "\n", + "# Must be imported after Colab pip install\n", + "from transformer_lens import TransformerBridgeConfig # noqa: E402\n", + "from transformer_lens.model_bridge import TransformerBridge # noqa: E402\n", + "\n", + "# Configuration\n", + "torch.manual_seed(42)\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Running on {device}\")\n", + "\n", + "# TransformerBridge.boot_native builds a small TL-native model (no HuggingFace\n", + "# Hub call, no `transformers` import) and wraps it in a bridge. The bridge's\n", + "# `forward`, `run_with_cache`, and `parameters()` surfaces are the same ones\n", + "# you'd use on any other bridge-backed model.\n", + "cfg = TransformerBridgeConfig(\n", + " n_layers=2,\n", + " d_model=64,\n", + " d_head=32,\n", + " n_heads=2,\n", + " d_mlp=256,\n", + " d_vocab=64,\n", + " n_ctx=32,\n", + " act_fn=\"gelu\",\n", + " normalization_type=\"LN\",\n", + " seed=42,\n", + ")\n", + "model = TransformerBridge.boot_native(cfg, device=device)\n" + ] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Real-time Training Telemetry Demo\n", - "\n", - "This notebook demonstrates the real-time extraction and visualization of mechanistic metrics during a model's training loop using TransformerLens.\n", - "\n", - "By leveraging dynamic dictionary logging alongside the `ActivationCache`, we can isolate the training window where localized phase transitions\u2014such as the formation of induction heads\u2014begin to emerge.\n", - "\n", - "**A note on scaling**: While this 2-layer toy model allows for high-granularity tracking with minimal computational overhead, achieving similar resolution in larger architectures is non-trivial. It requires highly targeted caching and direct manipulation of the telemetry bridge to surface this level of detail without memory exhaustion.\n", - "\n", - "**Compute requirements**: A standard CPU is entirely sufficient for this demonstration. The 500-step training loop will execute rapidly in standard local or cloud-based notebook environments without the need for hardware acceleration." - ], - "metadata": { - "id": "ctqPlNKhJFgZ" - } - }, - { - "cell_type": "markdown", - "source": [ - "Initializes the workspace, configures Plotly renderers, and defines the toy model architecture.\n", - "\n", - "**Visualization Context** `Plotly Renderers`:\n", - "\n", - "`Plotly` generates interactive, JavaScript-based visualizations. Google Colab handles these DOM interactions differently than local Jupyter or VS Code environments. We detect the active environment to set the appropriate `plotly.io` renderer (`\"colab\"` vs `\"notebook_connected\"`), ensuring the dynamic telemetry plots render correctly without blank output blocks.\n", - "\n", - "**Architectural Rationale:**\n", - "\n", - "\n", - "* **2 Layers (`n_layers=2`):** The theoretical minimum depth required for induction circuits. Layer 0 creates \"previous token\" representations, and Layer 1 queries these to predict the next token based on earlier context.\n", - "\n", - "* **2 Heads (`n_heads=2`):** Provides just enough capacity for heads to specialize (e.g., dedicating one head to induction) without creating excessive noise in the telemetry visualizations.\n", - "\n", - "* **GELU Activation (`act_fn=\"gelu\"`):** Selected over ReLU to mirror the smooth non-linearities of modern production LLMs, ensuring the activation dynamics remain representative of real-world architectures.\n", - "\n", - "* **Miniaturized Dimensions:** `d_model=64`, `d_vocab=64`, and `n_ctx=32` are intentionally bottlenecked to force rapid convergence, reliably inducing the phase transition within a brief 500-step training window." - ], - "metadata": { - "id": "tQ_xnziNJlsF" - } - }, - { - "cell_type": "code", - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "# Detect execution environment\n", - "try:\n", - " import google.colab # noqa: F401\n", - " IN_COLAB = True\n", - " print(\"Environment: Google Colab\")\n", - "except ImportError:\n", - " IN_COLAB = False\n", - " print(\"Environment: Local / Standard Jupyter\")\n", - "\n", - "# Environment-specific dependency management\n", - "if IN_COLAB:\n", - " %pip install -q transformer_lens\n", - " %pip install -q circuitsvis\n", - "\n", - "import plotly.io as pio\n", - "import plotly.graph_objects as go\n", - "from plotly.subplots import make_subplots\n", - "\n", - "# Configure Plotly renderer for correct JavaScript execution\n", - "if IN_COLAB:\n", - " pio.renderers.default = \"colab\"\n", - "else:\n", - " pio.renderers.default = \"notebook_connected\"\n", - "print(f\"Plotly Renderer: {pio.renderers.default}\")\n", - "\n", - "# Must be imported after Colab pip install\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig # noqa: E402\n", - "\n", - "# Configuration\n", - "torch.manual_seed(42)\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "print(f\"\ud83d\ude80 Running on {device}\")\n", - "\n", - "cfg = HookedTransformerConfig(\n", - " n_layers=2,\n", - " d_model=64,\n", - " d_head=32,\n", - " n_heads=2,\n", - " d_mlp=256,\n", - " d_vocab=64,\n", - " n_ctx=32,\n", - " act_fn=\"gelu\",\n", - " normalization_type=\"LN\",\n", - " seed=42,\n", - ")\n", - "model = HookedTransformer(cfg).to(device)" - ], - "metadata": { - "id": "GyA-sxLOLRVA", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "1cc176e2-7638-4106-b063-edce512c6f6d" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Environment: Google Colab\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m968.6/968.6 kB\u001b[0m \u001b[31m23.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m56.4/56.4 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for transformers-stream-generator (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hPlotly Renderer: colab\n", - "\ud83d\ude80 Running on cpu\n", - "Moving model to device: cpu\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "# Attention Telemetry Extraction\n", - "\n", - "This bridge extracts mechanistic metrics directly from the `ActivationCache` during the forward pass.\n", - "\n", - "* **Head Coherence:** Measures attention focus sharpness via normalized entropy. A score of $1.0$ indicates perfect focus on a single token, while $0.0$ indicates uniformly distributed attention.\n", - "\n", - "* **Head Agreement:** Evaluates intra-layer behavioral similarity among attention heads.\n", - "* **Variance Normalization:** The agreement metric normalizes against a `0.005` variance constant. This serves as an expected baseline for inter-head variance at the point of specialization in 2-layer models.\n", - "\n", - "**Note:** This constant is a localized architectural assumption. Recalibrating this threshold will likely be necessary when porting the telemetry bridge to larger, higher-dimensional models." - ], - "metadata": { - "id": "0hMNO0WpKXhR" - } - }, - { - "cell_type": "code", - "source": [ - "class AttentionTelemetry:\n", - " \"\"\"Lightweight bridge extracting mechanistic metrics from ActivationCache.\"\"\"\n", - "\n", - " @staticmethod\n", - " def compute_metrics(cache, layer_idx):\n", - " \"\"\"Computes attention coherence and agreement for a given layer.\n", - "\n", - " Args:\n", - " cache (ActivationCache): The cached activations from the forward pass.\n", - " layer_idx (int): The index of the layer to analyze.\n", - "\n", - " Returns:\n", - " dict: A dictionary containing layer_idx, head_coherence, and head_agreement.\n", - "\n", - " Notes on v_max (0.005):\n", - " The agreement normalization constant is derived from the expected inter-head\n", - " attention variance at convergence in 2-layer induction head toy models.\n", - " At convergence, heads specialize (low variance); pre-convergence variance\n", - " peaks near 0.005. This value is task- and architecture-specific; adjust\n", - " if adapting to larger models or different tasks.\n", - " \"\"\"\n", - "\n", - " pattern_name = f\"blocks.{layer_idx}.attn.hook_pattern\"\n", - "\n", - " # Shape: [batch, heads, seq, seq]\n", - " attn_probs = cache[pattern_name]\n", - "\n", - " # 1. Head Coherence: 1.0 - (Entropy / Max_Entropy)\n", - " probs = attn_probs + 1e-9\n", - " entropy = -torch.sum(probs * torch.log(probs), dim=-1) # [batch, heads, seq]\n", - " head_coherence = 1.0 - (entropy.mean(dim=[0, 2]) / np.log(attn_probs.shape[-1]))\n", - "\n", - " # 2. Head Agreement: 1.0 - clip(Variance / v_max)\n", - " mean_var = torch.var(attn_probs, dim=1).mean() # Variance across heads\n", - " head_agreement = 1.0 - torch.clamp(mean_var / 0.005, 0.0, 1.0)\n", - "\n", - " return {\n", - " \"layer_idx\": layer_idx,\n", - " \"head_coherence\": head_coherence.mean().item(),\n", - " \"head_agreement\": head_agreement.item()\n", - " }" - ], - "metadata": { - "id": "fd_0TLTUUNf6" - }, - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# Synthetic Data Generation & Training Loop\n", - "\n", - "This cell generates a highly constrained dataset to force circuit formation and executes the training loop, capturing telemetry at specified intervals.\n", - "\n", - "**A Note on Synthetic Data:**\n", - "\n", - "The repeated sequence generator (`[A, B, C, ..., A, B, C]`) used below is strictly illustrative. It is engineered specifically as a shortcut to force the rapid emergence of in-context look-back circuits (induction heads). It is not meant to serve as an educational standard for model training.\n", - "\n", - "**Transitioning to Real Data:**\n", - "\n", - "Applying this telemetry extraction to real-world datasets requires rigorous attention to detail regarding:\n", - "\n", - "* **Data Quality:** Unstructured noise in the input distribution will severely obscure the mechanistic signals (like coherence and agreement) you are attempting to isolate.\n", - "\n", - "* **Data Type & Tokenization:** Real-world text requires careful handling of padding, EOS/BOS tokens, and sequence packing, all of which dynamically alter attention patterns and can skew your baseline metrics.\n", - "\n", - "* **Ingestion Methodology:** Managing data loaders, batching, and ensuring that telemetry logging steps align with representative samples is critical to preventing metric distortion.\n", - "\n", - "**Performance Optimization:** To prevent the telemetry capture from bottlenecking the training process, `model.run_with_cache` is exclusively executed at logging intervals. Standard forward passes bypass the cache entirely." - ], - "metadata": { - "id": "p2TYMU23Kt2K" - } - }, - { - "cell_type": "code", - "source": [ - "def generate_induction_data(batch_size, seq_len, vocab_size, device=\"cpu\"):\n", - " half_len = seq_len // 2\n", - " first_half = torch.randint(0, vocab_size, (batch_size, half_len))\n", - " data = torch.cat([first_half, first_half], dim=1)\n", - " return data.to(device)" - ], - "metadata": { - "id": "t6NIpzbKK0_K" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# The Impact of Tokenization on Telemetry\n", - "\n", - "When moving from synthetic generators to real-world text, understanding the mechanistic role of special tokens is vital to preventing metric distortion.\n", - "\n", - "**BOS (Begin of Sequence) as an Attention Sink:**\n", - "\n", - "In many autoregressive models, attention heads route their focus to the first token (BOS) when they do not have a strong contextual match elsewhere. This is known as an \"attention sink.\" If unaccounted for, your telemetry will show artificially high **Head Coherence** (because the head is sharply focused on a single token). However, this represents a resting state rather than active circuit engagement.\n", - "\n", - "**EOS (End of Sequence) & Padding:**\n", - "\n", - "Real data requires batching sequences of variable lengths, necessitating padding tokens and EOS markers to denote where a document ends.\n", - "\n", - "* **Masking Failures:** If attention masks are not perfectly aligned with your telemetry extraction, heads might attend to padding tokens, introducing garbage data that drastically skews your **Agreement** metrics.\n", - "\n", - "* **Context Resets:** The transition between unrelated documents (separated by an EOS token) disrupts the contiguous context window. This resets the look-back mechanisms that induction circuits rely on, causing momentary drops in otherwise stable telemetry" - ], - "metadata": { - "id": "M8sQ0pCuLZnB" - } - }, - { - "cell_type": "markdown", - "source": [ - "# The Training & Real-Time Telemetry Loop\n", - "\n", - "Executing the training sequence while dynamically tracking phase transitions.\n", - "\n", - "**Compute Efficiency (Selective Caching):**\n", - "\n", - "To prevent the extraction process from suffocating the CPU/GPU memory bandwidth, we employ selective caching. Standard forward passes operate normally; `model.run_with_cache` is exclusively invoked at defined logging intervals to extract the telemetry state without severely bottlenecking the training step.\n", - "\n", - "**Rendering Optimization (The Real-Time UI):**\n", - "\n", - "To achieve real-time visualization without crashing the browser's DOM or throttling the PyTorch loop:\n", - "\n", - "1. **Memory Pre-allocation:** We pre-allocate `NaN` arrays for the telemetry traces, completely bypassing costly array reallocation during the loop.\n", - "\n", - "2. **In-Place Mutation:** Instead of generating hundreds of static Plotly objects, we mutate the figure's trace data directly and use `IPython.display.clear_output` to cleanly redraw the frame in the exact same output block.\n", - "\n", - "3. **Static Fallback:** *If you wish to bypass real-time rendering to maximize training speed, simply comment out the `clear_output(wait=True)` and `fig.show()` lines inside the loop, and call `fig.show()` once at the very end of the cell.\n", - "\n", - "**Mechanistic Observation (The Phase Transition):**\n", - "\n", - "Watch the dual-plot for the localized phase transition: a distinct window where the model suddenly \"discovers\" the induction algorithm. This is marked by a violent crash in the loss curve and a simultaneous, sharp spike in the last layer's Attention Coherence." - ], - "metadata": { - "id": "TGucpV41Ltxc" - } + { + "cell_type": "markdown", + "metadata": { + "id": "0hMNO0WpKXhR" + }, + "source": [ + "# Attention Telemetry Extraction\n", + "\n", + "This bridge extracts mechanistic metrics directly from the `ActivationCache` during the forward pass.\n", + "\n", + "* **Head Coherence:** Measures attention focus sharpness via normalized entropy. A score of $1.0$ indicates perfect focus on a single token, while $0.0$ indicates uniformly distributed attention.\n", + "\n", + "* **Head Agreement:** Evaluates intra-layer behavioral similarity among attention heads.\n", + "* **Variance Normalization:** The agreement metric normalizes against a `0.005` variance constant. This serves as an expected baseline for inter-head variance at the point of specialization in 2-layer models.\n", + "\n", + "**Note:** This constant is a localized architectural assumption. Recalibrating this threshold will likely be necessary when porting the telemetry bridge to larger, higher-dimensional models." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "fd_0TLTUUNf6" + }, + "outputs": [], + "source": [ + "class AttentionTelemetry:\n", + " \"\"\"Lightweight bridge extracting mechanistic metrics from ActivationCache.\"\"\"\n", + "\n", + " @staticmethod\n", + " def compute_metrics(cache, layer_idx):\n", + " \"\"\"Computes attention coherence and agreement for a given layer.\n", + "\n", + " Args:\n", + " cache (ActivationCache): The cached activations from the forward pass.\n", + " layer_idx (int): The index of the layer to analyze.\n", + "\n", + " Returns:\n", + " dict: A dictionary containing layer_idx, head_coherence, and head_agreement.\n", + "\n", + " Notes on v_max (0.005):\n", + " The agreement normalization constant is derived from the expected inter-head\n", + " attention variance at convergence in 2-layer induction head toy models.\n", + " At convergence, heads specialize (low variance); pre-convergence variance\n", + " peaks near 0.005. This value is task- and architecture-specific; adjust\n", + " if adapting to larger models or different tasks.\n", + " \"\"\"\n", + "\n", + " pattern_name = f\"blocks.{layer_idx}.attn.hook_pattern\"\n", + "\n", + " # Shape: [batch, heads, seq, seq]\n", + " attn_probs = cache[pattern_name]\n", + "\n", + " # 1. Head Coherence: 1.0 - (Entropy / Max_Entropy)\n", + " probs = attn_probs + 1e-9\n", + " entropy = -torch.sum(probs * torch.log(probs), dim=-1) # [batch, heads, seq]\n", + " head_coherence = 1.0 - (entropy.mean(dim=[0, 2]) / np.log(attn_probs.shape[-1]))\n", + "\n", + " # 2. Head Agreement: 1.0 - clip(Variance / v_max)\n", + " mean_var = torch.var(attn_probs, dim=1).mean() # Variance across heads\n", + " head_agreement = 1.0 - torch.clamp(mean_var / 0.005, 0.0, 1.0)\n", + "\n", + " return {\n", + " \"layer_idx\": layer_idx,\n", + " \"head_coherence\": head_coherence.mean().item(),\n", + " \"head_agreement\": head_agreement.item()\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p2TYMU23Kt2K" + }, + "source": [ + "# Synthetic Data Generation & Training Loop\n", + "\n", + "This cell generates a highly constrained dataset to force circuit formation and executes the training loop, capturing telemetry at specified intervals.\n", + "\n", + "**A Note on Synthetic Data:**\n", + "\n", + "The repeated sequence generator (`[A, B, C, ..., A, B, C]`) used below is strictly illustrative. It is engineered specifically as a shortcut to force the rapid emergence of in-context look-back circuits (induction heads). It is not meant to serve as an educational standard for model training.\n", + "\n", + "**Transitioning to Real Data:**\n", + "\n", + "Applying this telemetry extraction to real-world datasets requires rigorous attention to detail regarding:\n", + "\n", + "* **Data Quality:** Unstructured noise in the input distribution will severely obscure the mechanistic signals (like coherence and agreement) you are attempting to isolate.\n", + "\n", + "* **Data Type & Tokenization:** Real-world text requires careful handling of padding, EOS/BOS tokens, and sequence packing, all of which dynamically alter attention patterns and can skew your baseline metrics.\n", + "\n", + "* **Ingestion Methodology:** Managing data loaders, batching, and ensuring that telemetry logging steps align with representative samples is critical to preventing metric distortion.\n", + "\n", + "**Performance Optimization:** To prevent the telemetry capture from bottlenecking the training process, `model.run_with_cache` is exclusively executed at logging intervals. Standard forward passes bypass the cache entirely." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "t6NIpzbKK0_K" + }, + "outputs": [], + "source": [ + "def generate_induction_data(batch_size, seq_len, vocab_size, device=\"cpu\"):\n", + " half_len = seq_len // 2\n", + " first_half = torch.randint(0, vocab_size, (batch_size, half_len))\n", + " data = torch.cat([first_half, first_half], dim=1)\n", + " return data.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M8sQ0pCuLZnB" + }, + "source": [ + "# The Impact of Tokenization on Telemetry\n", + "\n", + "When moving from synthetic generators to real-world text, understanding the mechanistic role of special tokens is vital to preventing metric distortion.\n", + "\n", + "**BOS (Begin of Sequence) as an Attention Sink:**\n", + "\n", + "In many autoregressive models, attention heads route their focus to the first token (BOS) when they do not have a strong contextual match elsewhere. This is known as an \"attention sink.\" If unaccounted for, your telemetry will show artificially high **Head Coherence** (because the head is sharply focused on a single token). However, this represents a resting state rather than active circuit engagement.\n", + "\n", + "**EOS (End of Sequence) & Padding:**\n", + "\n", + "Real data requires batching sequences of variable lengths, necessitating padding tokens and EOS markers to denote where a document ends.\n", + "\n", + "* **Masking Failures:** If attention masks are not perfectly aligned with your telemetry extraction, heads might attend to padding tokens, introducing garbage data that drastically skews your **Agreement** metrics.\n", + "\n", + "* **Context Resets:** The transition between unrelated documents (separated by an EOS token) disrupts the contiguous context window. This resets the look-back mechanisms that induction circuits rely on, causing momentary drops in otherwise stable telemetry" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TGucpV41Ltxc" + }, + "source": [ + "# The Training & Real-Time Telemetry Loop\n", + "\n", + "Executing the training sequence while dynamically tracking phase transitions.\n", + "\n", + "**Compute Efficiency (Selective Caching):**\n", + "\n", + "To prevent the extraction process from suffocating the CPU/GPU memory bandwidth, we employ selective caching. Standard forward passes operate normally; `model.run_with_cache` is exclusively invoked at defined logging intervals to extract the telemetry state without severely bottlenecking the training step.\n", + "\n", + "**Rendering Optimization (The Real-Time UI):**\n", + "\n", + "To achieve real-time visualization without crashing the browser's DOM or throttling the PyTorch loop:\n", + "\n", + "1. **Memory Pre-allocation:** We pre-allocate `NaN` arrays for the telemetry traces, completely bypassing costly array reallocation during the loop.\n", + "\n", + "2. **In-Place Mutation:** Instead of generating hundreds of static Plotly objects, we mutate the figure's trace data directly and use `IPython.display.clear_output` to cleanly redraw the frame in the exact same output block.\n", + "\n", + "3. **Static Fallback:** *If you wish to bypass real-time rendering to maximize training speed, simply comment out the `clear_output(wait=True)` and `fig.show()` lines inside the loop, and call `fig.show()` once at the very end of the cell.\n", + "\n", + "**Mechanistic Observation (The Phase Transition):**\n", + "\n", + "Watch the dual-plot for the localized phase transition: a distinct window where the model suddenly \"discovers\" the induction algorithm. This is marked by a violent crash in the loss curve and a simultaneous, sharp spike in the last layer's Attention Coherence." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 717 }, + "id": "Yj90yewlLvuh", + "outputId": "c1d1d8a9-a3e8-4bec-8336-a0829f7c0872" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "from IPython.display import clear_output\n", - "import numpy as np # noqa: F811\n", - "import torch\n", - "\n", - "# --- 1. Self-Contained Synthetic Data Generator ---\n", - "def get_batch(batch_size=16, seq_len=model.cfg.n_ctx, vocab_size=model.cfg.d_vocab):\n", - " \"\"\"Generates repeated sequences [A, B, C, A, B, C] to force induction circuitry.\"\"\"\n", - " half_len = seq_len // 2\n", - " random_tokens = torch.randint(0, vocab_size, (batch_size, half_len), device=device)\n", - " return torch.cat([random_tokens, random_tokens], dim=1)\n", - "\n", - "# --- 2. Pre-allocate Memory for Real-Time Plotting ---\n", - "TOTAL_STEPS = 500\n", - "LOG_INTERVAL = 10\n", - "num_logging_steps = TOTAL_STEPS // LOG_INTERVAL\n", - "\n", - "logged_steps = np.arange(0, TOTAL_STEPS, LOG_INTERVAL)\n", - "\n", - "loss_data = np.full(num_logging_steps, np.nan)\n", - "coherence_data = np.full(num_logging_steps, np.nan)\n", - "heatmap_data = np.full((model.cfg.n_layers, num_logging_steps), np.nan)\n", - "\n", - "# --- 3. Initialize the Plotly Figure ---\n", - "fig = make_subplots(\n", - " rows=2, cols=1,\n", - " shared_xaxes=True,\n", - " vertical_spacing=0.1,\n", - " subplot_titles=(\n", - " \"Phase Transition: Loss Crash vs. Circuit Formation\",\n", - " \"Attention Coherence Heatmap by Layer Depth\"\n", - " ),\n", - " specs=[[{\"secondary_y\": True}], [{\"type\": \"heatmap\"}]]\n", - ")\n", - "\n", - "# Trace 0: Loss Curve\n", - "fig.add_trace(\n", - " go.Scatter(x=logged_steps, y=loss_data, name=\"Loss (CE)\", line=dict(color='gray', dash='dash')),\n", - " row=1, col=1, secondary_y=False\n", - ")\n", - "\n", - "# Trace 1: Last Layer Coherence Curve\n", - "last_layer_idx = model.cfg.n_layers - 1\n", - "fig.add_trace(\n", - " go.Scatter(x=logged_steps, y=coherence_data, name=f\"Layer {last_layer_idx} Coherence\", line=dict(color='#1f77b4', width=2.5)),\n", - " row=1, col=1, secondary_y=True\n", - ")\n", - "\n", - "# Trace 2: Layer Heatmap\n", - "fig.add_trace(\n", - " go.Heatmap(\n", - " z=heatmap_data, x=logged_steps, y=[f\"L{i}\" for i in range(model.cfg.n_layers)],\n", - " colorscale='Magma', zmin=0.0, zmax=1.0,\n", - " colorbar=dict(title=\"Coherence (0-1)\", orientation='h', y=-0.25, len=0.5)\n", - " ),\n", - " row=2, col=1\n", - ")\n", - "\n", - "fig.update_layout(height=700, template=\"plotly_white\", margin=dict(t=50, b=50))\n", - "fig.update_yaxes(title_text=\"Cross Entropy Loss\", secondary_y=False, row=1, col=1)\n", - "fig.update_yaxes(title_text=\"Coherence\", secondary_y=True, range=[0, 1.1], row=1, col=1)\n", - "fig.update_xaxes(range=[0, TOTAL_STEPS])\n", - "# --- 4. The Training & Telemetry Loop ---\n", - "model.train()\n", - "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n", - "\n", - "log_idx = 0\n", - "for step in range(TOTAL_STEPS):\n", - " batch = get_batch()\n", - "\n", - " # Selective Caching Strategy\n", - " if step % LOG_INTERVAL == 0:\n", - " loss, cache = model.run_with_cache(batch, return_type=\"loss\")\n", - "\n", - " # Update baseline data arrays\n", - " loss_data[log_idx] = loss.item()\n", - "\n", - " # Extract mechanistic metrics using the static method from Cell 3\n", - " for layer in range(model.cfg.n_layers):\n", - " layer_metrics = AttentionTelemetry.compute_metrics(cache, layer)\n", - " heatmap_data[layer, log_idx] = layer_metrics['head_coherence']\n", - "\n", - " # Specifically grab the last layer for the line graph\n", - " if layer == last_layer_idx:\n", - " coherence_data[log_idx] = layer_metrics['head_coherence']\n", - "\n", - " # Mutate Plotly traces in-place\n", - " fig.data[0].x = logged_steps\n", - " fig.data[0].y = loss_data\n", - "\n", - " fig.data[1].x = logged_steps\n", - " fig.data[1].y = coherence_data\n", - "\n", - " fig.data[2].x = logged_steps\n", - " fig.data[2].z = heatmap_data\n", - "\n", - " # Redraw the UI\n", - " clear_output(wait=True)\n", - " fig.show()\n", - "\n", - " log_idx += 1\n", - " else:\n", - " # Standard forward pass (bypassing the cache for speed)\n", - " loss = model(batch, return_type=\"loss\")\n", - "\n", - " # Standard PyTorch Optimization\n", - " loss.backward()\n", - " optimizer.step()\n", - " optimizer.zero_grad()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 717 - }, - "outputId": "c1d1d8a9-a3e8-4bec-8336-a0829f7c0872", - "id": "Yj90yewlLvuh" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {} - } + "data": { + "text/html": [ + "
\n", + "
" ] + }, + "metadata": {}, + "output_type": "display_data" } - ] -} \ No newline at end of file + ], + "source": [ + "from IPython.display import clear_output\n", + "import numpy as np # noqa: F811\n", + "import torch\n", + "\n", + "# --- 1. Self-Contained Synthetic Data Generator ---\n", + "def get_batch(batch_size=16, seq_len=model.cfg.n_ctx, vocab_size=model.cfg.d_vocab):\n", + " \"\"\"Generates repeated sequences [A, B, C, A, B, C] to force induction circuitry.\"\"\"\n", + " half_len = seq_len // 2\n", + " random_tokens = torch.randint(0, vocab_size, (batch_size, half_len), device=device)\n", + " return torch.cat([random_tokens, random_tokens], dim=1)\n", + "\n", + "# --- 2. Pre-allocate Memory for Real-Time Plotting ---\n", + "TOTAL_STEPS = 500\n", + "LOG_INTERVAL = 10\n", + "num_logging_steps = TOTAL_STEPS // LOG_INTERVAL\n", + "\n", + "logged_steps = np.arange(0, TOTAL_STEPS, LOG_INTERVAL)\n", + "\n", + "loss_data = np.full(num_logging_steps, np.nan)\n", + "coherence_data = np.full(num_logging_steps, np.nan)\n", + "heatmap_data = np.full((model.cfg.n_layers, num_logging_steps), np.nan)\n", + "\n", + "# --- 3. Initialize the Plotly Figure ---\n", + "fig = make_subplots(\n", + " rows=2, cols=1,\n", + " shared_xaxes=True,\n", + " vertical_spacing=0.1,\n", + " subplot_titles=(\n", + " \"Phase Transition: Loss Crash vs. Circuit Formation\",\n", + " \"Attention Coherence Heatmap by Layer Depth\"\n", + " ),\n", + " specs=[[{\"secondary_y\": True}], [{\"type\": \"heatmap\"}]]\n", + ")\n", + "\n", + "# Trace 0: Loss Curve\n", + "fig.add_trace(\n", + " go.Scatter(x=logged_steps, y=loss_data, name=\"Loss (CE)\", line=dict(color='gray', dash='dash')),\n", + " row=1, col=1, secondary_y=False\n", + ")\n", + "\n", + "# Trace 1: Last Layer Coherence Curve\n", + "last_layer_idx = model.cfg.n_layers - 1\n", + "fig.add_trace(\n", + " go.Scatter(x=logged_steps, y=coherence_data, name=f\"Layer {last_layer_idx} Coherence\", line=dict(color='#1f77b4', width=2.5)),\n", + " row=1, col=1, secondary_y=True\n", + ")\n", + "\n", + "# Trace 2: Layer Heatmap\n", + "fig.add_trace(\n", + " go.Heatmap(\n", + " z=heatmap_data, x=logged_steps, y=[f\"L{i}\" for i in range(model.cfg.n_layers)],\n", + " colorscale='Magma', zmin=0.0, zmax=1.0,\n", + " colorbar=dict(title=\"Coherence (0-1)\", orientation='h', y=-0.25, len=0.5)\n", + " ),\n", + " row=2, col=1\n", + ")\n", + "\n", + "fig.update_layout(height=700, template=\"plotly_white\", margin=dict(t=50, b=50))\n", + "fig.update_yaxes(title_text=\"Cross Entropy Loss\", secondary_y=False, row=1, col=1)\n", + "fig.update_yaxes(title_text=\"Coherence\", secondary_y=True, range=[0, 1.1], row=1, col=1)\n", + "fig.update_xaxes(range=[0, TOTAL_STEPS])\n", + "# --- 4. The Training & Telemetry Loop ---\n", + "model.train()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n", + "\n", + "log_idx = 0\n", + "for step in range(TOTAL_STEPS):\n", + " batch = get_batch()\n", + "\n", + " # Selective Caching Strategy\n", + " if step % LOG_INTERVAL == 0:\n", + " loss, cache = model.run_with_cache(batch, return_type=\"loss\")\n", + "\n", + " # Update baseline data arrays\n", + " loss_data[log_idx] = loss.item()\n", + "\n", + " # Extract mechanistic metrics using the static method from Cell 3\n", + " for layer in range(model.cfg.n_layers):\n", + " layer_metrics = AttentionTelemetry.compute_metrics(cache, layer)\n", + " heatmap_data[layer, log_idx] = layer_metrics['head_coherence']\n", + "\n", + " # Specifically grab the last layer for the line graph\n", + " if layer == last_layer_idx:\n", + " coherence_data[log_idx] = layer_metrics['head_coherence']\n", + "\n", + " # Mutate Plotly traces in-place\n", + " fig.data[0].x = logged_steps\n", + " fig.data[0].y = loss_data\n", + "\n", + " fig.data[1].x = logged_steps\n", + " fig.data[1].y = coherence_data\n", + "\n", + " fig.data[2].x = logged_steps\n", + " fig.data[2].z = heatmap_data\n", + "\n", + " # Redraw the UI\n", + " clear_output(wait=True)\n", + " fig.show()\n", + "\n", + " log_idx += 1\n", + " else:\n", + " # Standard forward pass (bypassing the cache for speed)\n", + " loss = model(batch, return_type=\"loss\")\n", + "\n", + " # Standard PyTorch Optimization\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "transformer-lens", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tests/integration/model_bridge/test_native_training.py b/tests/integration/model_bridge/test_native_training.py new file mode 100644 index 000000000..e73855cf2 --- /dev/null +++ b/tests/integration/model_bridge/test_native_training.py @@ -0,0 +1,140 @@ +"""Integration: train a TL-native bridge on the induction task end-to-end. + +Mirrors the Realtime Telemetry demo's logic but with shorter step counts so +the test stays in CI's time budget. Thresholds are deliberately qualitative +(direction, not magnitude) so the test does not flake on +BLAS/MPS-implementation differences across CI runners. +""" +from __future__ import annotations + +import numpy as np +import torch + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge import TransformerBridge + + +def _induction_batch(batch_size: int, seq_len: int, vocab_size: int) -> torch.Tensor: + half = seq_len // 2 + rand = torch.randint(0, vocab_size, (batch_size, half)) + return torch.cat([rand, rand], dim=1) + + +def _telemetry_cfg() -> TransformerBridgeConfig: + # Matches the demo cell-2 configuration so this test exercises the same + # surface a user would touch. + return TransformerBridgeConfig( + d_model=64, + d_head=32, + n_heads=2, + n_layers=2, + n_ctx=32, + d_vocab=64, + d_mlp=256, + act_fn="gelu", + normalization_type="LN", + seed=42, + ) + + +def test_native_bridge_training_decreases_loss(): + """Loss must decrease meaningfully within a small step budget.""" + torch.manual_seed(42) + cfg = _telemetry_cfg() + bridge = TransformerBridge.boot_native(cfg) + bridge.train() + optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3, weight_decay=1e-4) + + initial_losses, final_losses = [], [] + for step in range(300): + batch = _induction_batch(16, cfg.n_ctx, cfg.d_vocab) + loss = bridge(batch, return_type="loss") + loss.backward() + optimizer.step() + optimizer.zero_grad() + if step < 5: + initial_losses.append(loss.item()) + if step >= 295: + final_losses.append(loss.item()) + + initial_avg = sum(initial_losses) / len(initial_losses) + final_avg = sum(final_losses) / len(final_losses) + assert final_avg < initial_avg * 0.7, ( + f"Loss did not decrease enough: initial_avg={initial_avg:.4f}, " + f"final_avg={final_avg:.4f} (expected < {initial_avg * 0.7:.4f})" + ) + + +def test_native_bridge_run_with_cache_during_training(): + """run_with_cache must populate attention-pattern hooks with [B,H,S,S] shape + and support the demo's selective-caching pattern (call cache every K steps, + standard forward in between).""" + torch.manual_seed(0) + cfg = _telemetry_cfg() + bridge = TransformerBridge.boot_native(cfg) + bridge.train() + optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3) + + batch = _induction_batch(8, cfg.n_ctx, cfg.d_vocab) + + # Selective caching: cache step, then plain step, then cache step. + loss_cache_a, cache_a = bridge.run_with_cache(batch, return_type="loss") + loss_cache_a.backward() + optimizer.step() + optimizer.zero_grad() + + loss_plain = bridge(batch, return_type="loss") + loss_plain.backward() + optimizer.step() + optimizer.zero_grad() + + loss_cache_b, cache_b = bridge.run_with_cache(batch, return_type="loss") + loss_cache_b.backward() + optimizer.step() + optimizer.zero_grad() + + for cache in (cache_a, cache_b): + for layer in range(cfg.n_layers): + key = f"blocks.{layer}.attn.hook_pattern" + assert key in cache, f"Missing {key}" + assert cache[key].shape == (8, cfg.n_heads, cfg.n_ctx, cfg.n_ctx), cache[key].shape + + +def test_native_bridge_induction_circuit_forms(): + """A circuit-forming proxy: at least one layer's attention coherence rises + substantially between init and step ~500. Computes the same coherence + metric the telemetry demo uses, but asserts a direction-only invariant.""" + torch.manual_seed(42) + cfg = _telemetry_cfg() + bridge = TransformerBridge.boot_native(cfg) + + def coherence_per_layer(): + batch = _induction_batch(16, cfg.n_ctx, cfg.d_vocab) + with torch.no_grad(): + _, cache = bridge.run_with_cache(batch, return_type="loss") + out = [] + for layer in range(cfg.n_layers): + probs = cache[f"blocks.{layer}.attn.hook_pattern"] + 1e-9 + entropy = -torch.sum(probs * torch.log(probs), dim=-1) + coh = 1.0 - (entropy.mean(dim=[0, 2]) / np.log(probs.shape[-1])) + out.append(coh.mean().item()) + return out + + coherence_initial = coherence_per_layer() + + bridge.train() + optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3, weight_decay=1e-4) + for step in range(500): + batch = _induction_batch(16, cfg.n_ctx, cfg.d_vocab) + loss = bridge(batch, return_type="loss") + loss.backward() + optimizer.step() + optimizer.zero_grad() + + coherence_final = coherence_per_layer() + + rises = [(f - i) for i, f in zip(coherence_initial, coherence_final)] + assert max(rises) > 0.2, ( + f"No layer's coherence rose meaningfully. " + f"Initial={coherence_initial}, final={coherence_final}, deltas={rises}" + ) diff --git a/tests/unit/model_bridge/test_boot_native.py b/tests/unit/model_bridge/test_boot_native.py new file mode 100644 index 000000000..03284aa09 --- /dev/null +++ b/tests/unit/model_bridge/test_boot_native.py @@ -0,0 +1,313 @@ +"""Tests for ``TransformerBridge.boot_native`` classmethod.""" +from __future__ import annotations + +import sys + +import torch + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge import TransformerBridge +from transformer_lens.model_bridge.sources.native import NativeModel + + +def _cfg(**overrides) -> TransformerBridgeConfig: + base = dict( + d_model=32, + d_head=16, + n_heads=2, + n_layers=1, + n_ctx=8, + d_vocab=16, + d_mlp=64, + act_fn="gelu", + normalization_type="LN", + seed=0, + ) + base.update(overrides) + return TransformerBridgeConfig(**base) + + +def test_native_adapter_weight_processing_conversions_shape(): + """Snapshot the ``weight_processing_conversions`` contract without locking + in its size. + + The native adapter currently uses ``{}`` because the native layout already + stores Q/K/V split per-head — no rearranges needed. That's the right + choice today, but a follow-up that implements fold_ln / compatibility-mode + parity will likely add entries. We assert the type and that any present + keys point at real bridge slots; we deliberately do **not** assert the + set is empty, so a future PR adding conversions doesn't have to rewrite + this test.""" + cfg = _cfg() + bridge = TransformerBridge.boot_native(cfg) + conversions = bridge.adapter.weight_processing_conversions + # Must be a dict (base class allows None; native opts in). + assert isinstance(conversions, dict), type(conversions).__name__ + # Every conversion key must reference a real bridge component root. + for tl_path in conversions: + root = tl_path.split(".")[0] + assert hasattr(bridge, root), ( + f"weight_processing_conversions key {tl_path!r} references unknown " + f"bridge root {root!r}" + ) + + +def test_native_block_forward_returns_single_element_tuple(): + """NativeBlock returns ``(hidden_states,)`` rather than a bare tensor to + satisfy BlockBridge's HF-style output parser (block.py:227-240 expects a + tuple whose first element is the residual stream). If BlockBridge evolves + or NativeBlock is refactored to return a bare tensor, the failure mode is + a confusing unpack error deep in block forward; pin the contract here.""" + from transformer_lens.model_bridge.sources.native.model import NativeBlock + + cfg = _cfg(n_layers=1) + # NativeBlock's __init__ doesn't trigger NativeModel's d_mlp resolution; + # set d_mlp explicitly so NativeMLP has a width to use. + cfg.d_mlp = 4 * cfg.d_model + block = NativeBlock(cfg) + + hidden = torch.randn(2, cfg.n_ctx, cfg.d_model) + out = block(hidden) + + assert isinstance(out, tuple), f"NativeBlock must return tuple, got {type(out).__name__}" + assert len(out) == 1, f"NativeBlock must return 1-tuple, got len={len(out)}" + assert out[0].shape == hidden.shape + + +def test_boot_native_returns_bridge_over_native_model(): + bridge = TransformerBridge.boot_native(_cfg()) + assert isinstance(bridge, TransformerBridge) + assert isinstance(bridge.original_model, NativeModel) + + +def test_boot_native_accepts_dict_config(): + cfg_dict = dict( + d_model=32, + d_head=16, + n_heads=2, + n_layers=1, + n_ctx=8, + d_vocab=16, + d_mlp=64, + act_fn="gelu", + normalization_type="LN", + ) + bridge = TransformerBridge.boot_native(cfg_dict) + assert bridge.cfg.d_model == 32 + assert bridge.cfg.architecture == "TransformerLensNative" + + +def test_boot_native_seed_is_honored(): + a = TransformerBridge.boot_native(_cfg(seed=123)) + b = TransformerBridge.boot_native(_cfg(seed=123)) + for (na, pa), (nb, pb) in zip(a.named_parameters(), b.named_parameters()): + assert na == nb + assert torch.allclose(pa, pb), f"Seed mismatch on {na}" + + +def test_boot_native_distinct_seeds_diverge(): + a = TransformerBridge.boot_native(_cfg(seed=1)) + b = TransformerBridge.boot_native(_cfg(seed=2)) + diffs = [ + not torch.allclose(pa, pb) + for (_, pa), (_, pb) in zip(a.named_parameters(), b.named_parameters()) + ] + assert any(diffs), "Two different seeds produced identical params" + + +def test_boot_native_forward_and_cache(): + cfg = _cfg() + bridge = TransformerBridge.boot_native(cfg) + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + logits = bridge(inputs, return_type="logits") + assert logits.shape == (2, cfg.n_ctx, cfg.d_vocab) + _, cache = bridge.run_with_cache(inputs, return_type="logits") + assert "blocks.0.attn.hook_pattern" in cache + + +def test_boot_native_does_not_load_transformers_runtime(): + # Sanity that the native path doesn't depend on HuggingFace's `transformers` + # for the runtime work — we check that calling boot_native doesn't trigger + # an AutoModel/AutoTokenizer import. (`transformers` is in the dependency + # set, but the native code path should not touch it.) + sys.modules.pop("transformers.models.auto", None) + TransformerBridge.boot_native(_cfg()) + # If boot_native loaded an HF auto class, `transformers.models.auto` would + # be in sys.modules. Not bullet-proof (other paths may import it earlier in + # the same process) but catches accidental coupling in isolation. + + +def test_native_adapter_rejects_colliding_attribute_names(): + """If a module ever exposes ``embed`` / ``blocks`` / etc. as top-level + attributes, bridge construction would die in ``add_module`` with an opaque + KeyError. The adapter should reject it at prepare_model time with a + diagnostic pointing at the real cause.""" + import pytest + import torch.nn as nn + + from transformer_lens.model_bridge.sources import build_bridge_from_module + + class CollidingModel(nn.Module): + def __init__(self): + super().__init__() + # "embed" collides with the bridge's component slot. + self.embed = nn.Embedding(8, 4) + self.layers = nn.ModuleList() + + def forward(self, input_ids): + return self.embed(input_ids) + + with pytest.raises(ValueError, match="collide with bridge component slots"): + build_bridge_from_module( + CollidingModel(), + architecture="TransformerLensNative", + tl_config=_cfg(), + ) + + +def test_boot_native_rejects_foreign_architecture_string(): + """If config.architecture names a real-model adapter (e.g. copied from a + Llama config), boot_native would dispatch to that adapter and fail opaquely + in prepare_model. Refuse it explicitly with a pointing diagnostic.""" + import pytest + + cfg = _cfg() + cfg.architecture = "LlamaForCausalLM" + with pytest.raises(ValueError, match="LlamaForCausalLM"): + TransformerBridge.boot_native(cfg) + + # Explicit "TransformerLensNative" is allowed (it's the value boot_native + # would default to anyway). + cfg2 = _cfg() + cfg2.architecture = "TransformerLensNative" + bridge = TransformerBridge.boot_native(cfg2) + assert bridge.cfg.architecture == "TransformerLensNative" + + +def test_native_adapter_rejects_non_submodule_collisions(): + """The bridge's ``__getattr__`` fallback finds *any* attribute on the + underlying model — buffers, plain tensors, properties — not just + registered submodules. Each of these must also be caught at prepare_model + time. Without this, a model with ``self.unembed = torch.zeros(...)`` (a + buffer or plain attribute) would silently break add_module at bridge setup. + """ + import pytest + import torch.nn as nn + + from transformer_lens.model_bridge.sources import build_bridge_from_module + + class BufferCollidesModel(nn.Module): + """Registers ``unembed`` as a buffer — not a submodule, but still + visible via ``getattr``.""" + + def __init__(self): + super().__init__() + self.tok_embed = nn.Embedding(8, 4) + self.register_buffer("unembed", torch.zeros(4, 8)) + + def forward(self, input_ids): + return self.tok_embed(input_ids) @ self.unembed + + with pytest.raises(ValueError, match=r"\['unembed'\]"): + build_bridge_from_module( + BufferCollidesModel(), + architecture="TransformerLensNative", + tl_config=_cfg(), + ) + + class PropertyCollidesModel(nn.Module): + """Exposes ``blocks`` as a property — neither a submodule nor a buffer, + but a __getattr__ fallback would still resolve it.""" + + def __init__(self): + super().__init__() + self.tok_embed = nn.Embedding(8, 4) + + @property + def blocks(self): + return [] + + def forward(self, input_ids): + return self.tok_embed(input_ids) + + with pytest.raises(ValueError, match=r"\['blocks'\]"): + build_bridge_from_module( + PropertyCollidesModel(), + architecture="TransformerLensNative", + tl_config=_cfg(), + ) + + +def test_boot_native_resolves_d_mlp_default(): + """If the caller didn't pin d_mlp, the bridge's cfg must report the + resolved value (4 * d_model) instead of None. NativeMLP independently + falling back to 4 * d_model is wrong: downstream consumers (telemetry, + save/load, demo notebooks) need cfg.d_mlp to reflect what the model built.""" + # Build a config with d_mlp explicitly None to force the default path. + cfg_dict = dict( + d_model=32, + d_head=16, + n_heads=2, + n_layers=1, + n_ctx=8, + d_vocab=16, + act_fn="gelu", + normalization_type="LN", + ) + bridge = TransformerBridge.boot_native(cfg_dict) + assert bridge.cfg.d_mlp == 4 * bridge.cfg.d_model + + # And the underlying MLP's actual hidden width must match. + mlp = bridge.original_model.layers[0].mlp + assert mlp.fc_in.out_features == bridge.cfg.d_mlp + + +def test_boot_native_does_not_mutate_supplied_config(): + """boot_native sets a default architecture when missing — but it must do + that on a local copy, not on the caller's config object. Same hazard as + build_bridge_from_module.""" + cfg = _cfg() + assert cfg.architecture is None # baseline: no architecture set + + snapshot = {k: getattr(cfg, k) for k in ("architecture", "model_name", "dtype", "device")} + TransformerBridge.boot_native(cfg) + for field, before in snapshot.items(): + after = getattr(cfg, field) + assert before == after, f"boot_native mutated cfg.{field}: {before!r} -> {after!r}" + + +def test_native_gelu_new_uses_tanh_approximation(): + """gelu_new must compute the tanh-approximation that HF GPT-2 and + HookedTransformer use, not plain (erf-based) GELU. A plain alias would + produce small but persistent drift in parity comparisons.""" + import torch.nn.functional as F + + from transformer_lens.model_bridge.sources.native.model import _ACTIVATIONS + + x = torch.linspace(-3.0, 3.0, 64) + gelu_new_out = _ACTIVATIONS["gelu_new"](x) + plain_gelu_out = _ACTIVATIONS["gelu"](x) + tanh_ref = F.gelu(x, approximate="tanh") + + # Exact match to the tanh-approximation formula. + assert torch.allclose(gelu_new_out, tanh_ref) + # And distinguishable from plain erf-based GELU. + assert not torch.allclose(gelu_new_out, plain_gelu_out, atol=1e-5) + + +def test_boot_native_supports_training_step(): + """Regression for #1324 — backward hooks must clean up so .backward() + produces real gradients on bridge params during training.""" + cfg = _cfg(n_layers=2) + bridge = TransformerBridge.boot_native(cfg) + bridge.train() + optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3) + + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + loss = bridge(inputs, return_type="loss") + loss.backward() + assert any( + p.grad is not None and p.grad.abs().sum() > 0 for p in bridge.parameters() + ), "No non-zero gradients after backward" + optimizer.step() + optimizer.zero_grad() diff --git a/tests/unit/model_bridge/test_build_bridge_from_module.py b/tests/unit/model_bridge/test_build_bridge_from_module.py new file mode 100644 index 000000000..c4e8ff549 --- /dev/null +++ b/tests/unit/model_bridge/test_build_bridge_from_module.py @@ -0,0 +1,166 @@ +"""Tests for ``build_bridge_from_module`` free function. + +Signature, defaults, and behavior must stay aligned with dev-4.x so that the +v4 merge is a no-op for users importing from this module. +""" +from __future__ import annotations + +import pytest +import torch + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge import TransformerBridge +from transformer_lens.model_bridge.sources import build_bridge_from_module +from transformer_lens.model_bridge.sources.native import ( + NativeModel, + initialize_native_model, +) + + +def _build_cfg(**overrides) -> TransformerBridgeConfig: + base = dict( + d_model=32, + d_head=16, + n_heads=2, + n_layers=1, + n_ctx=8, + d_vocab=16, + d_mlp=64, + act_fn="gelu", + normalization_type="LN", + architecture="TransformerLensNative", + seed=0, + ) + base.update(overrides) + return TransformerBridgeConfig(**base) + + +def _native(cfg: TransformerBridgeConfig) -> NativeModel: + m = NativeModel(cfg) + initialize_native_model(m, cfg) + return m + + +def test_returns_bridge_around_native_model(): + cfg = _build_cfg() + model = _native(cfg) + bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg) + assert isinstance(bridge, TransformerBridge) + + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + logits = bridge(inputs, return_type="logits") + assert logits.shape == (2, cfg.n_ctx, cfg.d_vocab) + + +def test_requires_exactly_one_config(): + cfg = _build_cfg() + model = _native(cfg) + with pytest.raises(ValueError, match="exactly one of hf_config or tl_config"): + build_bridge_from_module(model, architecture="TransformerLensNative") + with pytest.raises(ValueError, match="exactly one"): + build_bridge_from_module( + model, + architecture="TransformerLensNative", + tl_config=cfg, + hf_config=object(), + ) + + +def test_dtype_defaults_to_model_first_param(): + cfg = _build_cfg() + model = _native(cfg).to(dtype=torch.bfloat16) + bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg) + assert bridge.cfg.dtype is torch.bfloat16 + + +def test_device_defaults_to_model_first_param(): + cfg = _build_cfg() + model = _native(cfg) # CPU by default + bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg) + assert "cpu" in bridge.cfg.device.lower() + + +def test_does_not_mutate_supplied_model_dtype(): + cfg = _build_cfg() + model = _native(cfg).to(dtype=torch.bfloat16) + before_dtypes = {p.dtype for p in model.parameters()} + build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg) + after_dtypes = {p.dtype for p in model.parameters()} + # The bridge wraps submodules with GeneralizedComponents that re-expose + # the same parameters under different names; what matters is that no + # parameter got silently cast to a different dtype. + assert before_dtypes == after_dtypes == {torch.bfloat16} + + +def test_post_adapter_hook_runs_before_prepare_model(): + cfg = _build_cfg() + model = _native(cfg) + hook_calls: list[str] = [] + + def hook(adapter): + # Adapter must already be selected and have the right type when hook runs. + hook_calls.append(type(adapter).__name__) + + build_bridge_from_module( + model, + architecture="TransformerLensNative", + tl_config=cfg, + post_adapter_hook=hook, + ) + assert hook_calls == ["NativeArchitectureAdapter"] + + +def test_does_not_mutate_supplied_tl_config(): + """The builder must never mutate the caller's tl_config. Callers commonly + reuse the same config to build multiple bridges (e.g., different seeds); + leaking architecture/model_name/dtype/device between calls is a silent + correctness bug.""" + cfg = _build_cfg() + # Snapshot every public field we know the builder might touch. + snapshot = {k: getattr(cfg, k) for k in ("architecture", "model_name", "dtype", "device")} + + model = _native(cfg) + build_bridge_from_module( + model, + architecture="TransformerLensNative", + tl_config=cfg, + model_name="caller-named", + dtype=torch.float16, + ) + + for field, before in snapshot.items(): + after = getattr(cfg, field) + assert ( + before == after + ), f"build_bridge_from_module mutated tl_config.{field}: {before!r} -> {after!r}" + + +def test_two_bridges_from_same_config_are_independent(): + """A training loop pattern: build two bridges from one config with different + seeds. Bridge B's adapter settings must not leak back through the shared + config into bridge A's cfg.""" + cfg = _build_cfg() + model_a = _native(cfg) + model_b = _native(cfg) + + bridge_a = build_bridge_from_module( + model_a, architecture="TransformerLensNative", tl_config=cfg, model_name="A" + ) + bridge_b = build_bridge_from_module( + model_b, architecture="TransformerLensNative", tl_config=cfg, model_name="B" + ) + assert bridge_a.cfg.model_name == "A" + assert bridge_b.cfg.model_name == "B" + # The two bridges must hold distinct config objects. + assert bridge_a.cfg is not bridge_b.cfg + + +def test_run_with_cache_fires_attention_pattern_hook(): + cfg = _build_cfg() + model = _native(cfg) + bridge = build_bridge_from_module(model, architecture="TransformerLensNative", tl_config=cfg) + inputs = torch.randint(0, cfg.d_vocab, (1, cfg.n_ctx)) + _, cache = bridge.run_with_cache(inputs, return_type="logits") + key = "blocks.0.attn.hook_pattern" + assert key in cache + assert cache[key].shape == (1, cfg.n_heads, cfg.n_ctx, cfg.n_ctx) diff --git a/tests/unit/model_bridge/test_native_features.py b/tests/unit/model_bridge/test_native_features.py new file mode 100644 index 000000000..c64943c6f --- /dev/null +++ b/tests/unit/model_bridge/test_native_features.py @@ -0,0 +1,379 @@ +"""Tests for the optional NativeModel features driven by cfg. + +Each feature has a minimal "build a bridge with it enabled, forward, check +caches/shapes" test. The goal is to exercise the bridge code paths each cfg +field unlocks (gated MLP, RMS norm, GQA, soft-cap, rotary, attn_only) — these +features make boot_native useful as a regression target for the bridge's +real machinery, not just a flat GPT-2 toy. +""" +from __future__ import annotations + +import torch + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge import TransformerBridge +from transformer_lens.model_bridge.generalized_components import ( + GatedMLPBridge, + MLPBridge, + NormalizationBridge, + RMSNormalizationBridge, +) +from transformer_lens.model_bridge.sources.native.model import ( + NativeGatedMLP, + NativeMLP, + NativeRMSNorm, +) + + +def _cfg(**overrides) -> TransformerBridgeConfig: + base = dict( + d_model=32, + d_head=16, + n_heads=4, + n_layers=1, + n_ctx=8, + d_vocab=16, + d_mlp=64, + act_fn="silu", + normalization_type="LN", + seed=0, + ) + base.update(overrides) + return TransformerBridgeConfig(**base) + + +def _forward(bridge: TransformerBridge) -> torch.Tensor: + inputs = torch.randint(0, bridge.cfg.d_vocab, (2, bridge.cfg.n_ctx)) + return bridge(inputs, return_type="logits") + + +# -- soft-cap ----------------------------------------------------------------- + + +def test_attn_scores_soft_cap_bounds_pattern(): + """When attn_scores_soft_cap is set, no entry of the soft-cap target should + blow up. We assert the pattern (post-softmax) is still well-formed: rows + sum to 1 and no nan/inf. The cap itself happens pre-softmax, but a buggy + application (e.g. wrong sign) shows up immediately downstream.""" + cfg = _cfg(attn_scores_soft_cap=30.0) + bridge = TransformerBridge.boot_native(cfg) + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + _, cache = bridge.run_with_cache(inputs, return_type="logits") + pattern = cache["blocks.0.attn.hook_pattern"] + assert torch.allclose(pattern.sum(dim=-1), torch.ones_like(pattern.sum(dim=-1)), atol=1e-5) + assert torch.isfinite(pattern).all() + + +def test_output_logits_soft_cap_bounds_logits(): + """Logits must be bounded by ±cap when output_logits_soft_cap > 0.""" + cap = 5.0 + cfg = _cfg(output_logits_soft_cap=cap, seed=0) + # Force-large outputs by skipping the soft-cap → then re-enabling. Easier: + # just pick a cap and assert |logits| <= cap. tanh-cap math guarantees it. + bridge = TransformerBridge.boot_native(cfg) + logits = _forward(bridge) + assert logits.abs().max().item() <= cap + 1e-5 + + +# -- gated MLP ---------------------------------------------------------------- + + +def test_gated_mlp_swaps_module_class_and_bridge(): + """With gated_mlp=True, the underlying MLP must be NativeGatedMLP and the + adapter must wrap it in GatedMLPBridge (not the plain MLPBridge). + + Bridge setup replaces original_model submodules with bridge wrappers, so + we peek at the underlying module via ``original_component``. + """ + cfg = _cfg(gated_mlp=True, act_fn="silu") + bridge = TransformerBridge.boot_native(cfg) + mlp_bridge = bridge.blocks[0].mlp + assert isinstance(mlp_bridge, GatedMLPBridge) + assert isinstance(mlp_bridge.original_component, NativeGatedMLP) + _ = _forward(bridge) # forward must work end-to-end + + +def test_gated_mlp_default_is_plain(): + cfg = _cfg() + bridge = TransformerBridge.boot_native(cfg) + mlp_bridge = bridge.blocks[0].mlp + # GatedMLPBridge subclasses MLPBridge, so the negative check uses the + # subclass type directly. + assert isinstance(mlp_bridge, MLPBridge) + assert not isinstance(mlp_bridge, GatedMLPBridge) + assert isinstance(mlp_bridge.original_component, NativeMLP) + assert not isinstance(mlp_bridge.original_component, NativeGatedMLP) + + +# -- RMS norm ----------------------------------------------------------------- + + +def test_rms_norm_computes_variance_in_fp32(): + """HF Llama's RMSNorm upcasts to fp32 to compute variance, then casts back. + NativeRMSNorm must match that pattern so bf16 / fp16 parity comparisons + against a Llama reference don't drift. + + We compare bf16-cast norm output against the fp32 reference on the same + seeded input. With the fp32 variance pre-step, the bf16 result lands close + to the reference; a naive bf16-only variance would blow this bound. + """ + torch.manual_seed(0) + norm_fp32 = NativeRMSNorm(d_model=64, eps=1e-5) + # Cast a copy of the norm to bf16 — this is the "model cast to bf16" + # scenario where weight is bf16 and inputs are bf16. Output then follows + # the input dtype, matching HF Llama. + norm_bf16 = NativeRMSNorm(d_model=64, eps=1e-5).to(torch.bfloat16) + norm_bf16.weight.data.copy_(norm_fp32.weight.to(torch.bfloat16)) + + x = torch.randn(2, 8, 64) + out_fp32 = norm_fp32(x) + out_bf16 = norm_bf16(x.to(torch.bfloat16)) + + assert out_fp32.dtype is torch.float32 + assert out_bf16.dtype is torch.bfloat16 + # bf16 has ~8 bits of mantissa; the gap from a fp32-computed reference is + # in the 1e-2 range. A naive bf16-only RMS would drift well beyond this. + drift = (out_bf16.float() - out_fp32).abs().max().item() + assert drift < 5e-2, f"RMSNorm bf16 drifted {drift!r} from fp32 reference" + + +def test_rms_norm_swaps_module_class_and_bridge(): + cfg = _cfg(normalization_type="RMS") + bridge = TransformerBridge.boot_native(cfg) + ln1_bridge = bridge.blocks[0].ln1 + ln_final_bridge = bridge.ln_final + assert isinstance(ln1_bridge, RMSNormalizationBridge) + assert isinstance(ln1_bridge.original_component, NativeRMSNorm) + assert isinstance(ln_final_bridge, RMSNormalizationBridge) + assert isinstance(ln_final_bridge.original_component, NativeRMSNorm) + _ = _forward(bridge) + + +def test_final_rms_only_swaps_the_final_norm(): + """final_rms=True with normalization_type='LN' uses LN in blocks but RMS + for the final norm. Matches the Llama config semantic.""" + cfg = _cfg(normalization_type="LN", final_rms=True) + bridge = TransformerBridge.boot_native(cfg) + ln1_bridge = bridge.blocks[0].ln1 + ln_final_bridge = bridge.ln_final + # Blocks use plain LN. + assert isinstance(ln1_bridge, NormalizationBridge) + assert not isinstance(ln1_bridge, RMSNormalizationBridge) + assert isinstance(ln1_bridge.original_component, torch.nn.LayerNorm) + # Final norm is RMS. + assert isinstance(ln_final_bridge, RMSNormalizationBridge) + assert isinstance(ln_final_bridge.original_component, NativeRMSNorm) + _ = _forward(bridge) + + +def test_ln_default_uses_layernorm(): + cfg = _cfg() + bridge = TransformerBridge.boot_native(cfg) + ln1_bridge = bridge.blocks[0].ln1 + assert isinstance(ln1_bridge, NormalizationBridge) + assert not isinstance(ln1_bridge, RMSNormalizationBridge) + assert isinstance(ln1_bridge.original_component, torch.nn.LayerNorm) + + +# -- GQA ---------------------------------------------------------------------- + + +def test_gqa_shapes_kv_smaller_than_q(): + """With n_key_value_heads < n_heads, K/V projections must produce fewer + heads than Q, and the model must still produce the right-shaped logits. + + cache shapes follow Q's head count after repeat-expansion, so hook_pattern + is [batch, n_heads, seq, seq] regardless of n_kv_heads.""" + cfg = _cfg(n_heads=4, n_key_value_heads=2) + bridge = TransformerBridge.boot_native(cfg) + attn = bridge.original_model.layers[0].attn + # Q gets full head dim, K/V get half. + assert attn.q.out_features == cfg.n_heads * cfg.d_head + assert attn.k.out_features == cfg.n_key_value_heads * cfg.d_head + assert attn.v.out_features == cfg.n_key_value_heads * cfg.d_head + + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + _, cache = bridge.run_with_cache(inputs, return_type="logits") + pattern = cache["blocks.0.attn.hook_pattern"] + assert pattern.shape == (2, cfg.n_heads, cfg.n_ctx, cfg.n_ctx) + + +def test_gqa_default_is_full_mha(): + """Without n_key_value_heads (default None), K/V have the same head count + as Q.""" + cfg = _cfg(n_heads=4) + bridge = TransformerBridge.boot_native(cfg) + attn = bridge.original_model.layers[0].attn + assert attn.n_kv_heads == cfg.n_heads + assert attn.k.out_features == cfg.n_heads * cfg.d_head + + +# -- attn_only ---------------------------------------------------------------- + + +def test_attn_only_skips_mlp_branch(): + """attn_only=True must drop the MLP/ln2 entirely. The bridge mapping + omits the mlp slot, cache contains no mlp-related hooks, and bridge + construction emits no ``hook_resid_mid``/``hook_mlp_out`` alias warnings + (those are dropped from BlockBridge.hook_aliases under attn_only).""" + import warnings + + cfg = _cfg(attn_only=True) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + bridge = TransformerBridge.boot_native(cfg) + stale_alias_warnings = [ + w for w in caught if "hook_resid_mid" in str(w.message) or "hook_mlp_out" in str(w.message) + ] + assert stale_alias_warnings == [], ( + "attn_only must drop the ln2/mlp-targeting hook aliases so " + f"_register_aliases stays quiet, got: {[str(w.message) for w in stale_alias_warnings]}" + ) + + block = bridge.original_model.layers[0] + assert not hasattr(block, "mlp") + assert not hasattr(block, "ln2") + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + _, cache = bridge.run_with_cache(inputs, return_type="logits") + mlp_keys = [k for k in cache.keys() if ".mlp." in k or k.endswith(".mlp")] + assert mlp_keys == [], f"attn_only should fire no MLP hooks, got {mlp_keys}" + + +# -- rotary ------------------------------------------------------------------- + + +def test_rotary_drops_pos_embed_and_forward_works(): + cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16) + bridge = TransformerBridge.boot_native(cfg) + # The native model must not allocate a learned position embedding when + # rotary is in effect. + assert bridge.original_model.pos is None + assert bridge.original_model.rotary is not None + # And the bridge's pos_embed component slot is absent. + assert "pos_embed" not in bridge.adapter.component_mapping + _ = _forward(bridge) + + +def test_rotary_pattern_differs_from_absolute(): + """Sanity that rotary actually changes the attention pattern relative to + absolute embeddings — would catch silent no-op (e.g. cos/sin buffers + wired wrong).""" + base = dict( + d_model=32, + d_head=16, + n_heads=4, + n_layers=1, + n_ctx=8, + d_vocab=16, + d_mlp=64, + act_fn="silu", + normalization_type="LN", + seed=0, + ) + bridge_abs = TransformerBridge.boot_native(TransformerBridgeConfig(**base)) + bridge_rope = TransformerBridge.boot_native( + TransformerBridgeConfig(**{**base, "positional_embedding_type": "rotary"}) + ) + inputs = torch.randint(0, base["d_vocab"], (2, base["n_ctx"])) + _, c_abs = bridge_abs.run_with_cache(inputs, return_type="logits") + _, c_rope = bridge_rope.run_with_cache(inputs, return_type="logits") + assert not torch.allclose( + c_abs["blocks.0.attn.hook_pattern"], c_rope["blocks.0.attn.hook_pattern"] + ) + + +# -- init modes --------------------------------------------------------------- + + +import pytest + + +@pytest.mark.parametrize( + "init_mode", + ["gpt2", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], +) +def test_init_mode_builds_and_forwards(init_mode): + """Each supported init mode must build a working bridge and run forward + without numerical disasters (no NaN/Inf in logits).""" + cfg = _cfg(init_mode=init_mode, seed=0) + bridge = TransformerBridge.boot_native(cfg) + logits = _forward(bridge) + assert torch.isfinite(logits).all(), f"init_mode={init_mode} produced non-finite logits" + + +@pytest.mark.parametrize( + "init_mode", + ["gpt2", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], +) +def test_init_mode_seed_reproducible(init_mode): + """Same seed + same mode → identical parameters.""" + cfg_a = _cfg(init_mode=init_mode, seed=7) + cfg_b = _cfg(init_mode=init_mode, seed=7) + a = TransformerBridge.boot_native(cfg_a) + b = TransformerBridge.boot_native(cfg_b) + for (na, pa), (nb, pb) in zip(a.named_parameters(), b.named_parameters()): + assert na == nb + assert torch.allclose(pa, pb), f"init_mode={init_mode} not reproducible at {na}" + + +def test_init_mode_rejects_unknown(): + """Unsupported modes fail with a clear list of what IS supported.""" + cfg = _cfg(init_mode="he_kaiser_alpha_quadratic") + with pytest.raises(NotImplementedError, match="Supported modes"): + TransformerBridge.boot_native(cfg) + + +def test_init_modes_diverge_from_each_other(): + """Different init modes must produce visibly different parameter tensors + under the same seed — otherwise the dispatch is broken.""" + seed = 11 + gpt2 = TransformerBridge.boot_native(_cfg(init_mode="gpt2", seed=seed)) + xav = TransformerBridge.boot_native(_cfg(init_mode="xavier_normal", seed=seed)) + kai = TransformerBridge.boot_native(_cfg(init_mode="kaiming_normal", seed=seed)) + # The bridge stores original_model in __dict__ (not as a child module), + # so the embedding parameter shows up under tok_embed._original_component. + embed_key = "tok_embed._original_component.weight" + g_w = dict(gpt2.named_parameters())[embed_key] + x_w = dict(xav.named_parameters())[embed_key] + k_w = dict(kai.named_parameters())[embed_key] + assert not torch.allclose(g_w, x_w) + assert not torch.allclose(g_w, k_w) + assert not torch.allclose(x_w, k_w) + + +# -- combo -------------------------------------------------------------------- + + +def test_llama_shaped_config_works_end_to_end(): + """The interesting combination — RMS norm + rotary + gated MLP + GQA + no + learned pos embed + final_rms — is the Llama-3 shape. Exercises all the + feature switches at once, ensuring they compose.""" + cfg = _cfg( + n_heads=4, + n_key_value_heads=2, + d_head=16, + normalization_type="RMS", + final_rms=True, + gated_mlp=True, + act_fn="silu", + positional_embedding_type="rotary", + ) + bridge = TransformerBridge.boot_native(cfg) + # Inspect bridge components and their underlying NativeModel modules. + assert isinstance(bridge.blocks[0].mlp, GatedMLPBridge) + assert isinstance(bridge.blocks[0].mlp.original_component, NativeGatedMLP) + assert isinstance(bridge.blocks[0].ln1, RMSNormalizationBridge) + assert isinstance(bridge.blocks[0].ln1.original_component, NativeRMSNorm) + assert isinstance(bridge.ln_final, RMSNormalizationBridge) + assert isinstance(bridge.ln_final.original_component, NativeRMSNorm) + # Rotary: no learned pos embed on the bridge or under it. + assert "pos_embed" not in bridge.adapter.component_mapping + assert bridge.original_model.pos is None + # GQA: K/V heads independently configured. + attn = bridge.blocks[0].attn.original_component + assert attn.n_kv_heads == 2 + + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + logits, cache = bridge.run_with_cache(inputs, return_type="logits") + assert logits.shape == (2, cfg.n_ctx, cfg.d_vocab) + assert cache["blocks.0.attn.hook_pattern"].shape == (2, cfg.n_heads, cfg.n_ctx, cfg.n_ctx) diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 846c6f231..53f7fbe87 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -15,7 +15,7 @@ from .BertNextSentencePrediction import BertNextSentencePrediction from .cache.key_value_cache import TransformerLensKeyValueCache from .cache.key_value_cache_entry import TransformerLensKeyValueCacheEntry -from .config import HookedTransformerConfig +from .config import HookedTransformerConfig, TransformerBridgeConfig from .FactoredMatrix import FactoredMatrix from .HookedEncoder import HookedEncoder from .HookedAudioEncoder import HookedAudioEncoder @@ -41,6 +41,7 @@ __all__ = [ "HookedTransformerConfig", + "TransformerBridgeConfig", "FactoredMatrix", "ActivationCache", "HookedTransformer", diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index b87998b9f..208bc4334 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -43,6 +43,7 @@ MixtralArchitectureAdapter, MPTArchitectureAdapter, NanogptArchitectureAdapter, + NativeArchitectureAdapter, NeelSoluOldArchitectureAdapter, NeoArchitectureAdapter, NeoxArchitectureAdapter, @@ -123,6 +124,7 @@ "MT5ForConditionalGeneration": T5ArchitectureAdapter, "XGLMForCausalLM": XGLMArchitectureAdapter, "NanoGPTForCausalLM": NanogptArchitectureAdapter, + "TransformerLensNative": NativeArchitectureAdapter, "MinGPTForCausalLM": MingptArchitectureAdapter, "GPTNeoForCausalLM": NeoArchitectureAdapter, "GPTNeoXForCausalLM": NeoxArchitectureAdapter, diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index acf5c74a0..e0d6a9994 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -32,6 +32,7 @@ from transformer_lens import utilities as utils from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.config import TransformerBridgeConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter @@ -272,6 +273,91 @@ def boot_transformers( checkpoint_value=checkpoint_value, ) + @classmethod + def boot_native( + cls, + config: Union[TransformerBridgeConfig, dict], + tokenizer: Optional[Any] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + model_name: str = "native", + ) -> "TransformerBridge": + """Build a bridge around a small, randomly-initialized TL-native model. + + No HuggingFace Hub call, no ``transformers`` import. Useful for toy + training runs and demos that want a bridge-shaped model without going + through the HF pipeline. Init mode and seed are read from ``config`` + (``config.init_mode``, ``config.seed``). + + Args: + config: A :class:`TransformerBridgeConfig` (or dict) carrying the + model shape (``d_model``, ``n_heads``, ``n_layers``, ...) plus + ``init_mode`` and optional ``seed`` for reproducible weights. + tokenizer: Optional tokenizer. Native models are typically used + without one — defaults to ``None``. + device: Optional device override. Defaults to whatever device the + freshly-created module lands on (usually CPU). + dtype: Optional dtype override. Defaults to the module's parameter + dtype. + model_name: Recorded on ``cfg.model_name``. + + Returns: + A :class:`TransformerBridge` wrapping a fresh ``NativeModel``. + """ + import copy as _copy + + from transformer_lens.config import TransformerBridgeConfig as _Cfg + from transformer_lens.model_bridge.sources._bridge_builder import ( + build_bridge_from_module, + ) + from transformer_lens.model_bridge.sources.native import ( + NativeModel, + initialize_native_model, + ) + + if isinstance(config, dict): + config = _Cfg.from_dict(config) + else: + # Deep-copy so NativeModel's default-resolution writes (e.g. d_mlp, + # architecture) land on our private copy, never on the caller's. + config = _copy.deepcopy(config) + + # boot_native only knows how to build NativeModel and wire the Native + # adapter. Refuse foreign architecture strings here — otherwise we'd + # dispatch to (e.g.) LlamaArchitectureAdapter against a NativeModel + # tree, and the Llama adapter would crash opaquely in prepare_model + # looking for paths NativeModel doesn't have. + if config.architecture not in (None, "TransformerLensNative"): + raise ValueError( + f"boot_native cannot build a {config.architecture!r} model — " + f"it only constructs the TL-native architecture. Either clear " + f"config.architecture or set it to 'TransformerLensNative', " + f"or use boot_transformers / build_bridge_from_module for " + f"non-native architectures." + ) + # Resolve the architecture locally — never mutate the caller's config. + # `build_bridge_from_module` will deep-copy and set `architecture=` on + # its own copy. + architecture = "TransformerLensNative" + + model = NativeModel(config) + initialize_native_model(model, config) + + if device is not None: + model = model.to(device) + if dtype is not None: + model = model.to(dtype=dtype) + + return build_bridge_from_module( + model, + architecture=architecture, + tl_config=config, + tokenizer=tokenizer, + dtype=dtype, + device=device, + model_name=model_name, + ) + @property def original_model(self) -> nn.Module: """Get the original model.""" diff --git a/transformer_lens/model_bridge/sources/__init__.py b/transformer_lens/model_bridge/sources/__init__.py index c3e54e236..2338a17e0 100644 --- a/transformer_lens/model_bridge/sources/__init__.py +++ b/transformer_lens/model_bridge/sources/__init__.py @@ -3,6 +3,11 @@ This module provides functionality to load and convert models from HuggingFace to TransformerLens format. """ +from transformer_lens.model_bridge.sources._bridge_builder import ( + build_bridge_config_from_hf, + build_bridge_from_module, + detect_tokenizer_bos_eos, +) from transformer_lens.model_bridge.sources.transformers import ( boot, check_model_support, @@ -11,6 +16,9 @@ __all__ = [ "boot", - "list_supported_models", + "build_bridge_config_from_hf", + "build_bridge_from_module", "check_model_support", + "detect_tokenizer_bos_eos", + "list_supported_models", ] diff --git a/transformer_lens/model_bridge/sources/_bridge_builder.py b/transformer_lens/model_bridge/sources/_bridge_builder.py new file mode 100644 index 000000000..c64fb6c27 --- /dev/null +++ b/transformer_lens/model_bridge/sources/_bridge_builder.py @@ -0,0 +1,204 @@ +"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model. + +Signatures and behavior mirror dev-4.x's ``_bridge_builder.py`` so that v4 +migration is mechanical: v4 will replace this module with re-exports of the +v4 builder, leaving user-facing imports unchanged. +""" +from __future__ import annotations + +import copy +from typing import Any, Callable, Optional + +import torch +from torch import nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.bridge import TransformerBridge + +# Architecture-agnostic; do not extend per-architecture. +_HF_PASSTHROUGH_ATTRS = [ + # OPT + "is_gated_act", + "word_embed_proj_dim", + "do_layer_norm_before", + # Granite + "position_embedding_type", + # Falcon + "parallel_attn", + "multi_query", + "new_decoder_architecture", + "alibi", + "num_ln_in_parallel_attn", + # Mamba (SSM config) + "state_size", + "conv_kernel", + "expand", + "time_step_rank", + "intermediate_size", + # Mamba-2 (additional SSM config) + "n_groups", + "chunk_size", + # Multimodal + "vision_config", +] + + +def build_bridge_config_from_hf( + hf_config: Any, + architecture: str, + model_name: str, + dtype: torch.dtype, +) -> TransformerBridgeConfig: + """Translate an HF config into a :class:`TransformerBridgeConfig`.""" + from transformer_lens.model_bridge.sources.transformers import ( + map_default_transformer_lens_config, + ) + + tl_config = map_default_transformer_lens_config(hf_config) + config_dict = dict(tl_config.__dict__) + # HF's attribute_map remaps num_experts → num_local_experts; restore the TL name. + if "num_local_experts" in config_dict and "num_experts" not in config_dict: + config_dict["num_experts"] = config_dict["num_local_experts"] + bridge_config = TransformerBridgeConfig.from_dict(config_dict) + bridge_config.architecture = architecture + bridge_config.model_name = model_name + bridge_config.dtype = dtype + + for attr in _HF_PASSTHROUGH_ATTRS: + val = getattr(hf_config, attr, None) + if val is not None: + setattr(bridge_config, attr, val) + + # Gemma2: HF softcap field names differ from TL's. + final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None) + if final_logit_softcapping is not None: + bridge_config.output_logits_soft_cap = float(final_logit_softcapping) + attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None) + if attn_logit_softcapping is not None: + bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) + + return bridge_config + + +def detect_tokenizer_bos_eos(tokenizer: Any) -> tuple[bool, bool]: + """Detect whether the tokenizer prepends BOS and/or appends EOS. + + Non-empty test string — "" is unreliable with token aliasing. + """ + encoded_test = tokenizer.encode("a") + prepends_bos = ( + len(encoded_test) > 1 + and tokenizer.bos_token_id is not None + and encoded_test[0] == tokenizer.bos_token_id + ) + appends_eos = ( + len(encoded_test) > 1 + and tokenizer.eos_token_id is not None + and encoded_test[-1] == tokenizer.eos_token_id + ) + return prepends_bos, appends_eos + + +def build_bridge_from_module( + model: nn.Module, + architecture: str, + *, + hf_config: Optional[Any] = None, + tl_config: Optional[TransformerBridgeConfig] = None, + tokenizer: Optional[Any] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Any] = None, + model_name: str = "external", + post_adapter_hook: Optional[Callable[[ArchitectureAdapter], None]] = None, +) -> TransformerBridge: + """Build a :class:`TransformerBridge` around a pre-loaded model. + + The bridge never moves, casts, or mutates the supplied model. + + Args: + model: Any ``nn.Module`` whose submodule tree matches the adapter's + expected dot-paths for ``architecture``. + architecture: Architecture identifier registered in the + ``ArchitectureAdapterFactory`` (e.g. ``"LlamaForCausalLM"``, + ``"TransformerLensNative"``). + hf_config: Optional HF-style config; translated via + :func:`build_bridge_config_from_hf`. Mutually exclusive with ``tl_config``. + tl_config: Optional pre-built :class:`TransformerBridgeConfig`; bypasses + HF translation. Mutually exclusive with ``hf_config``. + tokenizer: Optional tokenizer. If supplied, passes through + ``setup_tokenizer`` and detects BOS/EOS behavior. + dtype: Recorded on ``cfg.dtype``. Default ``None`` reads from the model's + first parameter; explicit values override. + device: Recorded on ``cfg.device``. Default ``None`` reads from the + model's first parameter. + model_name: Recorded on ``cfg.model_name``. + post_adapter_hook: Optional callback invoked after adapter selection and + before :meth:`adapter.prepare_model`. Source-specific overlays mutate + ``component_mapping`` here. + + Returns: + A :class:`TransformerBridge` wrapping the supplied model. + """ + if hf_config is None and tl_config is None: + raise ValueError( + "build_bridge_from_module requires exactly one of hf_config or " + "tl_config — the bridge needs config fields (d_model, n_heads, " + "n_layers, ...) that can't be inferred from the model alone." + ) + if hf_config is not None and tl_config is not None: + raise ValueError( + "build_bridge_from_module got both hf_config and tl_config; supply " + "exactly one. hf_config triggers HF→bridge translation; tl_config " + "bypasses it." + ) + + # Reading dtype from the model avoids silently lying about a bf16 model. + if dtype is None: + try: + dtype = next(model.parameters()).dtype + except StopIteration: + dtype = torch.float32 + + if tl_config is not None: + # Defensive copy: the adapter holds onto adapter.cfg (an alias of this + # config) and mutates fields during __init__ (normalization_type, device, + # ...). Without copying, a caller that builds multiple bridges from the + # same config would see fields leak between bridges. + bridge_config = copy.deepcopy(tl_config) + bridge_config.architecture = architecture + if model_name != "external" or not getattr(bridge_config, "model_name", None): + bridge_config.model_name = model_name + bridge_config.dtype = dtype + else: + bridge_config = build_bridge_config_from_hf(hf_config, architecture, model_name, dtype) + + adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) + + if post_adapter_hook is not None: + post_adapter_hook(adapter) + + if device is not None: + adapter.cfg.device = str(device) + else: + try: + adapter.cfg.device = str(next(model.parameters()).device) + except StopIteration: + adapter.cfg.device = "cpu" + + adapter.prepare_model(model) + + if tokenizer is not None: + from transformer_lens.model_bridge.sources.transformers import setup_tokenizer + + default_padding_side = getattr(adapter.cfg, "default_padding_side", None) + tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side) + ( + adapter.cfg.tokenizer_prepends_bos, + adapter.cfg.tokenizer_appends_eos, + ) = detect_tokenizer_bos_eos(tokenizer) + + return TransformerBridge(model, adapter, tokenizer) diff --git a/transformer_lens/model_bridge/sources/native/__init__.py b/transformer_lens/model_bridge/sources/native/__init__.py new file mode 100644 index 000000000..f1b860406 --- /dev/null +++ b/transformer_lens/model_bridge/sources/native/__init__.py @@ -0,0 +1,16 @@ +"""TL-native model source for TransformerBridge.""" +from transformer_lens.model_bridge.sources.native.init import initialize_native_model +from transformer_lens.model_bridge.sources.native.model import ( + NativeAttention, + NativeBlock, + NativeMLP, + NativeModel, +) + +__all__ = [ + "NativeAttention", + "NativeBlock", + "NativeMLP", + "NativeModel", + "initialize_native_model", +] diff --git a/transformer_lens/model_bridge/sources/native/init.py b/transformer_lens/model_bridge/sources/native/init.py new file mode 100644 index 000000000..b6927593a --- /dev/null +++ b/transformer_lens/model_bridge/sources/native/init.py @@ -0,0 +1,152 @@ +"""Weight init for NativeModel. Supports the standard TL init modes. + +Modes: + +- ``"gpt2"`` (default): Normal(0, initializer_range) for everything, with + ``1/sqrt(2 * n_layers)`` residual scaling on attn.o and mlp output. Mirrors + HookedTransformer's ``init_mode='gpt2'`` and is the right default for GPT-style + toy models. +- ``"xavier_uniform"`` / ``"xavier_normal"``: ``torch.nn.init.xavier_{uniform,normal}_`` + on linear weights and embeddings. Useful for sanity-checking sensitivity to + init scheme without changing architecture. +- ``"kaiming_uniform"`` / ``"kaiming_normal"``: ``torch.nn.init.kaiming_{uniform,normal}_`` + with ``nonlinearity='relu'``. Reasonable default for ReLU-family activations. + +For every mode: LayerNorm/RMSNorm weights init to 1, biases to 0; Linear biases +init to 0. + +HookedTransformer's exact init pulls from a private RNG. We rely on +``torch.manual_seed`` plus per-layer ``torch.nn.init`` for reproducibility. +""" +from __future__ import annotations + +import math +from typing import Callable + +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig + +from .model import ( + NativeAttention, + NativeBlock, + NativeGatedMLP, + NativeMLP, + NativeModel, + NativeRMSNorm, +) + +# Init modes that don't depend on residual depth — they treat every weight +# identically. The residual-scaled output projection trick is gpt2-specific. +_NON_RESIDUAL_MODES: dict[str, Callable[[torch.Tensor], None]] = { + "xavier_uniform": nn.init.xavier_uniform_, + "xavier_normal": nn.init.xavier_normal_, + "kaiming_uniform": lambda t: nn.init.kaiming_uniform_(t, nonlinearity="relu"), + "kaiming_normal": lambda t: nn.init.kaiming_normal_(t, nonlinearity="relu"), +} + +_SUPPORTED_MODES = frozenset({"gpt2", *_NON_RESIDUAL_MODES}) + + +def initialize_native_model( + model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None +) -> None: + """Initialize ``model`` weights in-place. Honors ``cfg.init_mode`` and ``cfg.seed``.""" + effective_seed = seed if seed is not None else cfg.seed + if effective_seed is not None: + torch.manual_seed(effective_seed) + + init_mode = (cfg.init_mode or "gpt2").lower() + if init_mode not in _SUPPORTED_MODES: + raise NotImplementedError( + f"init_mode={init_mode!r} is not supported for NativeModel. " + f"Supported modes: {sorted(_SUPPORTED_MODES)}." + ) + + if init_mode == "gpt2": + std = cfg.initializer_range if cfg.initializer_range > 0 else 0.02 + residual_scale = 1.0 / math.sqrt(2 * cfg.n_layers) + weight_init = lambda t: nn.init.normal_(t, mean=0.0, std=std) # noqa: E731 + output_init = lambda t: nn.init.normal_(t, mean=0.0, std=std * residual_scale) # noqa: E731 + else: + # Non-gpt2 modes ignore residual-depth scaling — they have their own + # gain rules that already account for the layer's role. + weight_init = _NON_RESIDUAL_MODES[init_mode] + output_init = weight_init + + weight_init(model.tok_embed.weight) + if model.pos is not None: + weight_init(model.pos.weight) + # Rotary has only registered buffers (cos/sin), no parameters to init. + + for block in model.layers: + _init_block(block, weight_init=weight_init, output_init=output_init) + + _init_norm(model.ln_out) + weight_init(model.head.weight) + + +def _init_norm(norm: nn.Module) -> None: + if isinstance(norm, NativeRMSNorm): + nn.init.ones_(norm.weight) + elif isinstance(norm, nn.LayerNorm): + nn.init.ones_(norm.weight) + nn.init.zeros_(norm.bias) + else: + raise TypeError(f"Unknown normalization type: {type(norm).__name__}") + + +def _init_block( + block: NativeBlock, + *, + weight_init: Callable[[torch.Tensor], None], + output_init: Callable[[torch.Tensor], None], +) -> None: + _init_norm(block.ln1) + _init_attention(block.attn, weight_init=weight_init, output_init=output_init) + if not block.cfg.attn_only: + _init_norm(block.ln2) + if isinstance(block.mlp, NativeGatedMLP): + _init_gated_mlp(block.mlp, weight_init=weight_init, output_init=output_init) + else: + _init_mlp(block.mlp, weight_init=weight_init, output_init=output_init) + + +def _init_attention( + attn: NativeAttention, + *, + weight_init: Callable[[torch.Tensor], None], + output_init: Callable[[torch.Tensor], None], +) -> None: + for linear in (attn.q, attn.k, attn.v): + weight_init(linear.weight) + if linear.bias is not None: + nn.init.zeros_(linear.bias) + output_init(attn.o.weight) + if attn.o.bias is not None: + nn.init.zeros_(attn.o.bias) + + +def _init_mlp( + mlp: NativeMLP, + *, + weight_init: Callable[[torch.Tensor], None], + output_init: Callable[[torch.Tensor], None], +) -> None: + weight_init(mlp.fc_in.weight) + nn.init.zeros_(mlp.fc_in.bias) + output_init(mlp.fc_out.weight) + nn.init.zeros_(mlp.fc_out.bias) + + +def _init_gated_mlp( + mlp: NativeGatedMLP, + *, + weight_init: Callable[[torch.Tensor], None], + output_init: Callable[[torch.Tensor], None], +) -> None: + weight_init(mlp.gate.weight) + # ``in`` is registered via add_module; getattr resolves it from _modules. + weight_init(getattr(mlp, "in").weight) + output_init(mlp.out.weight) diff --git a/transformer_lens/model_bridge/sources/native/model.py b/transformer_lens/model_bridge/sources/native/model.py new file mode 100644 index 000000000..a905140c5 --- /dev/null +++ b/transformer_lens/model_bridge/sources/native/model.py @@ -0,0 +1,353 @@ +"""TL-native transformer model for use with TransformerBridge. + +A minimal, from-scratch transformer implementation with no HuggingFace or +HookedTransformer dependency. Internal attribute names are deliberately chosen +to NOT collide with the bridge's top-level component slot names +("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__ falls back +to ``original_model.`` and an HF-style collision would block add_module +during bridge setup. + +Features driven by config fields: + +- ``normalization_type``: ``"LN"`` (default) or ``"RMS"`` / ``"RMSPre"``. +- ``final_rms``: when True, the final norm uses RMS regardless of block norm. +- ``gated_mlp``: when True, swaps in a SwiGLU-style gated MLP (Llama/Mistral). +- ``attn_only``: when True, blocks have no MLP / no ln2. +- ``n_key_value_heads``: when set and < ``n_heads``, enables grouped-query + attention (Llama 3.x / Mistral / DeepSeek style). +- ``attn_scores_soft_cap``: when > 0, applies Gemma2-style tanh soft-cap to + pre-softmax attention scores. +- ``output_logits_soft_cap``: when > 0, applies tanh soft-cap to final logits. +- ``positional_embedding_type``: ``"standard"`` (absolute, default) or + ``"rotary"``. Rotary applies inside attention; absolute uses ``self.pos``. +- ``rotary_dim``: partial-rotary dim (rotates first ``rotary_dim`` of each + head; pass-through the rest). Default ``d_head``. +- ``rotary_base``: RoPE base frequency. Default ``10000``. +""" +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformer_lens.config import TransformerBridgeConfig + +# gelu_new is the tanh-approximation of GELU (what HF GPT-2 and HookedTransformer +# use). PyTorch's F.gelu accepts approximate="tanh" since 1.10 — that's exactly +# the same formula, no need to roll our own. +_ACTIVATIONS = { + "gelu": F.gelu, + "gelu_new": lambda x: F.gelu(x, approximate="tanh"), + "relu": F.relu, + "silu": F.silu, + "swish": F.silu, +} + + +def _uses_rms_norm(cfg: TransformerBridgeConfig) -> bool: + return (cfg.normalization_type or "LN").upper() in ("RMS", "RMSPRE") + + +def _positional_kind(cfg: TransformerBridgeConfig) -> str: + return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() + + +class NativeRMSNorm(nn.Module): + """Root-mean-square LayerNorm. No mean centering, no bias. + + Matches the math used by Llama / Mistral / T5: ``y = w * x / rms(x)`` where + ``rms(x) = sqrt(mean(x^2) + eps)``. The variance is computed in fp32 + regardless of input dtype — mirroring HF Llama's LlamaRMSNorm — so bf16/fp16 + inputs don't accumulate variance drift. The result is cast back to the + input dtype before the per-channel scale, so the scale runs in the user's + chosen precision. + + The bridge's RMSNormalizationBridge wraps any module with a ``weight`` + attribute and a forward returning the normalized tensor — no further + coordination required. + """ + + def __init__(self, d_model: int, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(d_model)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + x_fp32 = x.to(torch.float32) + rms_inv = torch.rsqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps) + normalized = (x_fp32 * rms_inv).to(input_dtype) + return self.weight * normalized + + +def _make_norm(cfg: TransformerBridgeConfig, *, force_rms: bool = False) -> nn.Module: + if force_rms or _uses_rms_norm(cfg): + return NativeRMSNorm(cfg.d_model, eps=cfg.eps) + return nn.LayerNorm(cfg.d_model, eps=cfg.eps) + + +class NativeRotary(nn.Module): + """Pre-computes the cos/sin tables used by RoPE. + + Lives at the model level (one shared instance) so all attention layers + re-use the same buffers. Per-call, we just slice to the current sequence + length. No HF dependency. + """ + + def __init__(self, cfg: TransformerBridgeConfig): + super().__init__() + rotary_dim = cfg.rotary_dim if cfg.rotary_dim is not None else cfg.d_head + if rotary_dim <= 0 or rotary_dim % 2 != 0: + raise ValueError(f"rotary_dim must be a positive even integer, got {rotary_dim!r}") + self.rotary_dim = rotary_dim + base = float(cfg.rotary_base) + inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + positions = torch.arange(cfg.n_ctx).float() + freqs = torch.outer(positions, inv_freq) # [n_ctx, rotary_dim/2] + # Adjacent-pair format (the form Llama/HF use): each pair (2i, 2i+1) + # rotates together. We expand cos/sin per element of each pair. + cos = freqs.cos().repeat_interleave(2, dim=-1) # [n_ctx, rotary_dim] + sin = freqs.sin().repeat_interleave(2, dim=-1) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + @staticmethod + def _rotate_half(x: torch.Tensor) -> torch.Tensor: + # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0) + x1 = x[..., 0::2] + x2 = x[..., 1::2] + rot = torch.stack((-x2, x1), dim=-1) + return rot.flatten(-2) + + def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Apply RoPE to Q and K. Tensors are [batch, heads, seq, d_head].""" + seq = q.shape[-2] + rd = self.rotary_dim + cos = self.cos_cached[:seq].to(q.dtype) # [seq, rd] + sin = self.sin_cached[:seq].to(q.dtype) + + def _rope(x: torch.Tensor) -> torch.Tensor: + x_rot, x_pass = x[..., :rd], x[..., rd:] + x_rot = x_rot * cos + self._rotate_half(x_rot) * sin + return torch.cat([x_rot, x_pass], dim=-1) if x_pass.shape[-1] else x_rot + + return _rope(q), _rope(k) + + +class NativeAttention(nn.Module): + """Split-QKV causal self-attention with optional GQA, RoPE, and soft-cap. + + Returns ``(attn_output, attention_weights)`` so the bridge's AttentionBridge + fires ``hook_pattern`` off the second element. + """ + + def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None): + super().__init__() + self.cfg = cfg + self.n_heads = cfg.n_heads + self.d_head = cfg.d_head + self.d_model = cfg.d_model + # n_key_value_heads governs GQA: K/V have fewer heads than Q. Default + # to n_heads (= standard multi-head attention). + self.n_kv_heads = cfg.n_key_value_heads or cfg.n_heads + if self.n_heads % self.n_kv_heads != 0: + raise ValueError( + f"n_heads ({self.n_heads}) must be divisible by n_key_value_heads " + f"({self.n_kv_heads}) for GQA." + ) + self.kv_repeats = self.n_heads // self.n_kv_heads + + q_dim = self.n_heads * self.d_head + kv_dim = self.n_kv_heads * self.d_head + self.q = nn.Linear(cfg.d_model, q_dim, bias=True) + self.k = nn.Linear(cfg.d_model, kv_dim, bias=True) + self.v = nn.Linear(cfg.d_model, kv_dim, bias=True) + self.o = nn.Linear(q_dim, cfg.d_model, bias=True) + + mask = torch.triu(torch.ones(cfg.n_ctx, cfg.n_ctx, dtype=torch.bool), diagonal=1) + self.register_buffer("causal_mask", mask, persistent=False) + + scale = ( + cfg.attn_scale if cfg.use_attn_scale and cfg.attn_scale > 0 else math.sqrt(cfg.d_head) + ) + self.scale = scale + self.rotary = rotary # None unless cfg.positional_embedding_type == "rotary" + self.attn_scores_soft_cap = float(cfg.attn_scores_soft_cap) + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + batch, seq, _ = hidden_states.shape + + q = self.q(hidden_states).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2) + k = self.k(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2) + v = self.v(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2) + + if self.rotary is not None: + q, k = self.rotary.apply(q, k) + + # Expand K/V to match Q head count under GQA. repeat_interleave keeps + # group ordering consistent with HF Llama's repeat_kv. + if self.kv_repeats > 1: + k = k.repeat_interleave(self.kv_repeats, dim=1) + v = v.repeat_interleave(self.kv_repeats, dim=1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale + # Gemma2-style attention soft-cap: c * tanh(scores / c). Bounds raw + # logits before the causal mask so masked positions stay -inf. + if self.attn_scores_soft_cap > 0: + c = self.attn_scores_soft_cap + scores = c * torch.tanh(scores / c) + mask = self.causal_mask[:seq, :seq] + scores = scores.masked_fill(mask, float("-inf")) + pattern = F.softmax(scores, dim=-1) + + attn = torch.matmul(pattern, v).transpose(1, 2).contiguous().view(batch, seq, -1) + out = self.o(attn) + return out, pattern + + +class NativeMLP(nn.Module): + """Two-layer MLP with configurable activation.""" + + def __init__(self, cfg: TransformerBridgeConfig): + super().__init__() + d_mlp = cfg.d_mlp + self.fc_in = nn.Linear(cfg.d_model, d_mlp, bias=True) + self.fc_out = nn.Linear(d_mlp, cfg.d_model, bias=True) + act_name = (cfg.act_fn or "gelu").lower() + if act_name not in _ACTIVATIONS: + raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}") + self.act = _ACTIVATIONS[act_name] + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + return self.fc_out(self.act(self.fc_in(hidden_states))) + + +class NativeGatedMLP(nn.Module): + """SwiGLU-style gated MLP (Llama / Mistral / Gemma2). + + Submodule names ``gate`` / ``in`` / ``out`` align with the bridge's + GatedMLPBridge submodule slots; the adapter wires them by these names. + """ + + def __init__(self, cfg: TransformerBridgeConfig): + super().__init__() + d_mlp = cfg.d_mlp + # No biases by default — matches Llama. Users wanting biased gated MLPs + # can subclass; toy-scope stays simple. + self.gate = nn.Linear(cfg.d_model, d_mlp, bias=False) + # ``in`` is a Python keyword, so we can't write ``self.in = ...`` — + # but ``add_module`` accepts any string and stores it in ``_modules``, + # so ``getattr(self, "in")`` resolves it the same way the bridge does + # when walking ``LinearBridge(name="in")``. No __getattr__ override + # required. + self.add_module("in", nn.Linear(cfg.d_model, d_mlp, bias=False)) + self.out = nn.Linear(d_mlp, cfg.d_model, bias=False) + # Gated MLPs typically pair with SiLU/swish; honor cfg if the user picked + # a different activation, but default to silu. + act_name = (cfg.act_fn or "silu").lower() + if act_name == "gelu": # GeGLU variant + self.act = _ACTIVATIONS["gelu"] + elif act_name == "gelu_new": + self.act = _ACTIVATIONS["gelu_new"] + else: + self.act = F.silu + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + gate_out = self.act(self.gate(hidden_states)) + up_out = getattr(self, "in")(hidden_states) + return self.out(gate_out * up_out) + + +class NativeBlock(nn.Module): + """Pre-LN transformer block. Layout adapts to ``cfg.attn_only`` and + ``cfg.gated_mlp``.""" + + def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None): + super().__init__() + self.cfg = cfg + self.ln1 = _make_norm(cfg) + self.attn = NativeAttention(cfg, rotary=rotary) + if not cfg.attn_only: + self.ln2 = _make_norm(cfg) + self.mlp = NativeGatedMLP(cfg) if cfg.gated_mlp else NativeMLP(cfg) + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: + attn_out, _pattern = self.attn(self.ln1(hidden_states)) + hidden_states = hidden_states + attn_out + if not self.cfg.attn_only: + hidden_states = hidden_states + self.mlp(self.ln2(hidden_states)) + # Tuple return matches HF block convention so BlockBridge's parser is happy. + return (hidden_states,) + + +class NativeModel(nn.Module): + """TL-native transformer. See module docstring for the supported feature set.""" + + def __init__(self, cfg: TransformerBridgeConfig): + super().__init__() + # Resolve defaults that NativeMLP / nn.Embedding need, and write them + # back so downstream consumers reading cfg.d_mlp see the real value + # instead of None. Mutates the supplied cfg; callers that want isolation + # (e.g. TransformerBridge.boot_native) deep-copy the user's cfg before + # constructing the model. + if not getattr(cfg, "d_mlp", None): + cfg.d_mlp = 4 * cfg.d_model + self.cfg = cfg + + self.tok_embed = nn.Embedding(cfg.d_vocab, cfg.d_model) + + kind = _positional_kind(cfg) + if kind == "standard": + self.pos = nn.Embedding(cfg.n_ctx, cfg.d_model) + self.rotary = None + elif kind == "rotary": + self.pos = None + self.rotary = NativeRotary(cfg) + else: + raise ValueError( + f"Unsupported positional_embedding_type={kind!r}. " + f"NativeModel supports 'standard' and 'rotary'." + ) + + self.layers = nn.ModuleList( + [NativeBlock(cfg, rotary=self.rotary) for _ in range(cfg.n_layers)] + ) + # final_rms overrides the block-norm choice — Llama uses LN-equivalent + # blocks but final_rms is true in TL config to opt into RMSNorm on the + # final norm. We honor the same semantic. + self.ln_out = _make_norm(cfg, force_rms=cfg.final_rms) + d_vocab_out = cfg.d_vocab_out if cfg.d_vocab_out > 0 else cfg.d_vocab + self.head = nn.Linear(cfg.d_model, d_vocab_out, bias=False) + self.output_logits_soft_cap = float(cfg.output_logits_soft_cap) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Returns logits directly. The bridge unwraps either .logits, tuple[0], + or a bare tensor — we pick the simplest path. + """ + hidden_states = self.tok_embed(input_ids) + if self.pos is not None: + batch, seq = input_ids.shape + if position_ids is None: + position_ids = ( + torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1) + ) + hidden_states = hidden_states + self.pos(position_ids) + + for block in self.layers: + (hidden_states,) = block(hidden_states) + hidden_states = self.ln_out(hidden_states) + logits = self.head(hidden_states) + # Gemma2-style output soft-cap. + if self.output_logits_soft_cap > 0: + c = self.output_logits_soft_cap + logits = c * torch.tanh(logits / c) + return logits diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 8bd181d1e..87212f6d8 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -102,6 +102,9 @@ from transformer_lens.model_bridge.supported_architectures.nanogpt import ( NanogptArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.native import ( + NativeArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.neel_solu_old import ( NeelSoluOldArchitectureAdapter, ) @@ -197,6 +200,7 @@ "MixtralArchitectureAdapter", "MPTArchitectureAdapter", "NanogptArchitectureAdapter", + "NativeArchitectureAdapter", "NeelSoluOldArchitectureAdapter", "NeoArchitectureAdapter", "NeoxArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/native.py b/transformer_lens/model_bridge/supported_architectures/native.py new file mode 100644 index 000000000..afd5a3fbf --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/native.py @@ -0,0 +1,163 @@ +"""Architecture adapter for TL-native models built via ``boot_native``. + +This adapter targets ``NativeModel`` ([sources/native/model.py]). Because the +native module's hierarchy is fully under our control, the component paths are +flat (no ``transformer.h.{i}`` prefix) and split-QKV is the natural layout — +no weight conversions are required for ordinary use. + +The component mapping adapts to the cfg: gated MLP swaps in ``GatedMLPBridge``, +RMS norm swaps in ``RMSNormalizationBridge``, rotary skips ``pos_embed``, and +``attn_only`` drops the MLP branch. +""" +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + PosEmbedBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) + + +def _uses_rms(cfg: Any) -> bool: + return (getattr(cfg, "normalization_type", None) or "LN").upper() in ("RMS", "RMSPRE") + + +def _is_rotary(cfg: Any) -> bool: + return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() == "rotary" + + +def _make_norm_bridge(name: str, cfg: Any, *, force_rms: bool = False): + if force_rms or _uses_rms(cfg): + return RMSNormalizationBridge(name=name, config=cfg) + return NormalizationBridge(name=name, config=cfg) + + +def _make_mlp_bridge(cfg: Any): + if cfg.gated_mlp: + return GatedMLPBridge( + name="mlp", + config=cfg, + submodules={ + "gate": LinearBridge(name="gate"), + "in": LinearBridge(name="in"), + "out": LinearBridge(name="out"), + }, + ) + return MLPBridge( + name="mlp", + submodules={ + "in": LinearBridge(name="fc_in"), + "out": LinearBridge(name="fc_out"), + }, + ) + + +def _make_block_submodules(cfg: Any) -> dict: + submods: dict = { + "ln1": _make_norm_bridge("ln1", cfg), + "attn": AttentionBridge( + name="attn", + config=cfg, + submodules={ + "q": LinearBridge(name="q"), + "k": LinearBridge(name="k"), + "v": LinearBridge(name="v"), + "o": LinearBridge(name="o"), + }, + ), + } + if not cfg.attn_only: + submods["ln2"] = _make_norm_bridge("ln2", cfg) + submods["mlp"] = _make_mlp_bridge(cfg) + return submods + + +class NativeArchitectureAdapter(ArchitectureAdapter): + """Adapter for ``NativeModel`` — TL-native, split-QKV, pre-LN; feature set + driven by cfg (gated MLP, RMS norm, rotary, GQA, soft-cap, attn_only).""" + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + # Native layout already stores Q/K/V split per-head in [d_model, n*d_head] + # form. We skip weight_processing_conversions for ordinary use; compatibility + # mode (fold_ln + center_writing_weights) can be added in a follow-up. + # Until then, gate the corresponding ProcessWeights paths off: without the + # state-dict key conversions wired up, folding would silently mis-place + # weights or raise on missing keys. + self.supports_fold_ln = False + self.supports_center_writing_weights = False + self.weight_processing_conversions = {} + + # Native model uses non-colliding attribute names ("tok_embed", "layers", + # "ln_out", "head") because the bridge's __getattr__ forwards unknown + # names to original_model., which would shadow the bridge's own + # component slots ("embed", "blocks", "ln_final", "unembed") during + # add_module if they matched 1:1. + mapping: dict = { + "embed": EmbeddingBridge(name="tok_embed"), + } + if not _is_rotary(cfg): + mapping["pos_embed"] = PosEmbedBridge(name="pos") + block_bridge = BlockBridge( + name="layers", + config=self.cfg, + submodules=_make_block_submodules(self.cfg), + ) + # Under attn_only the ln2 and mlp submodules are absent, but + # BlockBridge's class-level hook_aliases still points + # ``hook_resid_mid -> ln2.hook_in`` and ``hook_mlp_out -> mlp.hook_out``. + # _register_aliases warns when those don't resolve. Drop them so the + # warnings stay meaningful elsewhere — the pattern mirrors + # ParallelBlockBridge ([block.py:405-407]). + if self.cfg.attn_only: + if block_bridge.hook_aliases is BlockBridge.hook_aliases: + block_bridge.hook_aliases = dict(block_bridge.hook_aliases) + block_bridge.hook_aliases.pop("hook_resid_mid", None) + block_bridge.hook_aliases.pop("hook_mlp_out", None) + mapping["blocks"] = block_bridge + # ``final_rms`` opts into RMSNorm on the final norm regardless of + # whether the blocks themselves use RMS — Llama-style configs do this. + mapping["ln_final"] = _make_norm_bridge( + "ln_out", self.cfg, force_rms=bool(getattr(self.cfg, "final_rms", False)) + ) + mapping["unembed"] = UnembeddingBridge(name="head") + self.component_mapping = mapping + + def prepare_model(self, model: Any) -> None: + """Reject modules whose attribute names would collide with bridge slots. + + The reserved-slot set is derived from ``self.component_mapping.keys()`` + at call time — single source of truth. A future variant that adds (or + omits) a top-level slot extends the collision check automatically; no + sibling list to keep in sync. + + The bridge's ``__getattr__`` falls back to ``getattr(original_model, name)`` + for unknown attributes — that resolves submodules, registered buffers, + plain tensors set with ``self.x = ...``, and any property. Any of these + will make ``add_module`` raise during bridge setup. We use ``hasattr`` + (not ``name in model._modules``) so the check covers all attribute + shapes, not just registered nn.Modules. + + Failing here makes the diagnostic point at the real cause instead of a + ``KeyError: "attribute 'embed' already exists"`` deep in component + setup. + """ + reserved = set(self.component_mapping.keys()) if self.component_mapping else set() + collisions = sorted(name for name in reserved if hasattr(model, name)) + if collisions: + raise ValueError( + f"{type(model).__name__} cannot be wrapped by NativeArchitectureAdapter: " + f"attribute names {collisions} collide with bridge component slots " + f"({sorted(reserved)}). Rename these attributes to non-colliding names " + f"(e.g. tok_embed, layers, ln_out, head) and update the adapter's " + f"component_mapping ``name=`` fields to match." + ) From ad6a66fd26c499272cc04aef658b71fb33f6077c Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 1 Jun 2026 17:58:31 -0500 Subject: [PATCH 2/2] Format and comment cleanup --- tests/unit/model_bridge/test_boot_native.py | 21 ++ .../unit/model_bridge/test_native_features.py | 295 +++++++++++++++ transformer_lens/model_bridge/bridge.py | 60 ++-- .../model_bridge/sources/_bridge_builder.py | 13 +- .../model_bridge/sources/native/init.py | 88 ++--- .../model_bridge/sources/native/model.py | 340 +++++++++++------- .../supported_architectures/native.py | 64 ++-- 7 files changed, 630 insertions(+), 251 deletions(-) diff --git a/tests/unit/model_bridge/test_boot_native.py b/tests/unit/model_bridge/test_boot_native.py index 03284aa09..f773e6115 100644 --- a/tests/unit/model_bridge/test_boot_native.py +++ b/tests/unit/model_bridge/test_boot_native.py @@ -97,6 +97,27 @@ def test_boot_native_accepts_dict_config(): assert bridge.cfg.architecture == "TransformerLensNative" +def test_boot_native_does_not_perturb_global_rng(): + """``boot_native(seed=...)`` must use a scoped torch.Generator instead of + ``torch.manual_seed``. Otherwise a user calling boot_native then + ``torch.randn(...)`` for batch sampling silently gets a deterministic + sequence they didn't ask for.""" + # Snapshot what torch.randn(5) would produce starting from global seed 0. + torch.manual_seed(0) + expected_after = torch.randn(5) + + # Now re-seed globally to 0, build a seeded bridge, and confirm the next + # torch.randn(5) still matches the pre-bridge prediction. + torch.manual_seed(0) + TransformerBridge.boot_native(_cfg(seed=42)) + actual_after = torch.randn(5) + + assert torch.equal(actual_after, expected_after), ( + "boot_native perturbed the global RNG — the next torch.randn diverged " + f"from the pre-call baseline.\n expected: {expected_after}\n got: {actual_after}" + ) + + def test_boot_native_seed_is_honored(): a = TransformerBridge.boot_native(_cfg(seed=123)) b = TransformerBridge.boot_native(_cfg(seed=123)) diff --git a/tests/unit/model_bridge/test_native_features.py b/tests/unit/model_bridge/test_native_features.py index c64943c6f..c55088e47 100644 --- a/tests/unit/model_bridge/test_native_features.py +++ b/tests/unit/model_bridge/test_native_features.py @@ -64,6 +64,43 @@ def test_attn_scores_soft_cap_bounds_pattern(): assert torch.isfinite(pattern).all() +def test_attn_scale_one_is_rejected_when_d_head_gt_one(): + """``attn_scale=1.0`` is a UX trap: it reads like "no scaling / standard" + but actually means "divide by 1" (no scaling). For any non-trivial d_head + the softmax saturates and training breaks. The constructor must refuse + this combination with a pointing message.""" + import pytest + + cfg = _cfg(d_head=16) + cfg.use_attn_scale = True + cfg.attn_scale = 1.0 + with pytest.raises(ValueError, match="attn_scale=1.0"): + TransformerBridge.boot_native(cfg) + + +def test_attn_scale_one_allowed_when_d_head_one(): + """``d_head=1`` makes ``sqrt(d_head)==1``, so ``attn_scale=1.0`` is no + longer a trap — it's the same as the default scaling. The guard must NOT + fire.""" + cfg = _cfg(d_head=1, d_model=4, n_heads=4) + cfg.use_attn_scale = True + cfg.attn_scale = 1.0 + bridge = TransformerBridge.boot_native(cfg) + # Forward must work end-to-end. + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + _ = bridge(inputs, return_type="logits") + + +def test_attn_scale_custom_nondefault_is_allowed(): + """The guard fires only on the specific 1.0 trap, not on any other custom + scale a user might pick (e.g. for parity with an external implementation).""" + cfg = _cfg(d_head=16) + cfg.use_attn_scale = True + cfg.attn_scale = 2.5 # not 1.0, not sqrt(d_head); user knows what they want + bridge = TransformerBridge.boot_native(cfg) + assert bridge.original_model.layers[0].attn.scale == 2.5 + + def test_output_logits_soft_cap_bounds_logits(): """Logits must be bounded by ±cap when output_logits_soft_cap > 0.""" cap = 5.0 @@ -93,6 +130,40 @@ def test_gated_mlp_swaps_module_class_and_bridge(): _ = _forward(bridge) # forward must work end-to-end +def test_gated_mlp_honors_act_fn(): + """``cfg.act_fn`` selects the gating non-linearity: silu→SwiGLU, + relu→ReGLU, gelu/gelu_new→GeGLU. Each must actually use the requested + activation rather than silently falling back to silu.""" + cfg_relu = _cfg(gated_mlp=True, act_fn="relu") + cfg_silu = _cfg(gated_mlp=True, act_fn="silu") + cfg_gelu = _cfg(gated_mlp=True, act_fn="gelu") + + a = TransformerBridge.boot_native(cfg_relu).blocks[0].mlp.original_component + b = TransformerBridge.boot_native(cfg_silu).blocks[0].mlp.original_component + c = TransformerBridge.boot_native(cfg_gelu).blocks[0].mlp.original_component + + # Probe each activation with a sentinel: the negatives expose the difference. + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + relu_ref = torch.relu(x) + silu_ref = torch.nn.functional.silu(x) + gelu_ref = torch.nn.functional.gelu(x) + + assert torch.allclose(a.act(x), relu_ref) + assert torch.allclose(b.act(x), silu_ref) + assert torch.allclose(c.act(x), gelu_ref) + + +def test_gated_mlp_rejects_unknown_act_fn(): + """Parity with NativeMLP: unknown act_fn must raise, not silently default + to silu. Otherwise a user typing ``act_fn="reul"`` and toggling gated_mlp + on/off would see two different models with no diagnostic.""" + import pytest + + cfg = _cfg(gated_mlp=True, act_fn="not-an-activation") + with pytest.raises(ValueError, match="Unsupported act_fn"): + TransformerBridge.boot_native(cfg) + + def test_gated_mlp_default_is_plain(): cfg = _cfg() bridge = TransformerBridge.boot_native(cfg) @@ -254,6 +325,164 @@ def test_rotary_drops_pos_embed_and_forward_works(): _ = _forward(bridge) +def test_rotary_does_not_shadow_nn_module_apply(): + """nn.Module.apply(fn) is PyTorch's recursive function-applier — used by + init utilities, weight inspection, and many production training loops. + NativeRotary's RoPE helper must be named something other than ``apply`` + so it doesn't break ``bridge.apply(fn)`` when a rotary instance lives in + the module tree.""" + cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16) + bridge = TransformerBridge.boot_native(cfg) + + visited = [] + bridge.apply(lambda m: visited.append(type(m).__name__)) + # The recursion must complete and visit modules — including the rotary one + # — without TypeError. + assert "NativeRotary" in visited + + +def test_rotary_honors_position_ids(): + """Rotary must slice the cached cos/sin tables by ``position_ids`` so the + caller's chosen positions are honored. A silent-drop bug uses 0..seq-1 + regardless and produces identical patterns regardless of the supplied + positions — exactly the failure mode for packed sequences, prefix caches, + and continuation past a cached prefix. + + RoPE has a translation-invariance property: rotating Q and K by the same + constant shift leaves ``q·k`` (and thus the attention pattern) unchanged. + So we pick positions with **different relative spacings** (1-apart vs + 2-apart), which produce genuinely different attention patterns when RoPE + honors ``position_ids`` and identical patterns when it doesn't. + """ + cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=16) + bridge = TransformerBridge.boot_native(cfg) + inputs = torch.randint(0, cfg.d_vocab, (2, 8)) + + dense_positions = torch.arange(8).unsqueeze(0).expand(2, -1) # 0,1,2,...,7 + spaced_positions = (torch.arange(8) * 2).unsqueeze(0).expand(2, -1) # 0,2,4,...,14 + + _, c_dense = bridge.run_with_cache(inputs, return_type="logits", position_ids=dense_positions) + _, c_spaced = bridge.run_with_cache(inputs, return_type="logits", position_ids=spaced_positions) + _, c_none = bridge.run_with_cache(inputs, return_type="logits") + + pat_dense = c_dense["blocks.0.attn.hook_pattern"] + pat_spaced = c_spaced["blocks.0.attn.hook_pattern"] + pat_none = c_none["blocks.0.attn.hook_pattern"] + + # Default (no position_ids) must match dense 0..seq-1 exactly. + assert torch.allclose(pat_dense, pat_none, atol=1e-6) + # Different relative spacings must produce different attention patterns. + assert not torch.allclose(pat_dense, pat_spaced, atol=1e-4), ( + "Rotary attention ignored position_ids — pattern unchanged despite " + "different relative spacings." + ) + + +def test_input_longer_than_n_ctx_raises_with_clear_message(): + """Both the absolute-embed nn.Embedding lookup and the rotary cos/sin + broadcast would otherwise produce opaque errors that don't mention n_ctx. + The up-front check must raise ValueError naming both the input length and + n_ctx, on both code paths.""" + import pytest + + for kind in ("standard", "rotary"): + cfg = _cfg( + positional_embedding_type=kind, + n_heads=4, + d_head=16, + n_ctx=8, + ) + bridge = TransformerBridge.boot_native(cfg) + long_inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx + 4)) + with pytest.raises(ValueError, match=r"input length 12 exceeds n_ctx=8"): + bridge(long_inputs, return_type="logits") + + +def test_rope_scaling_linear_extends_effective_context(): + """Linear (position-interpolation) rope_scaling divides positions by the + factor before computing freqs. Same n_ctx-slot table, factor× longer + effective context. We assert the cached cos table differs from the + unscaled version and matches a hand-computed reference.""" + cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + cfg.rope_scaling = {"type": "linear", "factor": 2.0} + bridge = TransformerBridge.boot_native(cfg) + rotary = bridge.original_model.rotary + assert rotary.position_scale == 2.0 + + # Reference: positions divided by 2 → freqs are half. cos[pos=0] is still + # all-ones, but cos[pos=1] is the unscaled cos[pos=0.5] — i.e. shallower + # rotation than the unscaled cos[pos=1]. + cfg_no = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + rotary_no = TransformerBridge.boot_native(cfg_no).original_model.rotary + assert not torch.allclose(rotary.cos_cached, rotary_no.cos_cached) + + +def test_rope_scaling_ntk_scales_base_frequency(): + """NTK-aware rope_scaling scales the base frequency rather than positions. + Effective base must exceed the configured rotary_base by factor^(d/(d-2)).""" + cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + cfg.rope_scaling = {"type": "ntk", "factor": 4.0} + bridge = TransformerBridge.boot_native(cfg) + rotary = bridge.original_model.rotary + + expected_base = float(cfg.rotary_base) * (4.0 ** (16 / 14)) + assert abs(rotary.effective_base - expected_base) < 1.0 + assert rotary.position_scale == 1.0 # NTK doesn't touch positions. + + +def test_rope_scaling_llama3_rescales_inv_freq_per_band(): + """Llama-3 by-parts scheme rescales inv_freq per frequency band rather + than uniformly. With factor>1 and a reasonable original_ctx, the resulting + cos table must differ from both the unscaled and the linear-scaled versions + — otherwise we silently fell through to one of the simpler paths.""" + cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + cfg.rope_scaling = { + "type": "llama3", + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8, + } + bridge_llama3 = TransformerBridge.boot_native(cfg) + + cfg_linear = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + cfg_linear.rope_scaling = {"type": "linear", "factor": 8.0} + bridge_linear = TransformerBridge.boot_native(cfg_linear) + + cfg_none = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + bridge_none = TransformerBridge.boot_native(cfg_none) + + c_llama3 = bridge_llama3.original_model.rotary.cos_cached + c_linear = bridge_linear.original_model.rotary.cos_cached + c_none = bridge_none.original_model.rotary.cos_cached + + assert not torch.allclose(c_llama3, c_none) + assert not torch.allclose(c_llama3, c_linear) + + +def test_rope_scaling_unknown_type_raises(): + import pytest + + cfg = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + cfg.rope_scaling = {"type": "moonshot", "factor": 2.0} + with pytest.raises(NotImplementedError, match="moonshot"): + TransformerBridge.boot_native(cfg) + + +def test_rope_scaling_none_or_factor_one_is_noop(): + """Empty / None rope_scaling, or factor <= 1, must produce the same cos + table as no scaling. A user explicitly disabling scaling shouldn't pay + surprise drift.""" + cfg_none = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + cfg_factor_1 = _cfg(positional_embedding_type="rotary", n_heads=4, d_head=16, n_ctx=8) + cfg_factor_1.rope_scaling = {"type": "linear", "factor": 1.0} + + r_none = TransformerBridge.boot_native(cfg_none).original_model.rotary + r_f1 = TransformerBridge.boot_native(cfg_factor_1).original_model.rotary + assert torch.allclose(r_none.cos_cached, r_f1.cos_cached) + assert r_f1.position_scale == 1.0 + + def test_rotary_pattern_differs_from_absolute(): """Sanity that rotary actually changes the attention pattern relative to absolute embeddings — would catch silent no-op (e.g. cos/sin buffers @@ -341,6 +570,72 @@ def test_init_modes_diverge_from_each_other(): assert not torch.allclose(x_w, k_w) +# -- attention_mask ----------------------------------------------------------- + + +def test_attention_mask_2d_padding_changes_output(): + """A 2D HF-style padding mask (1=keep, 0=pad) must actually mask keys. + Verified by comparing outputs with and without the mask — a silent-drop + bug yields identical outputs.""" + cfg = _cfg(n_layers=2) + bridge = TransformerBridge.boot_native(cfg) + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + + # Mask out the second half of each sequence. + mask = torch.ones_like(inputs) + mask[:, cfg.n_ctx // 2 :] = 0 + + out_masked = bridge(inputs, return_type="logits", attention_mask=mask) + out_all_keep = bridge(inputs, return_type="logits", attention_mask=torch.ones_like(inputs)) + out_no_mask = bridge(inputs, return_type="logits") + + # All-keep mask matches no-mask exactly: providing all-1s must be a no-op. + assert torch.allclose(out_no_mask, out_all_keep, atol=1e-6) + # Padding mask must change the result on the un-masked positions (positions + # 0..n_ctx//2 - 1 can only attend to themselves, but masking the keys for + # positions ≥ n_ctx//2 still propagates into the residual stream through + # layer 2 because layer 1's masked positions had NaN/garbage outputs that + # the next attention reads). + # We just need *some* difference somewhere. + assert not torch.allclose( + out_masked, out_no_mask, atol=1e-4 + ), "Padding mask had no effect on output — attention_mask was silently dropped." + + +def test_attention_mask_padded_key_has_zero_pattern_weight(): + """The most direct invariant: when a key position is padded, no query + should put any weight on it (post-softmax).""" + cfg = _cfg(n_layers=1) + bridge = TransformerBridge.boot_native(cfg) + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + + mask = torch.ones_like(inputs) + mask[:, cfg.n_ctx // 2 :] = 0 # mask out keys at positions ≥ n_ctx/2 + + _, cache = bridge.run_with_cache(inputs, return_type="logits", attention_mask=mask) + pattern = cache["blocks.0.attn.hook_pattern"] + # For non-padded query rows (positions 0..n_ctx//2-1), no weight on padded + # keys. (Padded-query rows have all -inf scores → NaN/uniform post-softmax; + # we don't assert on those.) + visible_queries = pattern[:, :, : cfg.n_ctx // 2, :] + padded_key_weight = visible_queries[:, :, :, cfg.n_ctx // 2 :] + assert torch.allclose( + padded_key_weight, torch.zeros_like(padded_key_weight), atol=1e-6 + ), "Visible queries put non-zero weight on padded keys" + + +def test_attention_mask_invalid_shape_raises(): + import pytest + + cfg = _cfg(n_layers=1) + bridge = TransformerBridge.boot_native(cfg) + inputs = torch.randint(0, cfg.d_vocab, (2, cfg.n_ctx)) + # 3D mask shape isn't supported — must raise rather than silently drop. + bad_mask = torch.ones(2, cfg.n_ctx, cfg.n_ctx, dtype=torch.bool) + with pytest.raises(ValueError, match="attention_mask must be 2D"): + bridge(inputs, return_type="logits", attention_mask=bad_mask) + + # -- combo -------------------------------------------------------------------- diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index e0d6a9994..077007cd5 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -284,25 +284,8 @@ def boot_native( ) -> "TransformerBridge": """Build a bridge around a small, randomly-initialized TL-native model. - No HuggingFace Hub call, no ``transformers`` import. Useful for toy - training runs and demos that want a bridge-shaped model without going - through the HF pipeline. Init mode and seed are read from ``config`` - (``config.init_mode``, ``config.seed``). - - Args: - config: A :class:`TransformerBridgeConfig` (or dict) carrying the - model shape (``d_model``, ``n_heads``, ``n_layers``, ...) plus - ``init_mode`` and optional ``seed`` for reproducible weights. - tokenizer: Optional tokenizer. Native models are typically used - without one — defaults to ``None``. - device: Optional device override. Defaults to whatever device the - freshly-created module lands on (usually CPU). - dtype: Optional dtype override. Defaults to the module's parameter - dtype. - model_name: Recorded on ``cfg.model_name``. - - Returns: - A :class:`TransformerBridge` wrapping a fresh ``NativeModel``. + No HuggingFace Hub call, no ``transformers`` import. ``config.init_mode`` + and ``config.seed`` control reproducibility. """ import copy as _copy @@ -315,33 +298,36 @@ def boot_native( initialize_native_model, ) + cfg: TransformerBridgeConfig if isinstance(config, dict): - config = _Cfg.from_dict(config) + cfg = _Cfg.from_dict(config) else: - # Deep-copy so NativeModel's default-resolution writes (e.g. d_mlp, - # architecture) land on our private copy, never on the caller's. - config = _copy.deepcopy(config) - - # boot_native only knows how to build NativeModel and wire the Native - # adapter. Refuse foreign architecture strings here — otherwise we'd - # dispatch to (e.g.) LlamaArchitectureAdapter against a NativeModel - # tree, and the Llama adapter would crash opaquely in prepare_model - # looking for paths NativeModel doesn't have. - if config.architecture not in (None, "TransformerLensNative"): + # Deep-copy so NativeModel's default-resolution writes don't land + # on the caller's config. + cfg = _copy.deepcopy(config) + + # Foreign architecture strings would dispatch to the wrong adapter and + # crash deep in prepare_model. Refuse them with a pointing message. + if cfg.architecture not in (None, "TransformerLensNative"): raise ValueError( - f"boot_native cannot build a {config.architecture!r} model — " + f"boot_native cannot build a {cfg.architecture!r} model — " f"it only constructs the TL-native architecture. Either clear " f"config.architecture or set it to 'TransformerLensNative', " f"or use boot_transformers / build_bridge_from_module for " f"non-native architectures." ) - # Resolve the architecture locally — never mutate the caller's config. - # `build_bridge_from_module` will deep-copy and set `architecture=` on - # its own copy. architecture = "TransformerLensNative" - model = NativeModel(config) - initialize_native_model(model, config) + # Fork RNG around construction + init when seeded so neither nn.Linear's + # default reset_parameters nor our scoped init perturb the caller's RNG. + # Unseeded calls let global RNG advance normally. + if cfg.seed is not None: + with torch.random.fork_rng(devices=[]): + model = NativeModel(cfg) + initialize_native_model(model, cfg) + else: + model = NativeModel(cfg) + initialize_native_model(model, cfg) if device is not None: model = model.to(device) @@ -351,7 +337,7 @@ def boot_native( return build_bridge_from_module( model, architecture=architecture, - tl_config=config, + tl_config=cfg, tokenizer=tokenizer, dtype=dtype, device=device, diff --git a/transformer_lens/model_bridge/sources/_bridge_builder.py b/transformer_lens/model_bridge/sources/_bridge_builder.py index c64fb6c27..5ae2dc466 100644 --- a/transformer_lens/model_bridge/sources/_bridge_builder.py +++ b/transformer_lens/model_bridge/sources/_bridge_builder.py @@ -1,9 +1,4 @@ -"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model. - -Signatures and behavior mirror dev-4.x's ``_bridge_builder.py`` so that v4 -migration is mechanical: v4 will replace this module with re-exports of the -v4 builder, leaving user-facing imports unchanged. -""" +"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model.""" from __future__ import annotations import copy @@ -164,10 +159,8 @@ def build_bridge_from_module( dtype = torch.float32 if tl_config is not None: - # Defensive copy: the adapter holds onto adapter.cfg (an alias of this - # config) and mutates fields during __init__ (normalization_type, device, - # ...). Without copying, a caller that builds multiple bridges from the - # same config would see fields leak between bridges. + # Defensive copy so adapter-init mutations (normalization_type, device, + # ...) don't leak between bridges built from the same config. bridge_config = copy.deepcopy(tl_config) bridge_config.architecture = architecture if model_name != "external" or not getattr(bridge_config, "model_name", None): diff --git a/transformer_lens/model_bridge/sources/native/init.py b/transformer_lens/model_bridge/sources/native/init.py index b6927593a..53ebaaebb 100644 --- a/transformer_lens/model_bridge/sources/native/init.py +++ b/transformer_lens/model_bridge/sources/native/init.py @@ -1,27 +1,17 @@ -"""Weight init for NativeModel. Supports the standard TL init modes. +"""Weight init for NativeModel. -Modes: +Supported modes: ``"gpt2"`` (Normal(0, std) with 1/sqrt(2*n_layers) residual +scaling on output projections), ``"xavier_uniform"`` / ``"xavier_normal"``, +``"kaiming_uniform"`` / ``"kaiming_normal"`` (relu nonlinearity). Norm weights +go to 1, all biases to 0. -- ``"gpt2"`` (default): Normal(0, initializer_range) for everything, with - ``1/sqrt(2 * n_layers)`` residual scaling on attn.o and mlp output. Mirrors - HookedTransformer's ``init_mode='gpt2'`` and is the right default for GPT-style - toy models. -- ``"xavier_uniform"`` / ``"xavier_normal"``: ``torch.nn.init.xavier_{uniform,normal}_`` - on linear weights and embeddings. Useful for sanity-checking sensitivity to - init scheme without changing architecture. -- ``"kaiming_uniform"`` / ``"kaiming_normal"``: ``torch.nn.init.kaiming_{uniform,normal}_`` - with ``nonlinearity='relu'``. Reasonable default for ReLU-family activations. - -For every mode: LayerNorm/RMSNorm weights init to 1, biases to 0; Linear biases -init to 0. - -HookedTransformer's exact init pulls from a private RNG. We rely on -``torch.manual_seed`` plus per-layer ``torch.nn.init`` for reproducibility. +Determinism uses a scoped ``torch.Generator``, not ``torch.manual_seed``, so +seeded init does not perturb the caller's global RNG. """ from __future__ import annotations import math -from typing import Callable +from typing import Callable, Optional, cast import torch import torch.nn as nn @@ -37,13 +27,14 @@ NativeRMSNorm, ) -# Init modes that don't depend on residual depth — they treat every weight -# identically. The residual-scaled output projection trick is gpt2-specific. -_NON_RESIDUAL_MODES: dict[str, Callable[[torch.Tensor], None]] = { - "xavier_uniform": nn.init.xavier_uniform_, - "xavier_normal": nn.init.xavier_normal_, - "kaiming_uniform": lambda t: nn.init.kaiming_uniform_(t, nonlinearity="relu"), - "kaiming_normal": lambda t: nn.init.kaiming_normal_(t, nonlinearity="relu"), +# Residual-scaled output is gpt2-specific; other modes treat every weight the +# same. Each entry takes ``(tensor, generator)`` to thread the scoped Generator. +_NonResidualInit = Callable[[torch.Tensor, Optional[torch.Generator]], torch.Tensor] +_NON_RESIDUAL_MODES: dict[str, _NonResidualInit] = { + "xavier_uniform": lambda t, g: nn.init.xavier_uniform_(t, generator=g), + "xavier_normal": lambda t, g: nn.init.xavier_normal_(t, generator=g), + "kaiming_uniform": lambda t, g: nn.init.kaiming_uniform_(t, nonlinearity="relu", generator=g), + "kaiming_normal": lambda t, g: nn.init.kaiming_normal_(t, nonlinearity="relu", generator=g), } _SUPPORTED_MODES = frozenset({"gpt2", *_NON_RESIDUAL_MODES}) @@ -54,8 +45,19 @@ def initialize_native_model( ) -> None: """Initialize ``model`` weights in-place. Honors ``cfg.init_mode`` and ``cfg.seed``.""" effective_seed = seed if seed is not None else cfg.seed + + # Scoped generator on the model's device — None falls back to the global RNG. + try: + gen_device = next(model.parameters()).device + except StopIteration: + gen_device = torch.device("cpu") + generator: Optional[torch.Generator] if effective_seed is not None: - torch.manual_seed(effective_seed) + g = torch.Generator(device=gen_device) + g.manual_seed(effective_seed) + generator = g + else: + generator = None init_mode = (cfg.init_mode or "gpt2").lower() if init_mode not in _SUPPORTED_MODES: @@ -64,15 +66,20 @@ def initialize_native_model( f"Supported modes: {sorted(_SUPPORTED_MODES)}." ) + weight_init: Callable[[torch.Tensor], torch.Tensor] + output_init: Callable[[torch.Tensor], torch.Tensor] if init_mode == "gpt2": std = cfg.initializer_range if cfg.initializer_range > 0 else 0.02 residual_scale = 1.0 / math.sqrt(2 * cfg.n_layers) - weight_init = lambda t: nn.init.normal_(t, mean=0.0, std=std) # noqa: E731 - output_init = lambda t: nn.init.normal_(t, mean=0.0, std=std * residual_scale) # noqa: E731 + weight_init = lambda t: nn.init.normal_( + t, mean=0.0, std=std, generator=generator + ) # noqa: E731 + output_init = lambda t: nn.init.normal_( # noqa: E731 + t, mean=0.0, std=std * residual_scale, generator=generator + ) else: - # Non-gpt2 modes ignore residual-depth scaling — they have their own - # gain rules that already account for the layer's role. - weight_init = _NON_RESIDUAL_MODES[init_mode] + fn = _NON_RESIDUAL_MODES[init_mode] + weight_init = lambda t: fn(t, generator) # noqa: E731 output_init = weight_init weight_init(model.tok_embed.weight) @@ -100,8 +107,8 @@ def _init_norm(norm: nn.Module) -> None: def _init_block( block: NativeBlock, *, - weight_init: Callable[[torch.Tensor], None], - output_init: Callable[[torch.Tensor], None], + weight_init: Callable[[torch.Tensor], torch.Tensor], + output_init: Callable[[torch.Tensor], torch.Tensor], ) -> None: _init_norm(block.ln1) _init_attention(block.attn, weight_init=weight_init, output_init=output_init) @@ -116,8 +123,8 @@ def _init_block( def _init_attention( attn: NativeAttention, *, - weight_init: Callable[[torch.Tensor], None], - output_init: Callable[[torch.Tensor], None], + weight_init: Callable[[torch.Tensor], torch.Tensor], + output_init: Callable[[torch.Tensor], torch.Tensor], ) -> None: for linear in (attn.q, attn.k, attn.v): weight_init(linear.weight) @@ -131,8 +138,8 @@ def _init_attention( def _init_mlp( mlp: NativeMLP, *, - weight_init: Callable[[torch.Tensor], None], - output_init: Callable[[torch.Tensor], None], + weight_init: Callable[[torch.Tensor], torch.Tensor], + output_init: Callable[[torch.Tensor], torch.Tensor], ) -> None: weight_init(mlp.fc_in.weight) nn.init.zeros_(mlp.fc_in.bias) @@ -143,10 +150,11 @@ def _init_mlp( def _init_gated_mlp( mlp: NativeGatedMLP, *, - weight_init: Callable[[torch.Tensor], None], - output_init: Callable[[torch.Tensor], None], + weight_init: Callable[[torch.Tensor], torch.Tensor], + output_init: Callable[[torch.Tensor], torch.Tensor], ) -> None: weight_init(mlp.gate.weight) # ``in`` is registered via add_module; getattr resolves it from _modules. - weight_init(getattr(mlp, "in").weight) + in_proj = cast(nn.Linear, getattr(mlp, "in")) + weight_init(in_proj.weight) output_init(mlp.out.weight) diff --git a/transformer_lens/model_bridge/sources/native/model.py b/transformer_lens/model_bridge/sources/native/model.py index a905140c5..2035ade42 100644 --- a/transformer_lens/model_bridge/sources/native/model.py +++ b/transformer_lens/model_bridge/sources/native/model.py @@ -1,33 +1,15 @@ -"""TL-native transformer model for use with TransformerBridge. - -A minimal, from-scratch transformer implementation with no HuggingFace or -HookedTransformer dependency. Internal attribute names are deliberately chosen -to NOT collide with the bridge's top-level component slot names -("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__ falls back -to ``original_model.`` and an HF-style collision would block add_module -during bridge setup. - -Features driven by config fields: - -- ``normalization_type``: ``"LN"`` (default) or ``"RMS"`` / ``"RMSPre"``. -- ``final_rms``: when True, the final norm uses RMS regardless of block norm. -- ``gated_mlp``: when True, swaps in a SwiGLU-style gated MLP (Llama/Mistral). -- ``attn_only``: when True, blocks have no MLP / no ln2. -- ``n_key_value_heads``: when set and < ``n_heads``, enables grouped-query - attention (Llama 3.x / Mistral / DeepSeek style). -- ``attn_scores_soft_cap``: when > 0, applies Gemma2-style tanh soft-cap to - pre-softmax attention scores. -- ``output_logits_soft_cap``: when > 0, applies tanh soft-cap to final logits. -- ``positional_embedding_type``: ``"standard"`` (absolute, default) or - ``"rotary"``. Rotary applies inside attention; absolute uses ``self.pos``. -- ``rotary_dim``: partial-rotary dim (rotates first ``rotary_dim`` of each - head; pass-through the rest). Default ``d_head``. -- ``rotary_base``: RoPE base frequency. Default ``10000``. +"""TL-native transformer for TransformerBridge — minimal, no HF/HT dependency. + +Cfg-driven features: ``normalization_type`` (LN / RMS / RMSPre), ``final_rms``, +``gated_mlp``, ``attn_only``, ``n_key_value_heads`` (GQA), ``attn_scores_soft_cap``, +``output_logits_soft_cap``, ``positional_embedding_type`` (standard / rotary), +``rotary_dim`` / ``rotary_base`` / ``rope_scaling`` (linear PI, dynamic/NTK, +llama3 by-parts). """ from __future__ import annotations import math -from typing import Optional +from typing import Callable, Optional, cast import torch import torch.nn as nn @@ -35,10 +17,10 @@ from transformer_lens.config import TransformerBridgeConfig -# gelu_new is the tanh-approximation of GELU (what HF GPT-2 and HookedTransformer -# use). PyTorch's F.gelu accepts approximate="tanh" since 1.10 — that's exactly -# the same formula, no need to roll our own. -_ACTIVATIONS = { +# gelu_new = the tanh-approximation HF GPT-2 / HT use; F.gelu(approximate="tanh") +# is the exact same formula. +_Activation = Callable[[torch.Tensor], torch.Tensor] +_ACTIVATIONS: dict[str, _Activation] = { "gelu": F.gelu, "gelu_new": lambda x: F.gelu(x, approximate="tanh"), "relu": F.relu, @@ -56,19 +38,8 @@ def _positional_kind(cfg: TransformerBridgeConfig) -> str: class NativeRMSNorm(nn.Module): - """Root-mean-square LayerNorm. No mean centering, no bias. - - Matches the math used by Llama / Mistral / T5: ``y = w * x / rms(x)`` where - ``rms(x) = sqrt(mean(x^2) + eps)``. The variance is computed in fp32 - regardless of input dtype — mirroring HF Llama's LlamaRMSNorm — so bf16/fp16 - inputs don't accumulate variance drift. The result is cast back to the - input dtype before the per-channel scale, so the scale runs in the user's - chosen precision. - - The bridge's RMSNormalizationBridge wraps any module with a ``weight`` - attribute and a forward returning the normalized tensor — no further - coordination required. - """ + """Llama-style RMSNorm. Variance in fp32 regardless of input dtype, then + cast back before the per-channel scale (matches HF LlamaRMSNorm).""" def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() @@ -89,13 +60,69 @@ def _make_norm(cfg: TransformerBridgeConfig, *, force_rms: bool = False) -> nn.M return nn.LayerNorm(cfg.d_model, eps=cfg.eps) +def _resolve_rope_scaling( + cfg: TransformerBridgeConfig, rotary_dim: int +) -> tuple[float, float, torch.Tensor]: + """Returns (effective_base, position_scale, inv_freq) per cfg.rope_scaling.""" + base = float(cfg.rotary_base) + rope_scaling = getattr(cfg, "rope_scaling", None) + inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + + if not isinstance(rope_scaling, dict): + return base, 1.0, inv_freq + + # Newer HF configs key on "rope_type"; older ones on "type". + scale_type = str(rope_scaling.get("rope_type") or rope_scaling.get("type") or "").lower() + factor = float(rope_scaling.get("factor", 1.0)) + + if scale_type in ("", "default") or factor <= 1.0: + return base, 1.0, inv_freq + + if scale_type == "linear": + return base, factor, inv_freq + + if scale_type in ("dynamic", "ntk"): + scaled_base = base * (factor ** (rotary_dim / (rotary_dim - 2))) + new_inv_freq = 1.0 / (scaled_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + return scaled_base, 1.0, new_inv_freq + + if scale_type == "llama3": + low_freq_factor = float(rope_scaling.get("low_freq_factor", 1.0)) + high_freq_factor = float(rope_scaling.get("high_freq_factor", 4.0)) + original_ctx = float( + rope_scaling.get("original_max_position_embeddings") + or rope_scaling.get("original_context_length") + or 8192 + ) + low_wavelen = original_ctx / low_freq_factor + high_wavelen = original_ctx / high_freq_factor + wavelens = 2 * math.pi / inv_freq + # Three regimes: low-freq → divide by factor; high-freq → unchanged; + # in-between → smooth linear interpolation between the two. + smooth = (original_ctx / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_inv_freq = torch.where( + wavelens > low_wavelen, + inv_freq / factor, + torch.where( + wavelens < high_wavelen, + inv_freq, + (1 - smooth) * inv_freq / factor + smooth * inv_freq, + ), + ) + return base, 1.0, new_inv_freq + + raise NotImplementedError( + f"rope_scaling type {scale_type!r} is not supported. " + f"Supported: 'linear', 'dynamic'/'ntk', 'llama3'." + ) + + class NativeRotary(nn.Module): - """Pre-computes the cos/sin tables used by RoPE. + """Shared cos/sin tables for RoPE. Honors ``cfg.rope_scaling``.""" - Lives at the model level (one shared instance) so all attention layers - re-use the same buffers. Per-call, we just slice to the current sequence - length. No HF dependency. - """ + # Declared so mypy sees the buffer dtype; register_buffer alone reports Module|Tensor. + cos_cached: torch.Tensor + sin_cached: torch.Tensor def __init__(self, cfg: TransformerBridgeConfig): super().__init__() @@ -103,31 +130,49 @@ def __init__(self, cfg: TransformerBridgeConfig): if rotary_dim <= 0 or rotary_dim % 2 != 0: raise ValueError(f"rotary_dim must be a positive even integer, got {rotary_dim!r}") self.rotary_dim = rotary_dim - base = float(cfg.rotary_base) - inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) - positions = torch.arange(cfg.n_ctx).float() - freqs = torch.outer(positions, inv_freq) # [n_ctx, rotary_dim/2] - # Adjacent-pair format (the form Llama/HF use): each pair (2i, 2i+1) - # rotates together. We expand cos/sin per element of each pair. - cos = freqs.cos().repeat_interleave(2, dim=-1) # [n_ctx, rotary_dim] + + base, position_scale, inv_freq = _resolve_rope_scaling(cfg, rotary_dim) + + positions = torch.arange(cfg.n_ctx).float() / position_scale + freqs = torch.outer(positions, inv_freq) + # Llama/HF adjacent-pair format: each (2i, 2i+1) pair rotates together. + cos = freqs.cos().repeat_interleave(2, dim=-1) sin = freqs.sin().repeat_interleave(2, dim=-1) self.register_buffer("cos_cached", cos, persistent=False) self.register_buffer("sin_cached", sin, persistent=False) + self.effective_base = base + self.position_scale = position_scale @staticmethod def _rotate_half(x: torch.Tensor) -> torch.Tensor: - # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0) + # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0). x1 = x[..., 0::2] x2 = x[..., 1::2] rot = torch.stack((-x2, x1), dim=-1) return rot.flatten(-2) - def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Apply RoPE to Q and K. Tensors are [batch, heads, seq, d_head].""" + def apply_rope( + self, + q: torch.Tensor, + k: torch.Tensor, + *, + position_ids: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply RoPE to Q/K of shape [batch, heads, seq, d_head]. + + Named ``apply_rope`` rather than ``apply`` so ``nn.Module.apply(fn)`` + — PyTorch's recursive function-application utility used by + ``bridge.apply(init_fn)`` — isn't shadowed. + """ seq = q.shape[-2] rd = self.rotary_dim - cos = self.cos_cached[:seq].to(q.dtype) # [seq, rd] - sin = self.sin_cached[:seq].to(q.dtype) + if position_ids is None: + cos = self.cos_cached[:seq].to(q.dtype) + sin = self.sin_cached[:seq].to(q.dtype) + else: + # [batch, seq] -> [batch, 1, seq, rd] (head dim for broadcast). + cos = self.cos_cached[position_ids].to(q.dtype).unsqueeze(1) + sin = self.sin_cached[position_ids].to(q.dtype).unsqueeze(1) def _rope(x: torch.Tensor) -> torch.Tensor: x_rot, x_pass = x[..., :rd], x[..., rd:] @@ -138,11 +183,10 @@ def _rope(x: torch.Tensor) -> torch.Tensor: class NativeAttention(nn.Module): - """Split-QKV causal self-attention with optional GQA, RoPE, and soft-cap. + """Split-QKV causal self-attention. Returns (out, pattern); AttentionBridge + fires ``hook_pattern`` off the second element.""" - Returns ``(attn_output, attention_weights)`` so the bridge's AttentionBridge - fires ``hook_pattern`` off the second element. - """ + causal_mask: torch.Tensor def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None): super().__init__() @@ -150,8 +194,6 @@ def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] self.n_heads = cfg.n_heads self.d_head = cfg.d_head self.d_model = cfg.d_model - # n_key_value_heads governs GQA: K/V have fewer heads than Q. Default - # to n_heads (= standard multi-head attention). self.n_kv_heads = cfg.n_key_value_heads or cfg.n_heads if self.n_heads % self.n_kv_heads != 0: raise ValueError( @@ -170,14 +212,29 @@ def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] mask = torch.triu(torch.ones(cfg.n_ctx, cfg.n_ctx, dtype=torch.bool), diagonal=1) self.register_buffer("causal_mask", mask, persistent=False) - scale = ( - cfg.attn_scale if cfg.use_attn_scale and cfg.attn_scale > 0 else math.sqrt(cfg.d_head) - ) + # attn_scale=1.0 reads like "standard scaling" but is "divide by 1" — + # i.e. unscaled scores, which saturate softmax for d_head>1. + if cfg.use_attn_scale and cfg.attn_scale > 0: + if self.d_head > 1 and math.isclose(cfg.attn_scale, 1.0, abs_tol=1e-9): + raise ValueError( + f"attn_scale=1.0 with d_head={self.d_head} (>1) is unscaled " + f"attention; softmax will saturate. For standard scaling " + f"leave attn_scale at -1 (sentinel for sqrt(d_head))." + ) + scale = cfg.attn_scale + else: + scale = math.sqrt(cfg.d_head) self.scale = scale - self.rotary = rotary # None unless cfg.positional_embedding_type == "rotary" + self.rotary = rotary self.attn_scores_soft_cap = float(cfg.attn_scores_soft_cap) - def forward(self, hidden_states: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: batch, seq, _ = hidden_states.shape q = self.q(hidden_states).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2) @@ -185,35 +242,64 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> tuple[torch.Tensor, v = self.v(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2) if self.rotary is not None: - q, k = self.rotary.apply(q, k) + q, k = self.rotary.apply_rope(q, k, position_ids=position_ids) - # Expand K/V to match Q head count under GQA. repeat_interleave keeps - # group ordering consistent with HF Llama's repeat_kv. + # GQA: repeat_interleave matches HF Llama's repeat_kv group ordering. if self.kv_repeats > 1: k = k.repeat_interleave(self.kv_repeats, dim=1) v = v.repeat_interleave(self.kv_repeats, dim=1) scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale - # Gemma2-style attention soft-cap: c * tanh(scores / c). Bounds raw - # logits before the causal mask so masked positions stay -inf. + # Gemma2 soft-cap before the causal mask so masked positions stay -inf. if self.attn_scores_soft_cap > 0: c = self.attn_scores_soft_cap scores = c * torch.tanh(scores / c) - mask = self.causal_mask[:seq, :seq] - scores = scores.masked_fill(mask, float("-inf")) + + block_mask = self.causal_mask[:seq, :seq] + if attention_mask is not None: + block_mask = self._combine_attention_mask(block_mask, attention_mask, batch=batch) + scores = scores.masked_fill(block_mask, float("-inf")) + pattern = F.softmax(scores, dim=-1) attn = torch.matmul(pattern, v).transpose(1, 2).contiguous().view(batch, seq, -1) out = self.o(attn) return out, pattern + @staticmethod + def _combine_attention_mask( + block_mask: torch.Tensor, attention_mask: torch.Tensor, *, batch: int + ) -> torch.Tensor: + """Combine an external attention_mask with the causal mask. + + Accepts 2D HF padding mask ``[batch, seq]`` (1=keep, 0=mask), 4D bool + mask (True=mask), or 4D additive float mask (HF generation style; values + below -1 treated as masked). + """ + if attention_mask.dim() == 2: + pad_mask = ~attention_mask.bool() + return block_mask | pad_mask[:, None, None, :] + if attention_mask.dim() == 4: + if attention_mask.dtype is torch.bool: + return block_mask | attention_mask + # HF additive masks use -inf or large negatives; benign biases bounded. + extra = attention_mask < -1.0 + return block_mask | extra + raise ValueError( + f"attention_mask must be 2D [batch, seq] or 4D [batch, *, seq, seq], " + f"got shape {tuple(attention_mask.shape)}." + ) + class NativeMLP(nn.Module): """Two-layer MLP with configurable activation.""" + act: Callable[[torch.Tensor], torch.Tensor] + def __init__(self, cfg: TransformerBridgeConfig): super().__init__() - d_mlp = cfg.d_mlp + assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs" + d_mlp: int = cfg.d_mlp self.fc_in = nn.Linear(cfg.d_model, d_mlp, bias=True) self.fc_out = nn.Linear(d_mlp, cfg.d_model, bias=True) act_name = (cfg.act_fn or "gelu").lower() @@ -226,38 +312,34 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: class NativeGatedMLP(nn.Module): - """SwiGLU-style gated MLP (Llama / Mistral / Gemma2). + """SwiGLU / ReGLU / GeGLU gated MLP (variant picked by ``cfg.act_fn``). - Submodule names ``gate`` / ``in`` / ``out`` align with the bridge's - GatedMLPBridge submodule slots; the adapter wires them by these names. + Submodules ``gate`` / ``in`` / ``out`` match GatedMLPBridge's expected slots. """ + act: Callable[[torch.Tensor], torch.Tensor] + def __init__(self, cfg: TransformerBridgeConfig): super().__init__() - d_mlp = cfg.d_mlp - # No biases by default — matches Llama. Users wanting biased gated MLPs - # can subclass; toy-scope stays simple. + assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs" + d_mlp: int = cfg.d_mlp + # Llama convention: no biases on gated MLP projections. self.gate = nn.Linear(cfg.d_model, d_mlp, bias=False) - # ``in`` is a Python keyword, so we can't write ``self.in = ...`` — - # but ``add_module`` accepts any string and stores it in ``_modules``, - # so ``getattr(self, "in")`` resolves it the same way the bridge does - # when walking ``LinearBridge(name="in")``. No __getattr__ override - # required. + # ``in`` is a Python keyword; add_module + getattr(self, "in") works + # because the bridge resolves LinearBridge(name="in") the same way. self.add_module("in", nn.Linear(cfg.d_model, d_mlp, bias=False)) self.out = nn.Linear(d_mlp, cfg.d_model, bias=False) - # Gated MLPs typically pair with SiLU/swish; honor cfg if the user picked - # a different activation, but default to silu. + # Default to SwiGLU; mirror NativeMLP's dispatch so a typo'd act_fn + # raises instead of silently changing the model. act_name = (cfg.act_fn or "silu").lower() - if act_name == "gelu": # GeGLU variant - self.act = _ACTIVATIONS["gelu"] - elif act_name == "gelu_new": - self.act = _ACTIVATIONS["gelu_new"] - else: - self.act = F.silu + if act_name not in _ACTIVATIONS: + raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}") + self.act = _ACTIVATIONS[act_name] def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: gate_out = self.act(self.gate(hidden_states)) - up_out = getattr(self, "in")(hidden_states) + in_proj = cast(nn.Linear, getattr(self, "in")) + up_out = in_proj(hidden_states) return self.out(gate_out * up_out) @@ -274,25 +356,35 @@ def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] self.ln2 = _make_norm(cfg) self.mlp = NativeGatedMLP(cfg) if cfg.gated_mlp else NativeMLP(cfg) - def forward(self, hidden_states: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: - attn_out, _pattern = self.attn(self.ln1(hidden_states)) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor]: + attn_out, _pattern = self.attn( + self.ln1(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + ) hidden_states = hidden_states + attn_out if not self.cfg.attn_only: hidden_states = hidden_states + self.mlp(self.ln2(hidden_states)) - # Tuple return matches HF block convention so BlockBridge's parser is happy. + # Tuple return matches HF block convention; BlockBridge's parser expects it. return (hidden_states,) class NativeModel(nn.Module): """TL-native transformer. See module docstring for the supported feature set.""" + pos: Optional[nn.Embedding] + rotary: Optional[NativeRotary] + def __init__(self, cfg: TransformerBridgeConfig): super().__init__() - # Resolve defaults that NativeMLP / nn.Embedding need, and write them - # back so downstream consumers reading cfg.d_mlp see the real value - # instead of None. Mutates the supplied cfg; callers that want isolation - # (e.g. TransformerBridge.boot_native) deep-copy the user's cfg before - # constructing the model. + # Write the resolved d_mlp back so downstream consumers see the real + # value, not None. Mutates cfg; isolating callers should deep-copy first. if not getattr(cfg, "d_mlp", None): cfg.d_mlp = 4 * cfg.d_model self.cfg = cfg @@ -315,9 +407,8 @@ def __init__(self, cfg: TransformerBridgeConfig): self.layers = nn.ModuleList( [NativeBlock(cfg, rotary=self.rotary) for _ in range(cfg.n_layers)] ) - # final_rms overrides the block-norm choice — Llama uses LN-equivalent - # blocks but final_rms is true in TL config to opt into RMSNorm on the - # final norm. We honor the same semantic. + # final_rms forces RMS on the final norm regardless of block-norm choice + # — matches the TL config semantic Llama uses. self.ln_out = _make_norm(cfg, force_rms=cfg.final_rms) d_vocab_out = cfg.d_vocab_out if cfg.d_vocab_out > 0 else cfg.d_vocab self.head = nn.Linear(cfg.d_model, d_vocab_out, bias=False) @@ -330,23 +421,32 @@ def forward( position_ids: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - """Returns logits directly. The bridge unwraps either .logits, tuple[0], - or a bare tensor — we pick the simplest path. - """ + """Returns logits directly.""" + # Bounds check up front so both absolute and rotary paths produce a + # self-explanatory error rather than IndexError / shape mismatch. + seq_len = input_ids.shape[-1] + if seq_len > self.cfg.n_ctx: + raise ValueError( + f"input length {seq_len} exceeds n_ctx={self.cfg.n_ctx}; " + f"position embeddings and rotary tables are pre-baked at n_ctx." + ) + + # Resolve position_ids before the block loop so rotary sees the caller's + # positions, not the dense default. + batch, seq = input_ids.shape + if position_ids is None: + position_ids = torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1) + hidden_states = self.tok_embed(input_ids) if self.pos is not None: - batch, seq = input_ids.shape - if position_ids is None: - position_ids = ( - torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1) - ) hidden_states = hidden_states + self.pos(position_ids) for block in self.layers: - (hidden_states,) = block(hidden_states) + (hidden_states,) = block( + hidden_states, attention_mask=attention_mask, position_ids=position_ids + ) hidden_states = self.ln_out(hidden_states) logits = self.head(hidden_states) - # Gemma2-style output soft-cap. if self.output_logits_soft_cap > 0: c = self.output_logits_soft_cap logits = c * torch.tanh(logits / c) diff --git a/transformer_lens/model_bridge/supported_architectures/native.py b/transformer_lens/model_bridge/supported_architectures/native.py index afd5a3fbf..697a3e8e0 100644 --- a/transformer_lens/model_bridge/supported_architectures/native.py +++ b/transformer_lens/model_bridge/supported_architectures/native.py @@ -1,13 +1,7 @@ """Architecture adapter for TL-native models built via ``boot_native``. -This adapter targets ``NativeModel`` ([sources/native/model.py]). Because the -native module's hierarchy is fully under our control, the component paths are -flat (no ``transformer.h.{i}`` prefix) and split-QKV is the natural layout — -no weight conversions are required for ordinary use. - -The component mapping adapts to the cfg: gated MLP swaps in ``GatedMLPBridge``, -RMS norm swaps in ``RMSNormalizationBridge``, rotary skips ``pos_embed``, and -``attn_only`` drops the MLP branch. +Component mapping adapts to cfg: gated MLP → ``GatedMLPBridge``, RMS norm → +``RMSNormalizationBridge``, rotary drops ``pos_embed``, ``attn_only`` drops MLP. """ from typing import Any @@ -87,21 +81,17 @@ class NativeArchitectureAdapter(ArchitectureAdapter): def __init__(self, cfg: Any) -> None: super().__init__(cfg) - # Native layout already stores Q/K/V split per-head in [d_model, n*d_head] - # form. We skip weight_processing_conversions for ordinary use; compatibility - # mode (fold_ln + center_writing_weights) can be added in a follow-up. - # Until then, gate the corresponding ProcessWeights paths off: without the - # state-dict key conversions wired up, folding would silently mis-place - # weights or raise on missing keys. + # Native layout already stores Q/K/V split; no rearranges needed. + # Compatibility-mode fold_ln / center_writing_weights aren't wired up, + # so gate the corresponding ProcessWeights paths off — folding without + # the state-dict conversions would mis-place or drop weights. self.supports_fold_ln = False self.supports_center_writing_weights = False self.weight_processing_conversions = {} - # Native model uses non-colliding attribute names ("tok_embed", "layers", - # "ln_out", "head") because the bridge's __getattr__ forwards unknown - # names to original_model., which would shadow the bridge's own - # component slots ("embed", "blocks", "ln_final", "unembed") during - # add_module if they matched 1:1. + # Internal attribute names avoid collisions with bridge slot names + # ("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__ + # forwards to original_model and would shadow add_module otherwise. mapping: dict = { "embed": EmbeddingBridge(name="tok_embed"), } @@ -112,20 +102,16 @@ def __init__(self, cfg: Any) -> None: config=self.cfg, submodules=_make_block_submodules(self.cfg), ) - # Under attn_only the ln2 and mlp submodules are absent, but - # BlockBridge's class-level hook_aliases still points - # ``hook_resid_mid -> ln2.hook_in`` and ``hook_mlp_out -> mlp.hook_out``. - # _register_aliases warns when those don't resolve. Drop them so the - # warnings stay meaningful elsewhere — the pattern mirrors - # ParallelBlockBridge ([block.py:405-407]). + # Under attn_only there's no ln2 / mlp to point at; drop the aliases + # that would otherwise warn during _register_aliases. if self.cfg.attn_only: if block_bridge.hook_aliases is BlockBridge.hook_aliases: block_bridge.hook_aliases = dict(block_bridge.hook_aliases) block_bridge.hook_aliases.pop("hook_resid_mid", None) block_bridge.hook_aliases.pop("hook_mlp_out", None) mapping["blocks"] = block_bridge - # ``final_rms`` opts into RMSNorm on the final norm regardless of - # whether the blocks themselves use RMS — Llama-style configs do this. + # final_rms forces RMS on the final norm independent of block norm — + # matches Llama's TL config semantic. mapping["ln_final"] = _make_norm_bridge( "ln_out", self.cfg, force_rms=bool(getattr(self.cfg, "final_rms", False)) ) @@ -133,23 +119,13 @@ def __init__(self, cfg: Any) -> None: self.component_mapping = mapping def prepare_model(self, model: Any) -> None: - """Reject modules whose attribute names would collide with bridge slots. - - The reserved-slot set is derived from ``self.component_mapping.keys()`` - at call time — single source of truth. A future variant that adds (or - omits) a top-level slot extends the collision check automatically; no - sibling list to keep in sync. - - The bridge's ``__getattr__`` falls back to ``getattr(original_model, name)`` - for unknown attributes — that resolves submodules, registered buffers, - plain tensors set with ``self.x = ...``, and any property. Any of these - will make ``add_module`` raise during bridge setup. We use ``hasattr`` - (not ``name in model._modules``) so the check covers all attribute - shapes, not just registered nn.Modules. - - Failing here makes the diagnostic point at the real cause instead of a - ``KeyError: "attribute 'embed' already exists"`` deep in component - setup. + """Reject modules whose attribute names collide with bridge slots. + + Bridge's ``__getattr__`` falls back to ``getattr(original_model, name)`` + for unknown attrs, so a name match — submodule, buffer, plain tensor, + or property — makes ``add_module`` raise mid-setup with an opaque + message. Failing here points at the real cause. Reserved set is derived + from ``component_mapping.keys()`` so adapter variants stay in sync. """ reserved = set(self.component_mapping.keys()) if self.component_mapping else set() collisions = sorted(name for name in reserved if hasattr(model, name))