Skip to content

Commit 86f062e

Browse files
stephantulPringled
andauthored
fix: refix tokenizer with added token shenanigans (#304)
* fix: refix tokenizer with added token shenanigans * fix: remove all added tokens except important ones * fix bug where 0 was sent to else path * up skeletoken version * update test + lockfile * Update model2vec/distill/distillation.py --------- Co-authored-by: Thomas van Dongen <thomas123@live.nl>
1 parent af989d5 commit 86f062e

6 files changed

Lines changed: 72 additions & 26 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
from huggingface_hub.hf_api import model_info
1010
from skeletoken import TokenizerModel
11+
from skeletoken.external.transformers import reshape_embeddings
1112
from transformers import AutoModel, AutoTokenizer
1213
from transformers.modeling_utils import PreTrainedModel
1314
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
@@ -77,17 +78,22 @@ def distill_from_model(
7778

7879
device = select_optimal_device(device)
7980
original_tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer)
81+
original_tokenizer_model = original_tokenizer_model.prune_added_tokens()
8082

8183
# Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
8284
# Copy the original tokenizer model.
83-
tokenizer_model = original_tokenizer_model.model_copy(deep=True)
85+
tokenizer_model = original_tokenizer_model.deep_copy()
8486
if tokenizer_model.adds_prefix_space is not None:
8587
tokenizer_model.adds_prefix_space = True
8688

8789
# Create the vocabulary in the new tokenizer.
8890
tokenizer_model = clean_and_create_vocabulary(tokenizer_model, vocabulary, token_remove_regex=token_remove_regex)
8991
# Remove the post processor, this is not necessary.
9092
tokenizer_model.post_processor = None
93+
# Prune again now that the post processor is gone.
94+
# We can't do this before because we need the post processor and associated
95+
# tokens before to add eos/bos.
96+
tokenizer_model = tokenizer_model.prune_added_tokens()
9197

9298
# All tokens in a single list.
9399
all_tokens = tokenizer_model.sorted_vocabulary
@@ -97,12 +103,15 @@ def distill_from_model(
97103
# Turn all _new_ tokens into ids using the original tokenizer
98104
token_ids = turn_tokens_into_ids(all_tokens, original_tokenizer_model)
99105

106+
# Reshape the transformer
107+
model = reshape_embeddings(model, original_tokenizer_model)
108+
100109
# Create the embeddings using the ids from the original tokenizer.
101110
embeddings = create_embeddings(
102111
tokenized=token_ids,
103112
model=model,
104113
device=device,
105-
pad_token_id=tokenizer_model.pad_token_id or 0,
114+
pad_token_id=original_tokenizer_model.pad_token_id or 0,
106115
pooling=pooling,
107116
)
108117

model2vec/tokenizer/tokenizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def turn_tokens_into_ids(tokens: list[str], model: TokenizerModel) -> list[list[
8989

9090
token_ids: list[list[int]] = []
9191
for token in tokens:
92-
if token_id := vocabulary.get(token):
92+
token_id = vocabulary.get(token)
93+
if token_id is not None:
9394
token_ids.append([*prefix, token_id, *suffix])
9495
else:
9596
token_ids.append(tokenizer.encode(token).ids)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ dev = [
6060
"ruff",
6161
]
6262

63-
distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.0"]
63+
distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.1"]
6464
onnx = ["onnx", "torch"]
6565
# train also installs inference
6666
train = ["torch", "lightning", "scikit-learn", "skops"]

tests/conftest.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,24 @@ def mock_tokenizermodel() -> TokenizerModel:
6363

6464

6565
@pytest.fixture
66-
def mock_transformer() -> PreTrainedModel:
66+
def mock_transformer(request: pytest.FixtureRequest) -> PreTrainedModel:
6767
"""Create a mock transformer model."""
68+
params = getattr(request, "param", {}) or {}
69+
# Default vocab size
70+
vocab_size: int = params.get("vocab_size", 30522)
71+
dim: int = params.get("dim", 768)
72+
with_pooler: bool = params.get("with_pooler", True)
73+
pooler_value: float = params.get("pooler_value", 7.0)
6874

6975
class MockPreTrainedModel:
70-
def __init__(self, dim: int = 768, with_pooler: bool = True, pooler_value: float = 7.0) -> None:
76+
def __init__(self, vocab_size: int, dim: int, with_pooler: bool, pooler_value: float) -> None:
7177
self.device = "cpu"
7278
self.name_or_path = "mock-model"
7379
self.dim = dim
7480
self.with_pooler = with_pooler
7581
self.pooler_value = pooler_value
82+
self.input_embs = torch.nn.Embedding(vocab_size, dim)
83+
self.config: dict[str, Any] = {}
7684

7785
def to(self, device: str) -> MockPreTrainedModel:
7886
self.device = device
@@ -92,7 +100,29 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
92100

93101
__call__ = forward
94102

95-
return cast(PreTrainedModel, MockPreTrainedModel())
103+
def get_input_embeddings(self) -> torch.nn.Embedding:
104+
return self.input_embs
105+
106+
def resize_token_embeddings(self, vocab_size: int) -> None:
107+
curr_size = len(self.input_embs.weight)
108+
if vocab_size == curr_size:
109+
return
110+
if vocab_size < curr_size:
111+
self.input_embs.weight.data = self.input_embs.weight.data[: vocab_size + 1]
112+
else:
113+
self.input_embs.weight.data = torch.cat(
114+
[self.input_embs.weight, torch.zeros(vocab_size - curr_size, self.dim)], dim=0
115+
)
116+
117+
return cast(
118+
PreTrainedModel,
119+
MockPreTrainedModel(
120+
dim=dim,
121+
with_pooler=with_pooler,
122+
pooler_value=pooler_value,
123+
vocab_size=vocab_size,
124+
),
125+
)
96126

97127

98128
@pytest.fixture(scope="session")

tests/test_distillation.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,22 @@ def test_distill_removal_pattern_all_tokens(
9292
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
9393
mock_auto_model.return_value = mock_transformer
9494

95-
with pytest.raises(ValueError):
96-
distill_from_model(
97-
model=mock_transformer,
98-
tokenizer=mock_berttokenizer,
99-
vocabulary=None,
100-
device="cpu",
101-
token_remove_pattern=r".*",
102-
)
95+
# Even if we remove all tokens, we can't remove the [UNK] token
96+
model = distill_from_model(
97+
model=mock_transformer,
98+
tokenizer=mock_berttokenizer,
99+
vocabulary=None,
100+
device="cpu",
101+
token_remove_pattern=r".*",
102+
)
103+
104+
# So the only token left is the [UNK] token.
105+
assert model.tokens == ("[UNK]",)
103106

104107

105108
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
106109
@patch("transformers.AutoModel.from_pretrained")
110+
@pytest.mark.parametrize("mock_transformer", [{"vocab_size": 35022}], indirect=True)
107111
def test_distill_removal_pattern(
108112
mock_auto_model: MagicMock,
109113
mock_model_info: MagicMock,
@@ -114,7 +118,8 @@ def test_distill_removal_pattern(
114118
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
115119
mock_auto_model.return_value = mock_transformer
116120

117-
expected_vocab_size = mock_berttokenizer.vocab_size
121+
# Because the added [MASK], [CLS] and [SEP] get removed
122+
expected_vocab_size = mock_berttokenizer.vocab_size - 3
118123

119124
static_model = distill_from_model(
120125
model=mock_transformer,
@@ -159,18 +164,19 @@ def test_distill_removal_pattern(
159164
@pytest.mark.parametrize(
160165
"vocabulary, pca_dims, sif_coefficient, expected_shape",
161166
[
162-
(None, 256, None, (30522, 256)), # PCA applied, SIF off
163-
(None, "auto", None, (30522, 768)), # PCA 'auto', SIF off
164-
(None, "auto", 1e-4, (30522, 768)), # PCA 'auto', SIF on
167+
(None, 256, None, (30519, 256)), # PCA applied, SIF off
168+
(None, "auto", None, (30519, 768)), # PCA 'auto', SIF off
169+
(None, "auto", 1e-4, (30519, 768)), # PCA 'auto', SIF on
165170
(None, "auto", 0, None), # invalid SIF (too low) -> raises
166171
(None, "auto", 1, None), # invalid SIF (too high) -> raises
167-
(None, 1024, None, (30522, 768)), # PCA set high (no reduction)
168-
(["wordA", "wordB"], 4, None, (30524, 4)), # Custom vocab, PCA applied
169-
(None, None, None, (30522, 768)), # No PCA, SIF off
172+
(None, 1024, None, (30519, 768)), # PCA set high (no reduction)
173+
(["wordA", "wordB"], 4, None, (30521, 4)), # Custom vocab, PCA applied
174+
(None, None, None, (30519, 768)), # No PCA, SIF off
170175
],
171176
)
172177
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
173178
@patch("transformers.AutoModel.from_pretrained")
179+
@pytest.mark.parametrize("mock_transformer", [{"vocab_size": 30522}], indirect=True)
174180
def test_distill(
175181
mock_auto_model: MagicMock,
176182
mock_model_info: MagicMock,

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)