diff --git a/tests/integration/model_bridge/test_cohere_adapter.py b/tests/integration/model_bridge/test_cohere_adapter.py index 5e92c1616..6111ed834 100644 --- a/tests/integration/model_bridge/test_cohere_adapter.py +++ b/tests/integration/model_bridge/test_cohere_adapter.py @@ -3,7 +3,8 @@ Model: trl-internal-testing/tiny-CohereForCausalLM - 2 layers, ~8M params, CPU-safe, no gating required - tie_word_embeddings=True by default - - logit_scale=0.0625 (1/16) + - logit_scale=0.125 (canonical Command-R is 0.0625; tiny diverges so + regression tests catch silent-fallback bugs in the passthrough) NOTE: The tiny model has use_qk_norm=False, so QK-norm is not exercised here. Cohere's QK-norm is a per-head LayerNorm inside CohereAttention.forward; it is @@ -96,8 +97,20 @@ def test_cfg_uses_rms_norm_false(self, cohere_bridge: TransformerBridge) -> None def test_cfg_logit_scale_is_float(self, cohere_bridge: TransformerBridge) -> None: assert isinstance(getattr(cohere_bridge.cfg, "logit_scale"), float) - def test_cfg_logit_scale_value(self, cohere_bridge: TransformerBridge) -> None: - assert getattr(cohere_bridge.cfg, "logit_scale") == pytest.approx(0.0625) + def test_cfg_logit_scale_matches_hf( + self, cohere_bridge: TransformerBridge, cohere_hf: Any + ) -> None: + """Regression: logit_scale must propagate from HF (not silently fall back to 0.0625).""" + bridge_scale = getattr(cohere_bridge.cfg, "logit_scale") + assert bridge_scale == cohere_hf.config.logit_scale + # Anchor 0.125 so a passthrough regression that defaults to 0.0625 also trips here. + assert bridge_scale == pytest.approx(0.125) + + def test_cfg_rope_parameters_matches_hf( + self, cohere_bridge: TransformerBridge, cohere_hf: Any + ) -> None: + """Regression: rope_parameters must propagate from HF (same passthrough trap as logit_scale).""" + assert getattr(cohere_bridge.cfg, "rope_parameters") == cohere_hf.config.rope_parameters # --------------------------------------------------------------------------- diff --git a/transformer_lens/model_bridge/sources/_bridge_builder.py b/transformer_lens/model_bridge/sources/_bridge_builder.py index 5ae2dc466..0f5adb927 100644 --- a/transformer_lens/model_bridge/sources/_bridge_builder.py +++ b/transformer_lens/model_bridge/sources/_bridge_builder.py @@ -39,6 +39,9 @@ "chunk_size", # Multimodal "vision_config", + # Cohere + "logit_scale", + "rope_parameters", ] diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index e30a022f4..049cda274 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -502,6 +502,9 @@ def boot( "chunk_size", # Multimodal "vision_config", + # Cohere + "logit_scale", + "rope_parameters", ] for attr in _HF_PASSTHROUGH_ATTRS: val = getattr(hf_config, attr, None)