Skip to content

Commit

Permalink
fix: Improve Request.user_data serialization (#540)
Browse files Browse the repository at this point in the history
- resolves #524

This adds validation to `Request.user_data` so that the user cannot pass
in data that is not JSON-serializable.
  • Loading branch information
janbuchar authored Sep 23, 2024
1 parent e8fc644 commit de29c0e
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 42 deletions.
138 changes: 97 additions & 41 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@

from __future__ import annotations

from collections.abc import Iterator, MutableMapping
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Annotated, Any

from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
from typing import Annotated, Any, cast

from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
JsonValue,
PlainSerializer,
PlainValidator,
TypeAdapter,
)
from typing_extensions import Self

from crawlee._types import EnqueueStrategy, HttpMethod
Expand All @@ -28,6 +38,64 @@ class RequestState(Enum):
SKIPPED = 7


class CrawleeRequestData(BaseModel):
"""Crawlee-specific configuration stored in the `user_data`."""

max_retries: Annotated[int | None, Field(alias='maxRetries')] = None
"""Maximum number of retries for this request. Allows to override the global `max_request_retries` option of
`BasicCrawler`."""

enqueue_strategy: Annotated[str | None, Field(alias='enqueueStrategy')] = None

state: RequestState | None = None
"""Describes the request's current lifecycle state."""

session_rotation_count: Annotated[int | None, Field(alias='sessionRotationCount')] = None

skip_navigation: Annotated[bool, Field(alias='skipNavigation')] = False

last_proxy_tier: Annotated[int | None, Field(alias='lastProxyTier')] = None

forefront: Annotated[bool, Field()] = False


class UserData(BaseModel, MutableMapping[str, JsonValue]):
"""Represents the `user_data` part of a Request.
Apart from the well-known attributes (`label` and `__crawlee`), it can also contain arbitrary JSON-compatible
values.
"""

model_config = ConfigDict(extra='allow')
__pydantic_extra__: dict[str, JsonValue] = Field(init=False) # pyright: ignore

crawlee_data: Annotated[CrawleeRequestData | None, Field(alias='__crawlee')] = None
label: Annotated[str | None, Field()] = None

def __getitem__(self, key: str) -> JsonValue:
return self.__pydantic_extra__[key]

def __setitem__(self, key: str, value: JsonValue) -> None:
if key == 'label':
if value is not None and not isinstance(value, str):
raise ValueError('`label` must be str or None')

self.label = value
self.__pydantic_extra__[key] = value

def __delitem__(self, key: str) -> None:
del self.__pydantic_extra__[key]

def __iter__(self) -> Iterator[str]: # type: ignore
yield from self.__pydantic_extra__

def __len__(self) -> int:
return len(self.__pydantic_extra__)


user_data_adapter = TypeAdapter(UserData)


class BaseRequestData(BaseModel):
"""Data needed to create a new crawling request."""

Expand Down Expand Up @@ -58,7 +126,20 @@ class BaseRequestData(BaseModel):

data: Annotated[dict[str, Any] | None, Field(default_factory=dict)] = None

user_data: Annotated[dict[str, Any], Field(alias='userData', default_factory=dict)]
user_data: Annotated[
dict[str, JsonValue], # Internally, the model contains `UserData`, this is just for convenience
Field(alias='userData', default_factory=lambda: UserData()),
PlainValidator(user_data_adapter.validate_python),
PlainSerializer(
lambda instance: user_data_adapter.dump_python(
instance,
by_alias=True,
exclude_none=True,
exclude_unset=True,
exclude_defaults=True,
)
),
]
"""Custom user data assigned to the request. Use this to save any request related data to the
request's scope, keeping them accessible on retries, failures etc.
"""
Expand Down Expand Up @@ -216,14 +297,16 @@ def from_base_request_data(cls, base_request_data: BaseRequestData, *, id: str |
@property
def label(self) -> str | None:
"""A string used to differentiate between arbitrary request types."""
if 'label' in self.user_data:
return str(self.user_data['label'])
return None
return cast(UserData, self.user_data).label

@property
def crawlee_data(self) -> CrawleeRequestData:
"""Crawlee-specific configuration stored in the user_data."""
return CrawleeRequestData.model_validate(self.user_data.get('__crawlee', {}))
user_data = cast(UserData, self.user_data)
if user_data.crawlee_data is None:
user_data.crawlee_data = CrawleeRequestData()

return user_data.crawlee_data

@property
def state(self) -> RequestState | None:
Expand All @@ -232,8 +315,7 @@ def state(self) -> RequestState | None:

@state.setter
def state(self, new_state: RequestState) -> None:
self.user_data.setdefault('__crawlee', {})
self.user_data['__crawlee']['state'] = new_state
self.crawlee_data.state = new_state

@property
def max_retries(self) -> int | None:
Expand All @@ -242,8 +324,7 @@ def max_retries(self) -> int | None:

@max_retries.setter
def max_retries(self, new_max_retries: int) -> None:
self.user_data.setdefault('__crawlee', {})
self.user_data['__crawlee']['maxRetries'] = new_max_retries
self.crawlee_data.max_retries = new_max_retries

@property
def session_rotation_count(self) -> int | None:
Expand All @@ -252,8 +333,7 @@ def session_rotation_count(self) -> int | None:

@session_rotation_count.setter
def session_rotation_count(self, new_session_rotation_count: int) -> None:
self.user_data.setdefault('__crawlee', {})
self.user_data['__crawlee']['sessionRotationCount'] = new_session_rotation_count
self.crawlee_data.session_rotation_count = new_session_rotation_count

@property
def enqueue_strategy(self) -> EnqueueStrategy:
Expand All @@ -266,8 +346,7 @@ def enqueue_strategy(self) -> EnqueueStrategy:

@enqueue_strategy.setter
def enqueue_strategy(self, new_enqueue_strategy: EnqueueStrategy) -> None:
self.user_data.setdefault('__crawlee', {})
self.user_data['__crawlee']['enqueueStrategy'] = str(new_enqueue_strategy)
self.crawlee_data.enqueue_strategy = new_enqueue_strategy

@property
def last_proxy_tier(self) -> int | None:
Expand All @@ -276,8 +355,7 @@ def last_proxy_tier(self) -> int | None:

@last_proxy_tier.setter
def last_proxy_tier(self, new_value: int) -> None:
self.user_data.setdefault('__crawlee', {})
self.user_data['__crawlee']['lastProxyTier'] = new_value
self.crawlee_data.last_proxy_tier = new_value

@property
def forefront(self) -> bool:
Expand All @@ -286,32 +364,10 @@ def forefront(self) -> bool:

@forefront.setter
def forefront(self, new_value: bool) -> None:
self.user_data.setdefault('__crawlee', {})
self.user_data['__crawlee']['forefront'] = new_value
self.crawlee_data.forefront = new_value


class RequestWithLock(Request):
"""A crawling request with information about locks."""

lock_expires_at: Annotated[datetime, Field(alias='lockExpiresAt')]


class CrawleeRequestData(BaseModel):
"""Crawlee-specific configuration stored in the user_data."""

max_retries: Annotated[int | None, Field(alias='maxRetries')] = None
"""Maximum number of retries for this request. Allows to override the global `max_request_retries` option of
`BasicCrawler`."""

enqueue_strategy: Annotated[str | None, Field(alias='enqueueStrategy')] = None

state: RequestState | None = None
"""Describes the request's current lifecycle state."""

session_rotation_count: Annotated[int | None, Field(alias='sessionRotationCount')] = None

skip_navigation: Annotated[bool, Field(alias='skipNavigation')] = False

last_proxy_tier: Annotated[int | None, Field(alias='lastProxyTier')] = None

forefront: Annotated[bool, Field()] = False
2 changes: 1 addition & 1 deletion tests/unit/basic_crawler/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ async def test_enqueue_strategy(test_input: AddRequestsTestInput) -> None:
crawler = BasicCrawler(request_provider=RequestList([Request.from_url('https://someplace.com/', label='start')]))

@crawler.router.handler('start')
async def default_handler(context: BasicCrawlingContext) -> None:
async def start_handler(context: BasicCrawlingContext) -> None:
await context.add_requests(
test_input.requests,
**test_input.kwargs,
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/storages/test_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

import pytest
from pydantic import ValidationError

from crawlee import Request
from crawlee.storages import RequestQueue
Expand Down Expand Up @@ -162,3 +163,49 @@ async def test_add_batched_requests(

# Confirm the queue is empty after processing all requests
assert await request_queue.is_empty() is True


async def test_invalid_user_data_serialization() -> None:
with pytest.raises(ValidationError):
Request.from_url(
'https://crawlee.dev',
user_data={
'foo': datetime(year=2020, month=7, day=4, tzinfo=timezone.utc),
'bar': {datetime(year=2020, month=4, day=7, tzinfo=timezone.utc)},
},
)


async def test_user_data_serialization(request_queue: RequestQueue) -> None:
request = Request.from_url(
'https://crawlee.dev',
user_data={
'hello': 'world',
'foo': 42,
},
)

await request_queue.add_request(request)

dequeued_request = await request_queue.fetch_next_request()
assert dequeued_request is not None

assert dequeued_request.user_data['hello'] == 'world'
assert dequeued_request.user_data['foo'] == 42


async def test_complex_user_data_serialization(request_queue: RequestQueue) -> None:
request = Request.from_url('https://crawlee.dev')
request.user_data['hello'] = 'world'
request.user_data['foo'] = 42
request.crawlee_data.max_retries = 1

await request_queue.add_request(request)

dequeued_request = await request_queue.fetch_next_request()
assert dequeued_request is not None

data = dequeued_request.model_dump(by_alias=True)
assert data['userData']['hello'] == 'world'
assert data['userData']['foo'] == 42
assert data['userData']['__crawlee'] == {'maxRetries': 1}

0 comments on commit de29c0e

Please sign in to comment.