@@ -58,6 +58,7 @@ def __init__(self, host: str, session: ClientSession | None = None) -> None:
5858 self ._attempt_reconnection = False
5959 self ._reconnect_task : Optional [Task [Any ]] = None
6060 self .position_last_updated : datetime = datetime .now ()
61+ self ._subscription_tasks : dict [str , asyncio .Task [Any ]] = {}
6162
6263 async def register_state_update_callbacks (self , callback : Any ) -> None :
6364 """Register state update callback."""
@@ -109,6 +110,14 @@ async def disconnect(self) -> None:
109110 except asyncio .CancelledError :
110111 pass
111112 await self .do_state_update_callbacks (CallbackType .CONNECTION )
113+ # Cancel all subscription handler tasks
114+ for task in self ._subscription_tasks .values ():
115+ task .cancel ()
116+ await asyncio .gather (* self ._subscription_tasks .values (), return_exceptions = True )
117+ self ._subscription_tasks .clear ()
118+ # Properly close the aiohttp session if it was created by this client
119+ if self .session is not None and not self .session .closed :
120+ await self .session .close ()
112121
113122 def is_connected (self ) -> bool :
114123 """Return True if device is connected."""
@@ -230,7 +239,6 @@ async def consumer_handler(
230239 ) -> None :
231240 """Callback consumer handler."""
232241 subscription_queues = {}
233- subscription_tasks = {}
234242 try :
235243 async for raw_msg in ws :
236244 if futures or subscriptions :
@@ -244,14 +252,13 @@ async def consumer_handler(
244252 if not future .done ():
245253 future .set_result (msg )
246254 if subscription :
247- if path not in subscription_tasks :
255+ if path not in self . _subscription_tasks :
248256 queue : Queue [dict [str , Any ]] = asyncio .Queue ()
249257 subscription_queues [path ] = queue
250- subscription_tasks [path ] = asyncio .create_task (
258+ self . _subscription_tasks [path ] = asyncio .create_task (
251259 self .subscription_handler (queue , subscription )
252260 )
253261 subscription_queues [path ].put_nowait (msg )
254-
255262 except (asyncio .CancelledError ,):
256263 pass
257264
@@ -645,3 +652,10 @@ async def set_auto_power_down(self, auto_power_down_time_seconds: int) -> None:
645652 await self .request (
646653 ep .POWER , params = {"auto_power_down" : auto_power_down_time_seconds }
647654 )
655+
656+ async def __aenter__ (self ):
657+ await self .connect ()
658+ return self
659+
660+ async def __aexit__ (self , exc_type , exc , tb ):
661+ await self .disconnect ()
0 commit comments