Skip to content

Commit e46d47b

Browse files
authored
Setup architecture adapters for the 3 Granite Architectures (#1206)
* Setup architecture adapters for the 3 Granite Architectures * CI checks
1 parent 04ccabf commit e46d47b

10 files changed

Lines changed: 10191 additions & 981 deletions

File tree

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
Gpt2LmHeadCustomArchitectureAdapter,
1717
GptjArchitectureAdapter,
1818
GPTOSSArchitectureAdapter,
19+
GraniteArchitectureAdapter,
20+
GraniteMoeArchitectureAdapter,
21+
GraniteMoeHybridArchitectureAdapter,
1922
LlamaArchitectureAdapter,
2023
LlavaArchitectureAdapter,
2124
LlavaNextArchitectureAdapter,
@@ -51,6 +54,9 @@
5154
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
5255
"Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
5356
"Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
57+
"GraniteForCausalLM": GraniteArchitectureAdapter,
58+
"GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
59+
"GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
5460
"GPT2LMHeadModel": GPT2ArchitectureAdapter,
5561
"GptOssForCausalLM": GPTOSSArchitectureAdapter,
5662
"GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,

transformer_lens/model_bridge/generalized_components/moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def get_random_inputs(
6565
if dtype is None:
6666
dtype = torch.float32
6767
d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
68-
return {
69-
"hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
70-
}
68+
# Use positional args to avoid parameter name mismatches across MoE implementations
69+
# (e.g., Mixtral uses "hidden_states", GraniteMoe uses "layer_input")
70+
return {"args": (torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype),)}
7171

7272
def forward(self, *args: Any, **kwargs: Any) -> Any:
7373
"""Forward pass through the MoE bridge.

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ def boot(
342342
attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
343343
if attn_logit_softcapping is not None:
344344
bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
345+
# Propagate position_embedding_type for Granite Hybrid models that use
346+
# "nope" (no positional embeddings) instead of "rope" on some/all layers.
347+
position_embedding_type = getattr(hf_config, "position_embedding_type", None)
348+
if position_embedding_type is not None:
349+
bridge_config.position_embedding_type = position_embedding_type
345350
# Propagate vision config for multimodal models so the adapter can
346351
# select the correct vision encoder bridge (CLIP vs SigLIP).
347352
if hasattr(hf_config, "vision_config") and hf_config.vision_config is not None:

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@
2121
from transformer_lens.model_bridge.supported_architectures.gemma3_multimodal import (
2222
Gemma3MultimodalArchitectureAdapter,
2323
)
24+
from transformer_lens.model_bridge.supported_architectures.granite import (
25+
GraniteArchitectureAdapter,
26+
)
27+
from transformer_lens.model_bridge.supported_architectures.granite_moe import (
28+
GraniteMoeArchitectureAdapter,
29+
)
30+
from transformer_lens.model_bridge.supported_architectures.granite_moe_hybrid import (
31+
GraniteMoeHybridArchitectureAdapter,
32+
)
2433
from transformer_lens.model_bridge.supported_architectures.gpt2 import (
2534
GPT2ArchitectureAdapter,
2635
)
@@ -116,6 +125,9 @@
116125
"Gemma2ArchitectureAdapter",
117126
"Gemma3ArchitectureAdapter",
118127
"Gemma3MultimodalArchitectureAdapter",
128+
"GraniteArchitectureAdapter",
129+
"GraniteMoeArchitectureAdapter",
130+
"GraniteMoeHybridArchitectureAdapter",
119131
"GPT2ArchitectureAdapter",
120132
"GPTOSSArchitectureAdapter",
121133
"Gpt2LmHeadCustomArchitectureAdapter",
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Granite architecture adapter.
2+
3+
Base adapter for the IBM Granite model family. Provides shared config setup and
4+
helper methods used by GraniteMoe and GraniteMoeHybrid variants.
5+
"""
6+
7+
from typing import Any, Dict
8+
9+
from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
10+
from transformer_lens.conversion_utils.param_processing_conversion import (
11+
ParamProcessingConversion,
12+
)
13+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
14+
from transformer_lens.model_bridge.generalized_components import (
15+
BlockBridge,
16+
EmbeddingBridge,
17+
GatedMLPBridge,
18+
LinearBridge,
19+
PositionEmbeddingsAttentionBridge,
20+
RMSNormalizationBridge,
21+
RotaryEmbeddingBridge,
22+
UnembeddingBridge,
23+
)
24+
25+
26+
class GraniteArchitectureAdapter(ArchitectureAdapter):
27+
"""Architecture adapter for IBM Granite models (dense).
28+
29+
Granite is a Llama-like architecture with RMSNorm, rotary position embeddings
30+
(RoPE), GQA, and a gated MLP (SiLU activation). Granite-specific scaling
31+
multipliers are handled by the HF model's native forward pass.
32+
33+
Optional Parameters (may not exist in state_dict):
34+
-------------------------------------------------
35+
Granite models do NOT have biases on attention and MLP projections:
36+
37+
- blocks.{i}.attn.b_Q/b_K/b_V/b_O - No bias on attention projections
38+
- blocks.{i}.mlp.b_in/b_gate/b_out - No bias on MLP projections
39+
- blocks.{i}.ln1.b, blocks.{i}.ln2.b, ln_final.b - RMSNorm has no bias
40+
"""
41+
42+
def __init__(self, cfg: Any) -> None:
43+
"""Initialize the Granite architecture adapter."""
44+
super().__init__(cfg)
45+
46+
self._setup_common_config(cfg)
47+
n_kv_heads = self._get_n_kv_heads()
48+
self.weight_processing_conversions = self._build_attn_weight_conversions(n_kv_heads)
49+
self.component_mapping = self._build_component_mapping()
50+
51+
def _setup_common_config(self, cfg: Any) -> None:
52+
"""Set up config variables shared across all Granite variants."""
53+
self.cfg.normalization_type = "RMS"
54+
self.cfg.positional_embedding_type = "rotary"
55+
self.cfg.final_rms = True
56+
self.cfg.gated_mlp = True
57+
self.cfg.attn_only = False
58+
self.cfg.uses_rms_norm = True
59+
self.cfg.eps_attr = "variance_epsilon"
60+
61+
self.default_config = {
62+
"d_model": cfg.d_model,
63+
"d_head": cfg.d_model // cfg.n_heads,
64+
"n_heads": cfg.n_heads,
65+
"n_layers": cfg.n_layers,
66+
"d_vocab": cfg.d_vocab,
67+
}
68+
69+
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
70+
self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
71+
self.cfg.n_key_value_heads = cfg.n_key_value_heads
72+
73+
def _get_n_kv_heads(self) -> int:
74+
"""Get the number of key-value heads (for GQA or MHA)."""
75+
if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None:
76+
return self.cfg.n_key_value_heads
77+
return self.cfg.n_heads
78+
79+
def _build_attn_weight_conversions(
80+
self, n_kv_heads: int
81+
) -> Dict[str, ParamProcessingConversion | str]:
82+
"""Build weight processing conversions for attention projections."""
83+
return {
84+
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
85+
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
86+
),
87+
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
88+
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
89+
),
90+
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
91+
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
92+
),
93+
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
94+
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
95+
),
96+
}
97+
98+
def _build_attention_bridge(self) -> PositionEmbeddingsAttentionBridge:
99+
"""Build the standard Granite attention bridge."""
100+
return PositionEmbeddingsAttentionBridge(
101+
name="self_attn",
102+
config=self.cfg,
103+
submodules={
104+
"q": LinearBridge(name="q_proj"),
105+
"k": LinearBridge(name="k_proj"),
106+
"v": LinearBridge(name="v_proj"),
107+
"o": LinearBridge(name="o_proj"),
108+
},
109+
requires_attention_mask=True,
110+
requires_position_embeddings=True,
111+
)
112+
113+
def _build_mlp_bridge(self) -> GatedMLPBridge:
114+
"""Build the dense gated MLP bridge."""
115+
return GatedMLPBridge(
116+
name="mlp",
117+
config=self.cfg,
118+
submodules={
119+
"gate": LinearBridge(name="gate_proj"),
120+
"in": LinearBridge(name="up_proj"),
121+
"out": LinearBridge(name="down_proj"),
122+
},
123+
)
124+
125+
def _build_component_mapping(self) -> dict:
126+
"""Build the full component mapping for dense Granite."""
127+
return {
128+
"embed": EmbeddingBridge(name="model.embed_tokens"),
129+
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
130+
"blocks": BlockBridge(
131+
name="model.layers",
132+
submodules={
133+
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
134+
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
135+
"attn": self._build_attention_bridge(),
136+
"mlp": self._build_mlp_bridge(),
137+
},
138+
),
139+
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
140+
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
141+
}
142+
143+
def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
144+
"""Set up rotary embedding references for Granite component testing.
145+
146+
Args:
147+
hf_model: The HuggingFace Granite model instance
148+
bridge_model: The TransformerBridge model (if available)
149+
"""
150+
if not hasattr(hf_model.model, "rotary_emb"):
151+
return
152+
153+
rotary_emb = hf_model.model.rotary_emb
154+
155+
if bridge_model is not None and hasattr(bridge_model, "blocks"):
156+
for block in bridge_model.blocks:
157+
if hasattr(block, "attn"):
158+
block.attn.set_rotary_emb(rotary_emb)
159+
160+
try:
161+
attn_bridge = self.get_generalized_component("blocks.0.attn")
162+
attn_bridge.set_rotary_emb(rotary_emb)
163+
except (AttributeError, KeyError):
164+
pass
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Granite MoE architecture adapter."""
2+
3+
from transformer_lens.model_bridge.generalized_components import (
4+
BlockBridge,
5+
EmbeddingBridge,
6+
MoEBridge,
7+
RMSNormalizationBridge,
8+
RotaryEmbeddingBridge,
9+
UnembeddingBridge,
10+
)
11+
from transformer_lens.model_bridge.supported_architectures.granite import (
12+
GraniteArchitectureAdapter,
13+
)
14+
15+
16+
class GraniteMoeArchitectureAdapter(GraniteArchitectureAdapter):
17+
"""Architecture adapter for IBM Granite MoE models.
18+
19+
Identical to dense Granite but replaces the gated MLP with a Sparse Mixture
20+
of Experts block (block_sparse_moe) using batched expert parameters and
21+
top-k routing.
22+
"""
23+
24+
def _build_component_mapping(self) -> dict:
25+
"""Build component mapping with MoE instead of dense MLP."""
26+
return {
27+
"embed": EmbeddingBridge(name="model.embed_tokens"),
28+
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
29+
"blocks": BlockBridge(
30+
name="model.layers",
31+
submodules={
32+
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
33+
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
34+
"attn": self._build_attention_bridge(),
35+
"mlp": MoEBridge(
36+
name="block_sparse_moe",
37+
config=self.cfg,
38+
),
39+
},
40+
),
41+
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
42+
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
43+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Granite MoE Hybrid architecture adapter.
2+
3+
GraniteMoeHybridForCausalLM is a hybrid Mamba + Attention architecture with
4+
Sparse Mixture of Experts. Layers alternate between Mamba SSM blocks and
5+
standard attention blocks, with a shared MLP and optional sparse MoE on
6+
every layer.
7+
8+
Since self_attn is None on Mamba layers and mamba is None on attention
9+
layers, we only map submodules that exist on ALL layers (norms, shared_mlp,
10+
block_sparse_moe). The HF native forward handles mamba/attention dispatch.
11+
"""
12+
13+
from typing import Any
14+
15+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
16+
from transformer_lens.model_bridge.generalized_components import (
17+
BlockBridge,
18+
EmbeddingBridge,
19+
LinearBridge,
20+
MLPBridge,
21+
MoEBridge,
22+
RMSNormalizationBridge,
23+
RotaryEmbeddingBridge,
24+
UnembeddingBridge,
25+
)
26+
from transformer_lens.model_bridge.supported_architectures.granite import (
27+
GraniteArchitectureAdapter,
28+
)
29+
30+
31+
class GraniteMoeHybridArchitectureAdapter(GraniteArchitectureAdapter):
32+
"""Architecture adapter for IBM Granite MoE Hybrid models.
33+
34+
Hybrid Mamba2 + Attention architecture with Sparse MoE. Most layers are Mamba
35+
SSM blocks; a few are standard attention (determined by config.layer_types).
36+
37+
Since self_attn is None on Mamba layers and mamba is None on attention layers,
38+
we only map submodules present on ALL layers (norms, shared_mlp, MoE). The HF
39+
native forward handles mamba/attention dispatch internally.
40+
41+
Hook coverage:
42+
- Block-level: hook_resid_pre, hook_resid_post on every layer
43+
- Normalization: ln1 (input_layernorm), ln2 (post_attention_layernorm)
44+
- MLP: shared_mlp input/output hooks
45+
- MoE: block_sparse_moe input/output and router_scores hooks
46+
- Attention/Mamba internals are NOT individually hooked (conditional per layer)
47+
"""
48+
49+
def __init__(self, cfg: Any) -> None:
50+
"""Initialize the Granite MoE Hybrid architecture adapter."""
51+
# Call ArchitectureAdapter.__init__ directly, not GraniteArchitectureAdapter.__init__,
52+
# because we need to customize the setup sequence
53+
ArchitectureAdapter.__init__(self, cfg)
54+
55+
self._setup_common_config(cfg)
56+
57+
# Hybrid may use "rope" or "nope" (no positional embeddings)
58+
pos_emb_type = getattr(cfg, "position_embedding_type", "rope")
59+
if pos_emb_type != "rope":
60+
self.cfg.positional_embedding_type = "none"
61+
62+
# No attention weight conversions — attn Q/K/V aren't mapped as submodules
63+
self.weight_processing_conversions = {}
64+
self.component_mapping = self._build_component_mapping()
65+
66+
def _build_component_mapping(self) -> dict:
67+
"""Build component mapping with only universal (all-layer) submodules."""
68+
block_submodules = {
69+
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
70+
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
71+
"shared_mlp": MLPBridge(
72+
name="shared_mlp",
73+
config=self.cfg,
74+
submodules={
75+
"in": LinearBridge(name="input_linear"),
76+
"out": LinearBridge(name="output_linear"),
77+
},
78+
),
79+
}
80+
81+
num_experts = getattr(self.cfg, "num_experts", None) or getattr(
82+
self.cfg, "num_local_experts", 0
83+
)
84+
if num_experts and num_experts > 0:
85+
block_submodules["moe"] = MoEBridge(
86+
name="block_sparse_moe",
87+
config=self.cfg,
88+
)
89+
90+
mapping = {
91+
"embed": EmbeddingBridge(name="model.embed_tokens"),
92+
"blocks": BlockBridge(
93+
name="model.layers",
94+
submodules=block_submodules,
95+
),
96+
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
97+
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
98+
}
99+
100+
if self.cfg.positional_embedding_type == "rotary":
101+
mapping["rotary_emb"] = RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg)
102+
103+
return mapping
104+
105+
def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
106+
"""No-op for hybrid models.
107+
108+
Hybrid models don't map attention as a submodule (it's conditional per
109+
layer), so there are no rotary embedding references to set up.
110+
"""

0 commit comments

Comments
 (0)