Skip to content

Commit b496cc0

Browse files
committed
feat: implement session runner for managing message processing and debounce logic
Signed-off-by: Frost Ming <me@frostming.com>
1 parent c52fa65 commit b496cc0

11 files changed

Lines changed: 138 additions & 81 deletions

src/bub/channels/base.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from collections.abc import Awaitable, Callable
77
from typing import TYPE_CHECKING, Any
88

9-
from loguru import logger
10-
119
from bub.app.runtime import AppRuntime
1210

1311
if TYPE_CHECKING:
@@ -30,26 +28,25 @@ def __init__(self, runtime: AppRuntime) -> None:
3028
async def start(self, on_receive: Callable[[T], Awaitable[None]]) -> None:
3129
"""Start the channel and set up the receive callback."""
3230

31+
@property
32+
def output_channel(self) -> str:
33+
"""The name of the channel to send outputs to. Defaults to the same channel."""
34+
return self.name
35+
36+
@abstractmethod
37+
def is_mentioned(self, message: T) -> bool:
38+
"""Determine if the message is relevant to this channel."""
39+
3340
@abstractmethod
34-
async def get_session_prompt(self, message: T) -> tuple[str, str] | None:
35-
"""Get the session id and prompt text for the given message.
36-
If None is returned, the message will be ignored.
37-
"""
41+
async def get_session_prompt(self, message: T) -> tuple[str, str]:
42+
"""Get the session id and prompt text for the given message."""
3843
pass
3944

45+
async def run_prompt(self, session_id: str, prompt: str) -> LoopResult:
46+
"""Run the given prompt through the runtime and return the result."""
47+
return await self.runtime.handle_input(session_id, prompt)
48+
4049
@abstractmethod
4150
async def process_output(self, session_id: str, output: LoopResult) -> None:
4251
"""Process the output returned by the LLM."""
4352
pass
44-
45-
async def run_prompt(self, message: T) -> None:
46-
"""Run a prompt based on the received message."""
47-
try:
48-
result = await self.get_session_prompt(message)
49-
if result is None:
50-
return
51-
session_id, prompt = result
52-
output = await self.runtime.handle_input(session_id, prompt)
53-
await self.process_output(session_id, output)
54-
except Exception:
55-
logger.exception("{}.agent.error", self.name)

src/bub/channels/discord.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def on_message(message: discord.Message) -> None:
9393
self._bot = None
9494
logger.info("discord.stopped")
9595

96-
async def get_session_prompt(self, message: discord.Message) -> tuple[str, str] | None:
96+
async def get_session_prompt(self, message: discord.Message) -> tuple[str, str]:
9797
channel_id = str(message.channel.id)
9898
session_id = f"{self.name}:{channel_id}"
9999
content, media = self._parse_message(message)
@@ -125,14 +125,10 @@ async def get_session_prompt(self, message: discord.Message) -> tuple[str, str]
125125
metadata["reply_to_message"] = reply_meta
126126

127127
metadata_json = json.dumps(
128-
{"channel": "discord", "channel_id": channel_id, **exclude_none(metadata)}, ensure_ascii=False
129-
)
130-
prompt = (
131-
"IMPORTANT: Please reply to this $discord message unless otherwise instructed.\n\n"
132-
f"{content}\n———————\n{metadata_json}"
128+
{"message": content, "channel_id": channel_id, **exclude_none(metadata)}, ensure_ascii=False
133129
)
134130
self._latest_message_by_session[session_id] = message
135-
return session_id, prompt
131+
return session_id, metadata_json
136132

137133
async def process_output(self, session_id: str, output: LoopResult) -> None:
138134
parts = [part for part in (output.immediate_output, output.assistant_output) if part]
@@ -163,8 +159,6 @@ async def process_output(self, session_id: str, output: LoopResult) -> None:
163159
async def _on_message(self, message: discord.Message) -> None:
164160
if message.author.bot:
165161
return
166-
if not self._allow_message(message):
167-
return
168162
if self._on_receive is None:
169163
logger.warning("discord.inbound no handler for received messages")
170164
return
@@ -194,7 +188,7 @@ async def _resolve_channel(self, session_id: str) -> discord.abc.Messageable | N
194188
return fetched
195189
return None
196190

197-
def _allow_message(self, message: discord.Message) -> bool:
191+
def is_mentioned(self, message: discord.Message) -> bool:
198192
channel_id = str(message.channel.id)
199193
if self._config.allow_channels and channel_id not in self._config.allow_channels:
200194
return False

src/bub/channels/manager.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
import asyncio
66
import contextlib
7+
import functools
78

89
from loguru import logger
910

1011
from bub.app.runtime import AppRuntime
1112
from bub.channels.base import BaseChannel
13+
from bub.channels.runner import SessionRunner
1214

1315

1416
class ChannelManager:
@@ -18,6 +20,7 @@ def __init__(self, runtime: AppRuntime) -> None:
1820
self.runtime = runtime
1921
self._channels: dict[str, BaseChannel] = {}
2022
self._channel_tasks: list[asyncio.Task[None]] = []
23+
self._session_runners: dict[str, SessionRunner] = {}
2124
for channel_cls in self.default_channels():
2225
self.register(channel_cls)
2326
runtime.install_hooks(self)
@@ -33,9 +36,7 @@ def channels(self) -> dict[str, BaseChannel]:
3336
async def run(self) -> None:
3437
logger.info("channel.manager.start channels={}", self.enabled_channels())
3538
for channel in self._channels.values():
36-
# XXX: Currently we just call the same message handler with itself.
37-
# But it will be likely decoupled later
38-
task = asyncio.create_task(channel.start(channel.run_prompt))
39+
task = asyncio.create_task(channel.start(functools.partial(self._process_input, channel)))
3940
self._channel_tasks.append(task)
4041
try:
4142
await asyncio.gather(*self._channel_tasks)
@@ -63,3 +64,14 @@ def default_channels(self) -> list[type[BaseChannel]]:
6364

6465
result.append(DiscordChannel)
6566
return result
67+
68+
async def _process_input[T](self, channel: BaseChannel[T], message: T) -> None:
69+
session_id, _ = await channel.get_session_prompt(message)
70+
if session_id not in self._session_runners:
71+
self._session_runners[session_id] = SessionRunner(
72+
channel,
73+
session_id,
74+
self.runtime.settings.message_debounce_seconds,
75+
self.runtime.settings.message_delay_seconds,
76+
)
77+
await self._session_runners[session_id].process_message(message)

src/bub/channels/runner.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import asyncio
2+
from typing import Any
3+
4+
from loguru import logger
5+
6+
from bub.channels.base import BaseChannel
7+
8+
9+
class SessionRunner:
10+
def __init__(
11+
self, channel: BaseChannel, session_id: str, debounce_seconds: int, message_delay_seconds: int
12+
) -> None:
13+
self.session_id = session_id
14+
self.channel = channel
15+
self.debounce_seconds = debounce_seconds
16+
self.message_delay_seconds = message_delay_seconds
17+
self._prompts: list[str] = []
18+
self._event = asyncio.Event()
19+
self._timer: asyncio.TimerHandle | None = None
20+
self._last_received_at: float | None = None
21+
self._running_task: asyncio.Task[None] | None = None
22+
self._loop = asyncio.get_running_loop()
23+
24+
async def _run(self) -> None:
25+
await self._event.wait()
26+
prompt = f"channel: ${self.channel.output_channel}\n" + "\n".join(self._prompts)
27+
self._prompts.clear()
28+
self._last_received_at = None
29+
self._running_task = None
30+
try:
31+
result = await self.channel.run_prompt(self.session_id, prompt)
32+
await self.channel.process_output(self.session_id, result)
33+
except Exception:
34+
logger.exception("session.run.error session_id={}", self.session_id)
35+
36+
def reset_timer(self, timeout: int) -> None:
37+
self._event.clear()
38+
if self._timer:
39+
self._timer.cancel()
40+
self._timer = self._loop.call_later(timeout, self._event.set)
41+
42+
async def process_message(self, message: Any) -> None:
43+
is_mentioned = self.channel.is_mentioned(message)
44+
_, prompt = await self.channel.get_session_prompt(message)
45+
now = self._loop.time()
46+
if not is_mentioned and self._last_received_at is None:
47+
return
48+
self._prompts.append(prompt)
49+
if is_mentioned:
50+
# wait at most 1 second to reply to mentioned messages.
51+
self._last_received_at = now
52+
self.reset_timer(self.debounce_seconds)
53+
if self._running_task is None:
54+
self._running_task = asyncio.create_task(self._run())
55+
return await self._running_task
56+
elif self._last_received_at is not None and self._running_task is None:
57+
# Otherwise if bot is mentioned before, we will keep reading messages for at most 30 seconds.
58+
self.reset_timer(self.message_delay_seconds)
59+
self._running_task = asyncio.create_task(self._run())
60+
return await self._running_task

src/bub/channels/telegram.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import contextlib
77
import json
8-
from collections.abc import Awaitable, Callable
8+
from collections.abc import AsyncGenerator, Awaitable, Callable
99
from dataclasses import dataclass
1010
from typing import Any, ClassVar
1111

@@ -124,9 +124,11 @@ def __init__(self, runtime: AppRuntime) -> None:
124124
proxy=settings.telegram_proxy,
125125
)
126126
self._app: Application | None = None
127-
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
128127
self._on_receive: Callable[[Message], Awaitable[None]] | None = None
129128

129+
def is_mentioned(self, message: Message) -> bool:
130+
return bool(MESSAGE_FILTER.filter(message))
131+
130132
async def start(self, on_receive: Callable[[Message], Awaitable[None]]) -> None:
131133
self._on_receive = on_receive
132134
proxy, _ = resolve_proxy(self._config.proxy)
@@ -153,11 +155,6 @@ async def start(self, on_receive: Callable[[Message], Awaitable[None]]) -> None:
153155
try:
154156
await asyncio.Event().wait() # Keep running until stopped
155157
finally:
156-
for task in self._typing_tasks.values():
157-
task.cancel()
158-
with contextlib.suppress(asyncio.CancelledError):
159-
await asyncio.gather(*self._typing_tasks.values())
160-
self._typing_tasks.clear()
161158
updater = self._app.updater
162159
with contextlib.suppress(Exception):
163160
if updater is not None and updater.running:
@@ -167,9 +164,7 @@ async def start(self, on_receive: Callable[[Message], Awaitable[None]]) -> None:
167164
self._app = None
168165
logger.info("telegram.stopped")
169166

170-
async def get_session_prompt(self, message: Message) -> tuple[str, str] | None:
171-
if MESSAGE_FILTER.filter(message) is False:
172-
return None
167+
async def get_session_prompt(self, message: Message) -> tuple[str, str]:
173168
chat_id = str(message.chat_id)
174169
session_id = f"{self.name}:{chat_id}"
175170
content, media = self._parse_message(message)
@@ -208,9 +203,8 @@ async def get_session_prompt(self, message: Message) -> tuple[str, str] | None:
208203
if reply_meta:
209204
metadata["reply_to_message"] = reply_meta
210205

211-
metadata_json = json.dumps({"channel": f"${self.name}", "chat_id": chat_id, **metadata}, ensure_ascii=False)
212-
prompt = f"{content}\n———————\n{metadata_json}"
213-
return session_id, prompt
206+
metadata_json = json.dumps({"message": content, "chat_id": chat_id, **metadata}, ensure_ascii=False)
207+
return session_id, metadata_json
214208

215209
async def process_output(self, session_id: str, output: LoopResult) -> None:
216210
parts = [part for part in (output.immediate_output, output.assistant_output) if part]
@@ -259,30 +253,24 @@ async def _on_text(self, update: Update, _context: ContextTypes.DEFAULT_TYPE) ->
259253
if self._on_receive is None:
260254
logger.warning("telegram.inbound no handler for received messages")
261255
return
262-
await self._start_typing(chat_id)
263-
try:
256+
async with self._start_typing(chat_id):
264257
await self._on_receive(update.message)
265-
finally:
266-
await self._stop_typing(chat_id)
267258

268-
async def _start_typing(self, chat_id: str) -> None:
269-
await self._stop_typing(chat_id)
270-
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
271-
272-
async def _stop_typing(self, chat_id: str) -> None:
273-
task = self._typing_tasks.pop(chat_id, None)
274-
if task is not None:
275-
task.cancel()
259+
@contextlib.asynccontextmanager
260+
async def _start_typing(self, chat_id: str) -> AsyncGenerator[None, None]:
261+
typing_task = asyncio.create_task(self._typing_loop(chat_id))
262+
try:
263+
yield
264+
finally:
265+
typing_task.cancel()
276266
with contextlib.suppress(asyncio.CancelledError):
277-
await task
267+
await typing_task
278268

279269
async def _typing_loop(self, chat_id: str) -> None:
280270
try:
281271
while self._app is not None:
282272
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
283273
await asyncio.sleep(4)
284-
except asyncio.CancelledError:
285-
return
286274
except Exception:
287275
logger.exception("telegram.typing_loop.error chat_id={}", chat_id)
288276
return

src/bub/config/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class Settings(BaseSettings):
3838
max_steps: int = Field(default=20, ge=1)
3939

4040
proactive_response: bool = False
41+
message_delay_seconds: int = 10
42+
message_debounce_seconds: int = 1
4143

4244
telegram_enabled: bool = False
4345
telegram_token: str | None = None

tests/test_channels.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ async def get_session_prompt(self, message: object) -> tuple[str, str]:
4040
_ = message
4141
return "session", "prompt"
4242

43+
def is_mentioned(self, message: object) -> bool:
44+
_ = message
45+
return True
46+
4347
async def process_output(self, session_id: str, output) -> None:
4448
_ = (session_id, output)
4549

tests/test_discord_filter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,31 +43,31 @@ def _build_channel() -> DiscordChannel:
4343
def test_allow_message_when_content_contains_bub() -> None:
4444
channel = _build_channel()
4545
message = DummyMessage(content="please ask Bub to check this", channel=SimpleNamespace(id=100, name="general"))
46-
assert channel._allow_message(message) is True # type: ignore[arg-type]
46+
assert channel.is_mentioned(message) is True # type: ignore[arg-type]
4747

4848

4949
def test_allow_message_when_thread_name_starts_with_bub() -> None:
5050
channel = _build_channel()
5151
thread = SimpleNamespace(id=101, name="bub-help", parent=SimpleNamespace(name="forum"))
5252
message = DummyMessage(content="hello", channel=thread)
53-
assert channel._allow_message(message) is True # type: ignore[arg-type]
53+
assert channel.is_mentioned(message) is True # type: ignore[arg-type]
5454

5555

5656
def test_reject_message_when_only_parent_name_starts_with_bub() -> None:
5757
channel = _build_channel()
5858
thread = SimpleNamespace(id=102, name="question-1", parent=SimpleNamespace(name="bub-forum"))
5959
message = DummyMessage(content="hello", channel=thread)
60-
assert channel._allow_message(message) is False # type: ignore[arg-type]
60+
assert channel.is_mentioned(message) is False # type: ignore[arg-type]
6161

6262

6363
def test_reject_unrelated_message_without_bot_context() -> None:
6464
channel = _build_channel()
6565
message = DummyMessage(content="hello world", channel=SimpleNamespace(id=103, name="general"))
66-
assert channel._allow_message(message) is False # type: ignore[arg-type]
66+
assert channel.is_mentioned(message) is False # type: ignore[arg-type]
6767

6868

6969
def test_reject_empty_content_even_in_bub_thread() -> None:
7070
channel = _build_channel()
7171
thread = SimpleNamespace(id=104, name="bub-help", parent=SimpleNamespace(name="forum"))
7272
message = DummyMessage(content=" ", channel=thread)
73-
assert channel._allow_message(message) is False # type: ignore[arg-type]
73+
assert channel.is_mentioned(message) is False # type: ignore[arg-type]

tests/test_graceful_shutdown.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ async def get_session_prompt(self, message: object) -> tuple[str, str]:
3434
_ = message
3535
return "s", "p"
3636

37+
def is_mentioned(self, message: object) -> bool:
38+
_ = message
39+
return True
40+
3741
async def process_output(self, session_id: str, output):
3842
_ = (session_id, output)
3943

0 commit comments

Comments
 (0)