Skip to content

Commit

Permalink
Pass crawler configuration to storages
Browse files Browse the repository at this point in the history
  • Loading branch information
janbuchar committed Jul 29, 2024
1 parent d8c5495 commit 2fa79fb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/crawlee/basic_crawler/basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ async def get_request_provider(
) -> RequestProvider:
"""Return the configured request provider. If none is configured, open and return the default request queue."""
if not self._request_provider:
self._request_provider = await RequestQueue.open(id=id, name=name)
self._request_provider = await RequestQueue.open(id=id, name=name, configuration=self._configuration)

return self._request_provider

Expand All @@ -299,7 +299,7 @@ async def get_dataset(
name: str | None = None,
) -> Dataset:
"""Return the dataset with the given ID or name. If none is provided, return the default dataset."""
return await Dataset.open(id=id, name=name)
return await Dataset.open(id=id, name=name, configuration=self._configuration)

async def get_key_value_store(
self,
Expand All @@ -308,7 +308,7 @@ async def get_key_value_store(
name: str | None = None,
) -> KeyValueStore:
"""Return the key-value store with the given ID or name. If none is provided, return the default KVS."""
return await KeyValueStore.open(id=id, name=name)
return await KeyValueStore.open(id=id, name=name, configuration=self._configuration)

def error_handler(
self, handler: ErrorHandler[TCrawlingContext | BasicCrawlingContext]
Expand Down Expand Up @@ -468,7 +468,7 @@ async def export_data(
dataset_id: The ID of the dataset.
dataset_name: The name of the dataset.
"""
dataset = await Dataset.open(id=dataset_id, name=dataset_name)
dataset = await self.get_dataset(id=dataset_id, name=dataset_name)
path = path if isinstance(path, Path) else Path(path)

if content_type is None:
Expand All @@ -494,7 +494,7 @@ async def _push_data(
dataset_name: The name of the dataset.
kwargs: Keyword arguments to be passed to the dataset's `push_data` method.
"""
dataset = await Dataset.open(id=dataset_id, name=dataset_name)
dataset = await self.get_dataset(id=dataset_id, name=dataset_name)
await dataset.push_data(data, **kwargs)

def _should_retry_request(self, crawling_context: BasicCrawlingContext, error: Exception) -> bool:
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/basic_crawler/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from crawlee.basic_crawler import BasicCrawler
from crawlee.basic_crawler.errors import SessionError, UserDefinedErrorHandlerError
from crawlee.basic_crawler.types import AddRequestsKwargs, BasicCrawlingContext
from crawlee.configuration import Configuration
from crawlee.enqueue_strategy import EnqueueStrategy
from crawlee.models import BaseRequestData, Request
from crawlee.storages import Dataset, KeyValueStore, RequestList, RequestQueue
Expand Down Expand Up @@ -586,3 +587,19 @@ def test_crawler_log() -> None:
crawler = BasicCrawler()
assert isinstance(crawler.log, logging.Logger)
crawler.log.info('Test log message')


async def test_passes_configuration_to_storages() -> None:
configuration = Configuration(persist_storage=False, purge_on_start=True)

crawler = BasicCrawler(configuration=configuration)

dataset = await crawler.get_dataset()
assert dataset._configuration is configuration

key_value_store = await crawler.get_key_value_store()
assert key_value_store._configuration is configuration

request_provider = await crawler.get_request_provider()
assert isinstance(request_provider, RequestQueue)
assert request_provider._configuration is configuration

0 comments on commit 2fa79fb

Please sign in to comment.