diff --git a/docs/examples/tools/python_plotting_repair.py b/docs/examples/tools/python_plotting_repair.py new file mode 100644 index 000000000..c2e042878 --- /dev/null +++ b/docs/examples/tools/python_plotting_repair.py @@ -0,0 +1,168 @@ +# pytest: ollama, e2e, qualitative +"""Repair plotting code with Python-tool and plotting-specific requirements. + +This example demonstrates the full tool lifecycle: +1. Model generates code and creates tool calls +2. Sampling validation checks code quality without execution +3. Tool is explicitly invoked after sampling succeeds (via _call_tools) +4. Results are returned to caller for inspection/handling + +Key insight: Tool execution is explicit and controlled by the caller, +not automatic within the sampling pipeline. This allows fine-grained control +over when/if tools are invoked, and enables safety checks (see tool_hooks.py). + +Prerequisites: + matplotlib must be installed for code execution to succeed: + $ uv pip install matplotlib +""" + +import tempfile +import traceback +from pathlib import Path + +import mellea +from mellea.backends import ModelOption +from mellea.backends.tools import MelleaTool +from mellea.stdlib.functional import _call_tools +from mellea.stdlib.requirements import ( + python_plotting_requirements, + python_tool_requirements, +) +from mellea.stdlib.sampling import SOFAISamplingStrategy +from mellea.stdlib.tools import local_code_interpreter +from mellea.stdlib.tools.interpreter import ExecutionResult + + +def python(code: str) -> ExecutionResult: + """Execute Python code. + + Args: + code: Python code to execute + + Returns: + Execution result containing stdout, stderr, and success status + """ + return local_code_interpreter(code) + + +def main(): + """Run the plotting repair example.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + + m = mellea.start_session(context_type="chat") + + requirements = [ + *python_tool_requirements(allowed_imports=["numpy", "matplotlib", "math"]), + *python_plotting_requirements(output_path=output_path), + ] + + sampling_strategy = SOFAISamplingStrategy( + s1_solver_backend=m.backend, + s2_solver_backend=m.backend, + s2_solver_mode="fresh_start", + loop_budget=3, + feedback_strategy="first_error", + ) + + task_summary = ( + f"Create a plot of sin(x) for x in 0..2π and save it to {output_path}" + ) + + print("=" * 70) + print("Testing plotting-code repair with Python tool requirements") + print("=" * 70) + print(f"Task: {task_summary}\n") + + try: + result = m.instruct( + task_summary, + requirements=requirements, + strategy=sampling_strategy, + return_sampling_results=True, + tool_calls=True, + model_options={ModelOption.TOOLS: [MelleaTool.from_callable(python)]}, + ) + + print(f"\nResult: {'SUCCESS' if result.success else 'FAILED'}\n") + + if result.success: + print("✓ Model successfully generated plotting code") + + # Invoke the generated tools from the final result + if ( + result.result + and hasattr(result.result, "tool_calls") + and result.result.tool_calls + ): + # Print the generated code + for tool_name, tool_call in result.result.tool_calls.items(): + if tool_call.args and "code" in tool_call.args: + code = tool_call.args["code"] + print(f"\n{'=' * 70}") + print(f"Generated Python code for tool '{tool_name}':") + print(f"{'=' * 70}") + print(code) + print(f"{'=' * 70}\n") + + tool_outputs = _call_tools(result.result, m.backend) + + if tool_outputs: + print("✓ Tool executed successfully") + for i, output in enumerate(tool_outputs, 1): + print(f" Output {i}: {output.content}") + else: + print("ℹ No tool calls in final result") + + print(f"\nCode saved to: {output_path}") + + print(f"\nRepair iterations: {len(result.sample_validations)}") + for attempt_idx, validations in enumerate(result.sample_validations, 1): + passed = sum(1 for _, val in validations if val.as_bool()) + total = len(validations) + status = "✓" if passed == total else "✗" + print( + f" {status} Attempt {attempt_idx}: {passed}/{total} " + f"requirements passed" + ) + + for req, val in validations: + if not val.as_bool(): + print(f" - {req.description}") + if val.reason: + reason_preview = val.reason[:100].replace("\n", " ") + print(f" Error: {reason_preview}...") + + else: + print("✗ Failed to generate working plotting code after all attempts\n") + print("Last attempt output:") + print("-" * 70) + print(result.result.value) + print("-" * 70) + + print(f"\nFailure history ({len(result.sample_validations)} attempts):") + for attempt_idx, validations in enumerate(result.sample_validations, 1): + failed_count = sum(1 for _, val in validations if not val.as_bool()) + if failed_count > 0: + print(f"\n Attempt {attempt_idx}:") + for req, val in validations: + if not val.as_bool(): + print(f" - {req.description}") + if val.reason: + reason_lines = val.reason.split("\n")[:2] + for line in reason_lines: + print(f" {line}") + + except Exception as e: + print(f"✗ Exception during sampling: {e}") + traceback.print_exc() + + print("\n" + "=" * 70) + print("Test completed") + print("=" * 70) + + +if __name__ == "__main__": + main() + +# Made with Bob diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index c0bd7d3c9..4c9fec263 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -3,7 +3,9 @@ # Import from core for ergonomics. from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table +from .plotting import python_plotting_requirements from .python_reqs import PythonExecutionReq +from .python_tools import python_tool_requirements from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -26,6 +28,8 @@ "default_output_to_bool", "is_markdown_list", "is_markdown_table", + "python_plotting_requirements", + "python_tool_requirements", "req", "reqify", "requirement_check_to_bool", diff --git a/mellea/stdlib/requirements/plotting/__init__.py b/mellea/stdlib/requirements/plotting/__init__.py new file mode 100644 index 000000000..c10a7b080 --- /dev/null +++ b/mellea/stdlib/requirements/plotting/__init__.py @@ -0,0 +1,9 @@ +"""Plotting-specific requirements for Python tool validation. + +Provides matplotlib and plotting-focused requirement factories separate from +generic Python tool requirements. +""" + +from .matplotlib import python_plotting_requirements + +__all__ = ["python_plotting_requirements"] diff --git a/mellea/stdlib/requirements/plotting/matplotlib.py b/mellea/stdlib/requirements/plotting/matplotlib.py new file mode 100644 index 000000000..76292e255 --- /dev/null +++ b/mellea/stdlib/requirements/plotting/matplotlib.py @@ -0,0 +1,417 @@ +"""Matplotlib-specific requirement validators for plotting code. + +Provides plotting validation including headless backend detection, plot saving, +and output artifact validation. Uses AST-based analysis for robust detection +of matplotlib operations, with string matching fallback for edge cases. +""" + +import ast +from collections.abc import Callable +from pathlib import Path + +from ....core import Context, Requirement, ValidationResult +from ..python_reqs import extract_python_code + +# Headless matplotlib backends that don't require a display server +HEADLESS_BACKENDS = {"Agg", "Svg", "Cairo", "PDF", "PS", "WebAgg", "nbAgg"} + +# Pyplot plot creation methods +PYPLOT_PLOT_METHODS = {"plot", "bar", "scatter", "hist", "imshow", "figure", "subplot"} + + +def _strip_comments(code: str) -> str: + """Remove Python comments from code while preserving strings. + + Splits code by lines and removes comments (text after # that's not in a string). + Handles both single and double quoted strings. + """ + lines = code.split("\n") + result = [] + for line in lines: + in_string = False + string_char = None + for i, char in enumerate(line): + if char in ('"', "'") and (i == 0 or line[i - 1] != "\\"): + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None + elif char == "#" and not in_string: + result.append(line[:i]) + break + else: + result.append(line) + return "\n".join(result) + + +def _find_attribute_calls(code: str, method_names: list[str]) -> bool: + """Check if code calls any of the specified methods using AST. + + Handles import aliases (e.g., `import matplotlib.pyplot as plt`) and + validates that methods are actually called, not just referenced. + + Args: + code: Python source code to analyze + method_names: Method names to look for (e.g., ["show", "savefig"]) + + Returns: + True if any of the methods are called, False otherwise + """ + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return False + + class CallFinder(ast.NodeVisitor): + def __init__(self, method_names: list[str]): + self.method_names = set(method_names) + self.found = False + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute): + if node.func.attr in self.method_names: + self.found = True + self.generic_visit(node) + + finder = CallFinder(method_names) + finder.visit(tree) + return finder.found + + +def _find_function_calls(code: str, func_names: list[str]) -> bool: + """Check if code calls any of the specified functions using AST. + + Handles qualified names (e.g., `matplotlib.use()`) and detects actual + function calls, not just references. + + Args: + code: Python source code to analyze + func_names: Function names to look for (e.g., ["matplotlib.use"]) + + Returns: + True if any of the functions are called, False otherwise + """ + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + return False + + class FunctionCallFinder(ast.NodeVisitor): + def __init__(self, func_names: list[str]): + self.func_names = set(func_names) + self.found = False + + def visit_Call(self, node: ast.Call) -> None: + func_name = self._get_full_name(node.func) + if func_name in self.func_names: + self.found = True + self.generic_visit(node) + + def _get_full_name(self, node: ast.expr) -> str: + """Extract full qualified name from an AST node.""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + value_name = self._get_full_name(node.value) + if value_name: + return f"{value_name}.{node.attr}" + return node.attr + return "" + + finder = FunctionCallFinder(func_names) + finder.visit(tree) + return finder.found + + +def _code_contains_strings(code: str, patterns: list[str]) -> bool: + """Check if code contains any of the given string patterns. + + Args: + code: Python source code to search + patterns: List of string patterns to look for + + Returns: + True if any pattern is found in the code, False otherwise + """ + clean_code = _strip_comments(code) + return any(pattern in clean_code for pattern in patterns) + + +def _uses_pyplot_show(code: str) -> bool: + """Check if code calls plt.show() or similar show() methods. + + Uses AST analysis to robustly detect show() calls regardless of import + aliases (e.g., `import matplotlib.pyplot as mpl`). AST approach detects + actual method calls, avoiding false positives from string literals. + Falls back to string matching only if code doesn't parse. + + Note: May have false positives if code calls .show() on non-matplotlib objects. + """ + if _find_attribute_calls(code, ["show"]): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings(code, ["plt.show", ".show()"]) + return False + + +def _sets_headless_backend(code: str) -> bool: + """Check if code sets matplotlib to use a headless backend. + + Uses AST analysis to detect matplotlib.use() calls with headless backends. + Handles various matplotlib import styles and fallback to string matching. + """ + if not _find_function_calls(code, ["matplotlib.use"]): + return False + + try: + tree = ast.parse(code) + except (SyntaxError, ValueError): + backend_patterns = [f"matplotlib.use('{b}')" for b in HEADLESS_BACKENDS] + [ + f'matplotlib.use("{b}")' for b in HEADLESS_BACKENDS + ] + return _code_contains_strings(code, backend_patterns) + + class BackendFinder(ast.NodeVisitor): + def __init__(self): + self.has_headless = False + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute): + if node.func.attr == "use": + if isinstance(node.func.value, ast.Name): + if node.func.value.id == "matplotlib": + if node.args and isinstance(node.args[0], ast.Constant): + if node.args[0].value in HEADLESS_BACKENDS: + self.has_headless = True + self.generic_visit(node) + + finder = BackendFinder() + finder.visit(tree) + return finder.has_headless + + +def _uses_pyplot_plot(code: str) -> bool: + """Check if code calls pyplot plotting functions. + + Uses AST analysis to detect plot-related method calls. Handles import + aliases and detects actual method calls, avoiding false positives from + string literals or method references. Falls back to string matching + only if code doesn't parse. + + Note: May have false positives if code calls these methods on non-pyplot objects. + For accuracy, combine with matplotlib import checks. + """ + if _find_attribute_calls(code, list(PYPLOT_PLOT_METHODS)): + return True + try: + ast.parse(code) + except (SyntaxError, ValueError): + return _code_contains_strings(code, [f".{m}(" for m in PYPLOT_PLOT_METHODS]) + return False + + +def _calls_savefig(code: str) -> bool: + """Check if code calls plt.savefig() or fig.savefig(). + + Uses AST analysis to robustly detect savefig() calls regardless of + how matplotlib was imported. Detects actual method calls, avoiding + false positives from string literals. Falls back to string matching + only if code doesn't parse. + + Note: May have false positives if code calls savefig() on non-matplotlib objects. + """ + if _find_attribute_calls(code, ["savefig"]): + return True + return _code_contains_strings(code, ["savefig"]) + + +def _make_matplotlib_headless_validator( + output_path: str | None = None, + show_patterns: list[str] | None = None, + backend_patterns: list[str] | None = None, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks matplotlib uses headless backend. + + This validator checks if code calls plt.show() without setting a headless + backend. It differs from _sets_headless_backend in that it combines both + checks into a single validation that flags the specific error condition. + + Args: + output_path: Path where plots should be saved + show_patterns: Patterns indicating plt.show() calls; defaults to ["plt.show", ".show()"] + backend_patterns: Patterns indicating headless backend setup; defaults to all + matplotlib.use() calls with HEADLESS_BACKENDS + """ + if show_patterns is None: + show_patterns = ["plt.show", ".show()"] + if backend_patterns is None: + backend_patterns = [f"matplotlib.use('{b}')" for b in HEADLESS_BACKENDS] + [ + f'matplotlib.use("{b}")' for b in HEADLESS_BACKENDS + ] + + def validate(ctx: Context) -> ValidationResult: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: + return ValidationResult(result=True) + + code = extraction_result.reason + has_show = _code_contains_strings(code, show_patterns) + has_backend = _code_contains_strings(code, backend_patterns) + + if has_show and not has_backend: + savefig_instruction = ( + f"plt.savefig('{output_path}'); plt.close()" + if output_path + else "plt.savefig(''); plt.close()" + ) + return ValidationResult( + result=False, + reason=f"Your code calls `plt.show()` but doesn't set a headless backend.\n" + f"This will fail in a headless environment (no display).\n\n" + f"Fix this by adding to the top of your code:\n" + f" import matplotlib\n" + f" matplotlib.use('Agg')\n\n" + f"Then replace `plt.show()` with `{savefig_instruction}`", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_plots_saved_validator( + output_path: str | None = None, + plot_patterns: list[str] | None = None, + save_patterns: list[str] | None = None, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks if code saves plots to a file. + + Args: + output_path: Path where plots should be saved + plot_patterns: Patterns indicating plot creation; defaults to all PYPLOT_PLOT_METHODS + prefixed with "plt." and "." + save_patterns: Patterns indicating plot saving; defaults to ["savefig"] + """ + if plot_patterns is None: + plot_patterns = [f"plt.{m}" for m in PYPLOT_PLOT_METHODS] + [ + f".{m}(" for m in PYPLOT_PLOT_METHODS + ] + if save_patterns is None: + save_patterns = ["savefig"] + + def validate(ctx: Context) -> ValidationResult: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: + return ValidationResult(result=True) + + code = extraction_result.reason + has_plot = _code_contains_strings(code, plot_patterns) + has_save = _code_contains_strings(code, save_patterns) + + if has_plot and not has_save: + savefig_instruction = ( + f"plt.savefig('{output_path}')\n plt.close()" + if output_path + else "plt.savefig('')\n plt.close()" + ) + return ValidationResult( + result=False, + reason=f"Your code creates plots with pyplot but never calls `plt.savefig()` to save them.\n\n" + f"Add this before your plotting code or at the end:\n" + f" {savefig_instruction}", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_output_artifacts_validator( + output_path: str, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks if output file exists post-execution.""" + + def validate(ctx: Context) -> ValidationResult: + path = Path(output_path) + + if not path.exists(): + return ValidationResult( + result=False, + reason=f"The output file '{output_path}' was not created during execution.\n" + f"Make sure your code calls `plt.savefig('{output_path}')` to save the plot.", + ) + + if path.stat().st_size == 0: + return ValidationResult( + result=False, + reason=f"The output file '{output_path}' exists but is empty.\n" + f"Check that your plot code executed correctly.", + ) + + return ValidationResult(result=True) + + return validate + + +def python_plotting_requirements( + output_path: str | None = None, + *, + check_output_artifacts: bool | None = None, + show_patterns: list[str] | None = None, + backend_patterns: list[str] | None = None, + plot_patterns: list[str] | None = None, + save_patterns: list[str] | None = None, +) -> list[Requirement]: + """Build plotting-specific requirements for Python tool responses. + + Args: + output_path: Path where plots should be saved + check_output_artifacts: Whether to verify the output file exists; defaults to False + show_patterns: Patterns indicating plt.show() calls; defaults to ["plt.show", ".show()"] + backend_patterns: Patterns indicating headless backend setup; defaults to all + matplotlib.use() calls with HEADLESS_BACKENDS + plot_patterns: Patterns indicating plot creation; defaults to all PYPLOT_PLOT_METHODS + prefixed with "plt." and "." + save_patterns: Patterns indicating plot saving; defaults to ["savefig"] + + Returns: + List of Requirement objects that validate matplotlib usage and plot output. + """ + reqs: list[Requirement] = [] + + reqs.append( + Requirement( + description="If using pyplot, must set headless backend and use savefig.", + validation_fn=_make_matplotlib_headless_validator( + output_path, + show_patterns=show_patterns, + backend_patterns=backend_patterns, + ), + check_only=False, + ) + ) + + reqs.append( + Requirement( + description="If creating plots, must call savefig to save them.", + validation_fn=_make_plots_saved_validator( + output_path, plot_patterns=plot_patterns, save_patterns=save_patterns + ), + check_only=False, + ) + ) + + if check_output_artifacts and output_path: + reqs.append( + Requirement( + description=f"Output file must be created at {output_path}", + validation_fn=_make_output_artifacts_validator(output_path), + check_only=False, + ) + ) + + return reqs diff --git a/mellea/stdlib/requirements/python_reqs.py b/mellea/stdlib/requirements/python_reqs.py index 3152acb71..5efa223a2 100644 --- a/mellea/stdlib/requirements/python_reqs.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -58,15 +58,19 @@ def _score_code_block(code: str) -> int: return score -def _has_python_code_listing(ctx: Context) -> ValidationResult: - """Extract Python code from context.""" - last_output = ctx.last_output() - if last_output is None or last_output.value is None: - return ValidationResult(result=False, reason="No output found in context") +def _extract_markdown_python_code(content: str) -> ValidationResult: + """Extract best Python code block from markdown content. - content = last_output.value + Searches for both ```python ... ``` and generic ``` ... ``` blocks, + scores them by code quality, and returns the highest-scoring block. - # Look for code blocks with python specifier + Args: + content: Text content to search for code blocks. + + Returns: + ValidationResult with result=True and the code as reason if blocks found, + or result=False if no code blocks found. + """ import re # Pattern for ```python ... ``` blocks @@ -98,6 +102,52 @@ def _has_python_code_listing(ctx: Context) -> ValidationResult: return ValidationResult(result=True, reason=best_block[0]) +def extract_python_code(ctx: Context) -> ValidationResult: + """Extract Python code from tool calls or markdown code blocks. + + Checks for code in two places (in order of priority): + 1. Direct python tool calls (used by python interpreter tool) + 2. Markdown ```python or ``` code blocks in text responses + + This function is used by requirements validators that may be called before + or after tool invocation, so it checks both sources. + + Args: + ctx: Context object containing the LLM output to extract code from. + + Returns: + ValidationResult with result=True and the code as reason if extraction succeeds, + or result=False with an error message if no code is found. + """ + last_output = ctx.last_output() + if last_output is None: + return ValidationResult(result=False, reason="No output found in context") + + if last_output.tool_calls and "python" in last_output.tool_calls: + tool_call = last_output.tool_calls["python"] + code = tool_call.args.get("code") + if isinstance(code, str) and code.strip(): + return ValidationResult(result=True, reason=code) + + if last_output.value is None: + return ValidationResult(result=False, reason="No output found in context") + + return _extract_markdown_python_code(last_output.value) + + +def _has_python_code_listing(ctx: Context) -> ValidationResult: + """Extract Python code from markdown code blocks in context only. + + Similar to extract_python_code but does not check tool calls. Used internally + by execution validators that always work with markdown code blocks. + """ + last_output = ctx.last_output() + if last_output is None or last_output.value is None: + return ValidationResult(result=False, reason="No output found in context") + + return _extract_markdown_python_code(last_output.value) + + # endregion # region execution validation diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py new file mode 100644 index 000000000..57f6b2265 --- /dev/null +++ b/mellea/stdlib/requirements/python_tools.py @@ -0,0 +1,220 @@ +"""Requirement factories for Python tool invocation and code validation. + +This module provides requirements for validating Python code, including syntax, +imports, and plotting. The python_tool_requirements() function bundles these +together, while specialized validators can be used independently. +""" + +from collections.abc import Callable + +from ...core import Context, Requirement, ValidationResult +from ..tools.interpreter import StaticAnalysisEnvironment, get_unauthorized_imports +from .plotting import python_plotting_requirements +from .python_reqs import extract_python_code +from .tool_reqs import tool_arg_validator, uses_tool + + +def _code_parses(code: str) -> tuple[bool, str | None]: + """Check if code parses as valid Python using StaticAnalysisEnvironment. + + Validates syntax without executing code. Reuses StaticAnalysisEnvironment + to avoid duplicating AST parsing logic. + + Returns: + (True, None) if code parses + (False, error_message) if syntax error + """ + env = StaticAnalysisEnvironment(allowed_imports=None) + result = env.execute(code, timeout=0) + + if not result.success and isinstance(result.analysis_result, SyntaxError): + e = result.analysis_result + error_msg = f"Syntax error at line {e.lineno}: {e.msg}" + if e.text: + error_msg += f"\n {e.text.rstrip()}" + if e.offset: + error_msg += "\n " + " " * (e.offset - 1) + "^" + return False, error_msg + + return True, None + + +# region Individual Requirement Validators + + +def _python_code_arg_present(arg_value: object) -> bool: + """Return True when the python tool code argument is present and non-empty.""" + return isinstance(arg_value, str) and bool(arg_value.strip()) + + +def _make_code_parses_validator() -> Callable[[Context], ValidationResult]: + """Create a validator that checks if extracted code parses. + + This validator searches for Python code in the context, checking both + direct python tool calls and markdown code blocks. The python tool is + invoked synchronously as part of the LLM's response generation, so code + is available in the context for validation at check time. + """ + + def validate(ctx: Context) -> ValidationResult: + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: + return ValidationResult( + result=False, + reason=( + "Could not extract Python code from your response. " + "Make sure to include code in the python tool call or " + "in ```python ... ``` blocks." + ), + ) + + parses, error = _code_parses(extraction_result.reason) + if not parses: + return ValidationResult( + result=False, + reason=f"Your code contains a syntax error. {error}\n\nPlease fix the syntax and try again.", + ) + + return ValidationResult(result=True) + + return validate + + +def _make_imports_allowed_validator( + allowed_imports: list[str] | None, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks if code imports are in allowlist. + + This validator extracts Python code from the context (tool calls or markdown + blocks) and checks that all imports are in the allowed list. + """ + + def validate(ctx: Context) -> ValidationResult: + if allowed_imports is None: + return ValidationResult(result=True) + + extraction_result = extract_python_code(ctx) + if not extraction_result.as_bool() or extraction_result.reason is None: + return ValidationResult( + result=False, reason="Could not extract Python code" + ) + + unauthorized = get_unauthorized_imports( + extraction_result.reason, allowed_imports + ) + if unauthorized: + allowed_str = ", ".join(sorted(set(allowed_imports))) + return ValidationResult( + result=False, + reason=( + f"Your code imports forbidden modules: " + f"{', '.join(sorted(set(unauthorized)))}.\n" + f"You may only import: {allowed_str}\n" + f"Please rewrite your code without these imports." + ), + ) + + return ValidationResult(result=True) + + return validate + + +def _make_output_limit_validator( + limit_bytes: int, +) -> Callable[[Context], ValidationResult]: + """Create a validator that checks stdout/stderr size limits.""" + + def validate(ctx: Context) -> ValidationResult: + output = ctx.last_output() + if output is None: + return ValidationResult(result=True) + + stdout = getattr(output, "stdout", "") + stderr = getattr(output, "stderr", "") + total_output = "" + if isinstance(stdout, str): + total_output += stdout + if isinstance(stderr, str): + total_output += stderr + + size = len(total_output.encode("utf-8")) + if size > limit_bytes: + return ValidationResult( + result=False, + reason=f"Your code produced {size} bytes of output, exceeding the limit of {limit_bytes} bytes.\n" + f"Add output limiting (e.g., redirect to /dev/null) or optimize your code.", + ) + + return ValidationResult(result=True) + + return validate + + +# endregion + + +def python_tool_requirements( + output_path: str | None = None, + allowed_imports: list[str] | None = None, + output_limit_bytes: int = 50_000, + check_output_artifacts: bool | None = None, +) -> list[Requirement]: + """Build requirements for Python code generation via the python tool. + + Args: + output_path: Path where plotting output should be saved; enables plot-related checks. + allowed_imports: List of allowed import module names; if provided, code must only import these. + output_limit_bytes: Maximum bytes for stdout/stderr combined; defaults to 50KB. + check_output_artifacts: Whether to verify output file exists after execution; auto-enabled if output_path is set. + + Returns: + List of Requirement objects that validate python tool usage and code correctness. + """ + reqs: list[Requirement] = [] + + if check_output_artifacts is None: + check_output_artifacts = output_path is not None + + reqs.append(uses_tool("python")) + + reqs.append( + tool_arg_validator( + description="The python tool call must include a code argument.", + tool_name="python", + arg_name="code", + validation_fn=_python_code_arg_present, + ) + ) + + reqs.append( + Requirement( + description="The Python code must parse correctly.", + validation_fn=_make_code_parses_validator(), + check_only=False, + ) + ) + + if allowed_imports is not None: + reqs.append( + Requirement( + description=f"Imports must be from allowed list: {', '.join(allowed_imports)}", + validation_fn=_make_imports_allowed_validator(allowed_imports), + check_only=False, + ) + ) + + reqs.extend( + python_plotting_requirements( + output_path=output_path, check_output_artifacts=check_output_artifacts + ) + ) + + reqs.append( + Requirement( + description=f"Output must not exceed {output_limit_bytes} bytes.", + validation_fn=_make_output_limit_validator(output_limit_bytes), + check_only=False, + ) + ) + + return reqs diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 4c537cc32..f9ad46fe2 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -148,7 +148,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: ) if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + unauthorized = get_unauthorized_imports(code, self.allowed_imports) if unauthorized: return ExecutionResult( success=False, @@ -185,7 +185,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: unexpected error occurs. """ if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + unauthorized = get_unauthorized_imports(code, self.allowed_imports) if unauthorized: return ExecutionResult( success=False, @@ -259,7 +259,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: flag, or a skipped result on import violation or sandbox error. """ if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + unauthorized = get_unauthorized_imports(code, self.allowed_imports) if unauthorized: return ExecutionResult( success=False, @@ -301,8 +301,22 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: ) -def _get_unauthorized_imports(code: str, allowed_imports: list[str]) -> list[str]: - """Get list of unauthorized imports used in code.""" +def get_unauthorized_imports(code: str, allowed_imports: list[str] | None) -> list[str]: + """Get list of unauthorized imports used in code. + + Analyzes Python code to extract imports and checks them against an allowlist. + Handles both `import X` and `from X import Y` style imports. + + Args: + code: Python source code to analyze + allowed_imports: List of allowed top-level module names; if None, allows any import + + Returns: + Sorted list of unauthorized top-level modules, or empty list if all imports allowed + """ + if allowed_imports is None: + return [] + unauthorized: list[str] = [] try: tree = ast.parse(code) @@ -326,12 +340,12 @@ def _get_unauthorized_imports(code: str, allowed_imports: list[str]) -> list[str and base_module not in unauthorized ): unauthorized.append(base_module) - return unauthorized + return sorted(unauthorized) def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: """Check if code only uses allowed imports.""" - return len(_get_unauthorized_imports(code, allowed_imports)) == 0 + return len(get_unauthorized_imports(code, allowed_imports)) == 0 def code_interpreter(code: str) -> ExecutionResult: diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py new file mode 100644 index 000000000..13f3000ae --- /dev/null +++ b/test/stdlib/requirements/test_python_tools.py @@ -0,0 +1,715 @@ +"""Tests for Python tool requirements bundle.""" + +import tempfile +from collections.abc import Callable +from pathlib import Path +from typing import Any, cast + +from mellea.core import ( + Context, + ModelOutputThunk, + ModelToolCall, + Requirement, + ValidationResult, +) +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.plotting.matplotlib import ( + _calls_savefig, + _sets_headless_backend, + _uses_pyplot_plot, + _uses_pyplot_show, +) +from mellea.stdlib.requirements.python_tools import ( + _code_parses, + python_tool_requirements, +) +from mellea.stdlib.tools.interpreter import get_unauthorized_imports + + +def from_model(content: str) -> Context: + """Helper to create context from model output.""" + ctx = ChatContext() + ctx = ctx.add(ModelOutputThunk(value=content)) + return ctx + + +def requirement_description(requirement: Requirement) -> str: + """Return a non-optional description for test filtering.""" + assert requirement.description is not None + return requirement.description + + +def validation_fn(requirement: Requirement) -> Callable[[Context], ValidationResult]: + """Return a requirement validation function for tests.""" + assert requirement.validation_fn is not None + return requirement.validation_fn + + +def validation_reason(result: ValidationResult) -> str: + """Return a non-optional validation reason for assertions.""" + assert result.reason is not None + return result.reason + + +class _DummyTool: + """Minimal tool stub for constructing ModelToolCall in tests.""" + + def run(self, **kwargs: Any) -> None: + return None + + +def python_tool_call(code: str | None = None, **extra_args: str) -> ModelToolCall: + """Create a typed python tool call for tests.""" + args: dict[str, str] = dict(extra_args) + if code is not None: + args["code"] = code + return ModelToolCall(name="python", func=cast(Any, _DummyTool()), args=args) + + +def requirements_matching( + substring: str, + *, + output_path: str | None = None, + allowed_imports: list[str] | None = None, + output_limit_bytes: int = 50_000, + check_output_artifacts: bool | None = None, +) -> list[Requirement]: + """Return requirements whose description contains the substring.""" + return [ + requirement + for requirement in python_tool_requirements( + output_path=output_path, + allowed_imports=allowed_imports, + output_limit_bytes=output_limit_bytes, + check_output_artifacts=check_output_artifacts, + ) + if substring in requirement_description(requirement).lower() + ] + + +# region: Helper function tests + + +class TestCodeParses: + """Tests for _code_parses helper.""" + + def test_valid_code_parses(self): + """Valid Python code should parse.""" + code = "x = 1\nprint(x)" + parses, error = _code_parses(code) + assert parses is True + assert error is None + + def test_syntax_error_detected(self): + """Syntax errors should be detected.""" + code = "def foo(\n return 42" + parses, error = _code_parses(code) + assert parses is False + assert error is not None + assert "Syntax error" in error + + def test_missing_colon(self): + """Missing colons should be detected.""" + code = "if True\n print('hello')" + parses, error = _code_parses(code) + assert parses is False + assert error is not None + + +class TestUnauthorizedImports: + """Tests for get_unauthorized_imports helper.""" + + def test_no_imports_allowed_always(self): + """Code with no imports should always pass.""" + code = "x = 1" + unauthorized = get_unauthorized_imports(code, ["numpy"]) + assert unauthorized == [] + + def test_allowed_imports(self): + """Allowed imports should not be flagged.""" + code = "import numpy\nimport pandas" + unauthorized = get_unauthorized_imports(code, ["numpy", "pandas"]) + assert unauthorized == [] + + def test_unauthorized_import_detected(self): + """Unauthorized imports should be detected.""" + code = "import subprocess" + unauthorized = get_unauthorized_imports(code, ["numpy", "pandas"]) + assert "subprocess" in unauthorized + + def test_none_allows_all(self): + """allowed_imports=None should allow any import.""" + code = "import subprocess\nimport socket\nimport os" + unauthorized = get_unauthorized_imports(code, None) + assert unauthorized == [] + + def test_nested_imports(self): + """Only top-level module checked for nested imports.""" + code = "import numpy.random" + unauthorized = get_unauthorized_imports(code, ["numpy"]) + assert unauthorized == [] + + def test_from_import(self): + """from ... import should be checked.""" + code = "from matplotlib import pyplot" + unauthorized = get_unauthorized_imports(code, ["numpy"]) + assert "matplotlib" in unauthorized + + def test_multiple_unauthorized(self): + """Multiple unauthorized imports should be detected.""" + code = "import subprocess\nimport socket" + unauthorized = get_unauthorized_imports(code, ["numpy"]) + assert len(unauthorized) == 2 + assert "subprocess" in unauthorized + assert "socket" in unauthorized + + +class TestMatplotlibDetection: + """Tests for matplotlib-related detection functions.""" + + def test_plt_show_detected(self): + """plt.show() should be detected.""" + code = "plt.plot([1, 2, 3])\nplt.show()" + assert _uses_pyplot_show(code) is True + + def test_plt_show_not_detected(self): + """Code without plt.show() should pass.""" + code = "plt.plot([1, 2, 3])\nplt.savefig('plot.png')" + assert _uses_pyplot_show(code) is False + + def test_headless_backend_agg(self): + """Agg backend should be detected.""" + code = "import matplotlib\nmatplotlib.use('Agg')" + assert _sets_headless_backend(code) is True + + def test_headless_backend_variations(self): + """Various headless backends should be detected.""" + for backend in ["Agg", "Svg", "Cairo", "PDF", "PS"]: + code = f"matplotlib.use('{backend}')" + assert _sets_headless_backend(code) is True + + def test_non_headless_backend(self): + """Non-headless backends should not be detected.""" + code = "matplotlib.use('TkAgg')" + assert _sets_headless_backend(code) is False + + def test_uses_pyplot_plot(self): + """Plotting functions should be detected.""" + for func in ["plt.plot", "plt.bar", "plt.scatter", ".plot("]: + code = f"import matplotlib.pyplot as plt\n{func}([1, 2, 3])" + assert _uses_pyplot_plot(code) is True + + def test_savefig_detected(self): + """savefig call should be detected.""" + code = "plt.plot([1, 2, 3])\nplt.savefig('plot.png')" + assert _calls_savefig(code) is True + + def test_savefig_not_detected(self): + """Code without savefig should not be detected.""" + code = "plt.plot([1, 2, 3])\nplt.show()" + assert _calls_savefig(code) is False + + +# endregion + + +# region: python_tool_requirements tests + + +class TestPythonToolRequirementsBasic: + """Basic tests for python_tool_requirements.""" + + def test_initialization(self): + """Factory should return requirements with default settings.""" + requirements = python_tool_requirements() + assert requirements + assert len(requirements) > 0 + + def test_with_output_path_enables_artifact_requirement(self): + """Output path should enable artifact validation by default.""" + requirements = python_tool_requirements(output_path="/tmp/plot.png") + artifact_reqs = [ + r + for r in requirements + if "output file" in requirement_description(r).lower() + ] + assert len(artifact_reqs) == 1 + + def test_without_output_path_disables_artifact_requirement(self): + """Artifact validation should be absent without output_path.""" + requirements = python_tool_requirements() + artifact_reqs = [ + r + for r in requirements + if "output file" in requirement_description(r).lower() + ] + assert len(artifact_reqs) == 0 + + def test_with_allowed_imports_adds_import_requirement(self): + """Allowed imports should add an import validation requirement.""" + requirements = python_tool_requirements(allowed_imports=["numpy", "matplotlib"]) + import_reqs = [ + r for r in requirements if "import" in requirement_description(r).lower() + ] + assert len(import_reqs) > 0 + + +# endregion + + +# region: Individual requirement validation tests + + +class TestMustInvokePythonTool: + """Tests for MustInvokePythonTool requirement.""" + + def test_tool_not_called(self): + """Should fail if python tool not called.""" + req = python_tool_requirements()[0] + + ctx = from_model("Here is the code:\n```python\nprint('hello')\n```") + result = validation_fn(req)(ctx) + + assert result.as_bool() is False + reason_lower = validation_reason(result).lower() + assert "no tool calls" in reason_lower or "did not call" in reason_lower + + def test_python_tool_called(self): + """Should pass if python tool is called.""" + req = python_tool_requirements()[0] + + ctx = ChatContext() + output = ModelOutputThunk( + value="I'll execute this code", + tool_calls={"python": python_tool_call("print('hi')")}, + ) + ctx = ctx.add(output) + + result = validation_fn(req)(ctx) + assert result.as_bool() is True + + +class TestPythonToolHasCodeArg: + """Tests for PythonToolHasCodeArg requirement.""" + + def test_missing_code_argument(self): + """Should fail if python tool call has no code argument.""" + req = python_tool_requirements()[1] + + ctx = ChatContext() + output = ModelOutputThunk( + value="I'll execute this", + tool_calls={"python": python_tool_call(other="value")}, + ) + ctx = ctx.add(output) + + result = validation_fn(req)(ctx) + assert result.as_bool() is False + assert "code" in validation_reason(result).lower() + + def test_has_code_argument(self): + """Should pass if python tool call has code argument.""" + req = python_tool_requirements()[1] + + ctx = ChatContext() + output = ModelOutputThunk( + value="I'll execute this", + tool_calls={"python": python_tool_call("print('hi')")}, + ) + ctx = ctx.add(output) + + result = validation_fn(req)(ctx) + assert result.as_bool() is True + + +class TestCodeParsesRequirement: + """Tests for code parsing requirement.""" + + def test_valid_code(self): + """Valid code should pass.""" + parse_reqs = requirements_matching("parse") + parse_req = parse_reqs[0] + + ctx = from_model("```python\nx = 1\nprint(x)\n```") + result = validation_fn(parse_req)(ctx) + + assert result.as_bool() is True + + def test_syntax_error(self): + """Syntax errors should be caught.""" + parse_reqs = requirements_matching("parse") + parse_req = parse_reqs[0] + + ctx = from_model("```python\ndef foo(\n return 42\n```") + result = validation_fn(parse_req)(ctx) + + assert result.as_bool() is False + assert "syntax" in validation_reason(result).lower() + + def test_valid_code_from_tool_calls(self): + """Valid code in tool_calls should parse.""" + parse_reqs = requirements_matching("parse") + parse_req = parse_reqs[0] + + # Create context with tool_calls instead of markdown + ctx = ChatContext() + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call("x = 1\nprint(x)")} + ) + ctx = ctx.add(output) + + result = validation_fn(parse_req)(ctx) + + assert result.as_bool() is True + + def test_syntax_error_from_tool_calls(self): + """Syntax errors in tool_calls should be caught.""" + parse_reqs = requirements_matching("parse") + parse_req = parse_reqs[0] + + # Create context with tool_calls containing syntax error + ctx = ChatContext() + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call("def foo(\n return 42")} + ) + ctx = ctx.add(output) + + result = validation_fn(parse_req)(ctx) + + assert result.as_bool() is False + assert "syntax" in validation_reason(result).lower() + + +class TestImportAllowlistRequirement: + """Tests for import allowlist requirement.""" + + def test_allowed_imports(self): + """Allowed imports should pass.""" + allowed = ["numpy", "matplotlib"] + import_reqs = [ + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() + ] + assert len(import_reqs) > 0 + import_req = import_reqs[0] + + ctx = from_model( + "```python\nimport numpy\nimport matplotlib.pyplot as plt\n```" + ) + result = validation_fn(import_req)(ctx) + + assert result.as_bool() is True + + def test_unauthorized_imports(self): + """Unauthorized imports should fail.""" + allowed = ["numpy"] + import_reqs = [ + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() + ] + import_req = import_reqs[0] + + ctx = from_model("```python\nimport subprocess\n```") + result = validation_fn(import_req)(ctx) + + assert result.as_bool() is False + assert "subprocess" in validation_reason(result) + + def test_allowed_imports_from_tool_calls(self): + """Allowed imports in tool_calls should pass.""" + allowed = ["numpy", "matplotlib"] + import_reqs = [ + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() + ] + import_req = import_reqs[0] + + # Create context with tool_calls + ctx = ChatContext() + code = "import numpy\nimport matplotlib.pyplot as plt\nprint(numpy.pi)" + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call(code)} + ) + ctx = ctx.add(output) + + result = validation_fn(import_req)(ctx) + + assert result.as_bool() is True + + def test_unauthorized_imports_from_tool_calls(self): + """Unauthorized imports in tool_calls should fail.""" + allowed = ["numpy"] + import_reqs = [ + r + for r in python_tool_requirements(allowed_imports=allowed) + if "import" in requirement_description(r).lower() + ] + import_req = import_reqs[0] + + # Create context with tool_calls + ctx = ChatContext() + output = ModelOutputThunk( + value="", tool_calls={"python": python_tool_call("import subprocess\n")} + ) + ctx = ctx.add(output) + + result = validation_fn(import_req)(ctx) + + assert result.as_bool() is False + assert "subprocess" in validation_reason(result) + + +class TestMatplotlibHeadlessRequirement: + """Tests for matplotlib headless backend requirement.""" + + def test_plt_show_without_backend(self): + """plt.show() without headless backend should fail.""" + matplotlib_reqs = requirements_matching("headless") + assert len(matplotlib_reqs) > 0 + matplotlib_req = matplotlib_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.show()\n```") + result = validation_fn(matplotlib_req)(ctx) + + assert result.as_bool() is False + assert "headless" in validation_reason( + result + ).lower() or "Agg" in validation_reason(result) + + def test_plt_show_with_backend(self): + """plt.show() with headless backend should pass.""" + matplotlib_reqs = requirements_matching("headless") + matplotlib_req = matplotlib_reqs[0] + + ctx = from_model( + "```python\n" + "import matplotlib\n" + "matplotlib.use('Agg')\n" + "plt.plot([1, 2, 3])\n" + "plt.show()\n" + "```" + ) + result = validation_fn(matplotlib_req)(ctx) + + assert result.as_bool() is True + + def test_no_plt_show(self): + """Code without plt.show() should pass.""" + matplotlib_reqs = requirements_matching("headless") + matplotlib_req = matplotlib_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.savefig('plot.png')\n```") + result = validation_fn(matplotlib_req)(ctx) + + assert result.as_bool() is True + + +class TestPlotsAreSavedRequirement: + """Tests for plots must be saved requirement.""" + + def test_plot_without_savefig(self): + """Plotting without savefig should fail.""" + plot_reqs = requirements_matching("savefig") + assert len(plot_reqs) > 0 + plot_req = plot_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.show()\n```") + result = validation_fn(plot_req)(ctx) + + assert result.as_bool() is False + assert "savefig" in validation_reason(result).lower() + + def test_plot_with_savefig(self): + """Plotting with savefig should pass.""" + plot_reqs = requirements_matching("savefig") + plot_req = plot_reqs[0] + + ctx = from_model("```python\nplt.plot([1, 2, 3])\nplt.savefig('plot.png')\n```") + result = validation_fn(plot_req)(ctx) + + assert result.as_bool() is True + + def test_no_plotting(self): + """Code without plotting should pass.""" + plot_reqs = requirements_matching("savefig") + plot_req = plot_reqs[0] + + ctx = from_model("```python\nx = 1\nprint(x)\n```") + result = validation_fn(plot_req)(ctx) + + assert result.as_bool() is True + + +class TestOutputArtifactsRequirement: + """Tests for output artifacts requirement.""" + + def test_output_file_not_created(self): + """Should fail if output file not created.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + + artifact_reqs = requirements_matching( + "output file", output_path=output_path + ) + assert len(artifact_reqs) > 0 + artifact_req = artifact_reqs[0] + + ctx = from_model("Code ran successfully") + result = validation_fn(artifact_req)(ctx) + + assert result.as_bool() is False + assert output_path in validation_reason(result) + + def test_output_file_exists_and_nonempty(self): + """Should pass if output file exists and is non-empty.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + Path(output_path).write_bytes(b"fake png data") + + artifact_reqs = requirements_matching( + "output file", output_path=output_path + ) + artifact_req = artifact_reqs[0] + + ctx = from_model("Code ran successfully") + result = validation_fn(artifact_req)(ctx) + + assert result.as_bool() is True + + def test_output_file_empty(self): + """Should fail if output file exists but is empty.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = str(Path(tmpdir) / "plot.png") + Path(output_path).write_bytes(b"") + + artifact_reqs = requirements_matching( + "output file", output_path=output_path + ) + artifact_req = artifact_reqs[0] + + ctx = from_model("Code ran successfully") + result = validation_fn(artifact_req)(ctx) + + assert result.as_bool() is False + assert "empty" in validation_reason(result).lower() + + def test_output_artifact_disabled_without_output_path(self): + """Output artifact requirement should not be present without output_path.""" + artifact_reqs = requirements_matching("output file") + assert len(artifact_reqs) == 0 + + +class TestOutputLimitValidator: + """Tests for _make_output_limit_validator.""" + + def test_empty_output_passes(self): + """No stdout/stderr should pass.""" + ctx = ChatContext().add(ModelOutputThunk(value="response")) + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) + assert len(limit_reqs) > 0 + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is True + + def test_output_within_limit_passes(self): + """Output under limit should pass.""" + output = ModelOutputThunk(value="response") + setattr(output, "stdout", "x" * 500) + setattr(output, "stderr", "") + ctx = ChatContext().add(output) + + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is True + + def test_stdout_exceeds_limit_fails(self): + """Stdout exceeding limit should fail.""" + output = ModelOutputThunk(value="response") + setattr(output, "stdout", "x" * 1500) + setattr(output, "stderr", "") + ctx = ChatContext().add(output) + + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is False + assert "exceeding" in validation_reason(result).lower() + assert "1500" in validation_reason(result) + + def test_stderr_exceeds_limit_fails(self): + """Stderr exceeding limit should fail.""" + output = ModelOutputThunk(value="response") + setattr(output, "stdout", "") + setattr(output, "stderr", "e" * 1500) + ctx = ChatContext().add(output) + + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is False + assert "exceeding" in validation_reason(result).lower() + + def test_combined_output_exceeds_limit_fails(self): + """Combined stdout+stderr exceeding limit should fail.""" + output = ModelOutputThunk(value="response") + setattr(output, "stdout", "x" * 600) + setattr(output, "stderr", "e" * 600) # Combined: 1200 bytes + ctx = ChatContext().add(output) + + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is False + assert "1200" in validation_reason(result) + + def test_utf8_multibyte_characters_counted_correctly(self): + """Multibyte UTF-8 characters should be counted in bytes, not chars.""" + output = ModelOutputThunk(value="response") + setattr(output, "stdout", "🎉" * 100) # 4 bytes per emoji = 400 bytes + setattr(output, "stderr", "") + ctx = ChatContext().add(output) + + limit_reqs = requirements_matching("exceed", output_limit_bytes=300) + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is False # 400 > 300 + assert "exceeding" in validation_reason(result).lower() + + def test_limit_at_boundary(self): + """Output exactly at limit should pass.""" + output = ModelOutputThunk(value="response") + setattr(output, "stdout", "x" * 1000) # Exactly 1000 bytes + setattr(output, "stderr", "") + ctx = ChatContext().add(output) + + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is True + + def test_limit_one_byte_over_boundary_fails(self): + """Output one byte over limit should fail.""" + output = ModelOutputThunk(value="response") + setattr(output, "stdout", "x" * 1001) # 1001 bytes + setattr(output, "stderr", "") + ctx = ChatContext().add(output) + + limit_reqs = requirements_matching("exceed", output_limit_bytes=1000) + limit_req = limit_reqs[0] + + result = validation_fn(limit_req)(ctx) + assert result.as_bool() is False + + +# endregion