Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 80 additions & 4 deletions src/google/adk/tools/discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import re
from typing import Any
from typing import Optional

Expand All @@ -25,6 +26,75 @@

from .function_tool import FunctionTool

_DEFAULT_ENDPOINT = "discoveryengine.googleapis.com"
_GLOBAL_LOCATION = "global"
_LOCATION_PATTERN = re.compile(
r"/locations/([a-z0-9-]+)(?:/|$)", flags=re.IGNORECASE
)
_VALID_LOCATION_PATTERN = re.compile(r"^[a-z0-9-]+$")


def _normalize_location(location: str, location_type: str) -> str:
"""Normalizes and validates a location value."""
normalized_location = location.strip().lower()
if not normalized_location:
raise ValueError(f"{location_type} must not be empty if specified.")
if not _VALID_LOCATION_PATTERN.fullmatch(normalized_location):
raise ValueError(
f"{location_type} must contain only letters, digits, and hyphens."
)
return normalized_location


def _extract_resource_location(resource_id: str) -> Optional[str]:
"""Extracts and validates location from a resource id."""
if "/locations/" not in resource_id.lower():
return None

location_match = _LOCATION_PATTERN.search(resource_id)
if not location_match:
raise ValueError("Invalid location in data_store_id or search_engine_id.")
return _normalize_location(location_match.group(1), "resource location")


def _resolve_location(resource_id: str, location: Optional[str]) -> str:
"""Resolves the Discovery Engine location to use for the endpoint."""
inferred_location = _extract_resource_location(resource_id)

if location is not None:
normalized_location = _normalize_location(location, "location")
if inferred_location and normalized_location != inferred_location:
raise ValueError(
"location must match the location in data_store_id or "
"search_engine_id."
)
return normalized_location

if inferred_location:
return inferred_location
return _GLOBAL_LOCATION


def _build_client_options(
resource_id: str,
quota_project_id: Optional[str],
location: Optional[str],
) -> Optional[client_options.ClientOptions]:
"""Builds client options for Discovery Engine requests."""
client_options_kwargs = {}
resolved_location = _resolve_location(resource_id, location)

if resolved_location != _GLOBAL_LOCATION:
client_options_kwargs["api_endpoint"] = (
f"{resolved_location}-{_DEFAULT_ENDPOINT}"
)
if quota_project_id:
client_options_kwargs["quota_project_id"] = quota_project_id

if not client_options_kwargs:
return None
return client_options.ClientOptions(**client_options_kwargs)


class DiscoveryEngineSearchTool(FunctionTool):
"""Tool for searching the discovery engine."""
Expand All @@ -38,6 +108,7 @@ def __init__(
search_engine_id: Optional[str] = None,
filter: Optional[str] = None,
max_results: Optional[int] = None,
location: Optional[str] = None,
):
"""Initializes the DiscoveryEngineSearchTool.

Expand All @@ -51,6 +122,9 @@ def __init__(
"projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
filter: The filter to be applied to the search request. Default is None.
max_results: The maximum number of results to return. Default is None.
location: Optional endpoint location override.
Examples: "global", "us", "eu". If not specified, location is inferred
from `data_store_id` or `search_engine_id` and defaults to "global".
"""
super().__init__(self.discovery_engine_search)
if (data_store_id is None and search_engine_id is None) or (
Expand All @@ -71,13 +145,15 @@ def __init__(
self._search_engine_id = search_engine_id
self._filter = filter
self._max_results = max_results
self._location = location

credentials, _ = google.auth.default()
quota_project_id = getattr(credentials, "quota_project_id", None)
options = (
client_options.ClientOptions(quota_project_id=quota_project_id)
if quota_project_id
else None
resource_id = data_store_id or search_engine_id or ""
options = _build_client_options(
resource_id=resource_id,
quota_project_id=quota_project_id,
location=location,
)
self._discovery_engine_client = discoveryengine.SearchServiceClient(
credentials=credentials, client_options=options
Expand Down
187 changes: 187 additions & 0 deletions tests/unittests/tools/test_discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,193 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error(
data_store_id="test_data_store", data_store_specs=[{"id": "123"}]
)

@pytest.mark.parametrize(
("tool_kwargs", "expected_endpoint"),
[
(
{
"data_store_id": (
"projects/test/locations/eu/collections/default_collection/"
"dataStores/test_data_store"
)
},
"eu-discoveryengine.googleapis.com",
),
(
{
"search_engine_id": (
"projects/test/locations/us/collections/default_collection/"
"engines/test_search_engine"
)
},
"us-discoveryengine.googleapis.com",
),
(
{
"data_store_id": (
"projects/test/locations/europe-west1/collections/"
"default_collection/dataStores/test_data_store"
)
},
"europe-west1-discoveryengine.googleapis.com",
),
],
)
@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_regional_location_uses_regional_endpoint(
self,
mock_search_client,
mock_client_options,
tool_kwargs,
expected_endpoint,
):
"""Test initialization uses the expected regional API endpoint."""
DiscoveryEngineSearchTool(**tool_kwargs)

mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint=expected_endpoint
)
mock_search_client.assert_called_once_with(
credentials="credentials",
client_options=mock_client_options.ClientOptions.return_value,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_explicit_location_override_uses_input_location(
self, mock_search_client, mock_client_options
):
"""Test initialization uses explicit location when resource has none."""
DiscoveryEngineSearchTool(
data_store_id="test_data_store",
location="eu",
)

mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="eu-discoveryengine.googleapis.com"
)
mock_search_client.assert_called_once_with(
credentials="credentials",
client_options=mock_client_options.ClientOptions.return_value,
)

@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_mismatched_location_raises_error(self, mock_search_client):
"""Test initialization rejects mismatched location overrides."""
with pytest.raises(
ValueError,
match=(
"location must match the location in data_store_id or "
"search_engine_id."
),
):
DiscoveryEngineSearchTool(
data_store_id=(
"projects/test/locations/us/collections/default_collection/"
"dataStores/test_data_store"
),
location="eu",
)

mock_search_client.assert_not_called()

@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_empty_location_raises_error(self, mock_search_client):
"""Test initialization rejects an empty location override."""
with pytest.raises(
ValueError, match="location must not be empty if specified."
):
DiscoveryEngineSearchTool(
data_store_id=(
"projects/test/locations/us/collections/default_collection/"
"dataStores/test_data_store"
),
location=" ",
)

mock_search_client.assert_not_called()

@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_invalid_override_location_raises_error(
self, mock_search_client
):
"""Test initialization rejects invalid override location characters."""
with pytest.raises(
ValueError,
match="location must contain only letters, digits, and hyphens.",
):
DiscoveryEngineSearchTool(
data_store_id="test_data_store",
location="attacker.com#",
)

mock_search_client.assert_not_called()

@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_invalid_resource_location_raises_error(
self, mock_search_client
):
"""Test initialization rejects invalid resource location characters."""
with pytest.raises(
ValueError,
match="Invalid location in data_store_id or search_engine_id.",
):
DiscoveryEngineSearchTool(
data_store_id=(
"projects/test/locations/attacker.com#/collections/"
"default_collection/dataStores/test_data_store"
)
)

mock_search_client.assert_not_called()

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_global_location_keeps_default_endpoint(
self, mock_search_client, mock_client_options
):
"""Test initialization keeps default API endpoint for global location."""
DiscoveryEngineSearchTool(
data_store_id=(
"projects/test/locations/global/collections/default_collection/"
"dataStores/test_data_store"
)
)

mock_client_options.ClientOptions.assert_not_called()
mock_search_client.assert_called_once_with(
credentials="credentials", client_options=None
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_init_with_regional_location_and_quota_project_id(
self, mock_search_client, mock_client_options
):
"""Test initialization uses endpoint and quota project id together."""
mock_credentials = mock.MagicMock()
mock_credentials.quota_project_id = "test-quota-project"

with mock.patch.object(
auth, "default", return_value=(mock_credentials, "project")
):
DiscoveryEngineSearchTool(
data_store_id=(
"projects/test/locations/eu/collections/default_collection/"
"dataStores/test_data_store"
)
)

mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="eu-discoveryengine.googleapis.com",
quota_project_id="test-quota-project",
)
mock_search_client.assert_called_once_with(
credentials=mock_credentials,
client_options=mock_client_options.ClientOptions.return_value,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(
discoveryengine,
Expand Down