diff --git a/CHANGES.md b/CHANGES.md index 319520f94309..4e14e4728dad 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -110,6 +110,7 @@ * Updates minimum Go version to 1.26.1 ([#37897](https://github.com/apache/beam/issues/37897)). * (Python) Added image embedding support in `apache_beam.ml.rag` package ([#37628](https://github.com/apache/beam/issues/37628)). * (Python) Added support for Python version 3.14 ([#37247](https://github.com/apache/beam/issues/37247)). +* (Python) Added [Qdrant](https://qdrant.tech/) VectorDatabaseWriteConfig implementation ([#38141](https://github.com/apache/beam/issues/38141)). ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/rag/ingestion/qdrant.py b/sdks/python/apache_beam/ml/rag/ingestion/qdrant.py new file mode 100644 index 000000000000..636ef139b2c1 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/qdrant.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional + +try: + from qdrant_client import QdrantClient, models +except ImportError: + logging.warning("Qdrant client library is not installed.") + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.types import EmbeddableItem + +DEFAULT_WRITE_BATCH_SIZE = 1000 + + +@dataclass +class QdrantConnectionParameters: + location: Optional[str] = None + url: Optional[str] = None + port: Optional[int] = 6333 + grpc_port: int = 6334 + prefer_grpc: bool = False + https: Optional[bool] = None + api_key: Optional[str] = None + prefix: Optional[str] = None + timeout: Optional[int] = None + host: Optional[str] = None + path: Optional[str] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not (self.location or self.url or self.host or self.path): + raise ValueError( + "One of location, url, host, or path must be provided for Qdrant") + + +@dataclass +class QdrantWriteConfig(VectorDatabaseWriteConfig): + connection_params: QdrantConnectionParameters + collection_name: str + timeout: Optional[float] = None + batch_size: int = DEFAULT_WRITE_BATCH_SIZE + kwargs: Dict[str, Any] = field(default_factory=dict) + dense_embedding_key: str = "dense" + sparse_embedding_key: str = "sparse" + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]: + return _QdrantWriteTransform(self) + + def create_converter( + self) -> Callable[[EmbeddableItem], 'models.PointStruct']: + def convert(item: EmbeddableItem) -> 'models.PointStruct': + if item.dense_embedding is None and item.sparse_embedding is None: + raise ValueError( + "EmbeddableItem must have at least one embedding (dense or sparse)") + vector = {} + if item.dense_embedding is not None: + vector[self.dense_embedding_key] = item.dense_embedding + if item.sparse_embedding is not None: + sparse_indices, sparse_values = item.sparse_embedding + vector[self.sparse_embedding_key] = models.SparseVector( + indices=sparse_indices, + values=sparse_values, + ) + id = ( + int(item.id) + if isinstance(item.id, str) and item.id.isdigit() else item.id) + return models.PointStruct( + id=id, + vector=vector, + payload=item.metadata if item.metadata else None, + ) + + return convert + + +class _QdrantWriteTransform(beam.PTransform): + def __init__(self, config: QdrantWriteConfig): + self.config = config + + def expand(self, input_or_inputs: beam.PCollection[EmbeddableItem]): + return ( + input_or_inputs + | "Convert to Records" >> beam.Map(self.config.create_converter()) + | beam.ParDo(_QdrantWriteFn(self.config))) + + +class _QdrantWriteFn(beam.DoFn): + def __init__(self, config: QdrantWriteConfig): + self.config = config + self._batch = [] + self._client: 'Optional[QdrantClient]' = None + + def process(self, element, *args, **kwargs): + self._batch.append(element) + if len(self._batch) >= self.config.batch_size: + self._flush() + + def setup(self): + params = self.config.connection_params + self._client = QdrantClient( + location=params.location, + url=params.url, + port=params.port, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + https=params.https, + api_key=params.api_key, + prefix=params.prefix, + timeout=params.timeout, + host=params.host, + path=params.path, + check_compatibility=False, + **params.kwargs, + ) + + def teardown(self): + if self._client: + self._client.close() + self._client = None + + def finish_bundle(self): + self._flush() + + def _flush(self): + if len(self._batch) == 0: + return + if not self._client: + raise RuntimeError("Qdrant client is not initialized") + self._client.upsert( + collection_name=self.config.collection_name, + points=self._batch, + timeout=self.config.timeout, + **self.config.kwargs, + ) + self._batch = [] + + def display_data(self): + res = super().display_data() + res["collection"] = self.config.collection_name + res["batch_size"] = self.config.batch_size + return res diff --git a/sdks/python/apache_beam/ml/rag/ingestion/qdrant_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/qdrant_it_test.py new file mode 100644 index 000000000000..2e035e083308 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/qdrant_it_test.py @@ -0,0 +1,279 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.qdrant import QdrantConnectionParameters +from apache_beam.ml.rag.ingestion.qdrant import QdrantWriteConfig +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import EmbeddableItem +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline + +# pylint: disable=ungrouped-imports +try: + from qdrant_client import QdrantClient, models + QDRANT_AVAILABLE = True +except ImportError: + QDRANT_AVAILABLE = False +# pylint: enable=ungrouped-imports + +TEST_CORPUS = [ + EmbeddableItem( + id="1", + content=Content(text="Test document one"), + metadata={"source": "test1"}, + embedding=Embedding(dense_embedding=[1.0, 0.0]), + ), + EmbeddableItem( + id="2", + content=Content(text="Test document two"), + metadata={"source": "test2"}, + embedding=Embedding(dense_embedding=[0.0, 1.0]), + ), + EmbeddableItem( + id="3", + content=Content(text="Test document three"), + metadata={"source": "test3"}, + embedding=Embedding(dense_embedding=[-1.0, 0.0]), + ), +] + + +@unittest.skipIf(not QDRANT_AVAILABLE, "qdrant dependencies not installed.") +class TestQdrantIngestion(unittest.TestCase): + @contextlib.contextmanager + def qdrant_client(self) -> 'QdrantClient': + client = QdrantClient(path=self._temp_dir.name) + try: + yield client + finally: + client.close() + + def setUp(self): + self._temp_dir = tempfile.TemporaryDirectory() + self._collection_name = f"test_collection_{self._testMethodName}" + + with self.qdrant_client() as client: + client.create_collection( + collection_name=self._collection_name, + vectors_config={ + "dense": models.VectorParams( + size=2, distance=models.Distance.COSINE) + }, + sparse_vectors_config={"sparse": models.SparseVectorParams()}, + ) + assert client.collection_exists(collection_name=self._collection_name) + + self._connection_params = QdrantConnectionParameters( + path=self._temp_dir.name) + + def tearDown(self): + self._temp_dir.cleanup() + + def test_write_on_non_existent_collection(self): + non_existent = "nonexistent_collection" + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=non_existent, + batch_size=1, + ) + + with self.assertRaises(Exception): + with TestPipeline() as p: + _ = p | beam.Create(TEST_CORPUS) | write_config.create_write_transform() + + def test_write_dense_embeddings_only(self): + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=len(TEST_CORPUS), + ) + + with TestPipeline() as p: + _ = p | beam.Create(TEST_CORPUS) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(TEST_CORPUS)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in TEST_CORPUS: + expected_record = models.Record( + id=int(item.id), + vector={"dense": item.dense_embedding}, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + def test_write_sparse_embeddings_only(self): + sparse_corpus = [ + EmbeddableItem( + id="1", + content=Content(text="Sparse doc one"), + metadata={"source": "sparse1"}, + embedding=Embedding(sparse_embedding=([0, 1, 2], [0.1, 0.2, 0.3])), + ), + EmbeddableItem( + id="2", + content=Content(text="Sparse doc two"), + metadata={"source": "sparse2"}, + embedding=Embedding(sparse_embedding=([1, 3, 5], [0.4, 0.5, 0.6])), + ), + ] + + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=len(sparse_corpus), + ) + + with TestPipeline() as p: + _ = p | beam.Create(sparse_corpus) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(sparse_corpus)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in sparse_corpus: + expected_record = models.Record( + id=int(item.id), + vector={ + "sparse": models.SparseVector( + indices=item.sparse_embedding[0], + values=item.sparse_embedding[1], + ) + }, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + def test_write_both_dense_and_sparse(self): + hybrid_corpus = [ + EmbeddableItem( + id="1", + content=Content(text="Hybrid doc one"), + metadata={"source": "hybrid1"}, + embedding=Embedding( + dense_embedding=[1.0, 0.0], + sparse_embedding=([0, 1], [0.1, 0.2])), + ), + EmbeddableItem( + id="2", + content=Content(text="Hybrid doc two"), + metadata={"source": "hybrid2"}, + embedding=Embedding( + dense_embedding=[0.0, 1.0], + sparse_embedding=([2, 3], [0.3, 0.4])), + ), + ] + + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=len(hybrid_corpus), + ) + + with TestPipeline() as p: + _ = p | beam.Create(hybrid_corpus) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(hybrid_corpus)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in hybrid_corpus: + expected_record = models.Record( + id=int(item.id), + vector={ + "dense": item.dense_embedding, + "sparse": models.SparseVector( + indices=item.sparse_embedding[0], + values=item.sparse_embedding[1]), + }, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + def test_write_with_batching(self): + batch_corpus = [ + EmbeddableItem( + id=str(i), + content=Content(text=f"Batch doc {i}"), + metadata={"batch_id": i}, + embedding=Embedding(dense_embedding=[1.0, 0.0]), + ) for i in range(1, 8) + ] + + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=3, + ) + + with TestPipeline() as p: + _ = p | beam.Create(batch_corpus) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(batch_corpus)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in batch_corpus: + expected_record = models.Record( + id=int(item.id), + vector={ + "dense": item.dense_embedding, + }, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 5d4f86ae4d97..fb21ef90126c 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -166,6 +166,7 @@ def cythonize(*args, **kwargs): ] milvus_dependency = ['pymilvus>=2.5.10,<3.0.0'] +qdrant_dependency = ['qdrant-client>=1.15.0'] # google-adk / OpenTelemetry require protobuf>=5; tensorflow-transform in # ml_test is pinned to versions that require protobuf<5 on Python 3.10. Those @@ -584,14 +585,14 @@ def get_portability_package_data(): 'tf2onnx>=1.16.1,<1.17', ] + ml_base_core, 'p310_ml_test': [ - 'datatable', - ] + ml_base, + 'datatable', + ] + ml_base + qdrant_dependency, 'p312_ml_test': [ 'datatable', - ] + ml_base, + ] + ml_base + qdrant_dependency, # maintainer: milvus tests only run with this extension. Make sure it # is covered by docker-in-docker test when changing py version - 'p313_ml_test': ml_base + milvus_dependency, + 'p313_ml_test': ml_base + milvus_dependency + qdrant_dependency, 'aws': ['boto3>=1.9,<2'], 'azure': [ 'azure-storage-blob>=12.3.2,<13', @@ -662,6 +663,7 @@ def get_portability_package_data(): 'xgboost': ['xgboost>=1.6.0,<2.1.3', 'datatable==1.0.0'], 'tensorflow-hub': ['tensorflow-hub>=0.14.0,<0.16.0'], 'milvus': milvus_dependency, + 'qdrant': qdrant_dependency, 'vllm': ['openai==1.107.1', 'vllm==0.10.1.1', 'triton==3.3.1'] }, zip_safe=False,