Skip to content

Commit 7ddfcaa

Browse files
committed
Add better broadcast handling with client ID
1 parent 16e3565 commit 7ddfcaa

3 files changed

Lines changed: 77 additions & 16 deletions

File tree

examples/broadcast.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ async def update_counter(count: int):
1818
el.style(color = "red" if count % 2 == 0 else "blue")
1919

2020

21+
@app.connect
22+
async def enter(id: str):
23+
print(f"Client connected {id}")
24+
25+
26+
@app.disconnect
27+
async def exit(id: str):
28+
print(f"Client disconnected {id}")
29+
30+
2131
# --- 2. Server Side (Background Task) ---
2232
async def background_pinger():
2333
"""Simulates a server event happening every second."""

violetear/app.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,17 @@ async def get_favicon():
6868
# Registry of served styles to prevent duplicate route registration
6969
self.served_styles: Dict[str, StyleSheet] = {}
7070

71-
# Registry for client- and serverside functions
71+
# Registry for client- and server-side functions
7272
self.client_functions: Dict[str, Callable] = {}
7373
self.server_functions: Dict[str, Callable] = {}
7474

7575
# Names of client-side functions to run on load
7676
self.startup_functions: List[str] = []
7777

78+
# Registry for connect and disconnect handlers
79+
self.connect_functions: list[Callable] = []
80+
self.disconnect_functions: list[Callable] = []
81+
7882
# PWA Registry: route_scope_hash -> (Manifest, ServiceWorker)
7983
self.pwa_registry: Dict[str, tuple[Manifest, ServiceWorker]] = {}
8084

@@ -108,18 +112,18 @@ def get_service_worker(scope_hash: str):
108112
)
109113

110114
# In App.__init__
111-
self.socket_manager = SocketManager()
115+
self.socket_manager = SocketManager(self)
112116

113117
@self.api.websocket("/_violetear/ws")
114-
async def websocket_endpoint(websocket: WebSocket):
115-
await self.socket_manager.connect(websocket)
118+
async def websocket_endpoint(websocket: WebSocket, client_id: str):
119+
await self.socket_manager.connect(client_id, websocket)
116120
try:
117121
while True:
118122
# Keep the connection alive.
119123
# We can also listen for client-to-server messages here if needed later.
120124
await websocket.receive()
121125
except (WebSocketDisconnect, RuntimeError):
122-
self.socket_manager.disconnect(websocket)
126+
await self.socket_manager.disconnect(client_id)
123127

124128
def client(self, func: Callable):
125129
"""Decorator to mark a function to be compiled to the client."""
@@ -190,6 +194,30 @@ async def wrapper(body: BodyModel): # type: ignore
190194

191195
return func
192196

197+
def connect(self, func: Callable):
198+
"""
199+
Decorator to register a function to run
200+
everytime a client connects.
201+
202+
The callable receives a single parameter `client_id:str`.
203+
"""
204+
if not inspect.iscoroutinefunction(func):
205+
raise ValueError("Connect callbacks must be async")
206+
207+
self.connect_functions.append(func)
208+
209+
def disconnect(self, func: Callable):
210+
"""
211+
Decorator to register a function to run
212+
everytime a client disconnects.
213+
214+
The callable receives a single parameter `client_id:str`.
215+
"""
216+
if not inspect.iscoroutinefunction(func):
217+
raise ValueError("Connect callbacks must be async")
218+
219+
self.disconnect_functions.append(func)
220+
193221
def style(self, path: str, sheet: StyleSheet):
194222
"""
195223
Registers a stylesheet to be served by the app at a specific path.
@@ -620,16 +648,23 @@ async def broadcast(self, *args, **kwargs):
620648

621649

622650
class SocketManager:
623-
def __init__(self):
651+
def __init__(self, app: App):
624652
# Keep track of active connections
625-
self.active_connections: List[WebSocket] = []
653+
self.active_connections: dict[str, WebSocket] = {}
654+
self.app = app
626655

627-
async def connect(self, websocket: WebSocket):
656+
async def connect(self, client_id: str, websocket: WebSocket):
628657
await websocket.accept()
629-
self.active_connections.append(websocket)
658+
self.active_connections[client_id] = websocket
659+
660+
for func in self.app.connect_functions:
661+
await func(client_id)
662+
663+
async def disconnect(self, client_id: str):
664+
self.active_connections.pop(client_id)
630665

631-
def disconnect(self, websocket: WebSocket):
632-
self.active_connections.remove(websocket)
666+
for func in self.app.disconnect_functions:
667+
await func(client_id)
633668

634669
async def broadcast(self, func_name: str, args: tuple, kwargs: dict):
635670
"""
@@ -641,9 +676,9 @@ async def broadcast(self, func_name: str, args: tuple, kwargs: dict):
641676

642677
# Iterate over all connections and send the message
643678
# We use a copy of the list to avoid modification errors during iteration
644-
for connection in self.active_connections[:]:
679+
for id, connection in list(self.active_connections.items()):
645680
try:
646681
await connection.send_text(payload)
647682
except Exception:
648683
# If sending fails (e.g. client disconnected), remove it
649-
self.disconnect(connection)
684+
await self.disconnect(id)

violetear/client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
11
import asyncio
22
import json
3-
from js import document, window, WebSocket, console
4-
from pyodide.ffi import create_proxy
3+
import uuid
4+
5+
from violetear.storage import session
6+
7+
# Pyiodide-specific imports that don't work in the IDE
8+
from js import document, window, WebSocket, console # type: ignore
9+
from pyodide.ffi import create_proxy # type: ignore
10+
11+
12+
def get_client_id():
13+
client_id = session.get("VIOLETEAR_ID")
14+
15+
if client_id is None:
16+
client_id = str(uuid.uuid4())
17+
18+
session["VIOLETEAR_ID"] = client_id
19+
return client_id
520

621

722
def get_socket_url():
823
"""Calculates the correct WebSocket URL based on the current page."""
924
protocol = "wss" if window.location.protocol == "https:" else "ws"
1025
host = window.location.host
11-
return f"{protocol}://{host}/_violetear/ws"
26+
client_id = get_client_id()
27+
return f"{protocol}://{host}/_violetear/ws?client_id={client_id}"
1228

1329

1430
def setup_socket_listener(scope):

0 commit comments

Comments
 (0)