-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathworkflow_aggregator_ranked.py
More file actions
177 lines (143 loc) · 6.31 KB
/
workflow_aggregator_ranked.py
File metadata and controls
177 lines (143 loc) · 6.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Fan-out/fan-in with LLM-as-judge ranking aggregation.
Three creative agents with different personas (bold, minimalist,
emotional) each propose a marketing slogan. A ranker Executor collects
the candidates, formats them, and uses an internal judge Agent to score
and rank them — letting the LLM evaluate creativity, memorability, and
brand fit.
Aggregation technique: LLM-as-judge (generate N candidates, rank the best).
Run:
uv run examples/workflow_aggregator_ranked.py
uv run examples/workflow_aggregator_ranked.py --devui (opens DevUI at http://localhost:8104)
"""
import asyncio
import os
import sys
from agent_framework import Agent, AgentExecutorResponse, Executor, Message, WorkflowBuilder, WorkflowContext, handler
from agent_framework.openai import OpenAIChatClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing_extensions import Never
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
# Configure the chat client based on the API host
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = OpenAIChatClient(
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
api_key=token_provider,
model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
)
elif API_HOST == "github":
client = OpenAIChatClient(
base_url="https://models.github.ai/inference",
api_key=os.environ["GITHUB_TOKEN"],
model_id=os.getenv("GITHUB_MODEL", "openai/gpt-4.1-mini"),
)
else:
client = OpenAIChatClient(
api_key=os.environ["OPENAI_API_KEY"], model_id=os.environ.get("OPENAI_MODEL", "gpt-5-mini")
)
class RankedSlogan(BaseModel):
"""A single ranked slogan entry."""
rank: int = Field(description="Rank position, 1 = best.")
agent_name: str = Field(description="Name of the agent that produced the slogan.")
slogan: str = Field(description="The marketing slogan text.")
score: int = Field(description="Score from 1 to 10.")
justification: str = Field(description="One-sentence justification for the score.")
class RankedSlogans(BaseModel):
"""Typed output: a ranked list of slogans."""
rankings: list[RankedSlogan] = Field(description="Slogans ranked from best to worst.")
class DispatchPrompt(Executor):
"""Emit the product brief downstream for fan-out broadcast."""
@handler
async def dispatch(self, prompt: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(prompt)
class RankerExecutor(Executor):
"""Fan-in aggregator that formats candidate slogans and ranks them via the LLM client directly."""
def __init__(self, *, client: OpenAIChatClient, id: str = "Ranker") -> None:
super().__init__(id=id)
self._client = client
@handler
async def run(
self,
results: list[AgentExecutorResponse],
ctx: WorkflowContext[Never, RankedSlogans],
) -> None:
"""Collect slogans, format them, and ask the LLM to rank them."""
lines = []
for result in results:
slogan = result.agent_response.text.strip().strip("\"'").split("\n")[0].strip().strip("\"'")
lines.append(f"- [{result.executor_id}]: \"{slogan}\"")
messages = [
Message(role="system", text=(
"You are a senior creative director judging marketing slogans. "
"Given a list of candidate slogans, rank them from best to worst. "
"For each slogan, give a 1-10 score and a one-sentence justification "
"evaluating creativity, memorability, clarity, and brand fit."
)),
Message(role="user", text="Candidate slogans:\n" + "\n".join(lines)),
]
response = await self._client.get_response(messages, options={"response_format": RankedSlogans})
await ctx.yield_output(response.value)
dispatcher = DispatchPrompt(id="dispatcher")
bold_writer = Agent(
client=client,
name="BoldWriter",
instructions=(
"You are a bold, dramatic copywriter. "
"Given the product brief, propose ONE punchy marketing slogan (max 10 words). "
"Make it attention-grabbing and confident. Reply with ONLY the slogan."
),
)
minimalist_writer = Agent(
client=client,
name="MinimalistWriter",
instructions=(
"You are a minimalist copywriter who values brevity above all. "
"Given the product brief, propose ONE ultra-short marketing slogan (max 6 words). "
"Less is more. Reply with ONLY the slogan."
),
)
emotional_writer = Agent(
client=client,
name="EmotionalWriter",
instructions=(
"You are an empathy-driven copywriter. "
"Given the product brief, propose ONE marketing slogan (max 10 words) "
"that connects emotionally with the audience. Reply with ONLY the slogan."
),
)
# The ranker Executor calls the LLM client directly to handle fan-in —
# it formats the collected slogans and has the LLM rank them.
ranker = RankerExecutor(client=client)
workflow = (
WorkflowBuilder(
name="FanOutFanInRanked",
description="Generate slogans in parallel, then LLM-judge ranks them.",
start_executor=dispatcher,
output_executors=[ranker],
)
.add_fan_out_edges(dispatcher, [bold_writer, minimalist_writer, emotional_writer])
.add_fan_in_edges([bold_writer, minimalist_writer, emotional_writer], ranker)
.build()
)
async def main() -> None:
"""Run the slogan pipeline and print the ranked results."""
prompt = "Budget-friendly electric bike for urban commuters. Reliable, affordable, green."
print(f"Product brief: {prompt}\n")
events = await workflow.run(prompt)
for output in events.get_outputs():
for entry in output.rankings:
print(f"#{entry.rank} (score {entry.score}) [{entry.agent_name}]: \"{entry.slogan}\"")
print(f" {entry.justification}\n")
if async_credential:
await async_credential.close()
if __name__ == "__main__":
if "--devui" in sys.argv:
from agent_framework.devui import serve
serve(entities=[workflow], port=8104, auto_open=True)
else:
asyncio.run(main())