Skip to content

Commit 84a8f06

Browse files
authored
Close session on disconnect and add context manager (noahhusby#151)
1 parent 2101bcb commit 84a8f06

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

aiostreammagic/stream_magic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, host: str, session: ClientSession | None = None) -> None:
5858
self._attempt_reconnection = False
5959
self._reconnect_task: Optional[Task[Any]] = None
6060
self.position_last_updated: datetime = datetime.now()
61+
self._subscription_tasks: dict[str, asyncio.Task[Any]] = {}
6162

6263
async def register_state_update_callbacks(self, callback: Any) -> None:
6364
"""Register state update callback."""
@@ -109,6 +110,14 @@ async def disconnect(self) -> None:
109110
except asyncio.CancelledError:
110111
pass
111112
await self.do_state_update_callbacks(CallbackType.CONNECTION)
113+
# Cancel all subscription handler tasks
114+
for task in self._subscription_tasks.values():
115+
task.cancel()
116+
await asyncio.gather(*self._subscription_tasks.values(), return_exceptions=True)
117+
self._subscription_tasks.clear()
118+
# Properly close the aiohttp session if it was created by this client
119+
if self.session is not None and not self.session.closed:
120+
await self.session.close()
112121

113122
def is_connected(self) -> bool:
114123
"""Return True if device is connected."""
@@ -230,7 +239,6 @@ async def consumer_handler(
230239
) -> None:
231240
"""Callback consumer handler."""
232241
subscription_queues = {}
233-
subscription_tasks = {}
234242
try:
235243
async for raw_msg in ws:
236244
if futures or subscriptions:
@@ -244,14 +252,13 @@ async def consumer_handler(
244252
if not future.done():
245253
future.set_result(msg)
246254
if subscription:
247-
if path not in subscription_tasks:
255+
if path not in self._subscription_tasks:
248256
queue: Queue[dict[str, Any]] = asyncio.Queue()
249257
subscription_queues[path] = queue
250-
subscription_tasks[path] = asyncio.create_task(
258+
self._subscription_tasks[path] = asyncio.create_task(
251259
self.subscription_handler(queue, subscription)
252260
)
253261
subscription_queues[path].put_nowait(msg)
254-
255262
except (asyncio.CancelledError,):
256263
pass
257264

@@ -645,3 +652,10 @@ async def set_auto_power_down(self, auto_power_down_time_seconds: int) -> None:
645652
await self.request(
646653
ep.POWER, params={"auto_power_down": auto_power_down_time_seconds}
647654
)
655+
656+
async def __aenter__(self):
657+
await self.connect()
658+
return self
659+
660+
async def __aexit__(self, exc_type, exc, tb):
661+
await self.disconnect()

0 commit comments

Comments
 (0)