Skip to content

Commit

Permalink
feat: add new model for HttpHeaders (#544)
Browse files Browse the repository at this point in the history
### Description

- Unify the HTTP headers type across the project.
- ~I used just the type alias plus custom validator in the `Request`
model to reach the same result (lowercase & sorted).~
- ~Of course, now it is only in the `Request`, but I believe it is not a
problem.~
- Take it out from #542.

### Issues

- N/A

### Testing

- N/A

### Checklist

- [x] CI passed
  • Loading branch information
vdusek authored Oct 2, 2024
1 parent 4891b73 commit 854f2c1
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 81 deletions.
4 changes: 2 additions & 2 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from typing_extensions import Self

from crawlee._types import EnqueueStrategy, HttpMethod, HttpPayload, HttpQueryParams
from crawlee._types import EnqueueStrategy, HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams
from crawlee._utils.requests import compute_unique_key, unique_key_to_request_id
from crawlee._utils.urls import extract_query_params, validate_http_url

Expand Down Expand Up @@ -119,7 +119,7 @@ class BaseRequestData(BaseModel):
method: HttpMethod = 'GET'
"""HTTP request method."""

headers: Annotated[dict[str, str], Field(default_factory=dict)] = {}
headers: Annotated[HttpHeaders, Field(default_factory=HttpHeaders())] = HttpHeaders()
"""HTTP request headers."""

query_params: Annotated[HttpQueryParams, Field(alias='queryParams', default_factory=dict)] = {}
Expand Down
100 changes: 48 additions & 52 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

from collections.abc import Mapping
from collections.abc import Iterator, Mapping
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal, Protocol, Union
from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, Union

from pydantic import ConfigDict, Field, PlainValidator, RootModel
from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack

if TYPE_CHECKING:
import logging
import re
from collections.abc import Coroutine, Iterator, Sequence
from collections.abc import Coroutine, Sequence

from crawlee import Glob
from crawlee._request import BaseRequestData, Request
Expand All @@ -32,6 +33,50 @@
HttpPayload: TypeAlias = Union[str, bytes]


def _normalize_headers(headers: Mapping[str, str]) -> dict[str, str]:
"""Converts all header keys to lowercase and returns them sorted by key."""
normalized_headers = {k.lower(): v for k, v in headers.items()}
sorted_headers = sorted(normalized_headers.items())
return dict(sorted_headers)


class HttpHeaders(RootModel, Mapping[str, str]):
"""A dictionary-like object representing HTTP headers."""

model_config = ConfigDict(populate_by_name=True)

root: Annotated[
dict[str, str],
PlainValidator(lambda value: _normalize_headers(value)),
Field(default_factory=dict),
] = {} # noqa: RUF012

def __getitem__(self, key: str) -> str:
return self.root[key.lower()]

def __setitem__(self, key: str, value: str) -> None:
raise TypeError(f'{self.__class__.__name__} is immutable')

def __delitem__(self, key: str) -> None:
raise TypeError(f'{self.__class__.__name__} is immutable')

def __or__(self, other: HttpHeaders) -> HttpHeaders:
"""Return a new instance of `HttpHeaders` combining this one with another one."""
combined_headers = {**self.root, **other}
return HttpHeaders(combined_headers)

def __ror__(self, other: HttpHeaders) -> HttpHeaders:
"""Support reversed | operation (other | self)."""
combined_headers = {**other, **self.root}
return HttpHeaders(combined_headers)

def __iter__(self) -> Iterator[str]: # type: ignore
yield from self.root

def __len__(self) -> int:
return len(self.root)


class EnqueueStrategy(str, Enum):
"""Strategy for deciding which links should be followed and which ones should be ignored."""

Expand Down Expand Up @@ -222,52 +267,3 @@ async def add_requests(
) -> None:
"""Track a call to the `add_requests` context helper."""
self.add_requests_calls.append(AddRequestsFunctionCall(requests=requests, **kwargs))


class HttpHeaders(Mapping[str, str]):
"""An immutable mapping for HTTP headers that ensures case-insensitivity for header names."""

def __init__(self, headers: Mapping[str, str] | None = None) -> None:
"""Create a new instance.
Args:
headers: A mapping of header names to values.
"""
# Ensure immutability by sorting and fixing the order.
headers = headers or {}
headers = {k.capitalize(): v for k, v in headers.items()}
self._headers = dict(sorted(headers.items()))

def __getitem__(self, key: str) -> str:
"""Get the value of a header by its name, case-insensitive."""
return self._headers[key.capitalize()]

def __iter__(self) -> Iterator[str]:
"""Return an iterator over the header names."""
return iter(self._headers)

def __len__(self) -> int:
"""Return the number of headers."""
return len(self._headers)

def __repr__(self) -> str:
"""Return a string representation of the object."""
return f'{self.__class__.__name__}({self._headers})'

def __setitem__(self, key: str, value: str) -> None:
"""Prevent setting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')

def __delitem__(self, key: str) -> None:
"""Prevent deleting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')

def __or__(self, other: Mapping[str, str]) -> HttpHeaders:
"""Return a new instance of `HttpHeaders` combining this one with another one."""
combined_headers = {**self._headers, **other}
return HttpHeaders(combined_headers)

def __ror__(self, other: Mapping[str, str]) -> HttpHeaders:
"""Support reversed | operation (other | self)."""
combined_headers = {**other, **self._headers}
return HttpHeaders(combined_headers)
2 changes: 1 addition & 1 deletion src/crawlee/http_clients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def status_code(self) -> int:
"""The HTTP status code received from the server."""

@property
def headers(self) -> dict[str, str]:
def headers(self) -> HttpHeaders:
"""The HTTP headers received in the response."""

def read(self) -> bytes:
Expand Down
8 changes: 4 additions & 4 deletions src/crawlee/http_clients/_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def status_code(self) -> int:
return self._response.status_code

@property
def headers(self) -> dict[str, str]:
return dict(self._response.headers.items())
def headers(self) -> HttpHeaders:
return HttpHeaders(dict(self._response.headers))

def read(self) -> bytes:
return self._response.read()
Expand Down Expand Up @@ -125,7 +125,7 @@ async def crawl(
statistics: Statistics | None = None,
) -> HttpCrawlingResult:
client = self._get_client(proxy_info.url if proxy_info else None)
headers = self._combine_headers(HttpHeaders(request.headers))
headers = self._combine_headers(request.headers)

http_request = client.build_request(
url=request.url,
Expand Down Expand Up @@ -177,7 +177,7 @@ async def send_request(
http_request = client.build_request(
url=url,
method=method,
headers=headers,
headers=dict(headers) if headers else None,
params=query_params,
data=data,
extensions={'crawlee_session': session if self._persist_cookies_per_session else None},
Expand Down
9 changes: 5 additions & 4 deletions src/crawlee/http_clients/curl_impersonate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from curl_cffi.const import CurlHttpVersion
from typing_extensions import override

from crawlee._types import HttpHeaders
from crawlee._utils.blocked import ROTATE_PROXY_ERRORS
from crawlee.errors import ProxyError
from crawlee.http_clients import BaseHttpClient, HttpCrawlingResult, HttpResponse
Expand All @@ -25,7 +26,7 @@

from curl_cffi.requests import Response

from crawlee._types import HttpHeaders, HttpMethod, HttpQueryParams
from crawlee._types import HttpMethod, HttpQueryParams
from crawlee.base_storage_client._models import Request
from crawlee.proxy_configuration import ProxyInfo
from crawlee.sessions import Session
Expand Down Expand Up @@ -62,8 +63,8 @@ def status_code(self) -> int:
return self._response.status_code

@property
def headers(self) -> dict[str, str]:
return dict(self._response.headers.items())
def headers(self) -> HttpHeaders:
return HttpHeaders(dict(self._response.headers))

def read(self) -> bytes:
return self._response.content
Expand Down Expand Up @@ -163,7 +164,7 @@ async def send_request(
response = await client.request(
url=url,
method=method.upper(), # type: ignore # curl-cffi requires uppercase method
headers=headers,
headers=dict(headers) if headers else None,
params=query_params,
data=data,
cookies=session.cookies if session else None,
Expand Down
41 changes: 30 additions & 11 deletions tests/unit/basic_crawler/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,20 @@ async def handler(context: BasicCrawlingContext) -> None:


async def test_calls_error_handler() -> None:
# Data structure to better track the calls to the error handler.
@dataclass(frozen=True)
class Call:
url: str
error: Exception
custom_retry_count: int

# List to store the information of calls to the error handler.
calls = list[Call]()

crawler = BasicCrawler(
request_provider=RequestList(['http://a.com/', 'http://b.com/', 'http://c.com/']),
max_request_retries=3,
)
calls = list[tuple[BasicCrawlingContext, Exception, int]]()

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
Expand All @@ -169,24 +178,34 @@ async def handler(context: BasicCrawlingContext) -> None:

@crawler.error_handler
async def error_handler(context: BasicCrawlingContext, error: Exception) -> Request:
# Retrieve or initialize the headers, and extract the current custom retry count.
headers = context.request.headers or HttpHeaders()
custom_retry_count = int(headers.get('custom_retry_count', '0'))
calls.append((context, error, custom_retry_count))

request = context.request.model_dump()
request['headers']['custom_retry_count'] = str(custom_retry_count + 1)
# Append the current call information.
calls.append(Call(context.request.url, error, custom_retry_count))

# Update the request to include an incremented custom retry count in the headers and return it.
request = context.request.model_dump()
request['headers'] = HttpHeaders({'custom_retry_count': str(custom_retry_count + 1)})
return Request.model_validate(request)

await crawler.run()

assert len(calls) == 2 # error handler should be called for each retryable request
assert calls[0][0].request.url == 'http://b.com/'
assert isinstance(calls[0][1], RuntimeError)
# Verify that the error handler was called twice
assert len(calls) == 2

# Check the first call...
first_call = calls[0]
assert first_call.url == 'http://b.com/'
assert isinstance(first_call.error, RuntimeError)
assert first_call.custom_retry_count == 0

# Check the contents of the `custom_retry_count` header added by the error handler
assert calls[0][2] == 0
assert calls[1][2] == 1
# Check the second call...
second_call = calls[1]
assert second_call.url == 'http://b.com/'
assert isinstance(second_call.error, RuntimeError)
assert second_call.custom_retry_count == 1


async def test_calls_error_handler_for_sesion_errors() -> None:
Expand Down Expand Up @@ -289,7 +308,7 @@ async def handler(context: BasicCrawlingContext) -> None:

response = await context.send_request('http://b.com/')
response_body = response.read()
response_headers = HttpHeaders(response.headers)
response_headers = response.headers

await crawler.run()
assert respx_mock['test_endpoint'].called
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/fingerprint_suite/test_header_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest

from crawlee._types import HttpHeaders
from crawlee.fingerprint_suite import HeaderGenerator
from crawlee.fingerprint_suite._consts import (
PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA,
Expand All @@ -21,7 +20,6 @@ def test_get_common_headers() -> None:

assert 'Accept' in headers
assert 'Accept-Language' in headers
assert isinstance(headers, HttpHeaders)


def test_get_random_user_agent_header() -> None:
Expand All @@ -30,7 +28,6 @@ def test_get_random_user_agent_header() -> None:
headers = header_generator.get_random_user_agent_header()

assert 'User-Agent' in headers
assert isinstance(headers, HttpHeaders)
assert headers['User-Agent'] in USER_AGENT_POOL


Expand All @@ -41,7 +38,6 @@ def test_get_user_agent_header_chromium() -> None:

assert 'User-Agent' in headers
assert headers['User-Agent'] == PW_CHROMIUM_HEADLESS_DEFAULT_USER_AGENT
assert isinstance(headers, HttpHeaders)


def test_get_user_agent_header_firefox() -> None:
Expand All @@ -51,7 +47,6 @@ def test_get_user_agent_header_firefox() -> None:

assert 'User-Agent' in headers
assert headers['User-Agent'] == PW_FIREFOX_HEADLESS_DEFAULT_USER_AGENT
assert isinstance(headers, HttpHeaders)


def test_get_user_agent_header_webkit() -> None:
Expand All @@ -61,7 +56,6 @@ def test_get_user_agent_header_webkit() -> None:

assert 'User-Agent' in headers
assert headers['User-Agent'] == PW_WEBKIT_HEADLESS_DEFAULT_USER_AGENT
assert isinstance(headers, HttpHeaders)


def test_get_user_agent_header_invalid_browser_type() -> None:
Expand All @@ -83,7 +77,6 @@ def test_get_sec_ch_ua_headers_chromium() -> None:
assert headers['Sec-Ch-Ua-Mobile'] == PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA_MOBILE
assert 'Sec-Ch-Ua-Platform' in headers
assert headers['Sec-Ch-Ua-Platform'] == PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA_PLATFORM
assert isinstance(headers, HttpHeaders)


def test_get_sec_ch_ua_headers_firefox() -> None:
Expand Down

0 comments on commit 854f2c1

Please sign in to comment.