Skip to content

Commit 51ad6c2

Browse files
committed
feat(tools): Add CallableRetrieval for custom retrieval backends
Add a new CallableRetrieval class that wraps any user-provided sync or async callable as a BaseRetrievalTool, enabling custom retrieval backends (Elasticsearch, Pinecone, pgvector, etc.) to have first-class retrieval semantics without requiring specific framework dependencies.
1 parent eaf50ce commit 51ad6c2

3 files changed

Lines changed: 247 additions & 0 deletions

File tree

src/google/adk/tools/retrieval/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414

1515
from .base_retrieval_tool import BaseRetrievalTool
16+
from .callable_retrieval import CallableRetrieval
1617

1718
__all__ = [
1819
"BaseRetrievalTool",
20+
"CallableRetrieval",
1921
"FilesRetrieval",
2022
"LlamaIndexRetrieval",
2123
"VertexAiRagRetrieval",
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Retrieval tool that wraps a user-provided callable."""
16+
17+
from __future__ import annotations
18+
19+
import inspect
20+
from typing import Any
21+
from typing import Awaitable
22+
from typing import Callable
23+
from typing import Union
24+
25+
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
26+
from google.adk.tools.tool_context import ToolContext
27+
from typing_extensions import override
28+
29+
30+
class CallableRetrieval(BaseRetrievalTool):
31+
"""Retrieval tool backed by a user-provided function.
32+
33+
Wraps any callable that accepts a query string and returns results,
34+
making it a first-class retrieval tool in ADK.
35+
36+
Example:
37+
>>> def search_docs(query: str) -> list[str]:
38+
... return my_db.search(query)
39+
>>> tool = CallableRetrieval(
40+
... name="search_docs",
41+
... description="Search the knowledge base.",
42+
... retriever=search_docs,
43+
... )
44+
45+
Args:
46+
name: Tool name exposed to the LLM.
47+
description: Tool description exposed to the LLM.
48+
retriever: A sync or async callable. Must accept a ``query``
49+
string as its first argument. May optionally accept a
50+
``tool_context`` parameter.
51+
"""
52+
53+
def __init__(
54+
self,
55+
*,
56+
name: str,
57+
description: str,
58+
retriever: Union[
59+
Callable[[str], Any],
60+
Callable[[str], Awaitable[Any]],
61+
],
62+
):
63+
super().__init__(name=name, description=description)
64+
self._retriever = retriever
65+
self._pass_tool_context = (
66+
"tool_context" in inspect.signature(retriever).parameters
67+
)
68+
69+
@override
70+
async def run_async(
71+
self, *, args: dict[str, Any], tool_context: ToolContext
72+
) -> Any:
73+
query = args["query"]
74+
kwargs = {"tool_context": tool_context} if self._pass_tool_context else {}
75+
result = self._retriever(query, **kwargs)
76+
if inspect.isawaitable(result):
77+
return await result
78+
return result
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import MagicMock
16+
17+
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
18+
from google.adk.tools.retrieval.callable_retrieval import CallableRetrieval
19+
from google.adk.tools.tool_context import ToolContext
20+
import pytest
21+
22+
23+
@pytest.fixture
24+
def mock_tool_context():
25+
return MagicMock(spec=ToolContext)
26+
27+
28+
def test_isinstance_base_retrieval_tool():
29+
tool = CallableRetrieval(
30+
name="test",
31+
description="A test tool.",
32+
retriever=lambda query: [],
33+
)
34+
assert isinstance(tool, BaseRetrievalTool)
35+
36+
37+
def test_get_declaration():
38+
tool = CallableRetrieval(
39+
name="my_search",
40+
description="Search docs.",
41+
retriever=lambda query: [],
42+
)
43+
declaration = tool._get_declaration()
44+
assert declaration.name == "my_search"
45+
assert declaration.description == "Search docs."
46+
47+
48+
@pytest.mark.asyncio
49+
async def test_sync_callable(mock_tool_context):
50+
def my_retriever(query: str):
51+
return [f"result for {query}"]
52+
53+
tool = CallableRetrieval(
54+
name="sync_tool",
55+
description="A sync retrieval tool.",
56+
retriever=my_retriever,
57+
)
58+
result = await tool.run_async(
59+
args={"query": "hello"}, tool_context=mock_tool_context
60+
)
61+
assert result == ["result for hello"]
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_async_callable(mock_tool_context):
66+
async def my_retriever(query: str):
67+
return [f"async result for {query}"]
68+
69+
tool = CallableRetrieval(
70+
name="async_tool",
71+
description="An async retrieval tool.",
72+
retriever=my_retriever,
73+
)
74+
result = await tool.run_async(
75+
args={"query": "world"}, tool_context=mock_tool_context
76+
)
77+
assert result == ["async result for world"]
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_tool_context_passthrough(mock_tool_context):
82+
received_context = {}
83+
84+
def my_retriever(query: str, tool_context: ToolContext):
85+
received_context["ctx"] = tool_context
86+
return ["with context"]
87+
88+
tool = CallableRetrieval(
89+
name="ctx_tool",
90+
description="Tool with context.",
91+
retriever=my_retriever,
92+
)
93+
result = await tool.run_async(
94+
args={"query": "test"}, tool_context=mock_tool_context
95+
)
96+
assert result == ["with context"]
97+
assert received_context["ctx"] is mock_tool_context
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_tool_context_omission(mock_tool_context):
102+
def my_retriever(query: str):
103+
return ["no context needed"]
104+
105+
tool = CallableRetrieval(
106+
name="no_ctx_tool",
107+
description="Tool without context.",
108+
retriever=my_retriever,
109+
)
110+
result = await tool.run_async(
111+
args={"query": "test"}, tool_context=mock_tool_context
112+
)
113+
assert result == ["no context needed"]
114+
115+
116+
@pytest.mark.asyncio
117+
async def test_async_callable_with_tool_context(mock_tool_context):
118+
async def my_retriever(query: str, tool_context: ToolContext):
119+
return [f"async {query} with context"]
120+
121+
tool = CallableRetrieval(
122+
name="async_ctx_tool",
123+
description="Async tool with context.",
124+
retriever=my_retriever,
125+
)
126+
result = await tool.run_async(
127+
args={"query": "test"}, tool_context=mock_tool_context
128+
)
129+
assert result == ["async test with context"]
130+
131+
132+
@pytest.mark.asyncio
133+
async def test_sync_callable_object(mock_tool_context):
134+
135+
class MyRetriever:
136+
137+
def __call__(self, query: str):
138+
return [f"object result for {query}"]
139+
140+
tool = CallableRetrieval(
141+
name="obj_tool",
142+
description="Callable object tool.",
143+
retriever=MyRetriever(),
144+
)
145+
result = await tool.run_async(
146+
args={"query": "hello"}, tool_context=mock_tool_context
147+
)
148+
assert result == ["object result for hello"]
149+
150+
151+
@pytest.mark.asyncio
152+
async def test_async_callable_object(mock_tool_context):
153+
154+
class MyAsyncRetriever:
155+
156+
async def __call__(self, query: str):
157+
return [f"async object result for {query}"]
158+
159+
tool = CallableRetrieval(
160+
name="async_obj_tool",
161+
description="Async callable object tool.",
162+
retriever=MyAsyncRetriever(),
163+
)
164+
result = await tool.run_async(
165+
args={"query": "world"}, tool_context=mock_tool_context
166+
)
167+
assert result == ["async object result for world"]

0 commit comments

Comments
 (0)