|
| 1 | +from importlib.util import find_spec |
1 | 2 | from typing import Union |
2 | 3 |
|
3 | 4 | from vicinity.backends.base import AbstractBackend |
4 | 5 | from vicinity.backends.basic import BasicBackend, BasicVectorStore |
5 | 6 | from vicinity.datatypes import Backend |
6 | 7 |
|
7 | 8 |
|
| 9 | +class OptionalDependencyError(ImportError): |
| 10 | + def __init__(self, backend: Backend, extra: str) -> None: |
| 11 | + msg = f"{backend} requires extra '{extra}'.\n" f"Install it with: pip install 'vicinity[{extra}]'\n" |
| 12 | + super().__init__(msg) |
| 13 | + self.backend = backend |
| 14 | + self.extra = extra |
| 15 | + |
| 16 | + |
| 17 | +def _require(module_name: str, backend: Backend, extra: str) -> None: |
| 18 | + """Check if a dependency is importable, otherwise raise an error.""" |
| 19 | + if find_spec(module_name) is None: |
| 20 | + raise OptionalDependencyError(backend, extra) |
| 21 | + |
| 22 | + |
8 | 23 | def get_backend_class(backend: Union[Backend, str]) -> type[AbstractBackend]: |
9 | | - """Get all available backends.""" |
| 24 | + """Get the requested backend and ensure its dependencies are installed.""" |
10 | 25 | backend = Backend(backend) |
| 26 | + |
11 | 27 | if backend == Backend.BASIC: |
12 | 28 | return BasicBackend |
| 29 | + |
13 | 30 | elif backend == Backend.HNSW: |
| 31 | + _require("hnswlib", backend, "hnsw") |
14 | 32 | from vicinity.backends.hnsw import HNSWBackend |
15 | 33 |
|
16 | 34 | return HNSWBackend |
| 35 | + |
17 | 36 | elif backend == Backend.ANNOY: |
| 37 | + _require("annoy", backend, "annoy") |
18 | 38 | from vicinity.backends.annoy import AnnoyBackend |
19 | 39 |
|
20 | 40 | return AnnoyBackend |
| 41 | + |
21 | 42 | elif backend == Backend.PYNNDESCENT: |
| 43 | + _require("pynndescent", backend, "pynndescent") |
22 | 44 | from vicinity.backends.pynndescent import PyNNDescentBackend |
23 | 45 |
|
24 | 46 | return PyNNDescentBackend |
25 | 47 |
|
26 | 48 | elif backend == Backend.FAISS: |
| 49 | + _require("faiss", backend, "faiss") |
27 | 50 | from vicinity.backends.faiss import FaissBackend |
28 | 51 |
|
29 | 52 | return FaissBackend |
30 | 53 |
|
31 | 54 | elif backend == Backend.USEARCH: |
| 55 | + _require("usearch", backend, "usearch") |
32 | 56 | from vicinity.backends.usearch import UsearchBackend |
33 | 57 |
|
34 | 58 | return UsearchBackend |
35 | 59 |
|
36 | 60 | elif backend == Backend.VOYAGER: |
| 61 | + _require("voyager", backend, "voyager") |
37 | 62 | from vicinity.backends.voyager import VoyagerBackend |
38 | 63 |
|
39 | 64 | return VoyagerBackend |
|
0 commit comments