Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions src/google/adk/agents/remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,44 @@ class A2AClientError(Exception):
pass


def _build_agent_card_description(agent_card: AgentCard) -> str:
"""Build transfer context from an A2A agent card."""
description_parts: list[str] = []
if agent_card.description:
description_parts.append(agent_card.description)

skill_lines: list[str] = []
for skill in getattr(agent_card, "skills", None) or []:
skill_name = getattr(skill, "name", "")
skill_description = getattr(skill, "description", "")
skill_tags = getattr(skill, "tags", None) or []
skill_examples = getattr(skill, "examples", None) or []

if skill_name and skill_description:
skill_line = f"- {skill_name}: {skill_description}"
elif skill_name:
skill_line = f"- {skill_name}"
elif skill_description:
skill_line = f"- {skill_description}"
else:
continue

if skill_tags:
skill_line += f" [{', '.join(str(tag) for tag in skill_tags)}]"
skill_lines.append(skill_line)

for index, example in enumerate(skill_examples, start=1):
skill_lines.append(f" Example {index}: {example}")

if skill_lines:
if description_parts:
description_parts.append("")
description_parts.append("Capabilities:")
description_parts.extend(skill_lines)

return "\n".join(description_parts).strip()


def _add_mock_function_call(event: Event, state: TaskState) -> None:
"""Generates a mock function call for input-required events if applicable."""
if event.content is None:
Expand Down Expand Up @@ -328,18 +366,20 @@ async def _ensure_resolved(self) -> None:
return

try:
# Resolve agent card if needed
if not self._agent_card:
self._agent_card = await self._resolve_agent_card()

# Resolve agent card if needed
if not self._agent_card:
self._agent_card = await self._resolve_agent_card()
assert self._agent_card is not None

# Validate agent card
await self._validate_agent_card(self._agent_card)
# Validate agent card
await self._validate_agent_card(self._agent_card)

# Update description if empty
if not self.description and self._agent_card.description:
self.description = self._agent_card.description
# Keep transfer descriptions aligned with the resolved agent card.
if agent_card_description := _build_agent_card_description(
self._agent_card
):
self.description = agent_card_description

# Initialize A2A client
if not self._a2a_client:
Expand Down
21 changes: 21 additions & 0 deletions src/google/adk/flows/llm_flows/agent_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from __future__ import annotations

import asyncio
import inspect
import typing
from typing import Any
from typing import AsyncGenerator
Expand Down Expand Up @@ -48,6 +50,8 @@ async def run_async(
if not transfer_targets:
return

await _resolve_transfer_target_descriptions(transfer_targets)

transfer_to_agent_tool = TransferToAgentTool(
agent_names=[agent.name for agent in transfer_targets]
)
Expand All @@ -72,6 +76,23 @@ async def run_async(
request_processor = _AgentTransferLlmRequestProcessor()


async def _resolve_transfer_target_descriptions(
target_agents: list[Any],
) -> None:
"""Resolve target-agent metadata before transfer instructions are built."""
resolve_tasks = []
for target_agent in target_agents:
ensure_resolved = getattr(target_agent, '_ensure_resolved', None)
if not callable(ensure_resolved):
continue
maybe_awaitable = ensure_resolved()
if inspect.isawaitable(maybe_awaitable):
resolve_tasks.append(maybe_awaitable)

if resolve_tasks:
await asyncio.gather(*resolve_tasks)


def _build_target_agents_info(target_agent: Any) -> str:
# TODO: Refactor the annotation of the parameters
return f"""
Expand Down
52 changes: 51 additions & 1 deletion tests/unittests/agents/test_remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,54 @@ async def test_ensure_resolved_with_direct_agent_card(self):
assert agent._is_resolved is True
assert agent._a2a_client == mock_a2a_client

@pytest.mark.asyncio
async def test_ensure_resolved_enhances_description_from_agent_card(self):
"""Test _ensure_resolved builds transfer context from the agent card."""
agent_card = AgentCard(
name="research-agent",
url="https://example.com/rpc",
description="Answers research questions.",
version="1.0",
capabilities=AgentCapabilities(),
default_input_modes=["text/plain"],
default_output_modes=["application/json"],
skills=[
AgentSkill(
id="sec-search",
name="SEC Search",
description="Searches SEC filings.",
tags=["finance", "filings"],
examples=["Find Apple's latest 10-K risk factors."],
)
],
)
agent = RemoteA2aAgent(
name="test_agent",
agent_card=agent_card,
description="Local placeholder",
)

with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value = mock_client

with patch(
"google.adk.agents.remote_a2a_agent.A2AClientFactory"
) as mock_factory_class:
mock_factory = Mock()
mock_a2a_client = Mock()
mock_factory.create.return_value = mock_a2a_client
mock_factory_class.return_value = mock_factory

await agent._ensure_resolved()

assert agent.description == (
"Answers research questions.\n\n"
"Capabilities:\n"
"- SEC Search: Searches SEC filings. [finance, filings]\n"
" Example 1: Find Apple's latest 10-K risk factors."
)

@pytest.mark.asyncio
async def test_ensure_resolved_with_direct_agent_card_with_factory(self):
"""Test _ensure_resolved with direct agent card."""
Expand Down Expand Up @@ -509,7 +557,9 @@ async def test_ensure_resolved_with_url_source(self):

assert agent._is_resolved is True
assert agent._agent_card == agent_card
assert agent.description == agent_card.description
assert agent.description == (
"Test agent\n\nCapabilities:\n- Test Skill: A test skill [test]"
)

@pytest.mark.asyncio
async def test_ensure_resolved_already_resolved(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ async def _run_async_impl(
yield Event(author=self.name, invocation_id=ctx.invocation_id)


class _ResolvableAgent(_NonLlmAgent):
"""A non-LLM agent whose description is populated asynchronously."""

resolved: bool = False

async def _ensure_resolved(self) -> None:
self.description = 'Description from resolved agent card'
self.resolved = True


async def create_test_invocation_context(agent: Agent) -> InvocationContext:
"""Helper to create constructed InvocationContext."""
session_service = InMemorySessionService()
Expand Down Expand Up @@ -342,3 +352,30 @@ async def test_agent_transfer_with_non_llm_peer_agent():

instructions = llm_request.config.system_instruction
assert 'non_llm_peer' in instructions


@pytest.mark.asyncio
async def test_agent_transfer_resolves_target_descriptions_before_prompt():
"""Remote-style targets can populate their description before delegation."""
mockModel = testing_utils.MockModel.create(responses=[])

remote_sub_agent = _ResolvableAgent(name='remote_sub_agent')
main_agent = Agent(
name='main_agent',
model=mockModel,
sub_agents=[remote_sub_agent],
description='Main agent',
)

invocation_context = await create_test_invocation_context(main_agent)
llm_request = LlmRequest()

async for _ in agent_transfer.request_processor.run_async(
invocation_context, llm_request
):
pass

instructions = llm_request.config.system_instruction
assert remote_sub_agent.resolved is True
assert 'Agent name: remote_sub_agent' in instructions
assert 'Description from resolved agent card' in instructions