-
Notifications
You must be signed in to change notification settings - Fork 103
Expand file tree
/
Copy pathagent_history_sqlite.py
More file actions
159 lines (130 loc) · 5.56 KB
/
agent_history_sqlite.py
File metadata and controls
159 lines (130 loc) · 5.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import asyncio
import logging
import os
import random
import sqlite3
import uuid
from collections.abc import Sequence
from typing import Annotated, Any
from agent_framework import Agent, BaseHistoryProvider, Message, tool
from agent_framework.openai import OpenAIChatClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from pydantic import Field
from rich import print
from rich.logging import RichHandler
# Setup logging
handler = RichHandler(show_path=False, rich_tracebacks=True, show_level=False)
logging.basicConfig(level=logging.WARNING, handlers=[handler], force=True, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Configure OpenAI client based on environment
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = OpenAIChatClient(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=token_provider,
model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
)
elif API_HOST == "github":
client = OpenAIChatClient(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
model_id=os.getenv("GITHUB_MODEL", "openai/gpt-4.1-mini"),
)
else:
client = OpenAIChatClient(
api_key=os.environ["OPENAI_API_KEY"], model_id=os.environ.get("OPENAI_MODEL", "gpt-4.1-mini")
)
class SQLiteHistoryProvider(BaseHistoryProvider):
"""A custom history provider backed by SQLite.
Implements the BaseHistoryProvider to persist chat messages
in a local SQLite database — useful when you want file-based
persistence without an external service like Redis.
"""
def __init__(self, db_path: str):
super().__init__("sqlite-history")
self.db_path = db_path
self._conn = sqlite3.connect(self.db_path)
self._conn.execute(
"""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
message_json TEXT NOT NULL
)
"""
)
self._conn.commit()
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
"""Retrieve all messages for this session from SQLite."""
if session_id is None:
return []
cursor = self._conn.execute(
"SELECT message_json FROM messages WHERE session_id = ? ORDER BY id",
(session_id,),
)
return [Message.from_json(row[0]) for row in cursor.fetchall()]
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
"""Save messages to the SQLite database."""
if session_id is None:
return
self._conn.executemany(
"INSERT INTO messages (session_id, message_json) VALUES (?, ?)",
[(session_id, message.to_json()) for message in messages],
)
self._conn.commit()
def close(self) -> None:
"""Close the SQLite connection."""
self._conn.close()
@tool
def get_weather(
city: Annotated[str, Field(description="The city to get the weather for.")],
) -> str:
"""Returns weather data for a given city."""
logger.info(f"Getting weather for {city}")
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {city} is {conditions[random.randint(0, 3)]} with a high of {random.randint(10, 30)}°C."
async def main() -> None:
"""Demonstrate a SQLite-backed session that persists conversation history to a local file."""
db_path = "chat_history.sqlite3"
session_id = str(uuid.uuid4())
# Phase 1: Start a conversation with a SQLite-backed history provider
print("\n[bold]=== Persistent SQLite Session ===[/bold]")
print("[bold]--- Phase 1: Starting conversation ---[/bold]")
sqlite_provider = SQLiteHistoryProvider(db_path=db_path)
agent = Agent(
client=client,
instructions="You are a helpful weather agent.",
tools=[get_weather],
context_providers=[sqlite_provider],
)
session = agent.create_session(session_id=session_id)
print("[blue]User:[/blue] What's the weather like in Tokyo?")
response = await agent.run("What's the weather like in Tokyo?", session=session)
print(f"[green]Agent:[/green] {response.text}")
print("\n[blue]User:[/blue] How about Paris?")
response = await agent.run("How about Paris?", session=session)
print(f"[green]Agent:[/green] {response.text}")
# Phase 2: Simulate an application restart — reconnect to the same session ID in SQLite
print("\n[bold]--- Phase 2: Resuming after 'restart' ---[/bold]")
sqlite_provider2 = SQLiteHistoryProvider(db_path=db_path)
agent2 = Agent(
client=client,
instructions="You are a helpful weather agent.",
tools=[get_weather],
context_providers=[sqlite_provider2],
)
session2 = agent2.create_session(session_id=session_id)
print("[blue]User:[/blue] Which of the cities I asked about had better weather?")
response = await agent2.run("Which of the cities I asked about had better weather?", session=session2)
print(f"[green]Agent:[/green] {response.text}")
sqlite_provider2.close()
if async_credential:
await async_credential.close()
if __name__ == "__main__":
asyncio.run(main())