Skip to content

Commit fac55c1

Browse files
feat(ml): add qdrant ingestion
refactor: use local qdrant implementation for tests chore: clean up imports chore: add qdrant dependency to ml_test extra chore: run precommit chore: add comment to CHANGES.md fix: guard against import error fix: import
1 parent 6dd599c commit fac55c1

File tree

4 files changed

+449
-4
lines changed

4 files changed

+449
-4
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
* Updates minimum Go version to 1.26.1 ([#37897](https://github.com/apache/beam/issues/37897)).
111111
* (Python) Added image embedding support in `apache_beam.ml.rag` package ([#37628](https://github.com/apache/beam/issues/37628)).
112112
* (Python) Added support for Python version 3.14 ([#37247](https://github.com/apache/beam/issues/37247)).
113+
* (Python) Added [Qdrant](https://qdrant.tech/) VectorDatabaseWriteConfig implementation ([#38141](https://github.com/apache/beam/issues/38141)).
113114

114115
## Breaking Changes
115116

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import logging
18+
from dataclasses import dataclass, field
19+
from typing import Any, Callable, Dict, Optional
20+
21+
try:
22+
from qdrant_client import QdrantClient, models
23+
except ImportError:
24+
logging.warning("Qdrant client library is not installed.")
25+
26+
import apache_beam as beam
27+
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
28+
from apache_beam.ml.rag.types import EmbeddableItem
29+
30+
DEFAULT_WRITE_BATCH_SIZE = 1000
31+
32+
33+
@dataclass
34+
class QdrantConnectionParameters:
35+
location: Optional[str] = None
36+
url: Optional[str] = None
37+
port: Optional[int] = 6333
38+
grpc_port: int = 6334
39+
prefer_grpc: bool = False
40+
https: Optional[bool] = None
41+
api_key: Optional[str] = None
42+
prefix: Optional[str] = None
43+
timeout: Optional[int] = None
44+
host: Optional[str] = None
45+
path: Optional[str] = None
46+
kwargs: Dict[str, Any] = field(default_factory=dict)
47+
48+
def __post_init__(self):
49+
if not (self.location or self.url or self.host or self.path):
50+
raise ValueError(
51+
"One of location, url, host, or path must be provided for Qdrant")
52+
53+
54+
@dataclass
55+
class QdrantWriteConfig(VectorDatabaseWriteConfig):
56+
connection_params: QdrantConnectionParameters
57+
collection_name: str
58+
timeout: Optional[float] = None
59+
batch_size: int = DEFAULT_WRITE_BATCH_SIZE
60+
kwargs: Dict[str, Any] = field(default_factory=dict)
61+
dense_embedding_key: str = "dense"
62+
sparse_embedding_key: str = "sparse"
63+
64+
def __post_init__(self):
65+
if not self.collection_name:
66+
raise ValueError("Collection name must be provided")
67+
68+
def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]:
69+
return _QdrantWriteTransform(self)
70+
71+
def create_converter(
72+
self) -> Callable[[EmbeddableItem], 'models.PointStruct']:
73+
def convert(item: EmbeddableItem) -> 'models.PointStruct':
74+
if item.dense_embedding is None and item.sparse_embedding is None:
75+
raise ValueError(
76+
"EmbeddableItem must have at least one embedding (dense or sparse)")
77+
vector = {}
78+
if item.dense_embedding is not None:
79+
vector[self.dense_embedding_key] = item.dense_embedding
80+
if item.sparse_embedding is not None:
81+
sparse_indices, sparse_values = item.sparse_embedding
82+
vector[self.sparse_embedding_key] = models.SparseVector(
83+
indices=sparse_indices,
84+
values=sparse_values,
85+
)
86+
id = (
87+
int(item.id)
88+
if isinstance(item.id, str) and item.id.isdigit() else item.id)
89+
return models.PointStruct(
90+
id=id,
91+
vector=vector,
92+
payload=item.metadata if item.metadata else None,
93+
)
94+
95+
return convert
96+
97+
98+
class _QdrantWriteTransform(beam.PTransform):
99+
def __init__(self, config: QdrantWriteConfig):
100+
self.config = config
101+
102+
def expand(self, input_or_inputs: beam.PCollection[EmbeddableItem]):
103+
return (
104+
input_or_inputs
105+
| "Convert to Records" >> beam.Map(self.config.create_converter())
106+
| beam.ParDo(_QdrantWriteFn(self.config)))
107+
108+
109+
class _QdrantWriteFn(beam.DoFn):
110+
def __init__(self, config: QdrantWriteConfig):
111+
self.config = config
112+
self._batch = []
113+
self._client: 'Optional[QdrantClient]' = None
114+
115+
def process(self, element, *args, **kwargs):
116+
self._batch.append(element)
117+
if len(self._batch) >= self.config.batch_size:
118+
self._flush()
119+
120+
def setup(self):
121+
params = self.config.connection_params
122+
self._client = QdrantClient(
123+
location=params.location,
124+
url=params.url,
125+
port=params.port,
126+
grpc_port=params.grpc_port,
127+
prefer_grpc=params.prefer_grpc,
128+
https=params.https,
129+
api_key=params.api_key,
130+
prefix=params.prefix,
131+
timeout=params.timeout,
132+
host=params.host,
133+
path=params.path,
134+
check_compatibility=False,
135+
**params.kwargs,
136+
)
137+
138+
def teardown(self):
139+
if self._client:
140+
self._client.close()
141+
self._client = None
142+
143+
def finish_bundle(self):
144+
self._flush()
145+
146+
def _flush(self):
147+
if len(self._batch) == 0:
148+
return
149+
if not self._client:
150+
raise RuntimeError("Qdrant client is not initialized")
151+
self._client.upsert(
152+
collection_name=self.config.collection_name,
153+
points=self._batch,
154+
timeout=self.config.timeout,
155+
**self.config.kwargs,
156+
)
157+
self._batch = []
158+
159+
def display_data(self):
160+
res = super().display_data()
161+
res["collection"] = self.config.collection_name
162+
res["batch_size"] = self.config.batch_size
163+
return res

0 commit comments

Comments
 (0)