Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion backend/app/api/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,18 @@ async def send_to_user(self, agent_id: str, user_id: str, message: dict):

async def get_active_session_ids(self, agent_id: str) -> list[str]:
"""Return distinct session IDs for all active WS connections of an agent."""
return await realtime_router.get_active_session_ids(agent_id)
seen: set[str] = set()
for _ws, session_id, _user_id in self._local_connections(agent_id):
if session_id:
seen.add(session_id)
seen.update(await realtime_router.get_active_session_ids(agent_id))
return list(seen)

async def is_user_viewing_session(self, agent_id: str, session_id: str, user_id: str) -> bool:
"""Return True if the given platform user currently has this exact session open."""
for _ws, local_session_id, local_user_id in self._local_connections(agent_id):
if local_session_id == session_id and local_user_id == user_id:
return True
return await realtime_router.is_user_viewing_session(
agent_id=agent_id,
session_id=session_id,
Expand Down
163 changes: 97 additions & 66 deletions backend/app/services/realtime_runtime/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

PRESENCE_TTL_SECONDS = 180
PUBSUB_PREFIX = "realtime:ws"
SUBSCRIBER_RETRY_SECONDS = 5


class RealtimeRouter:
Expand All @@ -42,31 +43,37 @@ async def register_connection(
user_id: str | None,
) -> str:
connection_id = uuid.uuid4().hex
redis = await get_redis()
setattr(websocket.state, "realtime_connection_id", connection_id)
payload = {
"agent_id": agent_id,
"session_id": session_id or "",
"user_id": user_id or "",
"instance_id": self.instance_id,
}
async with redis.pipeline(transaction=True) as pipe:
pipe.sadd(self._agent_index_key(agent_id), connection_id)
pipe.hset(self._connection_key(connection_id), mapping=payload)
pipe.expire(self._connection_key(connection_id), PRESENCE_TTL_SECONDS)
pipe.expire(self._agent_index_key(agent_id), PRESENCE_TTL_SECONDS)
await pipe.execute()
setattr(websocket.state, "realtime_connection_id", connection_id)
try:
redis = await get_redis()
async with redis.pipeline(transaction=True) as pipe:
pipe.sadd(self._agent_index_key(agent_id), connection_id)
pipe.hset(self._connection_key(connection_id), mapping=payload)
pipe.expire(self._connection_key(connection_id), PRESENCE_TTL_SECONDS)
pipe.expire(self._agent_index_key(agent_id), PRESENCE_TTL_SECONDS)
await pipe.execute()
except Exception as exc:
logger.warning(f"[Realtime] Redis presence unavailable; using local websocket only: {exc}")
return connection_id
Comment on lines +61 to 63
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Retry presence registration after Redis returns

When Redis is unavailable during register_connection, the socket is kept only in active_connections and this branch returns without ever writing its presence record later. If Redis recovers while the WebSocket stays open, remote instances still cannot discover this viewer in _list_presence, so cross-instance messages for that user/session are skipped until the user reconnects. Consider retrying the presence write or refreshing local connections once Redis is reachable again.

Useful? React with 👍 / 👎.


async def unregister_connection(self, *, agent_id: str, websocket: WebSocket) -> None:
connection_id = getattr(websocket.state, "realtime_connection_id", None)
if not connection_id:
return
redis = await get_redis()
async with redis.pipeline(transaction=True) as pipe:
pipe.srem(self._agent_index_key(agent_id), connection_id)
pipe.delete(self._connection_key(connection_id))
await pipe.execute()
try:
redis = await get_redis()
async with redis.pipeline(transaction=True) as pipe:
pipe.srem(self._agent_index_key(agent_id), connection_id)
pipe.delete(self._connection_key(connection_id))
await pipe.execute()
except Exception as exc:
logger.warning(f"[Realtime] Redis presence cleanup failed: {exc}")

async def is_user_viewing_session(self, *, agent_id: str, session_id: str, user_id: str) -> bool:
for record in await self._list_presence(agent_id):
Expand Down Expand Up @@ -118,21 +125,25 @@ async def route_message(
if not remote_targets:
return

redis = await get_redis()
envelope = json.dumps(
{
"message": message,
"agent_id": agent_id,
"session_id": session_id,
"user_id": user_id,
"origin_instance_id": self.instance_id,
}
)
publish_tasks = [
redis.publish(f"{PUBSUB_PREFIX}:instance:{instance_id}", envelope)
for instance_id in remote_targets
]
await asyncio.gather(*publish_tasks, return_exceptions=True)
try:
redis = await get_redis()
envelope = json.dumps(
{
"message": message,
"agent_id": agent_id,
"session_id": session_id,
"user_id": user_id,
"origin_instance_id": self.instance_id,
}
)
publish_tasks = [
redis.publish(f"{PUBSUB_PREFIX}:instance:{instance_id}", envelope)
for instance_id in remote_targets
]
await asyncio.gather(*publish_tasks, return_exceptions=True)
except Exception as exc:
logger.warning(f"[Realtime] Redis pubsub unavailable; skipped remote websocket routing: {exc}")
return
logger.debug(
f"[Realtime] Routed agent={agent_id} local={local_sent} remote_instances={list(remote_targets.keys())}"
)
Expand All @@ -154,47 +165,67 @@ async def stop(self) -> None:
self._started = False

async def _subscriber_loop(self, deliver_local) -> None:
redis = await get_redis()
pubsub = redis.pubsub()
await pubsub.subscribe(self._instance_channel())
try:
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if not message:
await asyncio.sleep(0.05)
continue
try:
data = json.loads(message["data"])
await deliver_local(
agent_id=data["agent_id"],
payload=data["message"],
session_id=data.get("session_id"),
user_id=data.get("user_id"),
)
except Exception as exc:
logger.warning(f"[Realtime] Failed to deliver pubsub message: {exc}")
except asyncio.CancelledError:
raise
finally:
await pubsub.unsubscribe(self._instance_channel())
await pubsub.aclose()
while True:
pubsub = None
try:
redis = await get_redis()
pubsub = redis.pubsub()
await pubsub.subscribe(self._instance_channel())
logger.info("[Realtime] Redis pubsub subscriber connected")
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if not message:
await asyncio.sleep(0.05)
continue
try:
data = json.loads(message["data"])
await deliver_local(
agent_id=data["agent_id"],
payload=data["message"],
session_id=data.get("session_id"),
user_id=data.get("user_id"),
)
except Exception as exc:
logger.warning(f"[Realtime] Failed to deliver pubsub message: {exc}")
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning(
f"[Realtime] Redis pubsub subscriber unavailable; retrying in "
f"{SUBSCRIBER_RETRY_SECONDS}s: {exc}"
)
await asyncio.sleep(SUBSCRIBER_RETRY_SECONDS)
finally:
if pubsub is not None:
try:
await pubsub.unsubscribe(self._instance_channel())
except Exception as exc:
logger.warning(f"[Realtime] Failed to unsubscribe pubsub channel: {exc}")
try:
await pubsub.aclose()
except Exception as exc:
logger.warning(f"[Realtime] Failed to close pubsub connection: {exc}")

async def _list_presence(self, agent_id: str) -> list[dict[str, str]]:
redis = await get_redis()
connection_ids = await redis.smembers(self._agent_index_key(agent_id))
if not connection_ids:
try:
redis = await get_redis()
connection_ids = await redis.smembers(self._agent_index_key(agent_id))
if not connection_ids:
return []
records: list[dict[str, str]] = []
stale_ids: list[str] = []
for connection_id in connection_ids:
data = await redis.hgetall(self._connection_key(connection_id))
if not data:
stale_ids.append(connection_id)
continue
records.append(data)
if stale_ids:
await redis.srem(self._agent_index_key(agent_id), *stale_ids)
return records
except Exception as exc:
logger.warning(f"[Realtime] Redis presence lookup failed: {exc}")
return []
Comment on lines +226 to 228
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve local viewer checks when Redis is unavailable

With Redis unavailable, _list_presence() now returns an empty list, so manager.is_user_viewing_session() always reports false even though register_connection() now allows local WebSockets to stay connected without Redis. In that local-only fallback scenario, maybe_mark_session_read_for_active_viewer() in backend/app/api/websocket.py stops marking messages read for the user currently viewing the session, leaving unread state stale while chat delivery works. The fallback needs to include local active connections or avoid treating Redis failure as no local viewers.

Useful? React with 👍 / 👎.

records: list[dict[str, str]] = []
stale_ids: list[str] = []
for connection_id in connection_ids:
data = await redis.hgetall(self._connection_key(connection_id))
if not data:
stale_ids.append(connection_id)
continue
records.append(data)
if stale_ids:
await redis.srem(self._agent_index_key(agent_id), *stale_ids)
return records


realtime_router = RealtimeRouter()
108 changes: 108 additions & 0 deletions backend/tests/test_realtime_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import asyncio
from types import SimpleNamespace

import pytest

from app.services.realtime_runtime import router as realtime_router_module
from app.services.realtime_runtime.router import RealtimeRouter


class DummyWebSocket:
def __init__(self) -> None:
self.state = SimpleNamespace()
self.sent: list[dict] = []

async def send_json(self, message: dict) -> None:
self.sent.append(message)


@pytest.mark.asyncio
async def test_realtime_router_falls_back_to_local_connections_when_redis_unavailable(monkeypatch):
async def unavailable_redis():
raise ConnectionError("redis is down")

monkeypatch.setattr(realtime_router_module, "get_redis", unavailable_redis)

router = RealtimeRouter()
websocket = DummyWebSocket()

connection_id = await router.register_connection(
agent_id="agent-1",
websocket=websocket,
session_id="session-1",
user_id="user-1",
)

assert connection_id
assert websocket.state.realtime_connection_id == connection_id

await router.route_message(
agent_id="agent-1",
message={"type": "chunk", "content": "hello"},
local_connections=[(websocket, "session-1", "user-1")],
session_id="session-1",
user_id="user-1",
)

assert websocket.sent == [{"type": "chunk", "content": "hello"}]

await router.unregister_connection(agent_id="agent-1", websocket=websocket)


@pytest.mark.asyncio
async def test_realtime_subscriber_retries_after_initial_redis_failure(monkeypatch):
subscribed = asyncio.Event()
calls = 0

class FakePubSub:
async def subscribe(self, _channel: str) -> None:
subscribed.set()

async def get_message(self, *, ignore_subscribe_messages: bool, timeout: float):
await asyncio.sleep(0.01)
return None

async def unsubscribe(self, _channel: str) -> None:
pass

async def aclose(self) -> None:
pass

class FakeRedis:
def pubsub(self):
return FakePubSub()

async def flaky_redis():
nonlocal calls
calls += 1
if calls == 1:
raise ConnectionError("redis is down")
return FakeRedis()

monkeypatch.setattr(realtime_router_module, "get_redis", flaky_redis)
monkeypatch.setattr(realtime_router_module, "SUBSCRIBER_RETRY_SECONDS", 0)

router = RealtimeRouter()
task = asyncio.create_task(router._subscriber_loop(lambda **_kwargs: None))
await asyncio.wait_for(subscribed.wait(), timeout=1)

task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

assert calls >= 2


@pytest.mark.asyncio
async def test_connection_manager_viewing_session_uses_local_connections_without_redis(monkeypatch):
from app.api import websocket as websocket_api

manager = websocket_api.ConnectionManager()
manager.active_connections["agent-1"] = [(DummyWebSocket(), "session-1", "user-1")]

async def redis_should_not_be_checked(*_args, **_kwargs):
raise AssertionError("local viewer should be detected before Redis presence lookup")

monkeypatch.setattr(websocket_api.realtime_router, "is_user_viewing_session", redis_should_not_be_checked)

assert await manager.is_user_viewing_session("agent-1", "session-1", "user-1") is True