22import threading
33from enum import IntFlag
44from typing import Callable , Optional , Any
5+ import json
56import logging
67
8+ import aiohttp
79from aiohttp import ClientSession , ClientTimeout , web
810
911
@@ -26,6 +28,9 @@ def __init__(
2628 self .host = host
2729 self .port = port
2830 self .mode = mode
31+
32+ self .routes : dict [str , list [str ]] = {}
33+
2934 self .app = web .Application ()
3035 self .runner = web .AppRunner (self .app )
3136 self .loop = asyncio .new_event_loop ()
@@ -34,6 +39,9 @@ def __init__(
3439 self ._started_event = threading .Event ()
3540 self .unresolved_futures = []
3641
42+ self .websockets : dict [str , set [web .WebSocketResponse ]] = {}
43+ self .ws_clients : dict [str , aiohttp .ClientWebSocketResponse ] = {}
44+
3745 def _start_loop (self ):
3846 asyncio .set_event_loop (self .loop )
3947 self .loop .run_until_complete (self ._start_server ())
@@ -86,21 +94,117 @@ def register():
8694 self .app .router .add_route (method .upper (), path , handler )
8795
8896 self .loop .call_soon_threadsafe (register )
97+ if self .routes .get (path ) is not None :
98+ self .routes [path ] = [method ]
99+ else :
100+ self .routes [path ].append (method )
89101
90102 def add_websocket (self , path : str , handler_lambda ):
91- if not (self .mode & HTTPConnectorMode .server ):
92- return
103+ """
104+ In server mode:
105+ `path` is the HTTP path (e.g. "/ws").
106+ `handler_lambda(ws, request)` is called for each connection.
107+
108+ In client mode:
109+ `path` is the full WebSocket URL (e.g. "ws://example.com/ws").
110+ `handler_lambda(ws, msg)` is called for each incoming message.
111+ """
112+ # SERVER SIDE
113+ if self .mode & HTTPConnectorMode .server :
114+ if path not in self .websockets :
115+ self .websockets [path ] = set ()
93116
94- async def ws_handler (request ):
95- ws = web .WebSocketResponse ()
96- await ws .prepare (request )
97- await handler_lambda (ws , request )
98- return ws
117+ async def ws_handler (request ):
118+ ws = web .WebSocketResponse ()
119+ await ws .prepare (request )
99120
100- def register ():
101- self .app . router . add_get ( path , ws_handler )
121+ # register this connection
122+ self .websockets [ path ]. add ( ws )
102123
103- self .loop .call_soon_threadsafe (register )
124+ try :
125+ # user handler can read/write freely, e.g.:
126+ # async for msg in ws: ...
127+ await handler_lambda (ws , request )
128+ finally :
129+ # ensure it is removed on close
130+ self .websockets [path ].discard (ws )
131+ await ws .close ()
132+
133+ return ws
134+
135+ def register_server ():
136+ self .app .router .add_get (path , ws_handler )
137+
138+ self .loop .call_soon_threadsafe (register_server )
139+
140+ # CLIENT SIDE
141+ if self .mode & HTTPConnectorMode .client :
142+ async def connect_client_ws ():
143+ assert self .client_session is not None , "ClientSession not initialized"
144+ ws = await self .client_session .ws_connect (path )
145+ self .ws_clients [path ] = ws
146+
147+ try :
148+ async for msg in ws :
149+ # let user handler inspect/read messages and optionally write
150+ await handler_lambda (ws , msg )
151+ finally :
152+ # clean up on close
153+ if self .ws_clients .get (path ) is ws :
154+ del self .ws_clients [path ]
155+ await ws .close ()
156+
157+ def start_client ():
158+ asyncio .create_task (connect_client_ws ())
159+
160+ self .loop .call_soon_threadsafe (start_client )
161+
162+ def publish_websocket (
163+ self ,
164+ path : str ,
165+ payload : Optional [str | dict ],
166+ ):
167+ """
168+ Send `payload` over all WebSocket connections associated with `path`.
169+
170+ - For server mode: broadcasts to all connected clients on that route.
171+ - For client mode: sends to the single client WebSocket created for that URL.
172+ """
173+ if payload is None :
174+ msg = ""
175+ elif isinstance (payload , dict ):
176+ msg = json .dumps (payload )
177+ else :
178+ msg = str (payload )
179+
180+ async def _publish ():
181+ # collect all websockets (server + client) associated with this key
182+ server_conns = list (self .websockets .get (path , []))
183+ client_ws = self .ws_clients .get (path )
184+ all_conns = server_conns + ([client_ws ] if client_ws is not None else [])
185+
186+ dead_server = []
187+ dead_client = False
188+
189+ for ws in all_conns :
190+ try :
191+ await ws .send_str (msg )
192+ except Exception :
193+ # mark broken ones to be removed
194+ if ws in server_conns :
195+ dead_server .append (ws )
196+ elif ws is client_ws :
197+ dead_client = True
198+
199+ # cleanup broken server connections
200+ for ws in dead_server :
201+ self .websockets .get (path , set ()).discard (ws )
202+
203+ # cleanup broken client connection
204+ if dead_client and self .ws_clients .get (path ) is client_ws :
205+ del self .ws_clients [path ]
206+
207+ asyncio .run_coroutine_threadsafe (_publish (), self .loop )
104208
105209 def send_request (
106210 self ,
0 commit comments