Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/unit/test_model_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from functools import lru_cache

import pytest

from transformer_lens import loading
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


@lru_cache(maxsize=None)
def get_cached_config(model_name: str) -> HookedTransformerConfig:
"""Retrieve the configuration of a pretrained model.

Args:
model_name (str): Name of the pretrained model.

Returns:
HookedTransformerConfig: Configuration of the pretrained model.
"""
return loading.get_pretrained_model_config(model_name)


@pytest.mark.parametrize("model_name", loading.DEFAULT_MODEL_ALIASES)
def test_model_configurations(model_name: str):
"""Tests that all of the model configurations are in fact loaded (e.g. are not None)."""
assert get_cached_config(model_name) is not None, f"Configuration for {model_name} is None"
2 changes: 1 addition & 1 deletion transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class HookedTransformerConfig:
tokenizer_prepends_bos: Optional[bool] = None
n_key_value_heads: Optional[int] = None
post_embedding_ln: bool = False
rotary_base: int = 10000
rotary_base: float = 10000.0
trust_remote_code: bool = False
rotary_adjacent_pairs: bool = False
load_in_4bit: bool = False
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def calculate_sin_cos_rotary(
self,
rotary_dim: int,
n_ctx: int,
base: int = 10000,
base: float = 10000,
dtype: torch.dtype = torch.float32,
) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
"""
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 4096 // 32,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 1000000,
"rotary_base": 1000000.0,
}
if "python" in official_model_name.lower():
# The vocab size of python version of CodeLlama-7b is 32000
Expand Down