Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions tests/integration/model_bridge/test_cohere_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
Model: trl-internal-testing/tiny-CohereForCausalLM
- 2 layers, ~8M params, CPU-safe, no gating required
- tie_word_embeddings=True by default
- logit_scale=0.0625 (1/16)
- logit_scale=0.125 (canonical Command-R is 0.0625; tiny diverges so
regression tests catch silent-fallback bugs in the passthrough)

NOTE: The tiny model has use_qk_norm=False, so QK-norm is not exercised here.
Cohere's QK-norm is a per-head LayerNorm inside CohereAttention.forward; it is
Expand Down Expand Up @@ -96,8 +97,20 @@ def test_cfg_uses_rms_norm_false(self, cohere_bridge: TransformerBridge) -> None
def test_cfg_logit_scale_is_float(self, cohere_bridge: TransformerBridge) -> None:
assert isinstance(getattr(cohere_bridge.cfg, "logit_scale"), float)

def test_cfg_logit_scale_value(self, cohere_bridge: TransformerBridge) -> None:
assert getattr(cohere_bridge.cfg, "logit_scale") == pytest.approx(0.0625)
def test_cfg_logit_scale_matches_hf(
self, cohere_bridge: TransformerBridge, cohere_hf: Any
) -> None:
"""Regression: logit_scale must propagate from HF (not silently fall back to 0.0625)."""
bridge_scale = getattr(cohere_bridge.cfg, "logit_scale")
assert bridge_scale == cohere_hf.config.logit_scale
# Anchor 0.125 so a passthrough regression that defaults to 0.0625 also trips here.
assert bridge_scale == pytest.approx(0.125)

def test_cfg_rope_parameters_matches_hf(
self, cohere_bridge: TransformerBridge, cohere_hf: Any
) -> None:
"""Regression: rope_parameters must propagate from HF (same passthrough trap as logit_scale)."""
assert getattr(cohere_bridge.cfg, "rope_parameters") == cohere_hf.config.rope_parameters


# ---------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/model_bridge/sources/_bridge_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"chunk_size",
# Multimodal
"vision_config",
# Cohere
"logit_scale",
"rope_parameters",
]


Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,9 @@ def boot(
"chunk_size",
# Multimodal
"vision_config",
# Cohere
"logit_scale",
"rope_parameters",
]
for attr in _HF_PASSTHROUGH_ATTRS:
val = getattr(hf_config, attr, None)
Expand Down
Loading