diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 46b42c2ba..5a54f6417 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -996,6 +996,8 @@ " \"stabilityai/stablelm-base-alpha-7b\",\n", " \"stabilityai/stablelm-tuned-alpha-3b\",\n", " \"stabilityai/stablelm-tuned-alpha-7b\",\n", + " \"swiss-ai/Apertus-8B-2509\",\n", + " \"swiss-ai/Apertus-8B-Instruct-2509\",\n", "]\n", "\n", "if IN_COLAB:\n", diff --git a/tests/unit/test_xielu.py b/tests/unit/test_xielu.py new file mode 100644 index 000000000..866dce1b0 --- /dev/null +++ b/tests/unit/test_xielu.py @@ -0,0 +1,93 @@ +import pytest +import torch +import torch.nn as nn + +from transformer_lens.utils import ACTIVATION_FN_DICT, XIELU, xielu + + +class TestXIELUClass: + def test_default_parameters(self): + act = XIELU() + assert act.alpha_p.item() == pytest.approx(0.8) + assert act.alpha_n.item() == pytest.approx(0.8) + assert act.beta.item() == pytest.approx(0.5) + + def test_custom_parameters(self): + act = XIELU(alpha_p_init=1.0, alpha_n_init=0.5, beta_init=0.25) + assert act.alpha_p.item() == pytest.approx(1.0) + assert act.alpha_n.item() == pytest.approx(0.5) + assert act.beta.item() == pytest.approx(0.25) + + def test_parameters_are_trainable(self): + act = XIELU() + assert isinstance(act.alpha_p, nn.Parameter) + assert isinstance(act.alpha_n, nn.Parameter) + assert isinstance(act.beta, nn.Parameter) + assert act.alpha_p.requires_grad + assert act.alpha_n.requires_grad + assert act.beta.requires_grad + + def test_output_shape_preserved(self): + act = XIELU() + x = torch.randn(2, 10, 32) + assert act(x).shape == x.shape + + def test_positive_branch(self): + """For x > 0: f(x) = alpha_p * x^2 + beta * x.""" + act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0) + x = torch.tensor([[[1.0, 2.0, 3.0]]]) # (1, 1, 3) + expected = 1.0 * x**2 + 1.0 * x # alpha_p * x^2 + beta * x + torch.testing.assert_close(act(x), expected) + + def test_negative_branch(self): + """For x <= 0: f(x) = alpha_n * (exp(clamp(x, eps)) - 1) - alpha_n * x + beta * x.""" + alpha_n, beta, eps = 1.0, 1.0, -1e-6 + act = XIELU(alpha_p_init=1.0, alpha_n_init=alpha_n, beta_init=beta, eps=eps) + x = torch.tensor([[[-1.0, -2.0, -3.0]]]) # (1, 1, 3) + expected = alpha_n * torch.expm1(torch.clamp_max(x, eps)) - alpha_n * x + beta * x + torch.testing.assert_close(act(x), expected) + + def test_zero_input(self): + """x = 0 falls into the negative branch.""" + act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0, eps=-1e-6) + x = torch.tensor([[[0.0]]]) # (1, 1, 1) + expected = 1.0 * torch.expm1(torch.clamp_max(x, -1e-6)) - 1.0 * x + 1.0 * x + torch.testing.assert_close(act(x), expected) + + def test_gradients_flow_through(self): + act = XIELU() + x = torch.randn(4, 8, 16, requires_grad=True) + out = act(x).sum() + out.backward() + assert x.grad is not None + assert act.alpha_p.grad is not None + assert act.alpha_n.grad is not None + assert act.beta.grad is not None + + def test_class_and_function_agree_at_defaults(self): + """XIELU class with default params should match the fixed xielu function.""" + act = XIELU() # defaults match xielu() fixed values + x = torch.randn(2, 5, 16) + torch.testing.assert_close(act(x), xielu(x)) + + +class TestXIELUFunction: + def test_output_shape_preserved(self): + x = torch.randn(2, 10, 32) + assert xielu(x).shape == x.shape + + def test_positive_values(self): + """For x > 0: f(x) = 0.8*x^2 + 0.5*x.""" + x = torch.tensor([[[1.0, 2.0]]]) # (1, 1, 2) + expected = 0.8 * x**2 + 0.5 * x + torch.testing.assert_close(xielu(x), expected) + + def test_negative_values(self): + """For x <= 0: f(x) = 0.8*(exp(clamp(x,-1e-6))-1) - 0.8*x + 0.5*x.""" + x = torch.tensor([[[-1.0, -2.0]]]) # (1, 1, 2) + expected = 0.8 * torch.expm1(torch.clamp_max(x, -1e-6)) - 0.8 * x + 0.5 * x + torch.testing.assert_close(xielu(x), expected) + + def test_registered_in_activation_dict(self): + assert "xielu" in ACTIVATION_FN_DICT + assert ACTIVATION_FN_DICT["xielu"] is xielu diff --git a/transformer_lens/components/mlps/can_be_used_as_mlp.py b/transformer_lens/components/mlps/can_be_used_as_mlp.py index bf51a3a28..fc4599918 100644 --- a/transformer_lens/components/mlps/can_be_used_as_mlp.py +++ b/transformer_lens/components/mlps/can_be_used_as_mlp.py @@ -18,6 +18,7 @@ from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities.activation_functions import ActivationFunction +from transformer_lens.utils import XIELU class CanBeUsedAsMLP(nn.Module): @@ -66,7 +67,10 @@ def select_activation_function(self) -> None: ValueError: If the configure activation function is not supported. """ - self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) + if self.cfg.act_fn == "xielu": + self.act_fn = XIELU() + else: + self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) if self.cfg.is_layer_norm_activation(): self.hook_mid = HookPoint() diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index bc1d8fe82..461a6d0a9 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -24,6 +24,7 @@ import transformer_lens.utils as utils from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions import ( + convert_apertus_weights, convert_bert_weights, convert_bloom_weights, convert_coder_weights, @@ -248,6 +249,8 @@ "microsoft/phi-2", "microsoft/Phi-3-mini-4k-instruct", "microsoft/phi-4", + "swiss-ai/Apertus-8B-2509", + "swiss-ai/Apertus-8B-Instruct-2509", "google/gemma-2b", "google/gemma-7b", "google/gemma-2b-it", @@ -720,6 +723,8 @@ "microsoft/phi-2": ["phi-2"], "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], "microsoft/phi-4": ["phi-4"], + "swiss-ai/Apertus-8B-2509": ["apertus-8b", "apertus"], + "swiss-ai/Apertus-8B-Instruct-2509": ["apertus-8b-instruct", "apertus-instruct"], "google/gemma-2b": ["gemma-2b"], "google/gemma-7b": ["gemma-7b"], "google/gemma-2b-it": ["gemma-2b-it"], @@ -775,6 +780,7 @@ "microsoft/phi-2", "microsoft/Phi-3-mini-4k-instruct", "microsoft/phi-4", + "swiss-ai/Apertus-", ) @@ -1506,6 +1512,44 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } + elif architecture == "ApertusForCausalLM": + n_heads = hf_config.num_attention_heads + d_head = hf_config.hidden_size // n_heads + num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads) + n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": d_head, + "n_heads": n_heads, + "n_key_value_heads": n_kv_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_dim": d_head, + "rotary_base": getattr(hf_config, "rope_theta", None), + "gated_mlp": False, + "final_rms": True, + "use_qk_norm": getattr(hf_config, "qk_norm", False), + } + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling: + rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower() + else: + rope_type = "" + if rope_type == "llama3": + assert rope_scaling is not None + cfg_dict["use_NTK_by_parts_rope"] = True + cfg_dict["NTK_original_ctx_len"] = rope_scaling.get( + "original_max_position_embeddings", hf_config.max_position_embeddings + ) + cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0) + cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0) + cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0) elif official_model_name.startswith("google/gemma-3-270m"): # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models @@ -2431,6 +2475,8 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "ApertusForCausalLM": + state_dict = convert_apertus_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma3ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma3ForConditionalGeneration": diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index b58cce705..b8d940f62 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,4 +19,5 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .apertus import convert_apertus_weights from .openai import convert_gpt_oss_weights diff --git a/transformer_lens/pretrained/weight_conversions/apertus.py b/transformer_lens/pretrained/weight_conversions/apertus.py new file mode 100644 index 000000000..9c76f8e87 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/apertus.py @@ -0,0 +1,132 @@ +""" +Apertus is Llama like model architecture from Swiss AI. +convert weights to standardized format for HookedTransformer +""" + +from typing import cast + +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): + state_dict = {} + + state_dict["embed.W_E"] = apertus.model.embed_tokens.weight + + using_gqa = cfg.n_key_value_heads is not None + gqa_uscore = "_" if using_gqa else "" + + n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) + + assert cfg.d_mlp is not None # keep mypy happy + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight + state_dict[f"blocks.{l}.ln1.b"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + W_Q = apertus.model.layers[l].self_attn.q_proj.weight + W_K = apertus.model.layers[l].self_attn.k_proj.weight + W_V = apertus.model.layers[l].self_attn.v_proj.weight + + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) + + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( + n_kv_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( + n_kv_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, + ) + + W_O = apertus.model.layers[l].self_attn.o_proj.weight + + if not cfg.load_in_4bit: + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + + state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight + state_dict[f"blocks.{l}.ln2.b"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight.T + else: + state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight + state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight + + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( + cfg.d_mlp, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + # Extract trainable activation parameters + mlp = apertus.model.layers[l].mlp + try: + if hasattr(mlp, "act_fn"): + alpha_p = mlp.act_fn.alpha_p + alpha_n = mlp.act_fn.alpha_n + beta = mlp.act_fn.beta + elif hasattr(mlp, "act"): + alpha_p = mlp.act.alpha_p + alpha_n = mlp.act.alpha_n + beta = mlp.act.beta + else: + alpha_p = mlp.alpha_p + alpha_n = mlp.alpha_n + beta = mlp.beta + state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = alpha_p + state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = alpha_n + state_dict[f"blocks.{l}.mlp.act_fn.beta"] = beta + except AttributeError: + # If parameters not found, use defaults + print(f"Activation parameters not found in layer {l}, using defaults") + state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor( + 0.8, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor( + 0.8, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor( + 0.5, dtype=cfg.dtype, device=cfg.device + ) + + state_dict["ln_final.w"] = apertus.model.norm.weight + state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + + state_dict["unembed.W_U"] = apertus.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) + + return state_dict diff --git a/transformer_lens/utilities/activation_functions.py b/transformer_lens/utilities/activation_functions.py index 4cecf3988..283f3b899 100644 --- a/transformer_lens/utilities/activation_functions.py +++ b/transformer_lens/utilities/activation_functions.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from transformer_lens.utils import gelu_fast, gelu_new, solu +from transformer_lens.utils import gelu_fast, gelu_new, solu, xielu # Convenient type for the format of each activation function ActivationFunction = Callable[..., torch.Tensor] @@ -24,4 +24,5 @@ "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), + "xielu": xielu, } diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 62a6e474e..697faab39 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -220,6 +220,74 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, " return input * F.softmax(input, dim=-1) +class XIELU(nn.Module): + """ + Trainable xIELU activation function as described by + https://arxiv.org/abs/2411.13010 + + Defined as: + f(x) = { + α_p * x² + β * x, if x > 0 + α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0 + } + where α_p, α_n, β are trainable parameters. + """ + + def __init__( + self, + alpha_p_init: float = 0.8, + alpha_n_init: float = 0.8, + beta_init: float = 0.5, + eps: float = -1e-6, + ): + super().__init__() + self.alpha_p = nn.Parameter(torch.tensor(alpha_p_init, dtype=torch.float32)) + self.alpha_n = nn.Parameter(torch.tensor(alpha_n_init, dtype=torch.float32)) + self.beta = nn.Parameter(torch.tensor(beta_init, dtype=torch.float32)) + self.eps = eps + + def forward( + self, input: Float[torch.Tensor, "batch pos d_mlp"] + ) -> Float[torch.Tensor, "batch pos d_mlp"]: + return torch.where( + input > 0, + self.alpha_p * input**2 + self.beta * input, + self.alpha_n * torch.expm1(torch.clamp_max(input, self.eps)) + - self.alpha_n * input + + self.beta * input, + ) + + +def xielu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: + """ + xIELU activation function as described by + https://arxiv.org/abs/2411.13010 + + and original code in: + https://github.com/rubber-duck-debug/xielu + + Defined as + + f(x) = { + α_p * x² + β * x, if x > 0 + α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0 + } + + in this function the values are FIXED. However, the script can_be_used_as_mlp.py correctly used the XIELU class with trainable parameters, so the parameters can be trained if desired. + """ + alpha_p: float = 0.8 + alpha_n: float = 0.8 + beta: float = 0.5 + eps: float = -1e-6 + + # The core calculation logic: + return torch.where( + input > 0, + alpha_p * input * input + beta * input, + alpha_n * torch.expm1(torch.clamp_max(input, eps)) - alpha_n * input + beta * input, + ) + + ACTIVATION_FN_DICT = { "solu": solu, "solu_ln": solu, @@ -229,6 +297,7 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, " "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": gelu_pytorch_tanh, + "xielu": xielu, }