diff --git a/tests/unit/utilities/test_logits_utils.py b/tests/unit/utilities/test_logits_utils.py index b5a8c2602..bdd7ed5a3 100644 --- a/tests/unit/utilities/test_logits_utils.py +++ b/tests/unit/utilities/test_logits_utils.py @@ -4,7 +4,11 @@ import pytest import torch -from transformer_lens.utilities.logits_utils import logits_to_df +from transformer_lens.utilities.logits_utils import ( + _apply_repetition_penalty, + logits_to_df, + sample_logits, +) class _StubTokenizer: @@ -85,3 +89,96 @@ def test_rejects_non_1d_input(self): with pytest.raises(BeartypeCallHintParamViolation): logits_to_df(torch.zeros(3, 4)) + + +class TestSampleLogitsTopK: + def test_top_k_larger_than_vocab_does_not_crash(self): + # Regression test: before clamping top_k, final_logits.topk(top_k) + # raised "selected index k out of range" when top_k > vocab size. + out = sample_logits(torch.randn(1, 3), top_k=10) + assert out.shape == (1,) + assert 0 <= out.item() < 3 + + def test_top_k_larger_than_vocab_batched(self): + out = sample_logits(torch.randn(4, 5), top_k=8) + assert out.shape == (4,) + assert torch.all((out >= 0) & (out < 5)) + + def test_top_k_equal_to_vocab(self): + out = sample_logits(torch.randn(1, 4), top_k=4) + assert out.shape == (1,) + assert 0 <= out.item() < 4 + + def test_top_k_restricts_to_dominant_token(self): + # With top_k=1 only the argmax token is ever sampled. + logits = torch.tensor([[0.0, 100.0, 0.0, 0.0]]) + outs = [sample_logits(logits, top_k=1).item() for _ in range(20)] + assert set(outs) == {1} + + def test_top_k_rejects_non_positive(self): + with pytest.raises(AssertionError): + sample_logits(torch.randn(1, 4), top_k=0) + + +class TestSampleLogitsTemperature: + def test_temperature_zero_is_greedy_argmax(self): + logits = torch.tensor([[1.0, 3.0, 2.0, 0.5]]) + out = sample_logits(logits, temperature=0.0) + assert out.tolist() == [1] + + def test_temperature_zero_batched_argmax(self): + logits = torch.tensor([[1.0, 3.0, 2.0], [5.0, 0.0, 1.0]]) + out = sample_logits(logits, temperature=0.0) + assert out.tolist() == [1, 0] + + def test_temperature_zero_applies_repetition_penalty(self): + # Token 1 is the argmax but has appeared, so the penalty should push + # the greedy choice onto the next-best unseen token (token 2). + logits = torch.tensor([[0.0, 10.0, 9.0, 0.0]]) + tokens = torch.tensor([[1]]) + out = sample_logits(logits, temperature=0.0, repetition_penalty=100.0, tokens=tokens) + assert out.tolist() == [2] + + +class TestSampleLogitsTopP: + def test_top_p_keeps_dominant_token(self): + # One token holds essentially all the probability mass, so even a small + # top_p must keep it and it is the only token ever sampled. + logits = torch.tensor([[0.0, 50.0, 0.0, 0.0]]) + outs = [sample_logits(logits, top_p=0.5).item() for _ in range(20)] + assert set(outs) == {1} + + def test_top_p_rejects_out_of_range(self): + with pytest.raises(AssertionError): + sample_logits(torch.randn(1, 4), top_p=0.0) + with pytest.raises(AssertionError): + sample_logits(torch.randn(1, 4), top_p=1.5) + + +class TestSampleLogitsFreqPenalty: + def test_freq_penalty_suppresses_repeated_token(self): + # Token 0 starts as the clear favourite, but appears many times in the + # context; a large frequency penalty should make it never get sampled. + logits = torch.tensor([[5.0, 4.0, 4.0, 4.0]]) + tokens = torch.zeros((1, 50), dtype=torch.long) # token 0 repeated 50x + outs = [sample_logits(logits, freq_penalty=10.0, tokens=tokens).item() for _ in range(50)] + assert 0 not in outs + + def test_freq_penalty_requires_tokens(self): + with pytest.raises(AssertionError): + sample_logits(torch.randn(1, 4), freq_penalty=1.0) + + +class TestApplyRepetitionPenalty: + def test_positive_logits_divided_negative_multiplied(self): + logits = torch.tensor([[2.0, -2.0, 0.0]]) + tokens = torch.tensor([[0, 1]]) + out = _apply_repetition_penalty(logits, tokens, penalty=2.0) + # token 0 positive -> divided; token 1 negative -> multiplied; token 2 untouched + assert out.tolist() == [[1.0, -4.0, 0.0]] + + def test_does_not_mutate_input(self): + logits = torch.tensor([[2.0, -2.0, 0.0]]) + original = logits.clone() + _apply_repetition_penalty(logits, torch.tensor([[0, 1]]), penalty=2.0) + assert torch.equal(logits, original) diff --git a/transformer_lens/utilities/logits_utils.py b/transformer_lens/utilities/logits_utils.py index 2bbc7a78f..2baacc22f 100644 --- a/transformer_lens/utilities/logits_utils.py +++ b/transformer_lens/utilities/logits_utils.py @@ -96,10 +96,7 @@ def sample_logits( Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling. - #! TODO: Finish testing all the edge cases here. Useful testing code: - logits = torch.randn(4) - print(logits) - np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) + When ``top_k`` exceeds the vocabulary size it is clamped to the vocabulary size (matching HuggingFace), rather than raising an error. """ if temperature == 0.0: # Greedy sampling - still apply repetition penalty before argmax @@ -128,6 +125,10 @@ def sample_logits( ) if top_k is not None: assert top_k > 0, "top_k has to be greater than 0" + # Clamp top_k to the vocab size so a large value does not raise + # "selected index k out of range" (matches HuggingFace's + # TopKLogitsWarper, which does top_k = min(top_k, logits.size(-1))). + top_k = min(top_k, final_logits.shape[-1]) top_logits, top_idx = final_logits.topk(top_k, dim=-1) indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))