Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mellea/stdlib/requirements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from ...core import Requirement, ValidationResult, default_output_to_bool
from .md import as_markdown_list, is_markdown_list, is_markdown_table
from .python_reqs import PythonExecutionReq
from .python_tools import (
ImportRestrictions,
OutputSizeLimit,
PythonCodeExtraction,
PythonSyntaxValid,
python_tool_requirements,
)
from .requirement import (
ALoraRequirement,
LLMaJRequirement,
Expand All @@ -17,15 +24,20 @@

__all__ = [
"ALoraRequirement",
"ImportRestrictions",
"LLMaJRequirement",
"OutputSizeLimit",
"PythonCodeExtraction",
"PythonExecutionReq",
"PythonSyntaxValid",
"Requirement",
"ValidationResult",
"as_markdown_list",
"check",
"default_output_to_bool",
"is_markdown_list",
"is_markdown_table",
"python_tool_requirements",
"req",
"reqify",
"requirement_check_to_bool",
Expand Down
318 changes: 318 additions & 0 deletions mellea/stdlib/requirements/python_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
"""Generic Python tool requirements for code generation validation.

This module provides a set of composable requirements for validating Python code
generated by language models. Requirements can be used individually or bundled
via the python_tool_requirements() factory function.

The requirement pipeline validates code in this order:
1. PythonCodeExtraction — code blocks are present and extractable
2. PythonSyntaxValid — code parses without syntax errors
3. PythonExecutesWithoutError — code runs without exceptions
4. OutputSizeLimit — captured output stays within bounds
5. ImportRestrictions — only whitelisted modules are imported (optional)
"""

import ast

from mellea.stdlib.tools.interpreter import (
ExecutionEnvironment,
LLMSandboxEnvironment,
StaticAnalysisEnvironment,
UnsafeEnvironment,
)

from ...core import Context, MelleaLogger, Requirement, ValidationResult
from .python_reqs import (
PythonExecutionReq,
_has_python_code_listing,
_python_executes_without_error,
)

logger = MelleaLogger.get_logger()


class PythonCodeExtraction(Requirement):
"""Code blocks are present and extractable from model output.

This requirement checks whether the model's response contains Python code
blocks that can be extracted for further validation or execution.
"""

def __init__(self) -> None:
"""Initialize PythonCodeExtraction requirement."""
super().__init__(
description="Code blocks are present and extractable.",
validation_fn=_has_python_code_listing,
check_only=True,
)


class PythonSyntaxValid(Requirement):
"""Python code is syntactically valid (parses without AST errors).

Uses Python's ast.parse() to validate syntax without executing code.
Useful for catching malformed code early in the validation pipeline.
"""

def __init__(self) -> None:
"""Initialize PythonSyntaxValid requirement."""
super().__init__(
description="Python code is syntactically valid.",
validation_fn=self._validate_syntax,
check_only=True,
)

def _validate_syntax(self, ctx: Context) -> ValidationResult:
"""Validate that extracted code has valid Python syntax.

Args:
ctx: Context containing model output with code blocks.

Returns:
ValidationResult with pass/fail and extracted code or error details.
"""
extraction_result = _has_python_code_listing(ctx)
if not extraction_result.as_bool():
return ValidationResult(
result=False,
reason=f"Could not extract code for syntax validation: {extraction_result.reason}",
)

code = extraction_result.reason
assert code is not None

try:
ast.parse(code)
return ValidationResult(result=True, reason="Syntax is valid.")
except SyntaxError as e:
return ValidationResult(
result=False, reason=f"Syntax error at line {e.lineno}: {e.msg}"
)


class OutputSizeLimit(Requirement):
"""Captured output does not exceed size limit (in characters).

Executes code and verifies that the captured stdout does not exceed
the configured character limit. Useful for preventing excessive logging
or infinite output loops.

Args:
limit_chars: Maximum allowed output size in characters. Defaults to 10,000.
"""

def __init__(self, limit_chars: int = 10_000) -> None:
"""Initialize OutputSizeLimit requirement.

Raises:
ValueError: If limit_chars is not positive.
"""
if limit_chars <= 0:
raise ValueError(f"limit_chars must be positive, got {limit_chars}")

self.limit_chars = limit_chars
super().__init__(
description=f"Output does not exceed {limit_chars} characters.",
validation_fn=self._validate_output_size,
check_only=True,
)

def _validate_output_size(self, ctx: Context) -> ValidationResult:
"""Validate that executed code's output stays within size limit.

Args:
ctx: Context containing model output with code blocks.

Returns:
ValidationResult with pass/fail and output size details.
"""
extraction_result = _has_python_code_listing(ctx)
if not extraction_result.as_bool():
return ValidationResult(
result=False,
reason="Could not extract code for output size validation.",
)

code = extraction_result.reason
assert code is not None

try:
result = _python_executes_without_error(ctx, timeout=5, allow_unsafe=False)
if not result.as_bool():
return ValidationResult(
result=False,
reason=f"Code execution failed during output size check: {result.reason}",
)

captured_output = getattr(result, "captured_output", "") or ""
output_size = len(str(captured_output))

if output_size <= self.limit_chars:
return ValidationResult(
result=True,
reason=f"Output size ({output_size} chars) within limit ({self.limit_chars}).",
)
else:
return ValidationResult(
result=False,
reason=f"Output size ({output_size} chars) exceeds limit ({self.limit_chars}).",
)
except Exception as e:
return ValidationResult(
result=False, reason=f"Error checking output size: {e!s}"
)


class ImportRestrictions(Requirement):
"""Only whitelisted modules are imported in the code.

Uses AST analysis to find all imports (Import and ImportFrom nodes)
and validates them against an optional allowlist. If no allowlist is
provided, all imports are accepted.

Args:
allowed_imports: List of module names that are allowed to be imported.
If None, all imports are accepted.
"""

def __init__(self, allowed_imports: list[str] | None = None) -> None:
"""Initialize ImportRestrictions requirement."""
self.allowed_imports = allowed_imports or []
imports_str = ", ".join(self.allowed_imports) if self.allowed_imports else "all"
description = f"Only imports from [{imports_str}] are used."

super().__init__(
description=description,
validation_fn=self._validate_imports,
check_only=True,
)

def _validate_imports(self, ctx: Context) -> ValidationResult:
"""Validate that imports in extracted code match allowlist.

Args:
ctx: Context containing model output with code blocks.

Returns:
ValidationResult with pass/fail and forbidden imports if any.
"""
extraction_result = _has_python_code_listing(ctx)
if not extraction_result.as_bool():
return ValidationResult(
result=False, reason="Could not extract code for import validation."
)

code = extraction_result.reason
assert code is not None

if not self.allowed_imports:
return ValidationResult(
result=True, reason="No import restrictions configured."
)

try:
tree = ast.parse(code)
except SyntaxError as e:
return ValidationResult(
result=False,
reason=f"Could not parse code for import analysis: {e.msg}",
)

forbidden_imports: list[str] = []

for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split(".")[0]
if module_name not in self.allowed_imports:
forbidden_imports.append(module_name)

elif isinstance(node, ast.ImportFrom):
if node.module is not None:
module_name = node.module.split(".")[0]
if module_name not in self.allowed_imports:
forbidden_imports.append(module_name)

if forbidden_imports:
unique_forbidden = list(set(forbidden_imports))
return ValidationResult(
result=False,
reason=f"Forbidden imports detected: {', '.join(unique_forbidden)}",
)

return ValidationResult(result=True, reason="All imports are whitelisted.")


def python_tool_requirements(
allowed_imports: list[str] | None = None,
output_limit_chars: int = 10_000,
timeout_seconds: int = 5,
use_sandbox: bool = False,
) -> list[Requirement]:
"""Bundle generic Python tool requirements with configurable parameters.

Factory function that creates a complete set of requirements for validating
Python code generation, from extraction through execution and output checks.

Args:
allowed_imports: Whitelist of importable top-level modules. If None, all
imports are allowed. Default None.
output_limit_chars: Maximum allowed characters of captured stdout.
Default 10,000.
timeout_seconds: Maximum execution time in seconds. Default 5.
use_sandbox: Use llm-sandbox for Docker-isolated execution. Default False.

Returns:
list[Requirement]: Requirement instances in validation order:
1. PythonCodeExtraction
2. PythonSyntaxValid
3. PythonExecutesWithoutError (configured with timeout and sandbox settings)
4. OutputSizeLimit (configured with output_limit_chars)
5. ImportRestrictions (only included if allowed_imports is provided)

Raises:
ValueError: If timeout_seconds is not positive.
ValueError: If output_limit_chars is not positive.

Examples:
>>> # Unrestricted execution with defaults
>>> reqs = python_tool_requirements()
>>> len(reqs)
4

>>> # Restricted to safe modules only
>>> reqs = python_tool_requirements(
... allowed_imports=["os", "sys", "json"],
... output_limit_chars=5_000,
... )
>>> len(reqs) # includes ImportRestrictions
5

>>> # Sandbox mode for untrusted code
>>> reqs = python_tool_requirements(
... use_sandbox=True,
... timeout_seconds=10,
... )
"""
if timeout_seconds <= 0:
raise ValueError(f"timeout_seconds must be positive, got {timeout_seconds}")
if output_limit_chars <= 0:
raise ValueError(
f"output_limit_chars must be positive, got {output_limit_chars}"
)

reqs: list[Requirement] = [
PythonCodeExtraction(),
PythonSyntaxValid(),
PythonExecutionReq(
timeout=timeout_seconds,
allowed_imports=allowed_imports,
use_sandbox=use_sandbox,
),
OutputSizeLimit(limit_chars=output_limit_chars),
]

if allowed_imports is not None:
reqs.append(ImportRestrictions(allowed_imports=allowed_imports))

return reqs
Loading
Loading