diff --git a/src/bub/channels/handler.py b/src/bub/channels/handler.py index 24e664d..7d9056b 100644 --- a/src/bub/channels/handler.py +++ b/src/bub/channels/handler.py @@ -1,4 +1,6 @@ import asyncio +from collections import deque +from dataclasses import dataclass from loguru import logger @@ -6,45 +8,117 @@ from bub.types import MessageHandler +@dataclass +class _CommandItem: + message: ChannelMessage + + +@dataclass +class _BatchItem: + """A batch of consecutive non-command messages.""" + + loop: asyncio.AbstractEventLoop + messages: list[ChannelMessage] + ready: asyncio.Event + timer: asyncio.TimerHandle | None = None + sealed: bool = False + + def schedule(self, timeout: float) -> None: + if self.sealed: + return + self.ready.clear() + if self.timer is not None: + self.timer.cancel() + self.timer = self.loop.call_later(timeout, self._fire) + + def _fire(self) -> None: + self.timer = None + self.sealed = True + self.ready.set() + + class BufferedMessageHandler: - """A message handler that buffers incoming messages and processes them in batch with debounce and active time window.""" + """Per-session message buffer with batching and strict arrival-order serialization.""" def __init__( self, handler: MessageHandler, *, active_time_window: float, max_wait_seconds: float, debounce_seconds: float ) -> None: self._handler = handler - self._pending_messages: list[ChannelMessage] = [] - self._last_active_time: float | None = None - self._event = asyncio.Event() - self._timer: asyncio.TimerHandle | None = None - self._in_processing: asyncio.Task | None = None self._loop = asyncio.get_running_loop() + self._work: deque[_CommandItem | _BatchItem] = deque() + self._in_processing: asyncio.Task | None = None + + self._last_active_time: float | None = None self.active_time_window = active_time_window self.max_wait_seconds = max_wait_seconds self.debounce_seconds = debounce_seconds - def _reset_timer(self, timeout: float) -> None: - self._event.clear() - if self._timer: - self._timer.cancel() - self._timer = self._loop.call_later(timeout, self._event.set) + @staticmethod + def _is_command(message: ChannelMessage) -> bool: + return message.content.startswith(",") + + def _ensure_worker(self) -> None: + if self._in_processing is None: + self._in_processing = asyncio.create_task(self._process()) + + def _append_to_tail_batch(self, message: ChannelMessage) -> _BatchItem: + if self._work and isinstance(self._work[-1], _BatchItem) and not self._work[-1].sealed: + batch = self._work[-1] + batch.messages.append(message) + return batch + + batch = _BatchItem(loop=self._loop, messages=[message], ready=asyncio.Event()) + self._work.append(batch) + return batch async def _process(self) -> None: - await self._event.wait() - message = ChannelMessage.from_batch(self._pending_messages) - self._pending_messages.clear() - self._in_processing = None - await self._handler(message) + try: + while True: + if not self._work: + return + + item = self._work[0] + if isinstance(item, _CommandItem): + self._work.popleft() + try: + await self._handler(item.message) + except asyncio.CancelledError: + raise + except Exception: + logger.exception( + "session.message command handler failed session_id={}, content={}", + item.message.session_id, + item.message.content, + ) + continue + + await item.ready.wait() + self._work.popleft() + try: + merged = ChannelMessage.from_batch(item.messages) + await self._handler(merged) + except asyncio.CancelledError: + raise + except Exception: + session_id = item.messages[-1].session_id if item.messages else "unknown" + logger.exception("session.message batch handler failed session_id={}", session_id) + finally: + self._in_processing = None + if self._work: + self._ensure_worker() async def __call__(self, message: ChannelMessage) -> None: now = self._loop.time() - if message.content.startswith(","): + + if self._is_command(message): logger.info( "session.message received command session_id={}, content={}", message.session_id, message.content ) - await self._handler(message) + self._work.append(_CommandItem(message)) + self._ensure_worker() return + if not message.is_active and ( self._last_active_time is None or now - self._last_active_time > self.active_time_window ): @@ -53,16 +127,19 @@ async def __call__(self, message: ChannelMessage) -> None: "session.message received ignored session_id={}, content={}", message.session_id, message.content ) return - self._pending_messages.append(message) + + batch = self._append_to_tail_batch(message) + if message.is_active: self._last_active_time = now logger.info( "session.message received active session_id={}, content={}", message.session_id, message.content ) - self._reset_timer(self.debounce_seconds) - if self._in_processing is None: - self._in_processing = asyncio.create_task(self._process()) - elif self._last_active_time is not None and self._in_processing is None: + batch.schedule(self.debounce_seconds) + self._ensure_worker() + return + + if self._last_active_time is not None: logger.info("session.receive followup session_id={} message={}", message.session_id, message.content) - self._reset_timer(self.max_wait_seconds) - self._in_processing = asyncio.create_task(self._process()) + batch.schedule(self.max_wait_seconds) + self._ensure_worker() diff --git a/tests/test_channels.py b/tests/test_channels.py index 7e5bee7..a1adf9f 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -78,9 +78,11 @@ def _message( @pytest.mark.asyncio async def test_buffered_handler_passes_commands_through_immediately() -> None: handled: list[str] = [] + done = asyncio.Event() async def receive(message: ChannelMessage) -> None: handled.append(message.content) + done.set() handler = BufferedMessageHandler( receive, @@ -90,10 +92,116 @@ async def receive(message: ChannelMessage) -> None: ) await handler(_message(",help")) + await asyncio.wait_for(done.wait(), timeout=1) assert handled == [",help"] +@pytest.mark.asyncio +async def test_buffered_handler_batches_active_and_followup_messages() -> None: + handled: list[str] = [] + done = asyncio.Event() + + async def receive(message: ChannelMessage) -> None: + handled.append(message.content) + done.set() + + handler = BufferedMessageHandler( + receive, + active_time_window=10, + max_wait_seconds=0.01, + debounce_seconds=0.05, + ) + + await handler(_message("a", is_active=True)) + await handler(_message("b")) + await asyncio.wait_for(done.wait(), timeout=1) + + assert handled == ["a\nb"] + + +@pytest.mark.asyncio +async def test_buffered_handler_serializes_messages_while_handler_is_running() -> None: + handled: list[str] = [] + first_started = asyncio.Event() + release_first = asyncio.Event() + max_inflight = 0 + inflight = 0 + + async def receive(message: ChannelMessage) -> None: + nonlocal inflight, max_inflight + inflight += 1 + max_inflight = max(max_inflight, inflight) + handled.append(message.content) + if message.content == "a": + first_started.set() + await release_first.wait() + inflight -= 1 + + handler = BufferedMessageHandler( + receive, + active_time_window=10, + max_wait_seconds=0.01, + debounce_seconds=0.01, + ) + + await handler(_message("a", is_active=True)) + await asyncio.wait_for(first_started.wait(), timeout=1) + + await handler(_message("b", is_active=True)) + await asyncio.sleep(0.03) + + assert handled == ["a"] + assert max_inflight == 1 + + release_first.set() + await _wait_for(lambda: handled == ["a", "b"]) + assert max_inflight == 1 + + +@pytest.mark.asyncio +async def test_buffered_handler_preserves_order_around_commands() -> None: + handled: list[str] = [] + + async def receive(message: ChannelMessage) -> None: + handled.append(message.content) + + handler = BufferedMessageHandler( + receive, + active_time_window=10, + max_wait_seconds=0.01, + debounce_seconds=0.01, + ) + + await handler(_message("a", is_active=True)) + await handler(_message(",help")) + await handler(_message("b", is_active=True)) + + await _wait_for(lambda: handled == ["a", ",help", "b"]) + + +@pytest.mark.asyncio +async def test_buffered_handler_continues_after_handler_error() -> None: + handled: list[str] = [] + + async def receive(message: ChannelMessage) -> None: + if message.content == ",boom": + raise RuntimeError("boom") + handled.append(message.content) + + handler = BufferedMessageHandler( + receive, + active_time_window=10, + max_wait_seconds=0.01, + debounce_seconds=0.01, + ) + + await handler(_message(",boom")) + await handler(_message(",ok")) + + await _wait_for(lambda: handled == [",ok"]) + + @pytest.mark.asyncio async def test_channel_manager_dispatch_uses_output_channel_and_preserves_metadata() -> None: cli_channel = FakeChannel("cli") @@ -155,6 +263,16 @@ async def __call__(self, message: ChannelMessage) -> None: assert isinstance(manager._session_handlers[message.session_id], StubBufferedMessageHandler) +async def _wait_for(predicate, timeout: float = 1.0) -> None: + deadline = asyncio.get_running_loop().time() + timeout + while True: + if predicate(): + return + if asyncio.get_running_loop().time() >= deadline: + raise AssertionError("condition not met before timeout") + await asyncio.sleep(0.001) + + @pytest.mark.asyncio async def test_channel_manager_shutdown_cancels_tasks_and_stops_enabled_channels() -> None: telegram = FakeChannel("telegram")