Skip to content

Commit 4b2ac99

Browse files
committed
fix: allow multiple subscriptions to the same URI in a single session
1 parent 5c51923 commit 4b2ac99

2 files changed

Lines changed: 46 additions & 12 deletions

File tree

xconn/async_session.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,24 @@ class SubscribeRequest:
4545

4646

4747
class Subscription:
48-
def __init__(self, subscription_id: int, session: AsyncSession):
48+
def __init__(
49+
self, subscription_id: int, session: AsyncSession, event_handler: Callable[[types.Event], Awaitable[None]]
50+
):
4951
self.subscription_id = subscription_id
5052
self._session = session
53+
self._event_handler = event_handler
5154

5255
async def unsubscribe(self) -> None:
5356
if not await self._session._base_session.transport.is_connected():
5457
raise Exception("cannot unsubscribe topic: session not established")
5558

59+
subscriptions = self._session._subscriptions.get(self.subscription_id, None)
60+
if subscriptions is not None:
61+
subscriptions.pop(self, None)
62+
if len(subscriptions) != 0:
63+
self._session._subscriptions[self.subscription_id] = subscriptions
64+
return None
65+
5666
unsubscribe = messages.Unsubscribe(
5767
messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id)
5868
)
@@ -79,7 +89,7 @@ def __init__(self, base_session: types.IAsyncBaseSession):
7989
# PubSub data structures
8090
self._publish_requests: dict[int, Future[None]] = {}
8191
self._subscribe_requests: dict[int, SubscribeRequest] = {}
82-
self._subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {}
92+
self._subscriptions: dict[int, dict[Subscription, Subscription]] = {}
8393
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}
8494

8595
self._goodbye_request = Future()
@@ -155,8 +165,14 @@ async def _process_incoming_message(self, msg: messages.Message):
155165
await self._base_session.send(data)
156166
elif isinstance(msg, messages.Subscribed):
157167
request = self._subscribe_requests.pop(msg.request_id)
158-
self._subscriptions[msg.subscription_id] = request.endpoint
159-
request.future.set_result(Subscription(msg.subscription_id, self))
168+
sub = Subscription(msg.subscription_id, self, request.endpoint)
169+
subscriptions = self._subscriptions.get(msg.subscription_id, None)
170+
if subscriptions is None:
171+
self._subscriptions[msg.subscription_id] = {sub: sub}
172+
else:
173+
subscriptions[sub] = sub
174+
175+
request.future.set_result(sub)
160176
elif isinstance(msg, messages.Unsubscribed):
161177
request = self._unsubscribe_requests.pop(msg.request_id)
162178
del self._subscriptions[request.subscription_id]
@@ -165,9 +181,11 @@ async def _process_incoming_message(self, msg: messages.Message):
165181
request = self._publish_requests.pop(msg.request_id)
166182
request.set_result(None)
167183
elif isinstance(msg, messages.Event):
168-
endpoint = self._subscriptions[msg.subscription_id]
169184
try:
170-
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
185+
subscriptions = self._subscriptions[msg.subscription_id]
186+
event = types.Event(msg.args, msg.kwargs, msg.details)
187+
for subscription in subscriptions.keys():
188+
await subscription._event_handler(event)
171189
except Exception as e:
172190
print(e)
173191
elif isinstance(msg, messages.Error):

xconn/session.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,22 @@ class SubscribeRequest:
4444

4545

4646
class Subscription:
47-
def __init__(self, subscription_id: int, session: Session):
47+
def __init__(self, subscription_id: int, session: Session, event_handler: Callable[[types.Event], None]):
4848
self.subscription_id = subscription_id
4949
self._session = session
50+
self._event_handler = event_handler
5051

5152
def unsubscribe(self) -> None:
5253
if not self._session._base_session.transport.is_connected():
5354
raise Exception("cannot unsubscribe topic: session not established")
5455

56+
subscriptions = self._session._subscriptions.get(self.subscription_id, None)
57+
if subscriptions is not None:
58+
subscriptions.pop(self, None)
59+
if len(subscriptions) != 0:
60+
self._session._subscriptions[self.subscription_id] = subscriptions
61+
return None
62+
5563
unsubscribe = messages.Unsubscribe(
5664
messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id)
5765
)
@@ -75,7 +83,7 @@ def __init__(self, base_session: types.BaseSession):
7583
# PubSub data structures
7684
self._publish_requests: dict[int, Future[None]] = {}
7785
self._subscribe_requests: dict[int, SubscribeRequest] = {}
78-
self._subscriptions: dict[int, Callable[[types.Event], None]] = {}
86+
self._subscriptions: dict[int, dict[Subscription, Subscription]] = {}
7987
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}
8088

8189
self._goodbye_request = Future()
@@ -150,8 +158,14 @@ def _process_incoming_message(self, msg: messages.Message):
150158
self._base_session.send(data)
151159
elif isinstance(msg, messages.Subscribed):
152160
request = self._subscribe_requests.pop(msg.request_id)
153-
self._subscriptions[msg.subscription_id] = request.endpoint
154-
request.future.set_result(Subscription(msg.subscription_id, self))
161+
sub = Subscription(msg.subscription_id, self, request.endpoint)
162+
subscriptions = self._subscriptions.get(msg.subscription_id, None)
163+
if subscriptions is None:
164+
self._subscriptions[msg.subscription_id] = {sub: sub}
165+
else:
166+
subscriptions[sub] = sub
167+
168+
request.future.set_result(sub)
155169
elif isinstance(msg, messages.Unsubscribed):
156170
request = self._unsubscribe_requests.pop(msg.request_id)
157171
del self._subscriptions[request.subscription_id]
@@ -161,8 +175,10 @@ def _process_incoming_message(self, msg: messages.Message):
161175
request.set_result(None)
162176
elif isinstance(msg, messages.Event):
163177
try:
164-
endpoint = self._subscriptions[msg.subscription_id]
165-
endpoint(types.Event(msg.args, msg.kwargs, msg.details))
178+
subscriptions = self._subscriptions[msg.subscription_id]
179+
event = types.Event(msg.args, msg.kwargs, msg.details)
180+
for subscription in subscriptions.keys():
181+
subscription._event_handler(event)
166182
except Exception as e:
167183
print(e)
168184
elif isinstance(msg, messages.Error):

0 commit comments

Comments
 (0)