Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 102 additions & 25 deletions src/bub/channels/handler.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,124 @@
import asyncio
from collections import deque
from dataclasses import dataclass

from loguru import logger

from bub.channels.message import ChannelMessage
from bub.types import MessageHandler


@dataclass
class _CommandItem:
message: ChannelMessage


@dataclass
class _BatchItem:
"""A batch of consecutive non-command messages."""

loop: asyncio.AbstractEventLoop
messages: list[ChannelMessage]
ready: asyncio.Event
timer: asyncio.TimerHandle | None = None
sealed: bool = False

def schedule(self, timeout: float) -> None:
if self.sealed:
return
self.ready.clear()
if self.timer is not None:
self.timer.cancel()
self.timer = self.loop.call_later(timeout, self._fire)

def _fire(self) -> None:
self.timer = None
self.sealed = True
self.ready.set()


class BufferedMessageHandler:
"""A message handler that buffers incoming messages and processes them in batch with debounce and active time window."""
"""Per-session message buffer with batching and strict arrival-order serialization."""

def __init__(
self, handler: MessageHandler, *, active_time_window: float, max_wait_seconds: float, debounce_seconds: float
) -> None:
self._handler = handler
self._pending_messages: list[ChannelMessage] = []
self._last_active_time: float | None = None
self._event = asyncio.Event()
self._timer: asyncio.TimerHandle | None = None
self._in_processing: asyncio.Task | None = None
self._loop = asyncio.get_running_loop()

self._work: deque[_CommandItem | _BatchItem] = deque()
self._in_processing: asyncio.Task | None = None

self._last_active_time: float | None = None
self.active_time_window = active_time_window
self.max_wait_seconds = max_wait_seconds
self.debounce_seconds = debounce_seconds

def _reset_timer(self, timeout: float) -> None:
self._event.clear()
if self._timer:
self._timer.cancel()
self._timer = self._loop.call_later(timeout, self._event.set)
@staticmethod
def _is_command(message: ChannelMessage) -> bool:
return message.content.startswith(",")

def _ensure_worker(self) -> None:
if self._in_processing is None:
self._in_processing = asyncio.create_task(self._process())

def _append_to_tail_batch(self, message: ChannelMessage) -> _BatchItem:
if self._work and isinstance(self._work[-1], _BatchItem) and not self._work[-1].sealed:
batch = self._work[-1]
batch.messages.append(message)
return batch

batch = _BatchItem(loop=self._loop, messages=[message], ready=asyncio.Event())
self._work.append(batch)
return batch

async def _process(self) -> None:
await self._event.wait()
message = ChannelMessage.from_batch(self._pending_messages)
self._pending_messages.clear()
self._in_processing = None
await self._handler(message)
try:
while True:
if not self._work:
return

item = self._work[0]
if isinstance(item, _CommandItem):
self._work.popleft()
try:
await self._handler(item.message)
except asyncio.CancelledError:
raise
except Exception:
logger.exception(
"session.message command handler failed session_id={}, content={}",
item.message.session_id,
item.message.content,
)
continue

await item.ready.wait()
self._work.popleft()
try:
merged = ChannelMessage.from_batch(item.messages)
await self._handler(merged)
except asyncio.CancelledError:
raise
except Exception:
session_id = item.messages[-1].session_id if item.messages else "unknown"
logger.exception("session.message batch handler failed session_id={}", session_id)
finally:
self._in_processing = None
if self._work:
self._ensure_worker()

async def __call__(self, message: ChannelMessage) -> None:
now = self._loop.time()
if message.content.startswith(","):

if self._is_command(message):
logger.info(
"session.message received command session_id={}, content={}", message.session_id, message.content
)
await self._handler(message)
self._work.append(_CommandItem(message))
self._ensure_worker()
return

if not message.is_active and (
self._last_active_time is None or now - self._last_active_time > self.active_time_window
):
Expand All @@ -53,16 +127,19 @@ async def __call__(self, message: ChannelMessage) -> None:
"session.message received ignored session_id={}, content={}", message.session_id, message.content
)
return
self._pending_messages.append(message)

batch = self._append_to_tail_batch(message)

if message.is_active:
self._last_active_time = now
logger.info(
"session.message received active session_id={}, content={}", message.session_id, message.content
)
self._reset_timer(self.debounce_seconds)
if self._in_processing is None:
self._in_processing = asyncio.create_task(self._process())
elif self._last_active_time is not None and self._in_processing is None:
batch.schedule(self.debounce_seconds)
self._ensure_worker()
return

if self._last_active_time is not None:
logger.info("session.receive followup session_id={} message={}", message.session_id, message.content)
self._reset_timer(self.max_wait_seconds)
self._in_processing = asyncio.create_task(self._process())
batch.schedule(self.max_wait_seconds)
self._ensure_worker()
118 changes: 118 additions & 0 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ def _message(
@pytest.mark.asyncio
async def test_buffered_handler_passes_commands_through_immediately() -> None:
handled: list[str] = []
done = asyncio.Event()

async def receive(message: ChannelMessage) -> None:
handled.append(message.content)
done.set()

handler = BufferedMessageHandler(
receive,
Expand All @@ -90,10 +92,116 @@ async def receive(message: ChannelMessage) -> None:
)

await handler(_message(",help"))
await asyncio.wait_for(done.wait(), timeout=1)

assert handled == [",help"]


@pytest.mark.asyncio
async def test_buffered_handler_batches_active_and_followup_messages() -> None:
handled: list[str] = []
done = asyncio.Event()

async def receive(message: ChannelMessage) -> None:
handled.append(message.content)
done.set()

handler = BufferedMessageHandler(
receive,
active_time_window=10,
max_wait_seconds=0.01,
debounce_seconds=0.05,
)

await handler(_message("a", is_active=True))
await handler(_message("b"))
await asyncio.wait_for(done.wait(), timeout=1)

assert handled == ["a\nb"]


@pytest.mark.asyncio
async def test_buffered_handler_serializes_messages_while_handler_is_running() -> None:
handled: list[str] = []
first_started = asyncio.Event()
release_first = asyncio.Event()
max_inflight = 0
inflight = 0

async def receive(message: ChannelMessage) -> None:
nonlocal inflight, max_inflight
inflight += 1
max_inflight = max(max_inflight, inflight)
handled.append(message.content)
if message.content == "a":
first_started.set()
await release_first.wait()
inflight -= 1

handler = BufferedMessageHandler(
receive,
active_time_window=10,
max_wait_seconds=0.01,
debounce_seconds=0.01,
)

await handler(_message("a", is_active=True))
await asyncio.wait_for(first_started.wait(), timeout=1)

await handler(_message("b", is_active=True))
await asyncio.sleep(0.03)

assert handled == ["a"]
assert max_inflight == 1

release_first.set()
await _wait_for(lambda: handled == ["a", "b"])
assert max_inflight == 1


@pytest.mark.asyncio
async def test_buffered_handler_preserves_order_around_commands() -> None:
handled: list[str] = []

async def receive(message: ChannelMessage) -> None:
handled.append(message.content)

handler = BufferedMessageHandler(
receive,
active_time_window=10,
max_wait_seconds=0.01,
debounce_seconds=0.01,
)

await handler(_message("a", is_active=True))
await handler(_message(",help"))
await handler(_message("b", is_active=True))

await _wait_for(lambda: handled == ["a", ",help", "b"])


@pytest.mark.asyncio
async def test_buffered_handler_continues_after_handler_error() -> None:
handled: list[str] = []

async def receive(message: ChannelMessage) -> None:
if message.content == ",boom":
raise RuntimeError("boom")
handled.append(message.content)

handler = BufferedMessageHandler(
receive,
active_time_window=10,
max_wait_seconds=0.01,
debounce_seconds=0.01,
)

await handler(_message(",boom"))
await handler(_message(",ok"))

await _wait_for(lambda: handled == [",ok"])


@pytest.mark.asyncio
async def test_channel_manager_dispatch_uses_output_channel_and_preserves_metadata() -> None:
cli_channel = FakeChannel("cli")
Expand Down Expand Up @@ -155,6 +263,16 @@ async def __call__(self, message: ChannelMessage) -> None:
assert isinstance(manager._session_handlers[message.session_id], StubBufferedMessageHandler)


async def _wait_for(predicate, timeout: float = 1.0) -> None:
deadline = asyncio.get_running_loop().time() + timeout
while True:
if predicate():
return
if asyncio.get_running_loop().time() >= deadline:
raise AssertionError("condition not met before timeout")
await asyncio.sleep(0.001)


@pytest.mark.asyncio
async def test_channel_manager_shutdown_cancels_tasks_and_stops_enabled_channels() -> None:
telegram = FakeChannel("telegram")
Expand Down
Loading