-
Notifications
You must be signed in to change notification settings - Fork 108
Expand file tree
/
Copy patha2a.py
More file actions
256 lines (204 loc) · 8.18 KB
/
a2a.py
File metadata and controls
256 lines (204 loc) · 8.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""A2A protocol support for Bedrock AgentCore Runtime.
Provides Bedrock-specific glue around the official a2a-sdk, handling header
extraction, health checks, and Docker host detection.
"""
import logging
import time
import uuid
from typing import Any, Callable, Optional
from .context import BedrockAgentCoreContext
from .models import (
ACCESS_TOKEN_HEADER,
AGENTCORE_RUNTIME_URL_ENV,
AUTHORIZATION_HEADER,
CUSTOM_HEADER_PREFIX,
OAUTH2_CALLBACK_URL_HEADER,
REQUEST_ID_HEADER,
SESSION_HEADER,
PingStatus,
)
logger = logging.getLogger(__name__)
def _check_a2a_sdk() -> None:
"""Raise ImportError with install instructions if a2a-sdk is missing."""
try:
import a2a # noqa: F401
except ImportError:
raise ImportError(
'a2a-sdk is required for A2A protocol support. Install it with: pip install "bedrock-agentcore[a2a]"'
) from None
def _build_agent_card(executor: Any, url: str) -> Any:
"""Build an AgentCard by introspecting a StrandsA2AExecutor.
Extracts name/description from ``executor.agent``. Falls back to generic
defaults for other executors.
"""
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
name = "agent"
description = "A Bedrock AgentCore agent"
agent = getattr(executor, "agent", None)
if agent is not None:
name = getattr(agent, "name", None) or name
description = getattr(agent, "description", None) or description
return AgentCard(
name=name,
description=description,
url=url,
version="0.1.0",
capabilities=AgentCapabilities(streaming=True),
skills=[AgentSkill(id="main", name=name, description=description, tags=["main"])],
default_input_modes=["text"],
default_output_modes=["text"],
)
class BedrockCallContextBuilder:
"""Extracts Bedrock runtime headers and propagates them into BedrockAgentCoreContext.
Implements the a2a-sdk CallContextBuilder ABC so the A2A server
automatically calls ``build()`` on every incoming request.
"""
def build(self, request: Any) -> Any:
"""Build a ServerCallContext from a Starlette Request.
Args:
request: A Starlette Request object.
Returns:
A ServerCallContext with Bedrock headers stored in ``state``.
"""
from a2a.server.context import ServerCallContext
headers = request.headers
request_id = headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4())
session_id = headers.get(SESSION_HEADER)
BedrockAgentCoreContext.set_request_context(request_id, session_id)
workload_access_token = headers.get(ACCESS_TOKEN_HEADER)
if workload_access_token:
BedrockAgentCoreContext.set_workload_access_token(workload_access_token)
oauth2_callback_url = headers.get(OAUTH2_CALLBACK_URL_HEADER)
if oauth2_callback_url:
BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url)
request_headers: dict[str, str] = {}
authorization_header = headers.get(AUTHORIZATION_HEADER)
if authorization_header is not None:
request_headers[AUTHORIZATION_HEADER] = authorization_header
for header_name, header_value in headers.items():
if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()):
request_headers[header_name] = header_value
if request_headers:
BedrockAgentCoreContext.set_request_headers(request_headers)
state = {
"request_id": request_id,
"session_id": session_id,
}
if workload_access_token:
state["workload_access_token"] = workload_access_token
if oauth2_callback_url:
state["oauth2_callback_url"] = oauth2_callback_url
return ServerCallContext(state=state)
# Register as a virtual subclass so isinstance checks pass without
# requiring a2a-sdk to be importable at class-definition time.
try:
from a2a.server.apps import CallContextBuilder
CallContextBuilder.register(BedrockCallContextBuilder)
except Exception: # pragma: no cover
pass
def build_a2a_app(
executor: Any,
agent_card: Any = None,
*,
task_store: Any = None,
context_builder: Any = None,
ping_handler: Optional[Callable[[], PingStatus]] = None,
) -> Any:
"""Build a Starlette app wired for A2A protocol with Bedrock extras.
Args:
executor: An ``AgentExecutor`` that implements the agent logic.
agent_card: Optional ``a2a.types.AgentCard`` describing the agent.
If ``None``, one is built automatically by introspecting the executor.
task_store: Optional ``TaskStore``; defaults to ``InMemoryTaskStore``.
context_builder: Optional ``CallContextBuilder``; defaults to
``BedrockCallContextBuilder``.
ping_handler: Optional callback returning a ``PingStatus``.
Returns:
A Starlette application.
"""
import os
_check_a2a_sdk()
from a2a.server.apps import A2AStarletteApplication
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from starlette.responses import JSONResponse
from starlette.routing import Route
runtime_url = os.environ.get(AGENTCORE_RUNTIME_URL_ENV, "http://localhost:9000/")
if agent_card is None:
agent_card = _build_agent_card(executor, runtime_url)
elif os.environ.get(AGENTCORE_RUNTIME_URL_ENV):
agent_card.url = runtime_url
if task_store is None:
task_store = InMemoryTaskStore()
if context_builder is None:
context_builder = BedrockCallContextBuilder()
http_handler = DefaultRequestHandler(
agent_executor=executor,
task_store=task_store,
)
a2a_app = A2AStarletteApplication(
agent_card=agent_card,
http_handler=http_handler,
context_builder=context_builder,
)
app = a2a_app.build()
last_status_update_time = time.time()
def _handle_ping(request: Any) -> JSONResponse:
nonlocal last_status_update_time
try:
if ping_handler is not None:
status = ping_handler()
else:
status = PingStatus.HEALTHY
last_status_update_time = time.time()
except Exception:
logger.exception("Custom ping handler failed, falling back to Healthy")
status = PingStatus.HEALTHY
return JSONResponse({"status": status.value, "time_of_last_update": int(last_status_update_time)})
app.routes.append(Route("/ping", _handle_ping, methods=["GET"]))
return app
def serve_a2a(
executor: Any,
agent_card: Any = None,
*,
port: int = 9000,
host: Optional[str] = None,
task_store: Any = None,
context_builder: Any = None,
ping_handler: Optional[Callable[[], PingStatus]] = None,
**kwargs: Any,
) -> None:
"""Start a Bedrock-compatible A2A server.
Args:
executor: An ``AgentExecutor`` that implements the agent logic.
agent_card: Optional ``a2a.types.AgentCard`` describing the agent.
If ``None``, one is built automatically by introspecting the executor.
port: Port to serve on (default 9000).
host: Host to bind to; auto-detected if ``None``.
task_store: Optional ``TaskStore``; defaults to ``InMemoryTaskStore``.
context_builder: Optional ``CallContextBuilder``; defaults to
``BedrockCallContextBuilder``.
ping_handler: Optional callback returning a ``PingStatus``.
**kwargs: Additional arguments forwarded to ``uvicorn.run()``.
"""
import os
import uvicorn
app = build_a2a_app(
executor,
agent_card,
task_store=task_store,
context_builder=context_builder,
ping_handler=ping_handler,
)
if host is None:
if os.path.exists("/.dockerenv") or os.environ.get("DOCKER_CONTAINER"):
host = "0.0.0.0" # nosec B104 - Container needs this to expose the port
else:
host = "127.0.0.1"
uvicorn_params: dict[str, Any] = {
"host": host,
"port": port,
"log_level": "info",
}
uvicorn_params.update(kwargs)
uvicorn.run(app, **uvicorn_params)