-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathagent_knowledge_pg.py
More file actions
381 lines (327 loc) Β· 14.7 KB
/
agent_knowledge_pg.py
File metadata and controls
381 lines (327 loc) Β· 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""
Knowledge retrieval (RAG) via PostgreSQL with hybrid search (vector + full-text).
Diagram:
Input βββΆ Agent βββββββββββββββββββΆ LLM βββΆ Response
β β²
β search with input β relevant knowledge
βΌ β
ββββββββββββββ β
β Knowledge βββββββββββββββββ
β store β
β (Postgres) β
ββββββββββββββ
This example uses pgvector for vector similarity search and PostgreSQL's
built-in tsvector for full-text search, combining them with Reciprocal
Rank Fusion (RRF) for hybrid retrieval. The agent searches the knowledge
store *before* asking the LLM β no tool call needed.
Requires:
- PostgreSQL with pgvector extension (see docker-compose.yml)
- An embedding model (GitHub Models, Azure OpenAI, or OpenAI)
See also: agent_knowledge_sqlite.py for a simpler SQLite-only (keyword search) version.
"""
import asyncio
import logging
import os
import sys
from typing import Any
import psycopg
from openai import OpenAI
from pgvector.psycopg import register_vector
from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext, SupportsAgentRun
from agent_framework.openai import OpenAIChatClient
from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential
from azure.identity import get_bearer_token_provider as sync_get_bearer_token_provider
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from rich import print
from rich.logging import RichHandler
# ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
handler = RichHandler(show_path=False, rich_tracebacks=True, show_level=False)
logging.basicConfig(level=logging.WARNING, handlers=[handler], force=True, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# ββ OpenAI clients (chat + embeddings) βββββββββββββββββββββββββββββββ
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
POSTGRES_URL = os.getenv("POSTGRES_URL", "postgresql://admin:LocalPasswordOnly@db:5432/postgres")
EMBEDDING_DIMENSIONS = 256 # Smaller dimension for efficiency
async_credential = None
if API_HOST == "azure":
# Async credential for the agent framework chat client
async_credential = DefaultAzureCredential()
async_token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
# Sync credential for the OpenAI SDK embed client
sync_credential = SyncDefaultAzureCredential()
sync_token_provider = sync_get_bearer_token_provider(sync_credential, "https://cognitiveservices.azure.com/.default")
chat_client = OpenAIChatClient(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=async_token_provider,
model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
)
embed_client = OpenAI(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=sync_token_provider(),
)
embed_model = os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-3-small")
elif API_HOST == "github":
chat_client = OpenAIChatClient(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
model_id=os.getenv("GITHUB_MODEL", "openai/gpt-4.1-mini"),
)
embed_client = OpenAI(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
)
embed_model = "text-embedding-3-small"
else:
chat_client = OpenAIChatClient(
api_key=os.environ["OPENAI_API_KEY"], model_id=os.environ.get("OPENAI_MODEL", "gpt-4.1-mini")
)
embed_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
embed_model = "text-embedding-3-small"
def get_embedding(text: str) -> list[float]:
"""Get an embedding vector for the given text."""
response = embed_client.embeddings.create(input=text, model=embed_model, dimensions=EMBEDDING_DIMENSIONS)
return response.data[0].embedding
# ββ Knowledge store (PostgreSQL + pgvector) ββββββββββββββββββββββββββ
PRODUCTS = [
{
"name": "TrailBlaze Hiking Boots",
"category": "Footwear",
"price": 149.99,
"description": (
"Waterproof hiking boots with Vibram soles, ankle support, "
"and breathable Gore-Tex lining. Ideal for rocky trails and wet conditions."
),
},
{
"name": "SummitPack 40L Backpack",
"category": "Bags",
"price": 89.95,
"description": (
"Lightweight 40-liter backpack with hydration sleeve, rain cover, "
"and ergonomic hip belt. Great for day hikes and overnight trips."
),
},
{
"name": "ArcticShield Down Jacket",
"category": "Clothing",
"price": 199.00,
"description": (
"800-fill goose down jacket rated to -20Β°F. "
"Features a water-resistant shell, packable design, and adjustable hood."
),
},
{
"name": "RiverRun Kayak Paddle",
"category": "Water Sports",
"price": 74.50,
"description": (
"Fiberglass kayak paddle with adjustable ferrule and drip rings. "
"Lightweight at 28 oz, suitable for touring and recreational kayaking."
),
},
{
"name": "TerraFirm Trekking Poles",
"category": "Accessories",
"price": 59.99,
"description": (
"Collapsible carbon-fiber trekking poles with cork grips and tungsten tips. "
"Adjustable from 24 to 54 inches, with anti-shock springs."
),
},
{
"name": "ClearView Binoculars 10x42",
"category": "Optics",
"price": 129.00,
"description": (
"Roof-prism binoculars with 10x magnification and 42mm objective lenses. "
"Nitrogen-purged and waterproof. Ideal for birding and wildlife observation."
),
},
{
"name": "NightGlow LED Headlamp",
"category": "Lighting",
"price": 34.99,
"description": (
"Rechargeable 350-lumen headlamp with red-light mode and adjustable beam. "
"IPX6 waterproof rating, runs up to 40 hours on low."
),
},
{
"name": "CozyNest Sleeping Bag",
"category": "Camping",
"price": 109.00,
"description": (
"Three-season mummy sleeping bag rated to 20Β°F. "
"Synthetic insulation, compression sack included. Weighs 2.5 lbs."
),
},
]
def create_knowledge_db(conn: psycopg.Connection) -> None:
"""Create the product catalog in PostgreSQL with pgvector and full-text search indexes."""
conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
register_vector(conn)
conn.execute("DROP TABLE IF EXISTS products")
conn.execute(
f"""
CREATE TABLE products (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
category TEXT NOT NULL,
price REAL NOT NULL,
description TEXT NOT NULL,
embedding vector({EMBEDDING_DIMENSIONS})
)
"""
)
# GIN index for full-text search on name + description
conn.execute(
"CREATE INDEX ON products USING GIN (to_tsvector('english', name || ' ' || description))"
)
logger.info("[π Knowledge] Generating embeddings for %d products...", len(PRODUCTS))
for product in PRODUCTS:
text_for_embedding = f"{product['name']} - {product['category']}: {product['description']}"
embedding = get_embedding(text_for_embedding)
conn.execute(
"INSERT INTO products (name, category, price, description, embedding) VALUES (%s, %s, %s, %s, %s)",
(product["name"], product["category"], product["price"], product["description"], embedding),
)
conn.commit()
logger.info("[π Knowledge] Product catalog seeded with embeddings.")
# ββ Custom context provider for hybrid knowledge retrieval βββββββββββ
# Hybrid search SQL using Reciprocal Rank Fusion (RRF)
# Combines vector similarity and full-text search results
HYBRID_SEARCH_SQL = f"""
WITH semantic_search AS (
SELECT id, RANK() OVER (ORDER BY embedding <=> %(embedding)s::vector({EMBEDDING_DIMENSIONS})) AS rank
FROM products
ORDER BY embedding <=> %(embedding)s::vector({EMBEDDING_DIMENSIONS})
LIMIT 20
),
keyword_search AS (
SELECT id, RANK() OVER (ORDER BY ts_rank_cd(to_tsvector('english', name || ' ' || description), query) DESC)
FROM products, plainto_tsquery('english', %(query)s) query
WHERE to_tsvector('english', name || ' ' || description) @@ query
ORDER BY ts_rank_cd(to_tsvector('english', name || ' ' || description), query) DESC
LIMIT 20
)
SELECT
COALESCE(semantic_search.id, keyword_search.id) AS id,
COALESCE(1.0 / (%(k)s + semantic_search.rank), 0.0) +
COALESCE(1.0 / (%(k)s + keyword_search.rank), 0.0) AS score
FROM semantic_search
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
ORDER BY score DESC
LIMIT %(limit)s
"""
class PostgresKnowledgeProvider(BaseContextProvider):
"""Retrieves relevant product knowledge via hybrid search (vector + full-text) with RRF.
Uses pgvector for semantic similarity and PostgreSQL tsvector for keyword
matching, combining results with Reciprocal Rank Fusion (RRF). This gives
better retrieval than either method alone.
"""
def __init__(self, conn: psycopg.Connection, max_results: int = 3):
super().__init__(source_id="postgres-knowledge")
self.conn = conn
self.max_results = max_results
def _search(self, query: str) -> list[dict]:
"""Run hybrid search (vector + full-text) and return matching products."""
query_embedding = get_embedding(query)
cursor = self.conn.execute(
HYBRID_SEARCH_SQL,
{"embedding": query_embedding, "query": query, "k": 60, "limit": self.max_results},
)
result_ids = [row[0] for row in cursor.fetchall()]
if not result_ids:
return []
# Fetch full product details for the matched IDs
products = []
for product_id in result_ids:
row = self.conn.execute(
"SELECT name, category, price, description FROM products WHERE id = %s",
(product_id,),
).fetchone()
if row:
products.append({"name": row[0], "category": row[1], "price": row[2], "description": row[3]})
return products
async def before_run(
self,
*,
agent: SupportsAgentRun,
session: AgentSession,
context: SessionContext,
state: dict[str, Any],
) -> None:
"""Search the knowledge base with the user's latest message and inject results."""
user_text = ""
for msg in reversed(context.input_messages):
if msg.role == "user" and msg.text:
user_text = msg.text
break
if not user_text:
return
results = self._search(user_text)
if not results:
logger.info("[π Knowledge] No matching products found for: %s", user_text)
return
logger.info("[π Knowledge] Found %d matching product(s) for: %s", len(results), user_text)
knowledge_lines = ["Here is relevant product information from our catalog:\n"]
for product in results:
knowledge_lines.append(
f"- **{product['name']}** ({product['category']}, ${product['price']:.2f}): "
f"{product['description']}"
)
knowledge_text = "\n".join(knowledge_lines)
context.extend_messages(
self.source_id,
[Message(role="system", text=knowledge_text)],
)
# ββ Setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def setup_db() -> psycopg.Connection:
"""Connect to PostgreSQL and seed the knowledge base."""
conn = psycopg.connect(POSTGRES_URL)
create_knowledge_db(conn)
return conn
conn = setup_db()
knowledge_provider = PostgresKnowledgeProvider(conn=conn)
agent = Agent(
client=chat_client,
instructions=(
"You are a helpful outdoor-gear shopping assistant for the store 'TrailBuddy'. "
"Answer customer questions using ONLY the product information provided in the context. "
"If no relevant products are found in the context, say you don't have information "
"about that item. Include prices when recommending products."
),
context_providers=[knowledge_provider],
)
async def main() -> None:
"""Demonstrate hybrid search RAG with several queries."""
print("\n[bold]=== Knowledge Retrieval (RAG) with PostgreSQL Hybrid Search ===[/bold]")
print("[dim]The agent uses pgvector (semantic) + tsvector (keyword) with RRF before each LLM call.[/dim]\n")
# Query 1: Should match hiking boots and trekking poles
print("[blue]User:[/blue] I'm planning a hiking trip. What boots and poles do you recommend?")
response = await agent.run("I'm planning a hiking trip. What boots and poles do you recommend?")
print(f"[green]Agent:[/green] {response.text}\n")
# Query 2: Should match the down jacket
print("[blue]User:[/blue] I need something warm for winter camping, maybe a jacket?")
response = await agent.run("I need something warm for winter camping, maybe a jacket?")
print(f"[green]Agent:[/green] {response.text}\n")
# Query 3: Should match the kayak paddle (semantic match β "water sports gear")
print("[blue]User:[/blue] What water sports gear do you carry?")
response = await agent.run("What water sports gear do you carry?")
print(f"[green]Agent:[/green] {response.text}\n")
# Query 4: Semantic match β "gadgets for wildlife watching" β binoculars
print("[blue]User:[/blue] I want gadgets for wildlife watching")
response = await agent.run("I want gadgets for wildlife watching")
print(f"[green]Agent:[/green] {response.text}\n")
conn.close()
if async_credential:
await async_credential.close()
if __name__ == "__main__":
if "--devui" in sys.argv:
from agent_framework.devui import serve
serve(entities=[agent], auto_open=True)
else:
asyncio.run(main())