@@ -45,14 +45,24 @@ class SubscribeRequest:
4545
4646
4747class Subscription :
48- def __init__ (self , subscription_id : int , session : AsyncSession ):
48+ def __init__ (
49+ self , subscription_id : int , session : AsyncSession , event_handler : Callable [[types .Event ], Awaitable [None ]]
50+ ):
4951 self .subscription_id = subscription_id
5052 self ._session = session
53+ self ._event_handler = event_handler
5154
5255 async def unsubscribe (self ) -> None :
5356 if not await self ._session ._base_session .transport .is_connected ():
5457 raise Exception ("cannot unsubscribe topic: session not established" )
5558
59+ subscriptions = self ._session ._subscriptions .get (self .subscription_id , None )
60+ if subscriptions is not None :
61+ subscriptions .pop (self , None )
62+ if len (subscriptions ) != 0 :
63+ self ._session ._subscriptions [self .subscription_id ] = subscriptions
64+ return None
65+
5666 unsubscribe = messages .Unsubscribe (
5767 messages .UnsubscribeFields (self ._session ._idgen .next (), self .subscription_id )
5868 )
@@ -79,7 +89,7 @@ def __init__(self, base_session: types.IAsyncBaseSession):
7989 # PubSub data structures
8090 self ._publish_requests : dict [int , Future [None ]] = {}
8191 self ._subscribe_requests : dict [int , SubscribeRequest ] = {}
82- self ._subscriptions : dict [int , Callable [[ types . Event ], Awaitable [ None ] ]] = {}
92+ self ._subscriptions : dict [int , dict [ Subscription , Subscription ]] = {}
8393 self ._unsubscribe_requests : dict [int , types .UnsubscribeRequest ] = {}
8494
8595 self ._goodbye_request = Future ()
@@ -155,8 +165,14 @@ async def _process_incoming_message(self, msg: messages.Message):
155165 await self ._base_session .send (data )
156166 elif isinstance (msg , messages .Subscribed ):
157167 request = self ._subscribe_requests .pop (msg .request_id )
158- self ._subscriptions [msg .subscription_id ] = request .endpoint
159- request .future .set_result (Subscription (msg .subscription_id , self ))
168+ sub = Subscription (msg .subscription_id , self , request .endpoint )
169+ subscriptions = self ._subscriptions .get (msg .subscription_id , None )
170+ if subscriptions is None :
171+ self ._subscriptions [msg .subscription_id ] = {sub : sub }
172+ else :
173+ subscriptions [sub ] = sub
174+
175+ request .future .set_result (sub )
160176 elif isinstance (msg , messages .Unsubscribed ):
161177 request = self ._unsubscribe_requests .pop (msg .request_id )
162178 del self ._subscriptions [request .subscription_id ]
@@ -165,9 +181,11 @@ async def _process_incoming_message(self, msg: messages.Message):
165181 request = self ._publish_requests .pop (msg .request_id )
166182 request .set_result (None )
167183 elif isinstance (msg , messages .Event ):
168- endpoint = self ._subscriptions [msg .subscription_id ]
169184 try :
170- await endpoint (types .Event (msg .args , msg .kwargs , msg .details ))
185+ subscriptions = self ._subscriptions [msg .subscription_id ]
186+ event = types .Event (msg .args , msg .kwargs , msg .details )
187+ for subscription in subscriptions .keys ():
188+ await subscription ._event_handler (event )
171189 except Exception as e :
172190 print (e )
173191 elif isinstance (msg , messages .Error ):
0 commit comments