Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tests/unit/tools/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,96 @@ def test_required_quant_library(self, model_id, expected_library):
)

assert required_quant_library_for_model(model_id) == expected_library


class TestRegistrySyncedWithFactory:
"""Assert HF_SUPPORTED_ARCHITECTURES and CANONICAL_AUTHORS_BY_ARCH stay in
sync with architecture_adapter_factory.SUPPORTED_ARCHITECTURES.

The module docstring of transformer_lens.tools.model_registry states the
invariant: HF_SUPPORTED_ARCHITECTURES "must correspond to adapters
registered in architecture_adapter_factory.py", with two documented
exception groups (internal-only architectures and factory-internal alias
casings). This class enforces that invariant bidirectionally so a future
adapter PR that forgets the registry update fails CI.
"""

# Factory keys that are NOT expected in the registry sets.
# Two groups, matching the module docstring on HF_SUPPORTED_ARCHITECTURES:
# 1. Internal-only architectures that never appear on HuggingFace Hub.
# 2. Factory-internal alias casings that route to canonical adapters
# under names HF does not emit in config.architectures[].
INTENTIONAL_EXCLUDES = frozenset(
{
# Group 1: internal-only architectures that never appear on HuggingFace Hub.
"NanoGPTForCausalLM",
"MinGPTForCausalLM",
"NeelSoluOldForCausalLM",
"GPT2LMHeadCustomModel",
"TransformerLensNative",
# Group 2: factory-internal alias casings (HF emits the canonical name).
"Gemma1ForCausalLM", # HF emits: GemmaForCausalLM
"NeoForCausalLM", # HF emits: GPTNeoForCausalLM
"NeoXForCausalLM", # HF emits: GPTNeoXForCausalLM
}
)

def test_every_factory_arch_is_in_hf_supported(self):
"""Every non-excluded factory key must be present in HF_SUPPORTED_ARCHITECTURES."""
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
)
from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES

missing = sorted(
k
for k in SUPPORTED_ARCHITECTURES
if k not in self.INTENTIONAL_EXCLUDES and k not in HF_SUPPORTED_ARCHITECTURES
)
assert not missing, (
f"Factory keys missing from HF_SUPPORTED_ARCHITECTURES: {missing}. "
"Add them to HF_SUPPORTED_ARCHITECTURES, or list them in "
"INTENTIONAL_EXCLUDES with a one-line reason."
)

def test_every_factory_arch_has_canonical_authors(self):
"""Every non-excluded factory key must have a CANONICAL_AUTHORS_BY_ARCH entry."""
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
)
from transformer_lens.tools.model_registry import CANONICAL_AUTHORS_BY_ARCH

missing = sorted(
k
for k in SUPPORTED_ARCHITECTURES
if k not in self.INTENTIONAL_EXCLUDES and k not in CANONICAL_AUTHORS_BY_ARCH
)
assert not missing, (
f"Factory keys missing from CANONICAL_AUTHORS_BY_ARCH: {missing}. "
"Add them to CANONICAL_AUTHORS_BY_ARCH, or list them in "
"INTENTIONAL_EXCLUDES with a one-line reason."
)

def test_hf_supported_keys_have_a_factory_adapter(self):
"""Every HF_SUPPORTED_ARCHITECTURES entry must correspond to a wired factory adapter."""
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
)
from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES

orphaned = sorted(k for k in HF_SUPPORTED_ARCHITECTURES if k not in SUPPORTED_ARCHITECTURES)
assert (
not orphaned
), f"HF_SUPPORTED_ARCHITECTURES entries with no factory adapter: {orphaned}"

def test_canonical_authors_keys_have_a_factory_adapter(self):
"""Every CANONICAL_AUTHORS_BY_ARCH entry must correspond to a wired factory adapter."""
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
)
from transformer_lens.tools.model_registry import CANONICAL_AUTHORS_BY_ARCH

orphaned = sorted(k for k in CANONICAL_AUTHORS_BY_ARCH if k not in SUPPORTED_ARCHITECTURES)
assert (
not orphaned
), f"CANONICAL_AUTHORS_BY_ARCH entries with no factory adapter: {orphaned}"
12 changes: 10 additions & 2 deletions transformer_lens/tools/model_registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@
# These must match the exact strings found in HF model config.architectures[]
# and correspond to adapters registered in architecture_adapter_factory.py.
#
# Internal-only architectures (NanoGPT, MinGPT, NeelSoluOld, GPT2LMHeadCustomModel)
# are excluded since they never appear on HuggingFace Hub.
# Internal-only architectures (NanoGPT, MinGPT, NeelSoluOld, GPT2LMHeadCustomModel,
# TransformerLensNative) are excluded since they never appear on HuggingFace Hub.
# Factory-internal alias casings (Gemma1, Neo, NeoX) are also excluded since they
# route to canonical adapters but HF reports the canonical names (Gemma, GPTNeo,
# GPTNeoX) in config.architectures instead.
HF_SUPPORTED_ARCHITECTURES: set[str] = {
"ApertusForCausalLM",
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BertForMaskedLM",
"BloomForCausalLM",
"CodeGenForCausalLM",
Expand Down Expand Up @@ -85,6 +90,7 @@
"QwenForCausalLM",
"Qwen2ForCausalLM",
"Qwen3ForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"Qwen3_5ForCausalLM",
"StableLmForCausalLM",
Expand All @@ -97,6 +103,7 @@
# download-threshold bypass and the docs table's "Canonical only" toggle.
CANONICAL_AUTHORS_BY_ARCH: dict[str, list[str]] = {
"ApertusForCausalLM": ["swiss-ai"],
"BaiChuanForCausalLM": ["baichuan-inc"],
"BaichuanForCausalLM": ["baichuan-inc"],
"BertForMaskedLM": ["google-bert"],
"BloomForCausalLM": ["bigscience"],
Expand Down Expand Up @@ -140,6 +147,7 @@
"PhiForCausalLM": ["microsoft"],
"Qwen2ForCausalLM": ["Qwen", "nvidia"],
"Qwen3ForCausalLM": ["Qwen", "nvidia"],
"Qwen3MoeForCausalLM": ["Qwen"],
"Qwen3NextForCausalLM": ["Qwen"],
"Qwen3_5ForCausalLM": ["Qwen"],
"QwenForCausalLM": ["Qwen"],
Expand Down
Loading