Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions model2vec/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data import DataLoader, Dataset

from model2vec import StaticModel
from model2vec.train.utils import get_probable_pad_token_id

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,7 +83,7 @@ def from_pretrained(

@classmethod
def from_static_model(
cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, pad_token: str = "[PAD]", **kwargs: Any
cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, pad_token: str | None = None, **kwargs: Any
) -> ModelType:
"""Load the model from a static model."""
model.embedding = np.nan_to_num(model.embedding)
Expand All @@ -92,9 +93,13 @@ def from_static_model(
token_mapping = model.token_mapping.tolist()
else:
token_mapping = None
if pad_token is not None:
pad_id = model.tokenizer.get_vocab()[pad_token]
else:
pad_id = get_probable_pad_token_id(model.tokenizer)
return cls(
vectors=embeddings_converted,
pad_id=model.tokenizer.token_to_id(pad_token),
pad_id=pad_id,
out_dim=out_dim,
tokenizer=model.tokenizer,
token_mapping=token_mapping,
Expand Down
21 changes: 21 additions & 0 deletions model2vec/train/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import logging

from tokenizers import Tokenizer

logger = logging.getLogger(__name__)

_KNOWN_PAD_TOKENS = ("[PAD]", "<pad>")


def get_probable_pad_token_id(tokenizer: Tokenizer) -> int:
"""Get a probable pad token by using the padding module and falling back to guessing."""
if tokenizer.padding is not None:
return tokenizer.padding["pad_id"]
vocab = tokenizer.get_vocab()
for token in _KNOWN_PAD_TOKENS:
token_id = vocab.get(token)
if token_id is not None:
return token_id

logger.warning("No known pad token found, using 0 as default")
return 0
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dev = [
"ruff",
]

distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.1"]
distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.2"]
onnx = ["onnx", "torch"]
# train also installs inference
train = ["torch", "lightning", "scikit-learn", "skops"]
Expand Down
50 changes: 50 additions & 0 deletions tests/test_trainable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
from tempfile import TemporaryDirectory

import numpy as np
import pytest
import torch
from skeletoken import TokenizerModel
from tokenizers import Tokenizer
from transformers import AutoTokenizer

from model2vec.model import StaticModel
from model2vec.train import StaticModelForClassification
from model2vec.train.base import FinetunableStaticModel, TextDataset
from model2vec.train.utils import get_probable_pad_token_id


@pytest.mark.parametrize("n_layers", [0, 1, 2, 3])
Expand Down Expand Up @@ -67,6 +70,21 @@ def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: To
assert s.w.shape[0] == mock_vectors.shape[0]


def test_pad_token(mock_tokenizer: Tokenizer) -> None:
"""Test initializion from a static model."""
tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer)
tokenizer_model.pad_token = "[HELLO]"
tokenizer = tokenizer_model.to_tokenizer()
vectors = np.random.RandomState().randn(6, 10)
model = StaticModel(vectors=vectors, tokenizer=tokenizer)
s = StaticModelForClassification.from_static_model(model=model, pad_token="[HELLO]")
assert s.w.shape[0] == vectors.shape[0]
assert s.pad_id == 5

with pytest.raises(KeyError):
StaticModelForClassification.from_static_model(model=model, pad_token="[BRR]")


def test_encode(mock_trained_pipeline: StaticModelForClassification) -> None:
"""Test the encode function."""
result = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long())
Expand Down Expand Up @@ -231,3 +249,35 @@ def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
else:
# Ignore the type error since we don't support int labels in our typing, but the code does
mock_trained_pipeline.evaluate(["dog cat", "dog"], [1, 1]) # type: ignore


def test_get_probable_pad_token_id(mock_tokenizer: Tokenizer, caplog: pytest.LogCaptureFixture) -> None:
"""Test loading from a static model with a pad token."""
tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer)
t = tokenizer_model.to_tokenizer()
token_id = get_probable_pad_token_id(t)
assert token_id == 0

# Adds new token
tokenizer_model.pad_token = "haha"
t = tokenizer_model.to_tokenizer()
token_id = get_probable_pad_token_id(t)
assert token_id == 5

tokenizer_model.pad_token = "word1"
t = tokenizer_model.to_tokenizer()
token_id = get_probable_pad_token_id(t)
assert token_id == 1

# Remove padding token
tokenizer_model.pad_token = None
t = tokenizer_model.to_tokenizer()
token_id = get_probable_pad_token_id(t)
assert token_id == tokenizer_model.vocabulary["[PAD]"]

tokenizer_model = tokenizer_model.remove_token_from_vocabulary("[PAD]")
t = tokenizer_model.to_tokenizer()
with caplog.at_level(logging.WARNING, logger="model2vec.train.utils"):
token_id = get_probable_pad_token_id(t)
assert token_id == 0
assert "No known pad token found, using 0 as default" in caplog.text
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading