From ddcb273e9c35fe6e03e770141bfe36afabb135ca Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 10 Mar 2026 16:43:18 +0100 Subject: [PATCH 1/9] Add Apertus model support with XIeLU activation --- .../components/mlps/can_be_used_as_mlp.py | 6 +- transformer_lens/loading_from_pretrained.py | 45 +++++++ .../pretrained/weight_conversions/__init__.py | 1 + .../pretrained/weight_conversions/apertus.py | 123 ++++++++++++++++++ .../utilities/activation_functions.py | 3 +- transformer_lens/utils.py | 59 +++++++++ 6 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 transformer_lens/pretrained/weight_conversions/apertus.py diff --git a/transformer_lens/components/mlps/can_be_used_as_mlp.py b/transformer_lens/components/mlps/can_be_used_as_mlp.py index b0945276b..6ff660e0c 100644 --- a/transformer_lens/components/mlps/can_be_used_as_mlp.py +++ b/transformer_lens/components/mlps/can_be_used_as_mlp.py @@ -17,6 +17,7 @@ from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities.activation_functions import ActivationFunction +from transformer_lens.utils import XIELU class CanBeUsedAsMLP(nn.Module): @@ -65,7 +66,10 @@ def select_activation_function(self) -> None: ValueError: If the configure activation function is not supported. """ - self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) + if self.cfg.act_fn == "xielu": + self.act_fn = XIELU() + else: + self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) if self.cfg.is_layer_norm_activation(): self.hook_mid = HookPoint() diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..0a54c3304 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -24,6 +24,7 @@ import transformer_lens.utils as utils from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions import ( + convert_apertus_weights, convert_bert_weights, convert_bloom_weights, convert_coder_weights, @@ -245,6 +246,8 @@ "microsoft/phi-2", "microsoft/Phi-3-mini-4k-instruct", "microsoft/phi-4", + "swiss-ai/Apertus-8B-2509", + "swiss-ai/Apertus-8B-Instruct-2509", "google/gemma-2b", "google/gemma-7b", "google/gemma-2b-it", @@ -701,6 +704,8 @@ "microsoft/phi-2": ["phi-2"], "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], "microsoft/phi-4": ["phi-4"], + "swiss-ai/Apertus-8B-2509": ["apertus-8b", "apertus"], + "swiss-ai/Apertus-8B-Instruct-2509": ["apertus-8b-instruct", "apertus-instruct"], "google/gemma-2b": ["gemma-2b"], "google/gemma-7b": ["gemma-7b"], "google/gemma-2b-it": ["gemma-2b-it"], @@ -742,6 +747,7 @@ "microsoft/phi-2", "microsoft/Phi-3-mini-4k-instruct", "microsoft/phi-4", + "swiss-ai/Apertus-", ) @@ -1436,6 +1442,43 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } + elif architecture == "ApertusForCausalLM": + n_heads = hf_config.num_attention_heads + d_head = hf_config.hidden_size // n_heads + num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads) + n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": d_head, + "n_heads": n_heads, + "n_key_value_heads": n_kv_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_dim": d_head, + "rotary_base": getattr(hf_config, "rope_theta", None), + "gated_mlp": False, + "final_rms": True, + "use_qk_norm": getattr(hf_config, "qk_norm", False), + } + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling: + rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower() + else: + rope_type = "" + if rope_type == "llama3": + cfg_dict["use_NTK_by_parts_rope"] = True + cfg_dict["NTK_original_ctx_len"] = rope_scaling.get( + "original_max_position_embeddings", hf_config.max_position_embeddings + ) + cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0) + cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0) + cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0) elif official_model_name.startswith("google/gemma-2b"): # Architecture for Gemma 2b and Gemma 2b Instruct models @@ -1986,6 +2029,8 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "ApertusForCausalLM": + state_dict = convert_apertus_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..b0defcf4c 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .apertus import convert_apertus_weights diff --git a/transformer_lens/pretrained/weight_conversions/apertus.py b/transformer_lens/pretrained/weight_conversions/apertus.py new file mode 100644 index 000000000..739a84d83 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/apertus.py @@ -0,0 +1,123 @@ +""" +Apertus is Llama like model architecture from Swiss AI. +convert weights to standardized format for HookedTransformer +""" + +from typing import cast + +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): + state_dict = {} + + state_dict["embed.W_E"] = apertus.model.embed_tokens.weight + + using_gqa = cfg.n_key_value_heads is not None + gqa_uscore = "_" if using_gqa else "" + + n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) + + + assert cfg.d_mlp is not None # keep mypy happy + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight + state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + + W_Q = apertus.model.layers[l].self_attn.q_proj.weight + W_K = apertus.model.layers[l].self_attn.k_proj.weight + W_V = apertus.model.layers[l].self_attn.v_proj.weight + + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) + + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( + n_kv_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( + n_kv_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, + ) + + W_O = apertus.model.layers[l].self_attn.o_proj.weight + + if not cfg.load_in_4bit: + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + + state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight + state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight.T + else: + state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight + state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight + + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( + cfg.d_mlp, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + # Extract trainable activation parameters + mlp = apertus.model.layers[l].mlp + try: + if hasattr(mlp, 'act_fn'): + alpha_p = mlp.act_fn.alpha_p + alpha_n = mlp.act_fn.alpha_n + beta = mlp.act_fn.beta + elif hasattr(mlp, 'act'): + alpha_p = mlp.act.alpha_p + alpha_n = mlp.act.alpha_n + beta = mlp.act.beta + else: + alpha_p = mlp.alpha_p + alpha_n = mlp.alpha_n + beta = mlp.beta + state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = alpha_p + state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = alpha_n + state_dict[f"blocks.{l}.mlp.act_fn.beta"] = beta + except AttributeError: + # If parameters not found, use defaults + print(f"Activation parameters not found in layer {l}, using defaults") + state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device) + state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device) + state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor(0.5, dtype=cfg.dtype, device=cfg.device) + + state_dict["ln_final.w"] = apertus.model.norm.weight + state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + + state_dict["unembed.W_U"] = apertus.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) + + return state_dict diff --git a/transformer_lens/utilities/activation_functions.py b/transformer_lens/utilities/activation_functions.py index 6cc701360..9cfa6eeb9 100644 --- a/transformer_lens/utilities/activation_functions.py +++ b/transformer_lens/utilities/activation_functions.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F -from transformer_lens.utils import gelu_fast, gelu_new, solu +from transformer_lens.utils import gelu_fast, gelu_new, solu, xielu # Convenient type for the format of each activation function ActivationFunction = Callable[..., torch.Tensor] @@ -23,4 +23,5 @@ "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), + "xielu": xielu, } diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index c0992848a..67c5c6652 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -30,6 +30,7 @@ from transformer_lens.FactoredMatrix import FactoredMatrix + CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE USE_DEFAULT_VALUE = None @@ -203,6 +204,63 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, " return input * F.softmax(input, dim=-1) +class XIELU(nn.Module): + """ + Trainable xIELU activation function as described by + https://arxiv.org/abs/2411.13010 + + Defined as: + f(x) = { + α_p * x² + β * x, if x > 0 + α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0 + } + where α_p, α_n, β are trainable parameters. + """ + def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta_init: float = 0.5, eps: float = -1e-6): + super().__init__() + self.alpha_p = nn.Parameter(torch.tensor(alpha_p_init, dtype=torch.float32)) + self.alpha_n = nn.Parameter(torch.tensor(alpha_n_init, dtype=torch.float32)) + self.beta = nn.Parameter(torch.tensor(beta_init, dtype=torch.float32)) + self.eps = eps + + def forward(self, input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: + return torch.where( + input > 0, + self.alpha_p * input ** 2 + self.beta * input, + self.alpha_n * torch.expm1(torch.clamp_max(input, self.eps)) - self.alpha_n * input + self.beta * input + ) + + +def xielu( + input: Float[torch.Tensor, "batch pos d_mlp"] +) -> Float[torch.Tensor, "batch pos d_mlp"]: + """ + xIELU activation function as described by + https://arxiv.org/abs/2411.13010 + + and original code in: + https://github.com/rubber-duck-debug/xielu + + Defined as + + f(x) = { + α_p * x² + β * x, if x > 0 + α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0 + } + + in this function the values are FIXED. However, the script can_be_used_as_mlp.py correctly used the XIELU class with trainable parameters, so the parameters can be trained if desired. + """ + alpha_p: float = 0.8 + alpha_n: float = 0.8 + beta: float = 0.5 + eps: float = -1e-6 + + # The core calculation logic: + return torch.where(input > 0, + alpha_p * input * input + beta * input, + alpha_n * torch.expm1(torch.clamp_max(input, eps)) - alpha_n * input + beta * input) + + ACTIVATION_FN_DICT = { "solu": solu, "solu_ln": solu, @@ -212,6 +270,7 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, " "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": gelu_pytorch_tanh, + "xielu": xielu, } From dc70868fcf8e2a1c2929d65b9a57cb8bba69514f Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 16 Mar 2026 09:10:46 -0500 Subject: [PATCH 2/9] Updating Interactive Neuroscope, CI to properly install demo --- .github/workflows/checks.yml | 2 +- demos/Interactive_Neuroscope.ipynb | 82 ++++++------------------------ 2 files changed, 17 insertions(+), 67 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 8dd414960..f8f069035 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -174,7 +174,7 @@ jobs: - name: Install dependencies run: | poetry check --lock - poetry install --with dev,jupyter + poetry install --with dev,jupyter,demo - name: Install pandoc uses: awalsh128/cache-apt-pkgs-action@latest with: diff --git a/demos/Interactive_Neuroscope.ipynb b/demos/Interactive_Neuroscope.ipynb index bc288411a..d4bffb6a6 100644 --- a/demos/Interactive_Neuroscope.ipynb +++ b/demos/Interactive_Neuroscope.ipynb @@ -101,18 +101,18 @@ "evalue": "tokenizers>=0.21,<0.22 is required for a normal functioning of this module, but found tokenizers==0.22.2.\nTry: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main", "output_type": "error", "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mImportError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[8], line 3\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;66;03m# NBVAL_IGNORE_OUTPUT\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mgradio\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mgr\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtransformer_lens\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m HookedTransformer\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtransformer_lens\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m to_numpy\n\u001B[1;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mIPython\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdisplay\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m HTML\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/transformer_lens/__init__.py:1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m hook_points\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m utils\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m evals\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/transformer_lens/hook_points.py:20\u001B[0m\n\u001B[1;32m 17\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mhooks\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mhooks\u001B[39;00m\n\u001B[1;32m 18\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m Tensor\n\u001B[0;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtransformer_lens\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m Slice, SliceInput\n\u001B[1;32m 23\u001B[0m \u001B[38;5;129m@dataclass\u001B[39m\n\u001B[1;32m 24\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mLensHandle\u001B[39;00m:\n\u001B[1;32m 25\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Dataclass that holds information about a PyTorch hook.\"\"\"\u001B[39;00m\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/transformer_lens/utils.py:24\u001B[0m\n\u001B[1;32m 22\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mnn\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mnn\u001B[39;00m\n\u001B[1;32m 23\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mnn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfunctional\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mF\u001B[39;00m\n\u001B[0;32m---> 24\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtransformers\u001B[39;00m\n\u001B[1;32m 25\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mdatasets\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01marrow_dataset\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m Dataset\n\u001B[1;32m 26\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mdatasets\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mload\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m load_dataset\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/__init__.py:27\u001B[0m\n\u001B[1;32m 24\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtyping\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m TYPE_CHECKING\n\u001B[1;32m 26\u001B[0m \u001B[38;5;66;03m# Check the dependencies satisfy the minimal versions required.\u001B[39;00m\n\u001B[0;32m---> 27\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m dependency_versions_check\n\u001B[1;32m 28\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[1;32m 29\u001B[0m OptionalDependencyNotAvailable,\n\u001B[1;32m 30\u001B[0m _LazyModule,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 36\u001B[0m is_pretty_midi_available,\n\u001B[1;32m 37\u001B[0m )\n\u001B[1;32m 39\u001B[0m \u001B[38;5;66;03m# Note: the following symbols are deliberately exported with `as`\u001B[39;00m\n\u001B[1;32m 40\u001B[0m \u001B[38;5;66;03m# so that mypy, pylint or other static linters can recognize them,\u001B[39;00m\n\u001B[1;32m 41\u001B[0m \u001B[38;5;66;03m# given that they are not exported using `__all__` in this file.\u001B[39;00m\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/dependency_versions_check.py:57\u001B[0m\n\u001B[1;32m 54\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m is_accelerate_available():\n\u001B[1;32m 55\u001B[0m \u001B[38;5;28;01mcontinue\u001B[39;00m \u001B[38;5;66;03m# not required, check version only if installed\u001B[39;00m\n\u001B[0;32m---> 57\u001B[0m \u001B[43mrequire_version_core\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdeps\u001B[49m\u001B[43m[\u001B[49m\u001B[43mpkg\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 58\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 59\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcan\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mt find \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpkg\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m in \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mdeps\u001B[38;5;241m.\u001B[39mkeys()\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, check dependency_versions_table.py\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/utils/versions.py:117\u001B[0m, in \u001B[0;36mrequire_version_core\u001B[0;34m(requirement)\u001B[0m\n\u001B[1;32m 115\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"require_version wrapper which emits a core-specific hint on failure\"\"\"\u001B[39;00m\n\u001B[1;32m 116\u001B[0m hint \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mTry: `pip install transformers -U` or `pip install -e \u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.[dev]\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m` if you\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mre working with git main\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m--> 117\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mrequire_version\u001B[49m\u001B[43m(\u001B[49m\u001B[43mrequirement\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mhint\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/utils/versions.py:111\u001B[0m, in \u001B[0;36mrequire_version\u001B[0;34m(requirement, hint)\u001B[0m\n\u001B[1;32m 109\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m want_ver \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 110\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m op, want_ver \u001B[38;5;129;01min\u001B[39;00m wanted\u001B[38;5;241m.\u001B[39mitems():\n\u001B[0;32m--> 111\u001B[0m \u001B[43m_compare_versions\u001B[49m\u001B[43m(\u001B[49m\u001B[43mop\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgot_ver\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mwant_ver\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mrequirement\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpkg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mhint\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/utils/versions.py:44\u001B[0m, in \u001B[0;36m_compare_versions\u001B[0;34m(op, got_ver, want_ver, requirement, pkg, hint)\u001B[0m\n\u001B[1;32m 39\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[1;32m 40\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mUnable to compare versions for \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mrequirement\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m: need=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mwant_ver\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m found=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mgot_ver\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m. This is unusual. Consider\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 41\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m reinstalling \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpkg\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 42\u001B[0m )\n\u001B[1;32m 43\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m ops[op](version\u001B[38;5;241m.\u001B[39mparse(got_ver), version\u001B[38;5;241m.\u001B[39mparse(want_ver)):\n\u001B[0;32m---> 44\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mImportError\u001B[39;00m(\n\u001B[1;32m 45\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mrequirement\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m is required for a normal functioning of this module, but found \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpkg\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m==\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mgot_ver\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mhint\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 46\u001B[0m )\n", - "\u001B[0;31mImportError\u001B[0m: tokenizers>=0.21,<0.22 is required for a normal functioning of this module, but found tokenizers==0.22.2.\nTry: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# NBVAL_IGNORE_OUTPUT\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mgradio\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mgr\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtransformer_lens\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m HookedTransformer\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtransformer_lens\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m to_numpy\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m HTML\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/transformer_lens/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m hook_points\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m utils\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m evals\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/transformer_lens/hook_points.py:20\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mhooks\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mhooks\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Tensor\n\u001b[0;32m---> 20\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtransformer_lens\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Slice, SliceInput\n\u001b[1;32m 23\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mLensHandle\u001b[39;00m:\n\u001b[1;32m 25\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Dataclass that holds information about a PyTorch hook.\"\"\"\u001b[39;00m\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/transformer_lens/utils.py:24\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnn\u001b[39;00m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mF\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtransformers\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mdatasets\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01marrow_dataset\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Dataset\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mdatasets\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mload\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/__init__.py:27\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TYPE_CHECKING\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# Check the dependencies satisfy the minimal versions required.\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m dependency_versions_check\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 29\u001b[0m OptionalDependencyNotAvailable,\n\u001b[1;32m 30\u001b[0m _LazyModule,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 36\u001b[0m is_pretty_midi_available,\n\u001b[1;32m 37\u001b[0m )\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# Note: the following symbols are deliberately exported with `as`\u001b[39;00m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;66;03m# so that mypy, pylint or other static linters can recognize them,\u001b[39;00m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;66;03m# given that they are not exported using `__all__` in this file.\u001b[39;00m\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/dependency_versions_check.py:57\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_accelerate_available():\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m \u001b[38;5;66;03m# not required, check version only if installed\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[43mrequire_version_core\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdeps\u001b[49m\u001b[43m[\u001b[49m\u001b[43mpkg\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpkg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdeps\u001b[38;5;241m.\u001b[39mkeys()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, check dependency_versions_table.py\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/utils/versions.py:117\u001b[0m, in \u001b[0;36mrequire_version_core\u001b[0;34m(requirement)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"require_version wrapper which emits a core-specific hint on failure\"\"\"\u001b[39;00m\n\u001b[1;32m 116\u001b[0m hint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTry: `pip install transformers -U` or `pip install -e \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.[dev]\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m` if you\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mre working with git main\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrequire_version\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequirement\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhint\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/utils/versions.py:111\u001b[0m, in \u001b[0;36mrequire_version\u001b[0;34m(requirement, hint)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m want_ver \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m op, want_ver \u001b[38;5;129;01min\u001b[39;00m wanted\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m--> 111\u001b[0m \u001b[43m_compare_versions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgot_ver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwant_ver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrequirement\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpkg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhint\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Projects/_TRANSFORMER_LENS/TransformerLens/.venv/lib/python3.12/site-packages/transformers/utils/versions.py:44\u001b[0m, in \u001b[0;36m_compare_versions\u001b[0;34m(op, got_ver, want_ver, requirement, pkg, hint)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 40\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnable to compare versions for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrequirement\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: need=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mwant_ver\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m found=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgot_ver\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is unusual. Consider\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m reinstalling \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpkg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 42\u001b[0m )\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m ops[op](version\u001b[38;5;241m.\u001b[39mparse(got_ver), version\u001b[38;5;241m.\u001b[39mparse(want_ver)):\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m 45\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrequirement\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is required for a normal functioning of this module, but found \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpkg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m==\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgot_ver\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhint\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 46\u001b[0m )\n", + "\u001b[0;31mImportError\u001b[0m: tokenizers>=0.21,<0.22 is required for a normal functioning of this module, but found tokenizers==0.22.2.\nTry: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main" ] } ], @@ -417,60 +417,10 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running on local URL: http://127.0.0.1:7860\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running on public URL: https://7a615281b36111d2e4.gradio.live\n", - "\n", - "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" - ] - }, - { - "data": { - "text/html": [ - "
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "demo.launch(share=True, height=1000)" - ] + "outputs": [], + "source": "# NBVAL_SKIP\ndemo.launch(share=True, height=1000)" } ], "metadata": { @@ -500,4 +450,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file From 5873b8aed1892cafdcba34dcb47930861348251d Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 17 Mar 2026 13:50:53 +0100 Subject: [PATCH 3/9] add xielu test --- tests/unit/test_xielu.py | 105 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 tests/unit/test_xielu.py diff --git a/tests/unit/test_xielu.py b/tests/unit/test_xielu.py new file mode 100644 index 000000000..77c331b2b --- /dev/null +++ b/tests/unit/test_xielu.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import pytest + +from transformer_lens.utils import XIELU, xielu, ACTIVATION_FN_DICT + + +class TestXIELUClass: + def test_default_parameters(self): + act = XIELU() + assert act.alpha_p.item() == pytest.approx(0.8) + assert act.alpha_n.item() == pytest.approx(0.8) + assert act.beta.item() == pytest.approx(0.5) + + def test_custom_parameters(self): + act = XIELU(alpha_p_init=1.0, alpha_n_init=0.5, beta_init=0.25) + assert act.alpha_p.item() == pytest.approx(1.0) + assert act.alpha_n.item() == pytest.approx(0.5) + assert act.beta.item() == pytest.approx(0.25) + + def test_parameters_are_trainable(self): + act = XIELU() + assert isinstance(act.alpha_p, nn.Parameter) + assert isinstance(act.alpha_n, nn.Parameter) + assert isinstance(act.beta, nn.Parameter) + assert act.alpha_p.requires_grad + assert act.alpha_n.requires_grad + assert act.beta.requires_grad + + def test_output_shape_preserved(self): + act = XIELU() + x = torch.randn(2, 10, 32) + assert act(x).shape == x.shape + + def test_positive_branch(self): + """For x > 0: f(x) = alpha_p * x^2 + beta * x.""" + act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0) + x = torch.tensor([1.0, 2.0, 3.0]) + expected = 1.0 * x ** 2 + 1.0 * x # alpha_p * x^2 + beta * x + torch.testing.assert_close(act(x), expected) + + def test_negative_branch(self): + """For x <= 0: f(x) = alpha_n * (exp(clamp(x, eps)) - 1) - alpha_n * x + beta * x.""" + alpha_n, beta, eps = 1.0, 1.0, -1e-6 + act = XIELU(alpha_p_init=1.0, alpha_n_init=alpha_n, beta_init=beta, eps=eps) + x = torch.tensor([-1.0, -2.0, -3.0]) + expected = ( + alpha_n * torch.expm1(torch.clamp_max(x, eps)) + - alpha_n * x + + beta * x + ) + torch.testing.assert_close(act(x), expected) + + def test_zero_input(self): + """x = 0 falls into the negative branch.""" + act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0, eps=-1e-6) + x = torch.tensor([0.0]) + expected = ( + 1.0 * torch.expm1(torch.clamp_max(x, -1e-6)) + - 1.0 * x + + 1.0 * x + ) + torch.testing.assert_close(act(x), expected) + + def test_gradients_flow_through(self): + act = XIELU() + x = torch.randn(4, 8, requires_grad=True) + out = act(x).sum() + out.backward() + assert x.grad is not None + assert act.alpha_p.grad is not None + assert act.alpha_n.grad is not None + assert act.beta.grad is not None + + def test_class_and_function_agree_at_defaults(self): + """XIELU class with default params should match the fixed xielu function.""" + act = XIELU() # defaults match xielu() fixed values + x = torch.randn(2, 5, 16) + torch.testing.assert_close(act(x), xielu(x)) + + +class TestXIELUFunction: + def test_output_shape_preserved(self): + x = torch.randn(2, 10, 32) + assert xielu(x).shape == x.shape + + def test_positive_values(self): + """For x > 0: f(x) = 0.8*x^2 + 0.5*x.""" + x = torch.tensor([1.0, 2.0]) + expected = 0.8 * x ** 2 + 0.5 * x + torch.testing.assert_close(xielu(x), expected) + + def test_negative_values(self): + """For x <= 0: f(x) = 0.8*(exp(clamp(x,-1e-6))-1) - 0.8*x + 0.5*x.""" + x = torch.tensor([-1.0, -2.0]) + expected = ( + 0.8 * torch.expm1(torch.clamp_max(x, -1e-6)) + - 0.8 * x + + 0.5 * x + ) + torch.testing.assert_close(xielu(x), expected) + + def test_registered_in_activation_dict(self): + assert "xielu" in ACTIVATION_FN_DICT + assert ACTIVATION_FN_DICT["xielu"] is xielu From b18a57159dc725fd49b691b2611373cd6f2df89a Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 17 Mar 2026 13:52:44 +0100 Subject: [PATCH 4/9] sdd notebook compatability --- demos/Colab_Compatibility.ipynb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 46b42c2ba..5a54f6417 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -996,6 +996,8 @@ " \"stabilityai/stablelm-base-alpha-7b\",\n", " \"stabilityai/stablelm-tuned-alpha-3b\",\n", " \"stabilityai/stablelm-tuned-alpha-7b\",\n", + " \"swiss-ai/Apertus-8B-2509\",\n", + " \"swiss-ai/Apertus-8B-Instruct-2509\",\n", "]\n", "\n", "if IN_COLAB:\n", From cfce7e65c07b724d1526c7b23ccd5882dd7da841 Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 17 Mar 2026 13:53:09 +0100 Subject: [PATCH 5/9] update utils import formatting --- transformer_lens/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 1f30c7254..234e3826b 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -34,7 +34,6 @@ from transformer_lens.FactoredMatrix import FactoredMatrix - CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE USE_DEFAULT_VALUE = None From ee86e87a22e62c43d687d507adabb7483fe2cb51 Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 17 Mar 2026 14:00:11 +0100 Subject: [PATCH 6/9] fix import style and add 3D inputs --- tests/unit/test_xielu.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_xielu.py b/tests/unit/test_xielu.py index 77c331b2b..3ee04dc84 100644 --- a/tests/unit/test_xielu.py +++ b/tests/unit/test_xielu.py @@ -1,8 +1,8 @@ +import pytest import torch import torch.nn as nn -import pytest -from transformer_lens.utils import XIELU, xielu, ACTIVATION_FN_DICT +from transformer_lens.utils import ACTIVATION_FN_DICT, XIELU, xielu class TestXIELUClass: @@ -35,7 +35,7 @@ def test_output_shape_preserved(self): def test_positive_branch(self): """For x > 0: f(x) = alpha_p * x^2 + beta * x.""" act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0) - x = torch.tensor([1.0, 2.0, 3.0]) + x = torch.tensor([[[1.0, 2.0, 3.0]]]) # (1, 1, 3) expected = 1.0 * x ** 2 + 1.0 * x # alpha_p * x^2 + beta * x torch.testing.assert_close(act(x), expected) @@ -43,7 +43,7 @@ def test_negative_branch(self): """For x <= 0: f(x) = alpha_n * (exp(clamp(x, eps)) - 1) - alpha_n * x + beta * x.""" alpha_n, beta, eps = 1.0, 1.0, -1e-6 act = XIELU(alpha_p_init=1.0, alpha_n_init=alpha_n, beta_init=beta, eps=eps) - x = torch.tensor([-1.0, -2.0, -3.0]) + x = torch.tensor([[[-1.0, -2.0, -3.0]]]) # (1, 1, 3) expected = ( alpha_n * torch.expm1(torch.clamp_max(x, eps)) - alpha_n * x @@ -54,7 +54,7 @@ def test_negative_branch(self): def test_zero_input(self): """x = 0 falls into the negative branch.""" act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0, eps=-1e-6) - x = torch.tensor([0.0]) + x = torch.tensor([[[0.0]]]) # (1, 1, 1) expected = ( 1.0 * torch.expm1(torch.clamp_max(x, -1e-6)) - 1.0 * x @@ -64,7 +64,7 @@ def test_zero_input(self): def test_gradients_flow_through(self): act = XIELU() - x = torch.randn(4, 8, requires_grad=True) + x = torch.randn(4, 8, 16, requires_grad=True) out = act(x).sum() out.backward() assert x.grad is not None @@ -86,13 +86,13 @@ def test_output_shape_preserved(self): def test_positive_values(self): """For x > 0: f(x) = 0.8*x^2 + 0.5*x.""" - x = torch.tensor([1.0, 2.0]) + x = torch.tensor([[[1.0, 2.0]]]) # (1, 1, 2) expected = 0.8 * x ** 2 + 0.5 * x torch.testing.assert_close(xielu(x), expected) def test_negative_values(self): """For x <= 0: f(x) = 0.8*(exp(clamp(x,-1e-6))-1) - 0.8*x + 0.5*x.""" - x = torch.tensor([-1.0, -2.0]) + x = torch.tensor([[[-1.0, -2.0]]]) # (1, 1, 2) expected = ( 0.8 * torch.expm1(torch.clamp_max(x, -1e-6)) - 0.8 * x From f3dcc2ec82beb02b7d7c22c344a7429c80f2c3d8 Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 17 Mar 2026 14:31:21 +0100 Subject: [PATCH 7/9] reformatting after installing pip black --- tests/unit/test_xielu.py | 22 +++------- .../pretrained/weight_conversions/apertus.py | 25 +++++++---- transformer_lens/utils.py | 43 ++++++++++++------- 3 files changed, 49 insertions(+), 41 deletions(-) diff --git a/tests/unit/test_xielu.py b/tests/unit/test_xielu.py index 3ee04dc84..866dce1b0 100644 --- a/tests/unit/test_xielu.py +++ b/tests/unit/test_xielu.py @@ -36,7 +36,7 @@ def test_positive_branch(self): """For x > 0: f(x) = alpha_p * x^2 + beta * x.""" act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0) x = torch.tensor([[[1.0, 2.0, 3.0]]]) # (1, 1, 3) - expected = 1.0 * x ** 2 + 1.0 * x # alpha_p * x^2 + beta * x + expected = 1.0 * x**2 + 1.0 * x # alpha_p * x^2 + beta * x torch.testing.assert_close(act(x), expected) def test_negative_branch(self): @@ -44,22 +44,14 @@ def test_negative_branch(self): alpha_n, beta, eps = 1.0, 1.0, -1e-6 act = XIELU(alpha_p_init=1.0, alpha_n_init=alpha_n, beta_init=beta, eps=eps) x = torch.tensor([[[-1.0, -2.0, -3.0]]]) # (1, 1, 3) - expected = ( - alpha_n * torch.expm1(torch.clamp_max(x, eps)) - - alpha_n * x - + beta * x - ) + expected = alpha_n * torch.expm1(torch.clamp_max(x, eps)) - alpha_n * x + beta * x torch.testing.assert_close(act(x), expected) def test_zero_input(self): """x = 0 falls into the negative branch.""" act = XIELU(alpha_p_init=1.0, alpha_n_init=1.0, beta_init=1.0, eps=-1e-6) x = torch.tensor([[[0.0]]]) # (1, 1, 1) - expected = ( - 1.0 * torch.expm1(torch.clamp_max(x, -1e-6)) - - 1.0 * x - + 1.0 * x - ) + expected = 1.0 * torch.expm1(torch.clamp_max(x, -1e-6)) - 1.0 * x + 1.0 * x torch.testing.assert_close(act(x), expected) def test_gradients_flow_through(self): @@ -87,17 +79,13 @@ def test_output_shape_preserved(self): def test_positive_values(self): """For x > 0: f(x) = 0.8*x^2 + 0.5*x.""" x = torch.tensor([[[1.0, 2.0]]]) # (1, 1, 2) - expected = 0.8 * x ** 2 + 0.5 * x + expected = 0.8 * x**2 + 0.5 * x torch.testing.assert_close(xielu(x), expected) def test_negative_values(self): """For x <= 0: f(x) = 0.8*(exp(clamp(x,-1e-6))-1) - 0.8*x + 0.5*x.""" x = torch.tensor([[[-1.0, -2.0]]]) # (1, 1, 2) - expected = ( - 0.8 * torch.expm1(torch.clamp_max(x, -1e-6)) - - 0.8 * x - + 0.5 * x - ) + expected = 0.8 * torch.expm1(torch.clamp_max(x, -1e-6)) - 0.8 * x + 0.5 * x torch.testing.assert_close(xielu(x), expected) def test_registered_in_activation_dict(self): diff --git a/transformer_lens/pretrained/weight_conversions/apertus.py b/transformer_lens/pretrained/weight_conversions/apertus.py index 739a84d83..9c76f8e87 100644 --- a/transformer_lens/pretrained/weight_conversions/apertus.py +++ b/transformer_lens/pretrained/weight_conversions/apertus.py @@ -21,12 +21,13 @@ def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) - assert cfg.d_mlp is not None # keep mypy happy for l in range(cfg.n_layers): state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight - state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + state_dict[f"blocks.{l}.ln1.b"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) W_Q = apertus.model.layers[l].self_attn.q_proj.weight W_K = apertus.model.layers[l].self_attn.k_proj.weight @@ -71,7 +72,9 @@ def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): ) state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight - state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + state_dict[f"blocks.{l}.ln2.b"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) # in case of quantization, # parameters should stay as bitsandbytes.nn.modules.Params4bit @@ -92,11 +95,11 @@ def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): # Extract trainable activation parameters mlp = apertus.model.layers[l].mlp try: - if hasattr(mlp, 'act_fn'): + if hasattr(mlp, "act_fn"): alpha_p = mlp.act_fn.alpha_p alpha_n = mlp.act_fn.alpha_n beta = mlp.act_fn.beta - elif hasattr(mlp, 'act'): + elif hasattr(mlp, "act"): alpha_p = mlp.act.alpha_p alpha_n = mlp.act.alpha_n beta = mlp.act.beta @@ -110,9 +113,15 @@ def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): except AttributeError: # If parameters not found, use defaults print(f"Activation parameters not found in layer {l}, using defaults") - state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device) - state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device) - state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor(0.5, dtype=cfg.dtype, device=cfg.device) + state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor( + 0.8, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor( + 0.8, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor( + 0.5, dtype=cfg.dtype, device=cfg.device + ) state_dict["ln_final.w"] = apertus.model.norm.weight state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 234e3826b..2d3eccd8c 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -232,32 +232,41 @@ class XIELU(nn.Module): } where α_p, α_n, β are trainable parameters. """ - def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta_init: float = 0.5, eps: float = -1e-6): + + def __init__( + self, + alpha_p_init: float = 0.8, + alpha_n_init: float = 0.8, + beta_init: float = 0.5, + eps: float = -1e-6, + ): super().__init__() self.alpha_p = nn.Parameter(torch.tensor(alpha_p_init, dtype=torch.float32)) self.alpha_n = nn.Parameter(torch.tensor(alpha_n_init, dtype=torch.float32)) self.beta = nn.Parameter(torch.tensor(beta_init, dtype=torch.float32)) self.eps = eps - def forward(self, input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: + def forward( + self, input: Float[torch.Tensor, "batch pos d_mlp"] + ) -> Float[torch.Tensor, "batch pos d_mlp"]: return torch.where( input > 0, - self.alpha_p * input ** 2 + self.beta * input, - self.alpha_n * torch.expm1(torch.clamp_max(input, self.eps)) - self.alpha_n * input + self.beta * input + self.alpha_p * input**2 + self.beta * input, + self.alpha_n * torch.expm1(torch.clamp_max(input, self.eps)) + - self.alpha_n * input + + self.beta * input, ) -def xielu( - input: Float[torch.Tensor, "batch pos d_mlp"] -) -> Float[torch.Tensor, "batch pos d_mlp"]: +def xielu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: """ xIELU activation function as described by https://arxiv.org/abs/2411.13010 and original code in: - https://github.com/rubber-duck-debug/xielu + https://github.com/rubber-duck-debug/xielu - Defined as + Defined as f(x) = { α_p * x² + β * x, if x > 0 @@ -270,11 +279,13 @@ def xielu( alpha_n: float = 0.8 beta: float = 0.5 eps: float = -1e-6 - + # The core calculation logic: - return torch.where(input > 0, - alpha_p * input * input + beta * input, - alpha_n * torch.expm1(torch.clamp_max(input, eps)) - alpha_n * input + beta * input) + return torch.where( + input > 0, + alpha_p * input * input + beta * input, + alpha_n * torch.expm1(torch.clamp_max(input, eps)) - alpha_n * input + beta * input, + ) ACTIVATION_FN_DICT = { @@ -421,9 +432,9 @@ def tokenize_and_concatenate( _deprecation_warnings_saved = None if hasattr(tokenizer, "deprecation_warnings"): _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() - tokenizer.deprecation_warnings[ - "sequence-length-is-longer-than-the-specified-maximum" - ] = False + tokenizer.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = ( + False + ) try: # Define the length to chop things up into - leaving space for a bos_token if required if add_bos_token: From d1b5abf69df608e36c164ae8bc73ead745b5ae95 Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 17 Mar 2026 15:04:53 +0100 Subject: [PATCH 8/9] reinstall black and fix formatting issues --- transformer_lens/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 2d3eccd8c..697faab39 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -432,9 +432,9 @@ def tokenize_and_concatenate( _deprecation_warnings_saved = None if hasattr(tokenizer, "deprecation_warnings"): _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() - tokenizer.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = ( - False - ) + tokenizer.deprecation_warnings[ + "sequence-length-is-longer-than-the-specified-maximum" + ] = False try: # Define the length to chop things up into - leaving space for a bos_token if required if add_bos_token: From 35228f3053c691bfa728e4e57bd6e3fde3b951d2 Mon Sep 17 00:00:00 2001 From: sinievanderben Date: Tue, 17 Mar 2026 15:36:33 +0100 Subject: [PATCH 9/9] assertion insert --- transformer_lens/loading_from_pretrained.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 7026f99d3..461a6d0a9 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1542,6 +1542,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): else: rope_type = "" if rope_type == "llama3": + assert rope_scaling is not None cfg_dict["use_NTK_by_parts_rope"] = True cfg_dict["NTK_original_ctx_len"] = rope_scaling.get( "original_max_position_embeddings", hf_config.max_position_embeddings