Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
Gpt2LmHeadCustomArchitectureAdapter,
GptjArchitectureAdapter,
GPTOSSArchitectureAdapter,
GraniteArchitectureAdapter,
GraniteMoeArchitectureAdapter,
GraniteMoeHybridArchitectureAdapter,
LlamaArchitectureAdapter,
LlavaArchitectureAdapter,
LlavaNextArchitectureAdapter,
Expand Down Expand Up @@ -51,6 +54,9 @@
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
"Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
"Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
"GraniteForCausalLM": GraniteArchitectureAdapter,
"GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
"GPT2LMHeadModel": GPT2ArchitectureAdapter,
"GptOssForCausalLM": GPTOSSArchitectureAdapter,
"GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,
Expand Down
6 changes: 3 additions & 3 deletions transformer_lens/model_bridge/generalized_components/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def get_random_inputs(
if dtype is None:
dtype = torch.float32
d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
return {
"hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
}
# Use positional args to avoid parameter name mismatches across MoE implementations
# (e.g., Mixtral uses "hidden_states", GraniteMoe uses "layer_input")
return {"args": (torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype),)}

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass through the MoE bridge.
Expand Down
5 changes: 5 additions & 0 deletions transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ def boot(
attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
if attn_logit_softcapping is not None:
bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
# Propagate position_embedding_type for Granite Hybrid models that use
# "nope" (no positional embeddings) instead of "rope" on some/all layers.
position_embedding_type = getattr(hf_config, "position_embedding_type", None)
if position_embedding_type is not None:
bridge_config.position_embedding_type = position_embedding_type
# Propagate vision config for multimodal models so the adapter can
# select the correct vision encoder bridge (CLIP vs SigLIP).
if hasattr(hf_config, "vision_config") and hf_config.vision_config is not None:
Expand Down
12 changes: 12 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
from transformer_lens.model_bridge.supported_architectures.gemma3_multimodal import (
Gemma3MultimodalArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.granite import (
GraniteArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.granite_moe import (
GraniteMoeArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.granite_moe_hybrid import (
GraniteMoeHybridArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gpt2 import (
GPT2ArchitectureAdapter,
)
Expand Down Expand Up @@ -116,6 +125,9 @@
"Gemma2ArchitectureAdapter",
"Gemma3ArchitectureAdapter",
"Gemma3MultimodalArchitectureAdapter",
"GraniteArchitectureAdapter",
"GraniteMoeArchitectureAdapter",
"GraniteMoeHybridArchitectureAdapter",
"GPT2ArchitectureAdapter",
"GPTOSSArchitectureAdapter",
"Gpt2LmHeadCustomArchitectureAdapter",
Expand Down
164 changes: 164 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/granite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Granite architecture adapter.

Base adapter for the IBM Granite model family. Provides shared config setup and
helper methods used by GraniteMoe and GraniteMoeHybrid variants.
"""

from typing import Any, Dict

from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
from transformer_lens.conversion_utils.param_processing_conversion import (
ParamProcessingConversion,
)
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.generalized_components import (
BlockBridge,
EmbeddingBridge,
GatedMLPBridge,
LinearBridge,
PositionEmbeddingsAttentionBridge,
RMSNormalizationBridge,
RotaryEmbeddingBridge,
UnembeddingBridge,
)


class GraniteArchitectureAdapter(ArchitectureAdapter):
"""Architecture adapter for IBM Granite models (dense).

Granite is a Llama-like architecture with RMSNorm, rotary position embeddings
(RoPE), GQA, and a gated MLP (SiLU activation). Granite-specific scaling
multipliers are handled by the HF model's native forward pass.

Optional Parameters (may not exist in state_dict):
-------------------------------------------------
Granite models do NOT have biases on attention and MLP projections:

- blocks.{i}.attn.b_Q/b_K/b_V/b_O - No bias on attention projections
- blocks.{i}.mlp.b_in/b_gate/b_out - No bias on MLP projections
- blocks.{i}.ln1.b, blocks.{i}.ln2.b, ln_final.b - RMSNorm has no bias
"""

def __init__(self, cfg: Any) -> None:
"""Initialize the Granite architecture adapter."""
super().__init__(cfg)

self._setup_common_config(cfg)
n_kv_heads = self._get_n_kv_heads()
self.weight_processing_conversions = self._build_attn_weight_conversions(n_kv_heads)
self.component_mapping = self._build_component_mapping()

def _setup_common_config(self, cfg: Any) -> None:
"""Set up config variables shared across all Granite variants."""
self.cfg.normalization_type = "RMS"
self.cfg.positional_embedding_type = "rotary"
self.cfg.final_rms = True
self.cfg.gated_mlp = True
self.cfg.attn_only = False
self.cfg.uses_rms_norm = True
self.cfg.eps_attr = "variance_epsilon"

self.default_config = {
"d_model": cfg.d_model,
"d_head": cfg.d_model // cfg.n_heads,
"n_heads": cfg.n_heads,
"n_layers": cfg.n_layers,
"d_vocab": cfg.d_vocab,
}

if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
self.cfg.n_key_value_heads = cfg.n_key_value_heads

def _get_n_kv_heads(self) -> int:
"""Get the number of key-value heads (for GQA or MHA)."""
if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None:
return self.cfg.n_key_value_heads
return self.cfg.n_heads

def _build_attn_weight_conversions(
self, n_kv_heads: int
) -> Dict[str, ParamProcessingConversion | str]:
"""Build weight processing conversions for attention projections."""
return {
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
),
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
),
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
),
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
),
}

def _build_attention_bridge(self) -> PositionEmbeddingsAttentionBridge:
"""Build the standard Granite attention bridge."""
return PositionEmbeddingsAttentionBridge(
name="self_attn",
config=self.cfg,
submodules={
"q": LinearBridge(name="q_proj"),
"k": LinearBridge(name="k_proj"),
"v": LinearBridge(name="v_proj"),
"o": LinearBridge(name="o_proj"),
},
requires_attention_mask=True,
requires_position_embeddings=True,
)

def _build_mlp_bridge(self) -> GatedMLPBridge:
"""Build the dense gated MLP bridge."""
return GatedMLPBridge(
name="mlp",
config=self.cfg,
submodules={
"gate": LinearBridge(name="gate_proj"),
"in": LinearBridge(name="up_proj"),
"out": LinearBridge(name="down_proj"),
},
)

def _build_component_mapping(self) -> dict:
"""Build the full component mapping for dense Granite."""
return {
"embed": EmbeddingBridge(name="model.embed_tokens"),
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
"blocks": BlockBridge(
name="model.layers",
submodules={
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
"attn": self._build_attention_bridge(),
"mlp": self._build_mlp_bridge(),
},
),
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references for Granite component testing.

Args:
hf_model: The HuggingFace Granite model instance
bridge_model: The TransformerBridge model (if available)
"""
if not hasattr(hf_model.model, "rotary_emb"):
return

rotary_emb = hf_model.model.rotary_emb

if bridge_model is not None and hasattr(bridge_model, "blocks"):
for block in bridge_model.blocks:
if hasattr(block, "attn"):
block.attn.set_rotary_emb(rotary_emb)

try:
attn_bridge = self.get_generalized_component("blocks.0.attn")
attn_bridge.set_rotary_emb(rotary_emb)
except (AttributeError, KeyError):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Granite MoE architecture adapter."""

from transformer_lens.model_bridge.generalized_components import (
BlockBridge,
EmbeddingBridge,
MoEBridge,
RMSNormalizationBridge,
RotaryEmbeddingBridge,
UnembeddingBridge,
)
from transformer_lens.model_bridge.supported_architectures.granite import (
GraniteArchitectureAdapter,
)


class GraniteMoeArchitectureAdapter(GraniteArchitectureAdapter):
"""Architecture adapter for IBM Granite MoE models.

Identical to dense Granite but replaces the gated MLP with a Sparse Mixture
of Experts block (block_sparse_moe) using batched expert parameters and
top-k routing.
"""

def _build_component_mapping(self) -> dict:
"""Build component mapping with MoE instead of dense MLP."""
return {
"embed": EmbeddingBridge(name="model.embed_tokens"),
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
"blocks": BlockBridge(
name="model.layers",
submodules={
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
"attn": self._build_attention_bridge(),
"mlp": MoEBridge(
name="block_sparse_moe",
config=self.cfg,
),
},
),
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Granite MoE Hybrid architecture adapter.

GraniteMoeHybridForCausalLM is a hybrid Mamba + Attention architecture with
Sparse Mixture of Experts. Layers alternate between Mamba SSM blocks and
standard attention blocks, with a shared MLP and optional sparse MoE on
every layer.

Since self_attn is None on Mamba layers and mamba is None on attention
layers, we only map submodules that exist on ALL layers (norms, shared_mlp,
block_sparse_moe). The HF native forward handles mamba/attention dispatch.
"""

from typing import Any

from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.generalized_components import (
BlockBridge,
EmbeddingBridge,
LinearBridge,
MLPBridge,
MoEBridge,
RMSNormalizationBridge,
RotaryEmbeddingBridge,
UnembeddingBridge,
)
from transformer_lens.model_bridge.supported_architectures.granite import (
GraniteArchitectureAdapter,
)


class GraniteMoeHybridArchitectureAdapter(GraniteArchitectureAdapter):
"""Architecture adapter for IBM Granite MoE Hybrid models.

Hybrid Mamba2 + Attention architecture with Sparse MoE. Most layers are Mamba
SSM blocks; a few are standard attention (determined by config.layer_types).

Since self_attn is None on Mamba layers and mamba is None on attention layers,
we only map submodules present on ALL layers (norms, shared_mlp, MoE). The HF
native forward handles mamba/attention dispatch internally.

Hook coverage:
- Block-level: hook_resid_pre, hook_resid_post on every layer
- Normalization: ln1 (input_layernorm), ln2 (post_attention_layernorm)
- MLP: shared_mlp input/output hooks
- MoE: block_sparse_moe input/output and router_scores hooks
- Attention/Mamba internals are NOT individually hooked (conditional per layer)
"""

def __init__(self, cfg: Any) -> None:
"""Initialize the Granite MoE Hybrid architecture adapter."""
# Call ArchitectureAdapter.__init__ directly, not GraniteArchitectureAdapter.__init__,
# because we need to customize the setup sequence
ArchitectureAdapter.__init__(self, cfg)

self._setup_common_config(cfg)

# Hybrid may use "rope" or "nope" (no positional embeddings)
pos_emb_type = getattr(cfg, "position_embedding_type", "rope")
if pos_emb_type != "rope":
self.cfg.positional_embedding_type = "none"

# No attention weight conversions — attn Q/K/V aren't mapped as submodules
self.weight_processing_conversions = {}
self.component_mapping = self._build_component_mapping()

def _build_component_mapping(self) -> dict:
"""Build component mapping with only universal (all-layer) submodules."""
block_submodules = {
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
"shared_mlp": MLPBridge(
name="shared_mlp",
config=self.cfg,
submodules={
"in": LinearBridge(name="input_linear"),
"out": LinearBridge(name="output_linear"),
},
),
}

num_experts = getattr(self.cfg, "num_experts", None) or getattr(
self.cfg, "num_local_experts", 0
)
if num_experts and num_experts > 0:
block_submodules["moe"] = MoEBridge(
name="block_sparse_moe",
config=self.cfg,
)

mapping = {
"embed": EmbeddingBridge(name="model.embed_tokens"),
"blocks": BlockBridge(
name="model.layers",
submodules=block_submodules,
),
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}

if self.cfg.positional_embedding_type == "rotary":
mapping["rotary_emb"] = RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg)

return mapping

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""No-op for hybrid models.

Hybrid models don't map attention as a submodule (it's conditional per
layer), so there are no rotary embedding references to set up.
"""
Loading
Loading