From 84e5c5737b35815cb97d66c261221393cf71edfc Mon Sep 17 00:00:00 2001 From: Gautier Masse Date: Thu, 5 Mar 2026 13:44:52 +0100 Subject: [PATCH] fix(tools): support secure regional discovery engine endpoints --- .../adk/tools/discovery_engine_search_tool.py | 84 +++++++- .../test_discovery_engine_search_tool.py | 187 ++++++++++++++++++ 2 files changed, 267 insertions(+), 4 deletions(-) diff --git a/src/google/adk/tools/discovery_engine_search_tool.py b/src/google/adk/tools/discovery_engine_search_tool.py index 11bc37c7ed..4724731981 100644 --- a/src/google/adk/tools/discovery_engine_search_tool.py +++ b/src/google/adk/tools/discovery_engine_search_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import re from typing import Any from typing import Optional @@ -25,6 +26,75 @@ from .function_tool import FunctionTool +_DEFAULT_ENDPOINT = "discoveryengine.googleapis.com" +_GLOBAL_LOCATION = "global" +_LOCATION_PATTERN = re.compile( + r"/locations/([a-z0-9-]+)(?:/|$)", flags=re.IGNORECASE +) +_VALID_LOCATION_PATTERN = re.compile(r"^[a-z0-9-]+$") + + +def _normalize_location(location: str, location_type: str) -> str: + """Normalizes and validates a location value.""" + normalized_location = location.strip().lower() + if not normalized_location: + raise ValueError(f"{location_type} must not be empty if specified.") + if not _VALID_LOCATION_PATTERN.fullmatch(normalized_location): + raise ValueError( + f"{location_type} must contain only letters, digits, and hyphens." + ) + return normalized_location + + +def _extract_resource_location(resource_id: str) -> Optional[str]: + """Extracts and validates location from a resource id.""" + if "/locations/" not in resource_id.lower(): + return None + + location_match = _LOCATION_PATTERN.search(resource_id) + if not location_match: + raise ValueError("Invalid location in data_store_id or search_engine_id.") + return _normalize_location(location_match.group(1), "resource location") + + +def _resolve_location(resource_id: str, location: Optional[str]) -> str: + """Resolves the Discovery Engine location to use for the endpoint.""" + inferred_location = _extract_resource_location(resource_id) + + if location is not None: + normalized_location = _normalize_location(location, "location") + if inferred_location and normalized_location != inferred_location: + raise ValueError( + "location must match the location in data_store_id or " + "search_engine_id." + ) + return normalized_location + + if inferred_location: + return inferred_location + return _GLOBAL_LOCATION + + +def _build_client_options( + resource_id: str, + quota_project_id: Optional[str], + location: Optional[str], +) -> Optional[client_options.ClientOptions]: + """Builds client options for Discovery Engine requests.""" + client_options_kwargs = {} + resolved_location = _resolve_location(resource_id, location) + + if resolved_location != _GLOBAL_LOCATION: + client_options_kwargs["api_endpoint"] = ( + f"{resolved_location}-{_DEFAULT_ENDPOINT}" + ) + if quota_project_id: + client_options_kwargs["quota_project_id"] = quota_project_id + + if not client_options_kwargs: + return None + return client_options.ClientOptions(**client_options_kwargs) + class DiscoveryEngineSearchTool(FunctionTool): """Tool for searching the discovery engine.""" @@ -38,6 +108,7 @@ def __init__( search_engine_id: Optional[str] = None, filter: Optional[str] = None, max_results: Optional[int] = None, + location: Optional[str] = None, ): """Initializes the DiscoveryEngineSearchTool. @@ -51,6 +122,9 @@ def __init__( "projects/{project}/locations/{location}/collections/{collection}/engines/{engine}". filter: The filter to be applied to the search request. Default is None. max_results: The maximum number of results to return. Default is None. + location: Optional endpoint location override. + Examples: "global", "us", "eu". If not specified, location is inferred + from `data_store_id` or `search_engine_id` and defaults to "global". """ super().__init__(self.discovery_engine_search) if (data_store_id is None and search_engine_id is None) or ( @@ -71,13 +145,15 @@ def __init__( self._search_engine_id = search_engine_id self._filter = filter self._max_results = max_results + self._location = location credentials, _ = google.auth.default() quota_project_id = getattr(credentials, "quota_project_id", None) - options = ( - client_options.ClientOptions(quota_project_id=quota_project_id) - if quota_project_id - else None + resource_id = data_store_id or search_engine_id or "" + options = _build_client_options( + resource_id=resource_id, + quota_project_id=quota_project_id, + location=location, ) self._discovery_engine_client = discoveryengine.SearchServiceClient( credentials=credentials, client_options=options diff --git a/tests/unittests/tools/test_discovery_engine_search_tool.py b/tests/unittests/tools/test_discovery_engine_search_tool.py index 7cba5f1841..d25566a67d 100644 --- a/tests/unittests/tools/test_discovery_engine_search_tool.py +++ b/tests/unittests/tools/test_discovery_engine_search_tool.py @@ -79,6 +79,193 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error( data_store_id="test_data_store", data_store_specs=[{"id": "123"}] ) + @pytest.mark.parametrize( + ("tool_kwargs", "expected_endpoint"), + [ + ( + { + "data_store_id": ( + "projects/test/locations/eu/collections/default_collection/" + "dataStores/test_data_store" + ) + }, + "eu-discoveryengine.googleapis.com", + ), + ( + { + "search_engine_id": ( + "projects/test/locations/us/collections/default_collection/" + "engines/test_search_engine" + ) + }, + "us-discoveryengine.googleapis.com", + ), + ( + { + "data_store_id": ( + "projects/test/locations/europe-west1/collections/" + "default_collection/dataStores/test_data_store" + ) + }, + "europe-west1-discoveryengine.googleapis.com", + ), + ], + ) + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_regional_location_uses_regional_endpoint( + self, + mock_search_client, + mock_client_options, + tool_kwargs, + expected_endpoint, + ): + """Test initialization uses the expected regional API endpoint.""" + DiscoveryEngineSearchTool(**tool_kwargs) + + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint=expected_endpoint + ) + mock_search_client.assert_called_once_with( + credentials="credentials", + client_options=mock_client_options.ClientOptions.return_value, + ) + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_explicit_location_override_uses_input_location( + self, mock_search_client, mock_client_options + ): + """Test initialization uses explicit location when resource has none.""" + DiscoveryEngineSearchTool( + data_store_id="test_data_store", + location="eu", + ) + + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint="eu-discoveryengine.googleapis.com" + ) + mock_search_client.assert_called_once_with( + credentials="credentials", + client_options=mock_client_options.ClientOptions.return_value, + ) + + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_mismatched_location_raises_error(self, mock_search_client): + """Test initialization rejects mismatched location overrides.""" + with pytest.raises( + ValueError, + match=( + "location must match the location in data_store_id or " + "search_engine_id." + ), + ): + DiscoveryEngineSearchTool( + data_store_id=( + "projects/test/locations/us/collections/default_collection/" + "dataStores/test_data_store" + ), + location="eu", + ) + + mock_search_client.assert_not_called() + + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_empty_location_raises_error(self, mock_search_client): + """Test initialization rejects an empty location override.""" + with pytest.raises( + ValueError, match="location must not be empty if specified." + ): + DiscoveryEngineSearchTool( + data_store_id=( + "projects/test/locations/us/collections/default_collection/" + "dataStores/test_data_store" + ), + location=" ", + ) + + mock_search_client.assert_not_called() + + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_invalid_override_location_raises_error( + self, mock_search_client + ): + """Test initialization rejects invalid override location characters.""" + with pytest.raises( + ValueError, + match="location must contain only letters, digits, and hyphens.", + ): + DiscoveryEngineSearchTool( + data_store_id="test_data_store", + location="attacker.com#", + ) + + mock_search_client.assert_not_called() + + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_invalid_resource_location_raises_error( + self, mock_search_client + ): + """Test initialization rejects invalid resource location characters.""" + with pytest.raises( + ValueError, + match="Invalid location in data_store_id or search_engine_id.", + ): + DiscoveryEngineSearchTool( + data_store_id=( + "projects/test/locations/attacker.com#/collections/" + "default_collection/dataStores/test_data_store" + ) + ) + + mock_search_client.assert_not_called() + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_global_location_keeps_default_endpoint( + self, mock_search_client, mock_client_options + ): + """Test initialization keeps default API endpoint for global location.""" + DiscoveryEngineSearchTool( + data_store_id=( + "projects/test/locations/global/collections/default_collection/" + "dataStores/test_data_store" + ) + ) + + mock_client_options.ClientOptions.assert_not_called() + mock_search_client.assert_called_once_with( + credentials="credentials", client_options=None + ) + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_init_with_regional_location_and_quota_project_id( + self, mock_search_client, mock_client_options + ): + """Test initialization uses endpoint and quota project id together.""" + mock_credentials = mock.MagicMock() + mock_credentials.quota_project_id = "test-quota-project" + + with mock.patch.object( + auth, "default", return_value=(mock_credentials, "project") + ): + DiscoveryEngineSearchTool( + data_store_id=( + "projects/test/locations/eu/collections/default_collection/" + "dataStores/test_data_store" + ) + ) + + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint="eu-discoveryengine.googleapis.com", + quota_project_id="test-quota-project", + ) + mock_search_client.assert_called_once_with( + credentials=mock_credentials, + client_options=mock_client_options.ClientOptions.return_value, + ) + @mock.patch.object(discovery_engine_search_tool, "client_options") @mock.patch.object( discoveryengine,