diff --git a/src/bub/channels/base.py b/src/bub/channels/base.py index 5772e5e..665ee66 100644 --- a/src/bub/channels/base.py +++ b/src/bub/channels/base.py @@ -1,5 +1,6 @@ import asyncio from abc import ABC, abstractmethod +from collections.abc import AsyncIterable from typing import ClassVar from republic import StreamEvent @@ -35,7 +36,6 @@ async def send(self, message: ChannelMessage) -> None: # Do nothing by default return - async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None: - """Handle an event from the agent. Optional to implement.""" - # Do nothing by default - return + def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: + """Optionally wrap the output stream for this channel.""" + return stream diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index ea137f8..ac627df 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -1,7 +1,6 @@ import asyncio import contextlib -from collections.abc import AsyncGenerator -from dataclasses import dataclass +from collections.abc import AsyncGenerator, AsyncIterable from datetime import datetime from hashlib import md5 from pathlib import Path @@ -20,19 +19,12 @@ from bub.builtin.tape import TapeInfo from bub.channels.base import Channel from bub.channels.cli.renderer import CliRenderer -from bub.channels.message import ChannelMessage, MessageKind +from bub.channels.message import ChannelMessage from bub.envelope import field_of from bub.tools import REGISTRY from bub.types import MessageHandler -@dataclass -class _StreamRenderState: - live: Live - kind: MessageKind - text: str = "" - - class CliChannel(Channel): """A simple CLI channel for testing and debugging.""" @@ -75,6 +67,11 @@ async def stop(self) -> None: with contextlib.suppress(asyncio.CancelledError): await self._main_task + async def send(self, message: ChannelMessage) -> None: + if message.kind != "error": + return + self._renderer.error(message.content) + async def _main_loop(self) -> None: self._renderer.welcome(model=self._agent.settings.model, workspace=str(self._workspace)) await self._refresh_tape_info() @@ -131,21 +128,25 @@ def _prompt_message(self) -> FormattedText: symbol = ">" if self._mode == "agent" else "," return FormattedText([("bold", f"{cwd} {symbol} ")]) - async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None: - streams = self._stream_render_states() - state = streams.get(message.session_id) - if event.kind == "text": - if state is None: - state = _StreamRenderState(live=self._renderer.start_stream(message.kind), kind=message.kind) - streams[message.session_id] = state - content = str(event.data.get("delta", "")) - state.text += content - self._renderer.update_stream(state.live, kind=message.kind, text=state.text) - elif event.kind == "final": - if state is None: - return - self._renderer.finish_stream(state.live, kind=state.kind, text=state.text) - streams.pop(message.session_id, None) + async def stream_events( + self, message: ChannelMessage, stream: AsyncIterable[StreamEvent] + ) -> AsyncIterable[StreamEvent]: + live: Live | None = None + text = "" + try: + async for event in stream: + if event.kind == "text": + content = str(event.data.get("delta", "")) + if not content.strip() and not text: + continue # skip leading whitespace-only events + if live is None: + live = self._renderer.start_stream(message.kind) + text += content + self._renderer.update_stream(live, kind=message.kind, text=text) + yield event + finally: + if live is not None: + self._renderer.finish_stream(live, kind=message.kind, text=text) def _build_prompt(self, workspace: Path) -> PromptSession[str]: kb = KeyBindings() @@ -188,10 +189,3 @@ def _render_bottom_toolbar(self) -> FormattedText: def _history_file(home: Path, workspace: Path) -> Path: workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest() return home / "history" / f"{workspace_hash}.history" - - def _stream_render_states(self) -> dict[str, _StreamRenderState]: - states = getattr(self, "_active_stream_renders", None) - if states is None: - states = {} - self._active_stream_renders = states - return states diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 8643e69..589d5d0 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -1,7 +1,7 @@ import asyncio import contextlib import functools -from collections.abc import Collection +from collections.abc import AsyncIterable, Collection from loguru import logger from pydantic import Field @@ -94,17 +94,17 @@ async def dispatch_output(self, message: Envelope) -> bool: await channel.send(outbound) return True - async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None: + def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: channel_name = field_of(message, "output_channel", field_of(message, "channel")) if channel_name is None: - return + return stream channel_key = str(channel_name) channel = self.get_channel(channel_key) if channel is None: - return + return stream - await channel.on_event(event, message) + return channel.stream_events(message, stream) async def quit(self, session_id: str) -> None: tasks = self._ongoing_tasks.pop(session_id, set()) diff --git a/src/bub/framework.py b/src/bub/framework.py index 55bc8e5..85cbb6e 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -140,8 +140,9 @@ async def _run_model( return prompt if isinstance(prompt, str) else content_of(inbound) else: parts: list[str] = [] + if self._outbound_router is not None: + stream = self._outbound_router.wrap_stream(inbound, stream) async for event in stream: - await self.dispatch_event_via_router(event, inbound) if event.kind == "text": parts.append(str(event.data.get("delta", ""))) elif event.kind == "error": @@ -163,12 +164,6 @@ async def dispatch_via_router(self, message: Envelope) -> bool: return False return await self._outbound_router.dispatch_output(message) - async def dispatch_event_via_router(self, event: Any, message: Envelope) -> bool: - if self._outbound_router is not None: - await self._outbound_router.dispatch_event(event, message) - return True - return False - async def quit_via_router(self, session_id: str) -> None: if self._outbound_router is not None: await self._outbound_router.quit(session_id) diff --git a/src/bub/types.py b/src/bub/types.py index 0f84ea3..b85bf2c 100644 --- a/src/bub/types.py +++ b/src/bub/types.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Callable, Coroutine +from collections.abc import AsyncIterable, Callable, Coroutine from dataclasses import dataclass, field from typing import Any, Protocol @@ -16,7 +16,7 @@ class OutboundChannelRouter(Protocol): async def dispatch_output(self, message: Envelope) -> bool: ... - async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None: ... + def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: ... async def quit(self, session_id: str) -> None: ... diff --git a/tests/test_channels.py b/tests/test_channels.py index db09562..7e5bee7 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -207,7 +207,7 @@ def test_cli_channel_normalize_input_prefixes_shell_commands() -> None: @pytest.mark.asyncio -async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send() -> None: +async def test_cli_channel_stream_events_renders_stream_and_yields_events() -> None: channel = CliChannel.__new__(CliChannel) events: list[tuple[str, str, str]] = [] live_handle = object() @@ -222,10 +222,12 @@ async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send( message = _message("ignored", channel="cli", kind="command", session_id="cli:1") - await channel.on_event(StreamEvent("text", {"delta": "hel"}), message) - await channel.on_event(StreamEvent("text", {"delta": "lo"}), message) - await channel.on_event(StreamEvent("final", {}), message) - await channel.send(_message("hello", channel="cli", kind="command", session_id="cli:1")) + async def source() -> asyncio.AsyncIterator[StreamEvent]: + yield StreamEvent("text", {"delta": "hel"}) + yield StreamEvent("text", {"delta": "lo"}) + yield StreamEvent("final", {}) + + yielded = [event async for event in channel.stream_events(message, source())] assert events == [ ("start", "command", ""), @@ -233,6 +235,7 @@ async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send( ("update", "command", "hello"), ("finish", "command", "hello"), ] + assert [event.kind for event in yielded] == ["text", "text", "final"] def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: