Skip to content

Commit

Permalink
Concurrent CDK: catch exceptions from worker thread and add integrati…
Browse files Browse the repository at this point in the history
…on test scenarios (#31245)

Co-authored-by: girarda <[email protected]>
  • Loading branch information
girarda and girarda authored Oct 23, 2023
1 parent 9835f6b commit 7da2822
Show file tree
Hide file tree
Showing 13 changed files with 1,130 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ def generate_partitions(self, partition_generator: PartitionGenerator, sync_mode
Generate partitions from a partition generator and put them in a queue.
When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated.
If an exception is encountered, the exception will be caught and put in the queue.
This method is meant to be called in a separate thread.
:param partition_generator: The partition Generator
:param sync_mode: The sync mode used
:return:
"""
for partition in partition_generator.generate(sync_mode=sync_mode):
self._queue.put(partition)
self._queue.put(self._sentinel)
try:
for partition in partition_generator.generate(sync_mode=sync_mode):
self._queue.put(partition)
self._queue.put(self._sentinel)
except Exception as e:
self._queue.put(e)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class PartitionReader:
"""
Generates records from a partition and puts them in a queuea.
Generates records from a partition and puts them in a queue.
"""

def __init__(self, queue: Queue[QueueItem]) -> None:
Expand All @@ -24,10 +24,15 @@ def process_partition(self, partition: Partition) -> None:
Process a partition and put the records in the output queue.
When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated.
If an exception is encountered, the exception will be caught and put in the queue.
This method is meant to be called from a thread.
:param partition: The partition to read data from
:return: None
"""
for record in partition.read():
self._queue.put(record)
self._queue.put(PartitionCompleteSentinel(partition))
try:
for record in partition.read():
self._queue.put(record)
self._queue.put(PartitionCompleteSentinel(partition))
except Exception as e:
self._queue.put(e)
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def __init__(self, partition: Partition):
"""
Typedef representing the items that can be added to the ThreadBasedConcurrentStream
"""
QueueItem = Union[Record, Partition, PartitionCompleteSentinel, PARTITIONS_GENERATED_SENTINEL, Partition]
QueueItem = Union[Record, Partition, PartitionCompleteSentinel, PARTITIONS_GENERATED_SENTINEL, Partition, Exception]
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,16 @@ def read(self) -> Iterable[Record]:
Algorithm:
1. Submit a future to generate the stream's partition to process.
- This has to be done asynchronously because we sometimes need to submit requests to the API to generate all partitions (eg for substreams).
- The future will add the partitions to process on a work queue
- The future will add the partitions to process on a work queue.
2. Continuously poll work from the work queue until all partitions are generated and processed
- If the next work item is an Exception, stop the threadpool and raise it.
- If the next work item is a partition, submit a future to process it.
- The future will add the records to emit on the work queue
- Add the partitions to the partitions_to_done dict so we know it needs to complete for the sync to succeed
- If the next work item is a record, yield the record
- If the next work item is PARTITIONS_GENERATED_SENTINEL, all the partitions were generated
- If the next work item is a PartitionCompleteSentinel, a partition is done processing
- Update the value in partitions_to_done to True so we know the partition is completed
- The future will add the records to emit on the work queue.
- Add the partitions to the partitions_to_done dict so we know it needs to complete for the sync to succeed.
- If the next work item is a record, yield the record.
- If the next work item is PARTITIONS_GENERATED_SENTINEL, all the partitions were generated.
- If the next work item is a PartitionCompleteSentinel, a partition is done processing.
- Update the value in partitions_to_done to True so we know the partition is completed.
"""
self._logger.debug(f"Processing stream slices for {self.name} (sync_mode: full_refresh)")
futures: List[Future[Any]] = []
Expand All @@ -93,26 +94,32 @@ def read(self) -> Iterable[Record]:
partitions_to_done: Dict[Partition, bool] = {}

finished_partitions = False
while record_or_partition := queue.get(block=True, timeout=self._timeout_seconds):
if record_or_partition == PARTITIONS_GENERATED_SENTINEL:
while record_or_partition_or_exception := queue.get(block=True, timeout=self._timeout_seconds):
if isinstance(record_or_partition_or_exception, Exception):
# An exception was raised while processing the stream
# Stop the threadpool and raise it
self._stop_and_raise_exception(record_or_partition_or_exception)
elif record_or_partition_or_exception == PARTITIONS_GENERATED_SENTINEL:
# All partitions were generated
finished_partitions = True
elif isinstance(record_or_partition, PartitionCompleteSentinel):
elif isinstance(record_or_partition_or_exception, PartitionCompleteSentinel):
# All records for a partition were generated
if record_or_partition.partition not in partitions_to_done:
if record_or_partition_or_exception.partition not in partitions_to_done:
raise RuntimeError(
f"Received sentinel for partition {record_or_partition.partition} that was not in partitions. This is indicative of a bug in the CDK. Please contact support.partitions:\n{partitions_to_done}"
f"Received sentinel for partition {record_or_partition_or_exception.partition} that was not in partitions. This is indicative of a bug in the CDK. Please contact support.partitions:\n{partitions_to_done}"
)
partitions_to_done[record_or_partition.partition] = True
elif isinstance(record_or_partition, Record):
partitions_to_done[record_or_partition_or_exception.partition] = True
elif isinstance(record_or_partition_or_exception, Record):
# Emit records
yield record_or_partition
elif isinstance(record_or_partition, Partition):
yield record_or_partition_or_exception
elif isinstance(record_or_partition_or_exception, Partition):
# A new partition was generated and must be processed
partitions_to_done[record_or_partition] = False
partitions_to_done[record_or_partition_or_exception] = False
if self._slice_logger.should_log_slice_message(self._logger):
self._message_repository.emit_message(self._slice_logger.create_slice_log_message(record_or_partition.to_slice()))
self._submit_task(futures, partition_reader.process_partition, record_or_partition)
self._message_repository.emit_message(
self._slice_logger.create_slice_log_message(record_or_partition_or_exception.to_slice())
)
self._submit_task(futures, partition_reader.process_partition, record_or_partition_or_exception)
if finished_partitions and all(partitions_to_done.values()):
# All partitions were generated and process. We're done here
break
Expand All @@ -135,10 +142,17 @@ def _wait_while_too_many_pending_futures(self, futures: List[Future[Any]]) -> No
def _check_for_errors(self, futures: List[Future[Any]]) -> None:
exceptions_from_futures = [f for f in [future.exception() for future in futures] if f is not None]
if exceptions_from_futures:
raise RuntimeError(f"Failed reading from stream {self.name} with errors: {exceptions_from_futures}")
futures_not_done = [f for f in futures if not f.done()]
if futures_not_done:
raise RuntimeError(f"Failed reading from stream {self.name} with futures not done: {futures_not_done}")
exception = RuntimeError(f"Failed reading from stream {self.name} with errors: {exceptions_from_futures}")
self._stop_and_raise_exception(exception)
else:
futures_not_done = [f for f in futures if not f.done()]
if futures_not_done:
exception = RuntimeError(f"Failed reading from stream {self.name} with futures not done: {futures_not_done}")
self._stop_and_raise_exception(exception)

def _stop_and_raise_exception(self, exception: BaseException) -> None:
self._threadpool.shutdown(wait=False, cancel_futures=True)
raise exception

@property
def name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Generic, List, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Generic, List, Mapping, Optional, Set, Tuple, Type, TypeVar

from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode
from airbyte_cdk.sources import AbstractSource
Expand Down Expand Up @@ -46,7 +45,10 @@ def __init__(
expected_read_error: Tuple[Optional[Type[Exception]], Optional[str]],
incremental_scenario_config: Optional[IncrementalScenarioConfig],
expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]] = None,
log_levels: Optional[Set[str]] = None,
):
if log_levels is None:
log_levels = {"ERROR", "WARN", "WARNING"}
self.name = name
self.config = config
self.source = source
Expand All @@ -60,6 +62,7 @@ def __init__(
self.expected_read_error = expected_read_error
self.incremental_scenario_config = incremental_scenario_config
self.expected_analytics = expected_analytics
self.log_levels = log_levels
self.validate()

def validate(self) -> None:
Expand Down Expand Up @@ -112,6 +115,7 @@ def __init__(self) -> None:
self._incremental_scenario_config: Optional[IncrementalScenarioConfig] = None
self._expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]] = None
self.source_builder: Optional[SourceBuilder[SourceType]] = None
self._log_levels = None

def set_name(self, name: str) -> "TestScenarioBuilder[SourceType]":
self._name = name
Expand Down Expand Up @@ -157,6 +161,10 @@ def set_expected_read_error(self, error: Type[Exception], message: str) -> "Test
self._expected_read_error = error, message
return self

def set_log_levels(self, levels: Set[str]) -> "TestScenarioBuilder":
self._log_levels = levels
return self

def set_source_builder(self, source_builder: SourceBuilder[SourceType]) -> "TestScenarioBuilder[SourceType]":
self.source_builder = source_builder
return self
Expand Down Expand Up @@ -188,6 +196,7 @@ def build(self) -> "TestScenario[SourceType]":
self._expected_read_error,
self._incremental_scenario_config,
self._expected_analytics,
self._log_levels,
)

def _configured_catalog(self, sync_mode: SyncMode) -> Optional[Mapping[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run_test_read_incremental(

def _verify_read_output(output: Dict[str, Any], scenario: TestScenario[AbstractSource]) -> None:
records, logs = output["records"], output["logs"]
logs = [log for log in logs if log.get("level") in ("ERROR", "WARN", "WARNING")]
logs = [log for log in logs if log.get("level") in scenario.log_levels]
expected_records = scenario.expected_records
assert len(records) == len(expected_records)
for actual, expected in zip(records, expected_records):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from typing import Any, List, Mapping, Optional, Tuple, Union

from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConnectorSpecification, DestinationSyncMode, SyncMode
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_protocol.models import ConfiguredAirbyteStream
from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder


class StreamFacadeSource(AbstractSource):
def __init__(self, streams: List[Stream], max_workers: int):
self._streams = streams
self._max_workers = max_workers

def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
return True, None

def streams(self, config: Mapping[str, Any]) -> List[Stream]:
return [StreamFacade.create_from_stream(stream, self, stream.logger, self._max_workers) for stream in self._streams]

@property
def message_repository(self) -> Union[None, MessageRepository]:
return InMemoryMessageRepository()

def spec(self, logger: logging.Logger) -> ConnectorSpecification:
return ConnectorSpecification(connectionSpecification={})

def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=s.as_airbyte_stream(),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
for s in self._streams
]
)


class StreamFacadeSourceBuilder(SourceBuilder[StreamFacadeSource]):
def __init__(self):
self._source = None
self._streams = []
self._max_workers = 1

def set_streams(self, streams: List[Stream]) -> "StreamFacadeSourceBuilder":
self._streams = streams
return self

def set_max_workers(self, max_workers: int):
self._max_workers = max_workers
return self

def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> StreamFacadeSource:
return StreamFacadeSource(self._streams, self._max_workers)
Loading

0 comments on commit 7da2822

Please sign in to comment.