Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion tests/unit/utilities/test_logits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import pytest
import torch

from transformer_lens.utilities.logits_utils import logits_to_df
from transformer_lens.utilities.logits_utils import (
_apply_repetition_penalty,
logits_to_df,
sample_logits,
)


class _StubTokenizer:
Expand Down Expand Up @@ -85,3 +89,96 @@ def test_rejects_non_1d_input(self):

with pytest.raises(BeartypeCallHintParamViolation):
logits_to_df(torch.zeros(3, 4))


class TestSampleLogitsTopK:
def test_top_k_larger_than_vocab_does_not_crash(self):
# Regression test: before clamping top_k, final_logits.topk(top_k)
# raised "selected index k out of range" when top_k > vocab size.
out = sample_logits(torch.randn(1, 3), top_k=10)
assert out.shape == (1,)
assert 0 <= out.item() < 3

def test_top_k_larger_than_vocab_batched(self):
out = sample_logits(torch.randn(4, 5), top_k=8)
assert out.shape == (4,)
assert torch.all((out >= 0) & (out < 5))

def test_top_k_equal_to_vocab(self):
out = sample_logits(torch.randn(1, 4), top_k=4)
assert out.shape == (1,)
assert 0 <= out.item() < 4

def test_top_k_restricts_to_dominant_token(self):
# With top_k=1 only the argmax token is ever sampled.
logits = torch.tensor([[0.0, 100.0, 0.0, 0.0]])
outs = [sample_logits(logits, top_k=1).item() for _ in range(20)]
assert set(outs) == {1}

def test_top_k_rejects_non_positive(self):
with pytest.raises(AssertionError):
sample_logits(torch.randn(1, 4), top_k=0)


class TestSampleLogitsTemperature:
def test_temperature_zero_is_greedy_argmax(self):
logits = torch.tensor([[1.0, 3.0, 2.0, 0.5]])
out = sample_logits(logits, temperature=0.0)
assert out.tolist() == [1]

def test_temperature_zero_batched_argmax(self):
logits = torch.tensor([[1.0, 3.0, 2.0], [5.0, 0.0, 1.0]])
out = sample_logits(logits, temperature=0.0)
assert out.tolist() == [1, 0]

def test_temperature_zero_applies_repetition_penalty(self):
# Token 1 is the argmax but has appeared, so the penalty should push
# the greedy choice onto the next-best unseen token (token 2).
logits = torch.tensor([[0.0, 10.0, 9.0, 0.0]])
tokens = torch.tensor([[1]])
out = sample_logits(logits, temperature=0.0, repetition_penalty=100.0, tokens=tokens)
assert out.tolist() == [2]


class TestSampleLogitsTopP:
def test_top_p_keeps_dominant_token(self):
# One token holds essentially all the probability mass, so even a small
# top_p must keep it and it is the only token ever sampled.
logits = torch.tensor([[0.0, 50.0, 0.0, 0.0]])
outs = [sample_logits(logits, top_p=0.5).item() for _ in range(20)]
assert set(outs) == {1}

def test_top_p_rejects_out_of_range(self):
with pytest.raises(AssertionError):
sample_logits(torch.randn(1, 4), top_p=0.0)
with pytest.raises(AssertionError):
sample_logits(torch.randn(1, 4), top_p=1.5)


class TestSampleLogitsFreqPenalty:
def test_freq_penalty_suppresses_repeated_token(self):
# Token 0 starts as the clear favourite, but appears many times in the
# context; a large frequency penalty should make it never get sampled.
logits = torch.tensor([[5.0, 4.0, 4.0, 4.0]])
tokens = torch.zeros((1, 50), dtype=torch.long) # token 0 repeated 50x
outs = [sample_logits(logits, freq_penalty=10.0, tokens=tokens).item() for _ in range(50)]
assert 0 not in outs

def test_freq_penalty_requires_tokens(self):
with pytest.raises(AssertionError):
sample_logits(torch.randn(1, 4), freq_penalty=1.0)


class TestApplyRepetitionPenalty:
def test_positive_logits_divided_negative_multiplied(self):
logits = torch.tensor([[2.0, -2.0, 0.0]])
tokens = torch.tensor([[0, 1]])
out = _apply_repetition_penalty(logits, tokens, penalty=2.0)
# token 0 positive -> divided; token 1 negative -> multiplied; token 2 untouched
assert out.tolist() == [[1.0, -4.0, 0.0]]

def test_does_not_mutate_input(self):
logits = torch.tensor([[2.0, -2.0, 0.0]])
original = logits.clone()
_apply_repetition_penalty(logits, torch.tensor([[0, 1]]), penalty=2.0)
assert torch.equal(logits, original)
9 changes: 5 additions & 4 deletions transformer_lens/utilities/logits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def sample_logits(

Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling.

#! TODO: Finish testing all the edge cases here. Useful testing code:
logits = torch.randn(4)
print(logits)
np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
When ``top_k`` exceeds the vocabulary size it is clamped to the vocabulary size (matching HuggingFace), rather than raising an error.
"""
if temperature == 0.0:
# Greedy sampling - still apply repetition penalty before argmax
Expand Down Expand Up @@ -128,6 +125,10 @@ def sample_logits(
)
if top_k is not None:
assert top_k > 0, "top_k has to be greater than 0"
# Clamp top_k to the vocab size so a large value does not raise
# "selected index k out of range" (matches HuggingFace's
# TopKLogitsWarper, which does top_k = min(top_k, logits.size(-1))).
top_k = min(top_k, final_logits.shape[-1])
top_logits, top_idx = final_logits.topk(top_k, dim=-1)
indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
Expand Down
Loading