Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/bub/builtin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
StreamEvent,
StreamState,
TapeContext,
Tool,
ToolAutoResult,
ToolContext,
)
Expand All @@ -34,7 +35,7 @@
from bub.builtin.tape import TapeService
from bub.framework import BubFramework
from bub.skills import discover_skills, render_skills_prompt
from bub.tools import REGISTRY, model_tools, render_tools_prompt
from bub.tools import REGISTRY, model_tools, render_tools_prompt, resolve_tool_names
from bub.types import State
from bub.utils import workspace_from_state

Expand Down Expand Up @@ -530,20 +531,20 @@ async def _run_once(
) -> AsyncStreamEvents | ToolAutoResult:
prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt)
if allowed_tools is not None:
allowed_tools = {name.casefold() for name in allowed_tools}
allowed_tools = resolve_tool_names(allowed_tools)
if allowed_skills is not None:
allowed_skills = {name.casefold() for name in allowed_skills}
tape.context.state["allowed_skills"] = list(allowed_skills)
if allowed_tools is not None:
tools = [tool for tool in REGISTRY.values() if tool.name.casefold() in allowed_tools]
tools = [tool for tool in REGISTRY.values() if tool.name in allowed_tools]
else:
tools = list(REGISTRY.values())
async with asyncio.timeout(self.settings.model_timeout_seconds):
if stream_output:
return await tape.stream_events_async(
prompt=prompt,
system_prompt=self._system_prompt(
prompt_text, state=tape.context.state, allowed_skills=allowed_skills
prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools
),
max_tokens=self.settings.max_tokens,
tools=model_tools(tools),
Expand All @@ -553,18 +554,20 @@ async def _run_once(
return await tape.run_tools_async(
prompt=prompt,
system_prompt=self._system_prompt(
prompt_text, state=tape.context.state, allowed_skills=allowed_skills
prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools
),
max_tokens=self.settings.max_tokens,
tools=model_tools(tools),
model=model,
)

def _system_prompt(self, prompt: str, state: State, allowed_skills: set[str] | None = None) -> str:
def _system_prompt(
self, prompt: str, state: State, allowed_skills: set[str] | None = None, tools: Iterable[Tool] | None = None
) -> str:
blocks: list[str] = []
if result := self.framework.get_system_prompt(prompt=prompt, state=state):
blocks.append(result)
tools_prompt = render_tools_prompt(REGISTRY.values())
tools_prompt = render_tools_prompt(tools if tools is not None else REGISTRY.values())
if tools_prompt:
blocks.append(tools_prompt)
workspace = workspace_from_state(state)
Expand Down
15 changes: 13 additions & 2 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,32 @@ async def stream_events(
) -> AsyncIterable[StreamEvent]:
live: Live | None = None
text = ""

def finish_live() -> None:
nonlocal live
if live is not None:
self._renderer.finish_stream(live, kind=message.kind, text=text)
live = None

try:
async for event in stream:
if event.kind == "text":
content = str(event.data.get("delta", ""))
if not content:
yield event
continue
if not content.strip() and not text:
continue # skip leading whitespace-only events
text += content
if live is None:
live = self._renderer.start_stream(message.kind, text)
else:
self._renderer.update_stream(live, kind=message.kind, text=text)
elif event.kind == "final":
finish_live()
yield event
finally:
if live is not None:
self._renderer.finish_stream(live, kind=message.kind, text=text)
finish_live()

def _build_prompt(self, workspace: Path) -> PromptSession[str]:
kb = KeyBindings()
Expand Down
4 changes: 2 additions & 2 deletions src/bub/channels/cli/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def start_stream(self, kind: MessageKind, text: str) -> Live:
console=self.console,
auto_refresh=False,
transient=False,
vertical_overflow="visible",
vertical_overflow="ellipsis",
)
live.start(refresh=True)
return live
Expand All @@ -72,7 +72,7 @@ def update_stream(self, live: Live, *, kind: MessageKind, text: str) -> None:
live.update(self.panel(kind, text), refresh=True)

def finish_stream(self, live: Live, *, kind: MessageKind, text: str) -> None:
live.update(self.panel(kind, text), refresh=True)
live.update(self.panel(kind, text), refresh=False)
live.stop()

@staticmethod
Expand Down
13 changes: 12 additions & 1 deletion src/bub/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,24 @@ def model_tools(tools: Iterable[Tool]) -> list[Tool]:
return [replace(tool, name=_to_model_name(tool.name)) for tool in tools]


def _tool_signature(tool: Tool) -> str:
properties = tool.parameters.get("properties", {})
if not isinstance(properties, dict) or not properties:
return f"{_to_model_name(tool.name)}()"

required = tool.parameters.get("required", [])
required_names = set(required) if isinstance(required, list) else set()
params = [name if name in required_names else f"{name}?" for name in properties]
return f"{_to_model_name(tool.name)}({', '.join(params)})"


def render_tools_prompt(tools: Iterable[Tool]) -> str:
"""Render a human-readable description of tools for model prompts."""
if not tools:
return ""
lines = []
for tool in tools:
line = f"- {_to_model_name(tool.name)}"
line = f"- {_tool_signature(tool)}"
if tool.description:
line += f": {tool.description}"
lines.append(line)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_builtin_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import bub.builtin.agent as agent_module
from bub.builtin.agent import Agent
from bub.builtin.settings import AgentSettings
from bub.tools import REGISTRY, tool


def test_build_llm_passes_codex_resolver_to_republic(monkeypatch) -> None:
Expand Down Expand Up @@ -84,6 +85,7 @@ class _FakeTapeService:
def __init__(self, fork_capture: _ForkCapture) -> None:
self._fork = fork_capture
self.run_tools_model: str | None = None
self.stream_kwargs: dict[str, Any] | None = None

def session_tape(self, session_id: str, workspace: Any) -> MagicMock:
tape = MagicMock()
Expand All @@ -92,6 +94,7 @@ def session_tape(self, session_id: str, workspace: Any) -> MagicMock:

async def fake_stream_events_async(**kwargs: Any) -> AsyncStreamEvents:
self.run_tools_model = kwargs.get("model")
self.stream_kwargs = kwargs

async def iterator():
yield StreamEvent("final", {"text": "done"})
Expand Down Expand Up @@ -184,3 +187,56 @@ async def test_agent_run_model_defaults_to_none() -> None:
[event async for event in result]

assert fake_tapes.run_tools_model is None


@pytest.mark.asyncio
async def test_agent_run_resolves_allowed_tool_aliases_and_limits_prompt() -> None:
allowed_name = "tests.allowed_agent_tool"
denied_name = "tests.denied_agent_tool"
REGISTRY.pop(allowed_name, None)
REGISTRY.pop(denied_name, None)

@tool(name=allowed_name, description="Allowed tool")
def allowed_agent_tool() -> str:
return "allowed"

@tool(name=denied_name, description="Denied tool")
def denied_agent_tool() -> str:
return "denied"

agent = _make_agent()
fork_capture = _ForkCapture()
fake_tapes = _FakeTapeService(fork_capture)
agent.tapes = fake_tapes # type: ignore[assignment]

result = await agent.run_stream(
session_id="user/s1",
prompt="hello",
state={"_runtime_workspace": "/tmp"}, # noqa: S108
allowed_tools=[" tests_allowed_agent_tool "],
)
[event async for event in result]

assert fake_tapes.stream_kwargs is not None
assert [tool.name for tool in fake_tapes.stream_kwargs["tools"]] == ["tests_allowed_agent_tool"]
system_prompt = fake_tapes.stream_kwargs["system_prompt"]
assert "- tests_allowed_agent_tool(): Allowed tool" in system_prompt
assert "tests_denied_agent_tool" not in system_prompt


@pytest.mark.asyncio
async def test_agent_run_rejects_unknown_allowed_tools() -> None:
agent = _make_agent()
fork_capture = _ForkCapture()
fake_tapes = _FakeTapeService(fork_capture)
agent.tapes = fake_tapes # type: ignore[assignment]

stream = await agent.run_stream(
session_id="user/s1",
prompt="hello",
state={"_runtime_workspace": "/tmp"}, # noqa: S108
allowed_tools=["tests_missing_agent_tool"],
)

with pytest.raises(ValueError, match="tests_missing_agent_tool"):
[event async for event in stream]
62 changes: 61 additions & 1 deletion tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,61 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]:
assert [event.kind for event in yielded] == ["text", "text", "final"]


@pytest.mark.asyncio
async def test_cli_channel_stream_events_skips_empty_delta_rerenders() -> None:
channel = CliChannel.__new__(CliChannel)
events: list[tuple[str, str, str]] = []
live_handle = object()
channel._renderer = SimpleNamespace(
start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle,
update_stream=lambda live, *, kind, text: events.append(("update", kind, text)),
finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)),
)

message = _message("ignored", channel="cli", kind="normal", session_id="cli:1")

async def source() -> asyncio.AsyncIterator[StreamEvent]:
yield StreamEvent("text", {"delta": "hello"})
yield StreamEvent("text", {"delta": ""})
yield StreamEvent("final", {})

yielded = [event async for event in channel.stream_events(message, source())]

assert events == [
("start", "normal", "hello"),
("finish", "normal", "hello"),
]
assert [event.kind for event in yielded] == ["text", "text", "final"]


@pytest.mark.asyncio
async def test_cli_channel_finishes_live_before_stream_cleanup_side_effects() -> None:
channel = CliChannel.__new__(CliChannel)
events: list[tuple[str, str, str]] = []
live_handle = object()
channel._renderer = SimpleNamespace(
start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle,
update_stream=lambda live, *, kind, text: events.append(("update", kind, text)),
finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)),
)

message = _message("ignored", channel="cli", kind="normal", session_id="cli:1")

async def source() -> asyncio.AsyncIterator[StreamEvent]:
yield StreamEvent("text", {"delta": "hello"})
yield StreamEvent("final", {})
events.append(("cleanup", "source", "after-final"))

yielded = [event async for event in channel.stream_events(message, source())]

assert events == [
("start", "normal", "hello"),
("finish", "normal", "hello"),
("cleanup", "source", "after-final"),
]
assert [event.kind for event in yielded] == ["text", "final"]


def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None:
home = tmp_path / "home"
workspace = tmp_path / "workspace"
Expand All @@ -497,6 +552,7 @@ class FakeLive:
def __init__(self, renderable, **kwargs) -> None:
live_calls.append(("init", renderable))
live_calls.append(("transient", kwargs["transient"]))
live_calls.append(("vertical_overflow", kwargs["vertical_overflow"]))
self.renderable = renderable

def start(self, *, refresh: bool = False) -> None:
Expand All @@ -519,8 +575,12 @@ def stop(self) -> None:
renderer.finish_stream(live, kind="normal", text="hello") # type: ignore[arg-type]

assert ("transient", False) in live_calls
assert ("vertical_overflow", "ellipsis") in live_calls
assert ("start_refresh", True) in live_calls
assert ("update_refresh", True) in live_calls
assert [call for call in live_calls if call[0] == "update_refresh"] == [
("update_refresh", True),
("update_refresh", False),
]
assert not printed


Expand Down
28 changes: 25 additions & 3 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ def test_model_tools_rewrites_dotted_names_without_mutating_original() -> None:
REGISTRY.pop(tool_name, None)

@tool(name=tool_name, description="rename")
def rename_me() -> str:
def rename_me(value: str) -> str:
return "ok"

rewritten = model_tools([rename_me])

assert [item.name for item in rewritten] == ["tests_rename_me"]
assert rewritten[0].parameters == rename_me.parameters
assert rename_me.name == tool_name
assert "additionalProperties" not in rename_me.parameters


def test_render_tools_prompt_renders_available_tools_block() -> None:
Expand All @@ -118,7 +120,20 @@ def prompt_two() -> str:

rendered = render_tools_prompt([prompt_one, prompt_two])

assert rendered == "<available_tools>\n- tests_prompt_one: First tool\n- tests_prompt_two\n</available_tools>"
assert rendered == "<available_tools>\n- tests_prompt_one(): First tool\n- tests_prompt_two()\n</available_tools>"


def test_render_tools_prompt_includes_model_name_and_parameter_signature() -> None:
tool_name = "tests.prompt_signature"
REGISTRY.pop(tool_name, None)

@tool(name=tool_name, description="Read a file")
def prompt_signature(path: str, offset: int = 0) -> str:
return f"{path}:{offset}"

rendered = render_tools_prompt([prompt_signature])

assert rendered == "<available_tools>\n- tests_prompt_signature(path, offset?): Read a file\n</available_tools>"


def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None:
Expand All @@ -128,8 +143,10 @@ def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None:
def test_resolve_tool_names_accepts_runtime_names_and_model_aliases() -> None:
dotted_name = "tests.resolve_alias"
underscored_name = "tests_with_underscore"
excluded_name = "tests.excluded_tool"
REGISTRY.pop(dotted_name, None)
REGISTRY.pop(underscored_name, None)
REGISTRY.pop(excluded_name, None)

@tool(name=dotted_name)
def resolve_alias() -> str:
Expand All @@ -139,11 +156,16 @@ def resolve_alias() -> str:
def resolve_runtime_name() -> str:
return "runtime"

assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" subagent "}) == {
@tool(name=excluded_name)
def excluded_tool() -> str:
return "excluded"

assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" tests_excluded_tool "}) == {
dotted_name,
underscored_name,
}
assert dotted_name not in resolve_tool_names(None, exclude={" tests_resolve_alias "})
assert excluded_name not in resolve_tool_names(None, exclude={" tests_excluded_tool "})
assert resolve_tool_names(None, exclude={" tests_resolve_alias "}) >= {underscored_name}


Expand Down
Loading