-
Notifications
You must be signed in to change notification settings - Fork 646
fix: tolerate redis outages in realtime websockets #612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
|
|
||
| PRESENCE_TTL_SECONDS = 180 | ||
| PUBSUB_PREFIX = "realtime:ws" | ||
| SUBSCRIBER_RETRY_SECONDS = 5 | ||
|
|
||
|
|
||
| class RealtimeRouter: | ||
|
|
@@ -42,31 +43,37 @@ async def register_connection( | |
| user_id: str | None, | ||
| ) -> str: | ||
| connection_id = uuid.uuid4().hex | ||
| redis = await get_redis() | ||
| setattr(websocket.state, "realtime_connection_id", connection_id) | ||
| payload = { | ||
| "agent_id": agent_id, | ||
| "session_id": session_id or "", | ||
| "user_id": user_id or "", | ||
| "instance_id": self.instance_id, | ||
| } | ||
| async with redis.pipeline(transaction=True) as pipe: | ||
| pipe.sadd(self._agent_index_key(agent_id), connection_id) | ||
| pipe.hset(self._connection_key(connection_id), mapping=payload) | ||
| pipe.expire(self._connection_key(connection_id), PRESENCE_TTL_SECONDS) | ||
| pipe.expire(self._agent_index_key(agent_id), PRESENCE_TTL_SECONDS) | ||
| await pipe.execute() | ||
| setattr(websocket.state, "realtime_connection_id", connection_id) | ||
| try: | ||
| redis = await get_redis() | ||
| async with redis.pipeline(transaction=True) as pipe: | ||
| pipe.sadd(self._agent_index_key(agent_id), connection_id) | ||
| pipe.hset(self._connection_key(connection_id), mapping=payload) | ||
| pipe.expire(self._connection_key(connection_id), PRESENCE_TTL_SECONDS) | ||
| pipe.expire(self._agent_index_key(agent_id), PRESENCE_TTL_SECONDS) | ||
| await pipe.execute() | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Redis presence unavailable; using local websocket only: {exc}") | ||
| return connection_id | ||
|
|
||
| async def unregister_connection(self, *, agent_id: str, websocket: WebSocket) -> None: | ||
| connection_id = getattr(websocket.state, "realtime_connection_id", None) | ||
| if not connection_id: | ||
| return | ||
| redis = await get_redis() | ||
| async with redis.pipeline(transaction=True) as pipe: | ||
| pipe.srem(self._agent_index_key(agent_id), connection_id) | ||
| pipe.delete(self._connection_key(connection_id)) | ||
| await pipe.execute() | ||
| try: | ||
| redis = await get_redis() | ||
| async with redis.pipeline(transaction=True) as pipe: | ||
| pipe.srem(self._agent_index_key(agent_id), connection_id) | ||
| pipe.delete(self._connection_key(connection_id)) | ||
| await pipe.execute() | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Redis presence cleanup failed: {exc}") | ||
|
|
||
| async def is_user_viewing_session(self, *, agent_id: str, session_id: str, user_id: str) -> bool: | ||
| for record in await self._list_presence(agent_id): | ||
|
|
@@ -118,21 +125,25 @@ async def route_message( | |
| if not remote_targets: | ||
| return | ||
|
|
||
| redis = await get_redis() | ||
| envelope = json.dumps( | ||
| { | ||
| "message": message, | ||
| "agent_id": agent_id, | ||
| "session_id": session_id, | ||
| "user_id": user_id, | ||
| "origin_instance_id": self.instance_id, | ||
| } | ||
| ) | ||
| publish_tasks = [ | ||
| redis.publish(f"{PUBSUB_PREFIX}:instance:{instance_id}", envelope) | ||
| for instance_id in remote_targets | ||
| ] | ||
| await asyncio.gather(*publish_tasks, return_exceptions=True) | ||
| try: | ||
| redis = await get_redis() | ||
| envelope = json.dumps( | ||
| { | ||
| "message": message, | ||
| "agent_id": agent_id, | ||
| "session_id": session_id, | ||
| "user_id": user_id, | ||
| "origin_instance_id": self.instance_id, | ||
| } | ||
| ) | ||
| publish_tasks = [ | ||
| redis.publish(f"{PUBSUB_PREFIX}:instance:{instance_id}", envelope) | ||
| for instance_id in remote_targets | ||
| ] | ||
| await asyncio.gather(*publish_tasks, return_exceptions=True) | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Redis pubsub unavailable; skipped remote websocket routing: {exc}") | ||
| return | ||
| logger.debug( | ||
| f"[Realtime] Routed agent={agent_id} local={local_sent} remote_instances={list(remote_targets.keys())}" | ||
| ) | ||
|
|
@@ -154,47 +165,67 @@ async def stop(self) -> None: | |
| self._started = False | ||
|
|
||
| async def _subscriber_loop(self, deliver_local) -> None: | ||
| redis = await get_redis() | ||
| pubsub = redis.pubsub() | ||
| await pubsub.subscribe(self._instance_channel()) | ||
| try: | ||
| while True: | ||
| message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) | ||
| if not message: | ||
| await asyncio.sleep(0.05) | ||
| continue | ||
| try: | ||
| data = json.loads(message["data"]) | ||
| await deliver_local( | ||
| agent_id=data["agent_id"], | ||
| payload=data["message"], | ||
| session_id=data.get("session_id"), | ||
| user_id=data.get("user_id"), | ||
| ) | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Failed to deliver pubsub message: {exc}") | ||
| except asyncio.CancelledError: | ||
| raise | ||
| finally: | ||
| await pubsub.unsubscribe(self._instance_channel()) | ||
| await pubsub.aclose() | ||
| while True: | ||
| pubsub = None | ||
| try: | ||
| redis = await get_redis() | ||
| pubsub = redis.pubsub() | ||
| await pubsub.subscribe(self._instance_channel()) | ||
| logger.info("[Realtime] Redis pubsub subscriber connected") | ||
| while True: | ||
| message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) | ||
| if not message: | ||
| await asyncio.sleep(0.05) | ||
| continue | ||
| try: | ||
| data = json.loads(message["data"]) | ||
| await deliver_local( | ||
| agent_id=data["agent_id"], | ||
| payload=data["message"], | ||
| session_id=data.get("session_id"), | ||
| user_id=data.get("user_id"), | ||
| ) | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Failed to deliver pubsub message: {exc}") | ||
| except asyncio.CancelledError: | ||
| raise | ||
| except Exception as exc: | ||
| logger.warning( | ||
| f"[Realtime] Redis pubsub subscriber unavailable; retrying in " | ||
| f"{SUBSCRIBER_RETRY_SECONDS}s: {exc}" | ||
| ) | ||
| await asyncio.sleep(SUBSCRIBER_RETRY_SECONDS) | ||
| finally: | ||
| if pubsub is not None: | ||
| try: | ||
| await pubsub.unsubscribe(self._instance_channel()) | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Failed to unsubscribe pubsub channel: {exc}") | ||
| try: | ||
| await pubsub.aclose() | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Failed to close pubsub connection: {exc}") | ||
|
|
||
| async def _list_presence(self, agent_id: str) -> list[dict[str, str]]: | ||
| redis = await get_redis() | ||
| connection_ids = await redis.smembers(self._agent_index_key(agent_id)) | ||
| if not connection_ids: | ||
| try: | ||
| redis = await get_redis() | ||
| connection_ids = await redis.smembers(self._agent_index_key(agent_id)) | ||
| if not connection_ids: | ||
| return [] | ||
| records: list[dict[str, str]] = [] | ||
| stale_ids: list[str] = [] | ||
| for connection_id in connection_ids: | ||
| data = await redis.hgetall(self._connection_key(connection_id)) | ||
| if not data: | ||
| stale_ids.append(connection_id) | ||
| continue | ||
| records.append(data) | ||
| if stale_ids: | ||
| await redis.srem(self._agent_index_key(agent_id), *stale_ids) | ||
| return records | ||
| except Exception as exc: | ||
| logger.warning(f"[Realtime] Redis presence lookup failed: {exc}") | ||
| return [] | ||
|
Comment on lines
+226
to
228
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
With Redis unavailable, Useful? React with 👍 / 👎. |
||
| records: list[dict[str, str]] = [] | ||
| stale_ids: list[str] = [] | ||
| for connection_id in connection_ids: | ||
| data = await redis.hgetall(self._connection_key(connection_id)) | ||
| if not data: | ||
| stale_ids.append(connection_id) | ||
| continue | ||
| records.append(data) | ||
| if stale_ids: | ||
| await redis.srem(self._agent_index_key(agent_id), *stale_ids) | ||
| return records | ||
|
|
||
|
|
||
| realtime_router = RealtimeRouter() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| import asyncio | ||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
|
|
||
| from app.services.realtime_runtime import router as realtime_router_module | ||
| from app.services.realtime_runtime.router import RealtimeRouter | ||
|
|
||
|
|
||
| class DummyWebSocket: | ||
| def __init__(self) -> None: | ||
| self.state = SimpleNamespace() | ||
| self.sent: list[dict] = [] | ||
|
|
||
| async def send_json(self, message: dict) -> None: | ||
| self.sent.append(message) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_realtime_router_falls_back_to_local_connections_when_redis_unavailable(monkeypatch): | ||
| async def unavailable_redis(): | ||
| raise ConnectionError("redis is down") | ||
|
|
||
| monkeypatch.setattr(realtime_router_module, "get_redis", unavailable_redis) | ||
|
|
||
| router = RealtimeRouter() | ||
| websocket = DummyWebSocket() | ||
|
|
||
| connection_id = await router.register_connection( | ||
| agent_id="agent-1", | ||
| websocket=websocket, | ||
| session_id="session-1", | ||
| user_id="user-1", | ||
| ) | ||
|
|
||
| assert connection_id | ||
| assert websocket.state.realtime_connection_id == connection_id | ||
|
|
||
| await router.route_message( | ||
| agent_id="agent-1", | ||
| message={"type": "chunk", "content": "hello"}, | ||
| local_connections=[(websocket, "session-1", "user-1")], | ||
| session_id="session-1", | ||
| user_id="user-1", | ||
| ) | ||
|
|
||
| assert websocket.sent == [{"type": "chunk", "content": "hello"}] | ||
|
|
||
| await router.unregister_connection(agent_id="agent-1", websocket=websocket) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_realtime_subscriber_retries_after_initial_redis_failure(monkeypatch): | ||
| subscribed = asyncio.Event() | ||
| calls = 0 | ||
|
|
||
| class FakePubSub: | ||
| async def subscribe(self, _channel: str) -> None: | ||
| subscribed.set() | ||
|
|
||
| async def get_message(self, *, ignore_subscribe_messages: bool, timeout: float): | ||
| await asyncio.sleep(0.01) | ||
| return None | ||
|
|
||
| async def unsubscribe(self, _channel: str) -> None: | ||
| pass | ||
|
|
||
| async def aclose(self) -> None: | ||
| pass | ||
|
|
||
| class FakeRedis: | ||
| def pubsub(self): | ||
| return FakePubSub() | ||
|
|
||
| async def flaky_redis(): | ||
| nonlocal calls | ||
| calls += 1 | ||
| if calls == 1: | ||
| raise ConnectionError("redis is down") | ||
| return FakeRedis() | ||
|
|
||
| monkeypatch.setattr(realtime_router_module, "get_redis", flaky_redis) | ||
| monkeypatch.setattr(realtime_router_module, "SUBSCRIBER_RETRY_SECONDS", 0) | ||
|
|
||
| router = RealtimeRouter() | ||
| task = asyncio.create_task(router._subscriber_loop(lambda **_kwargs: None)) | ||
| await asyncio.wait_for(subscribed.wait(), timeout=1) | ||
|
|
||
| task.cancel() | ||
| with pytest.raises(asyncio.CancelledError): | ||
| await task | ||
|
|
||
| assert calls >= 2 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_connection_manager_viewing_session_uses_local_connections_without_redis(monkeypatch): | ||
| from app.api import websocket as websocket_api | ||
|
|
||
| manager = websocket_api.ConnectionManager() | ||
| manager.active_connections["agent-1"] = [(DummyWebSocket(), "session-1", "user-1")] | ||
|
|
||
| async def redis_should_not_be_checked(*_args, **_kwargs): | ||
| raise AssertionError("local viewer should be detected before Redis presence lookup") | ||
|
|
||
| monkeypatch.setattr(websocket_api.realtime_router, "is_user_viewing_session", redis_should_not_be_checked) | ||
|
|
||
| assert await manager.is_user_viewing_session("agent-1", "session-1", "user-1") is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When Redis is unavailable during
register_connection, the socket is kept only inactive_connectionsand this branch returns without ever writing its presence record later. If Redis recovers while the WebSocket stays open, remote instances still cannot discover this viewer in_list_presence, so cross-instance messages for that user/session are skipped until the user reconnects. Consider retrying the presence write or refreshing local connections once Redis is reachable again.Useful? React with 👍 / 👎.