from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING, Any, cast
from scrapy.http import Request, Response, XmlResponse
from scrapy.spiders import Spider
from scrapy.utils._compression import _DecompressionMaxSizeExceeded
from scrapy.utils.gz import gunzip, gzip_magic_number
from scrapy.utils.sitemap import Sitemap, sitemap_urls_from_robots
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
# typing.Self requires Python 3.11
from typing_extensions import Self
from scrapy.crawler import Crawler
from scrapy.http.request import CallbackT
logger = logging.getLogger(__name__)
[docs]class SitemapSpider(Spider):
sitemap_urls: Sequence[str] = ()
sitemap_rules: Sequence[tuple[re.Pattern[str] | str, str | CallbackT]] = [
("", "parse")
]
sitemap_follow: Sequence[re.Pattern[str] | str] = [""]
sitemap_alternate_links: bool = False
_max_size: int
_warn_size: int
@classmethod
def from_crawler(cls, crawler: Crawler, *args: Any, **kwargs: Any) -> Self:
spider = super().from_crawler(crawler, *args, **kwargs)
spider._max_size = getattr(
spider, "download_maxsize", spider.settings.getint("DOWNLOAD_MAXSIZE")
)
spider._warn_size = getattr(
spider, "download_warnsize", spider.settings.getint("DOWNLOAD_WARNSIZE")
)
return spider
def __init__(self, *a: Any, **kw: Any):
super().__init__(*a, **kw)
self._cbs: list[tuple[re.Pattern[str], CallbackT]] = []
for r, c in self.sitemap_rules:
if isinstance(c, str):
c = cast("CallbackT", getattr(self, c))
self._cbs.append((regex(r), c))
self._follow: list[re.Pattern[str]] = [regex(x) for x in self.sitemap_follow]
def start_requests(self) -> Iterable[Request]:
for url in self.sitemap_urls:
yield Request(url, self._parse_sitemap)
[docs] def sitemap_filter(
self, entries: Iterable[dict[str, Any]]
) -> Iterable[dict[str, Any]]:
"""This method can be used to filter sitemap entries by their
attributes, for example, you can filter locs with lastmod greater
than a given date (see docs).
"""
yield from entries
def _parse_sitemap(self, response: Response) -> Iterable[Request]:
if response.url.endswith("/robots.txt"):
for url in sitemap_urls_from_robots(response.text, base_url=response.url):
yield Request(url, callback=self._parse_sitemap)
else:
body = self._get_sitemap_body(response)
if body is None:
logger.warning(
"Ignoring invalid sitemap: %(response)s",
{"response": response},
extra={"spider": self},
)
return
s = Sitemap(body)
it = self.sitemap_filter(s)
if s.type == "sitemapindex":
for loc in iterloc(it, self.sitemap_alternate_links):
if any(x.search(loc) for x in self._follow):
yield Request(loc, callback=self._parse_sitemap)
elif s.type == "urlset":
for loc in iterloc(it, self.sitemap_alternate_links):
for r, c in self._cbs:
if r.search(loc):
yield Request(loc, callback=c)
break
def _get_sitemap_body(self, response: Response) -> bytes | None:
"""Return the sitemap body contained in the given response,
or None if the response is not a sitemap.
"""
if isinstance(response, XmlResponse):
return response.body
if gzip_magic_number(response):
uncompressed_size = len(response.body)
max_size = response.meta.get("download_maxsize", self._max_size)
warn_size = response.meta.get("download_warnsize", self._warn_size)
try:
body = gunzip(response.body, max_size=max_size)
except _DecompressionMaxSizeExceeded:
return None
if uncompressed_size < warn_size <= len(body):
logger.warning(
f"{response} body size after decompression ({len(body)} B) "
f"is larger than the download warning size ({warn_size} B)."
)
return body
# actual gzipped sitemap files are decompressed above ;
# if we are here (response body is not gzipped)
# and have a response for .xml.gz,
# it usually means that it was already gunzipped
# by HttpCompression middleware,
# the HTTP response being sent with "Content-Encoding: gzip"
# without actually being a .xml.gz file in the first place,
# merely XML gzip-compressed on the fly,
# in other word, here, we have plain XML
if response.url.endswith(".xml") or response.url.endswith(".xml.gz"):
return response.body
return None
def regex(x: re.Pattern[str] | str) -> re.Pattern[str]:
if isinstance(x, str):
return re.compile(x)
return x
def iterloc(it: Iterable[dict[str, Any]], alt: bool = False) -> Iterable[str]:
for d in it:
yield d["loc"]
# Also consider alternate URLs (xhtml:link rel="alternate")
if alt and "alternate" in d:
yield from d["alternate"]