-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathagent_tools.py
More file actions
112 lines (94 loc) · 3.46 KB
/
agent_tools.py
File metadata and controls
112 lines (94 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import asyncio
import logging
import os
import random
import sys
from datetime import datetime
from typing import Annotated
from agent_framework import Agent, tool
from agent_framework.openai import OpenAIChatClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from pydantic import Field
from rich import print
from rich.logging import RichHandler
# Setup logging
handler = RichHandler(show_path=False, rich_tracebacks=True, show_level=False)
logging.basicConfig(level=logging.WARNING, handlers=[handler], force=True, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Configure OpenAI client based on environment
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = OpenAIChatClient(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=token_provider,
model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
)
elif API_HOST == "github":
client = OpenAIChatClient(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
model_id=os.getenv("GITHUB_MODEL", "openai/gpt-4.1-mini"),
)
else:
client = OpenAIChatClient(
api_key=os.environ["OPENAI_API_KEY"], model_id=os.environ.get("OPENAI_MODEL", "gpt-4.1-mini")
)
@tool
def get_weather(
city: Annotated[str, Field(description="The city to get the weather for.")],
) -> dict:
"""Returns weather data for a given city, a dictionary with temperature and description."""
logger.info(f"Getting weather for {city}")
if random.random() < 0.05:
return {
"temperature": 72,
"description": "Sunny",
}
else:
return {
"temperature": 60,
"description": "Rainy",
}
@tool
def get_activities(
city: Annotated[str, Field(description="The city to get activities for.")],
date: Annotated[str, Field(description="The date to get activities for in format YYYY-MM-DD.")],
) -> list[dict]:
"""Returns a list of activities for a given city and date."""
logger.info(f"Getting activities for {city} on {date}")
return [
{"name": "Hiking", "location": city},
{"name": "Beach", "location": city},
{"name": "Museum", "location": city},
]
@tool
def get_current_date() -> str:
"""Gets the current date from the system and returns as a string in format YYYY-MM-DD."""
logger.info("Getting current date")
return datetime.now().strftime("%Y-%m-%d")
agent = Agent(
client=client,
name="weekend-planner",
instructions=(
"You help users plan their weekends and choose the best activities for the given weather. "
"If an activity would be unpleasant in weather, don't suggest it. Include date of the weekend in response."
),
tools=[get_weather, get_activities, get_current_date],
)
async def main():
response = await agent.run("what can I do this weekend in San Francisco?")
print(response.text)
if async_credential:
await async_credential.close()
if __name__ == "__main__":
if "--devui" in sys.argv:
from agent_framework.devui import serve
serve(entities=[agent], auto_open=True)
else:
asyncio.run(main())