|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +"""Fix a buggy module with cascaded speculation and an LLM. |
| 4 | +
|
| 5 | +Cascaded starts with one cheap LLM call and escalates on failure, |
| 6 | +feeding accumulated test errors into the next wave. Each attempt |
| 7 | +asks the model to fix the code, informed by all prior errors. |
| 8 | +
|
| 9 | + Wave 0: model sees buggy code, no errors → might miss subtle bugs |
| 10 | + Wave 1: model sees code + wave 0's error → more targeted fix |
| 11 | + Wave 2: model sees all accumulated errors → most likely succeeds |
| 12 | +
|
| 13 | +Most bugs get fixed on wave 0 (one LLM call). You only pay for |
| 14 | +extra calls when the first attempt fails. |
| 15 | +
|
| 16 | +Supports any OpenAI-compatible API (OpenAI, Ollama, etc.). |
| 17 | +Use --dry-run to test the pattern without API keys. |
| 18 | +
|
| 19 | +Usage: |
| 20 | + # With OpenAI |
| 21 | + export OPENAI_API_KEY=sk-... |
| 22 | + python examples/cascaded_fix.py /mnt/workspace |
| 23 | +
|
| 24 | + # With a local model (Ollama) |
| 25 | + export OPENAI_BASE_URL=http://localhost:11434/v1 |
| 26 | + export OPENAI_API_KEY=unused |
| 27 | + python examples/cascaded_fix.py /mnt/workspace --model llama3 |
| 28 | +
|
| 29 | + # Without API keys (scripted fixes for demo) |
| 30 | + python examples/cascaded_fix.py /mnt/workspace --dry-run |
| 31 | +""" |
| 32 | + |
| 33 | +import argparse |
| 34 | +import re |
| 35 | +import subprocess |
| 36 | +import sys |
| 37 | +from pathlib import Path |
| 38 | +from textwrap import dedent |
| 39 | + |
| 40 | +from branching import Workspace, Cascaded |
| 41 | + |
| 42 | +# --------------------------------------------------------------------------- |
| 43 | +# Buggy module + test suite |
| 44 | +# --------------------------------------------------------------------------- |
| 45 | + |
| 46 | +BUGGY_CODE = dedent("""\ |
| 47 | + def safe_divide(a, b): |
| 48 | + \"\"\"Divide a by b, returning None if b is zero.\"\"\" |
| 49 | + return a / b |
| 50 | +
|
| 51 | + def average(nums): |
| 52 | + \"\"\"Return the mean of nums, or 0.0 for an empty list.\"\"\" |
| 53 | + return sum(nums) / len(nums) |
| 54 | +
|
| 55 | + def clamp(x, lo, hi): |
| 56 | + \"\"\"Clamp x to the range [lo, hi].\"\"\" |
| 57 | + if x < lo: |
| 58 | + return lo |
| 59 | + if x < hi: |
| 60 | + return hi |
| 61 | + return x |
| 62 | +""") |
| 63 | + |
| 64 | +TEST_CODE = dedent("""\ |
| 65 | + from mathutil import safe_divide, average, clamp |
| 66 | +
|
| 67 | + def test_divide(): |
| 68 | + assert safe_divide(10, 2) == 5.0 |
| 69 | +
|
| 70 | + def test_divide_zero(): |
| 71 | + assert safe_divide(10, 0) is None |
| 72 | +
|
| 73 | + def test_average(): |
| 74 | + assert average([1, 2, 3]) == 2.0 |
| 75 | +
|
| 76 | + def test_average_empty(): |
| 77 | + assert average([]) == 0.0 |
| 78 | +
|
| 79 | + def test_clamp_low(): |
| 80 | + assert clamp(1, 5, 10) == 5 |
| 81 | +
|
| 82 | + def test_clamp_high(): |
| 83 | + assert clamp(15, 5, 10) == 10 |
| 84 | +
|
| 85 | + def test_clamp_mid(): |
| 86 | + assert clamp(7, 5, 10) == 7 |
| 87 | +""") |
| 88 | + |
| 89 | +# --------------------------------------------------------------------------- |
| 90 | +# LLM + dry-run backends |
| 91 | +# --------------------------------------------------------------------------- |
| 92 | + |
| 93 | + |
| 94 | +def _call_llm(code: str, feedback: list[str], model: str) -> str: |
| 95 | + """Ask an LLM to fix the code, informed by prior test errors.""" |
| 96 | + import openai |
| 97 | + |
| 98 | + prompt = f"Fix the bugs in this Python module:\n\n```python\n{code}```\n" |
| 99 | + if feedback: |
| 100 | + prompt += "\nPrevious fix attempts failed with these test errors:\n" |
| 101 | + for i, err in enumerate(feedback): |
| 102 | + prompt += f"\n--- attempt {i} ---\n{err}\n" |
| 103 | + prompt += "\nReturn ONLY the corrected Python code. No markdown fences, no explanation." |
| 104 | + |
| 105 | + client = openai.OpenAI() |
| 106 | + resp = client.chat.completions.create( |
| 107 | + model=model, |
| 108 | + messages=[{"role": "user", "content": prompt}], |
| 109 | + temperature=0.7, |
| 110 | + ) |
| 111 | + text = resp.choices[0].message.content |
| 112 | + # Strip markdown fences if the model wraps its output anyway. |
| 113 | + m = re.search(r"```(?:python)?\s*\n(.+?)```", text, re.DOTALL) |
| 114 | + return m.group(1) if m else text |
| 115 | + |
| 116 | + |
| 117 | +def _scripted_fix(code: str, feedback: list[str]) -> str: |
| 118 | + """Dry-run: apply deterministic fixes based on error keywords.""" |
| 119 | + code = code.replace( |
| 120 | + " return a / b", |
| 121 | + " if b == 0:\n return None\n return a / b", |
| 122 | + ) |
| 123 | + if any("average" in e or "ZeroDivisionError" in e for e in feedback): |
| 124 | + code = code.replace( |
| 125 | + " return sum(nums) / len(nums)", |
| 126 | + " if not nums:\n return 0.0\n return sum(nums) / len(nums)", |
| 127 | + ) |
| 128 | + if any("clamp" in e for e in feedback): |
| 129 | + code = code.replace(" if x < hi:", " if x > hi:") |
| 130 | + return code |
| 131 | + |
| 132 | + |
| 133 | +# --------------------------------------------------------------------------- |
| 134 | +# Cascaded task |
| 135 | +# --------------------------------------------------------------------------- |
| 136 | + |
| 137 | + |
| 138 | +def make_task(model: str, dry_run: bool): |
| 139 | + """Build a Cascaded task that fixes code via LLM (or scripted fallback).""" |
| 140 | + |
| 141 | + def task(path: Path, feedback: list[str]) -> tuple[bool, str]: |
| 142 | + code = (path / "mathutil.py").read_text() |
| 143 | + |
| 144 | + if dry_run: |
| 145 | + fixed = _scripted_fix(code, feedback) |
| 146 | + else: |
| 147 | + fixed = _call_llm(code, feedback, model) |
| 148 | + |
| 149 | + (path / "mathutil.py").write_text(fixed) |
| 150 | + |
| 151 | + result = subprocess.run( |
| 152 | + [sys.executable, "-m", "pytest", "-x", "--tb=short", "test_mathutil.py"], |
| 153 | + cwd=path, capture_output=True, text=True, timeout=30, |
| 154 | + ) |
| 155 | + if result.returncode == 0: |
| 156 | + return True, "" |
| 157 | + return False, (result.stdout + result.stderr).strip() |
| 158 | + |
| 159 | + return task |
| 160 | + |
| 161 | + |
| 162 | +# --------------------------------------------------------------------------- |
| 163 | +# Main |
| 164 | +# --------------------------------------------------------------------------- |
| 165 | + |
| 166 | + |
| 167 | +def main(): |
| 168 | + parser = argparse.ArgumentParser( |
| 169 | + description="Fix a buggy module with Cascaded speculation + LLM.", |
| 170 | + ) |
| 171 | + parser.add_argument("workspace", help="BranchFS/DaxFS workspace path") |
| 172 | + parser.add_argument("--model", default="gpt-4o-mini", |
| 173 | + help="LLM model name (default: gpt-4o-mini)") |
| 174 | + parser.add_argument("--dry-run", action="store_true", |
| 175 | + help="Use scripted fixes instead of LLM (no API keys needed)") |
| 176 | + args = parser.parse_args() |
| 177 | + |
| 178 | + ws_path = Path(args.workspace) |
| 179 | + ws = Workspace(ws_path) |
| 180 | + |
| 181 | + # Seed workspace with buggy code and tests. |
| 182 | + (ws_path / "mathutil.py").write_text(BUGGY_CODE) |
| 183 | + (ws_path / "test_mathutil.py").write_text(TEST_CODE) |
| 184 | + |
| 185 | + mode = "dry-run (scripted fixes)" if args.dry_run else f"LLM ({args.model})" |
| 186 | + print(f"Workspace: {ws}") |
| 187 | + print(f"Mode: {mode}") |
| 188 | + print("Bug: mathutil.py has 3 bugs (divide-by-zero, empty list, wrong comparison)") |
| 189 | + print("Pattern: Cascaded — try cheap, escalate with error feedback\n") |
| 190 | + |
| 191 | + task = make_task(model=args.model, dry_run=args.dry_run) |
| 192 | + outcome = Cascaded(task, fan_out=(1, 1, 1), timeout=120)(ws) |
| 193 | + |
| 194 | + for r in outcome.all_results: |
| 195 | + status = "pass" if r.success else "FAIL" |
| 196 | + committed = " (committed)" if r.success and outcome.committed else "" |
| 197 | + print(f" attempt {r.branch_index}: {status}{committed}") |
| 198 | + |
| 199 | + if outcome.committed: |
| 200 | + print(f"\nFixed on attempt {outcome.winner.branch_index}") |
| 201 | + else: |
| 202 | + print("\nAll attempts failed") |
| 203 | + sys.exit(1) |
| 204 | + |
| 205 | + |
| 206 | +if __name__ == "__main__": |
| 207 | + main() |
0 commit comments