Skip to content

Commit c850878

Browse files
aaronspringclaude
andcommitted
feat: Add Jina Embeddings v3 with task-specific LoRA support
Add support for jinaai/jina-embeddings-v3, a multilingual embedding model with 1024 dimensions supporting 89+ languages and task-specific LoRA adapters. Features: - Task-specific embeddings via LoRA adapters (retrieval.query, retrieval.passage, classification, text-matching, separation) - Automatic task_id handling for ONNX inference - Default to text-matching task for general purpose use - query_embed() and passage_embed() methods for retrieval tasks - Matryoshka dimensions support (32-1024) - 8,192 token context window Model specs: - 570M parameters - 2.29 GB ONNX model - Apache 2.0 license Implementation: - Added model configuration with additional_files for model.onnx_data - Load lora_adaptations from config.json - Preprocess ONNX input to add task_id parameter - Override query_embed/passage_embed for automatic task selection - Added comprehensive multi-task test with canonical vectors Following the pattern from PR qdrant#561 but using task_id instead of text prefixes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ba1f605 commit c850878

2 files changed

Lines changed: 145 additions & 0 deletions

File tree

fastembed/text/onnx_embedding.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import json
2+
from pathlib import Path
13
from typing import Any, Iterable, Optional, Sequence, Type, Union
24

5+
import numpy as np
6+
37
from fastembed.common.types import NumpyArray, OnnxProvider
48
from fastembed.common.onnx_model import OnnxOutputContext
59
from fastembed.common.utils import define_cache_dir, normalize
@@ -180,6 +184,24 @@
180184
sources=ModelSource(hf="jinaai/jina-clip-v1"),
181185
model_file="onnx/text_model.onnx",
182186
),
187+
DenseModelDescription(
188+
model="jinaai/jina-embeddings-v3",
189+
dim=1024,
190+
description=(
191+
"Text embeddings, Unimodal (text), Multilingual (89+ languages), 8192 input tokens truncation, "
192+
"Task-specific LoRA adapters (retrieval, classification, text-matching, clustering), "
193+
"Matryoshka dimensions: 32-1024, 2024 year."
194+
),
195+
license="apache-2.0",
196+
size_in_GB=2.29,
197+
sources=ModelSource(hf="jinaai/jina-embeddings-v3"),
198+
model_file="onnx/model.onnx",
199+
additional_files=["onnx/model.onnx_data"],
200+
tasks={
201+
"query_task": "retrieval.query",
202+
"passage_task": "retrieval.passage",
203+
},
204+
),
183205
]
184206

185207

@@ -255,6 +277,14 @@ def __init__(
255277
specific_model_path=self._specific_model_path,
256278
)
257279

280+
# Load LoRA adaptations for models that support task-specific embeddings (e.g., Jina v3)
281+
self.lora_adaptations: Optional[list[str]] = None
282+
config_path = Path(self._model_dir) / "config.json"
283+
if config_path.exists():
284+
with open(config_path, "r") as f:
285+
config = json.load(f)
286+
self.lora_adaptations = config.get("lora_adaptations")
287+
258288
if not self.lazy_load:
259289
self.load_onnx_model()
260290

@@ -303,7 +333,20 @@ def _preprocess_onnx_input(
303333
) -> dict[str, NumpyArray]:
304334
"""
305335
Preprocess the onnx input.
336+
Adds task_id for models with LoRA adapters (e.g., Jina v3).
306337
"""
338+
# Handle task-specific embeddings for models with LoRA adapters
339+
if self.lora_adaptations:
340+
task_type = kwargs.get("task_type")
341+
342+
# If no task specified, use default (text-matching for general purpose)
343+
if not task_type:
344+
# Default to text-matching if available, otherwise first task
345+
task_type = "text-matching" if "text-matching" in self.lora_adaptations else self.lora_adaptations[0]
346+
347+
if task_type in self.lora_adaptations:
348+
task_id = np.array(self.lora_adaptations.index(task_type), dtype=np.int64)
349+
onnx_input["task_id"] = task_id
307350
return onnx_input
308351

309352
def _post_process_onnx_output(
@@ -329,6 +372,46 @@ def load_onnx_model(self) -> None:
329372
device_id=self.device_id,
330373
)
331374

375+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
376+
"""
377+
Embeds queries with task-specific handling for models that support it.
378+
379+
Args:
380+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
381+
**kwargs: Additional keyword arguments.
382+
383+
Returns:
384+
Iterable[NumpyArray]: The embeddings.
385+
"""
386+
# Use task-specific embedding for models with LoRA adapters
387+
if self.model_description.tasks and "query_task" in self.model_description.tasks:
388+
kwargs["task_type"] = self.model_description.tasks["query_task"]
389+
390+
if isinstance(query, str):
391+
yield from self.embed([query], **kwargs)
392+
else:
393+
yield from self.embed(query, **kwargs)
394+
395+
def passage_embed(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
396+
"""
397+
Embeds passages with task-specific handling for models that support it.
398+
399+
Args:
400+
texts (Union[str, Iterable[str]]): The text(s) to embed.
401+
**kwargs: Additional keyword arguments.
402+
403+
Returns:
404+
Iterable[NumpyArray]: The embeddings.
405+
"""
406+
# Use task-specific embedding for models with LoRA adapters
407+
if self.model_description.tasks and "passage_task" in self.model_description.tasks:
408+
kwargs["task_type"] = self.model_description.tasks["passage_task"]
409+
410+
if isinstance(texts, str):
411+
yield from self.embed([texts], **kwargs)
412+
else:
413+
yield from self.embed(texts, **kwargs)
414+
332415

333416
class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
334417
def init_embedding(

tests/test_text_onnx_embeddings.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
6868
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
6969
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
70+
"jinaai/jina-embeddings-v3": np.array([0.07257809, -0.08073004, 0.09241360, -0.01755937, 0.06534681]),
7071
}
7172

7273
MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"]
@@ -175,3 +176,64 @@ def test_embedding_size() -> None:
175176

176177
if is_ci:
177178
delete_model_cache(model.model._model_dir)
179+
180+
181+
@pytest.mark.parametrize("model_name", MULTI_TASK_MODELS)
182+
def test_multi_task_embedding(model_name: str) -> None:
183+
"""Test models that support task-specific embeddings (query vs passage)."""
184+
is_ci = os.getenv("CI")
185+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
186+
187+
# Skip in CI unless manual
188+
if is_ci and not is_manual:
189+
pytest.skip("Skipping multi-task model tests in CI (large models)")
190+
191+
model_desc = None
192+
for desc in TextEmbedding._list_supported_models():
193+
if desc.model == model_name:
194+
model_desc = desc
195+
break
196+
197+
assert model_desc is not None, f"Model {model_name} not found in supported models"
198+
199+
dim = model_desc.dim
200+
model = TextEmbedding(model_name=model_name)
201+
202+
# Test query embedding
203+
queries = ["What is the capital of France?", "How does photosynthesis work?"]
204+
query_embeddings = list(model.query_embed(queries))
205+
query_embeddings = np.stack(query_embeddings, axis=0)
206+
assert query_embeddings.shape == (2, dim), f"Query embeddings shape mismatch for {model_name}"
207+
208+
# Test passage embedding
209+
passages = ["Paris is the capital of France.", "Photosynthesis is a process used by plants."]
210+
passage_embeddings = list(model.passage_embed(passages))
211+
passage_embeddings = np.stack(passage_embeddings, axis=0)
212+
assert passage_embeddings.shape == (2, dim), f"Passage embeddings shape mismatch for {model_name}"
213+
214+
# Test regular embed (should work without task specification)
215+
docs = ["hello world", "flag embedding"]
216+
embeddings = list(model.embed(docs))
217+
embeddings = np.stack(embeddings, axis=0)
218+
assert embeddings.shape == (2, dim), f"Regular embeddings shape mismatch for {model_name}"
219+
220+
# Verify that query and passage embeddings are different (due to different LoRA adapters)
221+
# Using the same text should produce different embeddings for query vs passage
222+
test_text = "This is a test sentence"
223+
query_emb = np.array(list(model.query_embed([test_text])))
224+
passage_emb = np.array(list(model.passage_embed([test_text])))
225+
226+
# They should not be identical (different task adapters)
227+
assert not np.allclose(query_emb, passage_emb, atol=1e-6), \
228+
f"Query and passage embeddings should differ for {model_name}"
229+
230+
# Optional: Check canonical vectors if available
231+
if model_name in CANONICAL_VECTOR_VALUES:
232+
canonical_vector = CANONICAL_VECTOR_VALUES[model_name]
233+
# Check against regular embeddings[0] which is "hello world"
234+
assert np.allclose(
235+
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
236+
), f"Canonical vector mismatch for {model_name}"
237+
238+
if is_ci:
239+
delete_model_cache(model.model._model_dir)

0 commit comments

Comments
 (0)