Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/bub/channels/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable
from typing import ClassVar

from republic import StreamEvent
Expand Down Expand Up @@ -35,7 +36,6 @@ async def send(self, message: ChannelMessage) -> None:
# Do nothing by default
return

async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None:
"""Handle an event from the agent. Optional to implement."""
# Do nothing by default
return
def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]:
"""Optionally wrap the output stream for this channel."""
return stream
58 changes: 26 additions & 32 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import contextlib
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from collections.abc import AsyncGenerator, AsyncIterable
from datetime import datetime
from hashlib import md5
from pathlib import Path
Expand All @@ -20,19 +19,12 @@
from bub.builtin.tape import TapeInfo
from bub.channels.base import Channel
from bub.channels.cli.renderer import CliRenderer
from bub.channels.message import ChannelMessage, MessageKind
from bub.channels.message import ChannelMessage
from bub.envelope import field_of
from bub.tools import REGISTRY
from bub.types import MessageHandler


@dataclass
class _StreamRenderState:
live: Live
kind: MessageKind
text: str = ""


class CliChannel(Channel):
"""A simple CLI channel for testing and debugging."""

Expand Down Expand Up @@ -75,6 +67,11 @@ async def stop(self) -> None:
with contextlib.suppress(asyncio.CancelledError):
await self._main_task

async def send(self, message: ChannelMessage) -> None:
if message.kind != "error":
return
self._renderer.error(message.content)

async def _main_loop(self) -> None:
self._renderer.welcome(model=self._agent.settings.model, workspace=str(self._workspace))
await self._refresh_tape_info()
Expand Down Expand Up @@ -131,21 +128,25 @@ def _prompt_message(self) -> FormattedText:
symbol = ">" if self._mode == "agent" else ","
return FormattedText([("bold", f"{cwd} {symbol} ")])

async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None:
streams = self._stream_render_states()
state = streams.get(message.session_id)
if event.kind == "text":
if state is None:
state = _StreamRenderState(live=self._renderer.start_stream(message.kind), kind=message.kind)
streams[message.session_id] = state
content = str(event.data.get("delta", ""))
state.text += content
self._renderer.update_stream(state.live, kind=message.kind, text=state.text)
elif event.kind == "final":
if state is None:
return
self._renderer.finish_stream(state.live, kind=state.kind, text=state.text)
streams.pop(message.session_id, None)
async def stream_events(
self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]
) -> AsyncIterable[StreamEvent]:
live: Live | None = None
text = ""
try:
async for event in stream:
if event.kind == "text":
content = str(event.data.get("delta", ""))
if not content.strip() and not text:
continue # skip leading whitespace-only events
if live is None:
live = self._renderer.start_stream(message.kind)
text += content
self._renderer.update_stream(live, kind=message.kind, text=text)
yield event
finally:
if live is not None:
self._renderer.finish_stream(live, kind=message.kind, text=text)

def _build_prompt(self, workspace: Path) -> PromptSession[str]:
kb = KeyBindings()
Expand Down Expand Up @@ -188,10 +189,3 @@ def _render_bottom_toolbar(self) -> FormattedText:
def _history_file(home: Path, workspace: Path) -> Path:
workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest()
return home / "history" / f"{workspace_hash}.history"

def _stream_render_states(self) -> dict[str, _StreamRenderState]:
states = getattr(self, "_active_stream_renders", None)
if states is None:
states = {}
self._active_stream_renders = states
return states
10 changes: 5 additions & 5 deletions src/bub/channels/manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import contextlib
import functools
from collections.abc import Collection
from collections.abc import AsyncIterable, Collection

from loguru import logger
from pydantic import Field
Expand Down Expand Up @@ -94,17 +94,17 @@ async def dispatch_output(self, message: Envelope) -> bool:
await channel.send(outbound)
return True

async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None:
def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]:
channel_name = field_of(message, "output_channel", field_of(message, "channel"))
if channel_name is None:
return
return stream

channel_key = str(channel_name)
channel = self.get_channel(channel_key)
if channel is None:
return
return stream

await channel.on_event(event, message)
return channel.stream_events(message, stream)

async def quit(self, session_id: str) -> None:
tasks = self._ongoing_tasks.pop(session_id, set())
Expand Down
9 changes: 2 additions & 7 deletions src/bub/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ async def _run_model(
return prompt if isinstance(prompt, str) else content_of(inbound)
else:
parts: list[str] = []
if self._outbound_router is not None:
stream = self._outbound_router.wrap_stream(inbound, stream)
async for event in stream:
await self.dispatch_event_via_router(event, inbound)
if event.kind == "text":
parts.append(str(event.data.get("delta", "")))
elif event.kind == "error":
Expand All @@ -163,12 +164,6 @@ async def dispatch_via_router(self, message: Envelope) -> bool:
return False
return await self._outbound_router.dispatch_output(message)

async def dispatch_event_via_router(self, event: Any, message: Envelope) -> bool:
if self._outbound_router is not None:
await self._outbound_router.dispatch_event(event, message)
return True
return False

async def quit_via_router(self, session_id: str) -> None:
if self._outbound_router is not None:
await self._outbound_router.quit(session_id)
Expand Down
4 changes: 2 additions & 2 deletions src/bub/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import Callable, Coroutine
from collections.abc import AsyncIterable, Callable, Coroutine
from dataclasses import dataclass, field
from typing import Any, Protocol

Expand All @@ -16,7 +16,7 @@

class OutboundChannelRouter(Protocol):
async def dispatch_output(self, message: Envelope) -> bool: ...
async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None: ...
def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: ...
async def quit(self, session_id: str) -> None: ...


Expand Down
13 changes: 8 additions & 5 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_cli_channel_normalize_input_prefixes_shell_commands() -> None:


@pytest.mark.asyncio
async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send() -> None:
async def test_cli_channel_stream_events_renders_stream_and_yields_events() -> None:
channel = CliChannel.__new__(CliChannel)
events: list[tuple[str, str, str]] = []
live_handle = object()
Expand All @@ -222,17 +222,20 @@ async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send(

message = _message("ignored", channel="cli", kind="command", session_id="cli:1")

await channel.on_event(StreamEvent("text", {"delta": "hel"}), message)
await channel.on_event(StreamEvent("text", {"delta": "lo"}), message)
await channel.on_event(StreamEvent("final", {}), message)
await channel.send(_message("hello", channel="cli", kind="command", session_id="cli:1"))
async def source() -> asyncio.AsyncIterator[StreamEvent]:
yield StreamEvent("text", {"delta": "hel"})
yield StreamEvent("text", {"delta": "lo"})
yield StreamEvent("final", {})

yielded = [event async for event in channel.stream_events(message, source())]

assert events == [
("start", "command", ""),
("update", "command", "hel"),
("update", "command", "hello"),
("finish", "command", "hello"),
]
assert [event.kind for event in yielded] == ["text", "text", "final"]


def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None:
Expand Down
Loading