diff --git a/docs/guides/request_loaders.mdx b/docs/guides/request_loaders.mdx index 2c5607c8ff..a63393142d 100644 --- a/docs/guides/request_loaders.mdx +++ b/docs/guides/request_loaders.mdx @@ -136,7 +136,7 @@ The `SitemapRequestLoader` is The `SitemapRequestLoader` is designed specifically for sitemaps that follow the standard Sitemaps protocol. HTML pages containing links are not supported by this loader - those should be handled by regular crawlers using the `enqueue_links` functionality. ::: -The loader supports filtering URLs using glob patterns and regular expressions, allowing you to include or exclude specific types of URLs. The `SitemapRequestLoader` provides streaming processing of sitemaps, ensuring efficient memory usage without loading the entire sitemap into memory. +The loader supports filtering URLs using glob patterns and regular expressions, allowing you to include or exclude specific types of URLs. By default, the loader also keeps only URLs whose host matches their parent sitemap (`enqueue_strategy='same-hostname'`), matching the `enqueue_links` default. Pass `enqueue_strategy='all'` to disable this filter, or `'same-domain'` / `'same-origin'` for other scopes. The `SitemapRequestLoader` provides streaming processing of sitemaps, ensuring efficient memory usage without loading the entire sitemap into memory. {SitemapExample} diff --git a/src/crawlee/_utils/robots.py b/src/crawlee/_utils/robots.py index 67583c90eb..d4a5e00cbc 100644 --- a/src/crawlee/_utils/robots.py +++ b/src/crawlee/_utils/robots.py @@ -7,11 +7,13 @@ from yarl import URL from crawlee._utils.sitemap import Sitemap +from crawlee._utils.urls import filter_url from crawlee._utils.web import is_status_code_client_error if TYPE_CHECKING: from typing_extensions import Self + from crawlee._types import EnqueueStrategy from crawlee.http_clients import HttpClient from crawlee.proxy_configuration import ProxyInfo @@ -21,7 +23,11 @@ class RobotsTxtFile: def __init__( - self, url: str, robots: Protego, http_client: HttpClient | None = None, proxy_info: ProxyInfo | None = None + self, + url: str, + robots: Protego, + http_client: HttpClient | None = None, + proxy_info: ProxyInfo | None = None, ) -> None: self._robots = robots self._original_url = URL(url).origin() @@ -39,18 +45,6 @@ async def from_content(cls, url: str, content: str) -> Self: robots = Protego.parse(content) return cls(url, robots) - @classmethod - async def find(cls, url: str, http_client: HttpClient, proxy_info: ProxyInfo | None = None) -> Self: - """Determine the location of a robots.txt file for a URL and fetch it. - - Args: - url: The URL whose domain will be used to find the corresponding robots.txt file. - http_client: Optional `ProxyInfo` to be used when fetching the robots.txt file. If None, no proxy is used. - proxy_info: The `HttpClient` instance used to perform the network request for fetching the robots.txt file. - """ - robots_url = URL(url).with_path('/robots.txt') - return await cls.load(str(robots_url), http_client, proxy_info) - @classmethod async def load(cls, url: str, http_client: HttpClient, proxy_info: ProxyInfo | None = None) -> Self: """Load the robots.txt file for a given URL. @@ -77,6 +71,18 @@ async def load(cls, url: str, http_client: HttpClient, proxy_info: ProxyInfo | N return cls(url, robots, http_client=http_client, proxy_info=proxy_info) + @classmethod + async def find(cls, url: str, http_client: HttpClient, proxy_info: ProxyInfo | None = None) -> Self: + """Determine the location of a robots.txt file for a URL and fetch it. + + Args: + url: The URL whose domain will be used to find the corresponding robots.txt file. + http_client: Optional `ProxyInfo` to be used when fetching the robots.txt file. If None, no proxy is used. + proxy_info: The `HttpClient` instance used to perform the network request for fetching the robots.txt file. + """ + robots_url = URL(url).with_path('/robots.txt') + return await cls.load(str(robots_url), http_client, proxy_info) + def is_allowed(self, url: str, user_agent: str = '*') -> bool: """Check if the given URL is allowed for the given user agent. @@ -89,9 +95,25 @@ def is_allowed(self, url: str, user_agent: str = '*') -> bool: return True return bool(self._robots.can_fetch(str(check_url), user_agent)) - def get_sitemaps(self) -> list[str]: - """Get the list of sitemaps urls from the robots.txt file.""" - return list(self._robots.sitemaps) + def get_sitemaps(self, *, enqueue_strategy: EnqueueStrategy) -> list[str]: + """Get the list of sitemap URLs from the robots.txt file, filtered by enqueue strategy. + + Args: + enqueue_strategy: Strategy used to filter sitemap entries relative to the robots.txt URL's host. + Pass `'same-hostname'` to match the sitemap protocol's same-host expectation, or `'all'` to + disable host filtering. Regardless of the strategy, entries with non-`http(s)` schemes are + always filtered out. + """ + sitemaps: list[str] = [] + for sitemap_url in self._robots.sitemaps: + ok, reason = filter_url(target=sitemap_url, strategy=enqueue_strategy, origin=self._original_url) + if not ok: + logger.warning( + f'Skipping sitemap {sitemap_url!r} listed in robots.txt at {str(self._original_url)!r}: {reason}.' + ) + continue + sitemaps.append(sitemap_url) + return sitemaps def get_crawl_delay(self, user_agent: str = '*') -> int | None: """Get the crawl delay for the given user agent. @@ -103,15 +125,23 @@ def get_crawl_delay(self, user_agent: str = '*') -> int | None: crawl_delay = self._robots.crawl_delay(user_agent) return int(crawl_delay) if crawl_delay is not None else None - async def parse_sitemaps(self) -> Sitemap: - """Parse the sitemaps from the robots.txt file and return a `Sitemap` instance.""" - sitemaps = self.get_sitemaps() + async def parse_sitemaps(self, *, enqueue_strategy: EnqueueStrategy) -> Sitemap: + """Parse the sitemaps from the robots.txt file and return a `Sitemap` instance. + + Args: + enqueue_strategy: Forwarded to `get_sitemaps`; see that method for details. + """ + sitemaps = self.get_sitemaps(enqueue_strategy=enqueue_strategy) if not self._http_client: raise ValueError('HTTP client is required to parse sitemaps.') return await Sitemap.load(sitemaps, self._http_client, self._proxy_info) - async def parse_urls_from_sitemaps(self) -> list[str]: - """Parse the sitemaps in the robots.txt file and return a list URLs.""" - sitemap = await self.parse_sitemaps() + async def parse_urls_from_sitemaps(self, *, enqueue_strategy: EnqueueStrategy) -> list[str]: + """Parse the sitemaps in the robots.txt file and return a list URLs. + + Args: + enqueue_strategy: Forwarded to `get_sitemaps`; see that method for details. + """ + sitemap = await self.parse_sitemaps(enqueue_strategy=enqueue_strategy) return sitemap.urls diff --git a/src/crawlee/_utils/sitemap.py b/src/crawlee/_utils/sitemap.py index b90d2e6935..2632ed98e5 100644 --- a/src/crawlee/_utils/sitemap.py +++ b/src/crawlee/_utils/sitemap.py @@ -546,7 +546,7 @@ def _check_and_add(url: str) -> bool: # Try getting sitemaps from robots.txt first robots = await RobotsTxtFile.find(url=hostname_urls[0], http_client=http_client, proxy_info=proxy_info) - for sitemap_url in robots.get_sitemaps(): + for sitemap_url in robots.get_sitemaps(enqueue_strategy='same-hostname'): if _check_and_add(sitemap_url): yield sitemap_url diff --git a/src/crawlee/_utils/urls.py b/src/crawlee/_utils/urls.py index 0bc5a051c7..43a52e8385 100644 --- a/src/crawlee/_utils/urls.py +++ b/src/crawlee/_utils/urls.py @@ -1,14 +1,30 @@ from __future__ import annotations +import tempfile +from functools import lru_cache from typing import TYPE_CHECKING from pydantic import AnyHttpUrl, TypeAdapter +from tldextract import TLDExtract +from typing_extensions import assert_never from yarl import URL if TYPE_CHECKING: from collections.abc import Iterator from logging import Logger + from crawlee._types import EnqueueStrategy + + +_ALLOWED_SCHEMES: frozenset[str] = frozenset({'http', 'https'}) +"""URL schemes Crawlee accepts for fetching and enqueuing.""" + +UNSUPPORTED_SCHEME_MESSAGE = 'unsupported URL scheme (only http and https are allowed).' +"""Reusable suffix for log messages explaining why a non-`http(s)` URL was rejected.""" + +_HTTP_URL_ADAPTER: TypeAdapter[AnyHttpUrl] = TypeAdapter(AnyHttpUrl) +"""Pydantic validator for HTTP and HTTPS URLs.""" + def is_url_absolute(url: str) -> bool: """Check if a URL is absolute.""" @@ -38,16 +54,102 @@ def to_absolute_url_iterator(base_url: str, urls: Iterator[str], logger: Logger yield converted_url -_http_url_adapter = TypeAdapter(AnyHttpUrl) - - def validate_http_url(value: str | None) -> str | None: """Validate the given HTTP URL. + Args: + value: The URL to validate, or `None` to skip validation. + Raises: - pydantic.ValidationError: If the URL is not valid. + pydantic.ValidationError: If the URL is malformed or its scheme is not `http`/`https`. """ if value is not None: - _http_url_adapter.validate_python(value) + _HTTP_URL_ADAPTER.validate_python(value) return value + + +def filter_url( + *, + target: str | URL, + strategy: EnqueueStrategy, + origin: str | URL, +) -> tuple[bool, str | None]: + """Check whether `target` is eligible to be enqueued under `strategy` relative to `origin`. + + Combines the two checks every enqueue site needs: the URL must use a supported scheme + (`http` or `https`), and it must match `strategy` relative to `origin`. Callers that need to + distinguish a scheme rejection from a strategy mismatch (for different log levels or dedup) + can compare the returned reason against `UNSUPPORTED_SCHEME_MESSAGE`. + + Args: + target: The URL being evaluated. + strategy: The enqueue strategy to apply. + origin: The reference URL the target is compared against. + + Returns: + `(True, None)` if `target` is eligible. Otherwise `(False, reason)` where `reason` is + a human-readable rejection message suitable for log output. + """ + target_url = _to_url(target) + + if not _is_supported_url_scheme(target_url): + return False, UNSUPPORTED_SCHEME_MESSAGE + + if not _matches_enqueue_strategy(strategy, target_url=target_url, origin_url=_to_url(origin)): + return False, f'does not match enqueue strategy {strategy!r}' + + return True, None + + +def _is_supported_url_scheme(url: str | URL) -> bool: + """Return whether `url` uses a scheme Crawlee accepts (http or https).""" + return _to_url(url).scheme in _ALLOWED_SCHEMES + + +def _matches_enqueue_strategy( + strategy: EnqueueStrategy, + *, + target_url: URL, + origin_url: URL, +) -> bool: + """Check whether `target_url` matches `origin_url` under `strategy`. Scheme is not considered.""" + if strategy == 'all': + return True + + if origin_url.host is None or target_url.host is None: + return False + + if strategy == 'same-hostname': + return target_url.host == origin_url.host + + if strategy == 'same-domain': + return _domain_under_public_suffix(origin_url.host) == _domain_under_public_suffix(target_url.host) + + if strategy == 'same-origin': + return ( + target_url.host == origin_url.host + and target_url.scheme == origin_url.scheme + and target_url.port == origin_url.port + ) + + assert_never(strategy) + + +def _to_url(value: str | URL) -> URL: + return URL(value) if isinstance(value, str) else value + + +@lru_cache(maxsize=1) +def _get_tld_extractor() -> TLDExtract: + """Return a lazily-initialized `TLDExtract` instance shared across the module.""" + # `mkdtemp` (vs `TemporaryDirectory`) returns a path whose lifetime is tied to the process — `TemporaryDirectory` + # is collected immediately when its return value is discarded, which would race the directory out from under + # tldextract. + return TLDExtract(cache_dir=tempfile.mkdtemp()) + + +@lru_cache(maxsize=2048) +def _domain_under_public_suffix(host: str) -> str: + """Return the registrable domain for `host`, cached to avoid re-running the PSL lookup.""" + return _get_tld_extractor().extract_str(host).top_domain_under_public_suffix diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 636d38775c..ae58f60d82 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -6,7 +6,6 @@ import logging import signal import sys -import tempfile import threading import traceback from asyncio import CancelledError @@ -17,15 +16,13 @@ from io import StringIO from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, cast -from urllib.parse import ParseResult, urlparse from weakref import WeakKeyDictionary from cachetools import LRUCache -from tldextract import TLDExtract -from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack, assert_never +from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack from yarl import URL -from crawlee import EnqueueStrategy, Glob, RequestTransformAction, service_locator +from crawlee import Glob, RequestTransformAction, service_locator from crawlee._autoscaling import AutoscaledPool, Snapshotter, SystemStatus from crawlee._log_config import configure_logger, get_configured_log_level, string_to_log_level from crawlee._request import Request, RequestOptions, RequestState @@ -48,7 +45,7 @@ from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream from crawlee._utils.recurring_task import RecurringTask from crawlee._utils.robots import RobotsTxtFile -from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute +from crawlee._utils.urls import UNSUPPORTED_SCHEME_MESSAGE, convert_to_absolute_url, filter_url, is_url_absolute from crawlee._utils.wait import wait_for from crawlee._utils.web import is_status_code_client_error, is_status_code_server_error from crawlee.errors import ( @@ -485,7 +482,6 @@ async def persist_state_factory() -> KeyValueStore: # Internal, not explicitly configurable components self._robots_txt_file_cache: LRUCache[str, RobotsTxtFile] = LRUCache(maxsize=1000) self._robots_txt_lock = asyncio.Lock() - self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) self._snapshotter = Snapshotter.from_config(config) self._autoscaled_pool = AutoscaledPool( system_status=SystemStatus(self._snapshotter), @@ -978,16 +974,18 @@ def _should_retry_request(self, context: BasicCrawlingContext, error: Exception) async def _check_url_after_redirects(self, context: TCrawlingContext) -> AsyncGenerator[TCrawlingContext, None]: """Ensure that the `loaded_url` still matches the enqueue strategy after redirects. - Filter out links that redirect outside of the crawled domain. + Filter out links that redirect outside of the crawled domain or to unsupported URL schemes. """ - if context.request.loaded_url is not None and not self._check_enqueue_strategy( - context.request.enqueue_strategy, - origin_url=urlparse(context.request.url), - target_url=urlparse(context.request.loaded_url), - ): - raise ContextPipelineInterruptedError( - f'Skipping URL {context.request.loaded_url} (redirected from {context.request.url})' + if context.request.loaded_url is not None: + ok, reason = filter_url( + target=context.request.loaded_url, + strategy=context.request.enqueue_strategy, + origin=context.request.url, ) + if not ok: + raise ContextPipelineInterruptedError( + f'Skipping URL {context.request.loaded_url} (redirected from {context.request.url}): {reason}' + ) yield context @@ -1054,15 +1052,16 @@ def _enqueue_links_filter_iterator( ) -> Iterator[TRequestIterator]: """Filter requests based on the enqueue strategy and URL patterns.""" limit = kwargs.get('limit') - parsed_origin_url = urlparse(origin_url) + parsed_origin_url = URL(origin_url) strategy = kwargs.get('strategy', 'all') - if strategy == 'all' and not parsed_origin_url.hostname: + if strategy == 'all' and not parsed_origin_url.host: self.log.warning(f'Skipping enqueue: Missing hostname in origin_url = {origin_url}.') return - # Emit a `warning` message to the log, only once per call - warning_flag = True + # Each warning is emitted at most once per call. + host_warned = False + scheme_warned = False for request in request_iterator: if isinstance(request, Request): @@ -1071,15 +1070,22 @@ def _enqueue_links_filter_iterator( target_url = request.url else: target_url = request - parsed_target_url = urlparse(target_url) - - if warning_flag and strategy != 'all' and not parsed_target_url.hostname: - self.log.warning(f'Skipping enqueue url: Missing hostname in target_url = {target_url}.') - warning_flag = False - - if self._check_enqueue_strategy( - strategy, target_url=parsed_target_url, origin_url=parsed_origin_url - ) and self._check_url_patterns(target_url, kwargs.get('include'), kwargs.get('exclude')): + parsed_target_url = URL(target_url) + + ok, reason = filter_url(target=parsed_target_url, strategy=strategy, origin=parsed_origin_url) + if not ok: + # Strategy mismatches are expected (most extracted links are external) so stay silent. + # Scheme rejections and missing hostnames signal a misconfiguration upstream, so warn. + if reason == UNSUPPORTED_SCHEME_MESSAGE: + if not scheme_warned: + self.log.warning(f'Skipping enqueue url {target_url!r}: {reason}') + scheme_warned = True + elif not parsed_target_url.host and not host_warned: + self.log.warning(f'Skipping enqueue url: Missing hostname in target_url = {target_url}.') + host_warned = True + continue + + if self._check_url_patterns(target_url, kwargs.get('include'), kwargs.get('exclude')): yield request if limit is not None: @@ -1087,41 +1093,6 @@ def _enqueue_links_filter_iterator( if limit <= 0: break - def _check_enqueue_strategy( - self, - strategy: EnqueueStrategy, - *, - target_url: ParseResult, - origin_url: ParseResult, - ) -> bool: - """Check if a URL matches the enqueue_strategy.""" - if strategy == 'all': - return True - - if origin_url.hostname is None or target_url.hostname is None: - self.log.debug( - f'Skipping enqueue: Missing hostname in origin_url = {origin_url.geturl()} or ' - f'target_url = {target_url.geturl()}' - ) - return False - - if strategy == 'same-hostname': - return target_url.hostname == origin_url.hostname - - if strategy == 'same-domain': - origin_domain = self._tld_extractor.extract_str(origin_url.hostname).top_domain_under_public_suffix - target_domain = self._tld_extractor.extract_str(target_url.hostname).top_domain_under_public_suffix - return origin_domain == target_domain - - if strategy == 'same-origin': - return ( - target_url.hostname == origin_url.hostname - and target_url.scheme == origin_url.scheme - and target_url.port == origin_url.port - ) - - assert_never(strategy) - def _check_url_patterns( self, target_url: str, diff --git a/src/crawlee/request_loaders/_sitemap_request_loader.py b/src/crawlee/request_loaders/_sitemap_request_loader.py index e1c4c4d2e1..664686b23f 100644 --- a/src/crawlee/request_loaders/_sitemap_request_loader.py +++ b/src/crawlee/request_loaders/_sitemap_request_loader.py @@ -8,12 +8,14 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import override +from yarl import URL from crawlee import Request, RequestOptions from crawlee._utils.docs import docs_group from crawlee._utils.globs import Glob from crawlee._utils.recoverable_state import RecoverableState from crawlee._utils.sitemap import NestedSitemap, ParseSitemapOptions, SitemapSource, SitemapUrl, parse_sitemap +from crawlee._utils.urls import filter_url from crawlee.request_loaders._request_loader import RequestLoader if TYPE_CHECKING: @@ -22,6 +24,7 @@ from types import TracebackType from crawlee import RequestTransformAction + from crawlee._types import EnqueueStrategy from crawlee.http_clients import HttpClient from crawlee.proxy_configuration import ProxyInfo from crawlee.storage_clients.models import ProcessedRequest @@ -111,6 +114,7 @@ def __init__( proxy_info: ProxyInfo | None = None, include: list[re.Pattern[Any] | Glob] | None = None, exclude: list[re.Pattern[Any] | Glob] | None = None, + enqueue_strategy: EnqueueStrategy = 'same-hostname', max_buffer_size: int = 200, persist_state_key: str | None = None, transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None, @@ -122,6 +126,11 @@ def __init__( proxy_info: Optional proxy to use for fetching sitemaps. include: List of glob or regex patterns to include URLs. exclude: List of glob or regex patterns to exclude URLs. + enqueue_strategy: Strategy used to decide which sitemap-derived URLs (both nested-sitemap entries and + URL entries) are kept relative to the parent sitemap URL. Defaults to `'same-hostname'`, matching + the sitemap protocol's same-host expectation and the `enqueue_links` default; pass `'all'` to + disable filtering. Note: regardless of `enqueue_strategy`, entries with non-`http(s)` schemes are + always filtered out. max_buffer_size: Maximum number of URLs to buffer in memory. http_client: the instance of `HttpClient` to use for fetching sitemaps. persist_state_key: A key for persisting the loader's state in the KeyValueStore. @@ -135,6 +144,7 @@ def __init__( self._sitemap_urls = sitemap_urls self._include = include self._exclude = exclude + self._enqueue_strategy = enqueue_strategy self._proxy_info = proxy_info self._max_buffer_size = max_buffer_size self._transform_request_function = transform_request_function @@ -158,6 +168,106 @@ def __init__( # Start background loading self._loading_task = asyncio.create_task(self._load_sitemaps()) + async def __aenter__(self) -> SitemapRequestLoader: + """Enter the context manager.""" + await self.start() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + """Exit the context manager.""" + await self.close() + + @override + async def get_total_count(self) -> int: + """Return the total number of URLs found so far.""" + state = await self._get_state() + return state.total_count + + @override + async def get_handled_count(self) -> int: + """Return the number of URLs that have been handled.""" + state = await self._get_state() + return state.handled_count + + @override + async def is_empty(self) -> bool: + """Check if there are no more URLs to process.""" + state = await self._get_state() + return not state.url_queue + + @override + async def is_finished(self) -> bool: + """Check if all URLs have been processed.""" + state = await self._get_state() + return not state.url_queue and len(state.in_progress) == 0 and self._loading_task.done() + + @override + async def fetch_next_request(self) -> Request | None: + """Fetch the next request to process.""" + while not (await self.is_finished()): + state = await self._get_state() + if not state.url_queue: + await asyncio.sleep(0.1) + continue + + async with self._queue_lock: + # Double-check if the queue is still not empty after acquiring the lock + if not state.url_queue: + continue + + url = state.url_queue.popleft() + request_option = RequestOptions(url=url, enqueue_strategy=self._enqueue_strategy) + + if len(state.url_queue) < self._max_buffer_size: + self._queue_has_capacity.set() + + if self._transform_request_function: + transform_request_option = self._transform_request_function(request_option) + if transform_request_option == 'skip': + state.total_count -= 1 + continue + if transform_request_option != 'unchanged': + request_option = transform_request_option + + request = Request.from_url(**request_option) + state.in_progress.add(request.url) + + return request + + return None + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as successfully handled.""" + state = await self._get_state() + if request.url in state.in_progress: + state.in_progress.remove(request.url) + state.handled_count += 1 + return None + + async def start(self) -> None: + """Start the sitemap loading process.""" + if self._loading_task and not self._loading_task.done(): + return + self._loading_task = asyncio.create_task(self._load_sitemaps()) + + async def abort_loading(self) -> None: + """Abort the sitemap loading process.""" + if self._loading_task and not self._loading_task.done(): + self._loading_task.cancel() + with suppress(asyncio.CancelledError): + await self._loading_task + + async def close(self) -> None: + """Close the request loader.""" + await self.abort_loading() + await self._state.teardown() + async def _get_state(self) -> SitemapRequestLoaderState: """Initialize and return the current state.""" if self._state.is_initialized: @@ -190,6 +300,14 @@ async def _get_state(self) -> SitemapRequestLoaderState: return self._state.current_value + def _passes_filters(self, target: str, parent: URL, kind: str) -> bool: + """Filter `target` by URL scheme and enqueue strategy, logging the reason if rejected.""" + ok, reason = filter_url(target=target, strategy=self._enqueue_strategy, origin=parent) + if not ok: + logger.warning(f'Skipping {kind} {target!r} (parent {str(parent)!r}): {reason}.') + return False + return True + def _check_url_patterns( self, target_url: str, @@ -235,6 +353,7 @@ async def _load_sitemaps(self) -> None: state.in_progress_sitemap_url = sitemap_url parse_options = ParseSitemapOptions(max_depth=0, emit_nested_sitemaps=True, sitemap_retries=3) + parsed_sitemap_url = URL(sitemap_url) async for item in parse_sitemap( [SitemapSource(type='url', url=sitemap_url)], @@ -245,6 +364,8 @@ async def _load_sitemaps(self) -> None: if isinstance(item, NestedSitemap): # Add nested sitemap to queue if item.loc not in state.pending_sitemap_urls and item.loc not in state.processed_sitemap_urls: + if not self._passes_filters(item.loc, parsed_sitemap_url, 'nested sitemap'): + continue state.pending_sitemap_urls.append(item.loc) continue @@ -261,6 +382,9 @@ async def _load_sitemaps(self) -> None: if not self._check_url_patterns(url, self._include, self._exclude): continue + if not self._passes_filters(url, parsed_sitemap_url, 'sitemap URL'): + continue + # Check if we have capacity in the queue await self._queue_has_capacity.wait() @@ -286,100 +410,3 @@ async def _load_sitemaps(self) -> None: except Exception: logger.exception('Error loading sitemaps') raise - - @override - async def get_total_count(self) -> int: - """Return the total number of URLs found so far.""" - state = await self._get_state() - return state.total_count - - @override - async def get_handled_count(self) -> int: - """Return the number of URLs that have been handled.""" - state = await self._get_state() - return state.handled_count - - @override - async def is_empty(self) -> bool: - """Check if there are no more URLs to process.""" - state = await self._get_state() - return not state.url_queue - - @override - async def is_finished(self) -> bool: - """Check if all URLs have been processed.""" - state = await self._get_state() - return not state.url_queue and len(state.in_progress) == 0 and self._loading_task.done() - - @override - async def fetch_next_request(self) -> Request | None: - """Fetch the next request to process.""" - while not (await self.is_finished()): - state = await self._get_state() - if not state.url_queue: - await asyncio.sleep(0.1) - continue - - async with self._queue_lock: - # Double-check if the queue is still not empty after acquiring the lock - if not state.url_queue: - continue - - url = state.url_queue.popleft() - request_option = RequestOptions(url=url) - - if len(state.url_queue) < self._max_buffer_size: - self._queue_has_capacity.set() - - if self._transform_request_function: - transform_request_option = self._transform_request_function(request_option) - if transform_request_option == 'skip': - state.total_count -= 1 - continue - if transform_request_option != 'unchanged': - request_option = transform_request_option - - request = Request.from_url(**request_option) - state.in_progress.add(request.url) - - return request - - return None - - @override - async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: - """Mark a request as successfully handled.""" - state = await self._get_state() - if request.url in state.in_progress: - state.in_progress.remove(request.url) - state.handled_count += 1 - return None - - async def abort_loading(self) -> None: - """Abort the sitemap loading process.""" - if self._loading_task and not self._loading_task.done(): - self._loading_task.cancel() - with suppress(asyncio.CancelledError): - await self._loading_task - - async def start(self) -> None: - """Start the sitemap loading process.""" - if self._loading_task and not self._loading_task.done(): - return - self._loading_task = asyncio.create_task(self._load_sitemaps()) - - async def close(self) -> None: - """Close the request loader.""" - await self.abort_loading() - await self._state.teardown() - - async def __aenter__(self) -> SitemapRequestLoader: - """Enter the context manager.""" - await self.start() - return self - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None - ) -> None: - """Exit the context manager.""" - await self.close() diff --git a/tests/unit/_utils/test_robots.py b/tests/unit/_utils/test_robots.py index 61dc60daa5..5f4ad28f22 100644 --- a/tests/unit/_utils/test_robots.py +++ b/tests/unit/_utils/test_robots.py @@ -11,8 +11,10 @@ async def test_generation_robots_txt_url(server_url: URL, http_client: HttpClient) -> None: + """`RobotsTxtFile.find` constructs the correct /robots.txt URL and successfully parses the response.""" robots_file = await RobotsTxtFile.find(str(server_url), http_client) - assert len(robots_file.get_sitemaps()) > 0 + # The fixture's robots.txt disallows /deny_all/ — proves the file was fetched and parsed. + assert not robots_file.is_allowed(str(server_url / 'deny_all/page.html')) async def test_allow_disallow_robots_txt(server_url: URL, http_client: HttpClient) -> None: @@ -24,9 +26,49 @@ async def test_allow_disallow_robots_txt(server_url: URL, http_client: HttpClien async def test_extract_sitemaps_urls(server_url: URL, http_client: HttpClient) -> None: + """Cross-host sitemap entries are dropped under the `'same-hostname'` enqueue strategy.""" robots = await RobotsTxtFile.find(str(server_url), http_client) - assert len(robots.get_sitemaps()) == 2 - assert set(robots.get_sitemaps()) == {'http://not-exists.com/sitemap_1.xml', 'http://not-exists.com/sitemap_2.xml'} + # The fixture lists `http://not-exists.com/sitemap_*.xml`, which is cross-host relative to `server_url`. + assert robots.get_sitemaps(enqueue_strategy='same-hostname') == [] + + +async def test_extract_same_host_sitemaps_urls() -> None: + """Sitemap entries on the same host as the robots.txt are returned.""" + content = 'User-agent: *\nSitemap: http://example.com/sitemap_1.xml\nSitemap: http://example.com/sitemap_2.xml\n' + robots = await RobotsTxtFile.from_content('http://example.com/robots.txt', content) + assert set(robots.get_sitemaps(enqueue_strategy='same-hostname')) == { + 'http://example.com/sitemap_1.xml', + 'http://example.com/sitemap_2.xml', + } + + +async def test_extract_sitemaps_urls_filters_cross_host_and_non_http() -> None: + """Cross-host and non-http(s) `Sitemap:` directives in robots.txt are silently filtered.""" + content = ( + 'User-agent: *\n' + 'Sitemap: http://example.com/legit.xml\n' + 'Sitemap: http://other.test/payload.xml\n' + 'Sitemap: gopher://internal:6379/_PING\n' + 'Sitemap: ftp://example.com/payload.xml\n' + ) + robots = await RobotsTxtFile.from_content('http://example.com/robots.txt', content) + assert robots.get_sitemaps(enqueue_strategy='same-hostname') == ['http://example.com/legit.xml'] + + +async def test_get_sitemaps_with_strategy_all_returns_cross_host() -> None: + """`enqueue_strategy='all'` disables host filtering but still rejects non-http(s) schemes.""" + content = ( + 'User-agent: *\n' + 'Sitemap: http://example.com/legit.xml\n' + 'Sitemap: http://other.test/payload.xml\n' + 'Sitemap: gopher://internal:6379/_PING\n' + 'Sitemap: ftp://example.com/payload.xml\n' + ) + robots = await RobotsTxtFile.from_content('http://example.com/robots.txt', content) + assert set(robots.get_sitemaps(enqueue_strategy='all')) == { + 'http://example.com/legit.xml', + 'http://other.test/payload.xml', + } async def test_parse_from_content() -> None: diff --git a/tests/unit/_utils/test_urls.py b/tests/unit/_utils/test_urls.py index 659ef09226..f42343aae3 100644 --- a/tests/unit/_utils/test_urls.py +++ b/tests/unit/_utils/test_urls.py @@ -1,9 +1,19 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from pydantic import ValidationError -from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute, validate_http_url +from crawlee._utils.urls import ( + convert_to_absolute_url, + filter_url, + is_url_absolute, + validate_http_url, +) + +if TYPE_CHECKING: + from crawlee._types import EnqueueStrategy def test_is_url_absolute() -> None: @@ -55,3 +65,37 @@ def test_validate_http_url() -> None: def test_validate_http_url_rejects_non_http_scheme(invalid_url: str) -> None: with pytest.raises(ValidationError): validate_http_url(invalid_url) + + +@pytest.mark.parametrize( + ('strategy', 'origin', 'target', 'expected'), + [ + # 'all' lets http(s) through across hosts, but rejects non-http(s) schemes + ('all', 'https://example.com/', 'https://other.test/', True), + ('all', 'https://example.com/', 'gopher://internal:6379/_PING', False), + ('all', 'https://example.com/', 'mailto:foo@bar.com', False), + ('all', 'https://example.com/', 'javascript:alert(1)', False), + ('all', 'https://example.com/', 'ftp://example.com/', False), + # 'same-hostname' is exact host equality + ('same-hostname', 'https://example.com/a', 'https://example.com/b', True), + ('same-hostname', 'https://example.com/', 'https://www.example.com/', False), + ('same-hostname', 'https://example.com/', 'https://other.test/', False), + ('same-hostname', 'https://example.com/', 'mailto:foo@example.com', False), + # 'same-domain' allows subdomains under the same registrable domain + ('same-domain', 'https://example.com/', 'https://www.example.com/', True), + ('same-domain', 'https://example.com/', 'https://api.example.com/', True), + ('same-domain', 'https://example.com/', 'https://other.test/', False), + ('same-domain', 'https://example.com/', 'ftp://www.example.com/', False), + # 'same-origin' requires scheme + host + port match + ('same-origin', 'https://example.com/', 'https://example.com/path', True), + ('same-origin', 'https://example.com/', 'http://example.com/', False), + ('same-origin', 'https://example.com/', 'https://example.com:8443/', False), + # missing hostname rejects everything except 'all' + ('same-hostname', 'https://example.com/', 'not-a-url', False), + ('same-domain', 'not-a-url', 'https://example.com/', False), + ], +) +def test_filter_url(strategy: EnqueueStrategy, origin: str, target: str, *, expected: bool) -> None: + ok, reason = filter_url(target=target, strategy=strategy, origin=origin) + assert ok is expected + assert (reason is None) is expected diff --git a/tests/unit/request_loaders/test_sitemap_request_loader.py b/tests/unit/request_loaders/test_sitemap_request_loader.py index d3742dfd6f..98741578f1 100644 --- a/tests/unit/request_loaders/test_sitemap_request_loader.py +++ b/tests/unit/request_loaders/test_sitemap_request_loader.py @@ -57,7 +57,7 @@ def encode_base64(data: bytes) -> str: async def test_sitemap_traversal(server_url: URL, http_client: HttpClient) -> None: sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client) + sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, enqueue_strategy='all') while not await sitemap_loader.is_finished(): item = await sitemap_loader.fetch_next_request() @@ -73,7 +73,7 @@ async def test_sitemap_traversal(server_url: URL, http_client: HttpClient) -> No async def test_is_empty_does_not_depend_on_fetch_next_request(server_url: URL, http_client: HttpClient) -> None: sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client) + sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, enqueue_strategy='all') items = [] @@ -98,7 +98,9 @@ async def test_is_empty_does_not_depend_on_fetch_next_request(server_url: URL, h async def test_abort_sitemap_loading(server_url: URL, http_client: HttpClient) -> None: sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], max_buffer_size=2, http_client=http_client) + sitemap_loader = SitemapRequestLoader( + [str(sitemap_url)], max_buffer_size=2, http_client=http_client, enqueue_strategy='all' + ) item = await sitemap_loader.fetch_next_request() assert item is not None @@ -121,7 +123,9 @@ async def test_create_persist_state_for_sitemap_loading( ) -> None: sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) persist_key = 'create_persist_state' - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, persist_state_key=persist_key) + sitemap_loader = SitemapRequestLoader( + [str(sitemap_url)], http_client=http_client, persist_state_key=persist_key, enqueue_strategy='all' + ) assert await sitemap_loader.is_finished() is False await sitemap_loader.close() @@ -141,7 +145,9 @@ async def wait_for_sitemap_loader_not_empty(sitemap_loader: SitemapRequestLoader sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) persist_key = 'data_persist_state' - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, persist_state_key=persist_key) + sitemap_loader = SitemapRequestLoader( + [str(sitemap_url)], http_client=http_client, persist_state_key=persist_key, enqueue_strategy='all' + ) # Give time to load await asyncio.wait_for(wait_for_sitemap_loader_not_empty(sitemap_loader), timeout=10) @@ -161,7 +167,9 @@ async def test_recovery_data_persistence_for_sitemap_loading( ) -> None: sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) persist_key = 'recovery_persist_state' - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, persist_state_key=persist_key) + sitemap_loader = SitemapRequestLoader( + [str(sitemap_url)], http_client=http_client, persist_state_key=persist_key, enqueue_strategy='all' + ) item = await sitemap_loader.fetch_next_request() @@ -175,7 +183,9 @@ async def test_recovery_data_persistence_for_sitemap_loading( assert state_data is not None next_item_in_kvs = state_data['urlQueue'][0] - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, persist_state_key=persist_key) + sitemap_loader = SitemapRequestLoader( + [str(sitemap_url)], http_client=http_client, persist_state_key=persist_key, enqueue_strategy='all' + ) item = await sitemap_loader.fetch_next_request() @@ -195,6 +205,7 @@ def transform_request(request_options: RequestOptions) -> RequestOptions | Reque [str(sitemap_url)], http_client=http_client, transform_request_function=transform_request, + enqueue_strategy='all', ) extracted_urls = set() @@ -229,6 +240,7 @@ def transform_request(_request_options: RequestOptions) -> RequestOptions | Requ [str(sitemap_url)], http_client=http_client, transform_request_function=transform_request, + enqueue_strategy='all', ) while not await sitemap_loader.is_finished(): @@ -251,7 +263,7 @@ async def test_sitemap_loader_to_tandem( ) -> None: sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) - sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client) + sitemap_loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, enqueue_strategy='all') request_manager = await sitemap_loader.to_tandem() while not await sitemap_loader.is_finished(): @@ -276,6 +288,7 @@ async def test_sitemap_loader_to_tandem_with_request_dropped( sitemap_loader = SitemapRequestLoader( [str(sitemap_url)], http_client=http_client, + enqueue_strategy='all', ) request_manager = await sitemap_loader.to_tandem() @@ -293,3 +306,117 @@ async def test_sitemap_loader_to_tandem_with_request_dropped( assert await request_manager.is_empty() assert await request_manager.is_finished() + + +def _make_urlset(urls: list[str]) -> str: + """Build a `` sitemap XML containing the given URLs.""" + url_blocks = '\n'.join(f'{url}' for url in urls) + return ( + '\n' + '\n' + f'{url_blocks}\n' + '' + ) + + +def _make_sitemapindex(sitemap_urls: list[str]) -> str: + """Build a `` XML pointing at the given nested sitemap URLs.""" + sitemap_blocks = '\n'.join(f'{url}' for url in sitemap_urls) + return ( + '\n' + '\n' + f'{sitemap_blocks}\n' + '' + ) + + +async def test_sitemap_loader_filters_cross_host_urls(server_url: URL, http_client: HttpClient) -> None: + """Default strategy `same-hostname` filters out content URLs that are not on the sitemap's host.""" + same_host_url = str(server_url / 'page') + cross_host_url = 'http://other.test/payload' + sitemap_content = _make_urlset([same_host_url, cross_host_url]) + sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(sitemap_content.encode())) + + loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client) + + fetched: list[str] = [] + while not await loader.is_finished(): + request = await loader.fetch_next_request() + if request is not None: + fetched.append(request.url) + await loader.mark_request_as_handled(request) + + assert fetched == [same_host_url] + + +async def test_sitemap_loader_filters_cross_host_nested_sitemap(server_url: URL, http_client: HttpClient) -> None: + """Nested `` entries on a different host are dropped before fetching them.""" + child_content = _make_urlset([str(server_url / 'inner')]) + same_host_child_url = str((server_url / 'sitemap.xml').with_query(base64=encode_base64(child_content.encode()))) + cross_host_child_url = 'http://other.test/child.xml' + index_content = _make_sitemapindex([same_host_child_url, cross_host_child_url]) + index_url = str((server_url / 'sitemap.xml').with_query(base64=encode_base64(index_content.encode()))) + + loader = SitemapRequestLoader([index_url], http_client=http_client) + + fetched: list[str] = [] + while not await loader.is_finished(): + request = await loader.fetch_next_request() + if request is not None: + fetched.append(request.url) + await loader.mark_request_as_handled(request) + + assert fetched == [str(server_url / 'inner')] + + +async def test_sitemap_loader_stamps_request_enqueue_strategy(server_url: URL, http_client: HttpClient) -> None: + """Emitted `Request` objects carry the loader's enqueue strategy so redirects are policed downstream.""" + same_host_url = str(server_url / 'page') + sitemap_content = _make_urlset([same_host_url]) + sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(sitemap_content.encode())) + + loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, enqueue_strategy='same-domain') + request = await loader.fetch_next_request() + assert request is not None + assert request.enqueue_strategy == 'same-domain' + await loader.mark_request_as_handled(request) + + +async def test_sitemap_loader_strategy_all_disables_filtering(server_url: URL, http_client: HttpClient) -> None: + """Passing `enqueue_strategy='all'` keeps the pre-fix permissive behavior for opt-in callers.""" + cross_host_url = 'http://other.test/payload' + sitemap_content = _make_urlset([cross_host_url]) + sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(sitemap_content.encode())) + + loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, enqueue_strategy='all') + + fetched: list[str] = [] + while not await loader.is_finished(): + request = await loader.fetch_next_request() + if request is not None: + fetched.append(request.url) + await loader.mark_request_as_handled(request) + + assert fetched == [cross_host_url] + + +async def test_sitemap_loader_drops_non_http_scheme_under_strategy_all( + server_url: URL, http_client: HttpClient +) -> None: + """Even with `enqueue_strategy='all'`, sitemap entries with non-http(s) schemes are dropped.""" + http_url = 'http://other.test/page' + sitemap_content = _make_urlset( + [http_url, 'mailto:foo@bar.com', 'javascript:alert(1)', 'ftp://example.com/file.txt'] + ) + sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(sitemap_content.encode())) + + loader = SitemapRequestLoader([str(sitemap_url)], http_client=http_client, enqueue_strategy='all') + + fetched: list[str] = [] + while not await loader.is_finished(): + request = await loader.fetch_next_request() + if request is not None: + fetched.append(request.url) + await loader.mark_request_as_handled(request) + + assert fetched == [http_url]