diff --git a/src/aish/shell/runtime/app.py b/src/aish/shell/runtime/app.py index bc9b832..af8ddcc 100644 --- a/src/aish/shell/runtime/app.py +++ b/src/aish/shell/runtime/app.py @@ -1802,7 +1802,10 @@ def _track_backend_event(self, event: BackendControlEvent) -> None: self._shell_phase = "editing" elif event.type == "shell_exiting": self._shell_phase = "recovery_exit" - self._running = False + if self._should_exit_on_pty_close(): + self._running = False + elif not self._restart_pty(): + self._running = False def _handle_control_event(self) -> None: if not self._pty_manager or self._pty_manager.control_fd is None: diff --git a/src/aish/tools/code_exec.py b/src/aish/tools/code_exec.py index f618c68..07bc438 100644 --- a/src/aish/tools/code_exec.py +++ b/src/aish/tools/code_exec.py @@ -101,6 +101,10 @@ def _needs_interactive_bash(command: str) -> bool: return False +def _is_multiline_bash(command: str) -> bool: + return "\n" in command or "\r" in command + + def _collapse_output_lines(text: str, max_lines: int = DISPLAY_MAX_LINES) -> str: lines = text.splitlines() if len(lines) <= max_lines: @@ -634,7 +638,11 @@ async def __call__(self, code: str) -> ToolResult: ) pty_rc = returncode used_interactive_executor = True - elif self.pty_manager and self.pty_manager.is_running: + elif ( + self.pty_manager + and self.pty_manager.is_running + and not _is_multiline_bash(code) + ): # PTY execution: share user's bash session pty_stdout, pty_rc = self.pty_manager.execute_command(code) stdout = pty_stdout diff --git a/tests/shell/runtime/test_shell_pty_core.py b/tests/shell/runtime/test_shell_pty_core.py index 3f3620f..9afb960 100644 --- a/tests/shell/runtime/test_shell_pty_core.py +++ b/tests/shell/runtime/test_shell_pty_core.py @@ -1326,6 +1326,102 @@ def test_shell_does_not_restart_after_explicit_quit_when_flag_was_not_set(monkey assert shell._running is False shell._restart_pty.assert_not_called() + + +def test_shell_restarts_on_unexpected_shell_exiting_event(): + shell = object.__new__(PTYAIShell) + shell._pty_manager = _FakePTYManager(last_command="") + shell._backend_protocol_events = [] + shell._backend_protocol_errors = [] + shell._last_backend_event = None + shell._backend_session_ready = True + shell._shell_phase = "running_passthrough" + shell._next_command_seq = 1 + shell._pending_command_seq = None + shell._pending_command_text = None + shell._running = True + shell._output_processor = Mock() + shell._user_requested_exit = False + shell._restart_pty = Mock(return_value=True) + + PTYAIShell._track_backend_event( + shell, + BackendControlEvent( + version=1, + type="shell_exiting", + ts=1, + payload={"exit_code": 2}, + ), + ) + + assert shell._running is True + assert shell._shell_phase == "recovery_exit" + shell._restart_pty.assert_called_once_with() + + +def test_shell_exiting_event_honors_explicit_user_exit(): + shell = object.__new__(PTYAIShell) + shell._pty_manager = SimpleNamespace( + last_command="exit", + handle_backend_event=Mock(return_value=None), + ) + shell._backend_protocol_events = [] + shell._backend_protocol_errors = [] + shell._last_backend_event = None + shell._backend_session_ready = True + shell._shell_phase = "running_passthrough" + shell._next_command_seq = 1 + shell._pending_command_seq = None + shell._pending_command_text = None + shell._running = True + shell._output_processor = Mock() + shell._user_requested_exit = False + shell._restart_pty = Mock(return_value=True) + + PTYAIShell._track_backend_event( + shell, + BackendControlEvent( + version=1, + type="shell_exiting", + ts=1, + payload={"exit_code": 0}, + ), + ) + + assert shell._running is False + shell._restart_pty.assert_not_called() + + +def test_shell_stops_when_shell_exiting_recovery_fails(): + shell = object.__new__(PTYAIShell) + shell._pty_manager = _FakePTYManager(last_command="") + shell._backend_protocol_events = [] + shell._backend_protocol_errors = [] + shell._last_backend_event = None + shell._backend_session_ready = True + shell._shell_phase = "running_passthrough" + shell._next_command_seq = 1 + shell._pending_command_seq = None + shell._pending_command_text = None + shell._running = True + shell._output_processor = Mock() + shell._user_requested_exit = False + shell._restart_pty = Mock(return_value=False) + + PTYAIShell._track_backend_event( + shell, + BackendControlEvent( + version=1, + type="shell_exiting", + ts=1, + payload={"exit_code": 2}, + ), + ) + + assert shell._running is False + shell._restart_pty.assert_called_once_with() + + def test_backend_error_suppressed_prevents_repeated_hints(capsys): pty_manager = _FakePTYManager() processor = OutputProcessor(pty_manager) diff --git a/tests/tools/test_bash_output_offload.py b/tests/tools/test_bash_output_offload.py index 3a2ea43..2c24698 100644 --- a/tests/tools/test_bash_output_offload.py +++ b/tests/tools/test_bash_output_offload.py @@ -182,7 +182,7 @@ async def test_bash_exec_uses_shared_pty_without_metadata_leak(tmp_path: Path): result = await tool("printf 'hello\\n'") assert result.ok is True - assert _extract_tag(result.output, "stdout") == "hello" + assert "hello" in _extract_tag(result.output, "stdout") assert "__AISH_ACTIVE_COMMAND_SEQ" not in result.output assert "__AISH_ACTIVE_COMMAND_TEXT" not in result.output finally: @@ -224,10 +224,15 @@ def execute_command(self, _code: str): @pytest.mark.asyncio -async def test_bash_exec_uses_shared_pty_for_multiline_command_without_echo(tmp_path: Path): - manager = PTYManager(use_output_thread=False, env={"HISTFILE": str(tmp_path / "bash_history")}) +async def test_bash_exec_bypasses_shared_pty_for_multiline_scripts(): + class _FakePTYManager: + is_running = True + + def execute_command(self, _code: str): + raise AssertionError("multiline scripts should not use shared PTY") + tool = BashTool( - pty_manager=manager, + pty_manager=_FakePTYManager(), offload_settings=BashOutputOffloadSettings( enabled=True, threshold_bytes=1024, @@ -235,17 +240,23 @@ async def test_bash_exec_uses_shared_pty_for_multiline_command_without_echo(tmp_ ), ) - manager.start() - try: - with patch.object(tool.security_manager, "decide", return_value=_allow_decision()): - result = await tool("printf 'hello\\n' && \\\nprintf 'world\\n'") + script = 'set -euo pipefail\nprintf "--- bad\\n"' + with ( + patch.object(tool.security_manager, "decide", return_value=_allow_decision()), + patch.object( + tool.executor, + "execute", + return_value=(False, "", "boom\n", 2, {}), + ) as execute_mock, + patch("builtins.print"), + ): + result = await tool(script) - assert result.ok is True - assert _extract_tag(result.output, "stdout") == "hello\nworld" - assert "__AISH_ACTIVE_COMMAND_SEQ" not in result.output - assert "__AISH_ACTIVE_COMMAND_TEXT" not in result.output - finally: - manager.stop() + assert result.ok is False + assert result.code == 2 + execute_mock.assert_called_once() + assert execute_mock.call_args.args[0] == script + assert "return_code" in result.output @pytest.mark.asyncio