-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathworkflow_rag_ingest.py
More file actions
146 lines (113 loc) · 5.27 KB
/
workflow_rag_ingest.py
File metadata and controls
146 lines (113 loc) · 5.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""RAG ingestion pipeline using plain Python executors.
Demonstrates: Executor subclasses, @handler, typed WorkflowContext, and
WorkflowBuilder with explicit edges — no AI agents involved.
Pipeline:
Extract → Chunk → Embed
Run:
uv run examples/workflow_rag_ingest.py
uv run examples/workflow_rag_ingest.py --devui (opens DevUI at http://localhost:8090)
In the DevUI, enter a filename relative to the examples/ folder, e.g.: sample_document.pdf
"""
import asyncio
import logging
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from agent_framework import Executor, WorkflowBuilder, WorkflowContext, handler
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from markitdown import MarkItDown
from openai import OpenAI
from rich.logging import RichHandler
from typing_extensions import Never
log_handler = RichHandler(show_path=False, rich_tracebacks=True, show_level=False)
logging.basicConfig(level=logging.WARNING, handlers=[log_handler], force=True, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
EMBEDDING_DIMENSIONS = 256 # Smaller dimension for efficiency
# Configure the embedding client based on the API host
if API_HOST == "azure":
sync_credential = DefaultAzureCredential()
sync_token_provider = get_bearer_token_provider(sync_credential, "https://cognitiveservices.azure.com/.default")
embed_client = OpenAI(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=sync_token_provider(),
)
embed_model = os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-3-small")
elif API_HOST == "github":
embed_client = OpenAI(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
)
embed_model = "text-embedding-3-small"
else:
embed_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
embed_model = "text-embedding-3-small"
@dataclass
class EmbeddedChunk:
"""A text chunk paired with its embedding vector."""
text: str
vector: list[float] = field(default_factory=list)
class ExtractExecutor(Executor):
"""Convert a local file to plain markdown text."""
@handler
async def extract(self, path: str, ctx: WorkflowContext[str]) -> None:
"""Convert the file at the given path to markdown.
Accepts an absolute path or a filename relative to the examples folder
(e.g. ``sample_document.pdf``). Surrounding quotes are stripped automatically.
WorkflowContext[str] means this node sends a str to the next node.
"""
path = path.strip("'\"")
resolved = Path(path) if Path(path).is_absolute() else Path(__file__).parent.parent / path
result = MarkItDown().convert(str(resolved))
await ctx.send_message(result.text_content)
class ChunkExecutor(Executor):
"""Split markdown into paragraphs, keeping only substantive ones."""
@handler
async def chunk(self, markdown: str, ctx: WorkflowContext[list[str]]) -> None:
"""Split on blank lines and filter out headings and short fragments.
WorkflowContext[list[str]] means this node sends a list[str] downstream.
"""
paragraphs = markdown.split("\n\n")
chunks = [p.strip() for p in paragraphs if len(p.strip()) >= 80 and not p.strip().startswith("#")]
logger.info(f"→ {len(chunks)} chunks extracted")
await ctx.send_message(chunks)
class EmbedExecutor(Executor):
"""Embed each chunk with the configured OpenAI embedding model."""
@handler
async def embed(self, chunks: list[str], ctx: WorkflowContext[Never, list[EmbeddedChunk]]) -> None:
"""Call the embeddings API for each chunk and yield the results.
WorkflowContext[Never, list[EmbeddedChunk]] means this terminal node
yields workflow output but does not forward messages further.
"""
embedded = []
for chunk in chunks:
response = embed_client.embeddings.create(input=chunk, model=embed_model, dimensions=EMBEDDING_DIMENSIONS)
embedded.append(EmbeddedChunk(text=chunk, vector=response.data[0].embedding))
logger.info(f"→ {len(embedded)} chunks embedded ({EMBEDDING_DIMENSIONS}d each)")
await ctx.yield_output(embedded)
# Create executor instances
extract = ExtractExecutor(id="extract")
chunk = ChunkExecutor(id="chunk")
embed = EmbedExecutor(id="embed")
# Build the workflow: Extract → Chunk → Embed
workflow = WorkflowBuilder(start_executor=extract).add_edge(extract, chunk).add_edge(chunk, embed).build()
async def main():
pdf_path = str(Path(__file__).parent / "sample_document.pdf")
logger.info("Processing: %s", pdf_path)
events = await workflow.run(pdf_path)
outputs = events.get_outputs()
for result in outputs:
logger.info(f"Embedded {len(result)} chunks:")
for chunk in result:
preview = chunk.text[:80].replace("\n", " ")
logger.info(f" [{len(chunk.vector)}d] {preview}…")
if __name__ == "__main__":
if "--devui" in sys.argv:
from agent_framework.devui import serve
serve(entities=[workflow], port=8090, auto_open=True)
else:
asyncio.run(main())