diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 48743613b3..94da9ac225 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -151,7 +151,7 @@ def __post_init__(self): @dataclass class HybridQuery: """ - Hyrbid query. Could be used for dense-only or sparse-only or hybrid queries. + Hybrid query. Could be used for dense-only or sparse-only or hybrid queries. dense_embedding (List[float]): Optional. The dense part of the hybrid queries. @@ -281,7 +281,7 @@ def from_index_datapoint( ) # retrieve embedding metadata if index_datapoint.embedding_metadata is not None: - self.embedding_metadata = index_datapoint.embedding_metadata + self.embedding_metadata = dict(index_datapoint.embedding_metadata) return self def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor": @@ -328,6 +328,10 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb if embedding.sparse_embedding: self.sparse_embedding_values = embedding.sparse_embedding.float_val self.sparse_embedding_dimensions = embedding.sparse_embedding.dimension + + # retrieve embedding metadata + if embedding.embedding_metadata: + self.embedding_metadata = dict(embedding.embedding_metadata) return self @@ -1883,7 +1887,7 @@ def find_neighbors( [ MatchNeighbor( id=neighbor.datapoint.datapoint_id, - distance=neighbor.distance, + distance=neighbor.distance if neighbor.distance else None, sparse_distance=( neighbor.sparse_distance if neighbor.sparse_distance else None ), @@ -2219,19 +2223,18 @@ def match( # Wrap the results in MatchNeighbor objects and return match_neighbors_response = [] for resp in response.responses[0].responses: - match_neighbors_id_map = {} + embedding_map = {embedding.id: embedding for embedding in resp.embeddings} + neighbors_list = [] for neighbor in resp.neighbor: - match_neighbors_id_map[neighbor.id] = MatchNeighbor( + match_neighbor = MatchNeighbor( id=neighbor.id, - distance=neighbor.distance, + distance=neighbor.distance if neighbor.distance else None, sparse_distance=( neighbor.sparse_distance if neighbor.sparse_distance else None ), ) - for embedding in resp.embeddings: - if embedding.id in match_neighbors_id_map: - match_neighbors_id_map[embedding.id] = match_neighbors_id_map[ - embedding.id - ].from_embedding(embedding=embedding) - match_neighbors_response.append(list(match_neighbors_id_map.values())) + if neighbor.id in embedding_map: + match_neighbor.from_embedding(embedding=embedding_map[neighbor.id]) + neighbors_list.append(match_neighbor) + match_neighbors_response.append(neighbors_list) return match_neighbors_response diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index a8009e099e..7eb6d0c22a 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -15,47 +15,46 @@ # limitations under the License. # -import uuid from importlib import reload from unittest import mock from unittest.mock import patch +import uuid from google.api_core import operation from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer -from google.cloud.aiplatform.matching_engine._protos import ( - match_service_pb2, - match_service_pb2_grpc, -) -from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( - Namespace, - NumericNamespace, - MatchNeighbor, - HybridQuery, +from google.cloud.aiplatform.compat.services import ( + index_endpoint_service_client, + index_service_client, + match_service_client_v1beta1, ) from google.cloud.aiplatform.compat.types import ( - matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, + encryption_spec as gca_encryption_spec, index_endpoint as gca_index_endpoint, + index_v1beta1 as gca_index_v1beta1, index as gca_index, match_service_v1beta1 as gca_match_service_v1beta1, - index_v1beta1 as gca_index_v1beta1, + matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, service_networking as gca_service_networking, - encryption_spec as gca_encryption_spec, ) -from google.cloud.aiplatform.compat.services import ( - index_endpoint_service_client, - index_service_client, - match_service_client_v1beta1, +from google.cloud.aiplatform.matching_engine._protos import ( + match_service_pb2, + match_service_pb2_grpc, +) +from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( + HybridQuery, + MatchNeighbor, + Namespace, + NumericNamespace, ) import constants as test_constants - -from google.protobuf import field_mask_pb2 - import grpc - import pytest +from google.protobuf import field_mask_pb2 +from google.protobuf import struct_pb2 + # project _TEST_PROJECT = test_constants.ProjectConstants._TEST_PROJECT _TEST_LOCATION = test_constants.ProjectConstants._TEST_LOCATION @@ -2409,70 +2408,82 @@ def test_index_endpoint_read_index_datapoints_for_private_service_connect_automa class TestMatchNeighbor: - def test_from_index_datapoint(self): - index_datapoint = gca_index_v1beta1.IndexDatapoint() - index_datapoint.datapoint_id = "test_datapoint_id" - index_datapoint.feature_vector = [1.0, 2.0, 3.0] - index_datapoint.crowding_tag = gca_index_v1beta1.IndexDatapoint.CrowdingTag( + def test_from_index_datapoint(self): + index_datapoint = gca_index_v1beta1.IndexDatapoint() + index_datapoint.datapoint_id = "test_datapoint_id" + index_datapoint.feature_vector = [1.0, 2.0, 3.0] + index_datapoint.crowding_tag = gca_index_v1beta1.IndexDatapoint.CrowdingTag( crowding_attribute="test_crowding" ) - index_datapoint.restricts = [ + index_datapoint.restricts = [ gca_index_v1beta1.IndexDatapoint.Restriction( namespace="namespace1", allow_list=["token1"], deny_list=["token2"] ), ] - index_datapoint.numeric_restricts = [ + index_datapoint.numeric_restricts = [ gca_index_v1beta1.IndexDatapoint.NumericRestriction( namespace="namespace2", value_int=0, ) ] + index_datapoint.embedding_metadata = {"key": "value", "key2": "value2"} + + result = MatchNeighbor( + id="index_datapoint_id", distance=0.3 + ).from_index_datapoint(index_datapoint) + + assert result.feature_vector == [1.0, 2.0, 3.0] + assert result.crowding_tag == "test_crowding" + assert len(result.restricts) == 1 + assert result.restricts[0].name == "namespace1" + assert result.restricts[0].allow_tokens == ["token1"] + assert result.restricts[0].deny_tokens == ["token2"] + assert len(result.numeric_restricts) == 1 + assert result.numeric_restricts[0].name == "namespace2" + assert result.numeric_restricts[0].value_int == 0 + assert result.numeric_restricts[0].value_float is None + assert result.numeric_restricts[0].value_double is None + assert result.embedding_metadata == {"key": "value", "key2": "value2"} + + def test_from_embedding(self): + embedding_metadata_struct = struct_pb2.Struct() + embedding_metadata_struct.update({"key": "value", "key2": "value2"}) + + embedding = match_service_pb2.Embedding( + id="test_embedding_id", + float_val=[1.0, 2.0, 3.0], + crowding_attribute=1, + restricts=[ + match_service_pb2.Namespace( + name="namespace1", + allow_tokens=["token1"], + deny_tokens=["token2"], + ), + ], + numeric_restricts=[ + match_service_pb2.NumericNamespace( + name="namespace2", + value_int=10, + value_float=None, + value_double=None, + ) + ], + embedding_metadata=embedding_metadata_struct, + ) - result = MatchNeighbor( - id="index_datapoint_id", distance=0.3 - ).from_index_datapoint(index_datapoint) - - assert result.feature_vector == [1.0, 2.0, 3.0] - assert result.crowding_tag == "test_crowding" - assert len(result.restricts) == 1 - assert result.restricts[0].name == "namespace1" - assert result.restricts[0].allow_tokens == ["token1"] - assert result.restricts[0].deny_tokens == ["token2"] - assert len(result.numeric_restricts) == 1 - assert result.numeric_restricts[0].name == "namespace2" - assert result.numeric_restricts[0].value_int == 0 - assert result.numeric_restricts[0].value_float is None - assert result.numeric_restricts[0].value_double is None - - def test_from_embedding(self): - embedding = match_service_pb2.Embedding( - id="test_embedding_id", - float_val=[1.0, 2.0, 3.0], - crowding_attribute=1, - restricts=[ - match_service_pb2.Namespace( - name="namespace1", allow_tokens=["token1"], deny_tokens=["token2"] - ), - ], - numeric_restricts=[ - match_service_pb2.NumericNamespace( - name="namespace2", value_int=10, value_float=None, value_double=None - ) - ], - ) - - result = MatchNeighbor(id="embedding_id", distance=0.3).from_embedding( - embedding - ) + result = MatchNeighbor(id="embedding_id", distance=0.3).from_embedding( + embedding + ) - assert result.feature_vector == [1.0, 2.0, 3.0] - assert result.crowding_tag == "1" - assert len(result.restricts) == 1 - assert result.restricts[0].name == "namespace1" - assert result.restricts[0].allow_tokens == ["token1"] - assert result.restricts[0].deny_tokens == ["token2"] - assert len(result.numeric_restricts) == 1 - assert result.numeric_restricts[0].name == "namespace2" - assert result.numeric_restricts[0].value_int == 10 - assert not result.numeric_restricts[0].value_float - assert not result.numeric_restricts[0].value_double + assert result.feature_vector == [1.0, 2.0, 3.0] + assert result.crowding_tag == "1" + assert len(result.restricts) == 1 + assert result.restricts[0].name == "namespace1" + assert result.restricts[0].allow_tokens == ["token1"] + assert result.restricts[0].deny_tokens == ["token2"] + assert len(result.numeric_restricts) == 1 + assert result.numeric_restricts[0].name == "namespace2" + assert result.numeric_restricts[0].value_int == 10 + assert not result.numeric_restricts[0].value_float + assert not result.numeric_restricts[0].value_double + assert result.embedding_metadata == {"key": "value", "key2": "value2"}