-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathworkflow_hitl_requests_structured.py
More file actions
200 lines (158 loc) · 6.59 KB
/
workflow_hitl_requests_structured.py
File metadata and controls
200 lines (158 loc) · 6.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Trip planner workflow with human-in-the-loop via requests and responses.
Demonstrates: ctx.request_info(), @response_handler, structured agent output,
and driving the HITL loop from application code.
The user starts with a vague travel request like "I want to go somewhere warm."
The trip planner agent asks clarifying questions one at a time (destination,
budget, interests, dates). After each question, the workflow pauses and waits
for the human's answer. Once the agent has enough information, it produces a
final itinerary.
Run:
uv run examples/workflow_hitl_requests_structured.py
"""
import asyncio
import os
from dataclasses import dataclass
from typing import Literal
from agent_framework import (
Agent,
AgentExecutorRequest,
AgentExecutorResponse,
AgentResponseUpdate,
Executor,
Message,
WorkflowBuilder,
WorkflowContext,
handler,
response_handler,
)
from agent_framework.openai import OpenAIChatClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from pydantic import BaseModel
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
# Configure the chat client based on the API host
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = OpenAIChatClient(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=token_provider,
model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
)
elif API_HOST == "github":
client = OpenAIChatClient(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
model_id=os.getenv("GITHUB_MODEL", "openai/gpt-5-mini"),
)
else:
client = OpenAIChatClient(
api_key=os.environ["OPENAI_API_KEY"], model_id=os.environ.get("OPENAI_MODEL", "gpt-5-mini")
)
# --- Structured output models ---
class PlannerOutput(BaseModel):
"""Structured output from the trip planner agent."""
status: Literal["need_info", "complete"]
question: str | None = None
itinerary: str | None = None
# --- HITL request dataclass ---
@dataclass
class UserPrompt:
"""Request sent to the human when the agent needs more information."""
message: str
# --- Executor that coordinates agent ↔ human turns ---
class TripCoordinator(Executor):
"""Coordinates turns between the trip planner agent and the human.
- After each agent reply, checks if more info is needed.
- If so, requests human input via ctx.request_info().
- If the agent has enough info, yields the final itinerary.
"""
def __init__(self, agent_id: str, id: str = "trip_coordinator"):
super().__init__(id=id)
self._agent_id = agent_id
@handler
async def start(self, request: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
"""Kick off the first agent turn with the user's vague request."""
user_msg = Message("user", text=request)
await ctx.send_message(
AgentExecutorRequest(messages=[user_msg], should_respond=True),
target_id=self._agent_id,
)
@handler
async def on_agent_response(self, result: AgentExecutorResponse, ctx: WorkflowContext) -> None:
"""Handle the agent's structured response."""
# Parse structured output from text (AgentResponse.value may be None
# in streaming workflows due to response_format not propagating).
output = PlannerOutput.model_validate_json(result.agent_response.text)
if output.status == "need_info" and output.question:
# Pause and ask the human
await ctx.request_info(
request_data=UserPrompt(message=output.question),
response_type=str,
)
else:
await ctx.yield_output(output.itinerary or "No itinerary generated.")
@response_handler
async def on_human_answer(
self,
original_request: UserPrompt,
answer: str,
ctx: WorkflowContext[AgentExecutorRequest, str],
) -> None:
"""Forward the human's answer back to the agent."""
user_msg = Message("user", text=answer)
await ctx.send_message(
AgentExecutorRequest(messages=[user_msg], should_respond=True),
target_id=self._agent_id,
)
# --- Main ---
async def main() -> None:
"""Run the trip planner HITL workflow."""
planner_agent = Agent(
name="TripPlanner",
instructions=(
"You are a helpful trip planner. The user has a vague travel idea and you need to "
"gather enough details to create a personalized itinerary.\n"
"Ask clarifying questions ONE AT A TIME about: destination preferences, travel dates, "
"budget, interests/activities, and group size.\n"
"Once you have enough information (at least destination, dates, and budget), "
'produce a final itinerary.\n\n'
"You MUST return ONLY a JSON object matching this schema:\n"
' {"status": "need_info", "question": "your question here"}\n'
" OR\n"
' {"status": "complete", "itinerary": "your full itinerary here"}\n'
"No explanations or additional text outside the JSON."
),
client=client,
default_options={"response_format": PlannerOutput},
)
coordinator = TripCoordinator(agent_id="TripPlanner")
workflow = (
WorkflowBuilder(start_executor=coordinator)
.add_edge(coordinator, planner_agent)
.add_edge(planner_agent, coordinator)
.build()
)
user_request = "I want to go somewhere warm next month"
print(f"▶️ Starting trip planner with: \"{user_request}\"\n")
stream = workflow.run(user_request, stream=True)
while True:
pending: dict[str, str] = {}
async for event in stream:
if event.type == "request_info":
pending[event.request_id] = event.data
elif event.type == "output" and not isinstance(event.data, AgentResponseUpdate):
print(f"\n📍 Itinerary:\n{event.data}")
if not pending:
break
for request_id, request in pending.items():
print(f"\n⏸️ Agent asks: {request.message}")
answer = input("💬 Your answer: ")
pending[request_id] = answer
stream = workflow.run(stream=True, responses=pending)
if async_credential:
await async_credential.close()
if __name__ == "__main__":
asyncio.run(main())