diff --git a/tests/unit/test_model_configurations.py b/tests/unit/test_model_configurations.py new file mode 100644 index 000000000..2337d40be --- /dev/null +++ b/tests/unit/test_model_configurations.py @@ -0,0 +1,25 @@ +from functools import lru_cache + +import pytest + +from transformer_lens import loading +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +@lru_cache(maxsize=None) +def get_cached_config(model_name: str) -> HookedTransformerConfig: + """Retrieve the configuration of a pretrained model. + + Args: + model_name (str): Name of the pretrained model. + + Returns: + HookedTransformerConfig: Configuration of the pretrained model. + """ + return loading.get_pretrained_model_config(model_name) + + +@pytest.mark.parametrize("model_name", loading.DEFAULT_MODEL_ALIASES) +def test_model_configurations(model_name: str): + """Tests that all of the model configurations are in fact loaded (e.g. are not None).""" + assert get_cached_config(model_name) is not None, f"Configuration for {model_name} is None" diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 4458705de..ec48eb185 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -245,7 +245,7 @@ class HookedTransformerConfig: tokenizer_prepends_bos: Optional[bool] = None n_key_value_heads: Optional[int] = None post_embedding_ln: bool = False - rotary_base: int = 10000 + rotary_base: float = 10000.0 trust_remote_code: bool = False rotary_adjacent_pairs: bool = False load_in_4bit: bool = False diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 347548f34..6311aa2e4 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -469,7 +469,7 @@ def calculate_sin_cos_rotary( self, rotary_dim: int, n_ctx: int, - base: int = 10000, + base: float = 10000, dtype: torch.dtype = torch.float32, ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: """ diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index aa544786f..2e929a810 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -777,7 +777,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 4096 // 32, "final_rms": True, "gated_mlp": True, - "rotary_base": 1000000, + "rotary_base": 1000000.0, } if "python" in official_model_name.lower(): # The vocab size of python version of CodeLlama-7b is 32000