diff --git a/src/crawlee/_request.py b/src/crawlee/_request.py index c0f29dec..182adb4e 100644 --- a/src/crawlee/_request.py +++ b/src/crawlee/_request.py @@ -20,7 +20,7 @@ ) from typing_extensions import Self -from crawlee._types import EnqueueStrategy, HttpMethod, HttpPayload, HttpQueryParams +from crawlee._types import EnqueueStrategy, HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams from crawlee._utils.requests import compute_unique_key, unique_key_to_request_id from crawlee._utils.urls import extract_query_params, validate_http_url @@ -119,7 +119,7 @@ class BaseRequestData(BaseModel): method: HttpMethod = 'GET' """HTTP request method.""" - headers: Annotated[dict[str, str], Field(default_factory=dict)] = {} + headers: Annotated[HttpHeaders, Field(default_factory=HttpHeaders())] = HttpHeaders() """HTTP request headers.""" query_params: Annotated[HttpQueryParams, Field(alias='queryParams', default_factory=dict)] = {} diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index a1b9db33..4f8bf185 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -1,16 +1,17 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Iterator, Mapping from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Literal, Protocol, Union +from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, Union +from pydantic import ConfigDict, Field, PlainValidator, RootModel from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack if TYPE_CHECKING: import logging import re - from collections.abc import Coroutine, Iterator, Sequence + from collections.abc import Coroutine, Sequence from crawlee import Glob from crawlee._request import BaseRequestData, Request @@ -32,6 +33,50 @@ HttpPayload: TypeAlias = Union[str, bytes] +def _normalize_headers(headers: Mapping[str, str]) -> dict[str, str]: + """Converts all header keys to lowercase and returns them sorted by key.""" + normalized_headers = {k.lower(): v for k, v in headers.items()} + sorted_headers = sorted(normalized_headers.items()) + return dict(sorted_headers) + + +class HttpHeaders(RootModel, Mapping[str, str]): + """A dictionary-like object representing HTTP headers.""" + + model_config = ConfigDict(populate_by_name=True) + + root: Annotated[ + dict[str, str], + PlainValidator(lambda value: _normalize_headers(value)), + Field(default_factory=dict), + ] = {} # noqa: RUF012 + + def __getitem__(self, key: str) -> str: + return self.root[key.lower()] + + def __setitem__(self, key: str, value: str) -> None: + raise TypeError(f'{self.__class__.__name__} is immutable') + + def __delitem__(self, key: str) -> None: + raise TypeError(f'{self.__class__.__name__} is immutable') + + def __or__(self, other: HttpHeaders) -> HttpHeaders: + """Return a new instance of `HttpHeaders` combining this one with another one.""" + combined_headers = {**self.root, **other} + return HttpHeaders(combined_headers) + + def __ror__(self, other: HttpHeaders) -> HttpHeaders: + """Support reversed | operation (other | self).""" + combined_headers = {**other, **self.root} + return HttpHeaders(combined_headers) + + def __iter__(self) -> Iterator[str]: # type: ignore + yield from self.root + + def __len__(self) -> int: + return len(self.root) + + class EnqueueStrategy(str, Enum): """Strategy for deciding which links should be followed and which ones should be ignored.""" @@ -222,52 +267,3 @@ async def add_requests( ) -> None: """Track a call to the `add_requests` context helper.""" self.add_requests_calls.append(AddRequestsFunctionCall(requests=requests, **kwargs)) - - -class HttpHeaders(Mapping[str, str]): - """An immutable mapping for HTTP headers that ensures case-insensitivity for header names.""" - - def __init__(self, headers: Mapping[str, str] | None = None) -> None: - """Create a new instance. - - Args: - headers: A mapping of header names to values. - """ - # Ensure immutability by sorting and fixing the order. - headers = headers or {} - headers = {k.capitalize(): v for k, v in headers.items()} - self._headers = dict(sorted(headers.items())) - - def __getitem__(self, key: str) -> str: - """Get the value of a header by its name, case-insensitive.""" - return self._headers[key.capitalize()] - - def __iter__(self) -> Iterator[str]: - """Return an iterator over the header names.""" - return iter(self._headers) - - def __len__(self) -> int: - """Return the number of headers.""" - return len(self._headers) - - def __repr__(self) -> str: - """Return a string representation of the object.""" - return f'{self.__class__.__name__}({self._headers})' - - def __setitem__(self, key: str, value: str) -> None: - """Prevent setting a header, as the object is immutable.""" - raise TypeError(f'{self.__class__.__name__} is immutable') - - def __delitem__(self, key: str) -> None: - """Prevent deleting a header, as the object is immutable.""" - raise TypeError(f'{self.__class__.__name__} is immutable') - - def __or__(self, other: Mapping[str, str]) -> HttpHeaders: - """Return a new instance of `HttpHeaders` combining this one with another one.""" - combined_headers = {**self._headers, **other} - return HttpHeaders(combined_headers) - - def __ror__(self, other: Mapping[str, str]) -> HttpHeaders: - """Support reversed | operation (other | self).""" - combined_headers = {**other, **self._headers} - return HttpHeaders(combined_headers) diff --git a/src/crawlee/http_clients/_base.py b/src/crawlee/http_clients/_base.py index 6d64a3cb..7edafc10 100644 --- a/src/crawlee/http_clients/_base.py +++ b/src/crawlee/http_clients/_base.py @@ -29,7 +29,7 @@ def status_code(self) -> int: """The HTTP status code received from the server.""" @property - def headers(self) -> dict[str, str]: + def headers(self) -> HttpHeaders: """The HTTP headers received in the response.""" def read(self) -> bytes: diff --git a/src/crawlee/http_clients/_httpx.py b/src/crawlee/http_clients/_httpx.py index e3b263cb..2416bb04 100644 --- a/src/crawlee/http_clients/_httpx.py +++ b/src/crawlee/http_clients/_httpx.py @@ -39,8 +39,8 @@ def status_code(self) -> int: return self._response.status_code @property - def headers(self) -> dict[str, str]: - return dict(self._response.headers.items()) + def headers(self) -> HttpHeaders: + return HttpHeaders(dict(self._response.headers)) def read(self) -> bytes: return self._response.read() @@ -125,7 +125,7 @@ async def crawl( statistics: Statistics | None = None, ) -> HttpCrawlingResult: client = self._get_client(proxy_info.url if proxy_info else None) - headers = self._combine_headers(HttpHeaders(request.headers)) + headers = self._combine_headers(request.headers) http_request = client.build_request( url=request.url, @@ -177,7 +177,7 @@ async def send_request( http_request = client.build_request( url=url, method=method, - headers=headers, + headers=dict(headers) if headers else None, params=query_params, data=data, extensions={'crawlee_session': session if self._persist_cookies_per_session else None}, diff --git a/src/crawlee/http_clients/curl_impersonate.py b/src/crawlee/http_clients/curl_impersonate.py index 92e016d5..00543164 100644 --- a/src/crawlee/http_clients/curl_impersonate.py +++ b/src/crawlee/http_clients/curl_impersonate.py @@ -16,6 +16,7 @@ from curl_cffi.const import CurlHttpVersion from typing_extensions import override +from crawlee._types import HttpHeaders from crawlee._utils.blocked import ROTATE_PROXY_ERRORS from crawlee.errors import ProxyError from crawlee.http_clients import BaseHttpClient, HttpCrawlingResult, HttpResponse @@ -25,7 +26,7 @@ from curl_cffi.requests import Response - from crawlee._types import HttpHeaders, HttpMethod, HttpQueryParams + from crawlee._types import HttpMethod, HttpQueryParams from crawlee.base_storage_client._models import Request from crawlee.proxy_configuration import ProxyInfo from crawlee.sessions import Session @@ -62,8 +63,8 @@ def status_code(self) -> int: return self._response.status_code @property - def headers(self) -> dict[str, str]: - return dict(self._response.headers.items()) + def headers(self) -> HttpHeaders: + return HttpHeaders(dict(self._response.headers)) def read(self) -> bytes: return self._response.content @@ -163,7 +164,7 @@ async def send_request( response = await client.request( url=url, method=method.upper(), # type: ignore # curl-cffi requires uppercase method - headers=headers, + headers=dict(headers) if headers else None, params=query_params, data=data, cookies=session.cookies if session else None, diff --git a/tests/unit/basic_crawler/test_basic_crawler.py b/tests/unit/basic_crawler/test_basic_crawler.py index 32d9cf09..e2cb32cf 100644 --- a/tests/unit/basic_crawler/test_basic_crawler.py +++ b/tests/unit/basic_crawler/test_basic_crawler.py @@ -156,11 +156,20 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_calls_error_handler() -> None: + # Data structure to better track the calls to the error handler. + @dataclass(frozen=True) + class Call: + url: str + error: Exception + custom_retry_count: int + + # List to store the information of calls to the error handler. + calls = list[Call]() + crawler = BasicCrawler( request_provider=RequestList(['http://a.com/', 'http://b.com/', 'http://c.com/']), max_request_retries=3, ) - calls = list[tuple[BasicCrawlingContext, Exception, int]]() @crawler.router.default_handler async def handler(context: BasicCrawlingContext) -> None: @@ -169,24 +178,34 @@ async def handler(context: BasicCrawlingContext) -> None: @crawler.error_handler async def error_handler(context: BasicCrawlingContext, error: Exception) -> Request: + # Retrieve or initialize the headers, and extract the current custom retry count. headers = context.request.headers or HttpHeaders() custom_retry_count = int(headers.get('custom_retry_count', '0')) - calls.append((context, error, custom_retry_count)) - request = context.request.model_dump() - request['headers']['custom_retry_count'] = str(custom_retry_count + 1) + # Append the current call information. + calls.append(Call(context.request.url, error, custom_retry_count)) + # Update the request to include an incremented custom retry count in the headers and return it. + request = context.request.model_dump() + request['headers'] = HttpHeaders({'custom_retry_count': str(custom_retry_count + 1)}) return Request.model_validate(request) await crawler.run() - assert len(calls) == 2 # error handler should be called for each retryable request - assert calls[0][0].request.url == 'http://b.com/' - assert isinstance(calls[0][1], RuntimeError) + # Verify that the error handler was called twice + assert len(calls) == 2 + + # Check the first call... + first_call = calls[0] + assert first_call.url == 'http://b.com/' + assert isinstance(first_call.error, RuntimeError) + assert first_call.custom_retry_count == 0 - # Check the contents of the `custom_retry_count` header added by the error handler - assert calls[0][2] == 0 - assert calls[1][2] == 1 + # Check the second call... + second_call = calls[1] + assert second_call.url == 'http://b.com/' + assert isinstance(second_call.error, RuntimeError) + assert second_call.custom_retry_count == 1 async def test_calls_error_handler_for_sesion_errors() -> None: @@ -289,7 +308,7 @@ async def handler(context: BasicCrawlingContext) -> None: response = await context.send_request('http://b.com/') response_body = response.read() - response_headers = HttpHeaders(response.headers) + response_headers = response.headers await crawler.run() assert respx_mock['test_endpoint'].called diff --git a/tests/unit/fingerprint_suite/test_header_generator.py b/tests/unit/fingerprint_suite/test_header_generator.py index 2b09e238..5cc476bf 100644 --- a/tests/unit/fingerprint_suite/test_header_generator.py +++ b/tests/unit/fingerprint_suite/test_header_generator.py @@ -2,7 +2,6 @@ import pytest -from crawlee._types import HttpHeaders from crawlee.fingerprint_suite import HeaderGenerator from crawlee.fingerprint_suite._consts import ( PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA, @@ -21,7 +20,6 @@ def test_get_common_headers() -> None: assert 'Accept' in headers assert 'Accept-Language' in headers - assert isinstance(headers, HttpHeaders) def test_get_random_user_agent_header() -> None: @@ -30,7 +28,6 @@ def test_get_random_user_agent_header() -> None: headers = header_generator.get_random_user_agent_header() assert 'User-Agent' in headers - assert isinstance(headers, HttpHeaders) assert headers['User-Agent'] in USER_AGENT_POOL @@ -41,7 +38,6 @@ def test_get_user_agent_header_chromium() -> None: assert 'User-Agent' in headers assert headers['User-Agent'] == PW_CHROMIUM_HEADLESS_DEFAULT_USER_AGENT - assert isinstance(headers, HttpHeaders) def test_get_user_agent_header_firefox() -> None: @@ -51,7 +47,6 @@ def test_get_user_agent_header_firefox() -> None: assert 'User-Agent' in headers assert headers['User-Agent'] == PW_FIREFOX_HEADLESS_DEFAULT_USER_AGENT - assert isinstance(headers, HttpHeaders) def test_get_user_agent_header_webkit() -> None: @@ -61,7 +56,6 @@ def test_get_user_agent_header_webkit() -> None: assert 'User-Agent' in headers assert headers['User-Agent'] == PW_WEBKIT_HEADLESS_DEFAULT_USER_AGENT - assert isinstance(headers, HttpHeaders) def test_get_user_agent_header_invalid_browser_type() -> None: @@ -83,7 +77,6 @@ def test_get_sec_ch_ua_headers_chromium() -> None: assert headers['Sec-Ch-Ua-Mobile'] == PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA_MOBILE assert 'Sec-Ch-Ua-Platform' in headers assert headers['Sec-Ch-Ua-Platform'] == PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA_PLATFORM - assert isinstance(headers, HttpHeaders) def test_get_sec_ch_ua_headers_firefox() -> None: