Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ initdb
inmemory
INR
isready
itk
ITK
jcs
jit
jku
JOSE
JPY
Expand Down Expand Up @@ -106,11 +109,13 @@ protoc
pydantic
pyi
pypistats
pyproto
pyupgrade
pyversions
redef
respx
resub
rmi
RS256
RUF
SECP256R1
Expand All @@ -127,7 +132,7 @@ taskupdate
testuuid
Tful
tiangolo
TResponse
typ
typeerror
vulnz
TResponse
Empty file added itk/__init__.py
Empty file.
353 changes: 353 additions & 0 deletions itk/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
import argparse # noqa: I001
import asyncio
import base64
import logging
import uuid

import grpc
import httpx
import uvicorn

from fastapi import FastAPI

from pyproto import instruction_pb2

from a2a.client import ClientConfig, ClientFactory
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
from a2a.server.events import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
from a2a.server.tasks import TaskUpdater
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types import a2a_pb2_grpc
from a2a.types.a2a_pb2 import (
AgentCapabilities,
AgentCard,
AgentInterface,
Message,
Part,
SendMessageRequest,
TaskState,
)
from a2a.utils import TransportProtocol


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def extract_instruction(
message: Message | None,
) -> instruction_pb2.Instruction | None:
"""Extracts an Instruction proto from an A2A Message."""
if not message or not message.parts:
return None

for part in message.parts:
# 1. Handle binary protobuf part (media_type or filename)
if (
part.media_type == 'application/x-protobuf'
or part.filename == 'instruction.bin'
):
try:
inst = instruction_pb2.Instruction()
if part.raw:
inst.ParseFromString(part.raw)
elif part.text:
# Some clients might send it as base64 in text part
raw = base64.b64decode(part.text)
inst.ParseFromString(raw)
except Exception: # noqa: BLE001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can hide unexpected issues and make debugging harder. It's generally better to catch more specific exceptions that you anticipate, such as google.protobuf.message.DecodeError for protobuf parsing failures, or base64.binascii.Error for base64 decoding issues. This helps in understanding the root cause of failures more quickly.

logger.debug(
'Failed to parse instruction from binary part',
exc_info=True,
)
continue
else:
return inst

# 2. Handle base64 encoded instruction in any text part
if part.text:
try:
raw = base64.b64decode(part.text)
inst = instruction_pb2.Instruction()
inst.ParseFromString(raw)
except Exception: # noqa: BLE001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, catching a generic Exception can obscure specific errors. Consider narrowing down the exception types to base64.binascii.Error and google.protobuf.message.DecodeError for more precise error handling.

logger.debug(
'Failed to parse instruction from text part', exc_info=True
)
continue
else:
return inst
return None


def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
"""Wraps an Instruction proto into an A2A Message."""
inst_bytes = inst.SerializeToString()
return Message(
role='ROLE_USER',
message_id=str(uuid.uuid4()),
parts=[
Part(
raw=inst_bytes,
media_type='application/x-protobuf',
filename='instruction.bin',
)
],
)


async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
"""Handles the CallAgent instruction by invoking another agent."""
logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport)

# Mapping transport string to TransportProtocol enum
transport_map = {
'JSONRPC': TransportProtocol.JSONRPC,
'HTTP+JSON': TransportProtocol.HTTP_JSON,
'HTTP_JSON': TransportProtocol.HTTP_JSON,
'REST': TransportProtocol.HTTP_JSON,
'GRPC': TransportProtocol.GRPC,
}

selected_transport = transport_map.get(
call.transport.upper(), TransportProtocol.JSONRPC
)
if selected_transport is None:
raise ValueError(f'Unsupported transport: {call.transport}')
Comment on lines +120 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The selected_transport will never be None here because transport_map.get provides a default value (TransportProtocol.JSONRPC). This if condition is therefore redundant and can be removed.

Suggested change
if selected_transport is None:
raise ValueError(f'Unsupported transport: {call.transport}')
if not selected_transport:
raise ValueError(f'Unsupported transport: {call.transport}')
References
  1. If a field in a data model (e.g., ServerCallContext.user) is non-optional, avoid adding redundant checks for its existence and rely on the data model's contract.


config = ClientConfig()
config.httpx_client = httpx.AsyncClient(timeout=30.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The httpx.AsyncClient timeout is hardcoded to 30.0 seconds. It would be more flexible to make this configurable, perhaps through an environment variable or a parameter passed to main_async, especially for a test suite where different timeout behaviors might be desired.

config.grpc_channel_factory = grpc.aio.insecure_channel
config.supported_protocol_bindings = [selected_transport]
config.streaming = call.streaming or (
selected_transport == TransportProtocol.GRPC
)

try:
client = await ClientFactory.connect(
call.agent_card_uri,
client_config=config,
)

# Wrap nested instruction
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)

results = []
async for event in client.send_message(request):
# Event is streaming response and task
logger.info('Event: %s', event)
stream_resp, task = event

message = None
if stream_resp.HasField('message'):
message = stream_resp.message
elif task and task.status.HasField('message'):
message = task.status.message
elif stream_resp.HasField(
'status_update'
) and stream_resp.status_update.status.HasField('message'):
message = stream_resp.status_update.status.message

if message:
results.extend(part.text for part in message.parts if part.text)

except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A broad except Exception as e can hide specific issues. Consider catching more specific exceptions that might occur during client communication, such as httpx.RequestError, grpc.aio.AioRpcError, or custom A2A client exceptions, to provide more targeted error handling and logging.

logger.exception('Failed to call outbound agent')
raise RuntimeError(
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
) from e
else:
return results


async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]:
"""Recursively handles instructions."""
if inst.HasField('call_agent'):
return await handle_call_agent(inst.call_agent)
if inst.HasField('return_response'):
return [inst.return_response.response]
if inst.HasField('steps'):
all_results = []
for step in inst.steps.instructions:
results = await handle_instruction(step)
all_results.extend(results)
return all_results
raise ValueError('Unknown instruction type')


class V10AgentExecutor(AgentExecutor):
"""Executor for ITK v10 agent tasks."""

async def execute(
self, context: RequestContext, event_queue: EventQueue
) -> None:
"""Executes a task instruction."""
logger.info('Executing task %s', context.task_id)
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)

await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
await task_updater.update_status(TaskState.TASK_STATE_WORKING)

instruction = extract_instruction(context.message)
if not instruction:
error_msg = 'No valid instruction found in request'
logger.error(error_msg)
await task_updater.update_status(
TaskState.TASK_STATE_FAILED,
message=task_updater.new_agent_message([Part(text=error_msg)]),
)
return

try:
logger.info('Instruction: %s', instruction)
results = await handle_instruction(instruction)
response_text = '\n'.join(results)
logger.info('Response: %s', response_text)
await task_updater.update_status(
TaskState.TASK_STATE_COMPLETED,
message=task_updater.new_agent_message(
[Part(text=response_text)]
),
)
logger.info('Task %s completed', context.task_id)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to other try...except blocks, catching a generic Exception can mask underlying problems. It's advisable to catch more specific exceptions that handle_instruction or other parts of the execution might raise, such as ValueError for unknown instruction types, to improve error clarity and debugging.

logger.exception('Error during instruction handling')
await task_updater.update_status(
TaskState.TASK_STATE_FAILED,
message=task_updater.new_agent_message([Part(text=str(e))]),
)

async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
"""Cancels a task."""
logger.info('Cancel requested for task %s', context.task_id)
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)
await task_updater.update_status(TaskState.TASK_STATE_CANCELED)


async def main_async(http_port: int, grpc_port: int) -> None:
"""Starts the Agent with HTTP and gRPC interfaces."""
interfaces = [
AgentInterface(
protocol_binding=TransportProtocol.GRPC,
url=f'127.0.0.1:{grpc_port}',
protocol_version='1.0',
),
AgentInterface(
protocol_binding=TransportProtocol.GRPC,
url=f'127.0.0.1:{grpc_port}',
protocol_version='0.3',
),
]

interfaces.append(
AgentInterface(
protocol_binding=TransportProtocol.JSONRPC,
url=f'http://127.0.0.1:{http_port}/jsonrpc/',
)
)
interfaces.append(
AgentInterface(
protocol_binding=TransportProtocol.HTTP_JSON,
url=f'http://127.0.0.1:{http_port}/rest/',
protocol_version='1.0',
)
)
interfaces.append(
AgentInterface(
protocol_binding=TransportProtocol.HTTP_JSON,
url=f'http://127.0.0.1:{http_port}/rest/',
protocol_version='0.3',
)
)
Comment on lines +248 to +277
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The AgentInterface definitions for gRPC and HTTP_JSON are duplicated for different protocol versions (1.0 and 0.3) but use the same url. While this might be intentional for compatibility testing, it could be made more explicit or potentially refactored if the URLs are always expected to be the same for different versions of the same protocol binding. For example, a loop could generate these if the pattern is consistent.


agent_card = AgentCard(
name='ITK v10 Agent',
description='Python agent using SDK 1.0.',
version='1.0.0',
capabilities=AgentCapabilities(streaming=True),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=interfaces,
)

task_store = InMemoryTaskStore()
handler = DefaultRequestHandler(
agent_executor=V10AgentExecutor(),
task_store=task_store,
queue_manager=InMemoryQueueManager(),
)

app = FastAPI()

json_rpc_app = A2AFastAPIApplication(
agent_card, handler, enable_v0_3_compat=True
).build()
app.mount('/jsonrpc', json_rpc_app)
rest_app = A2ARESTFastAPIApplication(
http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True
).build()
app.mount('/rest', rest_app)

server = grpc.aio.server()

compat_servicer = CompatGrpcHandler(agent_card, handler)
a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server)
servicer = GrpcHandler(agent_card, handler)
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server)

server.add_insecure_port(f'127.0.0.1:{grpc_port}')
await server.start()

logger.info(
'Starting ITK v10 Agent on HTTP port %s and gRPC port %s',
http_port,
grpc_port,
)

config = uvicorn.Config(
app, host='127.0.0.1', port=http_port, log_level='info'
)
uvicorn_server = uvicorn.Server(config)
Comment on lines +323 to +326
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The uvicorn.Config parameters host='127.0.0.1' and port=http_port are hardcoded. For a test suite, it would be beneficial to make these configurable via command-line arguments or environment variables, similar to how httpPort and grpcPort are handled, to allow for testing on different network configurations or ports.


await uvicorn_server.serve()


def str2bool(v: str | bool) -> bool:
"""Converts a string to a boolean value."""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
if v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
raise argparse.ArgumentTypeError('Boolean value expected.')


def main() -> None:
"""Main entry point for the agent."""
parser = argparse.ArgumentParser()
parser.add_argument('--httpPort', type=int, default=10102)
parser.add_argument('--grpcPort', type=int, default=11002)
args = parser.parse_args()

asyncio.run(main_async(args.httpPort, args.grpcPort))


if __name__ == '__main__':
main()
Loading
Loading