From 204ef4045f747777133e3422f707fdf4bff5120f Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Thu, 21 May 2026 15:49:13 -0400 Subject: [PATCH 1/3] python requrements Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/__init__.py | 12 + mellea/stdlib/requirements/python_tools.py | 318 ++++++++++++++ test/stdlib/requirements/test_python_tools.py | 399 ++++++++++++++++++ 3 files changed, 729 insertions(+) create mode 100644 mellea/stdlib/requirements/python_tools.py create mode 100644 test/stdlib/requirements/test_python_tools.py diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index c0bd7d3c9..8b932b0ad 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -4,6 +4,13 @@ from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table from .python_reqs import PythonExecutionReq +from .python_tools import ( + ImportRestrictions, + OutputSizeLimit, + PythonCodeExtraction, + PythonSyntaxValid, + python_tool_requirements, +) from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -17,8 +24,12 @@ __all__ = [ "ALoraRequirement", + "ImportRestrictions", "LLMaJRequirement", + "OutputSizeLimit", + "PythonCodeExtraction", "PythonExecutionReq", + "PythonSyntaxValid", "Requirement", "ValidationResult", "as_markdown_list", @@ -26,6 +37,7 @@ "default_output_to_bool", "is_markdown_list", "is_markdown_table", + "python_tool_requirements", "req", "reqify", "requirement_check_to_bool", diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py new file mode 100644 index 000000000..48d1e0200 --- /dev/null +++ b/mellea/stdlib/requirements/python_tools.py @@ -0,0 +1,318 @@ +"""Generic Python tool requirements for code generation validation. + +This module provides a set of composable requirements for validating Python code +generated by language models. Requirements can be used individually or bundled +via the python_tool_requirements() factory function. + +The requirement pipeline validates code in this order: +1. PythonCodeExtraction — code blocks are present and extractable +2. PythonSyntaxValid — code parses without syntax errors +3. PythonExecutesWithoutError — code runs without exceptions +4. OutputSizeLimit — captured output stays within bounds +5. ImportRestrictions — only whitelisted modules are imported (optional) +""" + +import ast + +from mellea.stdlib.tools.interpreter import ( + ExecutionEnvironment, + LLMSandboxEnvironment, + StaticAnalysisEnvironment, + UnsafeEnvironment, +) + +from ...core import Context, MelleaLogger, Requirement, ValidationResult +from .python_reqs import ( + PythonExecutionReq, + _has_python_code_listing, + _python_executes_without_error, +) + +logger = MelleaLogger.get_logger() + + +class PythonCodeExtraction(Requirement): + """Code blocks are present and extractable from model output. + + This requirement checks whether the model's response contains Python code + blocks that can be extracted for further validation or execution. + """ + + def __init__(self) -> None: + """Initialize PythonCodeExtraction requirement.""" + super().__init__( + description="Code blocks are present and extractable.", + validation_fn=_has_python_code_listing, + check_only=True, + ) + + +class PythonSyntaxValid(Requirement): + """Python code is syntactically valid (parses without AST errors). + + Uses Python's ast.parse() to validate syntax without executing code. + Useful for catching malformed code early in the validation pipeline. + """ + + def __init__(self) -> None: + """Initialize PythonSyntaxValid requirement.""" + super().__init__( + description="Python code is syntactically valid.", + validation_fn=self._validate_syntax, + check_only=True, + ) + + def _validate_syntax(self, ctx: Context) -> ValidationResult: + """Validate that extracted code has valid Python syntax. + + Args: + ctx: Context containing model output with code blocks. + + Returns: + ValidationResult with pass/fail and extracted code or error details. + """ + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, + reason=f"Could not extract code for syntax validation: {extraction_result.reason}", + ) + + code = extraction_result.reason + assert code is not None + + try: + ast.parse(code) + return ValidationResult(result=True, reason="Syntax is valid.") + except SyntaxError as e: + return ValidationResult( + result=False, reason=f"Syntax error at line {e.lineno}: {e.msg}" + ) + + +class OutputSizeLimit(Requirement): + """Captured output does not exceed size limit (in characters). + + Executes code and verifies that the captured stdout does not exceed + the configured character limit. Useful for preventing excessive logging + or infinite output loops. + + Args: + limit_chars: Maximum allowed output size in characters. Defaults to 10,000. + """ + + def __init__(self, limit_chars: int = 10_000) -> None: + """Initialize OutputSizeLimit requirement. + + Raises: + ValueError: If limit_chars is not positive. + """ + if limit_chars <= 0: + raise ValueError(f"limit_chars must be positive, got {limit_chars}") + + self.limit_chars = limit_chars + super().__init__( + description=f"Output does not exceed {limit_chars} characters.", + validation_fn=self._validate_output_size, + check_only=True, + ) + + def _validate_output_size(self, ctx: Context) -> ValidationResult: + """Validate that executed code's output stays within size limit. + + Args: + ctx: Context containing model output with code blocks. + + Returns: + ValidationResult with pass/fail and output size details. + """ + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, + reason="Could not extract code for output size validation.", + ) + + code = extraction_result.reason + assert code is not None + + try: + result = _python_executes_without_error(ctx, timeout=5, allow_unsafe=False) + if not result.as_bool(): + return ValidationResult( + result=False, + reason=f"Code execution failed during output size check: {result.reason}", + ) + + captured_output = getattr(result, "captured_output", "") or "" + output_size = len(str(captured_output)) + + if output_size <= self.limit_chars: + return ValidationResult( + result=True, + reason=f"Output size ({output_size} chars) within limit ({self.limit_chars}).", + ) + else: + return ValidationResult( + result=False, + reason=f"Output size ({output_size} chars) exceeds limit ({self.limit_chars}).", + ) + except Exception as e: + return ValidationResult( + result=False, reason=f"Error checking output size: {e!s}" + ) + + +class ImportRestrictions(Requirement): + """Only whitelisted modules are imported in the code. + + Uses AST analysis to find all imports (Import and ImportFrom nodes) + and validates them against an optional allowlist. If no allowlist is + provided, all imports are accepted. + + Args: + allowed_imports: List of module names that are allowed to be imported. + If None, all imports are accepted. + """ + + def __init__(self, allowed_imports: list[str] | None = None) -> None: + """Initialize ImportRestrictions requirement.""" + self.allowed_imports = allowed_imports or [] + imports_str = ", ".join(self.allowed_imports) if self.allowed_imports else "all" + description = f"Only imports from [{imports_str}] are used." + + super().__init__( + description=description, + validation_fn=self._validate_imports, + check_only=True, + ) + + def _validate_imports(self, ctx: Context) -> ValidationResult: + """Validate that imports in extracted code match allowlist. + + Args: + ctx: Context containing model output with code blocks. + + Returns: + ValidationResult with pass/fail and forbidden imports if any. + """ + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, reason="Could not extract code for import validation." + ) + + code = extraction_result.reason + assert code is not None + + if not self.allowed_imports: + return ValidationResult( + result=True, reason="No import restrictions configured." + ) + + try: + tree = ast.parse(code) + except SyntaxError as e: + return ValidationResult( + result=False, + reason=f"Could not parse code for import analysis: {e.msg}", + ) + + forbidden_imports: list[str] = [] + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name.split(".")[0] + if module_name not in self.allowed_imports: + forbidden_imports.append(module_name) + + elif isinstance(node, ast.ImportFrom): + if node.module is not None: + module_name = node.module.split(".")[0] + if module_name not in self.allowed_imports: + forbidden_imports.append(module_name) + + if forbidden_imports: + unique_forbidden = list(set(forbidden_imports)) + return ValidationResult( + result=False, + reason=f"Forbidden imports detected: {', '.join(unique_forbidden)}", + ) + + return ValidationResult(result=True, reason="All imports are whitelisted.") + + +def python_tool_requirements( + allowed_imports: list[str] | None = None, + output_limit_chars: int = 10_000, + timeout_seconds: int = 5, + use_sandbox: bool = False, +) -> list[Requirement]: + """Bundle generic Python tool requirements with configurable parameters. + + Factory function that creates a complete set of requirements for validating + Python code generation, from extraction through execution and output checks. + + Args: + allowed_imports: Whitelist of importable top-level modules. If None, all + imports are allowed. Default None. + output_limit_chars: Maximum allowed characters of captured stdout. + Default 10,000. + timeout_seconds: Maximum execution time in seconds. Default 5. + use_sandbox: Use llm-sandbox for Docker-isolated execution. Default False. + + Returns: + list[Requirement]: Requirement instances in validation order: + 1. PythonCodeExtraction + 2. PythonSyntaxValid + 3. PythonExecutesWithoutError (configured with timeout and sandbox settings) + 4. OutputSizeLimit (configured with output_limit_chars) + 5. ImportRestrictions (only included if allowed_imports is provided) + + Raises: + ValueError: If timeout_seconds is not positive. + ValueError: If output_limit_chars is not positive. + + Examples: + >>> # Unrestricted execution with defaults + >>> reqs = python_tool_requirements() + >>> len(reqs) + 4 + + >>> # Restricted to safe modules only + >>> reqs = python_tool_requirements( + ... allowed_imports=["os", "sys", "json"], + ... output_limit_chars=5_000, + ... ) + >>> len(reqs) # includes ImportRestrictions + 5 + + >>> # Sandbox mode for untrusted code + >>> reqs = python_tool_requirements( + ... use_sandbox=True, + ... timeout_seconds=10, + ... ) + """ + if timeout_seconds <= 0: + raise ValueError(f"timeout_seconds must be positive, got {timeout_seconds}") + if output_limit_chars <= 0: + raise ValueError( + f"output_limit_chars must be positive, got {output_limit_chars}" + ) + + reqs: list[Requirement] = [ + PythonCodeExtraction(), + PythonSyntaxValid(), + PythonExecutionReq( + timeout=timeout_seconds, + allowed_imports=allowed_imports, + use_sandbox=use_sandbox, + ), + OutputSizeLimit(limit_chars=output_limit_chars), + ] + + if allowed_imports is not None: + reqs.append(ImportRestrictions(allowed_imports=allowed_imports)) + + return reqs diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py new file mode 100644 index 000000000..9607da910 --- /dev/null +++ b/test/stdlib/requirements/test_python_tools.py @@ -0,0 +1,399 @@ +"""Tests for Python tool requirements from python_tools module.""" + +import pytest + +from mellea.core import Context, ModelOutputThunk +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.python_tools import ( + ImportRestrictions, + OutputSizeLimit, + PythonCodeExtraction, + PythonSyntaxValid, + python_tool_requirements, +) + + +def from_model(content: str) -> Context: + """Helper to create context from model output.""" + ctx = ChatContext() + ctx = ctx.add(ModelOutputThunk(value=content)) + return ctx + + +# Test fixtures +VALID_PYTHON_CODE = """```python +def hello_world(): + return "Hello, World!" + +print(hello_world()) +```""" + +PYTHON_WITH_SYNTAX_ERROR = """```python +def hello_world( + return "Hello, World!" +```""" + +PYTHON_WITH_IMPORTS = """```python +import os +import sys +from pathlib import Path + +print("Hello from imports!") +```""" + +PYTHON_WITH_FORBIDDEN_IMPORTS = """```python +import subprocess +import socket +import urllib + +print("Dangerous imports!") +```""" + +NO_PYTHON_CODE = """ +This is just text without any Python code blocks. +It contains no executable content. +""" + + +class TestPythonCodeExtraction: + """Tests for PythonCodeExtraction requirement.""" + + def test_extract_valid_code_block(self): + """Test extraction of valid Python code.""" + req = PythonCodeExtraction() + ctx = from_model(VALID_PYTHON_CODE) + result = req.validation_fn(ctx) + + assert result.as_bool() is True + assert "hello_world" in (result.reason or "") + + def test_extract_no_code_blocks(self): + """Test extraction when no code blocks present.""" + req = PythonCodeExtraction() + ctx = from_model(NO_PYTHON_CODE) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + assert result.reason is not None + + def test_extract_multiple_code_blocks(self): + """Test extraction when multiple code blocks present (should return highest scoring).""" + code = """ +Here's a simple one: +```python +print("simple") +``` + +And a more complex one: +```python +def fibonacci(n): + if n <= 1: + return n + return fibonacci(n-1) + fibonacci(n-2) + +for i in range(10): + print(fibonacci(i)) +``` +""" + req = PythonCodeExtraction() + ctx = from_model(code) + result = req.validation_fn(ctx) + + assert result.as_bool() is True + assert "fibonacci" in (result.reason or "") + + +class TestPythonSyntaxValid: + """Tests for PythonSyntaxValid requirement.""" + + def test_valid_syntax(self): + """Test validation of syntactically valid code.""" + req = PythonSyntaxValid() + ctx = from_model(VALID_PYTHON_CODE) + result = req.validation_fn(ctx) + + assert result.as_bool() is True + assert "valid" in (result.reason or "").lower() + + def test_syntax_error(self): + """Test validation of code with syntax errors.""" + req = PythonSyntaxValid() + ctx = from_model(PYTHON_WITH_SYNTAX_ERROR) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + assert "syntax error" in (result.reason or "").lower() + + def test_syntax_error_unclosed_paren(self): + """Test validation of code with unclosed parenthesis.""" + code = """```python +def foo( + pass +```""" + req = PythonSyntaxValid() + ctx = from_model(code) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + + def test_syntax_error_bad_indentation(self): + """Test validation of code with indentation errors.""" + code = """```python +def foo(): +return "bad indent" +```""" + req = PythonSyntaxValid() + ctx = from_model(code) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + + def test_syntax_valid_no_code_extraction(self): + """Test validation when no code can be extracted.""" + req = PythonSyntaxValid() + ctx = from_model(NO_PYTHON_CODE) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + + +class TestOutputSizeLimit: + """Tests for OutputSizeLimit requirement.""" + + def test_init_valid_limit(self): + """Test initialization with valid limit.""" + req = OutputSizeLimit(limit_chars=5000) + assert req.limit_chars == 5000 + + def test_init_invalid_limit_zero(self): + """Test initialization with zero limit raises ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + OutputSizeLimit(limit_chars=0) + + def test_init_invalid_limit_negative(self): + """Test initialization with negative limit raises ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + OutputSizeLimit(limit_chars=-100) + + def test_init_default_limit(self): + """Test initialization with default limit.""" + req = OutputSizeLimit() + assert req.limit_chars == 10_000 + + def test_output_within_limit(self): + """Test validation when output stays within limit.""" + req = OutputSizeLimit(limit_chars=1000) + code = """```python +print("Hello, World!") +```""" + ctx = from_model(code) + result = req.validation_fn(ctx) + # Result depends on execution, but size check logic is present + assert isinstance(result.as_bool(), bool) + + +class TestImportRestrictions: + """Tests for ImportRestrictions requirement.""" + + def test_init_with_allowlist(self): + """Test initialization with import allowlist.""" + req = ImportRestrictions(allowed_imports=["os", "sys", "json"]) + assert req.allowed_imports == ["os", "sys", "json"] + + def test_init_with_none(self): + """Test initialization with None allowlist.""" + req = ImportRestrictions(allowed_imports=None) + assert req.allowed_imports == [] + + def test_init_default(self): + """Test initialization with default (None) allowlist.""" + req = ImportRestrictions() + assert req.allowed_imports == [] + + def test_allowed_imports_pass(self): + """Test validation when imports are in allowlist.""" + req = ImportRestrictions(allowed_imports=["os", "sys", "pathlib"]) + ctx = from_model(PYTHON_WITH_IMPORTS) + result = req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_forbidden_imports_fail(self): + """Test validation when forbidden imports are detected.""" + req = ImportRestrictions(allowed_imports=["os", "sys"]) + ctx = from_model(PYTHON_WITH_FORBIDDEN_IMPORTS) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + assert "forbidden" in (result.reason or "").lower() + assert any( + m in (result.reason or "") for m in ["subprocess", "socket", "urllib"] + ) + + def test_no_imports_pass(self): + """Test validation when code has no imports.""" + req = ImportRestrictions(allowed_imports=["os"]) + code = """```python +def add(a, b): + return a + b + +print(add(2, 3)) +```""" + ctx = from_model(code) + result = req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_no_allowlist_passes_all(self): + """Test validation with no allowlist (None) passes all imports.""" + req = ImportRestrictions(allowed_imports=None) + ctx = from_model(PYTHON_WITH_FORBIDDEN_IMPORTS) + result = req.validation_fn(ctx) + + assert result.as_bool() is True + assert "No import restrictions" in (result.reason or "") + + def test_syntax_error_in_imports_check(self): + """Test import validation when code has syntax errors.""" + req = ImportRestrictions(allowed_imports=["os"]) + ctx = from_model(PYTHON_WITH_SYNTAX_ERROR) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + + def test_submodule_imports(self): + """Test validation of submodule imports.""" + req = ImportRestrictions(allowed_imports=["pathlib"]) + code = """```python +from pathlib.posixpath import join +import pathlib.pure + +print("submodules") +```""" + ctx = from_model(code) + result = req.validation_fn(ctx) + + assert result.as_bool() is True + + def test_forbidden_submodule(self): + """Test validation when submodule is forbidden.""" + req = ImportRestrictions(allowed_imports=["os"]) + code = """```python +from urllib.request import urlopen + +print("fetch") +```""" + ctx = from_model(code) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + + +class TestPythonToolRequirementsFactory: + """Tests for python_tool_requirements() factory function.""" + + def test_factory_default_returns_four_requirements(self): + """Test factory with defaults returns 4 requirements (no import restrictions).""" + reqs = python_tool_requirements() + assert len(reqs) == 4 + assert isinstance(reqs[0], PythonCodeExtraction) + assert isinstance(reqs[1], PythonSyntaxValid) + assert isinstance(reqs[3], OutputSizeLimit) + + def test_factory_with_allowed_imports_returns_five(self): + """Test factory with allowed_imports returns 5 requirements.""" + reqs = python_tool_requirements(allowed_imports=["os", "sys"]) + assert len(reqs) == 5 + assert isinstance(reqs[4], ImportRestrictions) + + def test_factory_parameter_propagation_output_limit(self): + """Test factory propagates output_limit_chars to OutputSizeLimit.""" + reqs = python_tool_requirements(output_limit_chars=5000) + output_limit_req = reqs[3] + assert isinstance(output_limit_req, OutputSizeLimit) + assert output_limit_req.limit_chars == 5000 + + def test_factory_parameter_propagation_imports(self): + """Test factory propagates allowed_imports to ImportRestrictions.""" + imports = ["os", "sys", "json"] + reqs = python_tool_requirements(allowed_imports=imports) + import_req = reqs[4] + assert isinstance(import_req, ImportRestrictions) + assert import_req.allowed_imports == imports + + def test_factory_timeout_parameter(self): + """Test factory accepts and uses timeout_seconds parameter.""" + reqs = python_tool_requirements(timeout_seconds=10) + assert len(reqs) == 4 + + def test_factory_sandbox_parameter(self): + """Test factory accepts and uses use_sandbox parameter.""" + reqs = python_tool_requirements(use_sandbox=True) + assert len(reqs) == 4 + + def test_factory_all_parameters(self): + """Test factory with all parameters configured.""" + reqs = python_tool_requirements( + allowed_imports=["os", "sys"], + output_limit_chars=8000, + timeout_seconds=15, + use_sandbox=True, + ) + assert len(reqs) == 5 + assert isinstance(reqs[3], OutputSizeLimit) + assert reqs[3].limit_chars == 8000 + assert isinstance(reqs[4], ImportRestrictions) + + def test_factory_invalid_timeout(self): + """Test factory with invalid timeout raises ValueError.""" + with pytest.raises(ValueError, match="timeout_seconds must be positive"): + python_tool_requirements(timeout_seconds=0) + + with pytest.raises(ValueError, match="timeout_seconds must be positive"): + python_tool_requirements(timeout_seconds=-5) + + def test_factory_invalid_output_limit(self): + """Test factory with invalid output_limit raises ValueError.""" + with pytest.raises(ValueError, match="output_limit_chars must be positive"): + python_tool_requirements(output_limit_chars=0) + + with pytest.raises(ValueError, match="output_limit_chars must be positive"): + python_tool_requirements(output_limit_chars=-1000) + + def test_factory_requirement_order(self): + """Test factory returns requirements in correct validation order.""" + reqs = python_tool_requirements(allowed_imports=["os"]) + + assert isinstance(reqs[0], PythonCodeExtraction) + assert isinstance(reqs[1], PythonSyntaxValid) + assert isinstance(reqs[2], type(reqs[2])) # PythonExecutionReq + assert isinstance(reqs[3], OutputSizeLimit) + assert isinstance(reqs[4], ImportRestrictions) + + def test_factory_timeout_propagation_to_execution_req(self): + """Test factory propagates timeout_seconds to PythonExecutionReq.""" + from mellea.stdlib.requirements.python_reqs import PythonExecutionReq + + reqs = python_tool_requirements(timeout_seconds=15) + execution_req = reqs[2] + assert isinstance(execution_req, PythonExecutionReq) + assert execution_req._timeout == 15 + + def test_factory_sandbox_propagation_to_execution_req(self): + """Test factory propagates use_sandbox to PythonExecutionReq.""" + from mellea.stdlib.requirements.python_reqs import PythonExecutionReq + + reqs = python_tool_requirements(use_sandbox=True) + execution_req = reqs[2] + assert isinstance(execution_req, PythonExecutionReq) + assert execution_req._use_sandbox is True + + def test_factory_allowed_imports_propagation_to_execution_req(self): + """Test factory propagates allowed_imports to PythonExecutionReq.""" + from mellea.stdlib.requirements.python_reqs import PythonExecutionReq + + imports = ["os", "sys", "json"] + reqs = python_tool_requirements(allowed_imports=imports) + execution_req = reqs[2] + assert isinstance(execution_req, PythonExecutionReq) + assert execution_req._allowed_imports == imports From a804c8836cd8831ae3d890f368f5d20c638cc717 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Tue, 26 May 2026 10:21:19 -0400 Subject: [PATCH 2/3] review comments Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/python_tools.py | 66 ++++++++++++++----- test/stdlib/requirements/test_python_tools.py | 50 ++++++++++++-- 2 files changed, 95 insertions(+), 21 deletions(-) diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 48d1e0200..42c2a2e66 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -7,7 +7,7 @@ The requirement pipeline validates code in this order: 1. PythonCodeExtraction — code blocks are present and extractable 2. PythonSyntaxValid — code parses without syntax errors -3. PythonExecutesWithoutError — code runs without exceptions +3. PythonExecutionReq — code runs without exceptions 4. OutputSizeLimit — captured output stays within bounds 5. ImportRestrictions — only whitelisted modules are imported (optional) """ @@ -17,7 +17,6 @@ from mellea.stdlib.tools.interpreter import ( ExecutionEnvironment, LLMSandboxEnvironment, - StaticAnalysisEnvironment, UnsafeEnvironment, ) @@ -99,9 +98,18 @@ class OutputSizeLimit(Requirement): Args: limit_chars: Maximum allowed output size in characters. Defaults to 10,000. + timeout: Maximum execution time in seconds. Defaults to 5. + use_sandbox: Use llm-sandbox for Docker-isolated execution. Defaults to False. + allowed_imports: Whitelist of importable top-level modules. None allows all. """ - def __init__(self, limit_chars: int = 10_000) -> None: + def __init__( + self, + limit_chars: int = 10_000, + timeout: int = 5, + use_sandbox: bool = False, + allowed_imports: list[str] | None = None, + ) -> None: """Initialize OutputSizeLimit requirement. Raises: @@ -111,6 +119,9 @@ def __init__(self, limit_chars: int = 10_000) -> None: raise ValueError(f"limit_chars must be positive, got {limit_chars}") self.limit_chars = limit_chars + self.timeout = timeout + self.use_sandbox = use_sandbox + self.allowed_imports = allowed_imports super().__init__( description=f"Output does not exceed {limit_chars} characters.", validation_fn=self._validate_output_size, @@ -137,15 +148,22 @@ def _validate_output_size(self, ctx: Context) -> ValidationResult: assert code is not None try: - result = _python_executes_without_error(ctx, timeout=5, allow_unsafe=False) - if not result.as_bool(): + environment: ExecutionEnvironment + if self.use_sandbox: + environment = LLMSandboxEnvironment( + allowed_imports=self.allowed_imports + ) + else: + environment = UnsafeEnvironment(allowed_imports=self.allowed_imports) + + exec_result = environment.execute(code, timeout=self.timeout) + if not exec_result.success: return ValidationResult( result=False, - reason=f"Code execution failed during output size check: {result.reason}", + reason=f"Code execution failed during output size check: {exec_result.to_validationresult_reason()}", ) - captured_output = getattr(result, "captured_output", "") or "" - output_size = len(str(captured_output)) + output_size = len(exec_result.stdout or "") if output_size <= self.limit_chars: return ValidationResult( @@ -167,18 +185,21 @@ class ImportRestrictions(Requirement): """Only whitelisted modules are imported in the code. Uses AST analysis to find all imports (Import and ImportFrom nodes) - and validates them against an optional allowlist. If no allowlist is - provided, all imports are accepted. + and validates them against an optional allowlist. If an empty list is + provided, all imports are blocked. If None is provided, all imports are accepted. Args: allowed_imports: List of module names that are allowed to be imported. - If None, all imports are accepted. + If None, all imports are accepted. If an empty list, all imports are blocked. """ def __init__(self, allowed_imports: list[str] | None = None) -> None: """Initialize ImportRestrictions requirement.""" - self.allowed_imports = allowed_imports or [] - imports_str = ", ".join(self.allowed_imports) if self.allowed_imports else "all" + self.allowed_imports: list[str] | None = allowed_imports + if allowed_imports is None: + imports_str = "all" + else: + imports_str = ", ".join(allowed_imports) if allowed_imports else "none" description = f"Only imports from [{imports_str}] are used." super().__init__( @@ -205,7 +226,7 @@ def _validate_imports(self, ctx: Context) -> ValidationResult: code = extraction_result.reason assert code is not None - if not self.allowed_imports: + if self.allowed_imports is None: return ValidationResult( result=True, reason="No import restrictions configured." ) @@ -228,13 +249,19 @@ def _validate_imports(self, ctx: Context) -> ValidationResult: forbidden_imports.append(module_name) elif isinstance(node, ast.ImportFrom): - if node.module is not None: + if node.module is None: + # Relative-only imports like "from . import x" + for alias in node.names: + module_name = alias.name.split(".")[0] + if module_name not in self.allowed_imports: + forbidden_imports.append(module_name) + else: module_name = node.module.split(".")[0] if module_name not in self.allowed_imports: forbidden_imports.append(module_name) if forbidden_imports: - unique_forbidden = list(set(forbidden_imports)) + unique_forbidden = sorted(set(forbidden_imports)) return ValidationResult( result=False, reason=f"Forbidden imports detected: {', '.join(unique_forbidden)}", @@ -309,7 +336,12 @@ def python_tool_requirements( allowed_imports=allowed_imports, use_sandbox=use_sandbox, ), - OutputSizeLimit(limit_chars=output_limit_chars), + OutputSizeLimit( + limit_chars=output_limit_chars, + timeout=timeout_seconds, + use_sandbox=use_sandbox, + allowed_imports=allowed_imports, + ), ] if allowed_imports is not None: diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py index 9607da910..2865ad821 100644 --- a/test/stdlib/requirements/test_python_tools.py +++ b/test/stdlib/requirements/test_python_tools.py @@ -188,8 +188,20 @@ def test_output_within_limit(self): ```""" ctx = from_model(code) result = req.validation_fn(ctx) - # Result depends on execution, but size check logic is present - assert isinstance(result.as_bool(), bool) + # Should pass: "Hello, World!" is much less than 1000 chars + assert result.as_bool() is True + + def test_output_exceeds_limit(self): + """Test validation when output exceeds limit.""" + req = OutputSizeLimit(limit_chars=10) + code = """```python +print("Hello, World! This is a long message.") +```""" + ctx = from_model(code) + result = req.validation_fn(ctx) + # Should fail: output is more than 10 chars + assert result.as_bool() is False + assert "exceeds" in (result.reason or "").lower() class TestImportRestrictions: @@ -203,11 +215,16 @@ def test_init_with_allowlist(self): def test_init_with_none(self): """Test initialization with None allowlist.""" req = ImportRestrictions(allowed_imports=None) - assert req.allowed_imports == [] + assert req.allowed_imports is None def test_init_default(self): """Test initialization with default (None) allowlist.""" req = ImportRestrictions() + assert req.allowed_imports is None + + def test_init_with_empty_list(self): + """Test initialization with empty allowlist (blocks all imports).""" + req = ImportRestrictions(allowed_imports=[]) assert req.allowed_imports == [] def test_allowed_imports_pass(self): @@ -253,6 +270,15 @@ def test_no_allowlist_passes_all(self): assert result.as_bool() is True assert "No import restrictions" in (result.reason or "") + def test_empty_allowlist_blocks_all(self): + """Test validation with empty allowlist blocks all imports.""" + req = ImportRestrictions(allowed_imports=[]) + ctx = from_model(PYTHON_WITH_IMPORTS) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + assert "forbidden" in (result.reason or "").lower() + def test_syntax_error_in_imports_check(self): """Test import validation when code has syntax errors.""" req = ImportRestrictions(allowed_imports=["os"]) @@ -288,6 +314,20 @@ def test_forbidden_submodule(self): assert result.as_bool() is False + def test_relative_import_forbidden(self): + """Test validation catches relative-only imports like 'from . import x'.""" + req = ImportRestrictions(allowed_imports=["os"]) + code = """```python +from . import subprocess as sp + +print("relative import") +```""" + ctx = from_model(code) + result = req.validation_fn(ctx) + + assert result.as_bool() is False + assert "subprocess" in (result.reason or "") + class TestPythonToolRequirementsFactory: """Tests for python_tool_requirements() factory function.""" @@ -362,11 +402,13 @@ def test_factory_invalid_output_limit(self): def test_factory_requirement_order(self): """Test factory returns requirements in correct validation order.""" + from mellea.stdlib.requirements.python_reqs import PythonExecutionReq + reqs = python_tool_requirements(allowed_imports=["os"]) assert isinstance(reqs[0], PythonCodeExtraction) assert isinstance(reqs[1], PythonSyntaxValid) - assert isinstance(reqs[2], type(reqs[2])) # PythonExecutionReq + assert isinstance(reqs[2], PythonExecutionReq) assert isinstance(reqs[3], OutputSizeLimit) assert isinstance(reqs[4], ImportRestrictions) From d7724bca8c8e3089a1b39a61997989553dc8a324 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Wed, 27 May 2026 10:53:02 -0400 Subject: [PATCH 3/3] review commnets Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/python_tools.py | 12 ++++++------ test/stdlib/requirements/test_python_tools.py | 3 +++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mellea/stdlib/requirements/python_tools.py b/mellea/stdlib/requirements/python_tools.py index 42c2a2e66..16779b8bd 100644 --- a/mellea/stdlib/requirements/python_tools.py +++ b/mellea/stdlib/requirements/python_tools.py @@ -21,11 +21,7 @@ ) from ...core import Context, MelleaLogger, Requirement, ValidationResult -from .python_reqs import ( - PythonExecutionReq, - _has_python_code_listing, - _python_executes_without_error, -) +from .python_reqs import PythonExecutionReq, _has_python_code_listing logger = MelleaLogger.get_logger() @@ -281,6 +277,10 @@ def python_tool_requirements( Factory function that creates a complete set of requirements for validating Python code generation, from extraction through execution and output checks. + Note: OutputSizeLimit requires actual code execution to capture stdout size. + Control execution safety via use_sandbox (Docker isolation) or by using + untrusted LLM sources only with use_sandbox=True. + Args: allowed_imports: Whitelist of importable top-level modules. If None, all imports are allowed. Default None. @@ -293,7 +293,7 @@ def python_tool_requirements( list[Requirement]: Requirement instances in validation order: 1. PythonCodeExtraction 2. PythonSyntaxValid - 3. PythonExecutesWithoutError (configured with timeout and sandbox settings) + 3. PythonExecutionReq (configured with timeout and sandbox settings) 4. OutputSizeLimit (configured with output_limit_chars) 5. ImportRestrictions (only included if allowed_imports is provided) diff --git a/test/stdlib/requirements/test_python_tools.py b/test/stdlib/requirements/test_python_tools.py index 2865ad821..a65ff198e 100644 --- a/test/stdlib/requirements/test_python_tools.py +++ b/test/stdlib/requirements/test_python_tools.py @@ -334,10 +334,13 @@ class TestPythonToolRequirementsFactory: def test_factory_default_returns_four_requirements(self): """Test factory with defaults returns 4 requirements (no import restrictions).""" + from mellea.stdlib.requirements.python_reqs import PythonExecutionReq + reqs = python_tool_requirements() assert len(reqs) == 4 assert isinstance(reqs[0], PythonCodeExtraction) assert isinstance(reqs[1], PythonSyntaxValid) + assert isinstance(reqs[2], PythonExecutionReq) assert isinstance(reqs[3], OutputSizeLimit) def test_factory_with_allowed_imports_returns_five(self):