Skip to content

Commit 563ccc7

Browse files
committed
feat: add websockets client/server
1 parent 0bee669 commit 563ccc7

3 files changed

Lines changed: 155 additions & 18 deletions

File tree

src/rai_core/rai/communication/http/api.py

Lines changed: 114 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import threading
33
from enum import IntFlag
44
from typing import Callable, Optional, Any
5+
import json
56
import logging
67

8+
import aiohttp
79
from aiohttp import ClientSession, ClientTimeout, web
810

911

@@ -26,6 +28,9 @@ def __init__(
2628
self.host = host
2729
self.port = port
2830
self.mode = mode
31+
32+
self.routes: dict[str, list[str]] = {}
33+
2934
self.app = web.Application()
3035
self.runner = web.AppRunner(self.app)
3136
self.loop = asyncio.new_event_loop()
@@ -34,6 +39,9 @@ def __init__(
3439
self._started_event = threading.Event()
3540
self.unresolved_futures = []
3641

42+
self.websockets: dict[str, set[web.WebSocketResponse]] = {}
43+
self.ws_clients: dict[str, aiohttp.ClientWebSocketResponse] = {}
44+
3745
def _start_loop(self):
3846
asyncio.set_event_loop(self.loop)
3947
self.loop.run_until_complete(self._start_server())
@@ -86,21 +94,117 @@ def register():
8694
self.app.router.add_route(method.upper(), path, handler)
8795

8896
self.loop.call_soon_threadsafe(register)
97+
if self.routes.get(path) is not None:
98+
self.routes[path] = [method]
99+
else:
100+
self.routes[path].append(method)
89101

90102
def add_websocket(self, path: str, handler_lambda):
91-
if not (self.mode & HTTPConnectorMode.server):
92-
return
103+
"""
104+
In server mode:
105+
`path` is the HTTP path (e.g. "/ws").
106+
`handler_lambda(ws, request)` is called for each connection.
107+
108+
In client mode:
109+
`path` is the full WebSocket URL (e.g. "ws://example.com/ws").
110+
`handler_lambda(ws, msg)` is called for each incoming message.
111+
"""
112+
# SERVER SIDE
113+
if self.mode & HTTPConnectorMode.server:
114+
if path not in self.websockets:
115+
self.websockets[path] = set()
93116

94-
async def ws_handler(request):
95-
ws = web.WebSocketResponse()
96-
await ws.prepare(request)
97-
await handler_lambda(ws, request)
98-
return ws
117+
async def ws_handler(request):
118+
ws = web.WebSocketResponse()
119+
await ws.prepare(request)
99120

100-
def register():
101-
self.app.router.add_get(path, ws_handler)
121+
# register this connection
122+
self.websockets[path].add(ws)
102123

103-
self.loop.call_soon_threadsafe(register)
124+
try:
125+
# user handler can read/write freely, e.g.:
126+
# async for msg in ws: ...
127+
await handler_lambda(ws, request)
128+
finally:
129+
# ensure it is removed on close
130+
self.websockets[path].discard(ws)
131+
await ws.close()
132+
133+
return ws
134+
135+
def register_server():
136+
self.app.router.add_get(path, ws_handler)
137+
138+
self.loop.call_soon_threadsafe(register_server)
139+
140+
# CLIENT SIDE
141+
if self.mode & HTTPConnectorMode.client:
142+
async def connect_client_ws():
143+
assert self.client_session is not None, "ClientSession not initialized"
144+
ws = await self.client_session.ws_connect(path)
145+
self.ws_clients[path] = ws
146+
147+
try:
148+
async for msg in ws:
149+
# let user handler inspect/read messages and optionally write
150+
await handler_lambda(ws, msg)
151+
finally:
152+
# clean up on close
153+
if self.ws_clients.get(path) is ws:
154+
del self.ws_clients[path]
155+
await ws.close()
156+
157+
def start_client():
158+
asyncio.create_task(connect_client_ws())
159+
160+
self.loop.call_soon_threadsafe(start_client)
161+
162+
def publish_websocket(
163+
self,
164+
path: str,
165+
payload: Optional[str | dict],
166+
):
167+
"""
168+
Send `payload` over all WebSocket connections associated with `path`.
169+
170+
- For server mode: broadcasts to all connected clients on that route.
171+
- For client mode: sends to the single client WebSocket created for that URL.
172+
"""
173+
if payload is None:
174+
msg = ""
175+
elif isinstance(payload, dict):
176+
msg = json.dumps(payload)
177+
else:
178+
msg = str(payload)
179+
180+
async def _publish():
181+
# collect all websockets (server + client) associated with this key
182+
server_conns = list(self.websockets.get(path, []))
183+
client_ws = self.ws_clients.get(path)
184+
all_conns = server_conns + ([client_ws] if client_ws is not None else [])
185+
186+
dead_server = []
187+
dead_client = False
188+
189+
for ws in all_conns:
190+
try:
191+
await ws.send_str(msg)
192+
except Exception:
193+
# mark broken ones to be removed
194+
if ws in server_conns:
195+
dead_server.append(ws)
196+
elif ws is client_ws:
197+
dead_client = True
198+
199+
# cleanup broken server connections
200+
for ws in dead_server:
201+
self.websockets.get(path, set()).discard(ws)
202+
203+
# cleanup broken client connection
204+
if dead_client and self.ws_clients.get(path) is client_ws:
205+
del self.ws_clients[path]
206+
207+
asyncio.run_coroutine_threadsafe(_publish(), self.loop)
104208

105209
def send_request(
106210
self,

src/rai_core/rai/communication/http/connectors/base.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Callable, Optional, TypeVar
1+
import time
2+
import logging
3+
from typing import Any, Callable, Optional, TypeVar, Dict
24

35
from rai.communication.base_connector import BaseConnector, BaseMessage
46
from rai.communication.http.messages import HTTPMessage
@@ -19,19 +21,45 @@ def __init__(
1921
self._api = HTTPAPI(mode, host, port)
2022
self._api.run()
2123
self._services = []
24+
self.last_msg: Dict[str, T] = {}
2225

2326
def send_message(self, message: T, target: str, **kwargs: Optional[Any]) -> None:
24-
self._api.send_request(
25-
message.method,
26-
target,
27-
None,
28-
payload=message.payload,
29-
headers=message.headers,
30-
)
27+
if message.protocol == "http":
28+
self._api.send_request(
29+
message.method,
30+
target,
31+
None,
32+
payload=message.payload,
33+
headers=message.headers,
34+
)
35+
else:
36+
# self._api.
3137

3238
def receive_message(
3339
self, source: str, timeout_sec: float, **kwargs: Optional[Any]
3440
) -> T:
41+
msg = None
42+
if self._api.routes.get(source, None) is not None:
43+
# a GET method has already been added...
44+
else:
45+
def local_callback(payload: Any) -> None:
46+
msg = payload
47+
self._api.add_route(
48+
"GET",
49+
source,
50+
self.general_callback
51+
)
52+
53+
start_time = time.time()
54+
# wait for the message to be received
55+
while time.time() - start_time < timeout_sec:
56+
if source in self.last_msg:
57+
return self.last_msg[source]
58+
time.sleep(0.1)
59+
else:
60+
raise TimeoutError(
61+
f"Message from {source} not received in {timeout_sec} seconds"
62+
)
3563
raise NotImplementedError("This method should be implemented by the subclass.")
3664

3765
def _safe_callback_wrapper(self, callback: Callable[[T], None], message: T) -> None:
@@ -74,6 +102,10 @@ def create_service(
74102
**kwargs: Optional[Any],
75103
) -> str:
76104
id_str = f"{method.upper()}_{service_name}"
105+
if on_done is not None:
106+
logging.warning(
107+
f"not None on_done argument passed to create_service of {self.__class__}; will have no effect!"
108+
)
77109
if id_str in self._services:
78110
raise HTTPAPIError(
79111
f"Service {service_name} already has a {method.upper()} handler"

src/rai_core/rai/communication/http/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
class HTTPMessage(BaseMessage):
21+
protocol: Literal["http", "websocket"]
2122
headers: Optional[dict] = None
2223
method: Literal[
2324
"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE", "POST", "PATCH", "CONNECT"

0 commit comments

Comments
 (0)