diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py
index 9d5362e9..a7c3311d 100644
--- a/src/bub/builtin/agent.py
+++ b/src/bub/builtin/agent.py
@@ -24,6 +24,7 @@
StreamEvent,
StreamState,
TapeContext,
+ Tool,
ToolAutoResult,
ToolContext,
)
@@ -34,7 +35,7 @@
from bub.builtin.tape import TapeService
from bub.framework import BubFramework
from bub.skills import discover_skills, render_skills_prompt
-from bub.tools import REGISTRY, model_tools, render_tools_prompt
+from bub.tools import REGISTRY, model_tools, render_tools_prompt, resolve_tool_names
from bub.types import State
from bub.utils import workspace_from_state
@@ -530,12 +531,12 @@ async def _run_once(
) -> AsyncStreamEvents | ToolAutoResult:
prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt)
if allowed_tools is not None:
- allowed_tools = {name.casefold() for name in allowed_tools}
+ allowed_tools = resolve_tool_names(allowed_tools)
if allowed_skills is not None:
allowed_skills = {name.casefold() for name in allowed_skills}
tape.context.state["allowed_skills"] = list(allowed_skills)
if allowed_tools is not None:
- tools = [tool for tool in REGISTRY.values() if tool.name.casefold() in allowed_tools]
+ tools = [tool for tool in REGISTRY.values() if tool.name in allowed_tools]
else:
tools = list(REGISTRY.values())
async with asyncio.timeout(self.settings.model_timeout_seconds):
@@ -543,7 +544,7 @@ async def _run_once(
return await tape.stream_events_async(
prompt=prompt,
system_prompt=self._system_prompt(
- prompt_text, state=tape.context.state, allowed_skills=allowed_skills
+ prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools
),
max_tokens=self.settings.max_tokens,
tools=model_tools(tools),
@@ -553,18 +554,20 @@ async def _run_once(
return await tape.run_tools_async(
prompt=prompt,
system_prompt=self._system_prompt(
- prompt_text, state=tape.context.state, allowed_skills=allowed_skills
+ prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools
),
max_tokens=self.settings.max_tokens,
tools=model_tools(tools),
model=model,
)
- def _system_prompt(self, prompt: str, state: State, allowed_skills: set[str] | None = None) -> str:
+ def _system_prompt(
+ self, prompt: str, state: State, allowed_skills: set[str] | None = None, tools: Iterable[Tool] | None = None
+ ) -> str:
blocks: list[str] = []
if result := self.framework.get_system_prompt(prompt=prompt, state=state):
blocks.append(result)
- tools_prompt = render_tools_prompt(REGISTRY.values())
+ tools_prompt = render_tools_prompt(tools if tools is not None else REGISTRY.values())
if tools_prompt:
blocks.append(tools_prompt)
workspace = workspace_from_state(state)
diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py
index 9a3064f3..84881974 100644
--- a/src/bub/channels/cli/__init__.py
+++ b/src/bub/channels/cli/__init__.py
@@ -143,10 +143,20 @@ async def stream_events(
) -> AsyncIterable[StreamEvent]:
live: Live | None = None
text = ""
+
+ def finish_live() -> None:
+ nonlocal live
+ if live is not None:
+ self._renderer.finish_stream(live, kind=message.kind, text=text)
+ live = None
+
try:
async for event in stream:
if event.kind == "text":
content = str(event.data.get("delta", ""))
+ if not content:
+ yield event
+ continue
if not content.strip() and not text:
continue # skip leading whitespace-only events
text += content
@@ -154,10 +164,11 @@ async def stream_events(
live = self._renderer.start_stream(message.kind, text)
else:
self._renderer.update_stream(live, kind=message.kind, text=text)
+ elif event.kind == "final":
+ finish_live()
yield event
finally:
- if live is not None:
- self._renderer.finish_stream(live, kind=message.kind, text=text)
+ finish_live()
def _build_prompt(self, workspace: Path) -> PromptSession[str]:
kb = KeyBindings()
diff --git a/src/bub/channels/cli/renderer.py b/src/bub/channels/cli/renderer.py
index 8f4581f9..2ec5b836 100644
--- a/src/bub/channels/cli/renderer.py
+++ b/src/bub/channels/cli/renderer.py
@@ -63,7 +63,7 @@ def start_stream(self, kind: MessageKind, text: str) -> Live:
console=self.console,
auto_refresh=False,
transient=False,
- vertical_overflow="visible",
+ vertical_overflow="ellipsis",
)
live.start(refresh=True)
return live
@@ -72,7 +72,7 @@ def update_stream(self, live: Live, *, kind: MessageKind, text: str) -> None:
live.update(self.panel(kind, text), refresh=True)
def finish_stream(self, live: Live, *, kind: MessageKind, text: str) -> None:
- live.update(self.panel(kind, text), refresh=True)
+ live.update(self.panel(kind, text), refresh=False)
live.stop()
@staticmethod
diff --git a/src/bub/tools.py b/src/bub/tools.py
index 046cc4b1..7cdbd799 100644
--- a/src/bub/tools.py
+++ b/src/bub/tools.py
@@ -186,13 +186,24 @@ def model_tools(tools: Iterable[Tool]) -> list[Tool]:
return [replace(tool, name=_to_model_name(tool.name)) for tool in tools]
+def _tool_signature(tool: Tool) -> str:
+ properties = tool.parameters.get("properties", {})
+ if not isinstance(properties, dict) or not properties:
+ return f"{_to_model_name(tool.name)}()"
+
+ required = tool.parameters.get("required", [])
+ required_names = set(required) if isinstance(required, list) else set()
+ params = [name if name in required_names else f"{name}?" for name in properties]
+ return f"{_to_model_name(tool.name)}({', '.join(params)})"
+
+
def render_tools_prompt(tools: Iterable[Tool]) -> str:
"""Render a human-readable description of tools for model prompts."""
if not tools:
return ""
lines = []
for tool in tools:
- line = f"- {_to_model_name(tool.name)}"
+ line = f"- {_tool_signature(tool)}"
if tool.description:
line += f": {tool.description}"
lines.append(line)
diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py
index 4adeb169..2ad86f76 100644
--- a/tests/test_builtin_agent.py
+++ b/tests/test_builtin_agent.py
@@ -12,6 +12,7 @@
import bub.builtin.agent as agent_module
from bub.builtin.agent import Agent
from bub.builtin.settings import AgentSettings
+from bub.tools import REGISTRY, tool
def test_build_llm_passes_codex_resolver_to_republic(monkeypatch) -> None:
@@ -84,6 +85,7 @@ class _FakeTapeService:
def __init__(self, fork_capture: _ForkCapture) -> None:
self._fork = fork_capture
self.run_tools_model: str | None = None
+ self.stream_kwargs: dict[str, Any] | None = None
def session_tape(self, session_id: str, workspace: Any) -> MagicMock:
tape = MagicMock()
@@ -92,6 +94,7 @@ def session_tape(self, session_id: str, workspace: Any) -> MagicMock:
async def fake_stream_events_async(**kwargs: Any) -> AsyncStreamEvents:
self.run_tools_model = kwargs.get("model")
+ self.stream_kwargs = kwargs
async def iterator():
yield StreamEvent("final", {"text": "done"})
@@ -184,3 +187,56 @@ async def test_agent_run_model_defaults_to_none() -> None:
[event async for event in result]
assert fake_tapes.run_tools_model is None
+
+
+@pytest.mark.asyncio
+async def test_agent_run_resolves_allowed_tool_aliases_and_limits_prompt() -> None:
+ allowed_name = "tests.allowed_agent_tool"
+ denied_name = "tests.denied_agent_tool"
+ REGISTRY.pop(allowed_name, None)
+ REGISTRY.pop(denied_name, None)
+
+ @tool(name=allowed_name, description="Allowed tool")
+ def allowed_agent_tool() -> str:
+ return "allowed"
+
+ @tool(name=denied_name, description="Denied tool")
+ def denied_agent_tool() -> str:
+ return "denied"
+
+ agent = _make_agent()
+ fork_capture = _ForkCapture()
+ fake_tapes = _FakeTapeService(fork_capture)
+ agent.tapes = fake_tapes # type: ignore[assignment]
+
+ result = await agent.run_stream(
+ session_id="user/s1",
+ prompt="hello",
+ state={"_runtime_workspace": "/tmp"}, # noqa: S108
+ allowed_tools=[" tests_allowed_agent_tool "],
+ )
+ [event async for event in result]
+
+ assert fake_tapes.stream_kwargs is not None
+ assert [tool.name for tool in fake_tapes.stream_kwargs["tools"]] == ["tests_allowed_agent_tool"]
+ system_prompt = fake_tapes.stream_kwargs["system_prompt"]
+ assert "- tests_allowed_agent_tool(): Allowed tool" in system_prompt
+ assert "tests_denied_agent_tool" not in system_prompt
+
+
+@pytest.mark.asyncio
+async def test_agent_run_rejects_unknown_allowed_tools() -> None:
+ agent = _make_agent()
+ fork_capture = _ForkCapture()
+ fake_tapes = _FakeTapeService(fork_capture)
+ agent.tapes = fake_tapes # type: ignore[assignment]
+
+ stream = await agent.run_stream(
+ session_id="user/s1",
+ prompt="hello",
+ state={"_runtime_workspace": "/tmp"}, # noqa: S108
+ allowed_tools=["tests_missing_agent_tool"],
+ )
+
+ with pytest.raises(ValueError, match="tests_missing_agent_tool"):
+ [event async for event in stream]
diff --git a/tests/test_channels.py b/tests/test_channels.py
index 7fd2ab02..9790318a 100644
--- a/tests/test_channels.py
+++ b/tests/test_channels.py
@@ -480,6 +480,61 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]:
assert [event.kind for event in yielded] == ["text", "text", "final"]
+@pytest.mark.asyncio
+async def test_cli_channel_stream_events_skips_empty_delta_rerenders() -> None:
+ channel = CliChannel.__new__(CliChannel)
+ events: list[tuple[str, str, str]] = []
+ live_handle = object()
+ channel._renderer = SimpleNamespace(
+ start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle,
+ update_stream=lambda live, *, kind, text: events.append(("update", kind, text)),
+ finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)),
+ )
+
+ message = _message("ignored", channel="cli", kind="normal", session_id="cli:1")
+
+ async def source() -> asyncio.AsyncIterator[StreamEvent]:
+ yield StreamEvent("text", {"delta": "hello"})
+ yield StreamEvent("text", {"delta": ""})
+ yield StreamEvent("final", {})
+
+ yielded = [event async for event in channel.stream_events(message, source())]
+
+ assert events == [
+ ("start", "normal", "hello"),
+ ("finish", "normal", "hello"),
+ ]
+ assert [event.kind for event in yielded] == ["text", "text", "final"]
+
+
+@pytest.mark.asyncio
+async def test_cli_channel_finishes_live_before_stream_cleanup_side_effects() -> None:
+ channel = CliChannel.__new__(CliChannel)
+ events: list[tuple[str, str, str]] = []
+ live_handle = object()
+ channel._renderer = SimpleNamespace(
+ start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle,
+ update_stream=lambda live, *, kind, text: events.append(("update", kind, text)),
+ finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)),
+ )
+
+ message = _message("ignored", channel="cli", kind="normal", session_id="cli:1")
+
+ async def source() -> asyncio.AsyncIterator[StreamEvent]:
+ yield StreamEvent("text", {"delta": "hello"})
+ yield StreamEvent("final", {})
+ events.append(("cleanup", "source", "after-final"))
+
+ yielded = [event async for event in channel.stream_events(message, source())]
+
+ assert events == [
+ ("start", "normal", "hello"),
+ ("finish", "normal", "hello"),
+ ("cleanup", "source", "after-final"),
+ ]
+ assert [event.kind for event in yielded] == ["text", "final"]
+
+
def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None:
home = tmp_path / "home"
workspace = tmp_path / "workspace"
@@ -497,6 +552,7 @@ class FakeLive:
def __init__(self, renderable, **kwargs) -> None:
live_calls.append(("init", renderable))
live_calls.append(("transient", kwargs["transient"]))
+ live_calls.append(("vertical_overflow", kwargs["vertical_overflow"]))
self.renderable = renderable
def start(self, *, refresh: bool = False) -> None:
@@ -519,8 +575,12 @@ def stop(self) -> None:
renderer.finish_stream(live, kind="normal", text="hello") # type: ignore[arg-type]
assert ("transient", False) in live_calls
+ assert ("vertical_overflow", "ellipsis") in live_calls
assert ("start_refresh", True) in live_calls
- assert ("update_refresh", True) in live_calls
+ assert [call for call in live_calls if call[0] == "update_refresh"] == [
+ ("update_refresh", True),
+ ("update_refresh", False),
+ ]
assert not printed
diff --git a/tests/test_tools.py b/tests/test_tools.py
index e30588fa..d349c8bf 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -93,13 +93,15 @@ def test_model_tools_rewrites_dotted_names_without_mutating_original() -> None:
REGISTRY.pop(tool_name, None)
@tool(name=tool_name, description="rename")
- def rename_me() -> str:
+ def rename_me(value: str) -> str:
return "ok"
rewritten = model_tools([rename_me])
assert [item.name for item in rewritten] == ["tests_rename_me"]
+ assert rewritten[0].parameters == rename_me.parameters
assert rename_me.name == tool_name
+ assert "additionalProperties" not in rename_me.parameters
def test_render_tools_prompt_renders_available_tools_block() -> None:
@@ -118,7 +120,20 @@ def prompt_two() -> str:
rendered = render_tools_prompt([prompt_one, prompt_two])
- assert rendered == "\n- tests_prompt_one: First tool\n- tests_prompt_two\n"
+ assert rendered == "\n- tests_prompt_one(): First tool\n- tests_prompt_two()\n"
+
+
+def test_render_tools_prompt_includes_model_name_and_parameter_signature() -> None:
+ tool_name = "tests.prompt_signature"
+ REGISTRY.pop(tool_name, None)
+
+ @tool(name=tool_name, description="Read a file")
+ def prompt_signature(path: str, offset: int = 0) -> str:
+ return f"{path}:{offset}"
+
+ rendered = render_tools_prompt([prompt_signature])
+
+ assert rendered == "\n- tests_prompt_signature(path, offset?): Read a file\n"
def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None:
@@ -128,8 +143,10 @@ def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None:
def test_resolve_tool_names_accepts_runtime_names_and_model_aliases() -> None:
dotted_name = "tests.resolve_alias"
underscored_name = "tests_with_underscore"
+ excluded_name = "tests.excluded_tool"
REGISTRY.pop(dotted_name, None)
REGISTRY.pop(underscored_name, None)
+ REGISTRY.pop(excluded_name, None)
@tool(name=dotted_name)
def resolve_alias() -> str:
@@ -139,11 +156,16 @@ def resolve_alias() -> str:
def resolve_runtime_name() -> str:
return "runtime"
- assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" subagent "}) == {
+ @tool(name=excluded_name)
+ def excluded_tool() -> str:
+ return "excluded"
+
+ assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" tests_excluded_tool "}) == {
dotted_name,
underscored_name,
}
assert dotted_name not in resolve_tool_names(None, exclude={" tests_resolve_alias "})
+ assert excluded_name not in resolve_tool_names(None, exclude={" tests_excluded_tool "})
assert resolve_tool_names(None, exclude={" tests_resolve_alias "}) >= {underscored_name}