From dbef52618e440f4cf348292adb05f6d38487f7a9 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Thu, 7 May 2026 17:19:17 -0400 Subject: [PATCH 01/12] python tools requirement Signed-off-by: Akihiko Kuroda --- docs/examples/python_plotting_repair.py | 172 ++++++ mellea/stdlib/requirements/__init__.py | 2 + mellea/stdlib/requirements/python_tools.py | 523 ++++++++++++++++ test/stdlib/requirements/test_python_tools.py | 575 ++++++++++++++++++ 4 files changed, 1272 insertions(+) create mode 100644 docs/examples/python_plotting_repair.py create mode 100644 mellea/stdlib/requirements/python_tools.py create mode 100644 test/stdlib/requirements/test_python_tools.py diff --git a/docs/examples/python_plotting_repair.py b/docs/examples/python_plotting_repair.py new file mode 100644 index 000000000..7347f58b6 --- /dev/null +++ b/docs/examples/python_plotting_repair.py @@ -0,0 +1,172 @@ +# pytest: ollama, e2e, qualitative +"""Granite 4.1 repairs the three canonical plotting failures with Python tool. + +This example demonstrates: +1. Creating a PythonToolRequirements bundle for plotting validation +2. Using SOFAI sampling strategy with repair feedback loop +3. Granite 4.1 repairing through: syntax → imports → headless backend → savefig + +Canonical task: "Create a plot of sin(x) for x in 0..2π and save to /tmp/plot.png" + +The model will encounter and repair: +- Attempt 1: Missing matplotlib.use('Agg') (non-headless backend) +- Attempt 2: Missing plt.savefig() call +- Attempt 3: Success with both fixes applied + +The requirements bundle provides actionable failure messages that guide the model +through each repair iteration without explicit instruction. +""" + +import mellea +from mellea.backends import ModelOption +from mellea.backends.tools import MelleaTool +from mellea.stdlib.components import Instruction +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements import PythonToolRequirements +from mellea.stdlib.sampling import SOFAISamplingStrategy +from mellea.stdlib.tools import local_code_interpreter +from mellea.stdlib.tools.interpreter import ExecutionResult + + +def python(code: str) -> ExecutionResult: + """Execute Python code. + + Args: + code: Python code to execute + + Returns: + Execution result containing stdout, stderr, and success status + """ + return local_code_interpreter(code) + + +async def main(): + """Run the canonical plotting repair example.""" + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + + # Initialize session with local backend + m = mellea.start_session() + + # Create requirements bundle for plotting validation + # Allows matplotlib import (no output_path = skip file creation check) + bundle = PythonToolRequirements(allowed_imports=["numpy", "matplotlib", "math"]) + + # Define SOFAI strategy for repair: S1 (fast) up to 3 times, then S2 (slow) + sampling_strategy = SOFAISamplingStrategy( + s1_solver_backend=m.backend, + s2_solver_backend=m.backend, + s2_solver_mode="fresh_start", + loop_budget=3, + feedback_strategy="first_error", + ) + + # Create the plotting task instruction + description = f"""Create a plot of sin(x) for x in 0..2π and save it to {output_path}. + +Requirements: +- Use the python tool to execute your code +- Import numpy and matplotlib +- Generate x values from 0 to 2π +- Plot sin(x) against x +- Save the plot to the specified file path + +Use the python tool with your complete code.""" + instruction = Instruction(description=description) + + # Create a chat context for multi-turn repair + ctx = ChatContext() + + print("=" * 70) + print("Testing Granite 4.1's ability to repair plotting failures") + print("=" * 70) + print(f"Task: Create a plot of sin(x) and save to {output_path}\n") + + try: + # Run the sampling strategy with requirements + result = await sampling_strategy.sample( + action=instruction, + context=ctx, + backend=m.backend, + requirements=bundle.requirements, + tool_calls=True, + model_options={ModelOption.TOOLS: [MelleaTool.from_callable(python)]}, + ) + + print(f"\nResult: {'SUCCESS' if result.success else 'FAILED'}\n") + + if result.success: + print("✓ Granite 4.1 successfully generated and executed plotting code") + print("\nFinal generated code:") + print("-" * 70) + print(result.result.value) + print("-" * 70) + + # Verify output file exists + from pathlib import Path + + if Path(output_path).exists(): # noqa: ASYNC240 + file_size = Path(output_path).stat().st_size # noqa: ASYNC240 + print(f"\n✓ Output file created: {output_path}") + print(f" File size: {file_size} bytes") + else: + print(f"\n✗ Output file not found: {output_path}") + + # Print repair history + print(f"\nRepair iterations: {len(result.sample_validations)}") + for attempt_idx, validations in enumerate(result.sample_validations, 1): + passed = sum(1 for _, val in validations if val.as_bool()) + total = len(validations) + status = "✓" if passed == total else "✗" + print( + f" {status} Attempt {attempt_idx}: {passed}/{total} " + f"requirements passed" + ) + + # Show which requirements failed + for req, val in validations: + if not val.as_bool(): + print(f" - {req.description}") + if val.reason: + reason_preview = val.reason[:100].replace("\n", " ") + print(f" Error: {reason_preview}...") + + else: + print("✗ Failed to generate working plotting code after all attempts\n") + print("Last attempt output:") + print("-" * 70) + print(result.result.value) + print("-" * 70) + + # Print failure history + print(f"\nFailure history ({len(result.sample_validations)} attempts):") + for attempt_idx, validations in enumerate(result.sample_validations, 1): + failed_count = sum(1 for _, val in validations if not val.as_bool()) + if failed_count > 0: + print(f"\n Attempt {attempt_idx}:") + for req, val in validations: + if not val.as_bool(): + print(f" - {req.description}") + if val.reason: + reason_lines = val.reason.split("\n")[:2] + for line in reason_lines: + print(f" {line}") + + except Exception as e: + print(f"✗ Exception during sampling: {e}") + import traceback + + traceback.print_exc() + + print("\n" + "=" * 70) + print("Test completed") + print("=" * 70) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index c0bd7d3c9..f40287553 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -4,6 +4,7 @@ from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table from .python_reqs import PythonExecutionReq +from .python_tools import PythonToolRequirements from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -19,6 +20,7 @@ "ALoraRequirement", "LLMaJRequirement", "PythonExecutionReq", + "PythonToolRequirements", "Requirement", "ValidationResult", "as_markdown_list", diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py new file mode 100644 index 000000000..4f78fb43a --- /dev/null +++ b/mellea/stdlib/requirements/python_tools.py @@ -0,0 +1,523 @@ +"""Pre-composed requirement bundles for Python tool invocation and execution. + +This module provides bundled requirements for validating Python code generated +via the Python tool, with focus on reactive failure detection and repair: + +- Tool invocation validation (must call Python tool with code argument) +- Syntax validation (code must parse correctly) +- Import validation (code imports must be in allowlist) +- Matplotlib headless backend detection (plt.show() without backend) +- Plot artifact validation (savefig must be called, output files must exist) +- Output limiting (stdout/stderr must not exceed configured limits) + +Failure messages are written as feedback to the model, not to developers. +They state the failure, include relevant code/stderr, and explain the +correction well enough for the model to act on it. + +FAILURE MATRIX — How each requirement catches the canonical plotting failures: + +Scenario: Model generates plotting code with matplotlib + +Attempt 1: No tool call + → MustInvokePythonTool fails + → Repair: "Call the `python` tool with your code" + +Attempt 2: Tool called but no 'code' arg + → PythonToolHasCodeArg fails + → Repair: "The python tool requires a 'code' argument" + +Attempt 3: Code has syntax error + → PythonCodeParses fails + → Repair: "Your code has a syntax error at line X: {error}" + +Attempt 4: Code imports matplotlib (not in allowed_imports) + → PythonImportsAllowed fails + → Repair: "matplotlib is not allowed. Use only: {allowed_list}" + +Attempt 5: Code uses plt.show() without headless backend + → MatplotlibHeadless fails + → Repair: "Add matplotlib.use('Agg') and replace plt.show() with plt.savefig(...)" + +Attempt 6: Code has plt.plot() but no plt.savefig() + → PlotsAreSaved fails + → Repair: "Add plt.savefig('{output_path}') to save the plot" + +Attempt 7: Code runs, but output file not created + → OutputArtifactsExist fails + → Repair: "File '{output_path}' was not created. Check plt.savefig() call" + +Attempt 8: Success + → All requirements pass + → Result: plot file exists and is non-empty +""" + +import ast +from collections.abc import Callable +from pathlib import Path + +from ...core import Context, Requirement, ValidationResult +from .python_reqs import _has_python_code_listing + + +def _extract_code(ctx: Context) -> str | None: + """Extract Python code from either tool calls or markdown blocks. + + Checks tool_calls dict first (for tool calling), then falls back to + markdown code blocks in response text. + + Returns the code string, or None if no code found. + """ + # Try tool_calls first (tool calling format) + output = ctx.last_output() + if output and output.tool_calls and "python" in output.tool_calls: + tool_call = output.tool_calls["python"] + if hasattr(tool_call, "args") and "code" in tool_call.args: + return tool_call.args["code"] + + # Fall back to markdown code blocks in response text + result = _has_python_code_listing(ctx) + if result.as_bool() and result.reason: + return result.reason + return None + + +def _get_unauthorized_imports( + code: str, allowed_imports: list[str] | None +) -> list[str]: + """Return list of imports in code that are not in allowed_imports. + + Args: + code: Python code to analyze + allowed_imports: Allowlist of permitted top-level modules (None = allow all) + + Returns: + List of unauthorized import module names, or empty list if all allowed. + """ + if allowed_imports is None: + return [] + + unauthorized = [] + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name.split(".")[0] + if module_name not in allowed_imports: + unauthorized.append(module_name) + elif isinstance(node, ast.ImportFrom): + if node.module: + module_name = node.module.split(".")[0] + if module_name not in allowed_imports: + unauthorized.append(module_name) + except (SyntaxError, ValueError): + pass + + return list(set(unauthorized)) + + +def _code_parses(code: str) -> tuple[bool, str | None]: + """Check if code parses as valid Python. + + Returns: + (True, None) if code parses + (False, error_message) if syntax error + """ + try: + ast.parse(code) + return True, None + except SyntaxError as e: + error_msg = f"Syntax error at line {e.lineno}: {e.msg}" + if e.text: + error_msg += f"\n {e.text.rstrip()}" + if e.offset: + error_msg += "\n " + " " * (e.offset - 1) + "^" + return False, error_msg + + +def _uses_pyplot_show(code: str) -> bool: + """Check if code calls plt.show() or matplotlib.pyplot.show().""" + # Simple string checks work for most cases + return "plt.show" in code or ".show()" in code + + +def _sets_headless_backend(code: str) -> bool: + """Check if code sets matplotlib to use a headless backend.""" + headless_backends = ("Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg") + for backend in headless_backends: + if ( + f"matplotlib.use('{backend}')" in code + or f'matplotlib.use("{backend}")' in code + ): + return True + return False + + +def _uses_pyplot_plot(code: str) -> bool: + """Check if code calls pyplot plotting functions.""" + plot_functions = ( + "plt.plot", + "plt.bar", + "plt.scatter", + "plt.hist", + "plt.imshow", + "plt.figure", + "plt.subplot", + ".plot(", + ".bar(", + ".scatter(", + ".hist(", + ) + return any(func in code for func in plot_functions) + + +def _calls_savefig(code: str) -> bool: + """Check if code calls plt.savefig() or fig.savefig().""" + return "savefig" in code + + +# region Individual Requirement Validators + + +def _validate_python_tool_invoked(ctx: Context) -> ValidationResult: + """Requirement: Model must invoke the Python tool.""" + output = ctx.last_output() + if output is None or output.tool_calls is None: + return ValidationResult( + result=False, + reason=( + "You did not invoke any tools. To execute Python code, " + "call the `python` tool with your code." + ), + ) + if "python" not in output.tool_calls: + return ValidationResult( + result=False, + reason=( + "You did not call the `python` tool. Call it with your " + "code to execute it." + ), + ) + return ValidationResult(result=True) + + +def _validate_python_tool_has_code_arg(ctx: Context) -> ValidationResult: + """Requirement: Python tool call must include a 'code' argument.""" + output = ctx.last_output() + if output is None or output.tool_calls is None: + return ValidationResult(result=False, reason="No tool calls found") + + if "python" not in output.tool_calls: + return ValidationResult(result=False, reason="Python tool not called") + + python_call = output.tool_calls["python"] + if "code" not in python_call.args: + return ValidationResult( + result=False, + reason="The `python` tool call must include a `code` argument with your Python code.", + ) + + return ValidationResult(result=True) + + +def _make_code_parses_validator() -> Callable[[Context], ValidationResult]: + """Create a validator that checks if extracted code parses.""" + + def validate(ctx: Context) -> ValidationResult: + code = _extract_code(ctx) + if not code: + return ValidationResult( + result=False, + reason=( + "Could not extract Python code from your response. " + "Make sure to include code in ```python ... ``` blocks." + ), + ) + + parses, error = _code_parses(code) + if not parses: + return ValidationResult( + result=False, + reason=f"Your code contains a syntax error. {error}\n\nPlease fix the syntax and try again.", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_imports_allowed_validator( + allowed_imports: list[str] | None, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks if code imports are in allowlist.""" + + def validate(ctx: Context) -> ValidationResult: + if allowed_imports is None: + return ValidationResult(result=True) + + code = _extract_code(ctx) + if not code: + return ValidationResult( + result=False, reason="Could not extract Python code" + ) + + unauthorized = _get_unauthorized_imports(code, allowed_imports) + if unauthorized: + allowed_str = ", ".join(sorted(set(allowed_imports))) + return ValidationResult( + result=False, + reason=( + f"Your code imports forbidden modules: " + f"{', '.join(sorted(set(unauthorized)))}.\n" + f"You may only import: {allowed_str}\n" + f"Please rewrite your code without these imports." + ), + ) + + return ValidationResult(result=True) + + return validate + + +def _make_matplotlib_headless_validator() -> Callable[[Context], ValidationResult]: + """Create a validator that checks matplotlib uses headless backend.""" + + def validate(ctx: Context) -> ValidationResult: + code = _extract_code(ctx) + if not code: + return ValidationResult(result=True) + + if _uses_pyplot_show(code) and not _sets_headless_backend(code): + return ValidationResult( + result=False, + reason="Your code calls `plt.show()` but doesn't set a headless backend.\n" + "This will fail in a headless environment (no display).\n\n" + "Fix this by adding to the top of your code:\n" + " import matplotlib\n" + " matplotlib.use('Agg')\n\n" + "Then replace `plt.show()` with `plt.savefig('{output_path}'); plt.close()`", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_plots_saved_validator() -> Callable[[Context], ValidationResult]: + """Create a validator that checks if code saves plots to a file.""" + + def validate(ctx: Context) -> ValidationResult: + code = _extract_code(ctx) + if not code: + return ValidationResult(result=True) + + if _uses_pyplot_plot(code) and not _calls_savefig(code): + return ValidationResult( + result=False, + reason="Your code creates plots with pyplot but never calls `plt.savefig()` to save them.\n\n" + "Add this before your plotting code or at the end:\n" + " plt.savefig('{output_path}')\n" + " plt.close()", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_output_artifacts_validator( + output_path: str, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks if output file exists post-execution.""" + + def validate(ctx: Context) -> ValidationResult: + path = Path(output_path) + + if not path.exists(): + return ValidationResult( + result=False, + reason=f"The output file '{output_path}' was not created during execution.\n" + f"Make sure your code calls `plt.savefig('{output_path}')` to save the plot.", + ) + + if path.stat().st_size == 0: + return ValidationResult( + result=False, + reason=f"The output file '{output_path}' exists but is empty.\n" + f"Check that your plot code executed correctly.", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_output_limit_validator( + limit_bytes: int, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks stdout/stderr size limits.""" + + def validate(ctx: Context) -> ValidationResult: + output = ctx.last_output() + if output is None: + return ValidationResult(result=True) + + total_output = "" + if hasattr(output, "stdout") and output.stdout: + total_output += output.stdout + if hasattr(output, "stderr") and output.stderr: + total_output += output.stderr + + size = len(total_output.encode("utf-8")) + if size > limit_bytes: + return ValidationResult( + result=False, + reason=f"Your code produced {size} bytes of output, exceeding the limit of {limit_bytes} bytes.\n" + f"Add output limiting (e.g., redirect to /dev/null) or optimize your code.", + ) + + return ValidationResult(result=True) + + return validate + + +# endregion + + +class PythonToolRequirements: + """Pre-composed bundle of requirements for Python code generation via the tool. + + This bundle validates the complete Python code generation flow: tool invocation, + syntax, imports, execution, and output. It's designed to work with repair loops + (SOFAI, MultiTurnStrategy) to iteratively fix common plotting failures. + + Markers: + - **Deterministic** (unit-testable): tool invocation, syntax, imports, headless backend, + savefig presence, file existence, output limits + - **Qualitative** (needs model to evaluate): execution without error (captured via stderr) + + Args: + output_path (str | None): Path where plots should be saved. If specified, enables + output artifact validation. Defaults to None. + allowed_imports (list[str] | None): Allowlist of importable top-level modules. + None (default) allows any import. Set to list like ["numpy", "matplotlib"] + to restrict imports. + output_limit_bytes (int): Maximum bytes of stdout/stderr allowed. Defaults to 50000. + check_output_artifacts (bool): If True, validate that output file exists and is + non-empty after execution. Defaults to True if output_path is specified. + + Attributes: + requirements (list[Requirement]): The composed list of requirements, suitable + for use with sampling strategies. + """ + + def __init__( + self, + output_path: str | None = None, + allowed_imports: list[str] | None = None, + output_limit_bytes: int = 50_000, + check_output_artifacts: bool | None = None, + ): + """Initialize the Python tool requirements bundle.""" + self.output_path = output_path + self.allowed_imports = allowed_imports + self.output_limit_bytes = output_limit_bytes + + # Auto-enable output artifact checking if output_path is specified + if check_output_artifacts is None: + check_output_artifacts = output_path is not None + + self._check_output_artifacts = check_output_artifacts + + self.requirements = self._build_requirements() + + def _build_requirements(self) -> list[Requirement]: + """Build the list of requirements for this bundle.""" + reqs: list[Requirement] = [] + + # Tool invocation requirements (deterministic) + reqs.append( + Requirement( + description="Use the python tool to execute code.", + validation_fn=_validate_python_tool_invoked, + check_only=False, + ) + ) + + reqs.append( + Requirement( + description="The python tool call must include a code argument.", + validation_fn=_validate_python_tool_has_code_arg, + check_only=False, + ) + ) + + # Code quality requirements (deterministic) + reqs.append( + Requirement( + description="The Python code must parse correctly.", + validation_fn=_make_code_parses_validator(), + check_only=False, + ) + ) + + # Import validation (deterministic) + if self.allowed_imports is not None: + reqs.append( + Requirement( + description=f"Imports must be from allowed list: {', '.join(self.allowed_imports)}", + validation_fn=_make_imports_allowed_validator(self.allowed_imports), + check_only=False, + ) + ) + + # Matplotlib-specific requirements (deterministic) + reqs.append( + Requirement( + description=( + "If using pyplot, must set headless backend and use savefig." + ), + validation_fn=_make_matplotlib_headless_validator(), + check_only=False, + ) + ) + + reqs.append( + Requirement( + description="If creating plots, must call savefig to save them.", + validation_fn=_make_plots_saved_validator(), + check_only=False, + ) + ) + + # Output artifact validation (deterministic, post-execution) + if self._check_output_artifacts and self.output_path: + reqs.append( + Requirement( + description=f"Output file must be created at {self.output_path}", + validation_fn=_make_output_artifacts_validator(self.output_path), + check_only=False, + ) + ) + + # Output limiting (deterministic) + reqs.append( + Requirement( + description=f"Output must not exceed {self.output_limit_bytes} bytes.", + validation_fn=_make_output_limit_validator(self.output_limit_bytes), + check_only=False, + ) + ) + + return reqs + + def __repr__(self) -> str: + """Return a developer-readable representation.""" + return ( + f"PythonToolRequirements(" + f"output_path={self.output_path!r}, " + f"allowed_imports={self.allowed_imports!r}, " + f"output_limit_bytes={self.output_limit_bytes}, " + f"requirements={len(self.requirements)} items" + f")" + ) diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py new file mode 100644 index 000000000..6053d2277 --- /dev/null +++ b/test/stdlib/requirements/test_python_tools.py @@ -0,0 +1,575 @@ +"""Tests for Python tool requirements bundle.""" + +import tempfile +from pathlib import Path + +from mellea.core import Context, ModelOutputThunk +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.python_tools import ( + PythonToolRequirements, + _calls_savefig, + _code_parses, + _get_unauthorized_imports, + _sets_headless_backend, + _uses_pyplot_plot, + _uses_pyplot_show, +) + + +def from_model(content: str) -> Context: + """Helper to create context from model output.""" + ctx = ChatContext() + ctx = ctx.add(ModelOutputThunk(value=content)) + return ctx + + +# region: Helper function tests + + +class TestCodeParses: + """Tests for _code_parses helper.""" + + def test_valid_code_parses(self): + """Valid Python code should parse.""" + code = "x = 1\nprint(x)" + parses, error = _code_parses(code) + assert parses is True + assert error is None + + def test_syntax_error_detected(self): + """Syntax errors should be detected.""" + code = "def foo(\n return 42" + parses, error = _code_parses(code) + assert parses is False + assert error is not None + assert "Syntax error" in error + + def test_missing_colon(self): + """Missing colons should be detected.""" + code = "if True\n print('hello')" + parses, error = _code_parses(code) + assert parses is False + assert error is not None + + +class TestUnauthorizedImports: + """Tests for _get_unauthorized_imports helper.""" + + def test_no_imports_allowed_always(self): + """Code with no imports should always pass.""" + code = "x = 1" + unauthorized = _get_unauthorized_imports(code, ["numpy"]) + assert unauthorized == [] + + def test_allowed_imports(self): + """Allowed imports should not be flagged.""" + code = "import numpy\nimport pandas" + unauthorized = _get_unauthorized_imports(code, ["numpy", "pandas"]) + assert unauthorized == [] + + def test_unauthorized_import_detected(self): + """Unauthorized imports should be detected.""" + code = "import subprocess" + unauthorized = _get_unauthorized_imports(code, ["numpy", "pandas"]) + assert "subprocess" in unauthorized + + def test_none_allows_all(self): + """allowed_imports=None should allow any import.""" + code = "import subprocess\nimport socket\nimport os" + unauthorized = _get_unauthorized_imports(code, None) + assert unauthorized == [] + + def test_nested_imports(self): + """Only top-level module checked for nested imports.""" + code = "import numpy.random" + unauthorized = _get_unauthorized_imports(code, ["numpy"]) + assert unauthorized == [] + + def test_from_import(self): + """from ... import should be checked.""" + code = "from matplotlib import pyplot" + unauthorized = _get_unauthorized_imports(code, ["numpy"]) + assert "matplotlib" in unauthorized + + def test_multiple_unauthorized(self): + """Multiple unauthorized imports should be detected.""" + code = "import subprocess\nimport socket" + unauthorized = _get_unauthorized_imports(code, ["numpy"]) + assert len(unauthorized) == 2 + assert "subprocess" in unauthorized + assert "socket" in unauthorized + + +class TestMatplotlibDetection: + """Tests for matplotlib-related detection functions.""" + + def test_plt_show_detected(self): + """plt.show() should be detected.""" + code = "plt.plot([1, 2, 3])\nplt.show()" + assert _uses_pyplot_show(code) is True + + def test_plt_show_not_detected(self): + """Code without plt.show() should pass.""" + code = "plt.plot([1, 2, 3])\nplt.savefig('plot.png')" + assert _uses_pyplot_show(code) is False + + def test_headless_backend_agg(self): + """Agg backend should be detected.""" + code = "import matplotlib\nmatplotlib.use('Agg')" + assert _sets_headless_backend(code) is True + + def test_headless_backend_variations(self): + """Various headless backends should be detected.""" + for backend in ["Agg", "Svg", "Cairo", "PDF", "PS"]: + code = f"matplotlib.use('{backend}')" + assert _sets_headless_backend(code) is True + + def test_non_headless_backend(self): + """Non-headless backends should not be detected.""" + code = "matplotlib.use('TkAgg')" + assert _sets_headless_backend(code) is False + + def test_uses_pyplot_plot(self): + """Plotting functions should be detected.""" + for func in ["plt.plot", "plt.bar", "plt.scatter", ".plot("]: + code = f"import matplotlib.pyplot as plt\n{func}([1, 2, 3])" + assert _uses_pyplot_plot(code) is True + + def test_savefig_detected(self): + """savefig call should be detected.""" + code = "plt.plot([1, 2, 3])\nplt.savefig('plot.png')" + assert _calls_savefig(code) is True + + def test_savefig_not_detected(self): + """Code without savefig should not be detected.""" + code = "plt.plot([1, 2, 3])\nplt.show()" + assert _calls_savefig(code) is False + + +# endregion + + +# region: PythonToolRequirements tests + + +class TestPythonToolRequirementsBasic: + """Basic tests for PythonToolRequirements bundle.""" + + def test_initialization(self): + """Bundle should initialize with default settings.""" + bundle = PythonToolRequirements() + assert bundle.requirements is not None + assert len(bundle.requirements) > 0 + + def test_with_output_path(self): + """Bundle should accept output_path parameter.""" + bundle = PythonToolRequirements(output_path="/tmp/plot.png") + assert bundle.output_path == "/tmp/plot.png" + + def test_with_allowed_imports(self): + """Bundle should accept allowed_imports parameter.""" + allowed = ["numpy", "matplotlib"] + bundle = PythonToolRequirements(allowed_imports=allowed) + assert bundle.allowed_imports == allowed + + def test_output_artifact_checking_enabled_by_default(self): + """Output artifact checking should be enabled if output_path is set.""" + bundle = PythonToolRequirements(output_path="/tmp/plot.png") + assert bundle._check_output_artifacts is True + + def test_output_artifact_checking_disabled_by_default(self): + """Output artifact checking should be disabled if no output_path.""" + bundle = PythonToolRequirements() + assert bundle._check_output_artifacts is False + + def test_repr(self): + """Bundle should have a readable repr.""" + bundle = PythonToolRequirements(output_path="/tmp/plot.png") + repr_str = repr(bundle) + assert "PythonToolRequirements" in repr_str + assert "/tmp/plot.png" in repr_str + + +# endregion + + +# region: Individual requirement validation tests + + +class TestMustInvokePythonTool: + """Tests for MustInvokePythonTool requirement.""" + + def test_tool_not_called(self): + """Should fail if python tool not called.""" + bundle = PythonToolRequirements() + req = bundle.requirements[0] + + ctx = from_model("Here is the code:\n```python\nprint('hello')\n```") + result = req.validation_fn(ctx) + + assert result.as_bool() is False + reason_lower = result.reason.lower() + assert "did not invoke" in reason_lower or "did not call" in reason_lower + + def test_python_tool_called(self): + """Should pass if python tool is called.""" + bundle = PythonToolRequirements() + req = bundle.requirements[0] + + ctx = ChatContext() + call_obj = type("Call", (), {"args": {"code": "print('hi')"}})() + output = ModelOutputThunk( + value="I'll execute this code", tool_calls={"python": call_obj} + ) + ctx = ctx.add(output) + + result = req.validation_fn(ctx) + assert result.as_bool() is True + + +class TestPythonToolHasCodeArg: + """Tests for PythonToolHasCodeArg requirement.""" + + def test_missing_code_argument(self): + """Should fail if python tool call has no code argument.""" + bundle = PythonToolRequirements() + req = bundle.requirements[1] + + ctx = ChatContext() + call_obj = type("Call", (), {"args": {"other": "value"}})() + output = ModelOutputThunk( + value="I'll execute this", tool_calls={"python": call_obj} + ) + ctx = ctx.add(output) + + result = req.validation_fn(ctx) + assert result.as_bool() is False + assert "code" in result.reason.lower() + + def test_has_code_argument(self): + """Should pass if python tool call has code argument.""" + bundle = PythonToolRequirements() + req = bundle.requirements[1] + + ctx = ChatContext() + call_obj = type("Call", (), {"args": {"code": "print('hi')"}})() + output = ModelOutputThunk( + value="I'll execute this", tool_calls={"python": call_obj} + ) + ctx = ctx.add(output) + + result = req.validation_fn(ctx) + assert result.as_bool() is True + + +class TestCodeParsesRequirement: + """Tests for code parsing requirement.""" + + def test_valid_code(self): + """Valid code should pass.""" + bundle = PythonToolRequirements() + parse_reqs = [ + r for r in bundle.requirements if "parse" in r.description.lower() + ] + parse_req = parse_reqs[0] + + ctx = from_model("```python\nx = 1\nprint(x)\n```") + result = parse_req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_syntax_error(self): + """Syntax errors should be caught.""" + bundle = PythonToolRequirements() + parse_reqs = [ + r for r in bundle.requirements if "parse" in r.description.lower() + ] + parse_req = parse_reqs[0] + + ctx = from_model("```python\ndef foo(\n return 42\n```") + result = parse_req.validation_fn(ctx) + + assert result.as_bool() is False + assert "syntax" in result.reason.lower() + + def test_valid_code_from_tool_calls(self): + """Valid code in tool_calls should parse.""" + bundle = PythonToolRequirements() + parse_reqs = [ + r for r in bundle.requirements if "parse" in r.description.lower() + ] + parse_req = parse_reqs[0] + + # Create context with tool_calls instead of markdown + ctx = ChatContext() + tool_call = type("Call", (), {"args": {"code": "x = 1\nprint(x)"}})() + output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + ctx = ctx.add(output) + + result = parse_req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_syntax_error_from_tool_calls(self): + """Syntax errors in tool_calls should be caught.""" + bundle = PythonToolRequirements() + parse_reqs = [ + r for r in bundle.requirements if "parse" in r.description.lower() + ] + parse_req = parse_reqs[0] + + # Create context with tool_calls containing syntax error + ctx = ChatContext() + tool_call = type("Call", (), {"args": {"code": "def foo(\n return 42"}})() + output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + ctx = ctx.add(output) + + result = parse_req.validation_fn(ctx) + + assert result.as_bool() is False + assert "syntax" in result.reason.lower() + + +class TestImportAllowlistRequirement: + """Tests for import allowlist requirement.""" + + def test_allowed_imports(self): + """Allowed imports should pass.""" + allowed = ["numpy", "matplotlib"] + bundle = PythonToolRequirements(allowed_imports=allowed) + + import_reqs = [ + r for r in bundle.requirements if "import" in r.description.lower() + ] + assert len(import_reqs) > 0 + import_req = import_reqs[0] + + ctx = from_model( + "```python\nimport numpy\nimport matplotlib.pyplot as plt\n```" + ) + result = import_req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_unauthorized_imports(self): + """Unauthorized imports should fail.""" + allowed = ["numpy"] + bundle = PythonToolRequirements(allowed_imports=allowed) + + import_reqs = [ + r for r in bundle.requirements if "import" in r.description.lower() + ] + import_req = import_reqs[0] + + ctx = from_model("```python\nimport subprocess\n```") + result = import_req.validation_fn(ctx) + + assert result.as_bool() is False + assert "subprocess" in result.reason + + def test_allowed_imports_from_tool_calls(self): + """Allowed imports in tool_calls should pass.""" + allowed = ["numpy", "matplotlib"] + bundle = PythonToolRequirements(allowed_imports=allowed) + + import_reqs = [ + r for r in bundle.requirements if "import" in r.description.lower() + ] + import_req = import_reqs[0] + + # Create context with tool_calls + ctx = ChatContext() + code = "import numpy\nimport matplotlib.pyplot as plt\nprint(numpy.pi)" + tool_call = type("Call", (), {"args": {"code": code}})() + output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + ctx = ctx.add(output) + + result = import_req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_unauthorized_imports_from_tool_calls(self): + """Unauthorized imports in tool_calls should fail.""" + allowed = ["numpy"] + bundle = PythonToolRequirements(allowed_imports=allowed) + + import_reqs = [ + r for r in bundle.requirements if "import" in r.description.lower() + ] + import_req = import_reqs[0] + + # Create context with tool_calls + ctx = ChatContext() + tool_call = type("Call", (), {"args": {"code": "import subprocess\n"}})() + output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + ctx = ctx.add(output) + + result = import_req.validation_fn(ctx) + + assert result.as_bool() is False + assert "subprocess" in result.reason + + +class TestMatplotlibHeadlessRequirement: + """Tests for matplotlib headless backend requirement.""" + + def test_plt_show_without_backend(self): + """plt.show() without headless backend should fail.""" + bundle = PythonToolRequirements() + matplotlib_reqs = [ + r for r in bundle.requirements if "headless" in r.description.lower() + ] + assert len(matplotlib_reqs) > 0 + matplotlib_req = matplotlib_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.show()\n```") + result = matplotlib_req.validation_fn(ctx) + + assert result.as_bool() is False + assert "headless" in result.reason.lower() or "Agg" in result.reason + + def test_plt_show_with_backend(self): + """plt.show() with headless backend should pass.""" + bundle = PythonToolRequirements() + matplotlib_reqs = [ + r for r in bundle.requirements if "headless" in r.description.lower() + ] + matplotlib_req = matplotlib_reqs[0] + + ctx = from_model( + "```python\n" + "import matplotlib\n" + "matplotlib.use('Agg')\n" + "plt.plot([1, 2, 3])\n" + "plt.show()\n" + "```" + ) + result = matplotlib_req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_no_plt_show(self): + """Code without plt.show() should pass.""" + bundle = PythonToolRequirements() + matplotlib_reqs = [ + r for r in bundle.requirements if "headless" in r.description.lower() + ] + matplotlib_req = matplotlib_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.savefig('plot.png')\n```") + result = matplotlib_req.validation_fn(ctx) + + assert result.as_bool() is True + + +class TestPlotsAreSavedRequirement: + """Tests for plots must be saved requirement.""" + + def test_plot_without_savefig(self): + """Plotting without savefig should fail.""" + bundle = PythonToolRequirements() + plot_reqs = [ + r for r in bundle.requirements if "savefig" in r.description.lower() + ] + assert len(plot_reqs) > 0 + plot_req = plot_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.show()\n```") + result = plot_req.validation_fn(ctx) + + assert result.as_bool() is False + assert "savefig" in result.reason.lower() + + def test_plot_with_savefig(self): + """Plotting with savefig should pass.""" + bundle = PythonToolRequirements() + plot_reqs = [ + r for r in bundle.requirements if "savefig" in r.description.lower() + ] + plot_req = plot_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.savefig('plot.png')\n```") + result = plot_req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_no_plotting(self): + """Code without plotting should pass.""" + bundle = PythonToolRequirements() + plot_reqs = [ + r for r in bundle.requirements if "savefig" in r.description.lower() + ] + plot_req = plot_reqs[0] + + ctx = from_model("```python\nx = 1\nprint(x)\n```") + result = plot_req.validation_fn(ctx) + + assert result.as_bool() is True + + +class TestOutputArtifactsRequirement: + """Tests for output artifacts requirement.""" + + def test_output_file_not_created(self): + """Should fail if output file not created.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + + bundle = PythonToolRequirements(output_path=output_path) + artifact_reqs = [ + r for r in bundle.requirements if "output file" in r.description.lower() + ] + assert len(artifact_reqs) > 0 + artifact_req = artifact_reqs[0] + + ctx = from_model("Code ran successfully") + result = artifact_req.validation_fn(ctx) + + assert result.as_bool() is False + assert output_path in result.reason + + def test_output_file_exists_and_nonempty(self): + """Should pass if output file exists and is non-empty.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + Path(output_path).write_bytes(b"fake png data") + + bundle = PythonToolRequirements(output_path=output_path) + artifact_reqs = [ + r for r in bundle.requirements if "output file" in r.description.lower() + ] + artifact_req = artifact_reqs[0] + + ctx = from_model("Code ran successfully") + result = artifact_req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_output_file_empty(self): + """Should fail if output file exists but is empty.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + Path(output_path).write_bytes(b"") + + bundle = PythonToolRequirements(output_path=output_path) + artifact_reqs = [ + r for r in bundle.requirements if "output file" in r.description.lower() + ] + artifact_req = artifact_reqs[0] + + ctx = from_model("Code ran successfully") + result = artifact_req.validation_fn(ctx) + + assert result.as_bool() is False + assert "empty" in result.reason.lower() + + def test_output_artifact_disabled_without_output_path(self): + """Output artifact requirement should not be present without output_path.""" + bundle = PythonToolRequirements() + artifact_reqs = [ + r for r in bundle.requirements if "output file" in r.description.lower() + ] + assert len(artifact_reqs) == 0 + + +# endregion From 119699ed42a62d1642124939200c0e8e06ef30be Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Sat, 9 May 2026 08:44:53 -0400 Subject: [PATCH 02/12] review comments Signed-off-by: Akihiko Kuroda --- mellea/helpers/__init__.py | 2 + mellea/helpers/imports.py | 48 ++++++ mellea/stdlib/requirements/python_tools.py | 38 +---- mellea/stdlib/tools/interpreter.py | 37 +---- test/stdlib/requirements/test_python_tools.py | 152 ++++++++++++++++-- 5 files changed, 200 insertions(+), 77 deletions(-) create mode 100644 mellea/helpers/imports.py diff --git a/mellea/helpers/__init__.py b/mellea/helpers/__init__.py index 62b22eb6a..bad10f5b4 100644 --- a/mellea/helpers/__init__.py +++ b/mellea/helpers/__init__.py @@ -16,6 +16,7 @@ wait_for_all_mots, ) from .event_loop_helper import _run_async_in_thread +from .imports import get_unauthorized_imports from .openai_compatible_helpers import ( chat_completion_delta_merge, extract_model_tool_requests, @@ -36,6 +37,7 @@ "chat_completion_delta_merge", "extract_model_tool_requests", "get_current_event_loop", + "get_unauthorized_imports", "is_vllm_server_with_structured_output", "message_to_openai_message", "messages_to_docs", diff --git a/mellea/helpers/imports.py b/mellea/helpers/imports.py new file mode 100644 index 000000000..356bcc5d7 --- /dev/null +++ b/mellea/helpers/imports.py @@ -0,0 +1,48 @@ +"""Utilities for analyzing Python code imports.""" + +import ast + + +def get_unauthorized_imports( + code: str, allowed_imports: list[str] | None = None +) -> list[str]: + r"""Extract unauthorized imports from Python code. + + Parses Python code and returns a sorted list of top-level modules that are + imported but not in the allowed list. Handles both `import X` and `from X import Y` + statements, extracting the root module name (e.g., "numpy" from "numpy.random"). + + Args: + code: Python source code to analyze. + allowed_imports: Allowlist of permitted top-level modules. If None, allows all imports. + + Returns: + Sorted list of unauthorized module names found in code. Empty list if code + has syntax errors or if allowed_imports is None. + + Example: + >>> code = "import numpy\nimport os\nimport forbidden_lib" + >>> get_unauthorized_imports(code, ["os", "sys"]) + ['numpy', 'forbidden_lib'] + """ + if allowed_imports is None: + return [] + + unauthorized: set[str] = set() + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return [] + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module = alias.name.split(".")[0] + if module not in allowed_imports: + unauthorized.add(module) + elif isinstance(node, ast.ImportFrom) and node.module: + module = node.module.split(".")[0] + if module not in allowed_imports: + unauthorized.add(module) + + return sorted(unauthorized) diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 4f78fb43a..73e060560 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -56,6 +56,7 @@ from pathlib import Path from ...core import Context, Requirement, ValidationResult +from ...helpers import get_unauthorized_imports from .python_reqs import _has_python_code_listing @@ -81,41 +82,6 @@ def _extract_code(ctx: Context) -> str | None: return None -def _get_unauthorized_imports( - code: str, allowed_imports: list[str] | None -) -> list[str]: - """Return list of imports in code that are not in allowed_imports. - - Args: - code: Python code to analyze - allowed_imports: Allowlist of permitted top-level modules (None = allow all) - - Returns: - List of unauthorized import module names, or empty list if all allowed. - """ - if allowed_imports is None: - return [] - - unauthorized = [] - try: - tree = ast.parse(code) - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - module_name = alias.name.split(".")[0] - if module_name not in allowed_imports: - unauthorized.append(module_name) - elif isinstance(node, ast.ImportFrom): - if node.module: - module_name = node.module.split(".")[0] - if module_name not in allowed_imports: - unauthorized.append(module_name) - except (SyntaxError, ValueError): - pass - - return list(set(unauthorized)) - - def _code_parses(code: str) -> tuple[bool, str | None]: """Check if code parses as valid Python. @@ -261,7 +227,7 @@ def validate(ctx: Context) -> ValidationResult: result=False, reason="Could not extract Python code" ) - unauthorized = _get_unauthorized_imports(code, allowed_imports) + unauthorized = get_unauthorized_imports(code, allowed_imports) if unauthorized: allowed_str = ", ".join(sorted(set(allowed_imports))) return ValidationResult( diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 4c537cc32..b256f7d52 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -20,6 +20,7 @@ from typing import Any from ...core import MelleaLogger +from ...helpers import get_unauthorized_imports logger = MelleaLogger.get_logger() @@ -148,7 +149,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: ) if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + unauthorized = get_unauthorized_imports(code, self.allowed_imports) if unauthorized: return ExecutionResult( success=False, @@ -185,7 +186,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: unexpected error occurs. """ if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + unauthorized = get_unauthorized_imports(code, self.allowed_imports) if unauthorized: return ExecutionResult( success=False, @@ -259,7 +260,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: flag, or a skipped result on import violation or sandbox error. """ if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + unauthorized = get_unauthorized_imports(code, self.allowed_imports) if unauthorized: return ExecutionResult( success=False, @@ -301,37 +302,9 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: ) -def _get_unauthorized_imports(code: str, allowed_imports: list[str]) -> list[str]: - """Get list of unauthorized imports used in code.""" - unauthorized: list[str] = [] - try: - tree = ast.parse(code) - except SyntaxError: - return unauthorized - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - base_module = alias.name.split(".")[0] - if ( - base_module not in allowed_imports - and base_module not in unauthorized - ): - unauthorized.append(base_module) - elif isinstance(node, ast.ImportFrom): - if node.module: - base_module = node.module.split(".")[0] - if ( - base_module not in allowed_imports - and base_module not in unauthorized - ): - unauthorized.append(base_module) - return unauthorized - - def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: """Check if code only uses allowed imports.""" - return len(_get_unauthorized_imports(code, allowed_imports)) == 0 + return len(get_unauthorized_imports(code, allowed_imports)) == 0 def code_interpreter(code: str) -> ExecutionResult: diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py index 6053d2277..49056d93b 100644 --- a/test/stdlib/requirements/test_python_tools.py +++ b/test/stdlib/requirements/test_python_tools.py @@ -4,12 +4,12 @@ from pathlib import Path from mellea.core import Context, ModelOutputThunk +from mellea.helpers import get_unauthorized_imports from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.python_tools import ( PythonToolRequirements, _calls_savefig, _code_parses, - _get_unauthorized_imports, _sets_headless_backend, _uses_pyplot_plot, _uses_pyplot_show, @@ -53,48 +53,48 @@ def test_missing_colon(self): class TestUnauthorizedImports: - """Tests for _get_unauthorized_imports helper.""" + """Tests for get_unauthorized_imports helper.""" def test_no_imports_allowed_always(self): """Code with no imports should always pass.""" code = "x = 1" - unauthorized = _get_unauthorized_imports(code, ["numpy"]) + unauthorized = get_unauthorized_imports(code, ["numpy"]) assert unauthorized == [] def test_allowed_imports(self): """Allowed imports should not be flagged.""" code = "import numpy\nimport pandas" - unauthorized = _get_unauthorized_imports(code, ["numpy", "pandas"]) + unauthorized = get_unauthorized_imports(code, ["numpy", "pandas"]) assert unauthorized == [] def test_unauthorized_import_detected(self): """Unauthorized imports should be detected.""" code = "import subprocess" - unauthorized = _get_unauthorized_imports(code, ["numpy", "pandas"]) + unauthorized = get_unauthorized_imports(code, ["numpy", "pandas"]) assert "subprocess" in unauthorized def test_none_allows_all(self): """allowed_imports=None should allow any import.""" code = "import subprocess\nimport socket\nimport os" - unauthorized = _get_unauthorized_imports(code, None) + unauthorized = get_unauthorized_imports(code, None) assert unauthorized == [] def test_nested_imports(self): """Only top-level module checked for nested imports.""" code = "import numpy.random" - unauthorized = _get_unauthorized_imports(code, ["numpy"]) + unauthorized = get_unauthorized_imports(code, ["numpy"]) assert unauthorized == [] def test_from_import(self): """from ... import should be checked.""" code = "from matplotlib import pyplot" - unauthorized = _get_unauthorized_imports(code, ["numpy"]) + unauthorized = get_unauthorized_imports(code, ["numpy"]) assert "matplotlib" in unauthorized def test_multiple_unauthorized(self): """Multiple unauthorized imports should be detected.""" code = "import subprocess\nimport socket" - unauthorized = _get_unauthorized_imports(code, ["numpy"]) + unauthorized = get_unauthorized_imports(code, ["numpy"]) assert len(unauthorized) == 2 assert "subprocess" in unauthorized assert "socket" in unauthorized @@ -572,4 +572,138 @@ def test_output_artifact_disabled_without_output_path(self): assert len(artifact_reqs) == 0 +class TestOutputLimitValidator: + """Tests for _make_output_limit_validator.""" + + def test_empty_output_passes(self): + """No stdout/stderr should pass.""" + ctx = ChatContext().add(ModelOutputThunk(value="response")) + bundle = PythonToolRequirements(output_limit_bytes=1000) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + assert len(limit_reqs) > 0 + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is True + + def test_output_within_limit_passes(self): + """Output under limit should pass.""" + output = ModelOutputThunk(value="response") + output.stdout = "x" * 500 + output.stderr = "" + ctx = ChatContext().add(output) + + bundle = PythonToolRequirements(output_limit_bytes=1000) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is True + + def test_stdout_exceeds_limit_fails(self): + """Stdout exceeding limit should fail.""" + output = ModelOutputThunk(value="response") + output.stdout = "x" * 1500 + output.stderr = "" + ctx = ChatContext().add(output) + + bundle = PythonToolRequirements(output_limit_bytes=1000) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is False + assert "exceeding" in result.reason.lower() + assert "1500" in result.reason + + def test_stderr_exceeds_limit_fails(self): + """Stderr exceeding limit should fail.""" + output = ModelOutputThunk(value="response") + output.stdout = "" + output.stderr = "e" * 1500 + ctx = ChatContext().add(output) + + bundle = PythonToolRequirements(output_limit_bytes=1000) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is False + assert "exceeding" in result.reason.lower() + + def test_combined_output_exceeds_limit_fails(self): + """Combined stdout+stderr exceeding limit should fail.""" + output = ModelOutputThunk(value="response") + output.stdout = "x" * 600 + output.stderr = "e" * 600 # Combined: 1200 bytes + ctx = ChatContext().add(output) + + bundle = PythonToolRequirements(output_limit_bytes=1000) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is False + assert "1200" in result.reason + + def test_utf8_multibyte_characters_counted_correctly(self): + """Multibyte UTF-8 characters should be counted in bytes, not chars.""" + output = ModelOutputThunk(value="response") + output.stdout = "🎉" * 100 # 4 bytes per emoji = 400 bytes + output.stderr = "" + ctx = ChatContext().add(output) + + bundle = PythonToolRequirements(output_limit_bytes=300) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is False # 400 > 300 + assert "exceeding" in result.reason.lower() + + def test_limit_at_boundary(self): + """Output exactly at limit should pass.""" + output = ModelOutputThunk(value="response") + output.stdout = "x" * 1000 # Exactly 1000 bytes + output.stderr = "" + ctx = ChatContext().add(output) + + bundle = PythonToolRequirements(output_limit_bytes=1000) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is True + + def test_limit_one_byte_over_boundary_fails(self): + """Output one byte over limit should fail.""" + output = ModelOutputThunk(value="response") + output.stdout = "x" * 1001 # 1001 bytes + output.stderr = "" + ctx = ChatContext().add(output) + + bundle = PythonToolRequirements(output_limit_bytes=1000) + limit_reqs = [ + r for r in bundle.requirements if "exceed" in r.description.lower() + ] + limit_req = limit_reqs[0] + + result = limit_req.validation_fn(ctx) + assert result.as_bool() is False + + # endregion From 905510146bb70724acfd56446e9ed9326090c7a2 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Sat, 9 May 2026 09:27:50 -0400 Subject: [PATCH 03/12] review comments Signed-off-by: Akihiko Kuroda --- docs/examples/python_plotting_repair.py | 26 ++++++++-------- mellea/helpers/imports.py | 3 ++ mellea/stdlib/requirements/python_tools.py | 36 +++++++++++++++++++--- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/docs/examples/python_plotting_repair.py b/docs/examples/python_plotting_repair.py index 7347f58b6..5989f9a0c 100644 --- a/docs/examples/python_plotting_repair.py +++ b/docs/examples/python_plotting_repair.py @@ -17,6 +17,11 @@ through each repair iteration without explicit instruction. """ +import asyncio +import tempfile +import traceback +from pathlib import Path + import mellea from mellea.backends import ModelOption from mellea.backends.tools import MelleaTool @@ -42,9 +47,6 @@ def python(code: str) -> ExecutionResult: async def main(): """Run the canonical plotting repair example.""" - import tempfile - from pathlib import Path - with tempfile.TemporaryDirectory() as tmpdir: output_path = str(Path(tmpdir) / "plot.png") @@ -52,7 +54,10 @@ async def main(): m = mellea.start_session() # Create requirements bundle for plotting validation - # Allows matplotlib import (no output_path = skip file creation check) + # Note: We don't pass output_path, so the bundle doesn't enforce artifact validation. + # This example tests code generation and repair logic, not actual code execution. + # The model generates syntactically correct plotting code, but doesn't execute it, + # so the output file is never created. bundle = PythonToolRequirements(allowed_imports=["numpy", "matplotlib", "math"]) # Define SOFAI strategy for repair: S1 (fast) up to 3 times, then S2 (slow) @@ -65,7 +70,10 @@ async def main(): ) # Create the plotting task instruction - description = f"""Create a plot of sin(x) for x in 0..2π and save it to {output_path}. + task_summary = ( + f"Create a plot of sin(x) for x in 0..2π and save it to {output_path}" + ) + description = f"""{task_summary} Requirements: - Use the python tool to execute your code @@ -83,7 +91,7 @@ async def main(): print("=" * 70) print("Testing Granite 4.1's ability to repair plotting failures") print("=" * 70) - print(f"Task: Create a plot of sin(x) and save to {output_path}\n") + print(f"Task: {task_summary}\n") try: # Run the sampling strategy with requirements @@ -106,8 +114,6 @@ async def main(): print("-" * 70) # Verify output file exists - from pathlib import Path - if Path(output_path).exists(): # noqa: ASYNC240 file_size = Path(output_path).stat().st_size # noqa: ASYNC240 print(f"\n✓ Output file created: {output_path}") @@ -157,8 +163,6 @@ async def main(): except Exception as e: print(f"✗ Exception during sampling: {e}") - import traceback - traceback.print_exc() print("\n" + "=" * 70) @@ -167,6 +171,4 @@ async def main(): if __name__ == "__main__": - import asyncio - asyncio.run(main()) diff --git a/mellea/helpers/imports.py b/mellea/helpers/imports.py index 356bcc5d7..9dadc23bd 100644 --- a/mellea/helpers/imports.py +++ b/mellea/helpers/imports.py @@ -32,6 +32,9 @@ def get_unauthorized_imports( try: tree = ast.parse(code) except (SyntaxError, ValueError): + # Syntax errors are validated separately (e.g., in python_tools.py with + # _code_parses()). Returning empty here allows those dedicated validators + # to provide better error messages without double-reporting parse failures. return [] for node in ast.walk(tree): diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 73e060560..ae1f45e68 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -101,19 +101,47 @@ def _code_parses(code: str) -> tuple[bool, str | None]: return False, error_msg +def _strip_comments(code: str) -> str: + """Remove Python comments from code while preserving strings. + + Splits code by lines and removes comments (text after # that's not in a string). + Handles both single and double quoted strings. + """ + lines = code.split("\n") + result = [] + for line in lines: + in_string = False + string_char = None + for i, char in enumerate(line): + if char in ('"', "'") and (i == 0 or line[i - 1] != "\\"): + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None + elif char == "#" and not in_string: + result.append(line[:i]) + break + else: + result.append(line) + return "\n".join(result) + + def _uses_pyplot_show(code: str) -> bool: """Check if code calls plt.show() or matplotlib.pyplot.show().""" - # Simple string checks work for most cases - return "plt.show" in code or ".show()" in code + clean_code = _strip_comments(code) + return "plt.show" in clean_code or ".show()" in clean_code def _sets_headless_backend(code: str) -> bool: """Check if code sets matplotlib to use a headless backend.""" + clean_code = _strip_comments(code) headless_backends = ("Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg") for backend in headless_backends: if ( - f"matplotlib.use('{backend}')" in code - or f'matplotlib.use("{backend}")' in code + f"matplotlib.use('{backend}')" in clean_code + or f'matplotlib.use("{backend}")' in clean_code ): return True return False From 036cfee9982a9d4e2c5742c53296718519be6bb3 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 11 May 2026 15:40:37 -0400 Subject: [PATCH 04/12] review comments Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/python_tools.py | 41 ++++++++++++++-------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index ae1f45e68..5b3fc0680 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -273,7 +273,9 @@ def validate(ctx: Context) -> ValidationResult: return validate -def _make_matplotlib_headless_validator() -> Callable[[Context], ValidationResult]: +def _make_matplotlib_headless_validator( + output_path: str | None = None, +) -> Callable[[Context], ValidationResult]: """Create a validator that checks matplotlib uses headless backend.""" def validate(ctx: Context) -> ValidationResult: @@ -282,14 +284,19 @@ def validate(ctx: Context) -> ValidationResult: return ValidationResult(result=True) if _uses_pyplot_show(code) and not _sets_headless_backend(code): + savefig_instruction = ( + f"plt.savefig('{output_path}'); plt.close()" + if output_path + else "plt.savefig('{output_path}'); plt.close()" + ) return ValidationResult( result=False, - reason="Your code calls `plt.show()` but doesn't set a headless backend.\n" - "This will fail in a headless environment (no display).\n\n" - "Fix this by adding to the top of your code:\n" - " import matplotlib\n" - " matplotlib.use('Agg')\n\n" - "Then replace `plt.show()` with `plt.savefig('{output_path}'); plt.close()`", + reason=f"Your code calls `plt.show()` but doesn't set a headless backend.\n" + f"This will fail in a headless environment (no display).\n\n" + f"Fix this by adding to the top of your code:\n" + f" import matplotlib\n" + f" matplotlib.use('Agg')\n\n" + f"Then replace `plt.show()` with `{savefig_instruction}`", ) return ValidationResult(result=True) @@ -297,7 +304,9 @@ def validate(ctx: Context) -> ValidationResult: return validate -def _make_plots_saved_validator() -> Callable[[Context], ValidationResult]: +def _make_plots_saved_validator( + output_path: str | None = None, +) -> Callable[[Context], ValidationResult]: """Create a validator that checks if code saves plots to a file.""" def validate(ctx: Context) -> ValidationResult: @@ -306,12 +315,16 @@ def validate(ctx: Context) -> ValidationResult: return ValidationResult(result=True) if _uses_pyplot_plot(code) and not _calls_savefig(code): + savefig_instruction = ( + f"plt.savefig('{output_path}')\n plt.close()" + if output_path + else "plt.savefig('{output_path}')\n plt.close()" + ) return ValidationResult( result=False, - reason="Your code creates plots with pyplot but never calls `plt.savefig()` to save them.\n\n" - "Add this before your plotting code or at the end:\n" - " plt.savefig('{output_path}')\n" - " plt.close()", + reason=f"Your code creates plots with pyplot but never calls `plt.savefig()` to save them.\n\n" + f"Add this before your plotting code or at the end:\n" + f" {savefig_instruction}", ) return ValidationResult(result=True) @@ -471,7 +484,7 @@ def _build_requirements(self) -> list[Requirement]: description=( "If using pyplot, must set headless backend and use savefig." ), - validation_fn=_make_matplotlib_headless_validator(), + validation_fn=_make_matplotlib_headless_validator(self.output_path), check_only=False, ) ) @@ -479,7 +492,7 @@ def _build_requirements(self) -> list[Requirement]: reqs.append( Requirement( description="If creating plots, must call savefig to save them.", - validation_fn=_make_plots_saved_validator(), + validation_fn=_make_plots_saved_validator(self.output_path), check_only=False, ) ) From e3d00f3bed2bfebfe354784aea5b1a626327848e Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 11 May 2026 17:14:52 -0400 Subject: [PATCH 05/12] review comments Signed-off-by: Akihiko Kuroda --- .../{ => tools}/python_plotting_repair.py | 52 +-- mellea/helpers/__init__.py | 2 - mellea/helpers/imports.py | 51 -- mellea/stdlib/requirements/__init__.py | 5 +- mellea/stdlib/requirements/imports.py | 34 ++ mellea/stdlib/requirements/python_reqs.py | 51 +- mellea/stdlib/requirements/python_tools.py | 359 +++++--------- mellea/stdlib/tools/interpreter.py | 2 +- test/stdlib/requirements/test_python_tools.py | 438 +++++++++--------- 9 files changed, 432 insertions(+), 562 deletions(-) rename docs/examples/{ => tools}/python_plotting_repair.py (70%) delete mode 100644 mellea/helpers/imports.py create mode 100644 mellea/stdlib/requirements/imports.py diff --git a/docs/examples/python_plotting_repair.py b/docs/examples/tools/python_plotting_repair.py similarity index 70% rename from docs/examples/python_plotting_repair.py rename to docs/examples/tools/python_plotting_repair.py index 5989f9a0c..26e958d04 100644 --- a/docs/examples/python_plotting_repair.py +++ b/docs/examples/tools/python_plotting_repair.py @@ -1,21 +1,5 @@ # pytest: ollama, e2e, qualitative -"""Granite 4.1 repairs the three canonical plotting failures with Python tool. - -This example demonstrates: -1. Creating a PythonToolRequirements bundle for plotting validation -2. Using SOFAI sampling strategy with repair feedback loop -3. Granite 4.1 repairing through: syntax → imports → headless backend → savefig - -Canonical task: "Create a plot of sin(x) for x in 0..2π and save to /tmp/plot.png" - -The model will encounter and repair: -- Attempt 1: Missing matplotlib.use('Agg') (non-headless backend) -- Attempt 2: Missing plt.savefig() call -- Attempt 3: Success with both fixes applied - -The requirements bundle provides actionable failure messages that guide the model -through each repair iteration without explicit instruction. -""" +"""Repair plotting code with Python-tool and plotting-specific requirements.""" import asyncio import tempfile @@ -27,7 +11,10 @@ from mellea.backends.tools import MelleaTool from mellea.stdlib.components import Instruction from mellea.stdlib.context import ChatContext -from mellea.stdlib.requirements import PythonToolRequirements +from mellea.stdlib.requirements import ( + python_plotting_requirements, + python_tool_requirements, +) from mellea.stdlib.sampling import SOFAISamplingStrategy from mellea.stdlib.tools import local_code_interpreter from mellea.stdlib.tools.interpreter import ExecutionResult @@ -46,21 +33,17 @@ def python(code: str) -> ExecutionResult: async def main(): - """Run the canonical plotting repair example.""" + """Run the plotting repair example.""" with tempfile.TemporaryDirectory() as tmpdir: output_path = str(Path(tmpdir) / "plot.png") - # Initialize session with local backend m = mellea.start_session() - # Create requirements bundle for plotting validation - # Note: We don't pass output_path, so the bundle doesn't enforce artifact validation. - # This example tests code generation and repair logic, not actual code execution. - # The model generates syntactically correct plotting code, but doesn't execute it, - # so the output file is never created. - bundle = PythonToolRequirements(allowed_imports=["numpy", "matplotlib", "math"]) + requirements = [ + *python_tool_requirements(allowed_imports=["numpy", "matplotlib", "math"]), + *python_plotting_requirements(output_path=output_path), + ] - # Define SOFAI strategy for repair: S1 (fast) up to 3 times, then S2 (slow) sampling_strategy = SOFAISamplingStrategy( s1_solver_backend=m.backend, s2_solver_backend=m.backend, @@ -69,7 +52,6 @@ async def main(): feedback_strategy="first_error", ) - # Create the plotting task instruction task_summary = ( f"Create a plot of sin(x) for x in 0..2π and save it to {output_path}" ) @@ -85,21 +67,19 @@ async def main(): Use the python tool with your complete code.""" instruction = Instruction(description=description) - # Create a chat context for multi-turn repair ctx = ChatContext() print("=" * 70) - print("Testing Granite 4.1's ability to repair plotting failures") + print("Testing plotting-code repair with Python tool requirements") print("=" * 70) print(f"Task: {task_summary}\n") try: - # Run the sampling strategy with requirements result = await sampling_strategy.sample( action=instruction, context=ctx, backend=m.backend, - requirements=bundle.requirements, + requirements=requirements, tool_calls=True, model_options={ModelOption.TOOLS: [MelleaTool.from_callable(python)]}, ) @@ -107,13 +87,12 @@ async def main(): print(f"\nResult: {'SUCCESS' if result.success else 'FAILED'}\n") if result.success: - print("✓ Granite 4.1 successfully generated and executed plotting code") + print("✓ Model successfully generated and executed plotting code") print("\nFinal generated code:") print("-" * 70) print(result.result.value) print("-" * 70) - # Verify output file exists if Path(output_path).exists(): # noqa: ASYNC240 file_size = Path(output_path).stat().st_size # noqa: ASYNC240 print(f"\n✓ Output file created: {output_path}") @@ -121,7 +100,6 @@ async def main(): else: print(f"\n✗ Output file not found: {output_path}") - # Print repair history print(f"\nRepair iterations: {len(result.sample_validations)}") for attempt_idx, validations in enumerate(result.sample_validations, 1): passed = sum(1 for _, val in validations if val.as_bool()) @@ -132,7 +110,6 @@ async def main(): f"requirements passed" ) - # Show which requirements failed for req, val in validations: if not val.as_bool(): print(f" - {req.description}") @@ -147,7 +124,6 @@ async def main(): print(result.result.value) print("-" * 70) - # Print failure history print(f"\nFailure history ({len(result.sample_validations)} attempts):") for attempt_idx, validations in enumerate(result.sample_validations, 1): failed_count = sum(1 for _, val in validations if not val.as_bool()) @@ -172,3 +148,5 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) + +# Made with Bob diff --git a/mellea/helpers/__init__.py b/mellea/helpers/__init__.py index bad10f5b4..62b22eb6a 100644 --- a/mellea/helpers/__init__.py +++ b/mellea/helpers/__init__.py @@ -16,7 +16,6 @@ wait_for_all_mots, ) from .event_loop_helper import _run_async_in_thread -from .imports import get_unauthorized_imports from .openai_compatible_helpers import ( chat_completion_delta_merge, extract_model_tool_requests, @@ -37,7 +36,6 @@ "chat_completion_delta_merge", "extract_model_tool_requests", "get_current_event_loop", - "get_unauthorized_imports", "is_vllm_server_with_structured_output", "message_to_openai_message", "messages_to_docs", diff --git a/mellea/helpers/imports.py b/mellea/helpers/imports.py deleted file mode 100644 index 9dadc23bd..000000000 --- a/mellea/helpers/imports.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Utilities for analyzing Python code imports.""" - -import ast - - -def get_unauthorized_imports( - code: str, allowed_imports: list[str] | None = None -) -> list[str]: - r"""Extract unauthorized imports from Python code. - - Parses Python code and returns a sorted list of top-level modules that are - imported but not in the allowed list. Handles both `import X` and `from X import Y` - statements, extracting the root module name (e.g., "numpy" from "numpy.random"). - - Args: - code: Python source code to analyze. - allowed_imports: Allowlist of permitted top-level modules. If None, allows all imports. - - Returns: - Sorted list of unauthorized module names found in code. Empty list if code - has syntax errors or if allowed_imports is None. - - Example: - >>> code = "import numpy\nimport os\nimport forbidden_lib" - >>> get_unauthorized_imports(code, ["os", "sys"]) - ['numpy', 'forbidden_lib'] - """ - if allowed_imports is None: - return [] - - unauthorized: set[str] = set() - try: - tree = ast.parse(code) - except (SyntaxError, ValueError): - # Syntax errors are validated separately (e.g., in python_tools.py with - # _code_parses()). Returning empty here allows those dedicated validators - # to provide better error messages without double-reporting parse failures. - return [] - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - module = alias.name.split(".")[0] - if module not in allowed_imports: - unauthorized.add(module) - elif isinstance(node, ast.ImportFrom) and node.module: - module = node.module.split(".")[0] - if module not in allowed_imports: - unauthorized.add(module) - - return sorted(unauthorized) diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index f40287553..46177ff88 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -4,7 +4,7 @@ from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table from .python_reqs import PythonExecutionReq -from .python_tools import PythonToolRequirements +from .python_tools import python_plotting_requirements, python_tool_requirements from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -20,7 +20,6 @@ "ALoraRequirement", "LLMaJRequirement", "PythonExecutionReq", - "PythonToolRequirements", "Requirement", "ValidationResult", "as_markdown_list", @@ -28,6 +27,8 @@ "default_output_to_bool", "is_markdown_list", "is_markdown_table", + "python_plotting_requirements", + "python_tool_requirements", "req", "reqify", "requirement_check_to_bool", diff --git a/mellea/stdlib/requirements/imports.py b/mellea/stdlib/requirements/imports.py new file mode 100644 index 000000000..d95ba16c4 --- /dev/null +++ b/mellea/stdlib/requirements/imports.py @@ -0,0 +1,34 @@ +"""Import analysis helpers for Python requirements and execution environments.""" + +import ast + + +def get_unauthorized_imports( + code: str, allowed_imports: list[str] | None = None +) -> list[str]: + """Extract unauthorized top-level imports from Python code.""" + if allowed_imports is None: + return [] + + unauthorized: set[str] = set() + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + # Syntax errors are validated separately by dedicated validators. + return [] + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module = alias.name.split(".")[0] + if module not in allowed_imports: + unauthorized.add(module) + elif isinstance(node, ast.ImportFrom) and node.module: + module = node.module.split(".")[0] + if module not in allowed_imports: + unauthorized.add(module) + + return sorted(unauthorized) + + +# Made with Bob diff --git a/mellea/stdlib/requirements/python_reqs.py b/mellea/stdlib/requirements/python_reqs.py index 3152acb71..4ba2ee8bb 100644 --- a/mellea/stdlib/requirements/python_reqs.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -58,8 +58,57 @@ def _score_code_block(code: str) -> int: return score +def extract_python_code(ctx: Context) -> ValidationResult: + """Extract Python code from tool calls or markdown code blocks.""" + last_output = ctx.last_output() + if last_output is None: + return ValidationResult(result=False, reason="No output found in context") + + if last_output.tool_calls and "python" in last_output.tool_calls: + tool_call = last_output.tool_calls["python"] + code = tool_call.args.get("code") + if isinstance(code, str) and code.strip(): + return ValidationResult(result=True, reason=code) + + if last_output.value is None: + return ValidationResult(result=False, reason="No output found in context") + + content = last_output.value + + # Look for code blocks with python specifier + import re + + # Pattern for ```python ... ``` blocks + python_blocks = re.findall(r"```python\s*\n(.*?)\n```", content, re.DOTALL) + + # Pattern for generic ``` blocks + generic_blocks = re.findall(r"```\s*\n(.*?)\n```", content, re.DOTALL) + + all_blocks = [] + + # Add python blocks with high priority + for block in python_blocks: + all_blocks.append((block.strip(), _score_code_block(block.strip()) + 10)) + + # Add generic blocks if they look like Python + for block in generic_blocks: + block = block.strip() + if block and any( + keyword in block + for keyword in ["def ", "class ", "import ", "print(", "if __name__"] + ): + all_blocks.append((block, _score_code_block(block))) + + if not all_blocks: + return ValidationResult(result=False, reason="No Python code blocks found") + + # Return the highest scoring block + best_block = max(all_blocks, key=lambda x: x[1]) + return ValidationResult(result=True, reason=best_block[0]) + + def _has_python_code_listing(ctx: Context) -> ValidationResult: - """Extract Python code from context.""" + """Extract Python code from markdown code blocks in context.""" last_output = ctx.last_output() if last_output is None or last_output.value is None: return ValidationResult(result=False, reason="No output found in context") diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 5b3fc0680..8e63b32ed 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -1,54 +1,9 @@ -"""Pre-composed requirement bundles for Python tool invocation and execution. +"""Requirement factories for Python tool invocation and code validation. -This module provides bundled requirements for validating Python code generated -via the Python tool, with focus on reactive failure detection and repair: - -- Tool invocation validation (must call Python tool with code argument) -- Syntax validation (code must parse correctly) -- Import validation (code imports must be in allowlist) -- Matplotlib headless backend detection (plt.show() without backend) -- Plot artifact validation (savefig must be called, output files must exist) -- Output limiting (stdout/stderr must not exceed configured limits) - -Failure messages are written as feedback to the model, not to developers. -They state the failure, include relevant code/stderr, and explain the -correction well enough for the model to act on it. - -FAILURE MATRIX — How each requirement catches the canonical plotting failures: - -Scenario: Model generates plotting code with matplotlib - -Attempt 1: No tool call - → MustInvokePythonTool fails - → Repair: "Call the `python` tool with your code" - -Attempt 2: Tool called but no 'code' arg - → PythonToolHasCodeArg fails - → Repair: "The python tool requires a 'code' argument" - -Attempt 3: Code has syntax error - → PythonCodeParses fails - → Repair: "Your code has a syntax error at line X: {error}" - -Attempt 4: Code imports matplotlib (not in allowed_imports) - → PythonImportsAllowed fails - → Repair: "matplotlib is not allowed. Use only: {allowed_list}" - -Attempt 5: Code uses plt.show() without headless backend - → MatplotlibHeadless fails - → Repair: "Add matplotlib.use('Agg') and replace plt.show() with plt.savefig(...)" - -Attempt 6: Code has plt.plot() but no plt.savefig() - → PlotsAreSaved fails - → Repair: "Add plt.savefig('{output_path}') to save the plot" - -Attempt 7: Code runs, but output file not created - → OutputArtifactsExist fails - → Repair: "File '{output_path}' was not created. Check plt.savefig() call" - -Attempt 8: Success - → All requirements pass - → Result: plot file exists and is non-empty +This module provides generic requirements for Python-tool usage and code +correctness. Plotting-specific checks are exposed separately through +``python_plotting_requirements(...)`` so they are not implied to be universal +Python-tool requirements. """ import ast @@ -56,30 +11,9 @@ from pathlib import Path from ...core import Context, Requirement, ValidationResult -from ...helpers import get_unauthorized_imports -from .python_reqs import _has_python_code_listing - - -def _extract_code(ctx: Context) -> str | None: - """Extract Python code from either tool calls or markdown blocks. - - Checks tool_calls dict first (for tool calling), then falls back to - markdown code blocks in response text. - - Returns the code string, or None if no code found. - """ - # Try tool_calls first (tool calling format) - output = ctx.last_output() - if output and output.tool_calls and "python" in output.tool_calls: - tool_call = output.tool_calls["python"] - if hasattr(tool_call, "args") and "code" in tool_call.args: - return tool_call.args["code"] - - # Fall back to markdown code blocks in response text - result = _has_python_code_listing(ctx) - if result.as_bool() and result.reason: - return result.reason - return None +from .imports import get_unauthorized_imports +from .python_reqs import extract_python_code +from .tool_reqs import tool_arg_validator, uses_tool def _code_parses(code: str) -> tuple[bool, str | None]: @@ -173,62 +107,27 @@ def _calls_savefig(code: str) -> bool: # region Individual Requirement Validators -def _validate_python_tool_invoked(ctx: Context) -> ValidationResult: - """Requirement: Model must invoke the Python tool.""" - output = ctx.last_output() - if output is None or output.tool_calls is None: - return ValidationResult( - result=False, - reason=( - "You did not invoke any tools. To execute Python code, " - "call the `python` tool with your code." - ), - ) - if "python" not in output.tool_calls: - return ValidationResult( - result=False, - reason=( - "You did not call the `python` tool. Call it with your " - "code to execute it." - ), - ) - return ValidationResult(result=True) - - -def _validate_python_tool_has_code_arg(ctx: Context) -> ValidationResult: - """Requirement: Python tool call must include a 'code' argument.""" - output = ctx.last_output() - if output is None or output.tool_calls is None: - return ValidationResult(result=False, reason="No tool calls found") - - if "python" not in output.tool_calls: - return ValidationResult(result=False, reason="Python tool not called") - - python_call = output.tool_calls["python"] - if "code" not in python_call.args: - return ValidationResult( - result=False, - reason="The `python` tool call must include a `code` argument with your Python code.", - ) - - return ValidationResult(result=True) +def _python_code_arg_present(arg_value: object) -> bool: + """Return True when the python tool code argument is present and non-empty.""" + return isinstance(arg_value, str) and bool(arg_value.strip()) def _make_code_parses_validator() -> Callable[[Context], ValidationResult]: """Create a validator that checks if extracted code parses.""" def validate(ctx: Context) -> ValidationResult: - code = _extract_code(ctx) - if not code: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: return ValidationResult( result=False, reason=( "Could not extract Python code from your response. " - "Make sure to include code in ```python ... ``` blocks." + "Make sure to include code in the python tool call or " + "in ```python ... ``` blocks." ), ) - parses, error = _code_parses(code) + parses, error = _code_parses(extraction_result.reason) if not parses: return ValidationResult( result=False, @@ -249,13 +148,15 @@ def validate(ctx: Context) -> ValidationResult: if allowed_imports is None: return ValidationResult(result=True) - code = _extract_code(ctx) - if not code: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: return ValidationResult( result=False, reason="Could not extract Python code" ) - unauthorized = get_unauthorized_imports(code, allowed_imports) + unauthorized = get_unauthorized_imports( + extraction_result.reason, allowed_imports + ) if unauthorized: allowed_str = ", ".join(sorted(set(allowed_imports))) return ValidationResult( @@ -279,10 +180,11 @@ def _make_matplotlib_headless_validator( """Create a validator that checks matplotlib uses headless backend.""" def validate(ctx: Context) -> ValidationResult: - code = _extract_code(ctx) - if not code: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: return ValidationResult(result=True) + code = extraction_result.reason if _uses_pyplot_show(code) and not _sets_headless_backend(code): savefig_instruction = ( f"plt.savefig('{output_path}'); plt.close()" @@ -310,10 +212,11 @@ def _make_plots_saved_validator( """Create a validator that checks if code saves plots to a file.""" def validate(ctx: Context) -> ValidationResult: - code = _extract_code(ctx) - if not code: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: return ValidationResult(result=True) + code = extraction_result.reason if _uses_pyplot_plot(code) and not _calls_savefig(code): savefig_instruction = ( f"plt.savefig('{output_path}')\n plt.close()" @@ -332,6 +235,40 @@ def validate(ctx: Context) -> ValidationResult: return validate +def python_plotting_requirements( + output_path: str | None = None, *, check_output_artifacts: bool | None = None +) -> list[Requirement]: + """Build plotting-specific requirements for Python tool responses.""" + reqs: list[Requirement] = [] + + reqs.append( + Requirement( + description="If using pyplot, must set headless backend and use savefig.", + validation_fn=_make_matplotlib_headless_validator(output_path), + check_only=False, + ) + ) + + reqs.append( + Requirement( + description="If creating plots, must call savefig to save them.", + validation_fn=_make_plots_saved_validator(output_path), + check_only=False, + ) + ) + + if check_output_artifacts and output_path: + reqs.append( + Requirement( + description=f"Output file must be created at {output_path}", + validation_fn=_make_output_artifacts_validator(output_path), + check_only=False, + ) + ) + + return reqs + + def _make_output_artifacts_validator( output_path: str, ) -> Callable[[Context], ValidationResult]: @@ -369,11 +306,13 @@ def validate(ctx: Context) -> ValidationResult: if output is None: return ValidationResult(result=True) + stdout = getattr(output, "stdout", "") + stderr = getattr(output, "stderr", "") total_output = "" - if hasattr(output, "stdout") and output.stdout: - total_output += output.stdout - if hasattr(output, "stderr") and output.stderr: - total_output += output.stderr + if isinstance(stdout, str): + total_output += stdout + if isinstance(stderr, str): + total_output += stderr size = len(total_output.encode("utf-8")) if size > limit_bytes: @@ -391,140 +330,58 @@ def validate(ctx: Context) -> ValidationResult: # endregion -class PythonToolRequirements: - """Pre-composed bundle of requirements for Python code generation via the tool. - - This bundle validates the complete Python code generation flow: tool invocation, - syntax, imports, execution, and output. It's designed to work with repair loops - (SOFAI, MultiTurnStrategy) to iteratively fix common plotting failures. - - Markers: - - **Deterministic** (unit-testable): tool invocation, syntax, imports, headless backend, - savefig presence, file existence, output limits - - **Qualitative** (needs model to evaluate): execution without error (captured via stderr) - - Args: - output_path (str | None): Path where plots should be saved. If specified, enables - output artifact validation. Defaults to None. - allowed_imports (list[str] | None): Allowlist of importable top-level modules. - None (default) allows any import. Set to list like ["numpy", "matplotlib"] - to restrict imports. - output_limit_bytes (int): Maximum bytes of stdout/stderr allowed. Defaults to 50000. - check_output_artifacts (bool): If True, validate that output file exists and is - non-empty after execution. Defaults to True if output_path is specified. - - Attributes: - requirements (list[Requirement]): The composed list of requirements, suitable - for use with sampling strategies. - """ - - def __init__( - self, - output_path: str | None = None, - allowed_imports: list[str] | None = None, - output_limit_bytes: int = 50_000, - check_output_artifacts: bool | None = None, - ): - """Initialize the Python tool requirements bundle.""" - self.output_path = output_path - self.allowed_imports = allowed_imports - self.output_limit_bytes = output_limit_bytes - - # Auto-enable output artifact checking if output_path is specified - if check_output_artifacts is None: - check_output_artifacts = output_path is not None - - self._check_output_artifacts = check_output_artifacts - - self.requirements = self._build_requirements() - - def _build_requirements(self) -> list[Requirement]: - """Build the list of requirements for this bundle.""" - reqs: list[Requirement] = [] - - # Tool invocation requirements (deterministic) - reqs.append( - Requirement( - description="Use the python tool to execute code.", - validation_fn=_validate_python_tool_invoked, - check_only=False, - ) +def python_tool_requirements( + output_path: str | None = None, + allowed_imports: list[str] | None = None, + output_limit_bytes: int = 50_000, + check_output_artifacts: bool | None = None, +) -> list[Requirement]: + """Build requirements for Python code generation via the python tool.""" + reqs: list[Requirement] = [] + + if check_output_artifacts is None: + check_output_artifacts = output_path is not None + + reqs.append(uses_tool("python")) + + reqs.append( + tool_arg_validator( + description="The python tool call must include a code argument.", + tool_name="python", + arg_name="code", + validation_fn=_python_code_arg_present, ) + ) - reqs.append( - Requirement( - description="The python tool call must include a code argument.", - validation_fn=_validate_python_tool_has_code_arg, - check_only=False, - ) + reqs.append( + Requirement( + description="The Python code must parse correctly.", + validation_fn=_make_code_parses_validator(), + check_only=False, ) + ) - # Code quality requirements (deterministic) - reqs.append( - Requirement( - description="The Python code must parse correctly.", - validation_fn=_make_code_parses_validator(), - check_only=False, - ) - ) - - # Import validation (deterministic) - if self.allowed_imports is not None: - reqs.append( - Requirement( - description=f"Imports must be from allowed list: {', '.join(self.allowed_imports)}", - validation_fn=_make_imports_allowed_validator(self.allowed_imports), - check_only=False, - ) - ) - - # Matplotlib-specific requirements (deterministic) + if allowed_imports is not None: reqs.append( Requirement( - description=( - "If using pyplot, must set headless backend and use savefig." - ), - validation_fn=_make_matplotlib_headless_validator(self.output_path), + description=f"Imports must be from allowed list: {', '.join(allowed_imports)}", + validation_fn=_make_imports_allowed_validator(allowed_imports), check_only=False, ) ) - reqs.append( - Requirement( - description="If creating plots, must call savefig to save them.", - validation_fn=_make_plots_saved_validator(self.output_path), - check_only=False, - ) + reqs.extend( + python_plotting_requirements( + output_path=output_path, check_output_artifacts=check_output_artifacts ) + ) - # Output artifact validation (deterministic, post-execution) - if self._check_output_artifacts and self.output_path: - reqs.append( - Requirement( - description=f"Output file must be created at {self.output_path}", - validation_fn=_make_output_artifacts_validator(self.output_path), - check_only=False, - ) - ) - - # Output limiting (deterministic) - reqs.append( - Requirement( - description=f"Output must not exceed {self.output_limit_bytes} bytes.", - validation_fn=_make_output_limit_validator(self.output_limit_bytes), - check_only=False, - ) + reqs.append( + Requirement( + description=f"Output must not exceed {output_limit_bytes} bytes.", + validation_fn=_make_output_limit_validator(output_limit_bytes), + check_only=False, ) + ) - return reqs - - def __repr__(self) -> str: - """Return a developer-readable representation.""" - return ( - f"PythonToolRequirements(" - f"output_path={self.output_path!r}, " - f"allowed_imports={self.allowed_imports!r}, " - f"output_limit_bytes={self.output_limit_bytes}, " - f"requirements={len(self.requirements)} items" - f")" - ) + return reqs diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index b256f7d52..008d4ff30 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -20,7 +20,7 @@ from typing import Any from ...core import MelleaLogger -from ...helpers import get_unauthorized_imports +from ..requirements.imports import get_unauthorized_imports logger = MelleaLogger.get_logger() diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py index 49056d93b..6d47e19fe 100644 --- a/test/stdlib/requirements/test_python_tools.py +++ b/test/stdlib/requirements/test_python_tools.py @@ -1,18 +1,26 @@ """Tests for Python tool requirements bundle.""" import tempfile +from collections.abc import Callable from pathlib import Path - -from mellea.core import Context, ModelOutputThunk -from mellea.helpers import get_unauthorized_imports +from typing import Any, cast + +from mellea.core import ( + Context, + ModelOutputThunk, + ModelToolCall, + Requirement, + ValidationResult, +) from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.imports import get_unauthorized_imports from mellea.stdlib.requirements.python_tools import ( - PythonToolRequirements, _calls_savefig, _code_parses, _sets_headless_backend, _uses_pyplot_plot, _uses_pyplot_show, + python_tool_requirements, ) @@ -23,6 +31,60 @@ def from_model(content: str) -> Context: return ctx +def requirement_description(requirement: Requirement) -> str: + """Return a non-optional description for test filtering.""" + assert requirement.description is not None + return requirement.description + + +def validation_fn(requirement: Requirement) -> Callable[[Context], ValidationResult]: + """Return a requirement validation function for tests.""" + assert requirement.validation_fn is not None + return requirement.validation_fn + + +def validation_reason(result: ValidationResult) -> str: + """Return a non-optional validation reason for assertions.""" + assert result.reason is not None + return result.reason + + +class _DummyTool: + """Minimal tool stub for constructing ModelToolCall in tests.""" + + def run(self, **kwargs: Any) -> None: + return None + + +def python_tool_call(code: str | None = None, **extra_args: str) -> ModelToolCall: + """Create a typed python tool call for tests.""" + args: dict[str, str] = dict(extra_args) + if code is not None: + args["code"] = code + return ModelToolCall(name="python", func=cast(Any, _DummyTool()), args=args) + + +def requirements_matching( + substring: str, + *, + output_path: str | None = None, + allowed_imports: list[str] | None = None, + output_limit_bytes: int = 50_000, + check_output_artifacts: bool | None = None, +) -> list[Requirement]: + """Return requirements whose description contains the substring.""" + return [ + requirement + for requirement in python_tool_requirements( + output_path=output_path, + allowed_imports=allowed_imports, + output_limit_bytes=output_limit_bytes, + check_output_artifacts=check_output_artifacts, + ) + if substring in requirement_description(requirement).lower() + ] + + # region: Helper function tests @@ -149,45 +211,45 @@ def test_savefig_not_detected(self): # endregion -# region: PythonToolRequirements tests +# region: python_tool_requirements tests class TestPythonToolRequirementsBasic: - """Basic tests for PythonToolRequirements bundle.""" + """Basic tests for python_tool_requirements.""" def test_initialization(self): - """Bundle should initialize with default settings.""" - bundle = PythonToolRequirements() - assert bundle.requirements is not None - assert len(bundle.requirements) > 0 - - def test_with_output_path(self): - """Bundle should accept output_path parameter.""" - bundle = PythonToolRequirements(output_path="/tmp/plot.png") - assert bundle.output_path == "/tmp/plot.png" - - def test_with_allowed_imports(self): - """Bundle should accept allowed_imports parameter.""" - allowed = ["numpy", "matplotlib"] - bundle = PythonToolRequirements(allowed_imports=allowed) - assert bundle.allowed_imports == allowed - - def test_output_artifact_checking_enabled_by_default(self): - """Output artifact checking should be enabled if output_path is set.""" - bundle = PythonToolRequirements(output_path="/tmp/plot.png") - assert bundle._check_output_artifacts is True + """Factory should return requirements with default settings.""" + requirements = python_tool_requirements() + assert requirements + assert len(requirements) > 0 + + def test_with_output_path_enables_artifact_requirement(self): + """Output path should enable artifact validation by default.""" + requirements = python_tool_requirements(output_path="/tmp/plot.png") + artifact_reqs = [ + r + for r in requirements + if "output file" in requirement_description(r).lower() + ] + assert len(artifact_reqs) == 1 - def test_output_artifact_checking_disabled_by_default(self): - """Output artifact checking should be disabled if no output_path.""" - bundle = PythonToolRequirements() - assert bundle._check_output_artifacts is False + def test_without_output_path_disables_artifact_requirement(self): + """Artifact validation should be absent without output_path.""" + requirements = python_tool_requirements() + artifact_reqs = [ + r + for r in requirements + if "output file" in requirement_description(r).lower() + ] + assert len(artifact_reqs) == 0 - def test_repr(self): - """Bundle should have a readable repr.""" - bundle = PythonToolRequirements(output_path="/tmp/plot.png") - repr_str = repr(bundle) - assert "PythonToolRequirements" in repr_str - assert "/tmp/plot.png" in repr_str + def test_with_allowed_imports_adds_import_requirement(self): + """Allowed imports should add an import validation requirement.""" + requirements = python_tool_requirements(allowed_imports=["numpy", "matplotlib"]) + import_reqs = [ + r for r in requirements if "import" in requirement_description(r).lower() + ] + assert len(import_reqs) > 0 # endregion @@ -201,29 +263,27 @@ class TestMustInvokePythonTool: def test_tool_not_called(self): """Should fail if python tool not called.""" - bundle = PythonToolRequirements() - req = bundle.requirements[0] + req = python_tool_requirements()[0] ctx = from_model("Here is the code:\n```python\nprint('hello')\n```") - result = req.validation_fn(ctx) + result = validation_fn(req)(ctx) assert result.as_bool() is False - reason_lower = result.reason.lower() - assert "did not invoke" in reason_lower or "did not call" in reason_lower + reason_lower = validation_reason(result).lower() + assert "no tool calls" in reason_lower or "did not call" in reason_lower def test_python_tool_called(self): """Should pass if python tool is called.""" - bundle = PythonToolRequirements() - req = bundle.requirements[0] + req = python_tool_requirements()[0] ctx = ChatContext() - call_obj = type("Call", (), {"args": {"code": "print('hi')"}})() output = ModelOutputThunk( - value="I'll execute this code", tool_calls={"python": call_obj} + value="I'll execute this code", + tool_calls={"python": python_tool_call("print('hi')")}, ) ctx = ctx.add(output) - result = req.validation_fn(ctx) + result = validation_fn(req)(ctx) assert result.as_bool() is True @@ -232,33 +292,31 @@ class TestPythonToolHasCodeArg: def test_missing_code_argument(self): """Should fail if python tool call has no code argument.""" - bundle = PythonToolRequirements() - req = bundle.requirements[1] + req = python_tool_requirements()[1] ctx = ChatContext() - call_obj = type("Call", (), {"args": {"other": "value"}})() output = ModelOutputThunk( - value="I'll execute this", tool_calls={"python": call_obj} + value="I'll execute this", + tool_calls={"python": python_tool_call(other="value")}, ) ctx = ctx.add(output) - result = req.validation_fn(ctx) + result = validation_fn(req)(ctx) assert result.as_bool() is False - assert "code" in result.reason.lower() + assert "code" in validation_reason(result).lower() def test_has_code_argument(self): """Should pass if python tool call has code argument.""" - bundle = PythonToolRequirements() - req = bundle.requirements[1] + req = python_tool_requirements()[1] ctx = ChatContext() - call_obj = type("Call", (), {"args": {"code": "print('hi')"}})() output = ModelOutputThunk( - value="I'll execute this", tool_calls={"python": call_obj} + value="I'll execute this", + tool_calls={"python": python_tool_call("print('hi')")}, ) ctx = ctx.add(output) - result = req.validation_fn(ctx) + result = validation_fn(req)(ctx) assert result.as_bool() is True @@ -267,67 +325,57 @@ class TestCodeParsesRequirement: def test_valid_code(self): """Valid code should pass.""" - bundle = PythonToolRequirements() - parse_reqs = [ - r for r in bundle.requirements if "parse" in r.description.lower() - ] + parse_reqs = requirements_matching("parse") parse_req = parse_reqs[0] ctx = from_model("```python\nx = 1\nprint(x)\n```") - result = parse_req.validation_fn(ctx) + result = validation_fn(parse_req)(ctx) assert result.as_bool() is True def test_syntax_error(self): """Syntax errors should be caught.""" - bundle = PythonToolRequirements() - parse_reqs = [ - r for r in bundle.requirements if "parse" in r.description.lower() - ] + parse_reqs = requirements_matching("parse") parse_req = parse_reqs[0] ctx = from_model("```python\ndef foo(\n return 42\n```") - result = parse_req.validation_fn(ctx) + result = validation_fn(parse_req)(ctx) assert result.as_bool() is False - assert "syntax" in result.reason.lower() + assert "syntax" in validation_reason(result).lower() def test_valid_code_from_tool_calls(self): """Valid code in tool_calls should parse.""" - bundle = PythonToolRequirements() - parse_reqs = [ - r for r in bundle.requirements if "parse" in r.description.lower() - ] + parse_reqs = requirements_matching("parse") parse_req = parse_reqs[0] # Create context with tool_calls instead of markdown ctx = ChatContext() - tool_call = type("Call", (), {"args": {"code": "x = 1\nprint(x)"}})() - output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call("x = 1\nprint(x)")} + ) ctx = ctx.add(output) - result = parse_req.validation_fn(ctx) + result = validation_fn(parse_req)(ctx) assert result.as_bool() is True def test_syntax_error_from_tool_calls(self): """Syntax errors in tool_calls should be caught.""" - bundle = PythonToolRequirements() - parse_reqs = [ - r for r in bundle.requirements if "parse" in r.description.lower() - ] + parse_reqs = requirements_matching("parse") parse_req = parse_reqs[0] # Create context with tool_calls containing syntax error ctx = ChatContext() - tool_call = type("Call", (), {"args": {"code": "def foo(\n return 42"}})() - output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call("def foo(\n return 42")} + ) ctx = ctx.add(output) - result = parse_req.validation_fn(ctx) + result = validation_fn(parse_req)(ctx) assert result.as_bool() is False - assert "syntax" in result.reason.lower() + assert "syntax" in validation_reason(result).lower() class TestImportAllowlistRequirement: @@ -336,10 +384,10 @@ class TestImportAllowlistRequirement: def test_allowed_imports(self): """Allowed imports should pass.""" allowed = ["numpy", "matplotlib"] - bundle = PythonToolRequirements(allowed_imports=allowed) - import_reqs = [ - r for r in bundle.requirements if "import" in r.description.lower() + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() ] assert len(import_reqs) > 0 import_req = import_reqs[0] @@ -347,67 +395,69 @@ def test_allowed_imports(self): ctx = from_model( "```python\nimport numpy\nimport matplotlib.pyplot as plt\n```" ) - result = import_req.validation_fn(ctx) + result = validation_fn(import_req)(ctx) assert result.as_bool() is True def test_unauthorized_imports(self): """Unauthorized imports should fail.""" allowed = ["numpy"] - bundle = PythonToolRequirements(allowed_imports=allowed) - import_reqs = [ - r for r in bundle.requirements if "import" in r.description.lower() + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() ] import_req = import_reqs[0] ctx = from_model("```python\nimport subprocess\n```") - result = import_req.validation_fn(ctx) + result = validation_fn(import_req)(ctx) assert result.as_bool() is False - assert "subprocess" in result.reason + assert "subprocess" in validation_reason(result) def test_allowed_imports_from_tool_calls(self): """Allowed imports in tool_calls should pass.""" allowed = ["numpy", "matplotlib"] - bundle = PythonToolRequirements(allowed_imports=allowed) - import_reqs = [ - r for r in bundle.requirements if "import" in r.description.lower() + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() ] import_req = import_reqs[0] # Create context with tool_calls ctx = ChatContext() code = "import numpy\nimport matplotlib.pyplot as plt\nprint(numpy.pi)" - tool_call = type("Call", (), {"args": {"code": code}})() - output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call(code)} + ) ctx = ctx.add(output) - result = import_req.validation_fn(ctx) + result = validation_fn(import_req)(ctx) assert result.as_bool() is True def test_unauthorized_imports_from_tool_calls(self): """Unauthorized imports in tool_calls should fail.""" allowed = ["numpy"] - bundle = PythonToolRequirements(allowed_imports=allowed) - import_reqs = [ - r for r in bundle.requirements if "import" in r.description.lower() + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() ] import_req = import_reqs[0] # Create context with tool_calls ctx = ChatContext() - tool_call = type("Call", (), {"args": {"code": "import subprocess\n"}})() - output = ModelOutputThunk(value="", tool_calls={"python": tool_call}) + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call("import subprocess\n")} + ) ctx = ctx.add(output) - result = import_req.validation_fn(ctx) + result = validation_fn(import_req)(ctx) assert result.as_bool() is False - assert "subprocess" in result.reason + assert "subprocess" in validation_reason(result) class TestMatplotlibHeadlessRequirement: @@ -415,25 +465,21 @@ class TestMatplotlibHeadlessRequirement: def test_plt_show_without_backend(self): """plt.show() without headless backend should fail.""" - bundle = PythonToolRequirements() - matplotlib_reqs = [ - r for r in bundle.requirements if "headless" in r.description.lower() - ] + matplotlib_reqs = requirements_matching("headless") assert len(matplotlib_reqs) > 0 matplotlib_req = matplotlib_reqs[0] ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.show()\n```") - result = matplotlib_req.validation_fn(ctx) + result = validation_fn(matplotlib_req)(ctx) assert result.as_bool() is False - assert "headless" in result.reason.lower() or "Agg" in result.reason + assert "headless" in validation_reason( + result + ).lower() or "Agg" in validation_reason(result) def test_plt_show_with_backend(self): """plt.show() with headless backend should pass.""" - bundle = PythonToolRequirements() - matplotlib_reqs = [ - r for r in bundle.requirements if "headless" in r.description.lower() - ] + matplotlib_reqs = requirements_matching("headless") matplotlib_req = matplotlib_reqs[0] ctx = from_model( @@ -444,20 +490,17 @@ def test_plt_show_with_backend(self): "plt.show()\n" "```" ) - result = matplotlib_req.validation_fn(ctx) + result = validation_fn(matplotlib_req)(ctx) assert result.as_bool() is True def test_no_plt_show(self): """Code without plt.show() should pass.""" - bundle = PythonToolRequirements() - matplotlib_reqs = [ - r for r in bundle.requirements if "headless" in r.description.lower() - ] + matplotlib_reqs = requirements_matching("headless") matplotlib_req = matplotlib_reqs[0] ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.savefig('plot.png')\n```") - result = matplotlib_req.validation_fn(ctx) + result = validation_fn(matplotlib_req)(ctx) assert result.as_bool() is True @@ -467,42 +510,33 @@ class TestPlotsAreSavedRequirement: def test_plot_without_savefig(self): """Plotting without savefig should fail.""" - bundle = PythonToolRequirements() - plot_reqs = [ - r for r in bundle.requirements if "savefig" in r.description.lower() - ] + plot_reqs = requirements_matching("savefig") assert len(plot_reqs) > 0 plot_req = plot_reqs[0] ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.show()\n```") - result = plot_req.validation_fn(ctx) + result = validation_fn(plot_req)(ctx) assert result.as_bool() is False - assert "savefig" in result.reason.lower() + assert "savefig" in validation_reason(result).lower() def test_plot_with_savefig(self): """Plotting with savefig should pass.""" - bundle = PythonToolRequirements() - plot_reqs = [ - r for r in bundle.requirements if "savefig" in r.description.lower() - ] + plot_reqs = requirements_matching("savefig") plot_req = plot_reqs[0] ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.savefig('plot.png')\n```") - result = plot_req.validation_fn(ctx) + result = validation_fn(plot_req)(ctx) assert result.as_bool() is True def test_no_plotting(self): """Code without plotting should pass.""" - bundle = PythonToolRequirements() - plot_reqs = [ - r for r in bundle.requirements if "savefig" in r.description.lower() - ] + plot_reqs = requirements_matching("savefig") plot_req = plot_reqs[0] ctx = from_model("```python\nx = 1\nprint(x)\n```") - result = plot_req.validation_fn(ctx) + result = validation_fn(plot_req)(ctx) assert result.as_bool() is True @@ -515,18 +549,17 @@ def test_output_file_not_created(self): with tempfile.TemporaryDirectory() as tmpdir: output_path = str(Path(tmpdir) / "plot.png") - bundle = PythonToolRequirements(output_path=output_path) - artifact_reqs = [ - r for r in bundle.requirements if "output file" in r.description.lower() - ] + artifact_reqs = requirements_matching( + "output file", output_path=output_path + ) assert len(artifact_reqs) > 0 artifact_req = artifact_reqs[0] ctx = from_model("Code ran successfully") - result = artifact_req.validation_fn(ctx) + result = validation_fn(artifact_req)(ctx) assert result.as_bool() is False - assert output_path in result.reason + assert output_path in validation_reason(result) def test_output_file_exists_and_nonempty(self): """Should pass if output file exists and is non-empty.""" @@ -534,14 +567,13 @@ def test_output_file_exists_and_nonempty(self): output_path = str(Path(tmpdir) / "plot.png") Path(output_path).write_bytes(b"fake png data") - bundle = PythonToolRequirements(output_path=output_path) - artifact_reqs = [ - r for r in bundle.requirements if "output file" in r.description.lower() - ] + artifact_reqs = requirements_matching( + "output file", output_path=output_path + ) artifact_req = artifact_reqs[0] ctx = from_model("Code ran successfully") - result = artifact_req.validation_fn(ctx) + result = validation_fn(artifact_req)(ctx) assert result.as_bool() is True @@ -551,24 +583,20 @@ def test_output_file_empty(self): output_path = str(Path(tmpdir) / "plot.png") Path(output_path).write_bytes(b"") - bundle = PythonToolRequirements(output_path=output_path) - artifact_reqs = [ - r for r in bundle.requirements if "output file" in r.description.lower() - ] + artifact_reqs = requirements_matching( + "output file", output_path=output_path + ) artifact_req = artifact_reqs[0] ctx = from_model("Code ran successfully") - result = artifact_req.validation_fn(ctx) + result = validation_fn(artifact_req)(ctx) assert result.as_bool() is False - assert "empty" in result.reason.lower() + assert "empty" in validation_reason(result).lower() def test_output_artifact_disabled_without_output_path(self): """Output artifact requirement should not be present without output_path.""" - bundle = PythonToolRequirements() - artifact_reqs = [ - r for r in bundle.requirements if "output file" in r.description.lower() - ] + artifact_reqs = requirements_matching("output file") assert len(artifact_reqs) == 0 @@ -578,131 +606,107 @@ class TestOutputLimitValidator: def test_empty_output_passes(self): """No stdout/stderr should pass.""" ctx = ChatContext().add(ModelOutputThunk(value="response")) - bundle = PythonToolRequirements(output_limit_bytes=1000) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) assert len(limit_reqs) > 0 limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is True def test_output_within_limit_passes(self): """Output under limit should pass.""" output = ModelOutputThunk(value="response") - output.stdout = "x" * 500 - output.stderr = "" + setattr(output, "stdout", "x" * 500) + setattr(output, "stderr", "") ctx = ChatContext().add(output) - bundle = PythonToolRequirements(output_limit_bytes=1000) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is True def test_stdout_exceeds_limit_fails(self): """Stdout exceeding limit should fail.""" output = ModelOutputThunk(value="response") - output.stdout = "x" * 1500 - output.stderr = "" + setattr(output, "stdout", "x" * 1500) + setattr(output, "stderr", "") ctx = ChatContext().add(output) - bundle = PythonToolRequirements(output_limit_bytes=1000) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is False - assert "exceeding" in result.reason.lower() - assert "1500" in result.reason + assert "exceeding" in validation_reason(result).lower() + assert "1500" in validation_reason(result) def test_stderr_exceeds_limit_fails(self): """Stderr exceeding limit should fail.""" output = ModelOutputThunk(value="response") - output.stdout = "" - output.stderr = "e" * 1500 + setattr(output, "stdout", "") + setattr(output, "stderr", "e" * 1500) ctx = ChatContext().add(output) - bundle = PythonToolRequirements(output_limit_bytes=1000) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is False - assert "exceeding" in result.reason.lower() + assert "exceeding" in validation_reason(result).lower() def test_combined_output_exceeds_limit_fails(self): """Combined stdout+stderr exceeding limit should fail.""" output = ModelOutputThunk(value="response") - output.stdout = "x" * 600 - output.stderr = "e" * 600 # Combined: 1200 bytes + setattr(output, "stdout", "x" * 600) + setattr(output, "stderr", "e" * 600) # Combined: 1200 bytes ctx = ChatContext().add(output) - bundle = PythonToolRequirements(output_limit_bytes=1000) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is False - assert "1200" in result.reason + assert "1200" in validation_reason(result) def test_utf8_multibyte_characters_counted_correctly(self): """Multibyte UTF-8 characters should be counted in bytes, not chars.""" output = ModelOutputThunk(value="response") - output.stdout = "🎉" * 100 # 4 bytes per emoji = 400 bytes - output.stderr = "" + setattr(output, "stdout", "🎉" * 100) # 4 bytes per emoji = 400 bytes + setattr(output, "stderr", "") ctx = ChatContext().add(output) - bundle = PythonToolRequirements(output_limit_bytes=300) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=300) limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is False # 400 > 300 - assert "exceeding" in result.reason.lower() + assert "exceeding" in validation_reason(result).lower() def test_limit_at_boundary(self): """Output exactly at limit should pass.""" output = ModelOutputThunk(value="response") - output.stdout = "x" * 1000 # Exactly 1000 bytes - output.stderr = "" + setattr(output, "stdout", "x" * 1000) # Exactly 1000 bytes + setattr(output, "stderr", "") ctx = ChatContext().add(output) - bundle = PythonToolRequirements(output_limit_bytes=1000) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is True def test_limit_one_byte_over_boundary_fails(self): """Output one byte over limit should fail.""" output = ModelOutputThunk(value="response") - output.stdout = "x" * 1001 # 1001 bytes - output.stderr = "" + setattr(output, "stdout", "x" * 1001) # 1001 bytes + setattr(output, "stderr", "") ctx = ChatContext().add(output) - bundle = PythonToolRequirements(output_limit_bytes=1000) - limit_reqs = [ - r for r in bundle.requirements if "exceed" in r.description.lower() - ] + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) limit_req = limit_reqs[0] - result = limit_req.validation_fn(ctx) + result = validation_fn(limit_req)(ctx) assert result.as_bool() is False From 577eb00562bd261678806539ae4b7963df615cd5 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 11 May 2026 18:04:01 -0400 Subject: [PATCH 06/12] review comments Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/python_tools.py | 332 ++++++++++++++++++--- 1 file changed, 290 insertions(+), 42 deletions(-) diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 8e63b32ed..3f06082b6 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -11,22 +11,27 @@ from pathlib import Path from ...core import Context, Requirement, ValidationResult +from ..tools.interpreter import StaticAnalysisEnvironment from .imports import get_unauthorized_imports from .python_reqs import extract_python_code from .tool_reqs import tool_arg_validator, uses_tool def _code_parses(code: str) -> tuple[bool, str | None]: - """Check if code parses as valid Python. + """Check if code parses as valid Python using StaticAnalysisEnvironment. + + Validates syntax without executing code. Reuses StaticAnalysisEnvironment + to avoid duplicating AST parsing logic. Returns: (True, None) if code parses (False, error_message) if syntax error """ - try: - ast.parse(code) - return True, None - except SyntaxError as e: + env = StaticAnalysisEnvironment(allowed_imports=None) + result = env.execute(code, timeout=0) + + if not result.success and isinstance(result.analysis_result, SyntaxError): + e = result.analysis_result error_msg = f"Syntax error at line {e.lineno}: {e.msg}" if e.text: error_msg += f"\n {e.text.rstrip()}" @@ -34,6 +39,8 @@ def _code_parses(code: str) -> tuple[bool, str | None]: error_msg += "\n " + " " * (e.offset - 1) + "^" return False, error_msg + return True, None + def _strip_comments(code: str) -> str: """Remove Python comments from code while preserving strings. @@ -62,46 +69,209 @@ def _strip_comments(code: str) -> str: return "\n".join(result) -def _uses_pyplot_show(code: str) -> bool: - """Check if code calls plt.show() or matplotlib.pyplot.show().""" +def _find_attribute_calls(code: str, method_names: list[str]) -> bool: + """Check if code calls any of the specified methods using AST. + + Handles import aliases (e.g., `import matplotlib.pyplot as plt`) and + validates that methods are actually called, not just referenced. + + Args: + code: Python source code to analyze + method_names: Method names to look for (e.g., ["show", "savefig"]) + + Returns: + True if any of the methods are called, False otherwise + """ + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return False + + class CallFinder(ast.NodeVisitor): + def __init__(self, method_names: list[str]): + self.method_names = set(method_names) + self.found = False + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute): + if node.func.attr in self.method_names: + self.found = True + self.generic_visit(node) + + finder = CallFinder(method_names) + finder.visit(tree) + return finder.found + + +def _find_function_calls(code: str, func_names: list[str]) -> bool: + """Check if code calls any of the specified functions using AST. + + Handles qualified names (e.g., `matplotlib.use()`) and detects actual + function calls, not just references. + + Args: + code: Python source code to analyze + func_names: Function names to look for (e.g., ["matplotlib.use"]) + + Returns: + True if any of the functions are called, False otherwise + """ + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return False + + class FunctionCallFinder(ast.NodeVisitor): + def __init__(self, func_names: list[str]): + self.func_names = set(func_names) + self.found = False + + def visit_Call(self, node: ast.Call) -> None: + func_name = self._get_full_name(node.func) + if func_name in self.func_names: + self.found = True + self.generic_visit(node) + + def _get_full_name(self, node: ast.expr) -> str: + """Extract full qualified name from an AST node.""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + value_name = self._get_full_name(node.value) + if value_name: + return f"{value_name}.{node.attr}" + return node.attr + return "" + + finder = FunctionCallFinder(func_names) + finder.visit(tree) + return finder.found + + +def _code_contains_strings(code: str, patterns: list[str]) -> bool: + """Check if code contains any of the given string patterns. + + Args: + code: Python source code to search + patterns: List of string patterns to look for + + Returns: + True if any pattern is found in the code, False otherwise + """ clean_code = _strip_comments(code) - return "plt.show" in clean_code or ".show()" in clean_code + return any(pattern in clean_code for pattern in patterns) -def _sets_headless_backend(code: str) -> bool: - """Check if code sets matplotlib to use a headless backend.""" +def _code_contains_all_strings(code: str, patterns: list[str]) -> bool: + """Check if code contains all of the given string patterns. + + Args: + code: Python source code to search + patterns: List of string patterns that must all be present + + Returns: + True if all patterns are found in the code, False otherwise + """ clean_code = _strip_comments(code) - headless_backends = ("Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg") - for backend in headless_backends: - if ( - f"matplotlib.use('{backend}')" in clean_code - or f'matplotlib.use("{backend}")' in clean_code - ): - return True + return all(pattern in clean_code for pattern in patterns) + + +def _uses_pyplot_show(code: str) -> bool: + """Check if code calls plt.show() or similar show() methods. + + Uses AST analysis to robustly detect show() calls regardless of import + aliases (e.g., `import matplotlib.pyplot as mpl`). AST approach detects + actual method calls, avoiding false positives from string literals. + Falls back to string matching only if code doesn't parse. + """ + if _find_attribute_calls(code, ["show"]): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings(code, ["plt.show", ".show()"]) return False -def _uses_pyplot_plot(code: str) -> bool: - """Check if code calls pyplot plotting functions.""" - plot_functions = ( - "plt.plot", - "plt.bar", - "plt.scatter", - "plt.hist", - "plt.imshow", - "plt.figure", - "plt.subplot", - ".plot(", - ".bar(", - ".scatter(", - ".hist(", +def _sets_headless_backend(code: str) -> bool: + """Check if code sets matplotlib to use a headless backend. + + Uses AST analysis to detect matplotlib.use() calls with headless backends. + Handles various matplotlib import styles and fallback to string matching. + """ + if _find_function_calls(code, ["matplotlib.use"]): + headless_backends = {"Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"} + + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings( + code, [f"matplotlib.use('{b}')" for b in headless_backends] + ) + + class BackendFinder(ast.NodeVisitor): + def __init__(self): + self.has_headless = False + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute): + if node.func.attr == "use": + if isinstance(node.func.value, ast.Name): + if node.func.value.id == "matplotlib": + if node.args and isinstance(node.args[0], ast.Constant): + if node.args[0].value in headless_backends: + self.has_headless = True + self.generic_visit(node) + + finder = BackendFinder() + finder.visit(tree) + if finder.has_headless: + return True + + return _code_contains_strings( + code, + [ + f"matplotlib.use('{b}')" + for b in ["Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"] + ], ) - return any(func in code for func in plot_functions) + + +def _uses_pyplot_plot(code: str) -> bool: + """Check if code calls pyplot plotting functions. + + Uses AST analysis to detect plot-related method calls. Handles import + aliases and detects actual method calls, avoiding false positives from + string literals or method references. Falls back to string matching + only if code doesn't parse. + """ + plot_methods = {"plot", "bar", "scatter", "hist", "imshow", "figure", "subplot"} + if _find_attribute_calls(code, list(plot_methods)): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings( + code, [f".{m}(" for m in plot_methods] + [f"plt.{m}" for m in plot_methods] + ) + return False def _calls_savefig(code: str) -> bool: - """Check if code calls plt.savefig() or fig.savefig().""" - return "savefig" in code + """Check if code calls plt.savefig() or fig.savefig(). + + Uses AST analysis to robustly detect savefig() calls regardless of + how matplotlib was imported. Detects actual method calls, avoiding + false positives from string literals. Falls back to string matching + only if code doesn't parse. + """ + if _find_attribute_calls(code, ["savefig"]): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings(code, ["savefig"]) + return False # region Individual Requirement Validators @@ -176,8 +346,35 @@ def validate(ctx: Context) -> ValidationResult: def _make_matplotlib_headless_validator( output_path: str | None = None, + show_patterns: list[str] | None = None, + backend_patterns: list[str] | None = None, ) -> Callable[[Context], ValidationResult]: - """Create a validator that checks matplotlib uses headless backend.""" + r"""Create a validator that checks matplotlib uses headless backend. + + Args: + output_path: Path where plots should be saved + show_patterns: Patterns indicating plt.show() calls (e.g., ["plt.show", ".show()"]) + backend_patterns: Patterns indicating headless backend setup (e.g., ["matplotlib.use('Agg')", "matplotlib.use(\"Agg\")"]) + """ + if show_patterns is None: + show_patterns = ["plt.show", ".show()"] + if backend_patterns is None: + backend_patterns = [ + "matplotlib.use('Agg')", + 'matplotlib.use("Agg")', + "matplotlib.use('Svg')", + 'matplotlib.use("Svg")', + "matplotlib.use('Cairo')", + 'matplotlib.use("Cairo")', + "matplotlib.use('PDF')", + 'matplotlib.use("PDF")', + "matplotlib.use('PS')", + 'matplotlib.use("PS")', + "matplotlib.use('WebAgg')", + 'matplotlib.use("WebAgg")', + "matplotlib.use('nbAgg')", + 'matplotlib.use("nbAgg")', + ] def validate(ctx: Context) -> ValidationResult: extraction_result = extract_python_code(ctx) @@ -185,7 +382,10 @@ def validate(ctx: Context) -> ValidationResult: return ValidationResult(result=True) code = extraction_result.reason - if _uses_pyplot_show(code) and not _sets_headless_backend(code): + has_show = _code_contains_strings(code, show_patterns) + has_backend = _code_contains_strings(code, backend_patterns) + + if has_show and not has_backend: savefig_instruction = ( f"plt.savefig('{output_path}'); plt.close()" if output_path @@ -208,8 +408,32 @@ def validate(ctx: Context) -> ValidationResult: def _make_plots_saved_validator( output_path: str | None = None, + plot_patterns: list[str] | None = None, + save_patterns: list[str] | None = None, ) -> Callable[[Context], ValidationResult]: - """Create a validator that checks if code saves plots to a file.""" + """Create a validator that checks if code saves plots to a file. + + Args: + output_path: Path where plots should be saved + plot_patterns: Patterns indicating plot creation (e.g., ["plt.plot", "plt.scatter"]) + save_patterns: Patterns indicating plot saving (e.g., ["savefig"]) + """ + if plot_patterns is None: + plot_patterns = [ + "plt.plot", + "plt.bar", + "plt.scatter", + "plt.hist", + "plt.imshow", + "plt.figure", + "plt.subplot", + ".plot(", + ".bar(", + ".scatter(", + ".hist(", + ] + if save_patterns is None: + save_patterns = ["savefig"] def validate(ctx: Context) -> ValidationResult: extraction_result = extract_python_code(ctx) @@ -217,7 +441,10 @@ def validate(ctx: Context) -> ValidationResult: return ValidationResult(result=True) code = extraction_result.reason - if _uses_pyplot_plot(code) and not _calls_savefig(code): + has_plot = _code_contains_strings(code, plot_patterns) + has_save = _code_contains_strings(code, save_patterns) + + if has_plot and not has_save: savefig_instruction = ( f"plt.savefig('{output_path}')\n plt.close()" if output_path @@ -236,15 +463,34 @@ def validate(ctx: Context) -> ValidationResult: def python_plotting_requirements( - output_path: str | None = None, *, check_output_artifacts: bool | None = None + output_path: str | None = None, + *, + check_output_artifacts: bool | None = None, + show_patterns: list[str] | None = None, + backend_patterns: list[str] | None = None, + plot_patterns: list[str] | None = None, + save_patterns: list[str] | None = None, ) -> list[Requirement]: - """Build plotting-specific requirements for Python tool responses.""" + """Build plotting-specific requirements for Python tool responses. + + Args: + output_path: Path where plots should be saved + check_output_artifacts: Whether to verify the output file exists + show_patterns: Patterns indicating plt.show() calls + backend_patterns: Patterns indicating headless backend setup + plot_patterns: Patterns indicating plot creation + save_patterns: Patterns indicating plot saving (e.g., savefig) + """ reqs: list[Requirement] = [] reqs.append( Requirement( description="If using pyplot, must set headless backend and use savefig.", - validation_fn=_make_matplotlib_headless_validator(output_path), + validation_fn=_make_matplotlib_headless_validator( + output_path, + show_patterns=show_patterns, + backend_patterns=backend_patterns, + ), check_only=False, ) ) @@ -252,7 +498,9 @@ def python_plotting_requirements( reqs.append( Requirement( description="If creating plots, must call savefig to save them.", - validation_fn=_make_plots_saved_validator(output_path), + validation_fn=_make_plots_saved_validator( + output_path, plot_patterns=plot_patterns, save_patterns=save_patterns + ), check_only=False, ) ) From 41f09ccc4711b5144b60e4ce90fd444d85d6d19d Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 11 May 2026 18:26:58 -0400 Subject: [PATCH 07/12] matplotlib restructure Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/__init__.py | 3 +- .../stdlib/requirements/plotting/__init__.py | 9 + .../requirements/plotting/matplotlib.py | 431 +++++++++++++++++ mellea/stdlib/requirements/python_tools.py | 439 +----------------- test/stdlib/requirements/test_python_tools.py | 6 +- 5 files changed, 449 insertions(+), 439 deletions(-) create mode 100644 mellea/stdlib/requirements/plotting/__init__.py create mode 100644 mellea/stdlib/requirements/plotting/matplotlib.py diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index 46177ff88..4c9fec263 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -3,8 +3,9 @@ # Import from core for ergonomics. from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table +from .plotting import python_plotting_requirements from .python_reqs import PythonExecutionReq -from .python_tools import python_plotting_requirements, python_tool_requirements +from .python_tools import python_tool_requirements from .requirement import ( ALoraRequirement, LLMaJRequirement, diff --git a/mellea/stdlib/requirements/plotting/__init__.py b/mellea/stdlib/requirements/plotting/__init__.py new file mode 100644 index 000000000..c10a7b080 --- /dev/null +++ b/mellea/stdlib/requirements/plotting/__init__.py @@ -0,0 +1,9 @@ +"""Plotting-specific requirements for Python tool validation. + +Provides matplotlib and plotting-focused requirement factories separate from +generic Python tool requirements. +""" + +from .matplotlib import python_plotting_requirements + +__all__ = ["python_plotting_requirements"] diff --git a/mellea/stdlib/requirements/plotting/matplotlib.py b/mellea/stdlib/requirements/plotting/matplotlib.py new file mode 100644 index 000000000..682850d43 --- /dev/null +++ b/mellea/stdlib/requirements/plotting/matplotlib.py @@ -0,0 +1,431 @@ +"""Matplotlib-specific requirement validators for plotting code. + +Provides plotting validation including headless backend detection, plot saving, +and output artifact validation. Uses AST-based analysis for robust detection +of matplotlib operations, with string matching fallback for edge cases. +""" + +import ast +from collections.abc import Callable +from pathlib import Path + +from ....core import Context, Requirement, ValidationResult +from ..python_reqs import extract_python_code + + +def _strip_comments(code: str) -> str: + """Remove Python comments from code while preserving strings. + + Splits code by lines and removes comments (text after # that's not in a string). + Handles both single and double quoted strings. + """ + lines = code.split("\n") + result = [] + for line in lines: + in_string = False + string_char = None + for i, char in enumerate(line): + if char in ('"', "'") and (i == 0 or line[i - 1] != "\\"): + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None + elif char == "#" and not in_string: + result.append(line[:i]) + break + else: + result.append(line) + return "\n".join(result) + + +def _find_attribute_calls(code: str, method_names: list[str]) -> bool: + """Check if code calls any of the specified methods using AST. + + Handles import aliases (e.g., `import matplotlib.pyplot as plt`) and + validates that methods are actually called, not just referenced. + + Args: + code: Python source code to analyze + method_names: Method names to look for (e.g., ["show", "savefig"]) + + Returns: + True if any of the methods are called, False otherwise + """ + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return False + + class CallFinder(ast.NodeVisitor): + def __init__(self, method_names: list[str]): + self.method_names = set(method_names) + self.found = False + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute): + if node.func.attr in self.method_names: + self.found = True + self.generic_visit(node) + + finder = CallFinder(method_names) + finder.visit(tree) + return finder.found + + +def _find_function_calls(code: str, func_names: list[str]) -> bool: + """Check if code calls any of the specified functions using AST. + + Handles qualified names (e.g., `matplotlib.use()`) and detects actual + function calls, not just references. + + Args: + code: Python source code to analyze + func_names: Function names to look for (e.g., ["matplotlib.use"]) + + Returns: + True if any of the functions are called, False otherwise + """ + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return False + + class FunctionCallFinder(ast.NodeVisitor): + def __init__(self, func_names: list[str]): + self.func_names = set(func_names) + self.found = False + + def visit_Call(self, node: ast.Call) -> None: + func_name = self._get_full_name(node.func) + if func_name in self.func_names: + self.found = True + self.generic_visit(node) + + def _get_full_name(self, node: ast.expr) -> str: + """Extract full qualified name from an AST node.""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + value_name = self._get_full_name(node.value) + if value_name: + return f"{value_name}.{node.attr}" + return node.attr + return "" + + finder = FunctionCallFinder(func_names) + finder.visit(tree) + return finder.found + + +def _code_contains_strings(code: str, patterns: list[str]) -> bool: + """Check if code contains any of the given string patterns. + + Args: + code: Python source code to search + patterns: List of string patterns to look for + + Returns: + True if any pattern is found in the code, False otherwise + """ + clean_code = _strip_comments(code) + return any(pattern in clean_code for pattern in patterns) + + +def _uses_pyplot_show(code: str) -> bool: + """Check if code calls plt.show() or similar show() methods. + + Uses AST analysis to robustly detect show() calls regardless of import + aliases (e.g., `import matplotlib.pyplot as mpl`). AST approach detects + actual method calls, avoiding false positives from string literals. + Falls back to string matching only if code doesn't parse. + """ + if _find_attribute_calls(code, ["show"]): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings(code, ["plt.show", ".show()"]) + return False + + +def _sets_headless_backend(code: str) -> bool: + """Check if code sets matplotlib to use a headless backend. + + Uses AST analysis to detect matplotlib.use() calls with headless backends. + Handles various matplotlib import styles and fallback to string matching. + """ + if _find_function_calls(code, ["matplotlib.use"]): + headless_backends = {"Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"} + + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings( + code, [f"matplotlib.use('{b}')" for b in headless_backends] + ) + + class BackendFinder(ast.NodeVisitor): + def __init__(self): + self.has_headless = False + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute): + if node.func.attr == "use": + if isinstance(node.func.value, ast.Name): + if node.func.value.id == "matplotlib": + if node.args and isinstance(node.args[0], ast.Constant): + if node.args[0].value in headless_backends: + self.has_headless = True + self.generic_visit(node) + + finder = BackendFinder() + finder.visit(tree) + if finder.has_headless: + return True + + return _code_contains_strings( + code, + [ + f"matplotlib.use('{b}')" + for b in ["Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"] + ], + ) + + +def _uses_pyplot_plot(code: str) -> bool: + """Check if code calls pyplot plotting functions. + + Uses AST analysis to detect plot-related method calls. Handles import + aliases and detects actual method calls, avoiding false positives from + string literals or method references. Falls back to string matching + only if code doesn't parse. + """ + plot_methods = {"plot", "bar", "scatter", "hist", "imshow", "figure", "subplot"} + if _find_attribute_calls(code, list(plot_methods)): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings( + code, [f".{m}(" for m in plot_methods] + [f"plt.{m}" for m in plot_methods] + ) + return False + + +def _calls_savefig(code: str) -> bool: + """Check if code calls plt.savefig() or fig.savefig(). + + Uses AST analysis to robustly detect savefig() calls regardless of + how matplotlib was imported. Detects actual method calls, avoiding + false positives from string literals. Falls back to string matching + only if code doesn't parse. + """ + if _find_attribute_calls(code, ["savefig"]): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings(code, ["savefig"]) + return False + + +def _make_matplotlib_headless_validator( + output_path: str | None = None, + show_patterns: list[str] | None = None, + backend_patterns: list[str] | None = None, +) -> Callable[[Context], ValidationResult]: + r"""Create a validator that checks matplotlib uses headless backend. + + Args: + output_path: Path where plots should be saved + show_patterns: Patterns indicating plt.show() calls (e.g., ["plt.show", ".show()"]) + backend_patterns: Patterns indicating headless backend setup (e.g., ["matplotlib.use('Agg')", "matplotlib.use(\"Agg\")"]) + """ + if show_patterns is None: + show_patterns = ["plt.show", ".show()"] + if backend_patterns is None: + backend_patterns = [ + "matplotlib.use('Agg')", + 'matplotlib.use("Agg")', + "matplotlib.use('Svg')", + 'matplotlib.use("Svg")', + "matplotlib.use('Cairo')", + 'matplotlib.use("Cairo")', + "matplotlib.use('PDF')", + 'matplotlib.use("PDF")', + "matplotlib.use('PS')", + 'matplotlib.use("PS")', + "matplotlib.use('WebAgg')", + 'matplotlib.use("WebAgg")', + "matplotlib.use('nbAgg')", + 'matplotlib.use("nbAgg")', + ] + + def validate(ctx: Context) -> ValidationResult: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: + return ValidationResult(result=True) + + code = extraction_result.reason + has_show = _code_contains_strings(code, show_patterns) + has_backend = _code_contains_strings(code, backend_patterns) + + if has_show and not has_backend: + savefig_instruction = ( + f"plt.savefig('{output_path}'); plt.close()" + if output_path + else "plt.savefig('{output_path}'); plt.close()" + ) + return ValidationResult( + result=False, + reason=f"Your code calls `plt.show()` but doesn't set a headless backend.\n" + f"This will fail in a headless environment (no display).\n\n" + f"Fix this by adding to the top of your code:\n" + f" import matplotlib\n" + f" matplotlib.use('Agg')\n\n" + f"Then replace `plt.show()` with `{savefig_instruction}`", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_plots_saved_validator( + output_path: str | None = None, + plot_patterns: list[str] | None = None, + save_patterns: list[str] | None = None, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks if code saves plots to a file. + + Args: + output_path: Path where plots should be saved + plot_patterns: Patterns indicating plot creation (e.g., ["plt.plot", "plt.scatter"]) + save_patterns: Patterns indicating plot saving (e.g., ["savefig"]) + """ + if plot_patterns is None: + plot_patterns = [ + "plt.plot", + "plt.bar", + "plt.scatter", + "plt.hist", + "plt.imshow", + "plt.figure", + "plt.subplot", + ".plot(", + ".bar(", + ".scatter(", + ".hist(", + ] + if save_patterns is None: + save_patterns = ["savefig"] + + def validate(ctx: Context) -> ValidationResult: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: + return ValidationResult(result=True) + + code = extraction_result.reason + has_plot = _code_contains_strings(code, plot_patterns) + has_save = _code_contains_strings(code, save_patterns) + + if has_plot and not has_save: + savefig_instruction = ( + f"plt.savefig('{output_path}')\n plt.close()" + if output_path + else "plt.savefig('{output_path}')\n plt.close()" + ) + return ValidationResult( + result=False, + reason=f"Your code creates plots with pyplot but never calls `plt.savefig()` to save them.\n\n" + f"Add this before your plotting code or at the end:\n" + f" {savefig_instruction}", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_output_artifacts_validator( + output_path: str, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks if output file exists post-execution.""" + + def validate(ctx: Context) -> ValidationResult: + path = Path(output_path) + + if not path.exists(): + return ValidationResult( + result=False, + reason=f"The output file '{output_path}' was not created during execution.\n" + f"Make sure your code calls `plt.savefig('{output_path}')` to save the plot.", + ) + + if path.stat().st_size == 0: + return ValidationResult( + result=False, + reason=f"The output file '{output_path}' exists but is empty.\n" + f"Check that your plot code executed correctly.", + ) + + return ValidationResult(result=True) + + return validate + + +def python_plotting_requirements( + output_path: str | None = None, + *, + check_output_artifacts: bool | None = None, + show_patterns: list[str] | None = None, + backend_patterns: list[str] | None = None, + plot_patterns: list[str] | None = None, + save_patterns: list[str] | None = None, +) -> list[Requirement]: + """Build plotting-specific requirements for Python tool responses. + + Args: + output_path: Path where plots should be saved + check_output_artifacts: Whether to verify the output file exists + show_patterns: Patterns indicating plt.show() calls + backend_patterns: Patterns indicating headless backend setup + plot_patterns: Patterns indicating plot creation + save_patterns: Patterns indicating plot saving (e.g., savefig) + """ + reqs: list[Requirement] = [] + + reqs.append( + Requirement( + description="If using pyplot, must set headless backend and use savefig.", + validation_fn=_make_matplotlib_headless_validator( + output_path, + show_patterns=show_patterns, + backend_patterns=backend_patterns, + ), + check_only=False, + ) + ) + + reqs.append( + Requirement( + description="If creating plots, must call savefig to save them.", + validation_fn=_make_plots_saved_validator( + output_path, plot_patterns=plot_patterns, save_patterns=save_patterns + ), + check_only=False, + ) + ) + + if check_output_artifacts and output_path: + reqs.append( + Requirement( + description=f"Output file must be created at {output_path}", + validation_fn=_make_output_artifacts_validator(output_path), + check_only=False, + ) + ) + + return reqs diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 3f06082b6..618b1bb87 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -2,17 +2,16 @@ This module provides generic requirements for Python-tool usage and code correctness. Plotting-specific checks are exposed separately through -``python_plotting_requirements(...)`` so they are not implied to be universal -Python-tool requirements. +``plotting.python_plotting_requirements(...)`` so they are not implied to be +universal Python-tool requirements. """ -import ast from collections.abc import Callable -from pathlib import Path from ...core import Context, Requirement, ValidationResult from ..tools.interpreter import StaticAnalysisEnvironment from .imports import get_unauthorized_imports +from .plotting import python_plotting_requirements from .python_reqs import extract_python_code from .tool_reqs import tool_arg_validator, uses_tool @@ -42,238 +41,6 @@ def _code_parses(code: str) -> tuple[bool, str | None]: return True, None -def _strip_comments(code: str) -> str: - """Remove Python comments from code while preserving strings. - - Splits code by lines and removes comments (text after # that's not in a string). - Handles both single and double quoted strings. - """ - lines = code.split("\n") - result = [] - for line in lines: - in_string = False - string_char = None - for i, char in enumerate(line): - if char in ('"', "'") and (i == 0 or line[i - 1] != "\\"): - if not in_string: - in_string = True - string_char = char - elif char == string_char: - in_string = False - string_char = None - elif char == "#" and not in_string: - result.append(line[:i]) - break - else: - result.append(line) - return "\n".join(result) - - -def _find_attribute_calls(code: str, method_names: list[str]) -> bool: - """Check if code calls any of the specified methods using AST. - - Handles import aliases (e.g., `import matplotlib.pyplot as plt`) and - validates that methods are actually called, not just referenced. - - Args: - code: Python source code to analyze - method_names: Method names to look for (e.g., ["show", "savefig"]) - - Returns: - True if any of the methods are called, False otherwise - """ - try: - tree = ast.parse(code) - except (SyntaxError, ValueError): - return False - - class CallFinder(ast.NodeVisitor): - def __init__(self, method_names: list[str]): - self.method_names = set(method_names) - self.found = False - - def visit_Call(self, node: ast.Call) -> None: - if isinstance(node.func, ast.Attribute): - if node.func.attr in self.method_names: - self.found = True - self.generic_visit(node) - - finder = CallFinder(method_names) - finder.visit(tree) - return finder.found - - -def _find_function_calls(code: str, func_names: list[str]) -> bool: - """Check if code calls any of the specified functions using AST. - - Handles qualified names (e.g., `matplotlib.use()`) and detects actual - function calls, not just references. - - Args: - code: Python source code to analyze - func_names: Function names to look for (e.g., ["matplotlib.use"]) - - Returns: - True if any of the functions are called, False otherwise - """ - try: - tree = ast.parse(code) - except (SyntaxError, ValueError): - return False - - class FunctionCallFinder(ast.NodeVisitor): - def __init__(self, func_names: list[str]): - self.func_names = set(func_names) - self.found = False - - def visit_Call(self, node: ast.Call) -> None: - func_name = self._get_full_name(node.func) - if func_name in self.func_names: - self.found = True - self.generic_visit(node) - - def _get_full_name(self, node: ast.expr) -> str: - """Extract full qualified name from an AST node.""" - if isinstance(node, ast.Name): - return node.id - elif isinstance(node, ast.Attribute): - value_name = self._get_full_name(node.value) - if value_name: - return f"{value_name}.{node.attr}" - return node.attr - return "" - - finder = FunctionCallFinder(func_names) - finder.visit(tree) - return finder.found - - -def _code_contains_strings(code: str, patterns: list[str]) -> bool: - """Check if code contains any of the given string patterns. - - Args: - code: Python source code to search - patterns: List of string patterns to look for - - Returns: - True if any pattern is found in the code, False otherwise - """ - clean_code = _strip_comments(code) - return any(pattern in clean_code for pattern in patterns) - - -def _code_contains_all_strings(code: str, patterns: list[str]) -> bool: - """Check if code contains all of the given string patterns. - - Args: - code: Python source code to search - patterns: List of string patterns that must all be present - - Returns: - True if all patterns are found in the code, False otherwise - """ - clean_code = _strip_comments(code) - return all(pattern in clean_code for pattern in patterns) - - -def _uses_pyplot_show(code: str) -> bool: - """Check if code calls plt.show() or similar show() methods. - - Uses AST analysis to robustly detect show() calls regardless of import - aliases (e.g., `import matplotlib.pyplot as mpl`). AST approach detects - actual method calls, avoiding false positives from string literals. - Falls back to string matching only if code doesn't parse. - """ - if _find_attribute_calls(code, ["show"]): - return True - try: - ast.parse(code) - except (SyntaxError, ValueError): - return _code_contains_strings(code, ["plt.show", ".show()"]) - return False - - -def _sets_headless_backend(code: str) -> bool: - """Check if code sets matplotlib to use a headless backend. - - Uses AST analysis to detect matplotlib.use() calls with headless backends. - Handles various matplotlib import styles and fallback to string matching. - """ - if _find_function_calls(code, ["matplotlib.use"]): - headless_backends = {"Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"} - - try: - tree = ast.parse(code) - except (SyntaxError, ValueError): - return _code_contains_strings( - code, [f"matplotlib.use('{b}')" for b in headless_backends] - ) - - class BackendFinder(ast.NodeVisitor): - def __init__(self): - self.has_headless = False - - def visit_Call(self, node: ast.Call) -> None: - if isinstance(node.func, ast.Attribute): - if node.func.attr == "use": - if isinstance(node.func.value, ast.Name): - if node.func.value.id == "matplotlib": - if node.args and isinstance(node.args[0], ast.Constant): - if node.args[0].value in headless_backends: - self.has_headless = True - self.generic_visit(node) - - finder = BackendFinder() - finder.visit(tree) - if finder.has_headless: - return True - - return _code_contains_strings( - code, - [ - f"matplotlib.use('{b}')" - for b in ["Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"] - ], - ) - - -def _uses_pyplot_plot(code: str) -> bool: - """Check if code calls pyplot plotting functions. - - Uses AST analysis to detect plot-related method calls. Handles import - aliases and detects actual method calls, avoiding false positives from - string literals or method references. Falls back to string matching - only if code doesn't parse. - """ - plot_methods = {"plot", "bar", "scatter", "hist", "imshow", "figure", "subplot"} - if _find_attribute_calls(code, list(plot_methods)): - return True - try: - ast.parse(code) - except (SyntaxError, ValueError): - return _code_contains_strings( - code, [f".{m}(" for m in plot_methods] + [f"plt.{m}" for m in plot_methods] - ) - return False - - -def _calls_savefig(code: str) -> bool: - """Check if code calls plt.savefig() or fig.savefig(). - - Uses AST analysis to robustly detect savefig() calls regardless of - how matplotlib was imported. Detects actual method calls, avoiding - false positives from string literals. Falls back to string matching - only if code doesn't parse. - """ - if _find_attribute_calls(code, ["savefig"]): - return True - try: - ast.parse(code) - except (SyntaxError, ValueError): - return _code_contains_strings(code, ["savefig"]) - return False - - # region Individual Requirement Validators @@ -344,206 +111,6 @@ def validate(ctx: Context) -> ValidationResult: return validate -def _make_matplotlib_headless_validator( - output_path: str | None = None, - show_patterns: list[str] | None = None, - backend_patterns: list[str] | None = None, -) -> Callable[[Context], ValidationResult]: - r"""Create a validator that checks matplotlib uses headless backend. - - Args: - output_path: Path where plots should be saved - show_patterns: Patterns indicating plt.show() calls (e.g., ["plt.show", ".show()"]) - backend_patterns: Patterns indicating headless backend setup (e.g., ["matplotlib.use('Agg')", "matplotlib.use(\"Agg\")"]) - """ - if show_patterns is None: - show_patterns = ["plt.show", ".show()"] - if backend_patterns is None: - backend_patterns = [ - "matplotlib.use('Agg')", - 'matplotlib.use("Agg")', - "matplotlib.use('Svg')", - 'matplotlib.use("Svg")', - "matplotlib.use('Cairo')", - 'matplotlib.use("Cairo")', - "matplotlib.use('PDF')", - 'matplotlib.use("PDF")', - "matplotlib.use('PS')", - 'matplotlib.use("PS")', - "matplotlib.use('WebAgg')", - 'matplotlib.use("WebAgg")', - "matplotlib.use('nbAgg')", - 'matplotlib.use("nbAgg")', - ] - - def validate(ctx: Context) -> ValidationResult: - extraction_result = extract_python_code(ctx) - if not extraction_result.as_bool() or extraction_result.reason is None: - return ValidationResult(result=True) - - code = extraction_result.reason - has_show = _code_contains_strings(code, show_patterns) - has_backend = _code_contains_strings(code, backend_patterns) - - if has_show and not has_backend: - savefig_instruction = ( - f"plt.savefig('{output_path}'); plt.close()" - if output_path - else "plt.savefig('{output_path}'); plt.close()" - ) - return ValidationResult( - result=False, - reason=f"Your code calls `plt.show()` but doesn't set a headless backend.\n" - f"This will fail in a headless environment (no display).\n\n" - f"Fix this by adding to the top of your code:\n" - f" import matplotlib\n" - f" matplotlib.use('Agg')\n\n" - f"Then replace `plt.show()` with `{savefig_instruction}`", - ) - - return ValidationResult(result=True) - - return validate - - -def _make_plots_saved_validator( - output_path: str | None = None, - plot_patterns: list[str] | None = None, - save_patterns: list[str] | None = None, -) -> Callable[[Context], ValidationResult]: - """Create a validator that checks if code saves plots to a file. - - Args: - output_path: Path where plots should be saved - plot_patterns: Patterns indicating plot creation (e.g., ["plt.plot", "plt.scatter"]) - save_patterns: Patterns indicating plot saving (e.g., ["savefig"]) - """ - if plot_patterns is None: - plot_patterns = [ - "plt.plot", - "plt.bar", - "plt.scatter", - "plt.hist", - "plt.imshow", - "plt.figure", - "plt.subplot", - ".plot(", - ".bar(", - ".scatter(", - ".hist(", - ] - if save_patterns is None: - save_patterns = ["savefig"] - - def validate(ctx: Context) -> ValidationResult: - extraction_result = extract_python_code(ctx) - if not extraction_result.as_bool() or extraction_result.reason is None: - return ValidationResult(result=True) - - code = extraction_result.reason - has_plot = _code_contains_strings(code, plot_patterns) - has_save = _code_contains_strings(code, save_patterns) - - if has_plot and not has_save: - savefig_instruction = ( - f"plt.savefig('{output_path}')\n plt.close()" - if output_path - else "plt.savefig('{output_path}')\n plt.close()" - ) - return ValidationResult( - result=False, - reason=f"Your code creates plots with pyplot but never calls `plt.savefig()` to save them.\n\n" - f"Add this before your plotting code or at the end:\n" - f" {savefig_instruction}", - ) - - return ValidationResult(result=True) - - return validate - - -def python_plotting_requirements( - output_path: str | None = None, - *, - check_output_artifacts: bool | None = None, - show_patterns: list[str] | None = None, - backend_patterns: list[str] | None = None, - plot_patterns: list[str] | None = None, - save_patterns: list[str] | None = None, -) -> list[Requirement]: - """Build plotting-specific requirements for Python tool responses. - - Args: - output_path: Path where plots should be saved - check_output_artifacts: Whether to verify the output file exists - show_patterns: Patterns indicating plt.show() calls - backend_patterns: Patterns indicating headless backend setup - plot_patterns: Patterns indicating plot creation - save_patterns: Patterns indicating plot saving (e.g., savefig) - """ - reqs: list[Requirement] = [] - - reqs.append( - Requirement( - description="If using pyplot, must set headless backend and use savefig.", - validation_fn=_make_matplotlib_headless_validator( - output_path, - show_patterns=show_patterns, - backend_patterns=backend_patterns, - ), - check_only=False, - ) - ) - - reqs.append( - Requirement( - description="If creating plots, must call savefig to save them.", - validation_fn=_make_plots_saved_validator( - output_path, plot_patterns=plot_patterns, save_patterns=save_patterns - ), - check_only=False, - ) - ) - - if check_output_artifacts and output_path: - reqs.append( - Requirement( - description=f"Output file must be created at {output_path}", - validation_fn=_make_output_artifacts_validator(output_path), - check_only=False, - ) - ) - - return reqs - - -def _make_output_artifacts_validator( - output_path: str, -) -> Callable[[Context], ValidationResult]: - """Create a validator that checks if output file exists post-execution.""" - - def validate(ctx: Context) -> ValidationResult: - path = Path(output_path) - - if not path.exists(): - return ValidationResult( - result=False, - reason=f"The output file '{output_path}' was not created during execution.\n" - f"Make sure your code calls `plt.savefig('{output_path}')` to save the plot.", - ) - - if path.stat().st_size == 0: - return ValidationResult( - result=False, - reason=f"The output file '{output_path}' exists but is empty.\n" - f"Check that your plot code executed correctly.", - ) - - return ValidationResult(result=True) - - return validate - - def _make_output_limit_validator( limit_bytes: int, ) -> Callable[[Context], ValidationResult]: diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py index 6d47e19fe..e13182fe7 100644 --- a/test/stdlib/requirements/test_python_tools.py +++ b/test/stdlib/requirements/test_python_tools.py @@ -14,12 +14,14 @@ ) from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.imports import get_unauthorized_imports -from mellea.stdlib.requirements.python_tools import ( +from mellea.stdlib.requirements.plotting.matplotlib import ( _calls_savefig, - _code_parses, _sets_headless_backend, _uses_pyplot_plot, _uses_pyplot_show, +) +from mellea.stdlib.requirements.python_tools import ( + _code_parses, python_tool_requirements, ) From 85c5c67f824c305de9eab4c415c190c05358cd96 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 11 May 2026 19:54:50 -0400 Subject: [PATCH 08/12] review cmments Signed-off-by: Akihiko Kuroda --- docs/examples/tools/python_plotting_repair.py | 34 +++++-------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/docs/examples/tools/python_plotting_repair.py b/docs/examples/tools/python_plotting_repair.py index 26e958d04..b42d16099 100644 --- a/docs/examples/tools/python_plotting_repair.py +++ b/docs/examples/tools/python_plotting_repair.py @@ -1,7 +1,6 @@ # pytest: ollama, e2e, qualitative """Repair plotting code with Python-tool and plotting-specific requirements.""" -import asyncio import tempfile import traceback from pathlib import Path @@ -9,8 +8,6 @@ import mellea from mellea.backends import ModelOption from mellea.backends.tools import MelleaTool -from mellea.stdlib.components import Instruction -from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import ( python_plotting_requirements, python_tool_requirements, @@ -32,12 +29,12 @@ def python(code: str) -> ExecutionResult: return local_code_interpreter(code) -async def main(): +def main(): """Run the plotting repair example.""" with tempfile.TemporaryDirectory() as tmpdir: output_path = str(Path(tmpdir) / "plot.png") - m = mellea.start_session() + m = mellea.start_session(context_type="chat") requirements = [ *python_tool_requirements(allowed_imports=["numpy", "matplotlib", "math"]), @@ -55,19 +52,6 @@ async def main(): task_summary = ( f"Create a plot of sin(x) for x in 0..2π and save it to {output_path}" ) - description = f"""{task_summary} - -Requirements: -- Use the python tool to execute your code -- Import numpy and matplotlib -- Generate x values from 0 to 2π -- Plot sin(x) against x -- Save the plot to the specified file path - -Use the python tool with your complete code.""" - instruction = Instruction(description=description) - - ctx = ChatContext() print("=" * 70) print("Testing plotting-code repair with Python tool requirements") @@ -75,11 +59,11 @@ async def main(): print(f"Task: {task_summary}\n") try: - result = await sampling_strategy.sample( - action=instruction, - context=ctx, - backend=m.backend, + result = m.instruct( + task_summary, requirements=requirements, + strategy=sampling_strategy, + return_sampling_results=True, tool_calls=True, model_options={ModelOption.TOOLS: [MelleaTool.from_callable(python)]}, ) @@ -93,8 +77,8 @@ async def main(): print(result.result.value) print("-" * 70) - if Path(output_path).exists(): # noqa: ASYNC240 - file_size = Path(output_path).stat().st_size # noqa: ASYNC240 + if Path(output_path).exists(): + file_size = Path(output_path).stat().st_size print(f"\n✓ Output file created: {output_path}") print(f" File size: {file_size} bytes") else: @@ -147,6 +131,6 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + main() # Made with Bob From 8f391da971d34c5ba553b76d8426aa9f86849c43 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 11 May 2026 21:05:07 -0400 Subject: [PATCH 09/12] fix doc strings Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/plotting/matplotlib.py | 3 +++ mellea/stdlib/requirements/python_reqs.py | 10 +++++++++- mellea/stdlib/requirements/python_tools.py | 12 +++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/requirements/plotting/matplotlib.py b/mellea/stdlib/requirements/plotting/matplotlib.py index 682850d43..b42b98388 100644 --- a/mellea/stdlib/requirements/plotting/matplotlib.py +++ b/mellea/stdlib/requirements/plotting/matplotlib.py @@ -394,6 +394,9 @@ def python_plotting_requirements( backend_patterns: Patterns indicating headless backend setup plot_patterns: Patterns indicating plot creation save_patterns: Patterns indicating plot saving (e.g., savefig) + + Returns: + List of Requirement objects that validate matplotlib usage and plot output. """ reqs: list[Requirement] = [] diff --git a/mellea/stdlib/requirements/python_reqs.py b/mellea/stdlib/requirements/python_reqs.py index 4ba2ee8bb..79fbcf002 100644 --- a/mellea/stdlib/requirements/python_reqs.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -59,7 +59,15 @@ def _score_code_block(code: str) -> int: def extract_python_code(ctx: Context) -> ValidationResult: - """Extract Python code from tool calls or markdown code blocks.""" + """Extract Python code from tool calls or markdown code blocks. + + Args: + ctx: Context object containing the LLM output to extract code from. + + Returns: + ValidationResult with result=True and the code as reason if extraction succeeds, + or result=False with an error message if no code is found. + """ last_output = ctx.last_output() if last_output is None: return ValidationResult(result=False, reason="No output found in context") diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 618b1bb87..8b86b0be0 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -151,7 +151,17 @@ def python_tool_requirements( output_limit_bytes: int = 50_000, check_output_artifacts: bool | None = None, ) -> list[Requirement]: - """Build requirements for Python code generation via the python tool.""" + """Build requirements for Python code generation via the python tool. + + Args: + output_path: Path where plotting output should be saved; enables plot-related checks. + allowed_imports: List of allowed import module names; if provided, code must only import these. + output_limit_bytes: Maximum bytes for stdout/stderr combined; defaults to 50KB. + check_output_artifacts: Whether to verify output file exists after execution; auto-enabled if output_path is set. + + Returns: + List of Requirement objects that validate python tool usage and code correctness. + """ reqs: list[Requirement] = [] if check_output_artifacts is None: From 8e24e97eaf4ecd72647abd6c8d15c88743947fd1 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Tue, 12 May 2026 18:29:00 -0400 Subject: [PATCH 10/12] review comments Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/imports.py | 34 ----- .../requirements/plotting/matplotlib.py | 141 ++++++++---------- mellea/stdlib/requirements/python_reqs.py | 93 ++++++------ mellea/stdlib/requirements/python_tools.py | 24 ++- mellea/stdlib/tools/interpreter.py | 43 +++++- test/stdlib/requirements/test_python_tools.py | 2 +- 6 files changed, 168 insertions(+), 169 deletions(-) delete mode 100644 mellea/stdlib/requirements/imports.py diff --git a/mellea/stdlib/requirements/imports.py b/mellea/stdlib/requirements/imports.py deleted file mode 100644 index d95ba16c4..000000000 --- a/mellea/stdlib/requirements/imports.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Import analysis helpers for Python requirements and execution environments.""" - -import ast - - -def get_unauthorized_imports( - code: str, allowed_imports: list[str] | None = None -) -> list[str]: - """Extract unauthorized top-level imports from Python code.""" - if allowed_imports is None: - return [] - - unauthorized: set[str] = set() - try: - tree = ast.parse(code) - except (SyntaxError, ValueError): - # Syntax errors are validated separately by dedicated validators. - return [] - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - module = alias.name.split(".")[0] - if module not in allowed_imports: - unauthorized.add(module) - elif isinstance(node, ast.ImportFrom) and node.module: - module = node.module.split(".")[0] - if module not in allowed_imports: - unauthorized.add(module) - - return sorted(unauthorized) - - -# Made with Bob diff --git a/mellea/stdlib/requirements/plotting/matplotlib.py b/mellea/stdlib/requirements/plotting/matplotlib.py index b42b98388..9cc71befa 100644 --- a/mellea/stdlib/requirements/plotting/matplotlib.py +++ b/mellea/stdlib/requirements/plotting/matplotlib.py @@ -12,6 +12,12 @@ from ....core import Context, Requirement, ValidationResult from ..python_reqs import extract_python_code +# Headless matplotlib backends that don't require a display server +HEADLESS_BACKENDS = {"Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"} + +# Pyplot plot creation methods +PYPLOT_PLOT_METHODS = {"plot", "bar", "scatter", "hist", "imshow", "figure", "subplot"} + def _strip_comments(code: str) -> str: """Remove Python comments from code while preserving strings. @@ -140,6 +146,8 @@ def _uses_pyplot_show(code: str) -> bool: aliases (e.g., `import matplotlib.pyplot as mpl`). AST approach detects actual method calls, avoiding false positives from string literals. Falls back to string matching only if code doesn't parse. + + Note: May have false positives if code calls .show() on non-matplotlib objects. """ if _find_attribute_calls(code, ["show"]): return True @@ -156,42 +164,34 @@ def _sets_headless_backend(code: str) -> bool: Uses AST analysis to detect matplotlib.use() calls with headless backends. Handles various matplotlib import styles and fallback to string matching. """ - if _find_function_calls(code, ["matplotlib.use"]): - headless_backends = {"Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"} - - try: - tree = ast.parse(code) - except (SyntaxError, ValueError): - return _code_contains_strings( - code, [f"matplotlib.use('{b}')" for b in headless_backends] - ) + if not _find_function_calls(code, ["matplotlib.use"]): + return False - class BackendFinder(ast.NodeVisitor): - def __init__(self): - self.has_headless = False - - def visit_Call(self, node: ast.Call) -> None: - if isinstance(node.func, ast.Attribute): - if node.func.attr == "use": - if isinstance(node.func.value, ast.Name): - if node.func.value.id == "matplotlib": - if node.args and isinstance(node.args[0], ast.Constant): - if node.args[0].value in headless_backends: - self.has_headless = True - self.generic_visit(node) - - finder = BackendFinder() - finder.visit(tree) - if finder.has_headless: - return True - - return _code_contains_strings( - code, - [ - f"matplotlib.use('{b}')" - for b in ["Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"] - ], - ) + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + backend_patterns = [f"matplotlib.use('{b}')" for b in HEADLESS_BACKENDS] + [ + f'matplotlib.use("{b}")' for b in HEADLESS_BACKENDS + ] + return _code_contains_strings(code, backend_patterns) + + class BackendFinder(ast.NodeVisitor): + def __init__(self): + self.has_headless = False + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute): + if node.func.attr == "use": + if isinstance(node.func.value, ast.Name): + if node.func.value.id == "matplotlib": + if node.args and isinstance(node.args[0], ast.Constant): + if node.args[0].value in HEADLESS_BACKENDS: + self.has_headless = True + self.generic_visit(node) + + finder = BackendFinder() + finder.visit(tree) + return finder.has_headless def _uses_pyplot_plot(code: str) -> bool: @@ -201,15 +201,19 @@ def _uses_pyplot_plot(code: str) -> bool: aliases and detects actual method calls, avoiding false positives from string literals or method references. Falls back to string matching only if code doesn't parse. + + Note: May have false positives if code calls these methods on non-pyplot objects. + For accuracy, combine with matplotlib import checks. """ - plot_methods = {"plot", "bar", "scatter", "hist", "imshow", "figure", "subplot"} - if _find_attribute_calls(code, list(plot_methods)): + if _find_attribute_calls(code, list(PYPLOT_PLOT_METHODS)): return True try: ast.parse(code) except (SyntaxError, ValueError): return _code_contains_strings( - code, [f".{m}(" for m in plot_methods] + [f"plt.{m}" for m in plot_methods] + code, + [f".{m}(" for m in PYPLOT_PLOT_METHODS] + + [f"plt.{m}" for m in PYPLOT_PLOT_METHODS], ) return False @@ -221,6 +225,8 @@ def _calls_savefig(code: str) -> bool: how matplotlib was imported. Detects actual method calls, avoiding false positives from string literals. Falls back to string matching only if code doesn't parse. + + Note: May have false positives if code calls savefig() on non-matplotlib objects. """ if _find_attribute_calls(code, ["savefig"]): return True @@ -236,31 +242,23 @@ def _make_matplotlib_headless_validator( show_patterns: list[str] | None = None, backend_patterns: list[str] | None = None, ) -> Callable[[Context], ValidationResult]: - r"""Create a validator that checks matplotlib uses headless backend. + """Create a validator that checks matplotlib uses headless backend. + + This validator checks if code calls plt.show() without setting a headless + backend. It differs from _sets_headless_backend in that it combines both + checks into a single validation that flags the specific error condition. Args: output_path: Path where plots should be saved - show_patterns: Patterns indicating plt.show() calls (e.g., ["plt.show", ".show()"]) - backend_patterns: Patterns indicating headless backend setup (e.g., ["matplotlib.use('Agg')", "matplotlib.use(\"Agg\")"]) + show_patterns: Patterns indicating plt.show() calls; defaults to ["plt.show", ".show()"] + backend_patterns: Patterns indicating headless backend setup; defaults to all + matplotlib.use() calls with HEADLESS_BACKENDS """ if show_patterns is None: show_patterns = ["plt.show", ".show()"] if backend_patterns is None: - backend_patterns = [ - "matplotlib.use('Agg')", - 'matplotlib.use("Agg")', - "matplotlib.use('Svg')", - 'matplotlib.use("Svg")', - "matplotlib.use('Cairo')", - 'matplotlib.use("Cairo")', - "matplotlib.use('PDF')", - 'matplotlib.use("PDF")', - "matplotlib.use('PS')", - 'matplotlib.use("PS")', - "matplotlib.use('WebAgg')", - 'matplotlib.use("WebAgg")', - "matplotlib.use('nbAgg')", - 'matplotlib.use("nbAgg")', + backend_patterns = [f"matplotlib.use('{b}')" for b in HEADLESS_BACKENDS] + [ + f'matplotlib.use("{b}")' for b in HEADLESS_BACKENDS ] def validate(ctx: Context) -> ValidationResult: @@ -302,22 +300,13 @@ def _make_plots_saved_validator( Args: output_path: Path where plots should be saved - plot_patterns: Patterns indicating plot creation (e.g., ["plt.plot", "plt.scatter"]) - save_patterns: Patterns indicating plot saving (e.g., ["savefig"]) + plot_patterns: Patterns indicating plot creation; defaults to all PYPLOT_PLOT_METHODS + prefixed with "plt." and "." + save_patterns: Patterns indicating plot saving; defaults to ["savefig"] """ if plot_patterns is None: - plot_patterns = [ - "plt.plot", - "plt.bar", - "plt.scatter", - "plt.hist", - "plt.imshow", - "plt.figure", - "plt.subplot", - ".plot(", - ".bar(", - ".scatter(", - ".hist(", + plot_patterns = [f"plt.{m}" for m in PYPLOT_PLOT_METHODS] + [ + f".{m}(" for m in PYPLOT_PLOT_METHODS ] if save_patterns is None: save_patterns = ["savefig"] @@ -389,11 +378,13 @@ def python_plotting_requirements( Args: output_path: Path where plots should be saved - check_output_artifacts: Whether to verify the output file exists - show_patterns: Patterns indicating plt.show() calls - backend_patterns: Patterns indicating headless backend setup - plot_patterns: Patterns indicating plot creation - save_patterns: Patterns indicating plot saving (e.g., savefig) + check_output_artifacts: Whether to verify the output file exists; defaults to False + show_patterns: Patterns indicating plt.show() calls; defaults to ["plt.show", ".show()"] + backend_patterns: Patterns indicating headless backend setup; defaults to all + matplotlib.use() calls with HEADLESS_BACKENDS + plot_patterns: Patterns indicating plot creation; defaults to all PYPLOT_PLOT_METHODS + prefixed with "plt." and "." + save_patterns: Patterns indicating plot saving; defaults to ["savefig"] Returns: List of Requirement objects that validate matplotlib usage and plot output. diff --git a/mellea/stdlib/requirements/python_reqs.py b/mellea/stdlib/requirements/python_reqs.py index 79fbcf002..5efa223a2 100644 --- a/mellea/stdlib/requirements/python_reqs.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -58,32 +58,19 @@ def _score_code_block(code: str) -> int: return score -def extract_python_code(ctx: Context) -> ValidationResult: - """Extract Python code from tool calls or markdown code blocks. +def _extract_markdown_python_code(content: str) -> ValidationResult: + """Extract best Python code block from markdown content. + + Searches for both ```python ... ``` and generic ``` ... ``` blocks, + scores them by code quality, and returns the highest-scoring block. Args: - ctx: Context object containing the LLM output to extract code from. + content: Text content to search for code blocks. Returns: - ValidationResult with result=True and the code as reason if extraction succeeds, - or result=False with an error message if no code is found. + ValidationResult with result=True and the code as reason if blocks found, + or result=False if no code blocks found. """ - last_output = ctx.last_output() - if last_output is None: - return ValidationResult(result=False, reason="No output found in context") - - if last_output.tool_calls and "python" in last_output.tool_calls: - tool_call = last_output.tool_calls["python"] - code = tool_call.args.get("code") - if isinstance(code, str) and code.strip(): - return ValidationResult(result=True, reason=code) - - if last_output.value is None: - return ValidationResult(result=False, reason="No output found in context") - - content = last_output.value - - # Look for code blocks with python specifier import re # Pattern for ```python ... ``` blocks @@ -115,44 +102,50 @@ def extract_python_code(ctx: Context) -> ValidationResult: return ValidationResult(result=True, reason=best_block[0]) -def _has_python_code_listing(ctx: Context) -> ValidationResult: - """Extract Python code from markdown code blocks in context.""" - last_output = ctx.last_output() - if last_output is None or last_output.value is None: - return ValidationResult(result=False, reason="No output found in context") +def extract_python_code(ctx: Context) -> ValidationResult: + """Extract Python code from tool calls or markdown code blocks. - content = last_output.value + Checks for code in two places (in order of priority): + 1. Direct python tool calls (used by python interpreter tool) + 2. Markdown ```python or ``` code blocks in text responses - # Look for code blocks with python specifier - import re + This function is used by requirements validators that may be called before + or after tool invocation, so it checks both sources. - # Pattern for ```python ... ``` blocks - python_blocks = re.findall(r"```python\s*\n(.*?)\n```", content, re.DOTALL) + Args: + ctx: Context object containing the LLM output to extract code from. - # Pattern for generic ``` blocks - generic_blocks = re.findall(r"```\s*\n(.*?)\n```", content, re.DOTALL) + Returns: + ValidationResult with result=True and the code as reason if extraction succeeds, + or result=False with an error message if no code is found. + """ + last_output = ctx.last_output() + if last_output is None: + return ValidationResult(result=False, reason="No output found in context") - all_blocks = [] + if last_output.tool_calls and "python" in last_output.tool_calls: + tool_call = last_output.tool_calls["python"] + code = tool_call.args.get("code") + if isinstance(code, str) and code.strip(): + return ValidationResult(result=True, reason=code) - # Add python blocks with high priority - for block in python_blocks: - all_blocks.append((block.strip(), _score_code_block(block.strip()) + 10)) + if last_output.value is None: + return ValidationResult(result=False, reason="No output found in context") - # Add generic blocks if they look like Python - for block in generic_blocks: - block = block.strip() - if block and any( - keyword in block - for keyword in ["def ", "class ", "import ", "print(", "if __name__"] - ): - all_blocks.append((block, _score_code_block(block))) + return _extract_markdown_python_code(last_output.value) - if not all_blocks: - return ValidationResult(result=False, reason="No Python code blocks found") - # Return the highest scoring block - best_block = max(all_blocks, key=lambda x: x[1]) - return ValidationResult(result=True, reason=best_block[0]) +def _has_python_code_listing(ctx: Context) -> ValidationResult: + """Extract Python code from markdown code blocks in context only. + + Similar to extract_python_code but does not check tool calls. Used internally + by execution validators that always work with markdown code blocks. + """ + last_output = ctx.last_output() + if last_output is None or last_output.value is None: + return ValidationResult(result=False, reason="No output found in context") + + return _extract_markdown_python_code(last_output.value) # endregion diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 8b86b0be0..57f6b2265 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -1,16 +1,14 @@ """Requirement factories for Python tool invocation and code validation. -This module provides generic requirements for Python-tool usage and code -correctness. Plotting-specific checks are exposed separately through -``plotting.python_plotting_requirements(...)`` so they are not implied to be -universal Python-tool requirements. +This module provides requirements for validating Python code, including syntax, +imports, and plotting. The python_tool_requirements() function bundles these +together, while specialized validators can be used independently. """ from collections.abc import Callable from ...core import Context, Requirement, ValidationResult -from ..tools.interpreter import StaticAnalysisEnvironment -from .imports import get_unauthorized_imports +from ..tools.interpreter import StaticAnalysisEnvironment, get_unauthorized_imports from .plotting import python_plotting_requirements from .python_reqs import extract_python_code from .tool_reqs import tool_arg_validator, uses_tool @@ -50,7 +48,13 @@ def _python_code_arg_present(arg_value: object) -> bool: def _make_code_parses_validator() -> Callable[[Context], ValidationResult]: - """Create a validator that checks if extracted code parses.""" + """Create a validator that checks if extracted code parses. + + This validator searches for Python code in the context, checking both + direct python tool calls and markdown code blocks. The python tool is + invoked synchronously as part of the LLM's response generation, so code + is available in the context for validation at check time. + """ def validate(ctx: Context) -> ValidationResult: extraction_result = extract_python_code(ctx) @@ -79,7 +83,11 @@ def validate(ctx: Context) -> ValidationResult: def _make_imports_allowed_validator( allowed_imports: list[str] | None, ) -> Callable[[Context], ValidationResult]: - """Create a validator that checks if code imports are in allowlist.""" + """Create a validator that checks if code imports are in allowlist. + + This validator extracts Python code from the context (tool calls or markdown + blocks) and checks that all imports are in the allowed list. + """ def validate(ctx: Context) -> ValidationResult: if allowed_imports is None: diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 008d4ff30..f9ad46fe2 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -20,7 +20,6 @@ from typing import Any from ...core import MelleaLogger -from ..requirements.imports import get_unauthorized_imports logger = MelleaLogger.get_logger() @@ -302,6 +301,48 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: ) +def get_unauthorized_imports(code: str, allowed_imports: list[str] | None) -> list[str]: + """Get list of unauthorized imports used in code. + + Analyzes Python code to extract imports and checks them against an allowlist. + Handles both `import X` and `from X import Y` style imports. + + Args: + code: Python source code to analyze + allowed_imports: List of allowed top-level module names; if None, allows any import + + Returns: + Sorted list of unauthorized top-level modules, or empty list if all imports allowed + """ + if allowed_imports is None: + return [] + + unauthorized: list[str] = [] + try: + tree = ast.parse(code) + except SyntaxError: + return unauthorized + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + base_module = alias.name.split(".")[0] + if ( + base_module not in allowed_imports + and base_module not in unauthorized + ): + unauthorized.append(base_module) + elif isinstance(node, ast.ImportFrom): + if node.module: + base_module = node.module.split(".")[0] + if ( + base_module not in allowed_imports + and base_module not in unauthorized + ): + unauthorized.append(base_module) + return sorted(unauthorized) + + def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: """Check if code only uses allowed imports.""" return len(get_unauthorized_imports(code, allowed_imports)) == 0 diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py index e13182fe7..13f3000ae 100644 --- a/test/stdlib/requirements/test_python_tools.py +++ b/test/stdlib/requirements/test_python_tools.py @@ -13,7 +13,6 @@ ValidationResult, ) from mellea.stdlib.context import ChatContext -from mellea.stdlib.requirements.imports import get_unauthorized_imports from mellea.stdlib.requirements.plotting.matplotlib import ( _calls_savefig, _sets_headless_backend, @@ -24,6 +23,7 @@ _code_parses, python_tool_requirements, ) +from mellea.stdlib.tools.interpreter import get_unauthorized_imports def from_model(content: str) -> Context: From b7261eaee9b9f0b243a5f2c477f558ef5b1f2f9b Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Thu, 14 May 2026 12:18:37 -0400 Subject: [PATCH 11/12] review comment Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/plotting/matplotlib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/requirements/plotting/matplotlib.py b/mellea/stdlib/requirements/plotting/matplotlib.py index 9cc71befa..bc2dad747 100644 --- a/mellea/stdlib/requirements/plotting/matplotlib.py +++ b/mellea/stdlib/requirements/plotting/matplotlib.py @@ -274,7 +274,7 @@ def validate(ctx: Context) -> ValidationResult: savefig_instruction = ( f"plt.savefig('{output_path}'); plt.close()" if output_path - else "plt.savefig('{output_path}'); plt.close()" + else "plt.savefig(''); plt.close()" ) return ValidationResult( result=False, @@ -324,7 +324,7 @@ def validate(ctx: Context) -> ValidationResult: savefig_instruction = ( f"plt.savefig('{output_path}')\n plt.close()" if output_path - else "plt.savefig('{output_path}')\n plt.close()" + else "plt.savefig('')\n plt.close()" ) return ValidationResult( result=False, From 94051c9d6d4bb83c15a05b823b347a10ab2e8579 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Thu, 14 May 2026 17:03:26 -0400 Subject: [PATCH 12/12] review comments Signed-off-by: Akihiko Kuroda --- docs/examples/tools/python_plotting_repair.py | 56 +++++++++++++++---- .../requirements/plotting/matplotlib.py | 12 +--- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/docs/examples/tools/python_plotting_repair.py b/docs/examples/tools/python_plotting_repair.py index b42d16099..c2e042878 100644 --- a/docs/examples/tools/python_plotting_repair.py +++ b/docs/examples/tools/python_plotting_repair.py @@ -1,5 +1,20 @@ # pytest: ollama, e2e, qualitative -"""Repair plotting code with Python-tool and plotting-specific requirements.""" +"""Repair plotting code with Python-tool and plotting-specific requirements. + +This example demonstrates the full tool lifecycle: +1. Model generates code and creates tool calls +2. Sampling validation checks code quality without execution +3. Tool is explicitly invoked after sampling succeeds (via _call_tools) +4. Results are returned to caller for inspection/handling + +Key insight: Tool execution is explicit and controlled by the caller, +not automatic within the sampling pipeline. This allows fine-grained control +over when/if tools are invoked, and enables safety checks (see tool_hooks.py). + +Prerequisites: + matplotlib must be installed for code execution to succeed: + $ uv pip install matplotlib +""" import tempfile import traceback @@ -8,6 +23,7 @@ import mellea from mellea.backends import ModelOption from mellea.backends.tools import MelleaTool +from mellea.stdlib.functional import _call_tools from mellea.stdlib.requirements import ( python_plotting_requirements, python_tool_requirements, @@ -71,18 +87,34 @@ def main(): print(f"\nResult: {'SUCCESS' if result.success else 'FAILED'}\n") if result.success: - print("✓ Model successfully generated and executed plotting code") - print("\nFinal generated code:") - print("-" * 70) - print(result.result.value) - print("-" * 70) - - if Path(output_path).exists(): - file_size = Path(output_path).stat().st_size - print(f"\n✓ Output file created: {output_path}") - print(f" File size: {file_size} bytes") + print("✓ Model successfully generated plotting code") + + # Invoke the generated tools from the final result + if ( + result.result + and hasattr(result.result, "tool_calls") + and result.result.tool_calls + ): + # Print the generated code + for tool_name, tool_call in result.result.tool_calls.items(): + if tool_call.args and "code" in tool_call.args: + code = tool_call.args["code"] + print(f"\n{'=' * 70}") + print(f"Generated Python code for tool '{tool_name}':") + print(f"{'=' * 70}") + print(code) + print(f"{'=' * 70}\n") + + tool_outputs = _call_tools(result.result, m.backend) + + if tool_outputs: + print("✓ Tool executed successfully") + for i, output in enumerate(tool_outputs, 1): + print(f" Output {i}: {output.content}") else: - print(f"\n✗ Output file not found: {output_path}") + print("ℹ No tool calls in final result") + + print(f"\nCode saved to: {output_path}") print(f"\nRepair iterations: {len(result.sample_validations)}") for attempt_idx, validations in enumerate(result.sample_validations, 1): diff --git a/mellea/stdlib/requirements/plotting/matplotlib.py b/mellea/stdlib/requirements/plotting/matplotlib.py index bc2dad747..76292e255 100644 --- a/mellea/stdlib/requirements/plotting/matplotlib.py +++ b/mellea/stdlib/requirements/plotting/matplotlib.py @@ -210,11 +210,7 @@ def _uses_pyplot_plot(code: str) -> bool: try: ast.parse(code) except (SyntaxError, ValueError): - return _code_contains_strings( - code, - [f".{m}(" for m in PYPLOT_PLOT_METHODS] - + [f"plt.{m}" for m in PYPLOT_PLOT_METHODS], - ) + return _code_contains_strings(code, [f".{m}(" for m in PYPLOT_PLOT_METHODS]) return False @@ -230,11 +226,7 @@ def _calls_savefig(code: str) -> bool: """ if _find_attribute_calls(code, ["savefig"]): return True - try: - ast.parse(code) - except (SyntaxError, ValueError): - return _code_contains_strings(code, ["savefig"]) - return False + return _code_contains_strings(code, ["savefig"]) def _make_matplotlib_headless_validator(