From 19ea55d48753c87fd8bb03bc4721d7829d2feccf Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Tue, 28 Apr 2026 15:03:40 +0000 Subject: [PATCH 1/7] =?UTF-8?q?feat:=20add=20vended=20tools=20=E2=80=94=20?= =?UTF-8?q?shell,=20editor,=20python=5Frepl=20(2/N)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add production-ready vended tools that integrate with the Sandbox abstraction: - shell: Execute shell commands with streaming output via sandbox.execute_streaming() - editor: View, create, and edit files (5 commands matching Anthropic's text_editor spec) - python_repl: Execute Python code with streaming via sandbox.execute_code_streaming() All tools: - Read configuration from agent.state (persists across calls and sessions) - Work transparently with any Sandbox implementation (host, Docker, cloud) - Support streaming (shell/python_repl yield StreamChunk events in real-time) - Handle NoOpSandbox gracefully with clear error messages Tests: 69 passing (shell: 14, editor: 24, python_repl: 12, import/spec: 9, config: 4, streaming: 6) Part of #1968 series → feature/sandbox branch --- src/strands/vended_tools/__init__.py | 39 + src/strands/vended_tools/editor/__init__.py | 20 + src/strands/vended_tools/editor/editor.py | 403 ++++++ .../vended_tools/python_repl/__init__.py | 20 + .../vended_tools/python_repl/python_repl.py | 140 ++ src/strands/vended_tools/shell/__init__.py | 20 + src/strands/vended_tools/shell/shell.py | 164 +++ tests/strands/vended_tools/__init__.py | 0 .../strands/vended_tools/test_vended_tools.py | 1161 +++++++++++++++++ 9 files changed, 1967 insertions(+) create mode 100644 src/strands/vended_tools/__init__.py create mode 100644 src/strands/vended_tools/editor/__init__.py create mode 100644 src/strands/vended_tools/editor/editor.py create mode 100644 src/strands/vended_tools/python_repl/__init__.py create mode 100644 src/strands/vended_tools/python_repl/python_repl.py create mode 100644 src/strands/vended_tools/shell/__init__.py create mode 100644 src/strands/vended_tools/shell/shell.py create mode 100644 tests/strands/vended_tools/__init__.py create mode 100644 tests/strands/vended_tools/test_vended_tools.py diff --git a/src/strands/vended_tools/__init__.py b/src/strands/vended_tools/__init__.py new file mode 100644 index 000000000..4778efbff --- /dev/null +++ b/src/strands/vended_tools/__init__.py @@ -0,0 +1,39 @@ +"""Vended tools for Strands agents. + +These are production-ready tools that ship with the SDK and integrate +with the :class:`~strands.sandbox.base.Sandbox` abstraction. They work +transparently whether the agent uses a local :class:`~strands.sandbox.host.HostSandbox` +or a remote sandbox implementation. + +Each tool reads its configuration from ``tool_context.agent.state`` using +a namespaced key (e.g., ``strands_shell_tool``). This means configuration +persists across tool calls and survives session serialization. + +Available tools: + +- :func:`~strands.vended_tools.shell.shell` — Execute shell commands +- :func:`~strands.vended_tools.editor.editor` — View, create, and edit files +- :func:`~strands.vended_tools.python_repl.python_repl` — Execute Python code + +Example:: + + from strands import Agent + from strands.vended_tools import shell, editor, python_repl + + agent = Agent(tools=[shell, editor, python_repl]) + + # Configure tools via agent state (persists across calls) + agent.state.set("strands_shell_tool", { + "timeout": 60, + }) +""" + +from .editor import editor +from .python_repl import python_repl +from .shell import shell + +__all__ = [ + "editor", + "python_repl", + "shell", +] diff --git a/src/strands/vended_tools/editor/__init__.py b/src/strands/vended_tools/editor/__init__.py new file mode 100644 index 000000000..b3e23e74d --- /dev/null +++ b/src/strands/vended_tools/editor/__init__.py @@ -0,0 +1,20 @@ +"""File editor tool for viewing, creating, and editing files in the agent's sandbox. + +Example:: + + from strands import Agent + from strands.vended_tools import editor + + agent = Agent(tools=[editor]) + agent("View the contents of /tmp/example.py") + +Configuration via agent state:: + + agent.state.set("strands_editor_tool", { + "max_file_size": 1048576, # Maximum file size in bytes (default: 1MB) + }) +""" + +from .editor import editor + +__all__ = ["editor"] diff --git a/src/strands/vended_tools/editor/editor.py b/src/strands/vended_tools/editor/editor.py new file mode 100644 index 000000000..8fe3af25b --- /dev/null +++ b/src/strands/vended_tools/editor/editor.py @@ -0,0 +1,403 @@ +"""File editor tool implementation. + +Provides view, create, str_replace, insert, and undo_edit operations on files +in the agent's sandbox. The tool delegates all file I/O to the sandbox's +``read_file``, ``write_file``, and ``list_files`` methods. + +The tool shape matches Anthropic's ``text_editor`` built-in tool — 5 commands, +7 parameters. This means models trained on Anthropic's tool spec will work +well with this tool out of the box. + +Configuration keys (set via ``agent.state.set("strands_editor_tool", {...})``): + +- ``max_file_size`` (int): Maximum file size in bytes for read operations. + Default: 1048576 (1 MB). +- ``require_absolute_paths`` (bool): When True, rejects relative paths and + paths containing ``..``. When False (the default), paths are passed through + to the sandbox without filesystem-level validation — the sandbox decides + what a path means. Default: False. +""" + +import logging +from typing import Any, Literal + +from ...tools.decorator import tool +from ...types.tools import ToolContext + +logger = logging.getLogger(__name__) + +#: State key for editor tool configuration in agent.state +STATE_KEY = "strands_editor_tool" + +#: State key for undo history (internal) +_UNDO_STATE_KEY = "_strands_editor_undo" + +#: Default maximum file size (1 MB) +DEFAULT_MAX_FILE_SIZE = 1_048_576 + +#: Number of context lines to show around edits +SNIPPET_LINES = 4 + +#: Maximum directory listing depth +MAX_DIRECTORY_DEPTH = 2 + + +def _get_config(tool_context: ToolContext) -> dict[str, Any]: + """Read editor tool configuration from agent state.""" + return tool_context.agent.state.get(STATE_KEY) or {} + + +def _make_output(content: str, descriptor: str, init_line: int = 1) -> str: + """Format file content with line numbers (cat -n style). + + Args: + content: The file content to format. + descriptor: Description of what is being shown (e.g., file path). + init_line: Starting line number. + + Returns: + Formatted output with line numbers. + """ + # Expand tabs to spaces + content = content.replace("\t", " ") + lines = content.split("\n") + numbered = [] + for i, line in enumerate(lines): + line_num = i + init_line + numbered.append(f"{line_num:>6} {line}") + return f"Here's the result of running `cat -n` on {descriptor}:\n" + "\n".join(numbered) + "\n" + + +def _save_undo(tool_context: ToolContext, path: str, content: str) -> None: + """Save file content for undo. + + Args: + tool_context: The tool context providing access to agent state. + path: The file path. + content: The file content before modification. + """ + undo_state = tool_context.agent.state.get(_UNDO_STATE_KEY) or {} + undo_state[path] = content + tool_context.agent.state.set(_UNDO_STATE_KEY, undo_state) + + +def _get_undo(tool_context: ToolContext, path: str) -> str | None: + """Get saved undo content for a file. + + Args: + tool_context: The tool context providing access to agent state. + path: The file path. + + Returns: + The saved content, or None if no undo is available. + """ + undo_state = tool_context.agent.state.get(_UNDO_STATE_KEY) or {} + return undo_state.get(path) + + +@tool(context=True) +async def editor( + command: Literal["view", "create", "str_replace", "insert", "undo_edit"], + path: str, + file_text: str | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, + view_range: list[int] | None = None, + tool_context: ToolContext = None, # type: ignore[assignment] +) -> str: + """View, create, and edit files in the agent's sandbox. + + Commands: + + - **view**: Display file contents with line numbers, or list directory contents. + Use ``view_range`` as ``[start_line, end_line]`` (1-indexed, -1 for end of file) + to view a specific range. + - **create**: Create a new file with ``file_text`` content. Fails if file exists. + - **str_replace**: Replace ``old_str`` with ``new_str`` in the file. + ``old_str`` must match exactly once in the file (uniqueness enforced). + - **insert**: Insert ``new_str`` at ``insert_line`` (0-indexed line number). + - **undo_edit**: Revert the last edit to the file at ``path``. + + File operations go through the agent's sandbox. By default, paths are passed + through to the sandbox as-is — the sandbox decides what a path means. Set + ``require_absolute_paths: true`` in ``strands_editor_tool`` config to enforce + absolute paths and block directory traversal. + + Configuration is read from ``agent.state.get("strands_editor_tool")``: + + - ``max_file_size``: Maximum file size in bytes (default: 1 MB). + - ``require_absolute_paths``: Reject relative paths and ``..`` (default: False). + + Args: + command: The operation to perform. + path: Path to the file or directory. + file_text: Content for new file (required for ``create``). + old_str: String to find and replace (required for ``str_replace``). + Must appear exactly once in the file. + new_str: Replacement string for ``str_replace``, or text to insert for ``insert``. + insert_line: Line number for insertion (0-indexed, required for ``insert``). + view_range: Line range for view as ``[start, end]``. 1-indexed. + Use -1 for end to mean end of file. + tool_context: Framework-injected tool context. + + Returns: + Result of the operation — file contents, success message, or error. + """ + config = _get_config(tool_context) + sandbox = tool_context.agent.sandbox + + # Path validation is opt-in. By default, paths are passed straight through + # to the sandbox without filesystem-level validation. This allows sandboxes + # like S3Sandbox to use relative keys (e.g., "hello.txt") as paths. + if config.get("require_absolute_paths"): + import os + + if not os.path.isabs(path): + suggested = os.path.abspath(path) + return f"Error: The path {path} is not an absolute path. Maybe you meant {suggested}?" + if ".." in path: + return "Error: Path traversal (..) is not allowed." + + try: + if command == "view": + return await _handle_view(sandbox, config, path, view_range) + elif command == "create": + if file_text is None: + return "Error: Parameter `file_text` is required for command: create" + return await _handle_create(sandbox, tool_context, path, file_text) + elif command == "str_replace": + if old_str is None: + return "Error: Parameter `old_str` is required for command: str_replace" + return await _handle_str_replace(sandbox, tool_context, config, path, old_str, new_str or "") + elif command == "insert": + if insert_line is None: + return "Error: Parameter `insert_line` is required for command: insert" + if new_str is None: + return "Error: Parameter `new_str` is required for command: insert" + return await _handle_insert(sandbox, tool_context, config, path, insert_line, new_str) + elif command == "undo_edit": + return await _handle_undo(sandbox, tool_context, path) + + return f"Error: Unknown command: {command}" # type: ignore[unreachable] + except NotImplementedError as e: + return f"Error: Sandbox does not support this operation — {e}" + except Exception as e: + return f"Error: {e}" + + +async def _handle_view(sandbox: Any, config: dict[str, Any], path: str, view_range: list[int] | None) -> str: + """Handle the view command.""" + # Check if path is a directory + try: + entries = await sandbox.list_files(path) + # It's a directory + if view_range: + return "Error: The `view_range` parameter is not allowed when `path` points to a directory." + items = sorted(f"{e.name}/" if e.is_dir else e.name for e in entries if e.name not in (".", "..")) + return ( + f"Here's the files and directories up to 2 levels deep in {path}, " + f"excluding hidden items:\n" + "\n".join(items) + "\n" + ) + except (FileNotFoundError, OSError): + pass # Not a directory, try as file + + # Read file + max_size = config.get("max_file_size", DEFAULT_MAX_FILE_SIZE) + try: + content = (await sandbox.read_file(path)).decode("utf-8") + except FileNotFoundError: + return f"Error: The path {path} does not exist. Please provide a valid path." + except UnicodeDecodeError: + return f"Error: The file {path} is not a text file (cannot decode as UTF-8)." + + # Check size + if len(content.encode("utf-8")) > max_size: + return f"Error: File size exceeds maximum allowed size ({max_size} bytes)." + + if view_range is None: + return _make_output(content, path) + + # Validate and apply view range + lines = content.split("\n") + n_lines = len(lines) + + if len(view_range) != 2: + return "Error: `view_range` must be a list of two integers [start, end]." + + start, end = view_range[0], view_range[1] + + if start < 1 or start > n_lines: + return ( + f"Error: Invalid `view_range`: [{start}, {end}]. First element `{start}` should be within [1, {n_lines}]." + ) + if end != -1 and end > n_lines: + return f"Error: Invalid `view_range`: [{start}, {end}]. Second element `{end}` should be <= {n_lines}." + if end != -1 and end < start: + return f"Error: Invalid `view_range`: [{start}, {end}]. Second element must be >= first element." + + if end == -1: + selected = lines[start - 1 :] + else: + selected = lines[start - 1 : end] + + return _make_output("\n".join(selected), path, init_line=start) + + +async def _handle_create(sandbox: Any, tool_context: ToolContext, path: str, file_text: str) -> str: + """Handle the create command.""" + # Check if file already exists + try: + await sandbox.read_file(path) + return f"Error: File already exists at: {path}. Cannot overwrite with `create`. Use `str_replace` to edit." + except (FileNotFoundError, OSError): + pass # File doesn't exist, good + + await sandbox.write_file(path, file_text.encode("utf-8")) + return f"File created successfully at: {path}" + + +async def _handle_str_replace( + sandbox: Any, + tool_context: ToolContext, + config: dict[str, Any], + path: str, + old_str: str, + new_str: str, +) -> str: + """Handle the str_replace command.""" + try: + content = (await sandbox.read_file(path)).decode("utf-8") + except FileNotFoundError: + return f"Error: The path {path} does not exist." + + # Expand tabs for matching + content = content.replace("\t", " ") + expanded_old = old_str.replace("\t", " ") + expanded_new = new_str.replace("\t", " ") + + # Count occurrences — MUST be exactly 1 + count = content.count(expanded_old) + + if count == 0: + return f"Error: No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}." + + if count > 1: + # Find line numbers of all occurrences + lines = content.split("\n") + line_nums = [] + for i, line in enumerate(lines): + if expanded_old in line: + line_nums.append(i + 1) + # Also check multi-line matches + if not line_nums: + # old_str spans multiple lines, find approximate locations + idx = 0 + while True: + idx = content.find(expanded_old, idx) + if idx == -1: + break + line_num = content[:idx].count("\n") + 1 + line_nums.append(line_num) + idx += 1 + return ( + f"Error: No replacement was performed. Multiple occurrences ({count}) of old_str " + f"in lines {line_nums}. Please ensure old_str is unique." + ) + + # Save undo state + _save_undo(tool_context, path, content) + + # Perform replacement + new_content = content.replace(expanded_old, expanded_new, 1) + + # Write back + await sandbox.write_file(path, new_content.encode("utf-8")) + + # Generate snippet around the change + replace_idx = content.find(expanded_old) + replace_line = content[:replace_idx].count("\n") + inserted_lines = expanded_new.count("\n") + 1 + original_lines = expanded_old.count("\n") + 1 + line_diff = inserted_lines - original_lines + + new_lines = new_content.split("\n") + start = max(0, replace_line - SNIPPET_LINES) + end = min(len(new_lines), replace_line + SNIPPET_LINES + line_diff + 1) + snippet = "\n".join(new_lines[start:end]) + + return ( + f"The file {path} has been edited. " + + _make_output(snippet, f"a snippet of {path}", init_line=start + 1) + + "Review the changes and make sure they are as expected. Edit the file again if necessary." + ) + + +async def _handle_insert( + sandbox: Any, + tool_context: ToolContext, + config: dict[str, Any], + path: str, + insert_line: int, + new_str: str, +) -> str: + """Handle the insert command.""" + try: + content = (await sandbox.read_file(path)).decode("utf-8") + except FileNotFoundError: + return f"Error: The path {path} does not exist." + + # Expand tabs + content = content.replace("\t", " ") + expanded_new = new_str.replace("\t", " ") + + lines = content.split("\n") + n_lines = len(lines) + + if insert_line < 0 or insert_line > n_lines: + return f"Error: Invalid `insert_line`: {insert_line}. Should be within [0, {n_lines}]." + + # Save undo state + _save_undo(tool_context, path, content) + + # Insert + new_str_lines = expanded_new.split("\n") + if content == "": + new_lines = new_str_lines + else: + new_lines = lines[:insert_line] + new_str_lines + lines[insert_line:] + + new_content = "\n".join(new_lines) + await sandbox.write_file(path, new_content.encode("utf-8")) + + # Generate snippet + start = max(0, insert_line - SNIPPET_LINES) + end = min(len(new_lines), insert_line + len(new_str_lines) + SNIPPET_LINES) + snippet = "\n".join(new_lines[start:end]) + + return ( + f"The file {path} has been edited. " + + _make_output(snippet, "a snippet of the edited file", init_line=start + 1) + + "Review the changes and make sure they are as expected. Edit the file again if necessary." + ) + + +async def _handle_undo(sandbox: Any, tool_context: ToolContext, path: str) -> str: + """Handle the undo_edit command.""" + previous_content = _get_undo(tool_context, path) + if previous_content is None: + return f"Error: No edit history found for {path}." + + # Read current content for future undo + try: + current = (await sandbox.read_file(path)).decode("utf-8") + except FileNotFoundError: + current = "" + + # Write the previous content back + await sandbox.write_file(path, previous_content.encode("utf-8")) + + # Save current as new undo (so undo is toggleable) + _save_undo(tool_context, path, current) + + return f"Successfully reverted last edit to {path}." diff --git a/src/strands/vended_tools/python_repl/__init__.py b/src/strands/vended_tools/python_repl/__init__.py new file mode 100644 index 000000000..cd0268fce --- /dev/null +++ b/src/strands/vended_tools/python_repl/__init__.py @@ -0,0 +1,20 @@ +"""Python REPL tool for executing Python code in the agent's sandbox. + +Example:: + + from strands import Agent + from strands.vended_tools import python_repl + + agent = Agent(tools=[python_repl]) + agent("Calculate the first 10 Fibonacci numbers") + +Configuration via agent state:: + + agent.state.set("strands_python_repl_tool", { + "timeout": 30, # Default timeout in seconds + }) +""" + +from .python_repl import python_repl + +__all__ = ["python_repl"] diff --git a/src/strands/vended_tools/python_repl/python_repl.py b/src/strands/vended_tools/python_repl/python_repl.py new file mode 100644 index 000000000..01cc8c262 --- /dev/null +++ b/src/strands/vended_tools/python_repl/python_repl.py @@ -0,0 +1,140 @@ +"""Python REPL tool implementation with streaming support. + +Executes Python code in the agent's sandbox using +``sandbox.execute_code_streaming(code, language="python")``. Each chunk of +stdout/stderr is yielded as a ``ToolStreamEvent`` in real time, allowing UI +consumers to display live output from code execution. + +The tool is an **async generator**: ``StreamChunk`` objects from the sandbox +are yielded during execution, and the final yield is the formatted result +string that becomes the ``ToolResult``. + +Configuration keys (set via ``agent.state.set("strands_python_repl_tool", {...})``): + +- ``timeout`` (int): Default timeout in seconds for code execution. + Overridden by the per-call ``timeout`` parameter. Default: 30. +""" + +import asyncio +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from ...sandbox.base import ExecutionResult, StreamChunk +from ...tools.decorator import tool +from ...types.tools import ToolContext + +logger = logging.getLogger(__name__) + +#: State key for python_repl tool configuration in agent.state +STATE_KEY = "strands_python_repl_tool" + +#: Default timeout for code execution (seconds) +DEFAULT_TIMEOUT = 30 + + +def _get_config(tool_context: ToolContext) -> dict[str, Any]: + """Read python_repl tool configuration from agent state.""" + return tool_context.agent.state.get(STATE_KEY) or {} + + +@tool(context=True) +async def python_repl( + code: str, + timeout: int | None = None, + reset: bool = False, + tool_context: ToolContext = None, # type: ignore[assignment] +) -> AsyncGenerator[Any, None]: + """Execute Python code in the agent's sandbox with live output streaming. + + Code is executed via the agent's sandbox using + ``sandbox.execute_code_streaming(code, language="python")``. Each chunk + of stdout/stderr is yielded as a streaming event that UI consumers can + display in real time. The final yield is the formatted result string. + + Use ``reset=True`` to clear any sandbox-level state (e.g., restart the + interpreter session if the sandbox supports it). + + Configuration is read from ``agent.state.get("strands_python_repl_tool")``: + + - ``timeout``: Default timeout in seconds (overridden by per-call timeout). + + Args: + code: The Python code to execute. + timeout: Maximum execution time in seconds. Uses config default or 30s. + reset: If True, signal the sandbox to reset execution state. + tool_context: Framework-injected tool context. + + Yields: + :class:`~strands.sandbox.base.StreamChunk` objects during execution (wrapped as + ``ToolStreamEvent`` by the SDK), then a final string result. + """ + config = _get_config(tool_context) + sandbox = tool_context.agent.sandbox + + # Handle reset + if reset: + try: + tool_context.agent.state.delete("_strands_python_repl_state") + except Exception: + pass + if not code or not code.strip(): + yield "Python REPL state reset." + return + + # Resolve timeout: per-call > config > default + effective_timeout: int | None = timeout + if effective_timeout is None: + effective_timeout = config.get("timeout", DEFAULT_TIMEOUT) + + # Execute via sandbox streaming + result: ExecutionResult | None = None + try: + async for chunk in sandbox.execute_code_streaming( + code, + language="python", + timeout=effective_timeout, + ): + if isinstance(chunk, StreamChunk): + # Yield each chunk — the decorator wraps it as ToolStreamEvent + yield chunk + elif isinstance(chunk, ExecutionResult): + result = chunk + except asyncio.TimeoutError: + yield f"Error: Code execution timed out after {effective_timeout} seconds." + return + except NotImplementedError: + yield "Error: Sandbox does not support code execution (NoOpSandbox)." + return + except Exception as e: + yield f"Error: {e}" + return + + if result is None: + yield "Error: Sandbox did not return an execution result." + return + + # Format output + output_parts = [] + if result.stdout: + output_parts.append(result.stdout) + if result.stderr: + output_parts.append(result.stderr) + + output = "\n".join(output_parts).rstrip() + + if result.exit_code != 0: + if output: + output += f"\n\nExit code: {result.exit_code}" + else: + output = f"Code execution failed with exit code: {result.exit_code}" + + # Handle output files (images, charts, etc.) + if result.output_files: + file_names = [f.name for f in result.output_files] + if output: + output += f"\n\nGenerated files: {', '.join(file_names)}" + else: + output = f"Generated files: {', '.join(file_names)}" + + yield output if output else "(no output)" diff --git a/src/strands/vended_tools/shell/__init__.py b/src/strands/vended_tools/shell/__init__.py new file mode 100644 index 000000000..a04ae399a --- /dev/null +++ b/src/strands/vended_tools/shell/__init__.py @@ -0,0 +1,20 @@ +"""Shell tool for executing commands in the agent's sandbox. + +Example:: + + from strands import Agent + from strands.vended_tools import shell + + agent = Agent(tools=[shell]) + agent("List all Python files in the current directory") + +Configuration via agent state:: + + agent.state.set("strands_shell_tool", { + "timeout": 120, # Default timeout in seconds + }) +""" + +from .shell import shell + +__all__ = ["shell"] diff --git a/src/strands/vended_tools/shell/shell.py b/src/strands/vended_tools/shell/shell.py new file mode 100644 index 000000000..4cfdd1f38 --- /dev/null +++ b/src/strands/vended_tools/shell/shell.py @@ -0,0 +1,164 @@ +"""Shell tool implementation with streaming support. + +Executes shell commands in the agent's sandbox with persistent state tracking. +The tool uses ``sandbox.execute_streaming()`` so that stdout/stderr chunks are +yielded as ``ToolStreamEvent``s in real time. This allows UI consumers to display +live output from sandbox execution. + +The tool is an **async generator**: each ``StreamChunk`` from the sandbox is +yielded directly (the SDK decorator wraps it in a ``ToolStreamEvent``), and the +final yield is the formatted result string (which becomes the ``ToolResult``). + +Configuration keys (set via ``agent.state.set("strands_shell_tool", {...})``): + +- ``timeout`` (int): Default timeout in seconds. Overridden by the per-call + ``timeout`` parameter. Default: 120. +""" + +import asyncio +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from ...sandbox.base import ExecutionResult, StreamChunk +from ...tools.decorator import tool +from ...types.tools import ToolContext + +logger = logging.getLogger(__name__) + +#: State key for shell tool configuration in agent.state +STATE_KEY = "strands_shell_tool" + +#: Default timeout for shell commands (seconds) +DEFAULT_TIMEOUT = 120 + + +def _get_config(tool_context: ToolContext) -> dict[str, Any]: + """Read shell tool configuration from agent state. + + Args: + tool_context: The tool context providing access to agent state. + + Returns: + Configuration dict. Empty dict if no config is set. + """ + return tool_context.agent.state.get(STATE_KEY) or {} + + +@tool(context=True) +async def shell( + command: str, + timeout: int | None = None, + restart: bool = False, + tool_context: ToolContext = None, # type: ignore[assignment] +) -> AsyncGenerator[Any, None]: + """Execute a shell command in the agent's sandbox with live output streaming. + + The sandbox preserves working directory and environment variables across + calls when using a persistent sandbox implementation. Use ``restart=True`` + to reset the shell state. + + Commands are executed via the agent's sandbox + (``sandbox.execute_streaming()``). Each chunk of stdout/stderr is yielded + as a streaming event that UI consumers can display in real time. The final + yield is the formatted result string. + + Configuration is read from ``agent.state.get("strands_shell_tool")``: + + - ``timeout``: Default timeout in seconds (overridden by per-call timeout). + + Args: + command: The shell command to execute. + timeout: Maximum execution time in seconds. Uses config default or 120s. + restart: If True, reset shell state by clearing tracked working directory. + tool_context: Framework-injected tool context. + + Yields: + :class:`~strands.sandbox.base.StreamChunk` objects during execution (wrapped as + ``ToolStreamEvent`` by the SDK), then a final string result. + """ + config = _get_config(tool_context) + sandbox = tool_context.agent.sandbox + + # Handle restart + if restart: + _clear_shell_state(tool_context) + if not command or not command.strip(): + yield "Shell state reset." + return + + # Resolve timeout: per-call > config > default + effective_timeout: int | None = timeout + if effective_timeout is None: + effective_timeout = config.get("timeout", DEFAULT_TIMEOUT) + + # Get tracked working directory from state (for session continuity) + shell_state = tool_context.agent.state.get("_strands_shell_state") or {} + cwd = shell_state.get("cwd") + + # Execute via sandbox streaming + result: ExecutionResult | None = None + try: + async for chunk in sandbox.execute_streaming( + command, + timeout=effective_timeout, + cwd=cwd, + ): + if isinstance(chunk, StreamChunk): + # Yield each chunk — the decorator wraps it as ToolStreamEvent + yield chunk + elif isinstance(chunk, ExecutionResult): + result = chunk + except asyncio.TimeoutError: + yield f"Error: Command timed out after {effective_timeout} seconds." + return + except NotImplementedError: + yield "Error: Sandbox does not support command execution (NoOpSandbox)." + return + except Exception as e: + yield f"Error: {e}" + return + + if result is None: + yield "Error: Sandbox did not return an execution result." + return + + # Track working directory changes + try: + cwd_result = await sandbox.execute("pwd", timeout=5, cwd=cwd) + if cwd_result.exit_code == 0: + new_cwd = cwd_result.stdout.strip() + if new_cwd: + shell_state["cwd"] = new_cwd + tool_context.agent.state.set("_strands_shell_state", shell_state) + except Exception: + pass # Best-effort cwd tracking + + # Format final output (becomes the ToolResult) + output_parts = [] + if result.stdout: + output_parts.append(result.stdout) + if result.stderr: + output_parts.append(result.stderr) + + output = "\n".join(output_parts).rstrip() + + if result.exit_code != 0: + if output: + output += f"\n\nExit code: {result.exit_code}" + else: + output = f"Command failed with exit code: {result.exit_code}" + + yield output if output else "(no output)" + + +def _clear_shell_state(tool_context: ToolContext) -> None: + """Clear tracked shell state from agent state. + + Args: + tool_context: The tool context providing access to agent state. + """ + try: + tool_context.agent.state.delete("_strands_shell_state") + except Exception: + pass diff --git a/tests/strands/vended_tools/__init__.py b/tests/strands/vended_tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_tools/test_vended_tools.py b/tests/strands/vended_tools/test_vended_tools.py new file mode 100644 index 000000000..435ea677e --- /dev/null +++ b/tests/strands/vended_tools/test_vended_tools.py @@ -0,0 +1,1161 @@ +"""Tests for vended tools — shell, editor, python_repl. + +These tests use a real HostSandbox to validate end-to-end behavior. +They also test configuration via agent.state and streaming behavior +(shell and python_repl yield StreamChunk events). +""" + +import asyncio +import uuid +from unittest.mock import MagicMock + +import pytest + +from strands.agent.state import AgentState +from strands.sandbox.base import ExecutionResult, StreamChunk +from strands.sandbox.host import HostSandbox +from strands.sandbox.noop import NoOpSandbox +from strands.types.tools import ToolContext, ToolUse + +# ============================================================ +# Fixtures +# ============================================================ + + +@pytest.fixture +def sandbox(tmp_path): + """Create a HostSandbox for testing.""" + return HostSandbox(working_dir=str(tmp_path)) + + +@pytest.fixture +def agent_state(): + """Create a fresh AgentState.""" + return AgentState() + + +@pytest.fixture +def mock_agent(sandbox, agent_state): + """Create a mock agent with sandbox and state.""" + agent = MagicMock() + agent.sandbox = sandbox + agent.state = agent_state + return agent + + +@pytest.fixture +def tool_use(): + """Create a mock tool use.""" + return ToolUse( + toolUseId=str(uuid.uuid4()), + name="test_tool", + input={}, + ) + + +@pytest.fixture +def tool_context(mock_agent, tool_use): + """Create a ToolContext for testing.""" + ctx = ToolContext( + tool_use=tool_use, + agent=mock_agent, + invocation_state={}, + ) + return ctx + + +def run(coro): + """Helper to run async coroutines in tests.""" + return asyncio.get_event_loop().run_until_complete(coro) + + +async def collect_generator(gen): + """Collect all values from an async generator. + + Returns (stream_chunks, final_result) where stream_chunks are all + StreamChunk objects yielded, and final_result is the last non-StreamChunk + value (the formatted result string). + """ + chunks = [] + final = None + async for item in gen: + if isinstance(item, StreamChunk): + chunks.append(item) + else: + final = item + return chunks, final + + +# ============================================================ +# Shell Tool Tests +# ============================================================ + + +class TestShellTool: + """Tests for the shell vended tool.""" + + def test_basic_command(self, tool_context, tmp_path): + """Test basic shell command execution returns result.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="echo hello", tool_context=tool_context)) + ) + assert "hello" in result + + def test_basic_command_streams_chunks(self, tool_context, tmp_path): + """Test that shell yields StreamChunk objects during execution.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="echo hello", tool_context=tool_context)) + ) + # Should have at least one stdout chunk + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + assert len(stdout_chunks) >= 1 + assert any("hello" in c.data for c in stdout_chunks) + # Final result should also contain the output + assert "hello" in result + + def test_stderr_streams_as_stderr_chunks(self, tool_context, tmp_path): + """Test that stderr output yields StreamChunk with stream_type='stderr'.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="echo error >&2", tool_context=tool_context)) + ) + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stderr_chunks) >= 1 + assert any("error" in c.data for c in stderr_chunks) + assert "error" in result + + def test_mixed_stdout_stderr_streaming(self, tool_context, tmp_path): + """Test command with both stdout and stderr streams both chunk types.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator( + shell.__wrapped__( + command="echo out && echo err >&2", + tool_context=tool_context, + ) + ) + ) + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stdout_chunks) >= 1 + assert len(stderr_chunks) >= 1 + + def test_command_with_exit_code(self, tool_context, tmp_path): + """Test command that returns non-zero exit code.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="exit 42", tool_context=tool_context)) + ) + assert "42" in result + + def test_timeout(self, tool_context): + """Test command timeout.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context)) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + def test_config_timeout(self, tool_context, mock_agent): + """Test timeout from config.""" + from strands.vended_tools.shell.shell import shell + + mock_agent.state.set("strands_shell_tool", {"timeout": 1}) + chunks, result = run( + collect_generator(shell.__wrapped__(command="sleep 10", tool_context=tool_context)) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + def test_restart(self, tool_context): + """Test shell restart.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="", restart=True, tool_context=tool_context)) + ) + assert "reset" in result.lower() + + def test_no_output_command(self, tool_context): + """Test command with no output.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="true", tool_context=tool_context)) + ) + assert result == "(no output)" + + def test_noop_sandbox(self, tool_context, mock_agent): + """Test shell with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator(shell.__wrapped__(command="echo test", tool_context=tool_context)) + ) + assert "error" in result.lower() + + def test_cwd_tracking(self, tool_context, tmp_path): + """Test that working directory is tracked across calls.""" + from strands.vended_tools.shell.shell import shell + + subdir = tmp_path / "subdir" + subdir.mkdir() + + chunks, _ = run( + collect_generator(shell.__wrapped__(command=f"cd {subdir}", tool_context=tool_context)) + ) + + shell_state = tool_context.agent.state.get("_strands_shell_state") + assert shell_state is not None + + def test_multiline_output(self, tool_context): + """Test command with multiline output.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator( + shell.__wrapped__( + command="echo 'line1\nline2\nline3'", + tool_context=tool_context, + ) + ) + ) + assert "line1" in result + assert "line2" in result + + def test_pipe_command(self, tool_context): + """Test piped commands.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator( + shell.__wrapped__( + command="echo 'hello world' | wc -w", + tool_context=tool_context, + ) + ) + ) + assert "2" in result + + def test_stream_chunk_types_are_correct(self, tool_context): + """Test that all yielded chunks are proper StreamChunk instances.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator( + shell.__wrapped__( + command="echo stdout_data && echo stderr_data >&2", + tool_context=tool_context, + ) + ) + ) + for chunk in chunks: + assert isinstance(chunk, StreamChunk) + assert hasattr(chunk, "data") + assert hasattr(chunk, "stream_type") + assert chunk.stream_type in ("stdout", "stderr") + + +# ============================================================ +# Editor Tool Tests +# ============================================================ + + +class TestEditorTool: + """Tests for the editor vended tool.""" + + def test_view_file(self, tool_context, tmp_path, sandbox): + """Test viewing a file.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\n") + + result = run( + editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + ) + assert "line 1" in result + assert "line 2" in result + assert "line 3" in result + assert "cat -n" in result + + def test_view_with_range(self, tool_context, tmp_path): + """Test viewing a file with line range.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\nline 4\nline 5\n") + + result = run( + editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[2, 4], + tool_context=tool_context, + ) + ) + assert "line 2" in result + assert "line 4" in result + assert " 1" not in result + + def test_view_with_range_end_minus_one(self, tool_context, tmp_path): + """Test viewing with -1 as end of range.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\n") + + result = run( + editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[2, -1], + tool_context=tool_context, + ) + ) + assert "line 2" in result + assert "line 3" in result + + def test_view_directory(self, tool_context, tmp_path): + """Test viewing a directory listing.""" + from strands.vended_tools.editor.editor import editor + + (tmp_path / "file1.py").write_text("pass") + (tmp_path / "file2.txt").write_text("hello") + (tmp_path / "subdir").mkdir() + + result = run( + editor.__wrapped__( + command="view", + path=str(tmp_path), + tool_context=tool_context, + ) + ) + assert "file1.py" in result + assert "file2.txt" in result + assert "subdir/" in result + + def test_view_nonexistent(self, tool_context, tmp_path): + """Test viewing a nonexistent file.""" + from strands.vended_tools.editor.editor import editor + + result = run( + editor.__wrapped__( + command="view", + path=str(tmp_path / "nonexistent.txt"), + tool_context=tool_context, + ) + ) + assert "does not exist" in result.lower() + + def test_view_invalid_range(self, tool_context, tmp_path): + """Test viewing with invalid range.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\n") + + result = run( + editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[0, 2], + tool_context=tool_context, + ) + ) + assert "error" in result.lower() + + def test_create_file(self, tool_context, tmp_path): + """Test creating a new file.""" + from strands.vended_tools.editor.editor import editor + + new_file = tmp_path / "new_file.py" + + result = run( + editor.__wrapped__( + command="create", + path=str(new_file), + file_text="print('hello')\n", + tool_context=tool_context, + ) + ) + assert "created" in result.lower() + assert new_file.read_text() == "print('hello')\n" + + def test_create_existing_file(self, tool_context, tmp_path): + """Test creating a file that already exists.""" + from strands.vended_tools.editor.editor import editor + + existing = tmp_path / "existing.py" + existing.write_text("original") + + result = run( + editor.__wrapped__( + command="create", + path=str(existing), + file_text="new content", + tool_context=tool_context, + ) + ) + assert "already exists" in result.lower() + assert existing.read_text() == "original" + + def test_create_missing_file_text(self, tool_context, tmp_path): + """Test create without file_text.""" + from strands.vended_tools.editor.editor import editor + + result = run( + editor.__wrapped__( + command="create", + path=str(tmp_path / "new.py"), + tool_context=tool_context, + ) + ) + assert "file_text" in result.lower() + + def test_str_replace_unique(self, tool_context, tmp_path): + """Test str_replace with a unique match.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("def hello():\n return 'hello'\n") + + result = run( + editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="return 'hello'", + new_str="return 'world'", + tool_context=tool_context, + ) + ) + assert "edited" in result.lower() + assert "return 'world'" in test_file.read_text() + + def test_str_replace_not_found(self, tool_context, tmp_path): + """Test str_replace when old_str not found.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("def hello():\n return 'hello'\n") + + result = run( + editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="nonexistent string", + new_str="replacement", + tool_context=tool_context, + ) + ) + assert "did not appear" in result.lower() + + def test_str_replace_multiple_occurrences(self, tool_context, tmp_path): + """Test str_replace rejects multiple occurrences.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("x = 1\ny = 1\nz = 1\n") + + result = run( + editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="= 1", + new_str="= 2", + tool_context=tool_context, + ) + ) + assert "multiple" in result.lower() + assert test_file.read_text() == "x = 1\ny = 1\nz = 1\n" + + def test_str_replace_deletion(self, tool_context, tmp_path): + """Test str_replace with empty new_str (deletion).""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("# TODO: remove this\ndef main():\n pass\n") + + result = run( + editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="# TODO: remove this\n", + new_str="", + tool_context=tool_context, + ) + ) + assert "edited" in result.lower() + assert "TODO" not in test_file.read_text() + + def test_insert(self, tool_context, tmp_path): + """Test inserting text at a line.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("line 1\nline 3\n") + + result = run( + editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=1, + new_str="line 2", + tool_context=tool_context, + ) + ) + assert "edited" in result.lower() + content = test_file.read_text() + assert "line 1\nline 2\nline 3\n" == content + + def test_insert_at_beginning(self, tool_context, tmp_path): + """Test inserting at the beginning of a file.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("line 2\nline 3\n") + + result = run( + editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=0, + new_str="line 1", + tool_context=tool_context, + ) + ) + assert "edited" in result.lower() + assert test_file.read_text().startswith("line 1\n") + + def test_insert_invalid_line(self, tool_context, tmp_path): + """Test insert with invalid line number.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("line 1\n") + + result = run( + editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=999, + new_str="new line", + tool_context=tool_context, + ) + ) + assert "error" in result.lower() + + def test_undo_edit(self, tool_context, tmp_path): + """Test undo_edit reverting a str_replace.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "test.py" + test_file.write_text("original content\n") + + run( + editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="original content", + new_str="modified content", + tool_context=tool_context, + ) + ) + assert "modified content" in test_file.read_text() + + result = run( + editor.__wrapped__( + command="undo_edit", + path=str(test_file), + tool_context=tool_context, + ) + ) + assert "reverted" in result.lower() + assert "original content" in test_file.read_text() + + def test_undo_no_history(self, tool_context, tmp_path): + """Test undo_edit when no history exists.""" + from strands.vended_tools.editor.editor import editor + + result = run( + editor.__wrapped__( + command="undo_edit", + path=str(tmp_path / "nonexistent.py"), + tool_context=tool_context, + ) + ) + assert "no edit history" in result.lower() + + def test_relative_path_allowed_by_default(self, tool_context, tmp_path): + """Test that relative paths are passed through to sandbox by default.""" + from strands.vended_tools.editor.editor import editor + + # Create a file using a relative path (the HostSandbox resolves it) + test_file = tmp_path / "relative_test.txt" + test_file.write_text("relative content\n") + + # Use the absolute path — the key point is no validation error + result = run( + editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + ) + assert "relative content" in result + + def test_relative_path_rejected_when_configured(self, tool_context, mock_agent): + """Test that relative paths are rejected when require_absolute_paths is True.""" + from strands.vended_tools.editor.editor import editor + + mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) + + result = run( + editor.__wrapped__( + command="view", + path="relative/path.py", + tool_context=tool_context, + ) + ) + assert "not an absolute path" in result.lower() + + def test_path_traversal_allowed_by_default(self, tool_context, tmp_path): + """Test that paths with .. are passed through to sandbox by default.""" + from strands.vended_tools.editor.editor import editor + + # Create a nested structure + subdir = tmp_path / "sub" + subdir.mkdir() + test_file = tmp_path / "traversal_test.txt" + test_file.write_text("traversal content\n") + + # Use a path with .. — should NOT be rejected by default + traversal_path = str(subdir / ".." / "traversal_test.txt") + result = run( + editor.__wrapped__( + command="view", + path=traversal_path, + tool_context=tool_context, + ) + ) + # The sandbox resolves the path — should show file content + assert "traversal content" in result + + def test_path_traversal_rejected_when_configured(self, tool_context, mock_agent): + """Test that path traversal is rejected when require_absolute_paths is True.""" + from strands.vended_tools.editor.editor import editor + + mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) + + result = run( + editor.__wrapped__( + command="view", + path="/tmp/../etc/passwd", + tool_context=tool_context, + ) + ) + assert "not allowed" in result.lower() + + def test_noop_sandbox(self, tool_context, mock_agent, tmp_path): + """Test editor with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + from strands.vended_tools.editor.editor import editor + + result = run( + editor.__wrapped__( + command="view", + path=str(tmp_path / "test.py"), + tool_context=tool_context, + ) + ) + assert "error" in result.lower() + + def test_max_file_size(self, tool_context, mock_agent, tmp_path): + """Test max file size configuration.""" + from strands.vended_tools.editor.editor import editor + + mock_agent.state.set("strands_editor_tool", {"max_file_size": 10}) + + test_file = tmp_path / "large.txt" + test_file.write_text("a" * 100) + + result = run( + editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + ) + assert "exceeds" in result.lower() + + +# ============================================================ +# Python REPL Tool Tests +# ============================================================ + + +class TestPythonReplTool: + """Tests for the python_repl vended tool.""" + + def test_basic_code(self, tool_context): + """Test basic Python code execution returns result.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="print('hello from python')", + tool_context=tool_context, + ) + ) + ) + assert "hello from python" in result + + def test_basic_code_streams_chunks(self, tool_context): + """Test that python_repl yields StreamChunk objects during execution.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="print('hello from python')", + tool_context=tool_context, + ) + ) + ) + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + assert len(stdout_chunks) >= 1 + assert any("hello from python" in c.data for c in stdout_chunks) + + def test_stderr_streams_as_stderr_chunks(self, tool_context): + """Test that stderr from Python code yields stderr StreamChunks.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="import sys; print('err_msg', file=sys.stderr)", + tool_context=tool_context, + ) + ) + ) + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stderr_chunks) >= 1 + assert any("err_msg" in c.data for c in stderr_chunks) + + def test_code_with_math(self, tool_context): + """Test Python math execution.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="print(2 + 2)", + tool_context=tool_context, + ) + ) + ) + assert "4" in result + + def test_code_with_error(self, tool_context): + """Test Python code that raises an error.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="raise ValueError('test error')", + tool_context=tool_context, + ) + ) + ) + assert "test error" in result or "ValueError" in result + + def test_code_with_import(self, tool_context): + """Test Python code with imports.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="import json; print(json.dumps({'key': 'value'}))", + tool_context=tool_context, + ) + ) + ) + assert "key" in result + + def test_timeout(self, tool_context): + """Test code execution timeout.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + timeout=1, + tool_context=tool_context, + ) + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + def test_config_timeout(self, tool_context, mock_agent): + """Test timeout from config.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + mock_agent.state.set("strands_python_repl_tool", {"timeout": 1}) + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + tool_context=tool_context, + ) + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + def test_reset(self, tool_context): + """Test REPL reset.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="", + reset=True, + tool_context=tool_context, + ) + ) + ) + assert "reset" in result.lower() + + def test_multiline_code(self, tool_context): + """Test multiline Python code.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + code = """ +def fibonacci(n): + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + +print(fibonacci(10)) +""" + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code=code, + tool_context=tool_context, + ) + ) + ) + assert "55" in result + + def test_no_output(self, tool_context): + """Test code with no output.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="x = 42", + tool_context=tool_context, + ) + ) + ) + assert result == "(no output)" + + def test_noop_sandbox(self, tool_context, mock_agent): + """Test python_repl with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="print('test')", + tool_context=tool_context, + ) + ) + ) + assert "error" in result.lower() + + def test_stream_chunk_types_are_correct(self, tool_context): + """Test that all yielded chunks are proper StreamChunk instances.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="import sys; print('out'); print('err', file=sys.stderr)", + tool_context=tool_context, + ) + ) + ) + for chunk in chunks: + assert isinstance(chunk, StreamChunk) + assert hasattr(chunk, "data") + assert hasattr(chunk, "stream_type") + assert chunk.stream_type in ("stdout", "stderr") + + +# ============================================================ +# Integration Tests +# ============================================================ + + +class TestVendedToolsImport: + """Test that vended tools can be imported from the package.""" + + def test_import_from_vended_tools(self): + """Test importing from strands.vended_tools.""" + from strands.vended_tools import editor, python_repl, shell + + assert shell is not None + assert editor is not None + assert python_repl is not None + + def test_import_individual_tools(self): + """Test importing individual tools.""" + from strands.vended_tools.editor import editor + from strands.vended_tools.python_repl import python_repl + from strands.vended_tools.shell import shell + + assert shell is not None + assert editor is not None + assert python_repl is not None + + def test_tools_have_tool_spec(self): + """Test that tools have proper tool specs.""" + from strands.vended_tools import editor, python_repl, shell + + assert shell.tool_name == "shell" + assert editor.tool_name == "editor" + assert python_repl.tool_name == "python_repl" + + for t in [shell, editor, python_repl]: + spec = t.tool_spec + assert "name" in spec + assert "description" in spec + assert "inputSchema" in spec + + def test_shell_tool_spec_shape(self): + """Test shell tool spec matches expected shape.""" + from strands.vended_tools import shell + + spec = shell.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "command" in props + assert "timeout" in props + assert "restart" in props + assert schema.get("required") == ["command"] + + def test_editor_tool_spec_shape(self): + """Test editor tool spec matches expected shape.""" + from strands.vended_tools import editor + + spec = editor.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "command" in props + assert "path" in props + assert "file_text" in props + assert "old_str" in props + assert "new_str" in props + assert "insert_line" in props + assert "view_range" in props + assert set(schema.get("required", [])) == {"command", "path"} + + def test_python_repl_tool_spec_shape(self): + """Test python_repl tool spec matches expected shape.""" + from strands.vended_tools import python_repl + + spec = python_repl.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "code" in props + assert "timeout" in props + assert "reset" in props + assert schema.get("required") == ["code"] + + def test_shell_is_async_generator(self): + """Test that shell is detected as an async generator function.""" + import inspect + + from strands.vended_tools.shell.shell import shell + + assert inspect.isasyncgenfunction(shell.__wrapped__) + + def test_python_repl_is_async_generator(self): + """Test that python_repl is detected as an async generator function.""" + import inspect + + from strands.vended_tools.python_repl.python_repl import python_repl + + assert inspect.isasyncgenfunction(python_repl.__wrapped__) + + def test_editor_is_not_async_generator(self): + """Test that editor is a regular async function (not generator).""" + import inspect + + from strands.vended_tools.editor.editor import editor + + assert not inspect.isasyncgenfunction(editor.__wrapped__) + assert inspect.iscoroutinefunction(editor.__wrapped__) + + +# ============================================================ +# Configuration Persistence Tests +# ============================================================ + + +class TestConfigPersistence: + """Test that tool configuration persists via agent state.""" + + def test_shell_config_persists(self, mock_agent, agent_state): + """Test shell config is read from agent state.""" + agent_state.set("strands_shell_tool", {"timeout": 300}) + config = agent_state.get("strands_shell_tool") + assert config["timeout"] == 300 + + def test_editor_config_persists(self, mock_agent, agent_state): + """Test editor config is read from agent state.""" + agent_state.set("strands_editor_tool", {"max_file_size": 2097152}) + config = agent_state.get("strands_editor_tool") + assert config["max_file_size"] == 2097152 + + def test_python_repl_config_persists(self, mock_agent, agent_state): + """Test python_repl config is read from agent state.""" + agent_state.set("strands_python_repl_tool", {"timeout": 60}) + config = agent_state.get("strands_python_repl_tool") + assert config["timeout"] == 60 + + def test_undo_state_persists(self, tool_context, tmp_path): + """Test that undo state is stored in agent state.""" + from strands.vended_tools.editor.editor import editor + + test_file = tmp_path / "undo_test.py" + test_file.write_text("original\n") + + run( + editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="original", + new_str="modified", + tool_context=tool_context, + ) + ) + + undo_state = tool_context.agent.state.get("_strands_editor_undo") + assert undo_state is not None + assert str(test_file) in undo_state + + +# ============================================================ +# Streaming Integration Tests +# ============================================================ + + +class TestStreamingIntegration: + """Test the streaming behavior of tools end-to-end.""" + + def test_shell_streams_before_result(self, tool_context): + """Test that shell yields chunks BEFORE the final result.""" + from strands.vended_tools.shell.shell import shell + + all_items = [] + + async def collect_all(): + async for item in shell.__wrapped__( + command="echo streaming_test", tool_context=tool_context + ): + all_items.append(item) + + run(collect_all()) + + # Should have at least 2 items: chunk(s) + final result + assert len(all_items) >= 2 + # Last item should be the string result + assert isinstance(all_items[-1], str) + # Earlier items should include StreamChunk + stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] + assert len(stream_items) >= 1 + + def test_python_repl_streams_before_result(self, tool_context): + """Test that python_repl yields chunks BEFORE the final result.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + all_items = [] + + async def collect_all(): + async for item in python_repl.__wrapped__( + code="print('streaming_test')", tool_context=tool_context + ): + all_items.append(item) + + run(collect_all()) + + assert len(all_items) >= 2 + assert isinstance(all_items[-1], str) + stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] + assert len(stream_items) >= 1 + + def test_shell_error_no_streaming_on_timeout(self, tool_context): + """Test that timeout errors don't yield any StreamChunks before the error.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator( + shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context) + ) + ) + # On timeout, we should get the error message directly + assert "timed out" in result.lower() or "error" in result.lower() + + def test_python_repl_error_no_streaming_on_timeout(self, tool_context): + """Test that timeout errors don't yield chunks before the error.""" + from strands.vended_tools.python_repl.python_repl import python_repl + + chunks, result = run( + collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + timeout=1, + tool_context=tool_context, + ) + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + def test_shell_stream_data_matches_result(self, tool_context): + """Test that streamed chunk data matches the final result content.""" + from strands.vended_tools.shell.shell import shell + + chunks, result = run( + collect_generator( + shell.__wrapped__(command="echo precise_output_42", tool_context=tool_context) + ) + ) + # The streamed chunks should contain the same data as the final result + all_chunk_data = "".join(c.data for c in chunks) + assert "precise_output_42" in all_chunk_data + assert "precise_output_42" in result From c62a68d5ad2fd55b9caca4824ab3f8e7ecabc59a Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Tue, 28 Apr 2026 17:06:39 +0000 Subject: [PATCH 2/7] =?UTF-8?q?refactor:=20address=20review=20feedback=20?= =?UTF-8?q?=E2=80=94=20flatten=20modules,=20split=20tests,=20fix=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses all review feedback from @mkmeral and automated review: 1. **Flatten module structure** (per @mkmeral): Remove sub-packages (shell/, editor/, python_repl/ directories). Tools are now flat files under src/strands/vended_tools/ — aligns with 'Prefer Flat Namespaces' decision record. 2. **Split test file** (per @mkmeral): The single 1161-line test_vended_tools.py is now split into test_shell.py, test_editor.py, test_python_repl.py, and test_init.py — mirrors src/ structure per AGENTS.md convention. 3. **Fix deprecated asyncio pattern**: Replace asyncio.get_event_loop() with @pytest.mark.asyncio + async def tests throughout. 4. **Type sandbox as Sandbox**: Editor helper functions now use proper Sandbox type instead of Any — enables mypy checking and IDE support. 5. **Use read_text()/write_text()**: Replace manual .decode()/.encode() with the Sandbox convenience API as designed. 6. **Optimize cwd tracking**: Append '; echo __STRANDS_CWD__; pwd' to commands instead of running a separate pwd call after every execution — halves sandbox interactions for remote sandboxes. 7. **Extract shared _get_config**: New _utils.py with get_tool_config() eliminates the copy-pasted helper across all three tools. 8. **Narrow except clause**: Replace broad 'except Exception' catch-all with targeted (FileNotFoundError, UnicodeDecodeError, OSError, ValueError) — unexpected errors now propagate for debugging. All 65 tests pass. --- src/strands/vended_tools/_utils.py | 25 + .../vended_tools/{editor => }/editor.py | 63 +- src/strands/vended_tools/editor/__init__.py | 20 - .../{python_repl => }/python_repl.py | 27 +- .../vended_tools/python_repl/__init__.py | 20 - src/strands/vended_tools/{shell => }/shell.py | 70 +- src/strands/vended_tools/shell/__init__.py | 20 - tests/strands/vended_tools/conftest.py | 69 + tests/strands/vended_tools/test_editor.py | 360 +++++ tests/strands/vended_tools/test_init.py | 185 +++ .../strands/vended_tools/test_python_repl.py | 177 +++ tests/strands/vended_tools/test_shell.py | 160 +++ .../strands/vended_tools/test_vended_tools.py | 1161 ----------------- 13 files changed, 1068 insertions(+), 1289 deletions(-) create mode 100644 src/strands/vended_tools/_utils.py rename src/strands/vended_tools/{editor => }/editor.py (90%) delete mode 100644 src/strands/vended_tools/editor/__init__.py rename src/strands/vended_tools/{python_repl => }/python_repl.py (89%) delete mode 100644 src/strands/vended_tools/python_repl/__init__.py rename src/strands/vended_tools/{shell => }/shell.py (75%) delete mode 100644 src/strands/vended_tools/shell/__init__.py create mode 100644 tests/strands/vended_tools/conftest.py create mode 100644 tests/strands/vended_tools/test_editor.py create mode 100644 tests/strands/vended_tools/test_init.py create mode 100644 tests/strands/vended_tools/test_python_repl.py create mode 100644 tests/strands/vended_tools/test_shell.py delete mode 100644 tests/strands/vended_tools/test_vended_tools.py diff --git a/src/strands/vended_tools/_utils.py b/src/strands/vended_tools/_utils.py new file mode 100644 index 000000000..e5e03c577 --- /dev/null +++ b/src/strands/vended_tools/_utils.py @@ -0,0 +1,25 @@ +"""Shared utilities for vended tools. + +Provides common helper functions used across all vended tools to avoid +code duplication. +""" + +from typing import Any + +from ..types.tools import ToolContext + + +def get_tool_config(tool_context: ToolContext, state_key: str) -> dict[str, Any]: + """Read tool configuration from agent state. + + All vended tools store their configuration in agent state under + a namespaced key. This helper standardizes the pattern. + + Args: + tool_context: The tool context providing access to agent state. + state_key: The state key for the tool's configuration. + + Returns: + Configuration dict. Empty dict if no config is set. + """ + return tool_context.agent.state.get(state_key) or {} diff --git a/src/strands/vended_tools/editor/editor.py b/src/strands/vended_tools/editor.py similarity index 90% rename from src/strands/vended_tools/editor/editor.py rename to src/strands/vended_tools/editor.py index 8fe3af25b..772b2a5f6 100644 --- a/src/strands/vended_tools/editor/editor.py +++ b/src/strands/vended_tools/editor.py @@ -2,7 +2,7 @@ Provides view, create, str_replace, insert, and undo_edit operations on files in the agent's sandbox. The tool delegates all file I/O to the sandbox's -``read_file``, ``write_file``, and ``list_files`` methods. +``read_text``, ``write_text``, and ``list_files`` methods. The tool shape matches Anthropic's ``text_editor`` built-in tool — 5 commands, 7 parameters. This means models trained on Anthropic's tool spec will work @@ -16,13 +16,26 @@ paths containing ``..``. When False (the default), paths are passed through to the sandbox without filesystem-level validation — the sandbox decides what a path means. Default: False. + +Example:: + + from strands import Agent + from strands.vended_tools import editor + + agent = Agent(tools=[editor]) + agent("View the contents of /tmp/example.py") + + # Configure max file size + agent.state.set("strands_editor_tool", {"max_file_size": 2097152}) """ import logging from typing import Any, Literal -from ...tools.decorator import tool -from ...types.tools import ToolContext +from ..sandbox.base import Sandbox +from ..tools.decorator import tool +from ..types.tools import ToolContext +from ._utils import get_tool_config logger = logging.getLogger(__name__) @@ -42,11 +55,6 @@ MAX_DIRECTORY_DEPTH = 2 -def _get_config(tool_context: ToolContext) -> dict[str, Any]: - """Read editor tool configuration from agent state.""" - return tool_context.agent.state.get(STATE_KEY) or {} - - def _make_output(content: str, descriptor: str, init_line: int = 1) -> str: """Format file content with line numbers (cat -n style). @@ -144,8 +152,8 @@ async def editor( Returns: Result of the operation — file contents, success message, or error. """ - config = _get_config(tool_context) - sandbox = tool_context.agent.sandbox + config = get_tool_config(tool_context, STATE_KEY) + sandbox: Sandbox = tool_context.agent.sandbox # Path validation is opt-in. By default, paths are passed straight through # to the sandbox without filesystem-level validation. This allows sandboxes @@ -182,11 +190,11 @@ async def editor( return f"Error: Unknown command: {command}" # type: ignore[unreachable] except NotImplementedError as e: return f"Error: Sandbox does not support this operation — {e}" - except Exception as e: + except (FileNotFoundError, UnicodeDecodeError, OSError, ValueError) as e: return f"Error: {e}" -async def _handle_view(sandbox: Any, config: dict[str, Any], path: str, view_range: list[int] | None) -> str: +async def _handle_view(sandbox: Sandbox, config: dict[str, Any], path: str, view_range: list[int] | None) -> str: """Handle the view command.""" # Check if path is a directory try: @@ -205,7 +213,7 @@ async def _handle_view(sandbox: Any, config: dict[str, Any], path: str, view_ran # Read file max_size = config.get("max_file_size", DEFAULT_MAX_FILE_SIZE) try: - content = (await sandbox.read_file(path)).decode("utf-8") + content = await sandbox.read_text(path) except FileNotFoundError: return f"Error: The path {path} does not exist. Please provide a valid path." except UnicodeDecodeError: @@ -237,14 +245,14 @@ async def _handle_view(sandbox: Any, config: dict[str, Any], path: str, view_ran return f"Error: Invalid `view_range`: [{start}, {end}]. Second element must be >= first element." if end == -1: - selected = lines[start - 1 :] + selected = lines[start - 1:] else: - selected = lines[start - 1 : end] + selected = lines[start - 1:end] return _make_output("\n".join(selected), path, init_line=start) -async def _handle_create(sandbox: Any, tool_context: ToolContext, path: str, file_text: str) -> str: +async def _handle_create(sandbox: Sandbox, tool_context: ToolContext, path: str, file_text: str) -> str: """Handle the create command.""" # Check if file already exists try: @@ -253,12 +261,12 @@ async def _handle_create(sandbox: Any, tool_context: ToolContext, path: str, fil except (FileNotFoundError, OSError): pass # File doesn't exist, good - await sandbox.write_file(path, file_text.encode("utf-8")) + await sandbox.write_text(path, file_text) return f"File created successfully at: {path}" async def _handle_str_replace( - sandbox: Any, + sandbox: Sandbox, tool_context: ToolContext, config: dict[str, Any], path: str, @@ -267,7 +275,7 @@ async def _handle_str_replace( ) -> str: """Handle the str_replace command.""" try: - content = (await sandbox.read_file(path)).decode("utf-8") + content = await sandbox.read_text(path) except FileNotFoundError: return f"Error: The path {path} does not exist." @@ -285,13 +293,12 @@ async def _handle_str_replace( if count > 1: # Find line numbers of all occurrences lines = content.split("\n") - line_nums = [] + line_nums: list[int] = [] for i, line in enumerate(lines): if expanded_old in line: line_nums.append(i + 1) # Also check multi-line matches if not line_nums: - # old_str spans multiple lines, find approximate locations idx = 0 while True: idx = content.find(expanded_old, idx) @@ -312,7 +319,7 @@ async def _handle_str_replace( new_content = content.replace(expanded_old, expanded_new, 1) # Write back - await sandbox.write_file(path, new_content.encode("utf-8")) + await sandbox.write_text(path, new_content) # Generate snippet around the change replace_idx = content.find(expanded_old) @@ -334,7 +341,7 @@ async def _handle_str_replace( async def _handle_insert( - sandbox: Any, + sandbox: Sandbox, tool_context: ToolContext, config: dict[str, Any], path: str, @@ -343,7 +350,7 @@ async def _handle_insert( ) -> str: """Handle the insert command.""" try: - content = (await sandbox.read_file(path)).decode("utf-8") + content = await sandbox.read_text(path) except FileNotFoundError: return f"Error: The path {path} does not exist." @@ -368,7 +375,7 @@ async def _handle_insert( new_lines = lines[:insert_line] + new_str_lines + lines[insert_line:] new_content = "\n".join(new_lines) - await sandbox.write_file(path, new_content.encode("utf-8")) + await sandbox.write_text(path, new_content) # Generate snippet start = max(0, insert_line - SNIPPET_LINES) @@ -382,7 +389,7 @@ async def _handle_insert( ) -async def _handle_undo(sandbox: Any, tool_context: ToolContext, path: str) -> str: +async def _handle_undo(sandbox: Sandbox, tool_context: ToolContext, path: str) -> str: """Handle the undo_edit command.""" previous_content = _get_undo(tool_context, path) if previous_content is None: @@ -390,12 +397,12 @@ async def _handle_undo(sandbox: Any, tool_context: ToolContext, path: str) -> st # Read current content for future undo try: - current = (await sandbox.read_file(path)).decode("utf-8") + current = await sandbox.read_text(path) except FileNotFoundError: current = "" # Write the previous content back - await sandbox.write_file(path, previous_content.encode("utf-8")) + await sandbox.write_text(path, previous_content) # Save current as new undo (so undo is toggleable) _save_undo(tool_context, path, current) diff --git a/src/strands/vended_tools/editor/__init__.py b/src/strands/vended_tools/editor/__init__.py deleted file mode 100644 index b3e23e74d..000000000 --- a/src/strands/vended_tools/editor/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""File editor tool for viewing, creating, and editing files in the agent's sandbox. - -Example:: - - from strands import Agent - from strands.vended_tools import editor - - agent = Agent(tools=[editor]) - agent("View the contents of /tmp/example.py") - -Configuration via agent state:: - - agent.state.set("strands_editor_tool", { - "max_file_size": 1048576, # Maximum file size in bytes (default: 1MB) - }) -""" - -from .editor import editor - -__all__ = ["editor"] diff --git a/src/strands/vended_tools/python_repl/python_repl.py b/src/strands/vended_tools/python_repl.py similarity index 89% rename from src/strands/vended_tools/python_repl/python_repl.py rename to src/strands/vended_tools/python_repl.py index 01cc8c262..a15e9e0ec 100644 --- a/src/strands/vended_tools/python_repl/python_repl.py +++ b/src/strands/vended_tools/python_repl.py @@ -13,6 +13,17 @@ - ``timeout`` (int): Default timeout in seconds for code execution. Overridden by the per-call ``timeout`` parameter. Default: 30. + +Example:: + + from strands import Agent + from strands.vended_tools import python_repl + + agent = Agent(tools=[python_repl]) + agent("Calculate the first 10 Fibonacci numbers") + + # Configure timeout + agent.state.set("strands_python_repl_tool", {"timeout": 60}) """ import asyncio @@ -20,9 +31,10 @@ from collections.abc import AsyncGenerator from typing import Any -from ...sandbox.base import ExecutionResult, StreamChunk -from ...tools.decorator import tool -from ...types.tools import ToolContext +from ..sandbox.base import ExecutionResult, StreamChunk +from ..tools.decorator import tool +from ..types.tools import ToolContext +from ._utils import get_tool_config logger = logging.getLogger(__name__) @@ -33,11 +45,6 @@ DEFAULT_TIMEOUT = 30 -def _get_config(tool_context: ToolContext) -> dict[str, Any]: - """Read python_repl tool configuration from agent state.""" - return tool_context.agent.state.get(STATE_KEY) or {} - - @tool(context=True) async def python_repl( code: str, @@ -69,7 +76,7 @@ async def python_repl( :class:`~strands.sandbox.base.StreamChunk` objects during execution (wrapped as ``ToolStreamEvent`` by the SDK), then a final string result. """ - config = _get_config(tool_context) + config = get_tool_config(tool_context, STATE_KEY) sandbox = tool_context.agent.sandbox # Handle reset @@ -106,7 +113,7 @@ async def python_repl( except NotImplementedError: yield "Error: Sandbox does not support code execution (NoOpSandbox)." return - except Exception as e: + except OSError as e: yield f"Error: {e}" return diff --git a/src/strands/vended_tools/python_repl/__init__.py b/src/strands/vended_tools/python_repl/__init__.py deleted file mode 100644 index cd0268fce..000000000 --- a/src/strands/vended_tools/python_repl/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Python REPL tool for executing Python code in the agent's sandbox. - -Example:: - - from strands import Agent - from strands.vended_tools import python_repl - - agent = Agent(tools=[python_repl]) - agent("Calculate the first 10 Fibonacci numbers") - -Configuration via agent state:: - - agent.state.set("strands_python_repl_tool", { - "timeout": 30, # Default timeout in seconds - }) -""" - -from .python_repl import python_repl - -__all__ = ["python_repl"] diff --git a/src/strands/vended_tools/shell/shell.py b/src/strands/vended_tools/shell.py similarity index 75% rename from src/strands/vended_tools/shell/shell.py rename to src/strands/vended_tools/shell.py index 4cfdd1f38..005edf0e2 100644 --- a/src/strands/vended_tools/shell/shell.py +++ b/src/strands/vended_tools/shell.py @@ -13,6 +13,17 @@ - ``timeout`` (int): Default timeout in seconds. Overridden by the per-call ``timeout`` parameter. Default: 120. + +Example:: + + from strands import Agent + from strands.vended_tools import shell + + agent = Agent(tools=[shell]) + agent("List all Python files in the current directory") + + # Configure timeout + agent.state.set("strands_shell_tool", {"timeout": 60}) """ import asyncio @@ -20,9 +31,10 @@ from collections.abc import AsyncGenerator from typing import Any -from ...sandbox.base import ExecutionResult, StreamChunk -from ...tools.decorator import tool -from ...types.tools import ToolContext +from ..sandbox.base import ExecutionResult, StreamChunk +from ..tools.decorator import tool +from ..types.tools import ToolContext +from ._utils import get_tool_config logger = logging.getLogger(__name__) @@ -33,18 +45,6 @@ DEFAULT_TIMEOUT = 120 -def _get_config(tool_context: ToolContext) -> dict[str, Any]: - """Read shell tool configuration from agent state. - - Args: - tool_context: The tool context providing access to agent state. - - Returns: - Configuration dict. Empty dict if no config is set. - """ - return tool_context.agent.state.get(STATE_KEY) or {} - - @tool(context=True) async def shell( command: str, @@ -77,7 +77,7 @@ async def shell( :class:`~strands.sandbox.base.StreamChunk` objects during execution (wrapped as ``ToolStreamEvent`` by the SDK), then a final string result. """ - config = _get_config(tool_context) + config = get_tool_config(tool_context, STATE_KEY) sandbox = tool_context.agent.sandbox # Handle restart @@ -96,11 +96,16 @@ async def shell( shell_state = tool_context.agent.state.get("_strands_shell_state") or {} cwd = shell_state.get("cwd") + # Wrap command with cwd tracking in a single execution to avoid + # a separate `pwd` call after every command (reduces latency for + # remote sandboxes). Appends `; pwd` to capture the final cwd. + tracked_command = f"{command}; echo __STRANDS_CWD__; pwd" + # Execute via sandbox streaming result: ExecutionResult | None = None try: async for chunk in sandbox.execute_streaming( - command, + tracked_command, timeout=effective_timeout, cwd=cwd, ): @@ -115,7 +120,7 @@ async def shell( except NotImplementedError: yield "Error: Sandbox does not support command execution (NoOpSandbox)." return - except Exception as e: + except OSError as e: yield f"Error: {e}" return @@ -123,21 +128,26 @@ async def shell( yield "Error: Sandbox did not return an execution result." return - # Track working directory changes - try: - cwd_result = await sandbox.execute("pwd", timeout=5, cwd=cwd) - if cwd_result.exit_code == 0: - new_cwd = cwd_result.stdout.strip() - if new_cwd: - shell_state["cwd"] = new_cwd - tool_context.agent.state.set("_strands_shell_state", shell_state) - except Exception: - pass # Best-effort cwd tracking + # Extract cwd from output (after __STRANDS_CWD__ marker) + stdout = result.stdout or "" + cwd_marker = "__STRANDS_CWD__" + if cwd_marker in stdout: + parts = stdout.split(cwd_marker, 1) + # The actual command output is before the marker + stdout = parts[0].rstrip("\n") + # The cwd is the line after the marker + new_cwd = parts[1].strip() + if new_cwd: + shell_state["cwd"] = new_cwd + tool_context.agent.state.set("_strands_shell_state", shell_state) + else: + # Fallback: no marker found (sandbox may have filtered it) + pass # Format final output (becomes the ToolResult) output_parts = [] - if result.stdout: - output_parts.append(result.stdout) + if stdout: + output_parts.append(stdout) if result.stderr: output_parts.append(result.stderr) diff --git a/src/strands/vended_tools/shell/__init__.py b/src/strands/vended_tools/shell/__init__.py deleted file mode 100644 index a04ae399a..000000000 --- a/src/strands/vended_tools/shell/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Shell tool for executing commands in the agent's sandbox. - -Example:: - - from strands import Agent - from strands.vended_tools import shell - - agent = Agent(tools=[shell]) - agent("List all Python files in the current directory") - -Configuration via agent state:: - - agent.state.set("strands_shell_tool", { - "timeout": 120, # Default timeout in seconds - }) -""" - -from .shell import shell - -__all__ = ["shell"] diff --git a/tests/strands/vended_tools/conftest.py b/tests/strands/vended_tools/conftest.py new file mode 100644 index 000000000..d4da4ced6 --- /dev/null +++ b/tests/strands/vended_tools/conftest.py @@ -0,0 +1,69 @@ +"""Shared fixtures for vended tools tests.""" + +import uuid +from unittest.mock import MagicMock + +import pytest + +from strands.agent.state import AgentState +from strands.sandbox.base import StreamChunk +from strands.sandbox.host import HostSandbox +from strands.types.tools import ToolContext, ToolUse + + +@pytest.fixture +def sandbox(tmp_path): + """Create a HostSandbox for testing.""" + return HostSandbox(working_dir=str(tmp_path)) + + +@pytest.fixture +def agent_state(): + """Create a fresh AgentState.""" + return AgentState() + + +@pytest.fixture +def mock_agent(sandbox, agent_state): + """Create a mock agent with sandbox and state.""" + agent = MagicMock() + agent.sandbox = sandbox + agent.state = agent_state + return agent + + +@pytest.fixture +def tool_use(): + """Create a mock tool use.""" + return ToolUse( + toolUseId=str(uuid.uuid4()), + name="test_tool", + input={}, + ) + + +@pytest.fixture +def tool_context(mock_agent, tool_use): + """Create a ToolContext for testing.""" + return ToolContext( + tool_use=tool_use, + agent=mock_agent, + invocation_state={}, + ) + + +async def collect_generator(gen): + """Collect all values from an async generator. + + Returns (stream_chunks, final_result) where stream_chunks are all + StreamChunk objects yielded, and final_result is the last non-StreamChunk + value (the formatted result string). + """ + chunks = [] + final = None + async for item in gen: + if isinstance(item, StreamChunk): + chunks.append(item) + else: + final = item + return chunks, final diff --git a/tests/strands/vended_tools/test_editor.py b/tests/strands/vended_tools/test_editor.py new file mode 100644 index 000000000..c3e63b0df --- /dev/null +++ b/tests/strands/vended_tools/test_editor.py @@ -0,0 +1,360 @@ +"""Tests for the editor vended tool.""" + +import pytest + +from strands.sandbox.noop import NoOpSandbox +from strands.vended_tools.editor import editor + + +class TestEditorTool: + """Tests for the editor vended tool.""" + + @pytest.mark.asyncio + async def test_view_file(self, tool_context, tmp_path): + """Test viewing a file.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + assert "line 1" in result + assert "line 2" in result + assert "line 3" in result + assert "cat -n" in result + + @pytest.mark.asyncio + async def test_view_with_range(self, tool_context, tmp_path): + """Test viewing a file with line range.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\nline 4\nline 5\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[2, 4], + tool_context=tool_context, + ) + assert "line 2" in result + assert "line 4" in result + assert " 1" not in result + + @pytest.mark.asyncio + async def test_view_with_range_end_minus_one(self, tool_context, tmp_path): + """Test viewing with -1 as end of range.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[2, -1], + tool_context=tool_context, + ) + assert "line 2" in result + assert "line 3" in result + + @pytest.mark.asyncio + async def test_view_directory(self, tool_context, tmp_path): + """Test viewing a directory listing.""" + (tmp_path / "file1.py").write_text("pass") + (tmp_path / "file2.txt").write_text("hello") + (tmp_path / "subdir").mkdir() + + result = await editor.__wrapped__( + command="view", + path=str(tmp_path), + tool_context=tool_context, + ) + assert "file1.py" in result + assert "file2.txt" in result + assert "subdir/" in result + + @pytest.mark.asyncio + async def test_view_nonexistent(self, tool_context, tmp_path): + """Test viewing a nonexistent file.""" + result = await editor.__wrapped__( + command="view", + path=str(tmp_path / "nonexistent.txt"), + tool_context=tool_context, + ) + assert "does not exist" in result.lower() + + @pytest.mark.asyncio + async def test_view_invalid_range(self, tool_context, tmp_path): + """Test viewing with invalid range.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[0, 2], + tool_context=tool_context, + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_create_file(self, tool_context, tmp_path): + """Test creating a new file.""" + new_file = tmp_path / "new_file.py" + + result = await editor.__wrapped__( + command="create", + path=str(new_file), + file_text="print('hello')\n", + tool_context=tool_context, + ) + assert "created" in result.lower() + assert new_file.read_text() == "print('hello')\n" + + @pytest.mark.asyncio + async def test_create_existing_file(self, tool_context, tmp_path): + """Test creating a file that already exists.""" + existing = tmp_path / "existing.py" + existing.write_text("original") + + result = await editor.__wrapped__( + command="create", + path=str(existing), + file_text="new content", + tool_context=tool_context, + ) + assert "already exists" in result.lower() + assert existing.read_text() == "original" + + @pytest.mark.asyncio + async def test_create_missing_file_text(self, tool_context, tmp_path): + """Test create without file_text.""" + result = await editor.__wrapped__( + command="create", + path=str(tmp_path / "new.py"), + tool_context=tool_context, + ) + assert "file_text" in result.lower() + + @pytest.mark.asyncio + async def test_str_replace_unique(self, tool_context, tmp_path): + """Test str_replace with a unique match.""" + test_file = tmp_path / "test.py" + test_file.write_text("def hello():\n return 'hello'\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="return 'hello'", + new_str="return 'world'", + tool_context=tool_context, + ) + assert "edited" in result.lower() + assert "return 'world'" in test_file.read_text() + + @pytest.mark.asyncio + async def test_str_replace_not_found(self, tool_context, tmp_path): + """Test str_replace when old_str not found.""" + test_file = tmp_path / "test.py" + test_file.write_text("def hello():\n return 'hello'\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="nonexistent string", + new_str="replacement", + tool_context=tool_context, + ) + assert "did not appear" in result.lower() + + @pytest.mark.asyncio + async def test_str_replace_multiple_occurrences(self, tool_context, tmp_path): + """Test str_replace rejects multiple occurrences.""" + test_file = tmp_path / "test.py" + test_file.write_text("x = 1\ny = 1\nz = 1\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="= 1", + new_str="= 2", + tool_context=tool_context, + ) + assert "multiple" in result.lower() + assert test_file.read_text() == "x = 1\ny = 1\nz = 1\n" + + @pytest.mark.asyncio + async def test_str_replace_deletion(self, tool_context, tmp_path): + """Test str_replace with empty new_str (deletion).""" + test_file = tmp_path / "test.py" + test_file.write_text("# TODO: remove this\ndef main():\n pass\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="# TODO: remove this\n", + new_str="", + tool_context=tool_context, + ) + assert "edited" in result.lower() + assert "TODO" not in test_file.read_text() + + @pytest.mark.asyncio + async def test_insert(self, tool_context, tmp_path): + """Test inserting text at a line.""" + test_file = tmp_path / "test.py" + test_file.write_text("line 1\nline 3\n") + + result = await editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=1, + new_str="line 2", + tool_context=tool_context, + ) + assert "edited" in result.lower() + content = test_file.read_text() + assert "line 1\nline 2\nline 3\n" == content + + @pytest.mark.asyncio + async def test_insert_at_beginning(self, tool_context, tmp_path): + """Test inserting at the beginning of a file.""" + test_file = tmp_path / "test.py" + test_file.write_text("line 2\nline 3\n") + + result = await editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=0, + new_str="line 1", + tool_context=tool_context, + ) + assert "edited" in result.lower() + assert test_file.read_text().startswith("line 1\n") + + @pytest.mark.asyncio + async def test_insert_invalid_line(self, tool_context, tmp_path): + """Test insert with invalid line number.""" + test_file = tmp_path / "test.py" + test_file.write_text("line 1\n") + + result = await editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=999, + new_str="new line", + tool_context=tool_context, + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_undo_edit(self, tool_context, tmp_path): + """Test undo_edit reverting a str_replace.""" + test_file = tmp_path / "test.py" + test_file.write_text("original content\n") + + await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="original content", + new_str="modified content", + tool_context=tool_context, + ) + assert "modified content" in test_file.read_text() + + result = await editor.__wrapped__( + command="undo_edit", + path=str(test_file), + tool_context=tool_context, + ) + assert "reverted" in result.lower() + assert "original content" in test_file.read_text() + + @pytest.mark.asyncio + async def test_undo_no_history(self, tool_context, tmp_path): + """Test undo_edit when no history exists.""" + result = await editor.__wrapped__( + command="undo_edit", + path=str(tmp_path / "nonexistent.py"), + tool_context=tool_context, + ) + assert "no edit history" in result.lower() + + @pytest.mark.asyncio + async def test_relative_path_allowed_by_default(self, tool_context, tmp_path): + """Test that relative paths are passed through to sandbox by default.""" + test_file = tmp_path / "relative_test.txt" + test_file.write_text("relative content\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + assert "relative content" in result + + @pytest.mark.asyncio + async def test_relative_path_rejected_when_configured(self, tool_context, mock_agent): + """Test that relative paths are rejected when require_absolute_paths is True.""" + mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) + + result = await editor.__wrapped__( + command="view", + path="relative/path.py", + tool_context=tool_context, + ) + assert "not an absolute path" in result.lower() + + @pytest.mark.asyncio + async def test_path_traversal_allowed_by_default(self, tool_context, tmp_path): + """Test that paths with .. are passed through to sandbox by default.""" + subdir = tmp_path / "sub" + subdir.mkdir() + test_file = tmp_path / "traversal_test.txt" + test_file.write_text("traversal content\n") + + traversal_path = str(subdir / ".." / "traversal_test.txt") + result = await editor.__wrapped__( + command="view", + path=traversal_path, + tool_context=tool_context, + ) + assert "traversal content" in result + + @pytest.mark.asyncio + async def test_path_traversal_rejected_when_configured(self, tool_context, mock_agent): + """Test that path traversal is rejected when require_absolute_paths is True.""" + mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) + + result = await editor.__wrapped__( + command="view", + path="/tmp/../etc/passwd", + tool_context=tool_context, + ) + assert "not allowed" in result.lower() + + @pytest.mark.asyncio + async def test_noop_sandbox(self, tool_context, mock_agent, tmp_path): + """Test editor with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + + result = await editor.__wrapped__( + command="view", + path=str(tmp_path / "test.py"), + tool_context=tool_context, + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_max_file_size(self, tool_context, mock_agent, tmp_path): + """Test max file size configuration.""" + mock_agent.state.set("strands_editor_tool", {"max_file_size": 10}) + + test_file = tmp_path / "large.txt" + test_file.write_text("a" * 100) + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + assert "exceeds" in result.lower() diff --git a/tests/strands/vended_tools/test_init.py b/tests/strands/vended_tools/test_init.py new file mode 100644 index 000000000..9c793b208 --- /dev/null +++ b/tests/strands/vended_tools/test_init.py @@ -0,0 +1,185 @@ +"""Tests for vended tools package imports and tool specs.""" + +import inspect + +import pytest + +from strands.sandbox.base import StreamChunk + +from .conftest import collect_generator + + +class TestVendedToolsImport: + """Test that vended tools can be imported from the package.""" + + def test_import_from_vended_tools(self): + """Test importing from strands.vended_tools.""" + from strands.vended_tools import editor, python_repl, shell + + assert shell is not None + assert editor is not None + assert python_repl is not None + + def test_import_individual_tools(self): + """Test importing individual tools from flat modules.""" + from strands.vended_tools.editor import editor + from strands.vended_tools.python_repl import python_repl + from strands.vended_tools.shell import shell + + assert shell is not None + assert editor is not None + assert python_repl is not None + + def test_tools_have_tool_spec(self): + """Test that tools have proper tool specs.""" + from strands.vended_tools import editor, python_repl, shell + + assert shell.tool_name == "shell" + assert editor.tool_name == "editor" + assert python_repl.tool_name == "python_repl" + + for t in [shell, editor, python_repl]: + spec = t.tool_spec + assert "name" in spec + assert "description" in spec + assert "inputSchema" in spec + + def test_shell_tool_spec_shape(self): + """Test shell tool spec matches expected shape.""" + from strands.vended_tools import shell + + spec = shell.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "command" in props + assert "timeout" in props + assert "restart" in props + assert schema.get("required") == ["command"] + + def test_editor_tool_spec_shape(self): + """Test editor tool spec matches expected shape.""" + from strands.vended_tools import editor + + spec = editor.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "command" in props + assert "path" in props + assert "file_text" in props + assert "old_str" in props + assert "new_str" in props + assert "insert_line" in props + assert "view_range" in props + assert set(schema.get("required", [])) == {"command", "path"} + + def test_python_repl_tool_spec_shape(self): + """Test python_repl tool spec matches expected shape.""" + from strands.vended_tools import python_repl + + spec = python_repl.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "code" in props + assert "timeout" in props + assert "reset" in props + assert schema.get("required") == ["code"] + + def test_shell_is_async_generator(self): + """Test that shell is detected as an async generator function.""" + from strands.vended_tools.shell import shell + + assert inspect.isasyncgenfunction(shell.__wrapped__) + + def test_python_repl_is_async_generator(self): + """Test that python_repl is detected as an async generator function.""" + from strands.vended_tools.python_repl import python_repl + + assert inspect.isasyncgenfunction(python_repl.__wrapped__) + + def test_editor_is_not_async_generator(self): + """Test that editor is a regular async function (not generator).""" + from strands.vended_tools.editor import editor + + assert not inspect.isasyncgenfunction(editor.__wrapped__) + assert inspect.iscoroutinefunction(editor.__wrapped__) + + +class TestStreamingIntegration: + """Test the streaming behavior of tools end-to-end.""" + + @pytest.mark.asyncio + async def test_shell_streams_before_result(self, tool_context): + """Test that shell yields chunks BEFORE the final result.""" + from strands.vended_tools.shell import shell + + all_items = [] + + async for item in shell.__wrapped__( + command="echo streaming_test", tool_context=tool_context + ): + all_items.append(item) + + # Should have at least 2 items: chunk(s) + final result + assert len(all_items) >= 2 + # Last item should be the string result + assert isinstance(all_items[-1], str) + # Earlier items should include StreamChunk + stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] + assert len(stream_items) >= 1 + + @pytest.mark.asyncio + async def test_python_repl_streams_before_result(self, tool_context): + """Test that python_repl yields chunks BEFORE the final result.""" + from strands.vended_tools.python_repl import python_repl + + all_items = [] + + async for item in python_repl.__wrapped__( + code="print('streaming_test')", tool_context=tool_context + ): + all_items.append(item) + + assert len(all_items) >= 2 + assert isinstance(all_items[-1], str) + stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] + assert len(stream_items) >= 1 + + @pytest.mark.asyncio + async def test_shell_error_no_streaming_on_timeout(self, tool_context): + """Test that timeout errors don't yield any StreamChunks before the error.""" + from strands.vended_tools.shell import shell + + chunks, result = await collect_generator( + shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_python_repl_error_no_streaming_on_timeout(self, tool_context): + """Test that timeout errors don't yield chunks before the error.""" + from strands.vended_tools.python_repl import python_repl + + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + timeout=1, + tool_context=tool_context, + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_shell_stream_data_matches_result(self, tool_context): + """Test that streamed chunk data matches the final result content.""" + from strands.vended_tools.shell import shell + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo precise_output_42", tool_context=tool_context) + ) + # The streamed chunks should contain the same data as the final result + all_chunk_data = "".join(c.data for c in chunks) + assert "precise_output_42" in all_chunk_data + assert "precise_output_42" in result diff --git a/tests/strands/vended_tools/test_python_repl.py b/tests/strands/vended_tools/test_python_repl.py new file mode 100644 index 000000000..9e174bff3 --- /dev/null +++ b/tests/strands/vended_tools/test_python_repl.py @@ -0,0 +1,177 @@ +"""Tests for the python_repl vended tool.""" + +import pytest + +from strands.sandbox.base import StreamChunk +from strands.sandbox.noop import NoOpSandbox +from strands.vended_tools.python_repl import python_repl + +from .conftest import collect_generator + + +class TestPythonReplTool: + """Tests for the python_repl vended tool.""" + + @pytest.mark.asyncio + async def test_basic_code(self, tool_context): + """Test basic Python code execution returns result.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print('hello from python')", + tool_context=tool_context, + ) + ) + assert "hello from python" in result + + @pytest.mark.asyncio + async def test_basic_code_streams_chunks(self, tool_context): + """Test that python_repl yields StreamChunk objects during execution.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print('hello from python')", + tool_context=tool_context, + ) + ) + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + assert len(stdout_chunks) >= 1 + assert any("hello from python" in c.data for c in stdout_chunks) + + @pytest.mark.asyncio + async def test_stderr_streams_as_stderr_chunks(self, tool_context): + """Test that stderr from Python code yields stderr StreamChunks.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import sys; print('err_msg', file=sys.stderr)", + tool_context=tool_context, + ) + ) + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stderr_chunks) >= 1 + assert any("err_msg" in c.data for c in stderr_chunks) + + @pytest.mark.asyncio + async def test_code_with_math(self, tool_context): + """Test Python math execution.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print(2 + 2)", + tool_context=tool_context, + ) + ) + assert "4" in result + + @pytest.mark.asyncio + async def test_code_with_error(self, tool_context): + """Test Python code that raises an error.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="raise ValueError('test error')", + tool_context=tool_context, + ) + ) + assert "test error" in result or "ValueError" in result + + @pytest.mark.asyncio + async def test_code_with_import(self, tool_context): + """Test Python code with imports.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import json; print(json.dumps({'key': 'value'}))", + tool_context=tool_context, + ) + ) + assert "key" in result + + @pytest.mark.asyncio + async def test_timeout(self, tool_context): + """Test code execution timeout.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + timeout=1, + tool_context=tool_context, + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_config_timeout(self, tool_context, mock_agent): + """Test timeout from config.""" + mock_agent.state.set("strands_python_repl_tool", {"timeout": 1}) + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + tool_context=tool_context, + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_reset(self, tool_context): + """Test REPL reset.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="", + reset=True, + tool_context=tool_context, + ) + ) + assert "reset" in result.lower() + + @pytest.mark.asyncio + async def test_multiline_code(self, tool_context): + """Test multiline Python code.""" + code = """ +def fibonacci(n): + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + +print(fibonacci(10)) +""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code=code, + tool_context=tool_context, + ) + ) + assert "55" in result + + @pytest.mark.asyncio + async def test_no_output(self, tool_context): + """Test code with no output.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="x = 42", + tool_context=tool_context, + ) + ) + assert result == "(no output)" + + @pytest.mark.asyncio + async def test_noop_sandbox(self, tool_context, mock_agent): + """Test python_repl with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print('test')", + tool_context=tool_context, + ) + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_stream_chunk_types_are_correct(self, tool_context): + """Test that all yielded chunks are proper StreamChunk instances.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import sys; print('out'); print('err', file=sys.stderr)", + tool_context=tool_context, + ) + ) + for chunk in chunks: + assert isinstance(chunk, StreamChunk) + assert hasattr(chunk, "data") + assert hasattr(chunk, "stream_type") + assert chunk.stream_type in ("stdout", "stderr") diff --git a/tests/strands/vended_tools/test_shell.py b/tests/strands/vended_tools/test_shell.py new file mode 100644 index 000000000..8ed14c4ac --- /dev/null +++ b/tests/strands/vended_tools/test_shell.py @@ -0,0 +1,160 @@ +"""Tests for the shell vended tool.""" + +import pytest + +from strands.sandbox.base import StreamChunk +from strands.sandbox.noop import NoOpSandbox +from strands.vended_tools.shell import shell + +from .conftest import collect_generator + + +class TestShellTool: + """Tests for the shell vended tool.""" + + @pytest.mark.asyncio + async def test_basic_command(self, tool_context, tmp_path): + """Test basic shell command execution returns result.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo hello", tool_context=tool_context) + ) + assert "hello" in result + + @pytest.mark.asyncio + async def test_basic_command_streams_chunks(self, tool_context, tmp_path): + """Test that shell yields StreamChunk objects during execution.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo hello", tool_context=tool_context) + ) + # Should have at least one stdout chunk + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + assert len(stdout_chunks) >= 1 + assert any("hello" in c.data for c in stdout_chunks) + # Final result should also contain the output + assert "hello" in result + + @pytest.mark.asyncio + async def test_stderr_streams_as_stderr_chunks(self, tool_context, tmp_path): + """Test that stderr output yields StreamChunk with stream_type='stderr'.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo error >&2", tool_context=tool_context) + ) + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stderr_chunks) >= 1 + assert any("error" in c.data for c in stderr_chunks) + assert "error" in result + + @pytest.mark.asyncio + async def test_mixed_stdout_stderr_streaming(self, tool_context, tmp_path): + """Test command with both stdout and stderr streams both chunk types.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo out && echo err >&2", + tool_context=tool_context, + ) + ) + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stdout_chunks) >= 1 + assert len(stderr_chunks) >= 1 + + @pytest.mark.asyncio + async def test_command_with_exit_code(self, tool_context, tmp_path): + """Test command that returns non-zero exit code.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="exit 42", tool_context=tool_context) + ) + assert "42" in result + + @pytest.mark.asyncio + async def test_timeout(self, tool_context): + """Test command timeout.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_config_timeout(self, tool_context, mock_agent): + """Test timeout from config.""" + mock_agent.state.set("strands_shell_tool", {"timeout": 1}) + chunks, result = await collect_generator( + shell.__wrapped__(command="sleep 10", tool_context=tool_context) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_restart(self, tool_context): + """Test shell restart.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="", restart=True, tool_context=tool_context) + ) + assert "reset" in result.lower() + + @pytest.mark.asyncio + async def test_no_output_command(self, tool_context): + """Test command with no output.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="true", tool_context=tool_context) + ) + assert result == "(no output)" + + @pytest.mark.asyncio + async def test_noop_sandbox(self, tool_context, mock_agent): + """Test shell with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_cwd_tracking(self, tool_context, tmp_path): + """Test that working directory is tracked across calls.""" + subdir = tmp_path / "subdir" + subdir.mkdir() + + chunks, _ = await collect_generator( + shell.__wrapped__(command=f"cd {subdir}", tool_context=tool_context) + ) + + shell_state = tool_context.agent.state.get("_strands_shell_state") + assert shell_state is not None + + @pytest.mark.asyncio + async def test_multiline_output(self, tool_context): + """Test command with multiline output.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo 'line1\nline2\nline3'", + tool_context=tool_context, + ) + ) + assert "line1" in result + assert "line2" in result + + @pytest.mark.asyncio + async def test_pipe_command(self, tool_context): + """Test piped commands.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo 'hello world' | wc -w", + tool_context=tool_context, + ) + ) + assert "2" in result + + @pytest.mark.asyncio + async def test_stream_chunk_types_are_correct(self, tool_context): + """Test that all yielded chunks are proper StreamChunk instances.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo stdout_data && echo stderr_data >&2", + tool_context=tool_context, + ) + ) + for chunk in chunks: + assert isinstance(chunk, StreamChunk) + assert hasattr(chunk, "data") + assert hasattr(chunk, "stream_type") + assert chunk.stream_type in ("stdout", "stderr") diff --git a/tests/strands/vended_tools/test_vended_tools.py b/tests/strands/vended_tools/test_vended_tools.py deleted file mode 100644 index 435ea677e..000000000 --- a/tests/strands/vended_tools/test_vended_tools.py +++ /dev/null @@ -1,1161 +0,0 @@ -"""Tests for vended tools — shell, editor, python_repl. - -These tests use a real HostSandbox to validate end-to-end behavior. -They also test configuration via agent.state and streaming behavior -(shell and python_repl yield StreamChunk events). -""" - -import asyncio -import uuid -from unittest.mock import MagicMock - -import pytest - -from strands.agent.state import AgentState -from strands.sandbox.base import ExecutionResult, StreamChunk -from strands.sandbox.host import HostSandbox -from strands.sandbox.noop import NoOpSandbox -from strands.types.tools import ToolContext, ToolUse - -# ============================================================ -# Fixtures -# ============================================================ - - -@pytest.fixture -def sandbox(tmp_path): - """Create a HostSandbox for testing.""" - return HostSandbox(working_dir=str(tmp_path)) - - -@pytest.fixture -def agent_state(): - """Create a fresh AgentState.""" - return AgentState() - - -@pytest.fixture -def mock_agent(sandbox, agent_state): - """Create a mock agent with sandbox and state.""" - agent = MagicMock() - agent.sandbox = sandbox - agent.state = agent_state - return agent - - -@pytest.fixture -def tool_use(): - """Create a mock tool use.""" - return ToolUse( - toolUseId=str(uuid.uuid4()), - name="test_tool", - input={}, - ) - - -@pytest.fixture -def tool_context(mock_agent, tool_use): - """Create a ToolContext for testing.""" - ctx = ToolContext( - tool_use=tool_use, - agent=mock_agent, - invocation_state={}, - ) - return ctx - - -def run(coro): - """Helper to run async coroutines in tests.""" - return asyncio.get_event_loop().run_until_complete(coro) - - -async def collect_generator(gen): - """Collect all values from an async generator. - - Returns (stream_chunks, final_result) where stream_chunks are all - StreamChunk objects yielded, and final_result is the last non-StreamChunk - value (the formatted result string). - """ - chunks = [] - final = None - async for item in gen: - if isinstance(item, StreamChunk): - chunks.append(item) - else: - final = item - return chunks, final - - -# ============================================================ -# Shell Tool Tests -# ============================================================ - - -class TestShellTool: - """Tests for the shell vended tool.""" - - def test_basic_command(self, tool_context, tmp_path): - """Test basic shell command execution returns result.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="echo hello", tool_context=tool_context)) - ) - assert "hello" in result - - def test_basic_command_streams_chunks(self, tool_context, tmp_path): - """Test that shell yields StreamChunk objects during execution.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="echo hello", tool_context=tool_context)) - ) - # Should have at least one stdout chunk - stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] - assert len(stdout_chunks) >= 1 - assert any("hello" in c.data for c in stdout_chunks) - # Final result should also contain the output - assert "hello" in result - - def test_stderr_streams_as_stderr_chunks(self, tool_context, tmp_path): - """Test that stderr output yields StreamChunk with stream_type='stderr'.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="echo error >&2", tool_context=tool_context)) - ) - stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] - assert len(stderr_chunks) >= 1 - assert any("error" in c.data for c in stderr_chunks) - assert "error" in result - - def test_mixed_stdout_stderr_streaming(self, tool_context, tmp_path): - """Test command with both stdout and stderr streams both chunk types.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator( - shell.__wrapped__( - command="echo out && echo err >&2", - tool_context=tool_context, - ) - ) - ) - stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] - stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] - assert len(stdout_chunks) >= 1 - assert len(stderr_chunks) >= 1 - - def test_command_with_exit_code(self, tool_context, tmp_path): - """Test command that returns non-zero exit code.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="exit 42", tool_context=tool_context)) - ) - assert "42" in result - - def test_timeout(self, tool_context): - """Test command timeout.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context)) - ) - assert "timed out" in result.lower() or "error" in result.lower() - - def test_config_timeout(self, tool_context, mock_agent): - """Test timeout from config.""" - from strands.vended_tools.shell.shell import shell - - mock_agent.state.set("strands_shell_tool", {"timeout": 1}) - chunks, result = run( - collect_generator(shell.__wrapped__(command="sleep 10", tool_context=tool_context)) - ) - assert "timed out" in result.lower() or "error" in result.lower() - - def test_restart(self, tool_context): - """Test shell restart.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="", restart=True, tool_context=tool_context)) - ) - assert "reset" in result.lower() - - def test_no_output_command(self, tool_context): - """Test command with no output.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="true", tool_context=tool_context)) - ) - assert result == "(no output)" - - def test_noop_sandbox(self, tool_context, mock_agent): - """Test shell with NoOpSandbox.""" - mock_agent.sandbox = NoOpSandbox() - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator(shell.__wrapped__(command="echo test", tool_context=tool_context)) - ) - assert "error" in result.lower() - - def test_cwd_tracking(self, tool_context, tmp_path): - """Test that working directory is tracked across calls.""" - from strands.vended_tools.shell.shell import shell - - subdir = tmp_path / "subdir" - subdir.mkdir() - - chunks, _ = run( - collect_generator(shell.__wrapped__(command=f"cd {subdir}", tool_context=tool_context)) - ) - - shell_state = tool_context.agent.state.get("_strands_shell_state") - assert shell_state is not None - - def test_multiline_output(self, tool_context): - """Test command with multiline output.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator( - shell.__wrapped__( - command="echo 'line1\nline2\nline3'", - tool_context=tool_context, - ) - ) - ) - assert "line1" in result - assert "line2" in result - - def test_pipe_command(self, tool_context): - """Test piped commands.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator( - shell.__wrapped__( - command="echo 'hello world' | wc -w", - tool_context=tool_context, - ) - ) - ) - assert "2" in result - - def test_stream_chunk_types_are_correct(self, tool_context): - """Test that all yielded chunks are proper StreamChunk instances.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator( - shell.__wrapped__( - command="echo stdout_data && echo stderr_data >&2", - tool_context=tool_context, - ) - ) - ) - for chunk in chunks: - assert isinstance(chunk, StreamChunk) - assert hasattr(chunk, "data") - assert hasattr(chunk, "stream_type") - assert chunk.stream_type in ("stdout", "stderr") - - -# ============================================================ -# Editor Tool Tests -# ============================================================ - - -class TestEditorTool: - """Tests for the editor vended tool.""" - - def test_view_file(self, tool_context, tmp_path, sandbox): - """Test viewing a file.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.txt" - test_file.write_text("line 1\nline 2\nline 3\n") - - result = run( - editor.__wrapped__( - command="view", - path=str(test_file), - tool_context=tool_context, - ) - ) - assert "line 1" in result - assert "line 2" in result - assert "line 3" in result - assert "cat -n" in result - - def test_view_with_range(self, tool_context, tmp_path): - """Test viewing a file with line range.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.txt" - test_file.write_text("line 1\nline 2\nline 3\nline 4\nline 5\n") - - result = run( - editor.__wrapped__( - command="view", - path=str(test_file), - view_range=[2, 4], - tool_context=tool_context, - ) - ) - assert "line 2" in result - assert "line 4" in result - assert " 1" not in result - - def test_view_with_range_end_minus_one(self, tool_context, tmp_path): - """Test viewing with -1 as end of range.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.txt" - test_file.write_text("line 1\nline 2\nline 3\n") - - result = run( - editor.__wrapped__( - command="view", - path=str(test_file), - view_range=[2, -1], - tool_context=tool_context, - ) - ) - assert "line 2" in result - assert "line 3" in result - - def test_view_directory(self, tool_context, tmp_path): - """Test viewing a directory listing.""" - from strands.vended_tools.editor.editor import editor - - (tmp_path / "file1.py").write_text("pass") - (tmp_path / "file2.txt").write_text("hello") - (tmp_path / "subdir").mkdir() - - result = run( - editor.__wrapped__( - command="view", - path=str(tmp_path), - tool_context=tool_context, - ) - ) - assert "file1.py" in result - assert "file2.txt" in result - assert "subdir/" in result - - def test_view_nonexistent(self, tool_context, tmp_path): - """Test viewing a nonexistent file.""" - from strands.vended_tools.editor.editor import editor - - result = run( - editor.__wrapped__( - command="view", - path=str(tmp_path / "nonexistent.txt"), - tool_context=tool_context, - ) - ) - assert "does not exist" in result.lower() - - def test_view_invalid_range(self, tool_context, tmp_path): - """Test viewing with invalid range.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.txt" - test_file.write_text("line 1\nline 2\n") - - result = run( - editor.__wrapped__( - command="view", - path=str(test_file), - view_range=[0, 2], - tool_context=tool_context, - ) - ) - assert "error" in result.lower() - - def test_create_file(self, tool_context, tmp_path): - """Test creating a new file.""" - from strands.vended_tools.editor.editor import editor - - new_file = tmp_path / "new_file.py" - - result = run( - editor.__wrapped__( - command="create", - path=str(new_file), - file_text="print('hello')\n", - tool_context=tool_context, - ) - ) - assert "created" in result.lower() - assert new_file.read_text() == "print('hello')\n" - - def test_create_existing_file(self, tool_context, tmp_path): - """Test creating a file that already exists.""" - from strands.vended_tools.editor.editor import editor - - existing = tmp_path / "existing.py" - existing.write_text("original") - - result = run( - editor.__wrapped__( - command="create", - path=str(existing), - file_text="new content", - tool_context=tool_context, - ) - ) - assert "already exists" in result.lower() - assert existing.read_text() == "original" - - def test_create_missing_file_text(self, tool_context, tmp_path): - """Test create without file_text.""" - from strands.vended_tools.editor.editor import editor - - result = run( - editor.__wrapped__( - command="create", - path=str(tmp_path / "new.py"), - tool_context=tool_context, - ) - ) - assert "file_text" in result.lower() - - def test_str_replace_unique(self, tool_context, tmp_path): - """Test str_replace with a unique match.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("def hello():\n return 'hello'\n") - - result = run( - editor.__wrapped__( - command="str_replace", - path=str(test_file), - old_str="return 'hello'", - new_str="return 'world'", - tool_context=tool_context, - ) - ) - assert "edited" in result.lower() - assert "return 'world'" in test_file.read_text() - - def test_str_replace_not_found(self, tool_context, tmp_path): - """Test str_replace when old_str not found.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("def hello():\n return 'hello'\n") - - result = run( - editor.__wrapped__( - command="str_replace", - path=str(test_file), - old_str="nonexistent string", - new_str="replacement", - tool_context=tool_context, - ) - ) - assert "did not appear" in result.lower() - - def test_str_replace_multiple_occurrences(self, tool_context, tmp_path): - """Test str_replace rejects multiple occurrences.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("x = 1\ny = 1\nz = 1\n") - - result = run( - editor.__wrapped__( - command="str_replace", - path=str(test_file), - old_str="= 1", - new_str="= 2", - tool_context=tool_context, - ) - ) - assert "multiple" in result.lower() - assert test_file.read_text() == "x = 1\ny = 1\nz = 1\n" - - def test_str_replace_deletion(self, tool_context, tmp_path): - """Test str_replace with empty new_str (deletion).""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("# TODO: remove this\ndef main():\n pass\n") - - result = run( - editor.__wrapped__( - command="str_replace", - path=str(test_file), - old_str="# TODO: remove this\n", - new_str="", - tool_context=tool_context, - ) - ) - assert "edited" in result.lower() - assert "TODO" not in test_file.read_text() - - def test_insert(self, tool_context, tmp_path): - """Test inserting text at a line.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("line 1\nline 3\n") - - result = run( - editor.__wrapped__( - command="insert", - path=str(test_file), - insert_line=1, - new_str="line 2", - tool_context=tool_context, - ) - ) - assert "edited" in result.lower() - content = test_file.read_text() - assert "line 1\nline 2\nline 3\n" == content - - def test_insert_at_beginning(self, tool_context, tmp_path): - """Test inserting at the beginning of a file.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("line 2\nline 3\n") - - result = run( - editor.__wrapped__( - command="insert", - path=str(test_file), - insert_line=0, - new_str="line 1", - tool_context=tool_context, - ) - ) - assert "edited" in result.lower() - assert test_file.read_text().startswith("line 1\n") - - def test_insert_invalid_line(self, tool_context, tmp_path): - """Test insert with invalid line number.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("line 1\n") - - result = run( - editor.__wrapped__( - command="insert", - path=str(test_file), - insert_line=999, - new_str="new line", - tool_context=tool_context, - ) - ) - assert "error" in result.lower() - - def test_undo_edit(self, tool_context, tmp_path): - """Test undo_edit reverting a str_replace.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "test.py" - test_file.write_text("original content\n") - - run( - editor.__wrapped__( - command="str_replace", - path=str(test_file), - old_str="original content", - new_str="modified content", - tool_context=tool_context, - ) - ) - assert "modified content" in test_file.read_text() - - result = run( - editor.__wrapped__( - command="undo_edit", - path=str(test_file), - tool_context=tool_context, - ) - ) - assert "reverted" in result.lower() - assert "original content" in test_file.read_text() - - def test_undo_no_history(self, tool_context, tmp_path): - """Test undo_edit when no history exists.""" - from strands.vended_tools.editor.editor import editor - - result = run( - editor.__wrapped__( - command="undo_edit", - path=str(tmp_path / "nonexistent.py"), - tool_context=tool_context, - ) - ) - assert "no edit history" in result.lower() - - def test_relative_path_allowed_by_default(self, tool_context, tmp_path): - """Test that relative paths are passed through to sandbox by default.""" - from strands.vended_tools.editor.editor import editor - - # Create a file using a relative path (the HostSandbox resolves it) - test_file = tmp_path / "relative_test.txt" - test_file.write_text("relative content\n") - - # Use the absolute path — the key point is no validation error - result = run( - editor.__wrapped__( - command="view", - path=str(test_file), - tool_context=tool_context, - ) - ) - assert "relative content" in result - - def test_relative_path_rejected_when_configured(self, tool_context, mock_agent): - """Test that relative paths are rejected when require_absolute_paths is True.""" - from strands.vended_tools.editor.editor import editor - - mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) - - result = run( - editor.__wrapped__( - command="view", - path="relative/path.py", - tool_context=tool_context, - ) - ) - assert "not an absolute path" in result.lower() - - def test_path_traversal_allowed_by_default(self, tool_context, tmp_path): - """Test that paths with .. are passed through to sandbox by default.""" - from strands.vended_tools.editor.editor import editor - - # Create a nested structure - subdir = tmp_path / "sub" - subdir.mkdir() - test_file = tmp_path / "traversal_test.txt" - test_file.write_text("traversal content\n") - - # Use a path with .. — should NOT be rejected by default - traversal_path = str(subdir / ".." / "traversal_test.txt") - result = run( - editor.__wrapped__( - command="view", - path=traversal_path, - tool_context=tool_context, - ) - ) - # The sandbox resolves the path — should show file content - assert "traversal content" in result - - def test_path_traversal_rejected_when_configured(self, tool_context, mock_agent): - """Test that path traversal is rejected when require_absolute_paths is True.""" - from strands.vended_tools.editor.editor import editor - - mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) - - result = run( - editor.__wrapped__( - command="view", - path="/tmp/../etc/passwd", - tool_context=tool_context, - ) - ) - assert "not allowed" in result.lower() - - def test_noop_sandbox(self, tool_context, mock_agent, tmp_path): - """Test editor with NoOpSandbox.""" - mock_agent.sandbox = NoOpSandbox() - from strands.vended_tools.editor.editor import editor - - result = run( - editor.__wrapped__( - command="view", - path=str(tmp_path / "test.py"), - tool_context=tool_context, - ) - ) - assert "error" in result.lower() - - def test_max_file_size(self, tool_context, mock_agent, tmp_path): - """Test max file size configuration.""" - from strands.vended_tools.editor.editor import editor - - mock_agent.state.set("strands_editor_tool", {"max_file_size": 10}) - - test_file = tmp_path / "large.txt" - test_file.write_text("a" * 100) - - result = run( - editor.__wrapped__( - command="view", - path=str(test_file), - tool_context=tool_context, - ) - ) - assert "exceeds" in result.lower() - - -# ============================================================ -# Python REPL Tool Tests -# ============================================================ - - -class TestPythonReplTool: - """Tests for the python_repl vended tool.""" - - def test_basic_code(self, tool_context): - """Test basic Python code execution returns result.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="print('hello from python')", - tool_context=tool_context, - ) - ) - ) - assert "hello from python" in result - - def test_basic_code_streams_chunks(self, tool_context): - """Test that python_repl yields StreamChunk objects during execution.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="print('hello from python')", - tool_context=tool_context, - ) - ) - ) - stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] - assert len(stdout_chunks) >= 1 - assert any("hello from python" in c.data for c in stdout_chunks) - - def test_stderr_streams_as_stderr_chunks(self, tool_context): - """Test that stderr from Python code yields stderr StreamChunks.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="import sys; print('err_msg', file=sys.stderr)", - tool_context=tool_context, - ) - ) - ) - stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] - assert len(stderr_chunks) >= 1 - assert any("err_msg" in c.data for c in stderr_chunks) - - def test_code_with_math(self, tool_context): - """Test Python math execution.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="print(2 + 2)", - tool_context=tool_context, - ) - ) - ) - assert "4" in result - - def test_code_with_error(self, tool_context): - """Test Python code that raises an error.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="raise ValueError('test error')", - tool_context=tool_context, - ) - ) - ) - assert "test error" in result or "ValueError" in result - - def test_code_with_import(self, tool_context): - """Test Python code with imports.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="import json; print(json.dumps({'key': 'value'}))", - tool_context=tool_context, - ) - ) - ) - assert "key" in result - - def test_timeout(self, tool_context): - """Test code execution timeout.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="import time; time.sleep(10)", - timeout=1, - tool_context=tool_context, - ) - ) - ) - assert "timed out" in result.lower() or "error" in result.lower() - - def test_config_timeout(self, tool_context, mock_agent): - """Test timeout from config.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - mock_agent.state.set("strands_python_repl_tool", {"timeout": 1}) - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="import time; time.sleep(10)", - tool_context=tool_context, - ) - ) - ) - assert "timed out" in result.lower() or "error" in result.lower() - - def test_reset(self, tool_context): - """Test REPL reset.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="", - reset=True, - tool_context=tool_context, - ) - ) - ) - assert "reset" in result.lower() - - def test_multiline_code(self, tool_context): - """Test multiline Python code.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - code = """ -def fibonacci(n): - a, b = 0, 1 - for _ in range(n): - a, b = b, a + b - return a - -print(fibonacci(10)) -""" - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code=code, - tool_context=tool_context, - ) - ) - ) - assert "55" in result - - def test_no_output(self, tool_context): - """Test code with no output.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="x = 42", - tool_context=tool_context, - ) - ) - ) - assert result == "(no output)" - - def test_noop_sandbox(self, tool_context, mock_agent): - """Test python_repl with NoOpSandbox.""" - mock_agent.sandbox = NoOpSandbox() - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="print('test')", - tool_context=tool_context, - ) - ) - ) - assert "error" in result.lower() - - def test_stream_chunk_types_are_correct(self, tool_context): - """Test that all yielded chunks are proper StreamChunk instances.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="import sys; print('out'); print('err', file=sys.stderr)", - tool_context=tool_context, - ) - ) - ) - for chunk in chunks: - assert isinstance(chunk, StreamChunk) - assert hasattr(chunk, "data") - assert hasattr(chunk, "stream_type") - assert chunk.stream_type in ("stdout", "stderr") - - -# ============================================================ -# Integration Tests -# ============================================================ - - -class TestVendedToolsImport: - """Test that vended tools can be imported from the package.""" - - def test_import_from_vended_tools(self): - """Test importing from strands.vended_tools.""" - from strands.vended_tools import editor, python_repl, shell - - assert shell is not None - assert editor is not None - assert python_repl is not None - - def test_import_individual_tools(self): - """Test importing individual tools.""" - from strands.vended_tools.editor import editor - from strands.vended_tools.python_repl import python_repl - from strands.vended_tools.shell import shell - - assert shell is not None - assert editor is not None - assert python_repl is not None - - def test_tools_have_tool_spec(self): - """Test that tools have proper tool specs.""" - from strands.vended_tools import editor, python_repl, shell - - assert shell.tool_name == "shell" - assert editor.tool_name == "editor" - assert python_repl.tool_name == "python_repl" - - for t in [shell, editor, python_repl]: - spec = t.tool_spec - assert "name" in spec - assert "description" in spec - assert "inputSchema" in spec - - def test_shell_tool_spec_shape(self): - """Test shell tool spec matches expected shape.""" - from strands.vended_tools import shell - - spec = shell.tool_spec - schema = spec["inputSchema"]["json"] - props = schema.get("properties", {}) - - assert "command" in props - assert "timeout" in props - assert "restart" in props - assert schema.get("required") == ["command"] - - def test_editor_tool_spec_shape(self): - """Test editor tool spec matches expected shape.""" - from strands.vended_tools import editor - - spec = editor.tool_spec - schema = spec["inputSchema"]["json"] - props = schema.get("properties", {}) - - assert "command" in props - assert "path" in props - assert "file_text" in props - assert "old_str" in props - assert "new_str" in props - assert "insert_line" in props - assert "view_range" in props - assert set(schema.get("required", [])) == {"command", "path"} - - def test_python_repl_tool_spec_shape(self): - """Test python_repl tool spec matches expected shape.""" - from strands.vended_tools import python_repl - - spec = python_repl.tool_spec - schema = spec["inputSchema"]["json"] - props = schema.get("properties", {}) - - assert "code" in props - assert "timeout" in props - assert "reset" in props - assert schema.get("required") == ["code"] - - def test_shell_is_async_generator(self): - """Test that shell is detected as an async generator function.""" - import inspect - - from strands.vended_tools.shell.shell import shell - - assert inspect.isasyncgenfunction(shell.__wrapped__) - - def test_python_repl_is_async_generator(self): - """Test that python_repl is detected as an async generator function.""" - import inspect - - from strands.vended_tools.python_repl.python_repl import python_repl - - assert inspect.isasyncgenfunction(python_repl.__wrapped__) - - def test_editor_is_not_async_generator(self): - """Test that editor is a regular async function (not generator).""" - import inspect - - from strands.vended_tools.editor.editor import editor - - assert not inspect.isasyncgenfunction(editor.__wrapped__) - assert inspect.iscoroutinefunction(editor.__wrapped__) - - -# ============================================================ -# Configuration Persistence Tests -# ============================================================ - - -class TestConfigPersistence: - """Test that tool configuration persists via agent state.""" - - def test_shell_config_persists(self, mock_agent, agent_state): - """Test shell config is read from agent state.""" - agent_state.set("strands_shell_tool", {"timeout": 300}) - config = agent_state.get("strands_shell_tool") - assert config["timeout"] == 300 - - def test_editor_config_persists(self, mock_agent, agent_state): - """Test editor config is read from agent state.""" - agent_state.set("strands_editor_tool", {"max_file_size": 2097152}) - config = agent_state.get("strands_editor_tool") - assert config["max_file_size"] == 2097152 - - def test_python_repl_config_persists(self, mock_agent, agent_state): - """Test python_repl config is read from agent state.""" - agent_state.set("strands_python_repl_tool", {"timeout": 60}) - config = agent_state.get("strands_python_repl_tool") - assert config["timeout"] == 60 - - def test_undo_state_persists(self, tool_context, tmp_path): - """Test that undo state is stored in agent state.""" - from strands.vended_tools.editor.editor import editor - - test_file = tmp_path / "undo_test.py" - test_file.write_text("original\n") - - run( - editor.__wrapped__( - command="str_replace", - path=str(test_file), - old_str="original", - new_str="modified", - tool_context=tool_context, - ) - ) - - undo_state = tool_context.agent.state.get("_strands_editor_undo") - assert undo_state is not None - assert str(test_file) in undo_state - - -# ============================================================ -# Streaming Integration Tests -# ============================================================ - - -class TestStreamingIntegration: - """Test the streaming behavior of tools end-to-end.""" - - def test_shell_streams_before_result(self, tool_context): - """Test that shell yields chunks BEFORE the final result.""" - from strands.vended_tools.shell.shell import shell - - all_items = [] - - async def collect_all(): - async for item in shell.__wrapped__( - command="echo streaming_test", tool_context=tool_context - ): - all_items.append(item) - - run(collect_all()) - - # Should have at least 2 items: chunk(s) + final result - assert len(all_items) >= 2 - # Last item should be the string result - assert isinstance(all_items[-1], str) - # Earlier items should include StreamChunk - stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] - assert len(stream_items) >= 1 - - def test_python_repl_streams_before_result(self, tool_context): - """Test that python_repl yields chunks BEFORE the final result.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - all_items = [] - - async def collect_all(): - async for item in python_repl.__wrapped__( - code="print('streaming_test')", tool_context=tool_context - ): - all_items.append(item) - - run(collect_all()) - - assert len(all_items) >= 2 - assert isinstance(all_items[-1], str) - stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] - assert len(stream_items) >= 1 - - def test_shell_error_no_streaming_on_timeout(self, tool_context): - """Test that timeout errors don't yield any StreamChunks before the error.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator( - shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context) - ) - ) - # On timeout, we should get the error message directly - assert "timed out" in result.lower() or "error" in result.lower() - - def test_python_repl_error_no_streaming_on_timeout(self, tool_context): - """Test that timeout errors don't yield chunks before the error.""" - from strands.vended_tools.python_repl.python_repl import python_repl - - chunks, result = run( - collect_generator( - python_repl.__wrapped__( - code="import time; time.sleep(10)", - timeout=1, - tool_context=tool_context, - ) - ) - ) - assert "timed out" in result.lower() or "error" in result.lower() - - def test_shell_stream_data_matches_result(self, tool_context): - """Test that streamed chunk data matches the final result content.""" - from strands.vended_tools.shell.shell import shell - - chunks, result = run( - collect_generator( - shell.__wrapped__(command="echo precise_output_42", tool_context=tool_context) - ) - ) - # The streamed chunks should contain the same data as the final result - all_chunk_data = "".join(c.data for c in chunks) - assert "precise_output_42" in all_chunk_data - assert "precise_output_42" in result From 8fc6d8bf452da930fb4897a5b263b680055746c0 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> Date: Tue, 28 Apr 2026 18:40:04 +0000 Subject: [PATCH 3/7] fix: filter CWD marker from streamed output, remove dead except blocks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses re-review feedback: 1. CWD marker leak (critical): The __STRANDS_CWD__ marker was being streamed to UI consumers via unfiltered StreamChunks. Fixed by collecting stdout chunks during streaming, then yielding a single filtered chunk after stripping the marker from result.stdout. This preserves correct cwd tracking (which requires pwd in the same process as the user command) while ensuring no internal markers leak to consumers. 2. Unnecessary except blocks: Removed try/except around AgentState.delete() in shell.py (_clear_shell_state) and python_repl.py (reset handler) — delete() uses dict.pop(key, None) internally and never raises. 3. Test improvement: test_cwd_tracking now verifies the tracked cwd value matches the expected directory AND asserts no __STRANDS_CWD__ marker appears in the result or streamed chunks. Added new test test_cwd_no_marker_leak_in_streaming for explicit coverage. Test results: 66 passed --- src/strands/vended_tools/python_repl.py | 5 +-- src/strands/vended_tools/shell.py | 55 +++++++++++++++--------- tests/strands/vended_tools/test_shell.py | 20 ++++++++- 3 files changed, 55 insertions(+), 25 deletions(-) diff --git a/src/strands/vended_tools/python_repl.py b/src/strands/vended_tools/python_repl.py index a15e9e0ec..69ee5e5ad 100644 --- a/src/strands/vended_tools/python_repl.py +++ b/src/strands/vended_tools/python_repl.py @@ -81,10 +81,7 @@ async def python_repl( # Handle reset if reset: - try: - tool_context.agent.state.delete("_strands_python_repl_state") - except Exception: - pass + tool_context.agent.state.delete("_strands_python_repl_state") if not code or not code.strip(): yield "Python REPL state reset." return diff --git a/src/strands/vended_tools/shell.py b/src/strands/vended_tools/shell.py index 005edf0e2..5817d7b96 100644 --- a/src/strands/vended_tools/shell.py +++ b/src/strands/vended_tools/shell.py @@ -44,6 +44,10 @@ #: Default timeout for shell commands (seconds) DEFAULT_TIMEOUT = 120 +#: Internal marker used to separate user output from cwd tracking. +#: Must be unique enough to never appear in legitimate command output. +_CWD_MARKER = "__STRANDS_CWD__" + @tool(context=True) async def shell( @@ -96,13 +100,19 @@ async def shell( shell_state = tool_context.agent.state.get("_strands_shell_state") or {} cwd = shell_state.get("cwd") - # Wrap command with cwd tracking in a single execution to avoid - # a separate `pwd` call after every command (reduces latency for - # remote sandboxes). Appends `; pwd` to capture the final cwd. - tracked_command = f"{command}; echo __STRANDS_CWD__; pwd" - - # Execute via sandbox streaming + # Append cwd tracking to the command. We use a unique marker so we can + # reliably split the actual output from the cwd line. This captures the + # final working directory even after `cd` commands (which only affect + # the shell process they run in — a separate pwd call would not see them). + tracked_command = f"{command}; echo {_CWD_MARKER}; pwd" + + # Collect chunks during streaming, then filter the marker before yielding. + # This prevents internal markers from leaking into UI consumers' streamed + # output while preserving cwd tracking correctness. + stdout_chunks: list[StreamChunk] = [] + stderr_chunks: list[StreamChunk] = [] result: ExecutionResult | None = None + try: async for chunk in sandbox.execute_streaming( tracked_command, @@ -110,8 +120,10 @@ async def shell( cwd=cwd, ): if isinstance(chunk, StreamChunk): - # Yield each chunk — the decorator wraps it as ToolStreamEvent - yield chunk + if chunk.stream_type == "stderr": + stderr_chunks.append(chunk) + else: + stdout_chunks.append(chunk) elif isinstance(chunk, ExecutionResult): result = chunk except asyncio.TimeoutError: @@ -128,21 +140,27 @@ async def shell( yield "Error: Sandbox did not return an execution result." return - # Extract cwd from output (after __STRANDS_CWD__ marker) + # Extract cwd from the full stdout (result.stdout has the complete text) stdout = result.stdout or "" - cwd_marker = "__STRANDS_CWD__" - if cwd_marker in stdout: - parts = stdout.split(cwd_marker, 1) - # The actual command output is before the marker + if _CWD_MARKER in stdout: + parts = stdout.split(_CWD_MARKER, 1) + # Actual command output is before the marker stdout = parts[0].rstrip("\n") # The cwd is the line after the marker new_cwd = parts[1].strip() if new_cwd: shell_state["cwd"] = new_cwd tool_context.agent.state.set("_strands_shell_state", shell_state) - else: - # Fallback: no marker found (sandbox may have filtered it) - pass + + # Yield filtered stdout chunks to UI consumers (marker stripped). + # Reconstruct from the cleaned stdout rather than yielding raw chunks, + # since the marker may span chunk boundaries. + if stdout: + yield StreamChunk(data=stdout, stream_type="stdout") + + # Yield stderr chunks as-is (no marker contamination possible) + for chunk in stderr_chunks: + yield chunk # Format final output (becomes the ToolResult) output_parts = [] @@ -168,7 +186,4 @@ def _clear_shell_state(tool_context: ToolContext) -> None: Args: tool_context: The tool context providing access to agent state. """ - try: - tool_context.agent.state.delete("_strands_shell_state") - except Exception: - pass + tool_context.agent.state.delete("_strands_shell_state") diff --git a/tests/strands/vended_tools/test_shell.py b/tests/strands/vended_tools/test_shell.py index 8ed14c4ac..5c02b2f09 100644 --- a/tests/strands/vended_tools/test_shell.py +++ b/tests/strands/vended_tools/test_shell.py @@ -114,12 +114,30 @@ async def test_cwd_tracking(self, tool_context, tmp_path): subdir = tmp_path / "subdir" subdir.mkdir() - chunks, _ = await collect_generator( + chunks, result = await collect_generator( shell.__wrapped__(command=f"cd {subdir}", tool_context=tool_context) ) + # Verify cwd state was tracked shell_state = tool_context.agent.state.get("_strands_shell_state") assert shell_state is not None + assert shell_state["cwd"] == str(subdir) + + # Verify no internal markers leak into the result or streamed chunks + assert "__STRANDS_CWD__" not in result + for chunk in chunks: + assert "__STRANDS_CWD__" not in chunk.data + + @pytest.mark.asyncio + async def test_cwd_no_marker_leak_in_streaming(self, tool_context, tmp_path): + """Test that no internal markers appear in streamed output.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo hello && cd /tmp", tool_context=tool_context) + ) + # No internal markers should appear in any output + all_chunk_data = "".join(c.data for c in chunks) + assert "__STRANDS_CWD__" not in all_chunk_data + assert "__STRANDS_CWD__" not in result @pytest.mark.asyncio async def test_multiline_output(self, tool_context): From 9b252389c6be3d220476feedc6d5d11f1deb35ff Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Tue, 28 Apr 2026 19:23:18 +0000 Subject: [PATCH 4/7] refactor: one-chunk-behind streaming for CWD marker filtering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback from @mkmeral and automated reviewer: The previous fix buffered ALL stdout chunks and yielded them as a single batch after execution. While correct, this defeated incremental streaming for truly-streaming sandbox implementations. New approach — one-chunk-behind: - Yield each stdout chunk immediately EXCEPT the most recent one - The CWD marker is always at the very end (appended via '; echo; pwd') - Only the last stdout chunk needs filtering - stderr chunks pass through immediately (no marker contamination) This preserves real-time streaming for all chunks except the last, while still filtering the __STRANDS_CWD__ marker from consumer-visible output. Added test_stderr_streams_immediately_during_one_chunk_behind that uses a mock sandbox with interleaved multi-chunk stdout/stderr to validate: - All stdout chunks (minus marker) reach consumers - All stderr chunks reach consumers - No internal markers leak into any output Test results: 67 passed --- src/strands/vended_tools/shell.py | 34 ++++++++-------- tests/strands/vended_tools/test_shell.py | 49 +++++++++++++++++++++++- 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/src/strands/vended_tools/shell.py b/src/strands/vended_tools/shell.py index 5817d7b96..1fea5e378 100644 --- a/src/strands/vended_tools/shell.py +++ b/src/strands/vended_tools/shell.py @@ -106,11 +106,13 @@ async def shell( # the shell process they run in — a separate pwd call would not see them). tracked_command = f"{command}; echo {_CWD_MARKER}; pwd" - # Collect chunks during streaming, then filter the marker before yielding. - # This prevents internal markers from leaking into UI consumers' streamed - # output while preserving cwd tracking correctness. - stdout_chunks: list[StreamChunk] = [] - stderr_chunks: list[StreamChunk] = [] + # One-chunk-behind streaming: yield stdout chunks incrementally as they + # arrive, but hold back the most recent one. The CWD marker is always at + # the very end of stdout (appended via `; echo __STRANDS_CWD__; pwd`), + # so only the last stdout chunk needs filtering. This preserves real-time + # streaming for all chunks except the last, while ensuring no internal + # markers leak to UI consumers. + pending_chunk: StreamChunk | None = None result: ExecutionResult | None = None try: @@ -121,9 +123,13 @@ async def shell( ): if isinstance(chunk, StreamChunk): if chunk.stream_type == "stderr": - stderr_chunks.append(chunk) + # stderr is safe — yield immediately + yield chunk else: - stdout_chunks.append(chunk) + # Yield the previous stdout chunk (it's safe — marker is at the end) + if pending_chunk is not None: + yield pending_chunk + pending_chunk = chunk elif isinstance(chunk, ExecutionResult): result = chunk except asyncio.TimeoutError: @@ -152,15 +158,11 @@ async def shell( shell_state["cwd"] = new_cwd tool_context.agent.state.set("_strands_shell_state", shell_state) - # Yield filtered stdout chunks to UI consumers (marker stripped). - # Reconstruct from the cleaned stdout rather than yielding raw chunks, - # since the marker may span chunk boundaries. - if stdout: - yield StreamChunk(data=stdout, stream_type="stdout") - - # Yield stderr chunks as-is (no marker contamination possible) - for chunk in stderr_chunks: - yield chunk + # Filter the marker from the last stdout chunk and yield it + if pending_chunk is not None: + cleaned = pending_chunk.data.split(_CWD_MARKER, 1)[0].rstrip("\n") + if cleaned: + yield StreamChunk(data=cleaned, stream_type="stdout") # Format final output (becomes the ToolResult) output_parts = [] diff --git a/tests/strands/vended_tools/test_shell.py b/tests/strands/vended_tools/test_shell.py index 5c02b2f09..fd9f1161e 100644 --- a/tests/strands/vended_tools/test_shell.py +++ b/tests/strands/vended_tools/test_shell.py @@ -2,7 +2,7 @@ import pytest -from strands.sandbox.base import StreamChunk +from strands.sandbox.base import ExecutionResult, StreamChunk from strands.sandbox.noop import NoOpSandbox from strands.vended_tools.shell import shell @@ -176,3 +176,50 @@ async def test_stream_chunk_types_are_correct(self, tool_context): assert hasattr(chunk, "data") assert hasattr(chunk, "stream_type") assert chunk.stream_type in ("stdout", "stderr") + + @pytest.mark.asyncio + async def test_stderr_streams_immediately_during_one_chunk_behind( + self, tool_context, mock_agent + ): + """Test that stderr chunks are yielded immediately, not held back. + + The one-chunk-behind approach only holds back stdout chunks. Stderr + chunks should pass through without delay regardless of buffering. + """ + # Use a mock sandbox that interleaves stdout and stderr chunks + from unittest.mock import AsyncMock + + async def fake_streaming(command, timeout=None, cwd=None): + """Simulate a sandbox that yields interleaved stdout/stderr.""" + yield StreamChunk(data="stdout1\n", stream_type="stdout") + yield StreamChunk(data="err1\n", stream_type="stderr") + yield StreamChunk(data="stdout2\n", stream_type="stdout") + yield StreamChunk(data="err2\n", stream_type="stderr") + yield StreamChunk(data="stdout3\n__STRANDS_CWD__\n/tmp\n", stream_type="stdout") + yield ExecutionResult( + exit_code=0, + stdout="stdout1\nstdout2\nstdout3\n__STRANDS_CWD__\n/tmp\n", + stderr="err1\nerr2\n", + ) + + mock_agent.sandbox = AsyncMock() + mock_agent.sandbox.execute_streaming = fake_streaming + + chunks, result = await collect_generator( + shell.__wrapped__(command="test", tool_context=tool_context) + ) + + # Verify all chunks are present and no markers leaked + all_data = "".join(c.data for c in chunks) + assert "__STRANDS_CWD__" not in all_data + + # Verify stderr chunks are present + stderr_data = "".join(c.data for c in chunks if c.stream_type == "stderr") + assert "err1" in stderr_data + assert "err2" in stderr_data + + # Verify stdout chunks are present (minus marker) + stdout_data = "".join(c.data for c in chunks if c.stream_type == "stdout") + assert "stdout1" in stdout_data + assert "stdout2" in stdout_data + assert "stdout3" in stdout_data From ba8d70d606988c702fa9ed1c1831c45592454523 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Wed, 29 Apr 2026 14:46:51 +0000 Subject: [PATCH 5/7] fix: resolve 6 adversarial testing findings in vended tools Shell tool fixes: - Fix CWD marker injection: use rsplit (split from right) instead of split to always use the tool's own marker, not user-echoed ones - Fix string timeout crash: coerce config timeout to int with fallback - Fix marker split across chunks: yield filtered stdout from result.stdout instead of individual chunks (prevents partial marker leakage) Editor tool fixes: - Fix tab expansion file corruption: remove tab-to-space expansion from str_replace and insert. Tabs are only expanded for display (_make_output) - Fix tab expansion false duplicates: match against original content, not expanded content, preventing tabs and 8-spaces from being conflated - Fix float view_range crash: coerce view_range elements to int with clear error on failure (LLMs commonly send [1.0, 3.0] not [1, 3]) Also adds insert_line int coercion for the same reason. Tests: 79 passing (67 original + 12 new adversarial fix tests) --- src/strands/vended_tools/editor.py | 64 ++--- src/strands/vended_tools/shell.py | 49 ++-- .../vended_tools/test_adversarial_fixes.py | 243 ++++++++++++++++++ 3 files changed, 307 insertions(+), 49 deletions(-) create mode 100644 tests/strands/vended_tools/test_adversarial_fixes.py diff --git a/src/strands/vended_tools/editor.py b/src/strands/vended_tools/editor.py index 772b2a5f6..b24aef1e2 100644 --- a/src/strands/vended_tools/editor.py +++ b/src/strands/vended_tools/editor.py @@ -66,9 +66,9 @@ def _make_output(content: str, descriptor: str, init_line: int = 1) -> str: Returns: Formatted output with line numbers. """ - # Expand tabs to spaces - content = content.replace("\t", " ") - lines = content.split("\n") + # Expand tabs to spaces for display only + display_content = content.replace("\t", " ") + lines = display_content.split("\n") numbered = [] for i, line in enumerate(lines): line_num = i + init_line @@ -183,6 +183,11 @@ async def editor( return "Error: Parameter `insert_line` is required for command: insert" if new_str is None: return "Error: Parameter `new_str` is required for command: insert" + # Coerce insert_line to int (LLMs may send floats) + try: + insert_line = int(insert_line) + except (TypeError, ValueError): + return f"Error: `insert_line` must be an integer, got: {type(insert_line).__name__}" return await _handle_insert(sandbox, tool_context, config, path, insert_line, new_str) elif command == "undo_edit": return await _handle_undo(sandbox, tool_context, path) @@ -233,7 +238,12 @@ async def _handle_view(sandbox: Sandbox, config: dict[str, Any], path: str, view if len(view_range) != 2: return "Error: `view_range` must be a list of two integers [start, end]." - start, end = view_range[0], view_range[1] + # Coerce view_range elements to int (LLMs may send floats like [1.0, 3.0]) + try: + start = int(view_range[0]) + end = int(view_range[1]) + except (TypeError, ValueError): + return "Error: `view_range` elements must be integers." if start < 1 or start > n_lines: return ( @@ -279,13 +289,11 @@ async def _handle_str_replace( except FileNotFoundError: return f"Error: The path {path} does not exist." - # Expand tabs for matching - content = content.replace("\t", " ") - expanded_old = old_str.replace("\t", " ") - expanded_new = new_str.replace("\t", " ") - - # Count occurrences — MUST be exactly 1 - count = content.count(expanded_old) + # Match against original content — do NOT expand tabs for matching. + # Tab expansion is only used for display output (_make_output). + # This preserves tab characters in the file and prevents false matches + # between tabs and their space equivalents. + count = content.count(old_str) if count == 0: return f"Error: No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}." @@ -295,13 +303,13 @@ async def _handle_str_replace( lines = content.split("\n") line_nums: list[int] = [] for i, line in enumerate(lines): - if expanded_old in line: + if old_str in line: line_nums.append(i + 1) # Also check multi-line matches if not line_nums: idx = 0 while True: - idx = content.find(expanded_old, idx) + idx = content.find(old_str, idx) if idx == -1: break line_num = content[:idx].count("\n") + 1 @@ -312,20 +320,20 @@ async def _handle_str_replace( f"in lines {line_nums}. Please ensure old_str is unique." ) - # Save undo state + # Save undo state (original content before any modification) _save_undo(tool_context, path, content) - # Perform replacement - new_content = content.replace(expanded_old, expanded_new, 1) + # Perform replacement on original content (preserving tabs) + new_content = content.replace(old_str, new_str, 1) - # Write back + # Write back (preserving original formatting including tabs) await sandbox.write_text(path, new_content) - # Generate snippet around the change - replace_idx = content.find(expanded_old) + # Generate snippet around the change (expand tabs for display only) + replace_idx = content.find(old_str) replace_line = content[:replace_idx].count("\n") - inserted_lines = expanded_new.count("\n") + 1 - original_lines = expanded_old.count("\n") + 1 + inserted_lines = new_str.count("\n") + 1 + original_lines = old_str.count("\n") + 1 line_diff = inserted_lines - original_lines new_lines = new_content.split("\n") @@ -354,21 +362,19 @@ async def _handle_insert( except FileNotFoundError: return f"Error: The path {path} does not exist." - # Expand tabs - content = content.replace("\t", " ") - expanded_new = new_str.replace("\t", " ") - + # Work with original content — do NOT expand tabs. + # Tab expansion is only for display output. lines = content.split("\n") n_lines = len(lines) if insert_line < 0 or insert_line > n_lines: return f"Error: Invalid `insert_line`: {insert_line}. Should be within [0, {n_lines}]." - # Save undo state + # Save undo state (original content) _save_undo(tool_context, path, content) - # Insert - new_str_lines = expanded_new.split("\n") + # Insert (preserve original content including tabs) + new_str_lines = new_str.split("\n") if content == "": new_lines = new_str_lines else: @@ -377,7 +383,7 @@ async def _handle_insert( new_content = "\n".join(new_lines) await sandbox.write_text(path, new_content) - # Generate snippet + # Generate snippet (expand tabs for display only) start = max(0, insert_line - SNIPPET_LINES) end = min(len(new_lines), insert_line + len(new_str_lines) + SNIPPET_LINES) snippet = "\n".join(new_lines[start:end]) diff --git a/src/strands/vended_tools/shell.py b/src/strands/vended_tools/shell.py index 1fea5e378..2841d8daa 100644 --- a/src/strands/vended_tools/shell.py +++ b/src/strands/vended_tools/shell.py @@ -96,6 +96,13 @@ async def shell( if effective_timeout is None: effective_timeout = config.get("timeout", DEFAULT_TIMEOUT) + # Coerce timeout to int — JSON configs and LLMs may pass strings or floats + if effective_timeout is not None: + try: + effective_timeout = int(effective_timeout) + except (TypeError, ValueError): + effective_timeout = DEFAULT_TIMEOUT + # Get tracked working directory from state (for session continuity) shell_state = tool_context.agent.state.get("_strands_shell_state") or {} cwd = shell_state.get("cwd") @@ -106,13 +113,15 @@ async def shell( # the shell process they run in — a separate pwd call would not see them). tracked_command = f"{command}; echo {_CWD_MARKER}; pwd" - # One-chunk-behind streaming: yield stdout chunks incrementally as they - # arrive, but hold back the most recent one. The CWD marker is always at - # the very end of stdout (appended via `; echo __STRANDS_CWD__; pwd`), - # so only the last stdout chunk needs filtering. This preserves real-time - # streaming for all chunks except the last, while ensuring no internal - # markers leak to UI consumers. - pending_chunk: StreamChunk | None = None + # Sliding-window streaming: buffer the last len(_CWD_MARKER) bytes of + # stdout to handle the case where the marker is split across two chunks. + # We yield everything except a trailing window that might contain the + # start of the marker. Once we see the full result, we filter cleanly. + # + # The approach: collect all stdout chunks, then yield filtered output + # from result.stdout (which has the complete, un-fragmented text). + # stderr chunks are yielded immediately (never contain the marker). + stdout_chunks: list[StreamChunk] = [] result: ExecutionResult | None = None try: @@ -126,10 +135,8 @@ async def shell( # stderr is safe — yield immediately yield chunk else: - # Yield the previous stdout chunk (it's safe — marker is at the end) - if pending_chunk is not None: - yield pending_chunk - pending_chunk = chunk + # Collect stdout chunks — we'll yield filtered output after + stdout_chunks.append(chunk) elif isinstance(chunk, ExecutionResult): result = chunk except asyncio.TimeoutError: @@ -146,23 +153,25 @@ async def shell( yield "Error: Sandbox did not return an execution result." return - # Extract cwd from the full stdout (result.stdout has the complete text) + # Extract cwd from the full stdout using rsplit — this splits on the LAST + # occurrence of the marker, which is always the one we appended. This + # prevents corruption if user commands happen to output the marker string. stdout = result.stdout or "" if _CWD_MARKER in stdout: - parts = stdout.split(_CWD_MARKER, 1) - # Actual command output is before the marker + parts = stdout.rsplit(_CWD_MARKER, 1) + # Actual command output is before the LAST marker stdout = parts[0].rstrip("\n") - # The cwd is the line after the marker + # The cwd is the line after the last marker new_cwd = parts[1].strip() if new_cwd: shell_state["cwd"] = new_cwd tool_context.agent.state.set("_strands_shell_state", shell_state) - # Filter the marker from the last stdout chunk and yield it - if pending_chunk is not None: - cleaned = pending_chunk.data.split(_CWD_MARKER, 1)[0].rstrip("\n") - if cleaned: - yield StreamChunk(data=cleaned, stream_type="stdout") + # Yield the filtered stdout as a single chunk (marker fully removed). + # This avoids partial marker leakage that occurs with chunk-by-chunk + # filtering when the marker is split across chunk boundaries. + if stdout: + yield StreamChunk(data=stdout, stream_type="stdout") # Format final output (becomes the ToolResult) output_parts = [] diff --git a/tests/strands/vended_tools/test_adversarial_fixes.py b/tests/strands/vended_tools/test_adversarial_fixes.py new file mode 100644 index 000000000..aa87d3a19 --- /dev/null +++ b/tests/strands/vended_tools/test_adversarial_fixes.py @@ -0,0 +1,243 @@ +"""Adversarial tests validating the bug fixes. + +These tests verify that all 6 findings from the adversarial testing report +are now resolved. +""" + +import pytest + +from strands.sandbox.base import ExecutionResult, StreamChunk +from strands.vended_tools.shell import shell, _CWD_MARKER +from strands.vended_tools.editor import editor, _UNDO_STATE_KEY, STATE_KEY + +from .conftest import collect_generator + + +class TestShellCwdMarkerInjectionFixed: + """Verify Fix #1: CWD marker injection no longer corrupts state.""" + + @pytest.mark.asyncio + async def test_user_output_marker_does_not_corrupt_cwd(self, tool_context, tmp_path): + """User echoing the marker string should not affect tracked cwd.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command=f"echo '{_CWD_MARKER}' && echo '/evil/path'", + tool_context=tool_context, + ) + ) + + shell_state = tool_context.agent.state.get("_strands_shell_state") or {} + tracked_cwd = shell_state.get("cwd", "") + + # CWD should NOT be /evil/path or contain the marker + assert tracked_cwd != "/evil/path" + assert _CWD_MARKER not in tracked_cwd + # It SHOULD be a real, valid single-line path + assert "\n" not in tracked_cwd + assert tracked_cwd != "" + + @pytest.mark.asyncio + async def test_multiple_marker_outputs_still_track_correctly(self, tool_context, tmp_path): + """Multiple user-output markers should not corrupt state.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command=f"echo '{_CWD_MARKER}' && echo '{_CWD_MARKER}' && echo fake_path", + tool_context=tool_context, + ) + ) + + shell_state = tool_context.agent.state.get("_strands_shell_state") or {} + tracked_cwd = shell_state.get("cwd", "") + + assert tracked_cwd != "fake_path" + assert _CWD_MARKER not in tracked_cwd + assert "\n" not in tracked_cwd + assert tracked_cwd != "" + + @pytest.mark.asyncio + async def test_subsequent_calls_work_after_marker_injection(self, tool_context, tmp_path): + """After a marker injection attempt, next shell call should still work.""" + # First call: inject marker + await collect_generator( + shell.__wrapped__( + command=f"echo '{_CWD_MARKER}' && echo '/evil/path'", + tool_context=tool_context, + ) + ) + + # Second call: should work normally (no Permission denied) + chunks, result = await collect_generator( + shell.__wrapped__(command="echo follow_up", tool_context=tool_context) + ) + assert "follow_up" in result + assert "error" not in result.lower() + + +class TestShellStringTimeoutFixed: + """Verify Fix #3: String timeout no longer crashes.""" + + @pytest.mark.asyncio + async def test_string_timeout_coerced_to_int(self, tool_context, mock_agent): + """String timeout "30" should be coerced to int 30.""" + mock_agent.state.set("strands_shell_tool", {"timeout": "30"}) + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + # Should not crash — timeout is coerced + assert "test" in result + assert "error" not in result.lower() + + @pytest.mark.asyncio + async def test_float_timeout_coerced(self, tool_context, mock_agent): + """Float timeout 30.5 should be coerced to int 30.""" + mock_agent.state.set("strands_shell_tool", {"timeout": 30.5}) + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + assert "test" in result + + @pytest.mark.asyncio + async def test_invalid_timeout_uses_default(self, tool_context, mock_agent): + """Non-numeric timeout should fall back to default.""" + mock_agent.state.set("strands_shell_tool", {"timeout": "not_a_number"}) + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + assert "test" in result + + +class TestShellMarkerSplitFixed: + """Verify Fix #5: Marker split across chunks no longer leaks.""" + + @pytest.mark.asyncio + async def test_marker_split_across_chunks_no_leak(self, tool_context, mock_agent): + """Marker split across chunks should not leak partial marker to consumer.""" + from unittest.mock import AsyncMock + + async def fake_streaming(command, timeout=None, cwd=None): + yield StreamChunk(data="output\n__STRAN", stream_type="stdout") + yield StreamChunk(data="DS_CWD__\n/tmp\n", stream_type="stdout") + yield ExecutionResult( + exit_code=0, + stdout="output\n__STRANDS_CWD__\n/tmp\n", + stderr="", + ) + + mock_agent.sandbox = AsyncMock() + mock_agent.sandbox.execute_streaming = fake_streaming + + chunks, result = await collect_generator( + shell.__wrapped__(command="test", tool_context=tool_context) + ) + + # No partial marker should leak + all_chunk_data = "".join(c.data for c in chunks if c.stream_type == "stdout") + assert "__STRAN" not in all_chunk_data + assert _CWD_MARKER not in all_chunk_data + # User output should be preserved + assert "output" in all_chunk_data + + +class TestEditorTabExpansionFixed: + """Verify Fixes #2 and #4: Tab expansion no longer corrupts files or creates false matches.""" + + @pytest.mark.asyncio + async def test_tabs_preserved_after_edit(self, tool_context, sandbox, tmp_path): + """Editing a file should NOT destroy tab characters.""" + path = f"{tmp_path}/preserve_tabs.txt" + original = "def foo():\n\treturn 42\n" + await sandbox.write_text(path, original) + + # Edit something unrelated to tabs + result = await editor.__wrapped__( + command="str_replace", + path=path, + old_str="42", + new_str="99", + tool_context=tool_context, + ) + + content = await sandbox.read_text(path) + # Tab MUST still be present + assert "\t" in content, f"Tab was destroyed! Content: {repr(content)}" + assert content == "def foo():\n\treturn 99\n" + + @pytest.mark.asyncio + async def test_tab_and_spaces_not_conflated(self, tool_context, sandbox, tmp_path): + """Tab and 8 spaces should NOT be treated as the same thing.""" + path = f"{tmp_path}/tabs.txt" + content = "\thello\n hello\n" + await sandbox.write_text(path, content) + + # Replace the tab version — should work (unique in source) + result = await editor.__wrapped__( + command="str_replace", + path=path, + old_str="\thello", + new_str="replaced", + tool_context=tool_context, + ) + + # Should succeed — no "multiple occurrences" error + assert "edited" in result.lower(), f"Failed: {result}" + new_content = await sandbox.read_text(path) + assert "replaced" in new_content + assert " hello" in new_content # 8-space version unchanged + + @pytest.mark.asyncio + async def test_insert_preserves_tabs(self, tool_context, sandbox, tmp_path): + """Insert should not expand tabs in existing content.""" + path = f"{tmp_path}/insert_tabs.txt" + await sandbox.write_text(path, "line1\n\tindented\nline3") + + result = await editor.__wrapped__( + command="insert", + path=path, + insert_line=1, + new_str="new_line", + tool_context=tool_context, + ) + + content = await sandbox.read_text(path) + assert "\t" in content, f"Tab destroyed by insert! Content: {repr(content)}" + + +class TestEditorFloatViewRangeFixed: + """Verify Fix #6: Float view_range no longer crashes.""" + + @pytest.mark.asyncio + async def test_float_view_range_coerced(self, tool_context, sandbox, tmp_path): + """view_range with floats [1.0, 3.0] should work (coerced to ints).""" + path = f"{tmp_path}/float_range.txt" + await sandbox.write_text(path, "line1\nline2\nline3\nline4\nline5") + + result = await editor.__wrapped__( + command="view", + path=path, + view_range=[1.0, 3.0], + tool_context=tool_context, + ) + + # Should work, showing lines 1-3 + assert "line1" in result + assert "line2" in result + assert "line3" in result + assert "line4" not in result + + @pytest.mark.asyncio + async def test_invalid_view_range_type_gives_error(self, tool_context, sandbox, tmp_path): + """Non-numeric view_range should give a clear error.""" + path = f"{tmp_path}/bad_range.txt" + await sandbox.write_text(path, "line1\nline2\nline3") + + result = await editor.__wrapped__( + command="view", + path=path, + view_range=["a", "b"], + tool_context=tool_context, + ) + + assert "error" in result.lower() From f21618f50b7c577a4fec3ea4c5363d951eb78515 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:08:26 +0000 Subject: [PATCH 6/7] refactor: restore real-time streaming with marker-aware buffer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the full-buffering approach (which collected all stdout chunks and yielded a single batch at the end) with a sliding-window buffer that yields stdout incrementally as it arrives. The approach: - Accumulate stdout in a buffer, yield the 'safe' prefix on each chunk - 'Safe' = bytes that cannot be part of the CWD marker (no suffix of the yielded text matches a prefix of __STRANDS_CWD__) - Only the tail matching a potential marker prefix is held back - At the end, flush the buffer with rsplit to strip the marker This preserves real-time streaming for UI consumers while still preventing any partial or full marker leakage — even when the marker is split across chunk boundaries. Tests: 79 passed (no regressions) --- src/strands/vended_tools/shell.py | 60 ++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/src/strands/vended_tools/shell.py b/src/strands/vended_tools/shell.py index 2841d8daa..a7d4efa99 100644 --- a/src/strands/vended_tools/shell.py +++ b/src/strands/vended_tools/shell.py @@ -49,6 +49,28 @@ _CWD_MARKER = "__STRANDS_CWD__" +def _safe_yield_length(buffer: str, marker: str) -> int: + """Return the number of chars from the start of buffer that are safe to yield. + + "Safe" means: no suffix of the yielded portion could be a prefix of the marker. + This prevents partial marker leakage when the marker is split across chunks. + """ + # Check if the buffer contains the full marker — if so, only yield up to it + marker_pos = buffer.find(marker) + if marker_pos != -1: + return marker_pos + + # Check if the end of the buffer matches a prefix of the marker. + # e.g., buffer ends with "__STRAN" which is a prefix of "__STRANDS_CWD__" + max_overlap = min(len(marker) - 1, len(buffer)) + for i in range(max_overlap, 0, -1): + if buffer.endswith(marker[:i]): + return len(buffer) - i + + # No overlap — everything is safe + return len(buffer) + + @tool(context=True) async def shell( command: str, @@ -113,15 +135,12 @@ async def shell( # the shell process they run in — a separate pwd call would not see them). tracked_command = f"{command}; echo {_CWD_MARKER}; pwd" - # Sliding-window streaming: buffer the last len(_CWD_MARKER) bytes of - # stdout to handle the case where the marker is split across two chunks. - # We yield everything except a trailing window that might contain the - # start of the marker. Once we see the full result, we filter cleanly. - # - # The approach: collect all stdout chunks, then yield filtered output - # from result.stdout (which has the complete, un-fragmented text). - # stderr chunks are yielded immediately (never contain the marker). - stdout_chunks: list[StreamChunk] = [] + # Streaming with marker-aware buffering: + # We accumulate stdout in a buffer and yield only the "safe" prefix — i.e., + # bytes that cannot possibly be part of the CWD marker. This preserves + # real-time streaming for all output while ensuring no partial or full + # marker ever leaks to UI consumers. stderr is always yielded immediately. + stdout_buffer = "" result: ExecutionResult | None = None try: @@ -132,11 +151,17 @@ async def shell( ): if isinstance(chunk, StreamChunk): if chunk.stream_type == "stderr": - # stderr is safe — yield immediately + # stderr never contains the marker — yield immediately yield chunk else: - # Collect stdout chunks — we'll yield filtered output after - stdout_chunks.append(chunk) + # Append to buffer, yield whatever is safely past the marker + stdout_buffer += chunk.data + safe_len = _safe_yield_length(stdout_buffer, _CWD_MARKER) + if safe_len > 0: + yield StreamChunk( + data=stdout_buffer[:safe_len], stream_type="stdout" + ) + stdout_buffer = stdout_buffer[safe_len:] elif isinstance(chunk, ExecutionResult): result = chunk except asyncio.TimeoutError: @@ -167,11 +192,12 @@ async def shell( shell_state["cwd"] = new_cwd tool_context.agent.state.set("_strands_shell_state", shell_state) - # Yield the filtered stdout as a single chunk (marker fully removed). - # This avoids partial marker leakage that occurs with chunk-by-chunk - # filtering when the marker is split across chunk boundaries. - if stdout: - yield StreamChunk(data=stdout, stream_type="stdout") + # Flush remaining buffer with marker stripped. + # The buffer holds the tail that overlapped with the marker prefix. + if stdout_buffer: + cleaned = stdout_buffer.rsplit(_CWD_MARKER, 1)[0].rstrip("\n") + if cleaned: + yield StreamChunk(data=cleaned, stream_type="stdout") # Format final output (becomes the ToolResult) output_parts = [] From 1a46e0ff2729835665513012b1b4fbd95f98af9c Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Thu, 30 Apr 2026 19:33:12 +0000 Subject: [PATCH 7/7] fix: add timeout coercion to python_repl (matching shell.py fix) Apply the same int() coercion pattern from shell.py to python_repl.py. Without this, string/float timeouts from JSON configs or LLM-generated params crash with TypeError in asyncio.wait_for(). Addresses Round 5 review feedback. --- src/strands/vended_tools/python_repl.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/strands/vended_tools/python_repl.py b/src/strands/vended_tools/python_repl.py index 69ee5e5ad..3a86b6e83 100644 --- a/src/strands/vended_tools/python_repl.py +++ b/src/strands/vended_tools/python_repl.py @@ -91,6 +91,13 @@ async def python_repl( if effective_timeout is None: effective_timeout = config.get("timeout", DEFAULT_TIMEOUT) + # Coerce timeout to int — JSON configs and LLMs may pass strings or floats + if effective_timeout is not None: + try: + effective_timeout = int(effective_timeout) + except (TypeError, ValueError): + effective_timeout = DEFAULT_TIMEOUT + # Execute via sandbox streaming result: ExecutionResult | None = None try: