diff --git a/tests/unit/tools/test_model_registry.py b/tests/unit/tools/test_model_registry.py index eca6206cf..8fbb12c01 100644 --- a/tests/unit/tools/test_model_registry.py +++ b/tests/unit/tools/test_model_registry.py @@ -780,3 +780,96 @@ def test_required_quant_library(self, model_id, expected_library): ) assert required_quant_library_for_model(model_id) == expected_library + + +class TestRegistrySyncedWithFactory: + """Assert HF_SUPPORTED_ARCHITECTURES and CANONICAL_AUTHORS_BY_ARCH stay in + sync with architecture_adapter_factory.SUPPORTED_ARCHITECTURES. + + The module docstring of transformer_lens.tools.model_registry states the + invariant: HF_SUPPORTED_ARCHITECTURES "must correspond to adapters + registered in architecture_adapter_factory.py", with two documented + exception groups (internal-only architectures and factory-internal alias + casings). This class enforces that invariant bidirectionally so a future + adapter PR that forgets the registry update fails CI. + """ + + # Factory keys that are NOT expected in the registry sets. + # Two groups, matching the module docstring on HF_SUPPORTED_ARCHITECTURES: + # 1. Internal-only architectures that never appear on HuggingFace Hub. + # 2. Factory-internal alias casings that route to canonical adapters + # under names HF does not emit in config.architectures[]. + INTENTIONAL_EXCLUDES = frozenset( + { + # Group 1: internal-only architectures that never appear on HuggingFace Hub. + "NanoGPTForCausalLM", + "MinGPTForCausalLM", + "NeelSoluOldForCausalLM", + "GPT2LMHeadCustomModel", + "TransformerLensNative", + # Group 2: factory-internal alias casings (HF emits the canonical name). + "Gemma1ForCausalLM", # HF emits: GemmaForCausalLM + "NeoForCausalLM", # HF emits: GPTNeoForCausalLM + "NeoXForCausalLM", # HF emits: GPTNeoXForCausalLM + } + ) + + def test_every_factory_arch_is_in_hf_supported(self): + """Every non-excluded factory key must be present in HF_SUPPORTED_ARCHITECTURES.""" + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES + + missing = sorted( + k + for k in SUPPORTED_ARCHITECTURES + if k not in self.INTENTIONAL_EXCLUDES and k not in HF_SUPPORTED_ARCHITECTURES + ) + assert not missing, ( + f"Factory keys missing from HF_SUPPORTED_ARCHITECTURES: {missing}. " + "Add them to HF_SUPPORTED_ARCHITECTURES, or list them in " + "INTENTIONAL_EXCLUDES with a one-line reason." + ) + + def test_every_factory_arch_has_canonical_authors(self): + """Every non-excluded factory key must have a CANONICAL_AUTHORS_BY_ARCH entry.""" + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + from transformer_lens.tools.model_registry import CANONICAL_AUTHORS_BY_ARCH + + missing = sorted( + k + for k in SUPPORTED_ARCHITECTURES + if k not in self.INTENTIONAL_EXCLUDES and k not in CANONICAL_AUTHORS_BY_ARCH + ) + assert not missing, ( + f"Factory keys missing from CANONICAL_AUTHORS_BY_ARCH: {missing}. " + "Add them to CANONICAL_AUTHORS_BY_ARCH, or list them in " + "INTENTIONAL_EXCLUDES with a one-line reason." + ) + + def test_hf_supported_keys_have_a_factory_adapter(self): + """Every HF_SUPPORTED_ARCHITECTURES entry must correspond to a wired factory adapter.""" + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES + + orphaned = sorted(k for k in HF_SUPPORTED_ARCHITECTURES if k not in SUPPORTED_ARCHITECTURES) + assert ( + not orphaned + ), f"HF_SUPPORTED_ARCHITECTURES entries with no factory adapter: {orphaned}" + + def test_canonical_authors_keys_have_a_factory_adapter(self): + """Every CANONICAL_AUTHORS_BY_ARCH entry must correspond to a wired factory adapter.""" + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + from transformer_lens.tools.model_registry import CANONICAL_AUTHORS_BY_ARCH + + orphaned = sorted(k for k in CANONICAL_AUTHORS_BY_ARCH if k not in SUPPORTED_ARCHITECTURES) + assert ( + not orphaned + ), f"CANONICAL_AUTHORS_BY_ARCH entries with no factory adapter: {orphaned}" diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index 511ad01c0..424e4903c 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -39,10 +39,15 @@ # These must match the exact strings found in HF model config.architectures[] # and correspond to adapters registered in architecture_adapter_factory.py. # -# Internal-only architectures (NanoGPT, MinGPT, NeelSoluOld, GPT2LMHeadCustomModel) -# are excluded since they never appear on HuggingFace Hub. +# Internal-only architectures (NanoGPT, MinGPT, NeelSoluOld, GPT2LMHeadCustomModel, +# TransformerLensNative) are excluded since they never appear on HuggingFace Hub. +# Factory-internal alias casings (Gemma1, Neo, NeoX) are also excluded since they +# route to canonical adapters but HF reports the canonical names (Gemma, GPTNeo, +# GPTNeoX) in config.architectures instead. HF_SUPPORTED_ARCHITECTURES: set[str] = { "ApertusForCausalLM", + "BaiChuanForCausalLM", + "BaichuanForCausalLM", "BertForMaskedLM", "BloomForCausalLM", "CodeGenForCausalLM", @@ -85,6 +90,7 @@ "QwenForCausalLM", "Qwen2ForCausalLM", "Qwen3ForCausalLM", + "Qwen3MoeForCausalLM", "Qwen3NextForCausalLM", "Qwen3_5ForCausalLM", "StableLmForCausalLM", @@ -97,6 +103,7 @@ # download-threshold bypass and the docs table's "Canonical only" toggle. CANONICAL_AUTHORS_BY_ARCH: dict[str, list[str]] = { "ApertusForCausalLM": ["swiss-ai"], + "BaiChuanForCausalLM": ["baichuan-inc"], "BaichuanForCausalLM": ["baichuan-inc"], "BertForMaskedLM": ["google-bert"], "BloomForCausalLM": ["bigscience"], @@ -140,6 +147,7 @@ "PhiForCausalLM": ["microsoft"], "Qwen2ForCausalLM": ["Qwen", "nvidia"], "Qwen3ForCausalLM": ["Qwen", "nvidia"], + "Qwen3MoeForCausalLM": ["Qwen"], "Qwen3NextForCausalLM": ["Qwen"], "Qwen3_5ForCausalLM": ["Qwen"], "QwenForCausalLM": ["Qwen"],