diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index b8b925429..2d2e76fc4 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -104,8 +104,10 @@ def token_count( include_extension: bool = False, **kwargs: Any, ) -> int: - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() # loads the tokenizer as well + if not hasattr(self, "tokenizer") or self.tokenizer is None: + self._load_tokenizer(model_dir=self._model_dir) + if self.query_tokenizer is None: + self._load_query_tokenizer() token_num = 0 texts = [texts] if isinstance(texts, str) else texts tokenizer = self.tokenizer if is_doc else self.query_tokenizer @@ -218,6 +220,9 @@ def load_onnx_model(self) -> None: device_id=self.device_id, extra_session_options=self._extra_session_options, ) + self._load_query_tokenizer() + + def _load_query_tokenizer(self) -> None: self.query_tokenizer, _ = load_tokenizer(model_dir=self._model_dir) assert self.tokenizer is not None diff --git a/fastembed/late_interaction_multimodal/colmodernvbert.py b/fastembed/late_interaction_multimodal/colmodernvbert.py index 20b8e4f7b..90570dc85 100644 --- a/fastembed/late_interaction_multimodal/colmodernvbert.py +++ b/fastembed/late_interaction_multimodal/colmodernvbert.py @@ -203,8 +203,8 @@ def token_count( include_extension: bool = False, **kwargs: Any, ) -> int: - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() # loads the tokenizer as well + if not hasattr(self, "tokenizer") or self.tokenizer is None: + self._load_tokenizer(model_dir=self._model_dir) token_num = 0 texts = [texts] if isinstance(texts, str) else texts assert self.tokenizer is not None diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 85fbcd06b..69876c640 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -180,8 +180,8 @@ def token_count( include_extension: bool = False, **kwargs: Any, ) -> int: - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() # loads the tokenizer as well + if not hasattr(self, "tokenizer") or self.tokenizer is None: + self._load_tokenizer(model_dir=self._model_dir) token_num = 0 texts = [texts] if isinstance(texts, str) else texts assert self.tokenizer is not None diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index 934368957..e603a450d 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -75,10 +75,13 @@ def _load_onnx_model( device_id=device_id, extra_session_options=extra_session_options, ) - self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + self._load_tokenizer(model_dir=model_dir) assert self.tokenizer is not None self.processor = load_preprocessor(model_dir=model_dir) + def _load_tokenizer(self, model_dir: Path) -> None: + self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 55f3ea85c..013c0af36 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -44,9 +44,12 @@ def _load_onnx_model( device_id=device_id, extra_session_options=extra_session_options, ) - self.tokenizer, _ = load_tokenizer(model_dir=model_dir) + self._load_tokenizer(model_dir=model_dir) assert self.tokenizer is not None + def _load_tokenizer(self, model_dir: Path) -> None: + self.tokenizer, _ = load_tokenizer(model_dir=model_dir) + def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]: return self.tokenizer.encode_batch(pairs) # type: ignore[union-attr] @@ -168,8 +171,11 @@ def _preprocess_onnx_input( def _token_count( self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **_: Any ) -> int: - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() # loads the tokenizer as well + if not hasattr(self, "tokenizer") or self.tokenizer is None: + model_dir = getattr(self, "_model_dir", None) + if model_dir is None: + raise ValueError("Tokenizer cannot be loaded before model files are resolved.") + self._load_tokenizer(model_dir=Path(model_dir)) token_num = 0 assert self.tokenizer is not None diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index 2b090f749..b554ccb21 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -355,8 +355,6 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]: def token_count( self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any ) -> int: - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() # loads the tokenizer as well return self._token_count(texts, batch_size=batch_size, **kwargs) diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 10a4aa17e..d5dfbe50d 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -38,6 +38,9 @@ def __init__(self) -> None: self.tokenizer: Tokenizer | None = None self.special_token_to_id: dict[str, int] = {} + def _load_tokenizer(self, model_dir: Path) -> None: + self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + def _preprocess_onnx_input( self, onnx_input: dict[str, NumpyArray], **kwargs: Any ) -> dict[str, NumpyArray | NDArray[np.int64]]: @@ -65,7 +68,7 @@ def _load_onnx_model( device_id=device_id, extra_session_options=extra_session_options, ) - self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + self._load_tokenizer(model_dir=model_dir) def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") @@ -167,8 +170,11 @@ def _embed_documents( yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore def _token_count(self, texts: str | Iterable[str], batch_size: int = 1024, **_: Any) -> int: - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() # loads the tokenizer as well + if not hasattr(self, "tokenizer") or self.tokenizer is None: + model_dir = getattr(self, "_model_dir", None) + if model_dir is None: + raise ValueError("Tokenizer cannot be loaded before model files are resolved.") + self._load_tokenizer(model_dir=Path(model_dir)) token_num = 0 assert self.tokenizer is not None diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index ea83e76a3..0e433fce8 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -276,6 +276,8 @@ def test_lazy_load(model_name: str): assert not hasattr(model.model, "model") docs = ["hello world", "flag embedding"] + assert model.token_count(docs, is_doc=False, include_extension=True) > 0 + assert not hasattr(model.model, "model") list(model.embed(docs)) assert hasattr(model.model, "model") diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 4d0d5b7d6..3d03f3383 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -100,6 +100,9 @@ def test_lazy_load(model_name: str) -> None: assert not hasattr(model.model, "model") query = "What is the capital of France?" documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] + pairs = [(query, doc) for doc in documents] + assert model.token_count(pairs) > 0 + assert not hasattr(model.model, "model") list(model.rerank(query, documents)) assert hasattr(model.model, "model") diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index cdd22d79a..84107dfec 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -215,6 +215,8 @@ def test_lazy_load(model_name: str) -> None: model = TextEmbedding(model_name=model_name, lazy_load=True) assert not hasattr(model.model, "model") docs = ["hello world", "flag embedding"] + assert model.token_count(docs) > 0 + assert not hasattr(model.model, "model") list(model.embed(docs)) assert hasattr(model.model, "model")