@@ -68,13 +68,17 @@ async def get_favicon():
6868 # Registry of served styles to prevent duplicate route registration
6969 self .served_styles : Dict [str , StyleSheet ] = {}
7070
71- # Registry for client- and serverside functions
71+ # Registry for client- and server-side functions
7272 self .client_functions : Dict [str , Callable ] = {}
7373 self .server_functions : Dict [str , Callable ] = {}
7474
7575 # Names of client-side functions to run on load
7676 self .startup_functions : List [str ] = []
7777
78+ # Registry for connect and disconnect handlers
79+ self .connect_functions : list [Callable ] = []
80+ self .disconnect_functions : list [Callable ] = []
81+
7882 # PWA Registry: route_scope_hash -> (Manifest, ServiceWorker)
7983 self .pwa_registry : Dict [str , tuple [Manifest , ServiceWorker ]] = {}
8084
@@ -108,18 +112,18 @@ def get_service_worker(scope_hash: str):
108112 )
109113
110114 # In App.__init__
111- self .socket_manager = SocketManager ()
115+ self .socket_manager = SocketManager (self )
112116
113117 @self .api .websocket ("/_violetear/ws" )
114- async def websocket_endpoint (websocket : WebSocket ):
115- await self .socket_manager .connect (websocket )
118+ async def websocket_endpoint (websocket : WebSocket , client_id : str ):
119+ await self .socket_manager .connect (client_id , websocket )
116120 try :
117121 while True :
118122 # Keep the connection alive.
119123 # We can also listen for client-to-server messages here if needed later.
120124 await websocket .receive ()
121125 except (WebSocketDisconnect , RuntimeError ):
122- self .socket_manager .disconnect (websocket )
126+ await self .socket_manager .disconnect (client_id )
123127
124128 def client (self , func : Callable ):
125129 """Decorator to mark a function to be compiled to the client."""
@@ -190,6 +194,30 @@ async def wrapper(body: BodyModel): # type: ignore
190194
191195 return func
192196
197+ def connect (self , func : Callable ):
198+ """
199+ Decorator to register a function to run
200+ everytime a client connects.
201+
202+ The callable receives a single parameter `client_id:str`.
203+ """
204+ if not inspect .iscoroutinefunction (func ):
205+ raise ValueError ("Connect callbacks must be async" )
206+
207+ self .connect_functions .append (func )
208+
209+ def disconnect (self , func : Callable ):
210+ """
211+ Decorator to register a function to run
212+ everytime a client disconnects.
213+
214+ The callable receives a single parameter `client_id:str`.
215+ """
216+ if not inspect .iscoroutinefunction (func ):
217+ raise ValueError ("Connect callbacks must be async" )
218+
219+ self .disconnect_functions .append (func )
220+
193221 def style (self , path : str , sheet : StyleSheet ):
194222 """
195223 Registers a stylesheet to be served by the app at a specific path.
@@ -620,16 +648,23 @@ async def broadcast(self, *args, **kwargs):
620648
621649
622650class SocketManager :
623- def __init__ (self ):
651+ def __init__ (self , app : App ):
624652 # Keep track of active connections
625- self .active_connections : List [WebSocket ] = []
653+ self .active_connections : dict [str , WebSocket ] = {}
654+ self .app = app
626655
627- async def connect (self , websocket : WebSocket ):
656+ async def connect (self , client_id : str , websocket : WebSocket ):
628657 await websocket .accept ()
629- self .active_connections .append (websocket )
658+ self .active_connections [client_id ] = websocket
659+
660+ for func in self .app .connect_functions :
661+ await func (client_id )
662+
663+ async def disconnect (self , client_id : str ):
664+ self .active_connections .pop (client_id )
630665
631- def disconnect ( self , websocket : WebSocket ) :
632- self . active_connections . remove ( websocket )
666+ for func in self . app . disconnect_functions :
667+ await func ( client_id )
633668
634669 async def broadcast (self , func_name : str , args : tuple , kwargs : dict ):
635670 """
@@ -641,9 +676,9 @@ async def broadcast(self, func_name: str, args: tuple, kwargs: dict):
641676
642677 # Iterate over all connections and send the message
643678 # We use a copy of the list to avoid modification errors during iteration
644- for connection in self .active_connections [:] :
679+ for id , connection in list ( self .active_connections . items ()) :
645680 try :
646681 await connection .send_text (payload )
647682 except Exception :
648683 # If sending fails (e.g. client disconnected), remove it
649- self .disconnect (connection )
684+ await self .disconnect (id )
0 commit comments