@@ -168,6 +168,106 @@ def __init__(
168168 # Start background loading
169169 self ._loading_task = asyncio .create_task (self ._load_sitemaps ())
170170
171+ async def __aenter__ (self ) -> SitemapRequestLoader :
172+ """Enter the context manager."""
173+ await self .start ()
174+ return self
175+
176+ async def __aexit__ (
177+ self ,
178+ exc_type : type [BaseException ] | None ,
179+ exc_value : BaseException | None ,
180+ exc_traceback : TracebackType | None ,
181+ ) -> None :
182+ """Exit the context manager."""
183+ await self .close ()
184+
185+ @override
186+ async def get_total_count (self ) -> int :
187+ """Return the total number of URLs found so far."""
188+ state = await self ._get_state ()
189+ return state .total_count
190+
191+ @override
192+ async def get_handled_count (self ) -> int :
193+ """Return the number of URLs that have been handled."""
194+ state = await self ._get_state ()
195+ return state .handled_count
196+
197+ @override
198+ async def is_empty (self ) -> bool :
199+ """Check if there are no more URLs to process."""
200+ state = await self ._get_state ()
201+ return not state .url_queue
202+
203+ @override
204+ async def is_finished (self ) -> bool :
205+ """Check if all URLs have been processed."""
206+ state = await self ._get_state ()
207+ return not state .url_queue and len (state .in_progress ) == 0 and self ._loading_task .done ()
208+
209+ @override
210+ async def fetch_next_request (self ) -> Request | None :
211+ """Fetch the next request to process."""
212+ while not (await self .is_finished ()):
213+ state = await self ._get_state ()
214+ if not state .url_queue :
215+ await asyncio .sleep (0.1 )
216+ continue
217+
218+ async with self ._queue_lock :
219+ # Double-check if the queue is still not empty after acquiring the lock
220+ if not state .url_queue :
221+ continue
222+
223+ url = state .url_queue .popleft ()
224+ request_option = RequestOptions (url = url , enqueue_strategy = self ._enqueue_strategy )
225+
226+ if len (state .url_queue ) < self ._max_buffer_size :
227+ self ._queue_has_capacity .set ()
228+
229+ if self ._transform_request_function :
230+ transform_request_option = self ._transform_request_function (request_option )
231+ if transform_request_option == 'skip' :
232+ state .total_count -= 1
233+ continue
234+ if transform_request_option != 'unchanged' :
235+ request_option = transform_request_option
236+
237+ request = Request .from_url (** request_option )
238+ state .in_progress .add (request .url )
239+
240+ return request
241+
242+ return None
243+
244+ @override
245+ async def mark_request_as_handled (self , request : Request ) -> ProcessedRequest | None :
246+ """Mark a request as successfully handled."""
247+ state = await self ._get_state ()
248+ if request .url in state .in_progress :
249+ state .in_progress .remove (request .url )
250+ state .handled_count += 1
251+ return None
252+
253+ async def start (self ) -> None :
254+ """Start the sitemap loading process."""
255+ if self ._loading_task and not self ._loading_task .done ():
256+ return
257+ self ._loading_task = asyncio .create_task (self ._load_sitemaps ())
258+
259+ async def abort_loading (self ) -> None :
260+ """Abort the sitemap loading process."""
261+ if self ._loading_task and not self ._loading_task .done ():
262+ self ._loading_task .cancel ()
263+ with suppress (asyncio .CancelledError ):
264+ await self ._loading_task
265+
266+ async def close (self ) -> None :
267+ """Close the request loader."""
268+ await self .abort_loading ()
269+ await self ._state .teardown ()
270+
171271 async def _get_state (self ) -> SitemapRequestLoaderState :
172272 """Initialize and return the current state."""
173273 if self ._state .is_initialized :
@@ -310,100 +410,3 @@ async def _load_sitemaps(self) -> None:
310410 except Exception :
311411 logger .exception ('Error loading sitemaps' )
312412 raise
313-
314- @override
315- async def get_total_count (self ) -> int :
316- """Return the total number of URLs found so far."""
317- state = await self ._get_state ()
318- return state .total_count
319-
320- @override
321- async def get_handled_count (self ) -> int :
322- """Return the number of URLs that have been handled."""
323- state = await self ._get_state ()
324- return state .handled_count
325-
326- @override
327- async def is_empty (self ) -> bool :
328- """Check if there are no more URLs to process."""
329- state = await self ._get_state ()
330- return not state .url_queue
331-
332- @override
333- async def is_finished (self ) -> bool :
334- """Check if all URLs have been processed."""
335- state = await self ._get_state ()
336- return not state .url_queue and len (state .in_progress ) == 0 and self ._loading_task .done ()
337-
338- @override
339- async def fetch_next_request (self ) -> Request | None :
340- """Fetch the next request to process."""
341- while not (await self .is_finished ()):
342- state = await self ._get_state ()
343- if not state .url_queue :
344- await asyncio .sleep (0.1 )
345- continue
346-
347- async with self ._queue_lock :
348- # Double-check if the queue is still not empty after acquiring the lock
349- if not state .url_queue :
350- continue
351-
352- url = state .url_queue .popleft ()
353- request_option = RequestOptions (url = url , enqueue_strategy = self ._enqueue_strategy )
354-
355- if len (state .url_queue ) < self ._max_buffer_size :
356- self ._queue_has_capacity .set ()
357-
358- if self ._transform_request_function :
359- transform_request_option = self ._transform_request_function (request_option )
360- if transform_request_option == 'skip' :
361- state .total_count -= 1
362- continue
363- if transform_request_option != 'unchanged' :
364- request_option = transform_request_option
365-
366- request = Request .from_url (** request_option )
367- state .in_progress .add (request .url )
368-
369- return request
370-
371- return None
372-
373- @override
374- async def mark_request_as_handled (self , request : Request ) -> ProcessedRequest | None :
375- """Mark a request as successfully handled."""
376- state = await self ._get_state ()
377- if request .url in state .in_progress :
378- state .in_progress .remove (request .url )
379- state .handled_count += 1
380- return None
381-
382- async def abort_loading (self ) -> None :
383- """Abort the sitemap loading process."""
384- if self ._loading_task and not self ._loading_task .done ():
385- self ._loading_task .cancel ()
386- with suppress (asyncio .CancelledError ):
387- await self ._loading_task
388-
389- async def start (self ) -> None :
390- """Start the sitemap loading process."""
391- if self ._loading_task and not self ._loading_task .done ():
392- return
393- self ._loading_task = asyncio .create_task (self ._load_sitemaps ())
394-
395- async def close (self ) -> None :
396- """Close the request loader."""
397- await self .abort_loading ()
398- await self ._state .teardown ()
399-
400- async def __aenter__ (self ) -> SitemapRequestLoader :
401- """Enter the context manager."""
402- await self .start ()
403- return self
404-
405- async def __aexit__ (
406- self , exc_type : type [BaseException ] | None , exc_value : BaseException | None , exc_traceback : TracebackType | None
407- ) -> None :
408- """Exit the context manager."""
409- await self .close ()
0 commit comments