Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions demos/Colab_Compatibility.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,8 @@
" \"stabilityai/stablelm-base-alpha-7b\",\n",
" \"stabilityai/stablelm-tuned-alpha-3b\",\n",
" \"stabilityai/stablelm-tuned-alpha-7b\",\n",
" \"swiss-ai/Apertus-8B-2509\",\n",
" \"swiss-ai/Apertus-8B-Instruct-2509\",\n",
"]\n",
"\n",
"if IN_COLAB:\n",
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/test_xielu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import torch
import torch.nn as nn

from transformer_lens.utils import ACTIVATION_FN_DICT, XIELU, xielu


class TestXIELUClass:
def test_default_parameters(self):
act = XIELU()
assert act.alpha_p.item() == pytest.approx(0.8)
assert act.alpha_n.item() == pytest.approx(0.8)
assert act.beta.item() == pytest.approx(0.5)

def test_custom_parameters(self):
act = XIELU(alpha_p_init=1.0, alpha_n_init=0.5, beta_init=0.25)
assert act.alpha_p.item() == pytest.approx(1.0)
assert act.alpha_n.item() == pytest.approx(0.5)
assert act.beta.item() == pytest.approx(0.25)

def test_parameters_are_trainable(self):
act = XIELU()
assert isinstance(act.alpha_p, nn.Parameter)
assert isinstance(act.alpha_n, nn.Parameter)
assert isinstance(act.beta, nn.Parameter)
assert act.alpha_p.requires_grad
assert act.alpha_n.requires_grad
assert act.beta.requires_grad

def test_output_shape_preserved(self):
act = XIELU()
x = torch.randn(2, 10, 32)
assert act(x).shape == x.shape

def test_positive_branch(self):
"""For x > 0: f(x) = alpha_p * x^2 + beta * x."""
act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0)
x = torch.tensor([[[1.0, 2.0, 3.0]]]) # (1, 1, 3)
expected = 1.0 * x**2 + 1.0 * x # alpha_p * x^2 + beta * x
torch.testing.assert_close(act(x), expected)

def test_negative_branch(self):
"""For x <= 0: f(x) = alpha_n * (exp(clamp(x, eps)) - 1) - alpha_n * x + beta * x."""
alpha_n, beta, eps = 1.0, 1.0, -1e-6
act = XIELU(alpha_p_init=1.0, alpha_n_init=alpha_n, beta_init=beta, eps=eps)
x = torch.tensor([[[-1.0, -2.0, -3.0]]]) # (1, 1, 3)
expected = alpha_n * torch.expm1(torch.clamp_max(x, eps)) - alpha_n * x + beta * x
torch.testing.assert_close(act(x), expected)

def test_zero_input(self):
"""x = 0 falls into the negative branch."""
act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0, eps=-1e-6)
x = torch.tensor([[[0.0]]]) # (1, 1, 1)
expected = 1.0 * torch.expm1(torch.clamp_max(x, -1e-6)) - 1.0 * x + 1.0 * x
torch.testing.assert_close(act(x), expected)

def test_gradients_flow_through(self):
act = XIELU()
x = torch.randn(4, 8, 16, requires_grad=True)
out = act(x).sum()
out.backward()
assert x.grad is not None
assert act.alpha_p.grad is not None
assert act.alpha_n.grad is not None
assert act.beta.grad is not None

def test_class_and_function_agree_at_defaults(self):
"""XIELU class with default params should match the fixed xielu function."""
act = XIELU() # defaults match xielu() fixed values
x = torch.randn(2, 5, 16)
torch.testing.assert_close(act(x), xielu(x))


class TestXIELUFunction:
def test_output_shape_preserved(self):
x = torch.randn(2, 10, 32)
assert xielu(x).shape == x.shape

def test_positive_values(self):
"""For x > 0: f(x) = 0.8*x^2 + 0.5*x."""
x = torch.tensor([[[1.0, 2.0]]]) # (1, 1, 2)
expected = 0.8 * x**2 + 0.5 * x
torch.testing.assert_close(xielu(x), expected)

def test_negative_values(self):
"""For x <= 0: f(x) = 0.8*(exp(clamp(x,-1e-6))-1) - 0.8*x + 0.5*x."""
x = torch.tensor([[[-1.0, -2.0]]]) # (1, 1, 2)
expected = 0.8 * torch.expm1(torch.clamp_max(x, -1e-6)) - 0.8 * x + 0.5 * x
torch.testing.assert_close(xielu(x), expected)

def test_registered_in_activation_dict(self):
assert "xielu" in ACTIVATION_FN_DICT
assert ACTIVATION_FN_DICT["xielu"] is xielu
6 changes: 5 additions & 1 deletion transformer_lens/components/mlps/can_be_used_as_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformer_lens.hook_points import HookPoint
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.utilities.activation_functions import ActivationFunction
from transformer_lens.utils import XIELU


class CanBeUsedAsMLP(nn.Module):
Expand Down Expand Up @@ -66,7 +67,10 @@ def select_activation_function(self) -> None:
ValueError: If the configure activation function is not supported.
"""

self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
if self.cfg.act_fn == "xielu":
self.act_fn = XIELU()
else:
self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)

if self.cfg.is_layer_norm_activation():
self.hook_mid = HookPoint()
Expand Down
46 changes: 46 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import transformer_lens.utils as utils
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions import (
convert_apertus_weights,
convert_bert_weights,
convert_bloom_weights,
convert_coder_weights,
Expand Down Expand Up @@ -248,6 +249,8 @@
"microsoft/phi-2",
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/phi-4",
"swiss-ai/Apertus-8B-2509",
"swiss-ai/Apertus-8B-Instruct-2509",
"google/gemma-2b",
"google/gemma-7b",
"google/gemma-2b-it",
Expand Down Expand Up @@ -720,6 +723,8 @@
"microsoft/phi-2": ["phi-2"],
"microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
"microsoft/phi-4": ["phi-4"],
"swiss-ai/Apertus-8B-2509": ["apertus-8b", "apertus"],
"swiss-ai/Apertus-8B-Instruct-2509": ["apertus-8b-instruct", "apertus-instruct"],
"google/gemma-2b": ["gemma-2b"],
"google/gemma-7b": ["gemma-7b"],
"google/gemma-2b-it": ["gemma-2b-it"],
Expand Down Expand Up @@ -775,6 +780,7 @@
"microsoft/phi-2",
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/phi-4",
"swiss-ai/Apertus-",
)


Expand Down Expand Up @@ -1506,6 +1512,44 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"parallel_attn_mlp": False,
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
}
elif architecture == "ApertusForCausalLM":
n_heads = hf_config.num_attention_heads
d_head = hf_config.hidden_size // n_heads
num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads)
n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": d_head,
"n_heads": n_heads,
"n_key_value_heads": n_kv_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_dim": d_head,
"rotary_base": getattr(hf_config, "rope_theta", None),
"gated_mlp": False,
"final_rms": True,
"use_qk_norm": getattr(hf_config, "qk_norm", False),
}
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling:
rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower()
else:
rope_type = ""
if rope_type == "llama3":
assert rope_scaling is not None
cfg_dict["use_NTK_by_parts_rope"] = True
cfg_dict["NTK_original_ctx_len"] = rope_scaling.get(
"original_max_position_embeddings", hf_config.max_position_embeddings
)
cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0)
cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0)
cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0)

elif official_model_name.startswith("google/gemma-3-270m"):
# Architecture for Gemma-3 270m and Gemma-3 270m Instruct models
Expand Down Expand Up @@ -2431,6 +2475,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "ApertusForCausalLM":
state_dict = convert_apertus_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma3ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma3ForConditionalGeneration":
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .apertus import convert_apertus_weights
from .openai import convert_gpt_oss_weights
132 changes: 132 additions & 0 deletions transformer_lens/pretrained/weight_conversions/apertus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Apertus is Llama like model architecture from Swiss AI.
convert weights to standardized format for HookedTransformer
"""

from typing import cast

import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_apertus_weights(apertus, cfg: HookedTransformerConfig):
state_dict = {}

state_dict["embed.W_E"] = apertus.model.embed_tokens.weight

using_gqa = cfg.n_key_value_heads is not None
gqa_uscore = "_" if using_gqa else ""

n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)

assert cfg.d_mlp is not None # keep mypy happy

for l in range(cfg.n_layers):
state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight
state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

W_Q = apertus.model.layers[l].self_attn.q_proj.weight
W_K = apertus.model.layers[l].self_attn.k_proj.weight
W_V = apertus.model.layers[l].self_attn.v_proj.weight

# in case of quantization,
# parameters should stay as bitsandbytes.nn.modules.Params4bit
if not cfg.load_in_4bit:
W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads)

state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
n_kv_heads,
cfg.d_head,
dtype=cfg.dtype,
device=cfg.device,
)
state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
n_kv_heads,
cfg.d_head,
dtype=cfg.dtype,
device=cfg.device,
)

W_O = apertus.model.layers[l].self_attn.o_proj.weight

if not cfg.load_in_4bit:
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)

state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device)

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight
state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

# in case of quantization,
# parameters should stay as bitsandbytes.nn.modules.Params4bit
if not cfg.load_in_4bit:
state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight.T
state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight.T
else:
state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight
state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight

state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
cfg.d_mlp, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

# Extract trainable activation parameters
mlp = apertus.model.layers[l].mlp
try:
if hasattr(mlp, "act_fn"):
alpha_p = mlp.act_fn.alpha_p
alpha_n = mlp.act_fn.alpha_n
beta = mlp.act_fn.beta
elif hasattr(mlp, "act"):
alpha_p = mlp.act.alpha_p
alpha_n = mlp.act.alpha_n
beta = mlp.act.beta
else:
alpha_p = mlp.alpha_p
alpha_n = mlp.alpha_n
beta = mlp.beta
state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = alpha_p
state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = alpha_n
state_dict[f"blocks.{l}.mlp.act_fn.beta"] = beta
except AttributeError:
# If parameters not found, use defaults
print(f"Activation parameters not found in layer {l}, using defaults")
state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor(
0.8, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor(
0.8, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor(
0.5, dtype=cfg.dtype, device=cfg.device
)

state_dict["ln_final.w"] = apertus.model.norm.weight
state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device)

state_dict["unembed.W_U"] = apertus.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)

return state_dict
3 changes: 2 additions & 1 deletion transformer_lens/utilities/activation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn.functional as F

from transformer_lens.utils import gelu_fast, gelu_new, solu
from transformer_lens.utils import gelu_fast, gelu_new, solu, xielu

# Convenient type for the format of each activation function
ActivationFunction = Callable[..., torch.Tensor]
Expand All @@ -24,4 +24,5 @@
"relu": F.relu,
"gelu": F.gelu,
"gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"),
"xielu": xielu,
}
Loading
Loading