Skip to content

Commit 2862122

Browse files
committed
Replace race_models example with cascaded_fix LLM demo
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 0111ad4 commit 2862122

2 files changed

Lines changed: 207 additions & 347 deletions

File tree

examples/cascaded_fix.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Fix a buggy module with cascaded speculation and an LLM.
4+
5+
Cascaded starts with one cheap LLM call and escalates on failure,
6+
feeding accumulated test errors into the next wave. Each attempt
7+
asks the model to fix the code, informed by all prior errors.
8+
9+
Wave 0: model sees buggy code, no errors → might miss subtle bugs
10+
Wave 1: model sees code + wave 0's error → more targeted fix
11+
Wave 2: model sees all accumulated errors → most likely succeeds
12+
13+
Most bugs get fixed on wave 0 (one LLM call). You only pay for
14+
extra calls when the first attempt fails.
15+
16+
Supports any OpenAI-compatible API (OpenAI, Ollama, etc.).
17+
Use --dry-run to test the pattern without API keys.
18+
19+
Usage:
20+
# With OpenAI
21+
export OPENAI_API_KEY=sk-...
22+
python examples/cascaded_fix.py /mnt/workspace
23+
24+
# With a local model (Ollama)
25+
export OPENAI_BASE_URL=http://localhost:11434/v1
26+
export OPENAI_API_KEY=unused
27+
python examples/cascaded_fix.py /mnt/workspace --model llama3
28+
29+
# Without API keys (scripted fixes for demo)
30+
python examples/cascaded_fix.py /mnt/workspace --dry-run
31+
"""
32+
33+
import argparse
34+
import re
35+
import subprocess
36+
import sys
37+
from pathlib import Path
38+
from textwrap import dedent
39+
40+
from branching import Workspace, Cascaded
41+
42+
# ---------------------------------------------------------------------------
43+
# Buggy module + test suite
44+
# ---------------------------------------------------------------------------
45+
46+
BUGGY_CODE = dedent("""\
47+
def safe_divide(a, b):
48+
\"\"\"Divide a by b, returning None if b is zero.\"\"\"
49+
return a / b
50+
51+
def average(nums):
52+
\"\"\"Return the mean of nums, or 0.0 for an empty list.\"\"\"
53+
return sum(nums) / len(nums)
54+
55+
def clamp(x, lo, hi):
56+
\"\"\"Clamp x to the range [lo, hi].\"\"\"
57+
if x < lo:
58+
return lo
59+
if x < hi:
60+
return hi
61+
return x
62+
""")
63+
64+
TEST_CODE = dedent("""\
65+
from mathutil import safe_divide, average, clamp
66+
67+
def test_divide():
68+
assert safe_divide(10, 2) == 5.0
69+
70+
def test_divide_zero():
71+
assert safe_divide(10, 0) is None
72+
73+
def test_average():
74+
assert average([1, 2, 3]) == 2.0
75+
76+
def test_average_empty():
77+
assert average([]) == 0.0
78+
79+
def test_clamp_low():
80+
assert clamp(1, 5, 10) == 5
81+
82+
def test_clamp_high():
83+
assert clamp(15, 5, 10) == 10
84+
85+
def test_clamp_mid():
86+
assert clamp(7, 5, 10) == 7
87+
""")
88+
89+
# ---------------------------------------------------------------------------
90+
# LLM + dry-run backends
91+
# ---------------------------------------------------------------------------
92+
93+
94+
def _call_llm(code: str, feedback: list[str], model: str) -> str:
95+
"""Ask an LLM to fix the code, informed by prior test errors."""
96+
import openai
97+
98+
prompt = f"Fix the bugs in this Python module:\n\n```python\n{code}```\n"
99+
if feedback:
100+
prompt += "\nPrevious fix attempts failed with these test errors:\n"
101+
for i, err in enumerate(feedback):
102+
prompt += f"\n--- attempt {i} ---\n{err}\n"
103+
prompt += "\nReturn ONLY the corrected Python code. No markdown fences, no explanation."
104+
105+
client = openai.OpenAI()
106+
resp = client.chat.completions.create(
107+
model=model,
108+
messages=[{"role": "user", "content": prompt}],
109+
temperature=0.7,
110+
)
111+
text = resp.choices[0].message.content
112+
# Strip markdown fences if the model wraps its output anyway.
113+
m = re.search(r"```(?:python)?\s*\n(.+?)```", text, re.DOTALL)
114+
return m.group(1) if m else text
115+
116+
117+
def _scripted_fix(code: str, feedback: list[str]) -> str:
118+
"""Dry-run: apply deterministic fixes based on error keywords."""
119+
code = code.replace(
120+
" return a / b",
121+
" if b == 0:\n return None\n return a / b",
122+
)
123+
if any("average" in e or "ZeroDivisionError" in e for e in feedback):
124+
code = code.replace(
125+
" return sum(nums) / len(nums)",
126+
" if not nums:\n return 0.0\n return sum(nums) / len(nums)",
127+
)
128+
if any("clamp" in e for e in feedback):
129+
code = code.replace(" if x < hi:", " if x > hi:")
130+
return code
131+
132+
133+
# ---------------------------------------------------------------------------
134+
# Cascaded task
135+
# ---------------------------------------------------------------------------
136+
137+
138+
def make_task(model: str, dry_run: bool):
139+
"""Build a Cascaded task that fixes code via LLM (or scripted fallback)."""
140+
141+
def task(path: Path, feedback: list[str]) -> tuple[bool, str]:
142+
code = (path / "mathutil.py").read_text()
143+
144+
if dry_run:
145+
fixed = _scripted_fix(code, feedback)
146+
else:
147+
fixed = _call_llm(code, feedback, model)
148+
149+
(path / "mathutil.py").write_text(fixed)
150+
151+
result = subprocess.run(
152+
[sys.executable, "-m", "pytest", "-x", "--tb=short", "test_mathutil.py"],
153+
cwd=path, capture_output=True, text=True, timeout=30,
154+
)
155+
if result.returncode == 0:
156+
return True, ""
157+
return False, (result.stdout + result.stderr).strip()
158+
159+
return task
160+
161+
162+
# ---------------------------------------------------------------------------
163+
# Main
164+
# ---------------------------------------------------------------------------
165+
166+
167+
def main():
168+
parser = argparse.ArgumentParser(
169+
description="Fix a buggy module with Cascaded speculation + LLM.",
170+
)
171+
parser.add_argument("workspace", help="BranchFS/DaxFS workspace path")
172+
parser.add_argument("--model", default="gpt-4o-mini",
173+
help="LLM model name (default: gpt-4o-mini)")
174+
parser.add_argument("--dry-run", action="store_true",
175+
help="Use scripted fixes instead of LLM (no API keys needed)")
176+
args = parser.parse_args()
177+
178+
ws_path = Path(args.workspace)
179+
ws = Workspace(ws_path)
180+
181+
# Seed workspace with buggy code and tests.
182+
(ws_path / "mathutil.py").write_text(BUGGY_CODE)
183+
(ws_path / "test_mathutil.py").write_text(TEST_CODE)
184+
185+
mode = "dry-run (scripted fixes)" if args.dry_run else f"LLM ({args.model})"
186+
print(f"Workspace: {ws}")
187+
print(f"Mode: {mode}")
188+
print("Bug: mathutil.py has 3 bugs (divide-by-zero, empty list, wrong comparison)")
189+
print("Pattern: Cascaded — try cheap, escalate with error feedback\n")
190+
191+
task = make_task(model=args.model, dry_run=args.dry_run)
192+
outcome = Cascaded(task, fan_out=(1, 1, 1), timeout=120)(ws)
193+
194+
for r in outcome.all_results:
195+
status = "pass" if r.success else "FAIL"
196+
committed = " (committed)" if r.success and outcome.committed else ""
197+
print(f" attempt {r.branch_index}: {status}{committed}")
198+
199+
if outcome.committed:
200+
print(f"\nFixed on attempt {outcome.winner.branch_index}")
201+
else:
202+
print("\nAll attempts failed")
203+
sys.exit(1)
204+
205+
206+
if __name__ == "__main__":
207+
main()

0 commit comments

Comments
 (0)