-
Notifications
You must be signed in to change notification settings - Fork 103
Expand file tree
/
Copy pathagent_supervisor.py
More file actions
206 lines (172 loc) · 6.81 KB
/
agent_supervisor.py
File metadata and controls
206 lines (172 loc) · 6.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import asyncio
import logging
import os
import random
import sys
from datetime import datetime
from typing import Annotated
from agent_framework import Agent, tool
from agent_framework.openai import OpenAIChatClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from pydantic import Field
from rich import print
from rich.logging import RichHandler
# Setup logging
handler = RichHandler(show_path=False, rich_tracebacks=True, show_level=False)
logging.basicConfig(level=logging.WARNING, handlers=[handler], force=True, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Configure OpenAI client based on environment
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = OpenAIChatClient(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=token_provider,
model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
)
elif API_HOST == "github":
client = OpenAIChatClient(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
model_id=os.getenv("GITHUB_MODEL", "openai/gpt-4.1-mini"),
)
else:
client = OpenAIChatClient(
api_key=os.environ["OPENAI_API_KEY"], model_id=os.environ.get("OPENAI_MODEL", "gpt-4.1-mini")
)
# ----------------------------------------------------------------------------------
# Sub-agent 1 tools: weekend planning
# ----------------------------------------------------------------------------------
@tool
def get_weather(
city: Annotated[str, Field(description="The city to get the weather for.")],
date: Annotated[str, Field(description="The date to get weather for in format YYYY-MM-DD.")],
) -> dict:
"""Returns weather data for a given city and date."""
logger.info(f"Getting weather for {city} on {date}")
if random.random() < 0.05:
return {"temperature": 72, "description": "Sunny"}
else:
return {"temperature": 60, "description": "Rainy"}
@tool
def get_activities(
city: Annotated[str, Field(description="The city to get activities for.")],
date: Annotated[str, Field(description="The date to get activities for in format YYYY-MM-DD.")],
) -> list[dict]:
"""Returns a list of activities for a given city and date."""
logger.info(f"Getting activities for {city} on {date}")
return [
{"name": "Hiking", "location": city},
{"name": "Beach", "location": city},
{"name": "Museum", "location": city},
]
@tool
def get_current_date() -> str:
"""Gets the current date from the system (YYYY-MM-DD)."""
logger.info("Getting current date")
return datetime.now().strftime("%Y-%m-%d")
weekend_agent = Agent(
client=client,
instructions=(
"You help users plan their weekends and choose the best activities for the given weather. "
"If an activity would be unpleasant in the weather, don't suggest it. "
"Include the date of the weekend in your response."
),
tools=[get_weather, get_activities, get_current_date],
)
@tool
async def plan_weekend(query: str) -> str:
"""Plan a weekend based on user query and return the final response."""
logger.info("Tool: plan_weekend invoked")
response = await weekend_agent.run(query)
return response.text
# ----------------------------------------------------------------------------------
# Sub-agent 2 tools: meal planning
# ----------------------------------------------------------------------------------
@tool
def find_recipes(
query: Annotated[str, Field(description="User query or desired meal/ingredient")],
) -> list[dict]:
"""Returns recipes (JSON) based on a query."""
logger.info(f"Finding recipes for '{query}'")
if "pasta" in query.lower():
recipes = [
{
"title": "Pasta Primavera",
"ingredients": ["pasta", "vegetables", "olive oil"],
"steps": ["Cook pasta.", "Sauté vegetables."],
}
]
elif "tofu" in query.lower():
recipes = [
{
"title": "Tofu Stir Fry",
"ingredients": ["tofu", "soy sauce", "vegetables"],
"steps": ["Cube tofu.", "Stir fry veggies."],
}
]
else:
recipes = [
{
"title": "Grilled Cheese Sandwich",
"ingredients": ["bread", "cheese", "butter"],
"steps": ["Butter bread.", "Place cheese between slices.", "Grill until golden brown."],
}
]
return recipes
@tool
def check_fridge() -> list[str]:
"""Returns a JSON list of ingredients currently in the fridge."""
logger.info("Checking fridge for current ingredients")
if random.random() < 0.5:
items = ["pasta", "tomato sauce", "bell peppers", "olive oil"]
else:
items = ["tofu", "soy sauce", "broccoli", "carrots"]
logger.info("Returned")
return items
meal_agent = Agent(
client=client,
instructions=(
"You help users plan meals and choose the best recipes. "
"Include the ingredients and cooking instructions in your response. "
"Indicate what the user needs to buy from the store when their fridge is missing ingredients."
),
tools=[find_recipes, check_fridge],
)
@tool
async def plan_meal(query: str) -> str:
"""Plan a meal based on user query and return the final response."""
logger.info("Tool: plan_meal invoked")
response = await meal_agent.run(query)
return response.text
# ----------------------------------------------------------------------------------
# Supervisor agent orchestrating sub-agents
# ----------------------------------------------------------------------------------
supervisor_agent = Agent(
name="supervisor",
client=client,
instructions=(
"You are a supervisor managing two specialist agents: a weekend planning agent and a meal planning agent. "
"Break down the user's request, decide which specialist (or both) to call via the available tools, "
"and then synthesize a final helpful answer. When invoking a tool, provide clear, concise queries."
),
tools=[plan_weekend, plan_meal],
)
async def main():
user_query = "my kids want pasta for dinner"
response = await supervisor_agent.run(user_query)
print(response.text)
if async_credential:
await async_credential.close()
if __name__ == "__main__":
logger.setLevel(logging.INFO)
if "--devui" in sys.argv:
from agent_framework.devui import serve
serve(entities=[supervisor_agent], auto_open=True)
else:
asyncio.run(main())