Skip to content

Commit 8ce09ac

Browse files
committed
fix(runner): refactor SessionRunner to pass channel explicitly in process_message and _run methods
Signed-off-by: Frost Ming <me@frostming.com>
1 parent 216998e commit 8ce09ac

2 files changed

Lines changed: 13 additions & 17 deletions

File tree

src/bub/channels/manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ async def _process_input[T](self, channel: BaseChannel[T], message: T) -> None:
6969
session_id, _ = await channel.get_session_prompt(message)
7070
if session_id not in self._session_runners:
7171
self._session_runners[session_id] = SessionRunner(
72-
channel,
7372
session_id,
7473
self.runtime.settings.message_debounce_seconds,
7574
self.runtime.settings.message_delay_seconds,
7675
)
77-
await self._session_runners[session_id].process_message(message)
76+
await self._session_runners[session_id].process_message(channel, message)

src/bub/channels/runner.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77

88

99
class SessionRunner:
10-
def __init__(
11-
self, channel: BaseChannel, session_id: str, debounce_seconds: int, message_delay_seconds: int
12-
) -> None:
10+
def __init__(self, session_id: str, debounce_seconds: int, message_delay_seconds: int) -> None:
1311
self.session_id = session_id
14-
self.channel = channel
1512
self.debounce_seconds = debounce_seconds
1613
self.message_delay_seconds = message_delay_seconds
1714
self._prompts: list[str] = []
@@ -21,9 +18,9 @@ def __init__(
2118
self._running_task: asyncio.Task[None] | None = None
2219
self._loop = asyncio.get_running_loop()
2320

24-
async def _run(self) -> None:
21+
async def _run(self, channel: BaseChannel) -> None:
2522
await self._event.wait()
26-
prompt = f"channel: ${self.channel.output_channel}\n" + "\n".join(self._prompts)
23+
prompt = f"channel: ${channel.output_channel}\n" + "\n".join(self._prompts)
2724
self._prompts.clear()
2825
self._running_task = None
2926
if (
@@ -32,8 +29,8 @@ async def _run(self) -> None:
3229
):
3330
self._last_received_at = None
3431
try:
35-
result = await self.channel.run_prompt(self.session_id, prompt)
36-
await self.channel.process_output(self.session_id, result)
32+
result = await channel.run_prompt(self.session_id, prompt)
33+
await channel.process_output(self.session_id, result)
3734
except Exception:
3835
logger.exception("session.run.error session_id={}", self.session_id)
3936

@@ -43,9 +40,9 @@ def reset_timer(self, timeout: int) -> None:
4340
self._timer.cancel()
4441
self._timer = self._loop.call_later(timeout, self._event.set)
4542

46-
async def process_message(self, message: Any) -> None:
47-
is_mentioned = self.channel.is_mentioned(message)
48-
_, prompt = await self.channel.get_session_prompt(message)
43+
async def process_message(self, channel: BaseChannel, message: Any) -> None:
44+
is_mentioned = channel.is_mentioned(message)
45+
_, prompt = await channel.get_session_prompt(message)
4946
now = self._loop.time()
5047
if not is_mentioned and self._last_received_at is None:
5148
logger.info("session.receive ignored session_id={} message={}", self.session_id, prompt)
@@ -54,8 +51,8 @@ async def process_message(self, message: Any) -> None:
5451
if prompt.startswith(","):
5552
logger.info("session.receive.command session_id={} message={}", self.session_id, prompt)
5653
try:
57-
result = await self.channel.run_prompt(self.session_id, prompt)
58-
await self.channel.process_output(self.session_id, result)
54+
result = await channel.run_prompt(self.session_id, prompt)
55+
await channel.process_output(self.session_id, result)
5956
except Exception:
6057
logger.exception("session.run.error session_id={}", self.session_id)
6158
elif is_mentioned:
@@ -64,11 +61,11 @@ async def process_message(self, message: Any) -> None:
6461
logger.info("session.receive.mentioned session_id={} message={}", self.session_id, prompt)
6562
self.reset_timer(self.debounce_seconds)
6663
if self._running_task is None:
67-
self._running_task = asyncio.create_task(self._run())
64+
self._running_task = asyncio.create_task(self._run(channel))
6865
return await self._running_task
6966
elif self._last_received_at is not None and self._running_task is None:
7067
# Otherwise if bot is mentioned before, we will keep reading messages for at most 30 seconds.
7168
logger.info("session.receive followup session_id={} message={}", self.session_id, prompt)
7269
self.reset_timer(self.message_delay_seconds)
73-
self._running_task = asyncio.create_task(self._run())
70+
self._running_task = asyncio.create_task(self._run(channel))
7471
return await self._running_task

0 commit comments

Comments
 (0)