diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 9d5362e9..a7c3311d 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -24,6 +24,7 @@ StreamEvent, StreamState, TapeContext, + Tool, ToolAutoResult, ToolContext, ) @@ -34,7 +35,7 @@ from bub.builtin.tape import TapeService from bub.framework import BubFramework from bub.skills import discover_skills, render_skills_prompt -from bub.tools import REGISTRY, model_tools, render_tools_prompt +from bub.tools import REGISTRY, model_tools, render_tools_prompt, resolve_tool_names from bub.types import State from bub.utils import workspace_from_state @@ -530,12 +531,12 @@ async def _run_once( ) -> AsyncStreamEvents | ToolAutoResult: prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt) if allowed_tools is not None: - allowed_tools = {name.casefold() for name in allowed_tools} + allowed_tools = resolve_tool_names(allowed_tools) if allowed_skills is not None: allowed_skills = {name.casefold() for name in allowed_skills} tape.context.state["allowed_skills"] = list(allowed_skills) if allowed_tools is not None: - tools = [tool for tool in REGISTRY.values() if tool.name.casefold() in allowed_tools] + tools = [tool for tool in REGISTRY.values() if tool.name in allowed_tools] else: tools = list(REGISTRY.values()) async with asyncio.timeout(self.settings.model_timeout_seconds): @@ -543,7 +544,7 @@ async def _run_once( return await tape.stream_events_async( prompt=prompt, system_prompt=self._system_prompt( - prompt_text, state=tape.context.state, allowed_skills=allowed_skills + prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools ), max_tokens=self.settings.max_tokens, tools=model_tools(tools), @@ -553,18 +554,20 @@ async def _run_once( return await tape.run_tools_async( prompt=prompt, system_prompt=self._system_prompt( - prompt_text, state=tape.context.state, allowed_skills=allowed_skills + prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools ), max_tokens=self.settings.max_tokens, tools=model_tools(tools), model=model, ) - def _system_prompt(self, prompt: str, state: State, allowed_skills: set[str] | None = None) -> str: + def _system_prompt( + self, prompt: str, state: State, allowed_skills: set[str] | None = None, tools: Iterable[Tool] | None = None + ) -> str: blocks: list[str] = [] if result := self.framework.get_system_prompt(prompt=prompt, state=state): blocks.append(result) - tools_prompt = render_tools_prompt(REGISTRY.values()) + tools_prompt = render_tools_prompt(tools if tools is not None else REGISTRY.values()) if tools_prompt: blocks.append(tools_prompt) workspace = workspace_from_state(state) diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 9a3064f3..84881974 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -143,10 +143,20 @@ async def stream_events( ) -> AsyncIterable[StreamEvent]: live: Live | None = None text = "" + + def finish_live() -> None: + nonlocal live + if live is not None: + self._renderer.finish_stream(live, kind=message.kind, text=text) + live = None + try: async for event in stream: if event.kind == "text": content = str(event.data.get("delta", "")) + if not content: + yield event + continue if not content.strip() and not text: continue # skip leading whitespace-only events text += content @@ -154,10 +164,11 @@ async def stream_events( live = self._renderer.start_stream(message.kind, text) else: self._renderer.update_stream(live, kind=message.kind, text=text) + elif event.kind == "final": + finish_live() yield event finally: - if live is not None: - self._renderer.finish_stream(live, kind=message.kind, text=text) + finish_live() def _build_prompt(self, workspace: Path) -> PromptSession[str]: kb = KeyBindings() diff --git a/src/bub/channels/cli/renderer.py b/src/bub/channels/cli/renderer.py index 8f4581f9..2ec5b836 100644 --- a/src/bub/channels/cli/renderer.py +++ b/src/bub/channels/cli/renderer.py @@ -63,7 +63,7 @@ def start_stream(self, kind: MessageKind, text: str) -> Live: console=self.console, auto_refresh=False, transient=False, - vertical_overflow="visible", + vertical_overflow="ellipsis", ) live.start(refresh=True) return live @@ -72,7 +72,7 @@ def update_stream(self, live: Live, *, kind: MessageKind, text: str) -> None: live.update(self.panel(kind, text), refresh=True) def finish_stream(self, live: Live, *, kind: MessageKind, text: str) -> None: - live.update(self.panel(kind, text), refresh=True) + live.update(self.panel(kind, text), refresh=False) live.stop() @staticmethod diff --git a/src/bub/tools.py b/src/bub/tools.py index 046cc4b1..7cdbd799 100644 --- a/src/bub/tools.py +++ b/src/bub/tools.py @@ -186,13 +186,24 @@ def model_tools(tools: Iterable[Tool]) -> list[Tool]: return [replace(tool, name=_to_model_name(tool.name)) for tool in tools] +def _tool_signature(tool: Tool) -> str: + properties = tool.parameters.get("properties", {}) + if not isinstance(properties, dict) or not properties: + return f"{_to_model_name(tool.name)}()" + + required = tool.parameters.get("required", []) + required_names = set(required) if isinstance(required, list) else set() + params = [name if name in required_names else f"{name}?" for name in properties] + return f"{_to_model_name(tool.name)}({', '.join(params)})" + + def render_tools_prompt(tools: Iterable[Tool]) -> str: """Render a human-readable description of tools for model prompts.""" if not tools: return "" lines = [] for tool in tools: - line = f"- {_to_model_name(tool.name)}" + line = f"- {_tool_signature(tool)}" if tool.description: line += f": {tool.description}" lines.append(line) diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index 4adeb169..2ad86f76 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -12,6 +12,7 @@ import bub.builtin.agent as agent_module from bub.builtin.agent import Agent from bub.builtin.settings import AgentSettings +from bub.tools import REGISTRY, tool def test_build_llm_passes_codex_resolver_to_republic(monkeypatch) -> None: @@ -84,6 +85,7 @@ class _FakeTapeService: def __init__(self, fork_capture: _ForkCapture) -> None: self._fork = fork_capture self.run_tools_model: str | None = None + self.stream_kwargs: dict[str, Any] | None = None def session_tape(self, session_id: str, workspace: Any) -> MagicMock: tape = MagicMock() @@ -92,6 +94,7 @@ def session_tape(self, session_id: str, workspace: Any) -> MagicMock: async def fake_stream_events_async(**kwargs: Any) -> AsyncStreamEvents: self.run_tools_model = kwargs.get("model") + self.stream_kwargs = kwargs async def iterator(): yield StreamEvent("final", {"text": "done"}) @@ -184,3 +187,56 @@ async def test_agent_run_model_defaults_to_none() -> None: [event async for event in result] assert fake_tapes.run_tools_model is None + + +@pytest.mark.asyncio +async def test_agent_run_resolves_allowed_tool_aliases_and_limits_prompt() -> None: + allowed_name = "tests.allowed_agent_tool" + denied_name = "tests.denied_agent_tool" + REGISTRY.pop(allowed_name, None) + REGISTRY.pop(denied_name, None) + + @tool(name=allowed_name, description="Allowed tool") + def allowed_agent_tool() -> str: + return "allowed" + + @tool(name=denied_name, description="Denied tool") + def denied_agent_tool() -> str: + return "denied" + + agent = _make_agent() + fork_capture = _ForkCapture() + fake_tapes = _FakeTapeService(fork_capture) + agent.tapes = fake_tapes # type: ignore[assignment] + + result = await agent.run_stream( + session_id="user/s1", + prompt="hello", + state={"_runtime_workspace": "/tmp"}, # noqa: S108 + allowed_tools=[" tests_allowed_agent_tool "], + ) + [event async for event in result] + + assert fake_tapes.stream_kwargs is not None + assert [tool.name for tool in fake_tapes.stream_kwargs["tools"]] == ["tests_allowed_agent_tool"] + system_prompt = fake_tapes.stream_kwargs["system_prompt"] + assert "- tests_allowed_agent_tool(): Allowed tool" in system_prompt + assert "tests_denied_agent_tool" not in system_prompt + + +@pytest.mark.asyncio +async def test_agent_run_rejects_unknown_allowed_tools() -> None: + agent = _make_agent() + fork_capture = _ForkCapture() + fake_tapes = _FakeTapeService(fork_capture) + agent.tapes = fake_tapes # type: ignore[assignment] + + stream = await agent.run_stream( + session_id="user/s1", + prompt="hello", + state={"_runtime_workspace": "/tmp"}, # noqa: S108 + allowed_tools=["tests_missing_agent_tool"], + ) + + with pytest.raises(ValueError, match="tests_missing_agent_tool"): + [event async for event in stream] diff --git a/tests/test_channels.py b/tests/test_channels.py index 7fd2ab02..9790318a 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -480,6 +480,61 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]: assert [event.kind for event in yielded] == ["text", "text", "final"] +@pytest.mark.asyncio +async def test_cli_channel_stream_events_skips_empty_delta_rerenders() -> None: + channel = CliChannel.__new__(CliChannel) + events: list[tuple[str, str, str]] = [] + live_handle = object() + channel._renderer = SimpleNamespace( + start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle, + update_stream=lambda live, *, kind, text: events.append(("update", kind, text)), + finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)), + ) + + message = _message("ignored", channel="cli", kind="normal", session_id="cli:1") + + async def source() -> asyncio.AsyncIterator[StreamEvent]: + yield StreamEvent("text", {"delta": "hello"}) + yield StreamEvent("text", {"delta": ""}) + yield StreamEvent("final", {}) + + yielded = [event async for event in channel.stream_events(message, source())] + + assert events == [ + ("start", "normal", "hello"), + ("finish", "normal", "hello"), + ] + assert [event.kind for event in yielded] == ["text", "text", "final"] + + +@pytest.mark.asyncio +async def test_cli_channel_finishes_live_before_stream_cleanup_side_effects() -> None: + channel = CliChannel.__new__(CliChannel) + events: list[tuple[str, str, str]] = [] + live_handle = object() + channel._renderer = SimpleNamespace( + start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle, + update_stream=lambda live, *, kind, text: events.append(("update", kind, text)), + finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)), + ) + + message = _message("ignored", channel="cli", kind="normal", session_id="cli:1") + + async def source() -> asyncio.AsyncIterator[StreamEvent]: + yield StreamEvent("text", {"delta": "hello"}) + yield StreamEvent("final", {}) + events.append(("cleanup", "source", "after-final")) + + yielded = [event async for event in channel.stream_events(message, source())] + + assert events == [ + ("start", "normal", "hello"), + ("finish", "normal", "hello"), + ("cleanup", "source", "after-final"), + ] + assert [event.kind for event in yielded] == ["text", "final"] + + def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: home = tmp_path / "home" workspace = tmp_path / "workspace" @@ -497,6 +552,7 @@ class FakeLive: def __init__(self, renderable, **kwargs) -> None: live_calls.append(("init", renderable)) live_calls.append(("transient", kwargs["transient"])) + live_calls.append(("vertical_overflow", kwargs["vertical_overflow"])) self.renderable = renderable def start(self, *, refresh: bool = False) -> None: @@ -519,8 +575,12 @@ def stop(self) -> None: renderer.finish_stream(live, kind="normal", text="hello") # type: ignore[arg-type] assert ("transient", False) in live_calls + assert ("vertical_overflow", "ellipsis") in live_calls assert ("start_refresh", True) in live_calls - assert ("update_refresh", True) in live_calls + assert [call for call in live_calls if call[0] == "update_refresh"] == [ + ("update_refresh", True), + ("update_refresh", False), + ] assert not printed diff --git a/tests/test_tools.py b/tests/test_tools.py index e30588fa..d349c8bf 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -93,13 +93,15 @@ def test_model_tools_rewrites_dotted_names_without_mutating_original() -> None: REGISTRY.pop(tool_name, None) @tool(name=tool_name, description="rename") - def rename_me() -> str: + def rename_me(value: str) -> str: return "ok" rewritten = model_tools([rename_me]) assert [item.name for item in rewritten] == ["tests_rename_me"] + assert rewritten[0].parameters == rename_me.parameters assert rename_me.name == tool_name + assert "additionalProperties" not in rename_me.parameters def test_render_tools_prompt_renders_available_tools_block() -> None: @@ -118,7 +120,20 @@ def prompt_two() -> str: rendered = render_tools_prompt([prompt_one, prompt_two]) - assert rendered == "\n- tests_prompt_one: First tool\n- tests_prompt_two\n" + assert rendered == "\n- tests_prompt_one(): First tool\n- tests_prompt_two()\n" + + +def test_render_tools_prompt_includes_model_name_and_parameter_signature() -> None: + tool_name = "tests.prompt_signature" + REGISTRY.pop(tool_name, None) + + @tool(name=tool_name, description="Read a file") + def prompt_signature(path: str, offset: int = 0) -> str: + return f"{path}:{offset}" + + rendered = render_tools_prompt([prompt_signature]) + + assert rendered == "\n- tests_prompt_signature(path, offset?): Read a file\n" def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None: @@ -128,8 +143,10 @@ def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None: def test_resolve_tool_names_accepts_runtime_names_and_model_aliases() -> None: dotted_name = "tests.resolve_alias" underscored_name = "tests_with_underscore" + excluded_name = "tests.excluded_tool" REGISTRY.pop(dotted_name, None) REGISTRY.pop(underscored_name, None) + REGISTRY.pop(excluded_name, None) @tool(name=dotted_name) def resolve_alias() -> str: @@ -139,11 +156,16 @@ def resolve_alias() -> str: def resolve_runtime_name() -> str: return "runtime" - assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" subagent "}) == { + @tool(name=excluded_name) + def excluded_tool() -> str: + return "excluded" + + assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" tests_excluded_tool "}) == { dotted_name, underscored_name, } assert dotted_name not in resolve_tool_names(None, exclude={" tests_resolve_alias "}) + assert excluded_name not in resolve_tool_names(None, exclude={" tests_excluded_tool "}) assert resolve_tool_names(None, exclude={" tests_resolve_alias "}) >= {underscored_name}