Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/google/adk/tools/discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import re
from typing import Any
from typing import Optional

Expand All @@ -25,6 +26,9 @@

from .function_tool import FunctionTool

_LOCATION_PATTERN = re.compile(r"/locations/([^/]+)/")
_DEFAULT_ENDPOINT = "discoveryengine.googleapis.com"


class DiscoveryEngineSearchTool(FunctionTool):
"""Tool for searching the discovery engine."""
Expand Down Expand Up @@ -74,10 +78,19 @@ def __init__(

credentials, _ = google.auth.default()
quota_project_id = getattr(credentials, "quota_project_id", None)
options = (
client_options.ClientOptions(quota_project_id=quota_project_id)
if quota_project_id
else None

resource_id = data_store_id or search_engine_id or ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The or "" is redundant here. The ValueError check on lines 60-65 ensures that either data_store_id or search_engine_id is a non-None string, so the expression data_store_id or search_engine_id will always evaluate to a string. Removing the fallback to an empty string makes the code slightly cleaner and relies on the existing validation.

Suggested change
resource_id = data_store_id or search_engine_id or ""
resource_id = data_store_id or search_engine_id

location_match = _LOCATION_PATTERN.search(resource_id)
location = location_match.group(1) if location_match else "global"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code/fix isn't crazy

api_endpoint = (
f"{location}-{_DEFAULT_ENDPOINT}"
if location != "global"
else _DEFAULT_ENDPOINT
)

options = client_options.ClientOptions(
api_endpoint=api_endpoint,
quota_project_id=quota_project_id,
)
self._discovery_engine_client = discoveryengine.SearchServiceClient(
credentials=credentials, client_options=options
Expand Down
71 changes: 70 additions & 1 deletion tests/unittests/tools/test_discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def test_discovery_engine_search_success(
assert result["results"][0]["content"] == "Test Content"
mock_auth.assert_called_once()
mock_client_options.ClientOptions.assert_called_once_with(
quota_project_id="test-quota-project"
api_endpoint="discoveryengine.googleapis.com",
quota_project_id="test-quota-project",
)
mock_search_client.assert_called_once_with(
credentials=mock_credentials,
Expand Down Expand Up @@ -156,3 +157,71 @@ def test_discovery_engine_search_no_results(self, mock_search_client):

assert result["status"] == "success"
assert not result["results"]

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_regional_endpoint_eu_data_store(
self, mock_search_client, mock_client_options
):
"""Test that an EU data store uses the EU regional endpoint."""
DiscoveryEngineSearchTool(
data_store_id="projects/my-project/locations/eu/collections/default_collection/dataStores/my-ds"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="eu-discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_regional_endpoint_us_search_engine(
self, mock_search_client, mock_client_options
):
"""Test that a US search engine uses the US regional endpoint."""
DiscoveryEngineSearchTool(
search_engine_id="projects/my-project/locations/us/collections/default_collection/engines/my-engine"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="us-discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_regional_endpoint_single_region(
self, mock_search_client, mock_client_options
):
"""Test that a single-region location uses the correct endpoint."""
DiscoveryEngineSearchTool(
data_store_id="projects/my-project/locations/europe-west1/collections/default_collection/dataStores/my-ds"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="europe-west1-discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_global_endpoint_explicit(
self, mock_search_client, mock_client_options
):
"""Test that a global data store uses the default global endpoint."""
DiscoveryEngineSearchTool(
data_store_id="projects/my-project/locations/global/collections/default_collection/dataStores/my-ds"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_global_endpoint_no_location_in_id(
self, mock_search_client, mock_client_options
):
"""Test that a short ID without location falls back to global endpoint."""
DiscoveryEngineSearchTool(data_store_id="test_data_store")
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="discoveryengine.googleapis.com",
quota_project_id=None,
)
Comment on lines +161 to +227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These five tests for endpoint selection are very similar. They can be consolidated into a single, more maintainable test by using pytest.mark.parametrize. This approach reduces code duplication and makes it clearer what is being tested across different inputs.

  @pytest.mark.parametrize(
      ("resource_id_key", "resource_id_value", "expected_endpoint"),
      [
          (
              "data_store_id",
              "projects/my-project/locations/eu/collections/default_collection/dataStores/my-ds",
              "eu-discoveryengine.googleapis.com",
          ),
          (
              "search_engine_id",
              "projects/my-project/locations/us/collections/default_collection/engines/my-engine",
              "us-discoveryengine.googleapis.com",
          ),
          (
              "data_store_id",
              "projects/my-project/locations/europe-west1/collections/default_collection/dataStores/my-ds",
              "europe-west1-discoveryengine.googleapis.com",
          ),
          (
              "data_store_id",
              "projects/my-project/locations/global/collections/default_collection/dataStores/my-ds",
              "discoveryengine.googleapis.com",
          ),
          ("data_store_id", "test_data_store", "discoveryengine.googleapis.com"),
      ],
  )
  @mock.patch.object(discovery_engine_search_tool, "client_options")
  @mock.patch.object(discoveryengine, "SearchServiceClient")
  def test_endpoint_selection(
      self,
      mock_search_client,
      mock_client_options,
      resource_id_key,
      resource_id_value,
      expected_endpoint,
  ):
    """Test that the correct API endpoint is selected based on resource ID."""
    kwargs = {resource_id_key: resource_id_value}
    DiscoveryEngineSearchTool(**kwargs)
    mock_client_options.ClientOptions.assert_called_once_with(
        api_endpoint=expected_endpoint,
        quota_project_id=None,
    )