diff --git a/src/google/adk/tools/discovery_engine_search_tool.py b/src/google/adk/tools/discovery_engine_search_tool.py index 11bc37c7ed..e9ad4abda1 100644 --- a/src/google/adk/tools/discovery_engine_search_tool.py +++ b/src/google/adk/tools/discovery_engine_search_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import re from typing import Any from typing import Optional @@ -25,6 +26,9 @@ from .function_tool import FunctionTool +_LOCATION_PATTERN = re.compile(r"/locations/([^/]+)/") +_DEFAULT_ENDPOINT = "discoveryengine.googleapis.com" + class DiscoveryEngineSearchTool(FunctionTool): """Tool for searching the discovery engine.""" @@ -74,10 +78,19 @@ def __init__( credentials, _ = google.auth.default() quota_project_id = getattr(credentials, "quota_project_id", None) - options = ( - client_options.ClientOptions(quota_project_id=quota_project_id) - if quota_project_id - else None + + resource_id = data_store_id or search_engine_id or "" + location_match = _LOCATION_PATTERN.search(resource_id) + location = location_match.group(1) if location_match else "global" + api_endpoint = ( + f"{location}-{_DEFAULT_ENDPOINT}" + if location != "global" + else _DEFAULT_ENDPOINT + ) + + options = client_options.ClientOptions( + api_endpoint=api_endpoint, + quota_project_id=quota_project_id, ) self._discovery_engine_client = discoveryengine.SearchServiceClient( credentials=credentials, client_options=options diff --git a/tests/unittests/tools/test_discovery_engine_search_tool.py b/tests/unittests/tools/test_discovery_engine_search_tool.py index 7cba5f1841..a6be520f2f 100644 --- a/tests/unittests/tools/test_discovery_engine_search_tool.py +++ b/tests/unittests/tools/test_discovery_engine_search_tool.py @@ -121,7 +121,8 @@ def test_discovery_engine_search_success( assert result["results"][0]["content"] == "Test Content" mock_auth.assert_called_once() mock_client_options.ClientOptions.assert_called_once_with( - quota_project_id="test-quota-project" + api_endpoint="discoveryengine.googleapis.com", + quota_project_id="test-quota-project", ) mock_search_client.assert_called_once_with( credentials=mock_credentials, @@ -156,3 +157,71 @@ def test_discovery_engine_search_no_results(self, mock_search_client): assert result["status"] == "success" assert not result["results"] + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_regional_endpoint_eu_data_store( + self, mock_search_client, mock_client_options + ): + """Test that an EU data store uses the EU regional endpoint.""" + DiscoveryEngineSearchTool( + data_store_id="projects/my-project/locations/eu/collections/default_collection/dataStores/my-ds" + ) + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint="eu-discoveryengine.googleapis.com", + quota_project_id=None, + ) + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_regional_endpoint_us_search_engine( + self, mock_search_client, mock_client_options + ): + """Test that a US search engine uses the US regional endpoint.""" + DiscoveryEngineSearchTool( + search_engine_id="projects/my-project/locations/us/collections/default_collection/engines/my-engine" + ) + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint="us-discoveryengine.googleapis.com", + quota_project_id=None, + ) + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_regional_endpoint_single_region( + self, mock_search_client, mock_client_options + ): + """Test that a single-region location uses the correct endpoint.""" + DiscoveryEngineSearchTool( + data_store_id="projects/my-project/locations/europe-west1/collections/default_collection/dataStores/my-ds" + ) + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint="europe-west1-discoveryengine.googleapis.com", + quota_project_id=None, + ) + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_global_endpoint_explicit( + self, mock_search_client, mock_client_options + ): + """Test that a global data store uses the default global endpoint.""" + DiscoveryEngineSearchTool( + data_store_id="projects/my-project/locations/global/collections/default_collection/dataStores/my-ds" + ) + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint="discoveryengine.googleapis.com", + quota_project_id=None, + ) + + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object(discoveryengine, "SearchServiceClient") + def test_global_endpoint_no_location_in_id( + self, mock_search_client, mock_client_options + ): + """Test that a short ID without location falls back to global endpoint.""" + DiscoveryEngineSearchTool(data_store_id="test_data_store") + mock_client_options.ClientOptions.assert_called_once_with( + api_endpoint="discoveryengine.googleapis.com", + quota_project_id=None, + )