diff --git a/airbyte/_connector_base.py b/airbyte/_connector_base.py index 9de6faff..f1db8b21 100644 --- a/airbyte/_connector_base.py +++ b/airbyte/_connector_base.py @@ -61,15 +61,20 @@ def __init__( If config is provided, it will be validated against the spec if validate is True. """ self.executor = executor - self.name = name + self._name = name self._config_dict: dict[str, Any] | None = None self._last_log_messages: list[str] = [] self._spec: ConnectorSpecification | None = None self._selected_stream_names: list[str] = [] - self._file_logger: logging.Logger = new_passthrough_file_logger(self.name) + self._file_logger: logging.Logger = new_passthrough_file_logger(self._name) if config is not None: self.set_config(config, validate=validate) + @property + def name(self) -> str: + """Get the name of the connector.""" + return self._name + def _print_info_message( self, message: str, diff --git a/airbyte/_future_cdk/catalog_providers.py b/airbyte/_future_cdk/catalog_providers.py index b8ea6137..dd0bb7d7 100644 --- a/airbyte/_future_cdk/catalog_providers.py +++ b/airbyte/_future_cdk/catalog_providers.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy from typing import TYPE_CHECKING, Any, final from airbyte_protocol.models import ( @@ -15,6 +16,7 @@ ) from airbyte import exceptions as exc +from airbyte.strategies import WriteMethod, WriteStrategy if TYPE_CHECKING: @@ -135,3 +137,70 @@ def from_read_result( ] ) ) + + def get_primary_keys( + self, + stream_name: str, + ) -> list[str]: + pks = self.get_configured_stream_info(stream_name).primary_key + if not pks: + return [] + + joined_pks = [".".join(pk) for pk in pks] + for pk in joined_pks: + if "." in pk: + msg = f"Nested primary keys are not yet supported. Found: {pk}" + raise NotImplementedError(msg) + + return joined_pks + + def get_cursor_key( + self, + stream_name: str, + ) -> str | None: + return self.get_configured_stream_info(stream_name).cursor_field + + def resolve_write_method( + self, + stream_name: str, + write_strategy: WriteStrategy, + ) -> WriteMethod: + """Return the write method for the given stream.""" + has_pks: bool = bool(self.get_primary_keys(stream_name)) + has_incremental_key: bool = bool(self.get_cursor_key(stream_name)) + if write_strategy == WriteStrategy.MERGE and not has_pks: + raise exc.PyAirbyteInputError( + message="Cannot use merge strategy on a stream with no primary keys.", + context={ + "stream_name": stream_name, + }, + ) + + if write_strategy != WriteStrategy.AUTO: + return WriteMethod(write_strategy) + + if has_pks: + return WriteMethod.MERGE + + if has_incremental_key: + return WriteMethod.APPEND + + return WriteMethod.REPLACE + + def with_write_strategy( + self, + write_strategy: WriteStrategy, + ) -> CatalogProvider: + """Return a new catalog provider with the specified write strategy applied. + + The original catalog provider is not modified. + """ + new_catalog: ConfiguredAirbyteCatalog = copy.deepcopy(self.configured_catalog) + for stream in new_catalog.streams: + write_method = self.resolve_write_method( + stream_name=stream.stream.name, + write_strategy=write_strategy, + ) + stream.destination_sync_mode = write_method.destination_sync_mode + + return CatalogProvider(new_catalog) diff --git a/airbyte/_future_cdk/record_processor.py b/airbyte/_future_cdk/record_processor.py deleted file mode 100644 index dbe21baa..00000000 --- a/airbyte/_future_cdk/record_processor.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -"""Abstract base class for Processors, including SQL processors. - -Processors accept Airbyte messages as input from STDIN or from another input stream. -""" - -from __future__ import annotations - -import abc -import io -import sys -from collections import defaultdict -from typing import IO, TYPE_CHECKING, cast, final - -from airbyte_cdk import AirbyteMessage -from airbyte_protocol.models import ( - AirbyteRecordMessage, - AirbyteStateMessage, - AirbyteStateType, - AirbyteStreamState, - AirbyteTraceMessage, - Type, -) - -from airbyte import exceptions as exc -from airbyte._future_cdk.state_writers import StdOutStateWriter -from airbyte._message_iterators import AirbyteMessageIterator -from airbyte.records import StreamRecordHandler -from airbyte.strategies import WriteStrategy - - -if TYPE_CHECKING: - from collections.abc import Iterable, Iterator - - from airbyte._batch_handles import BatchHandle - from airbyte._future_cdk.catalog_providers import CatalogProvider - from airbyte._future_cdk.state_writers import StateWriterBase - from airbyte.progress import ProgressTracker - - -class AirbyteMessageParsingError(Exception): - """Raised when an Airbyte message is invalid or cannot be parsed.""" - - -class RecordProcessorBase(abc.ABC): - """Abstract base class for classes which can process Airbyte messages from a source. - - This class is responsible for all aspects of handling Airbyte protocol. - - The class should be passed a catalog manager and stream manager class to handle the - catalog and state aspects of the protocol. - """ - - def __init__( - self, - *, - catalog_provider: CatalogProvider, - state_writer: StateWriterBase | None = None, - ) -> None: - """Initialize the processor. - - If a state writer is not provided, the processor will use the default (STDOUT) state writer. - """ - self._catalog_provider: CatalogProvider | None = catalog_provider - self._state_writer: StateWriterBase | None = state_writer or StdOutStateWriter() - - self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {}) - self._finalized_state_messages: dict[ - str, - list[AirbyteStateMessage], - ] = defaultdict(list, {}) - - self._setup() - - @property - def expected_streams(self) -> set[str]: - """Return the expected stream names.""" - return set(self.catalog_provider.stream_names) - - @property - def catalog_provider( - self, - ) -> CatalogProvider: - """Return the catalog manager. - - Subclasses should set this property to a valid catalog manager instance if one - is not explicitly passed to the constructor. - - Raises: - PyAirbyteInternalError: If the catalog manager is not set. - """ - if not self._catalog_provider: - raise exc.PyAirbyteInternalError( - message="Catalog manager should exist but does not.", - ) - - return self._catalog_provider - - @property - def state_writer( - self, - ) -> StateWriterBase: - """Return the state writer instance. - - Subclasses should set this property to a valid state manager instance if one - is not explicitly passed to the constructor. - - Raises: - PyAirbyteInternalError: If the state manager is not set. - """ - if not self._state_writer: - raise exc.PyAirbyteInternalError( - message="State manager should exist but does not.", - ) - - return self._state_writer - - @final - def process_stdin( - self, - *, - write_strategy: WriteStrategy = WriteStrategy.AUTO, - progress_tracker: ProgressTracker, - ) -> None: - """Process the input stream from stdin. - - Return a list of summaries for testing. - """ - input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8") - self.process_input_stream( - input_stream, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - @final - def _airbyte_messages_from_buffer( - self, - buffer: io.TextIOBase, - ) -> Iterator[AirbyteMessage]: - """Yield messages from a buffer.""" - yield from (AirbyteMessage.model_validate_json(line) for line in buffer) - - @final - def process_input_stream( - self, - input_stream: IO[str], - *, - write_strategy: WriteStrategy = WriteStrategy.AUTO, - progress_tracker: ProgressTracker, - ) -> None: - """Parse the input stream and process data in batches. - - Return a list of summaries for testing. - """ - messages = AirbyteMessageIterator.from_str_buffer(input_stream) - self.process_airbyte_messages( - messages, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - @abc.abstractmethod - def process_record_message( - self, - record_msg: AirbyteRecordMessage, - stream_record_handler: StreamRecordHandler, - progress_tracker: ProgressTracker, - ) -> None: - """Write a record. - - This method is called for each record message. - - In most cases, the SQL processor will not perform any action, but will pass this along to to - the file processor. - """ - - @final - def process_airbyte_messages( - self, - messages: Iterable[AirbyteMessage], - *, - write_strategy: WriteStrategy, - progress_tracker: ProgressTracker, - ) -> None: - """Process a stream of Airbyte messages.""" - if not isinstance(write_strategy, WriteStrategy): - raise exc.AirbyteInternalError( - message="Invalid `write_strategy` argument. Expected instance of WriteStrategy.", - context={"write_strategy": write_strategy}, - ) - - stream_record_handlers: dict[str, StreamRecordHandler] = {} - - # Process messages, writing to batches as we go - for message in messages: - if message.type is Type.RECORD: - record_msg = cast(AirbyteRecordMessage, message.record) - stream_name = record_msg.stream - - if stream_name not in stream_record_handlers: - stream_record_handlers[stream_name] = StreamRecordHandler( - json_schema=self.catalog_provider.get_stream_json_schema( - stream_name=stream_name, - ), - normalize_keys=True, - prune_extra_fields=True, - ) - - self.process_record_message( - record_msg, - stream_record_handler=stream_record_handlers[stream_name], - progress_tracker=progress_tracker, - ) - - elif message.type is Type.STATE: - state_msg = cast(AirbyteStateMessage, message.state) - if state_msg.type in {AirbyteStateType.GLOBAL, AirbyteStateType.LEGACY}: - self._pending_state_messages[f"_{state_msg.type}"].append(state_msg) - else: - stream_state = cast(AirbyteStreamState, state_msg.stream) - stream_name = stream_state.stream_descriptor.name - self._pending_state_messages[stream_name].append(state_msg) - - elif message.type is Type.TRACE: - trace_msg: AirbyteTraceMessage = cast(AirbyteTraceMessage, message.trace) - if trace_msg.stream_status and trace_msg.stream_status.status == "SUCCEEDED": - # This stream has completed successfully, so go ahead and write the data. - # This will also finalize any pending state messages. - self.write_stream_data( - stream_name=trace_msg.stream_status.stream_descriptor.name, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - else: - # Ignore unexpected or unhandled message types: - # Type.LOG, Type.CONTROL, etc. - pass - - # We've finished processing input data. - # Finalize all received records and state messages: - self.write_all_stream_data( - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - self.cleanup_all() - - def write_all_stream_data( - self, - write_strategy: WriteStrategy, - progress_tracker: ProgressTracker, - ) -> None: - """Finalize any pending writes. - - Streams are processed in alphabetical order, so that order is deterministic and opaque, - without resorting to knowledge about catalog declaration order. - """ - for stream_name in sorted(self.catalog_provider.stream_names): - self.write_stream_data( - stream_name, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - @abc.abstractmethod - def write_stream_data( - self, - stream_name: str, - write_strategy: WriteStrategy, - progress_tracker: ProgressTracker, - ) -> list[BatchHandle]: - """Write pending stream data to the cache.""" - ... - - def _finalize_state_messages( - self, - state_messages: list[AirbyteStateMessage], - ) -> None: - """Handle state messages by passing them to the catalog manager.""" - if state_messages: - self.state_writer.write_state( - state_message=state_messages[-1], - ) - - def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract - """Create the database. - - By default this is a no-op but subclasses can override this method to prepare - any necessary resources. - """ - pass - - def cleanup_all(self) -> None: # noqa: B027 # Intentionally empty, not abstract - """Clean up all resources. - - The default implementation is a no-op. - """ - pass diff --git a/airbyte/_future_cdk/sql_processor.py b/airbyte/_future_cdk/sql_processor.py index 04bc72cb..57442912 100644 --- a/airbyte/_future_cdk/sql_processor.py +++ b/airbyte/_future_cdk/sql_processor.py @@ -6,6 +6,7 @@ import abc import contextlib import enum +from collections import defaultdict from contextlib import contextmanager from functools import cached_property from pathlib import Path @@ -29,8 +30,17 @@ ) from sqlalchemy.sql.elements import TextClause +from airbyte_protocol.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStreamState, + AirbyteTraceMessage, + Type, +) + from airbyte import exceptions as exc -from airbyte._future_cdk.record_processor import RecordProcessorBase from airbyte._future_cdk.state_writers import StdOutStateWriter from airbyte._util.name_normalizers import LowerCaseNormalizer from airbyte.constants import ( @@ -39,12 +49,13 @@ AB_RAW_ID_COLUMN, DEBUG_MODE, ) -from airbyte.strategies import WriteStrategy +from airbyte.records import StreamRecordHandler +from airbyte.strategies import WriteMethod, WriteStrategy from airbyte.types import SQLTypeConverter if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Iterable from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.cursor import CursorResult @@ -52,17 +63,11 @@ from sqlalchemy.sql.base import Executable from sqlalchemy.sql.type_api import TypeEngine - from airbyte_protocol.models import ( - AirbyteRecordMessage, - AirbyteStateMessage, - ) - from airbyte._batch_handles import BatchHandle from airbyte._future_cdk.catalog_providers import CatalogProvider from airbyte._future_cdk.state_writers import StateWriterBase - from airbyte._processors.file.base import FileWriterBase + from airbyte._writers.jsonl import FileWriterBase from airbyte.progress import ProgressTracker - from airbyte.records import StreamRecordHandler from airbyte.secrets.base import SecretString @@ -116,7 +121,7 @@ def get_vendor_client(self) -> object: ) -class SqlProcessorBase(RecordProcessorBase): +class SqlProcessorBase(abc.ABC): """A base class to be used for SQL Caches.""" type_converter_class: type[SQLTypeConverter] = SQLTypeConverter @@ -131,8 +136,6 @@ class SqlProcessorBase(RecordProcessorBase): supports_merge_insert = False """True if the database supports the MERGE INTO syntax.""" - # Constructor: - def __init__( self, *, @@ -152,10 +155,16 @@ def __init__( self._sql_config: SqlConfig = sql_config - super().__init__( - state_writer=state_writer, - catalog_provider=catalog_provider, - ) + self._catalog_provider: CatalogProvider | None = catalog_provider + self._state_writer: StateWriterBase | None = state_writer or StdOutStateWriter() + + self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {}) + self._finalized_state_messages: dict[ + str, + list[AirbyteStateMessage], + ] = defaultdict(list, {}) + + self._setup() self.file_writer = file_writer or self.file_writer_class( cache_dir=cast(Path, temp_dir), cleanup=temp_file_cleanup, @@ -166,6 +175,150 @@ def __init__( self._known_schemas_list: list[str] = [] self._ensure_schema_exists() + @property + def catalog_provider( + self, + ) -> CatalogProvider: + """Return the catalog manager. + + Subclasses should set this property to a valid catalog manager instance if one + is not explicitly passed to the constructor. + + Raises: + PyAirbyteInternalError: If the catalog manager is not set. + """ + if not self._catalog_provider: + raise exc.PyAirbyteInternalError( + message="Catalog manager should exist but does not.", + ) + + return self._catalog_provider + + @property + def state_writer( + self, + ) -> StateWriterBase: + """Return the state writer instance. + + Subclasses should set this property to a valid state manager instance if one + is not explicitly passed to the constructor. + + Raises: + PyAirbyteInternalError: If the state manager is not set. + """ + if not self._state_writer: + raise exc.PyAirbyteInternalError( + message="State manager should exist but does not.", + ) + + return self._state_writer + + @final + def process_airbyte_messages( + self, + messages: Iterable[AirbyteMessage], + *, + write_strategy: WriteStrategy = WriteStrategy.AUTO, + progress_tracker: ProgressTracker, + ) -> None: + """Process a stream of Airbyte messages. + + This method assumes that the catalog is already registered with the processor. + """ + if not isinstance(write_strategy, WriteStrategy): + raise exc.AirbyteInternalError( + message="Invalid `write_strategy` argument. Expected instance of WriteStrategy.", + context={"write_strategy": write_strategy}, + ) + + stream_record_handlers: dict[str, StreamRecordHandler] = {} + + # Process messages, writing to batches as we go + for message in messages: + if message.type is Type.RECORD: + record_msg = cast(AirbyteRecordMessage, message.record) + stream_name = record_msg.stream + + if stream_name not in stream_record_handlers: + stream_record_handlers[stream_name] = StreamRecordHandler( + json_schema=self.catalog_provider.get_stream_json_schema( + stream_name=stream_name, + ), + normalize_keys=True, + prune_extra_fields=True, + ) + + self.process_record_message( + record_msg, + stream_record_handler=stream_record_handlers[stream_name], + progress_tracker=progress_tracker, + ) + + elif message.type is Type.STATE: + state_msg = cast(AirbyteStateMessage, message.state) + if state_msg.type in {AirbyteStateType.GLOBAL, AirbyteStateType.LEGACY}: + self._pending_state_messages[f"_{state_msg.type}"].append(state_msg) + else: + stream_state = cast(AirbyteStreamState, state_msg.stream) + stream_name = stream_state.stream_descriptor.name + self._pending_state_messages[stream_name].append(state_msg) + + elif message.type is Type.TRACE: + trace_msg: AirbyteTraceMessage = cast(AirbyteTraceMessage, message.trace) + if trace_msg.stream_status and trace_msg.stream_status.status == "SUCCEEDED": + # This stream has completed successfully, so go ahead and write the data. + # This will also finalize any pending state messages. + self.write_stream_data( + stream_name=trace_msg.stream_status.stream_descriptor.name, + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + + else: + # Ignore unexpected or unhandled message types: + # Type.LOG, Type.CONTROL, etc. + pass + + # We've finished processing input data. + # Finalize all received records and state messages: + self._write_all_stream_data( + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + + self.cleanup_all() + + def _write_all_stream_data( + self, + write_strategy: WriteStrategy, + progress_tracker: ProgressTracker, + ) -> None: + """Finalize any pending writes.""" + for stream_name in self.catalog_provider.stream_names: + self.write_stream_data( + stream_name, + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + + def _finalize_state_messages( + self, + state_messages: list[AirbyteStateMessage], + ) -> None: + """Handle state messages by passing them to the catalog manager.""" + if state_messages: + self.state_writer.write_state( + state_message=state_messages[-1], + ) + + def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract + """Create the database. + + By default this is a no-op but subclasses can override this method to prepare + any necessary resources. + """ + pass + # Public interface: @property @@ -243,7 +396,7 @@ def process_record_message( # Protected members (non-public interface): - def _init_connection_settings(self, connection: Connection) -> None: + def _init_connection_settings(self, connection: Connection) -> None: # noqa: B027 # Intentionally empty, not abstract """This is called automatically whenever a new connection is created. By default this is a no-op. Subclasses can use this to set connection settings, such as @@ -478,7 +631,9 @@ def _get_sql_column_definitions( def write_stream_data( self, stream_name: str, - write_strategy: WriteStrategy, + *, + write_method: WriteMethod | None = None, + write_strategy: WriteStrategy | None = None, progress_tracker: ProgressTracker, ) -> list[BatchHandle]: """Finalize all uncommitted batches. @@ -491,6 +646,18 @@ def write_stream_data( Some sources will send us duplicate records within the same stream, although this is a fairly rare edge case we can ignore in V1. """ + if write_method and write_strategy and write_strategy != WriteStrategy.AUTO: + raise exc.PyAirbyteInternalError( + message=( + "Both `write_method` and `write_strategy` were provided. " + "Only one should be set." + ), + ) + if not write_method: + write_method = self.catalog_provider.resolve_write_method( + stream_name=stream_name, + write_strategy=write_strategy or WriteStrategy.AUTO, + ) # Flush any pending writes self.file_writer.flush_active_batches( progress_tracker=progress_tracker, @@ -528,7 +695,7 @@ def write_stream_data( stream_name=stream_name, temp_table_name=temp_table_name, final_table_name=final_table_name, - write_strategy=write_strategy, + write_method=write_method, ) finally: self._drop_temp_table(temp_table_name, if_exists=True) @@ -705,28 +872,10 @@ def _write_temp_table_to_final_table( stream_name: str, temp_table_name: str, final_table_name: str, - write_strategy: WriteStrategy, + write_method: WriteMethod, ) -> None: """Write the temp table into the final table using the provided write strategy.""" - has_pks: bool = bool(self._get_primary_keys(stream_name)) - has_incremental_key: bool = bool(self._get_incremental_key(stream_name)) - if write_strategy == WriteStrategy.MERGE and not has_pks: - raise exc.PyAirbyteInputError( - message="Cannot use merge strategy on a stream with no primary keys.", - context={ - "stream_name": stream_name, - }, - ) - - if write_strategy == WriteStrategy.AUTO: - if has_pks: - write_strategy = WriteStrategy.MERGE - elif has_incremental_key: - write_strategy = WriteStrategy.APPEND - else: - write_strategy = WriteStrategy.REPLACE - - if write_strategy == WriteStrategy.REPLACE: + if write_method == WriteMethod.REPLACE: # Note: No need to check for schema compatibility # here, because we are fully replacing the table. self._swap_temp_table_with_final_table( @@ -736,7 +885,7 @@ def _write_temp_table_to_final_table( ) return - if write_strategy == WriteStrategy.APPEND: + if write_method == WriteMethod.APPEND: self._ensure_compatible_table_schema( stream_name=stream_name, table_name=final_table_name, @@ -748,7 +897,7 @@ def _write_temp_table_to_final_table( ) return - if write_strategy == WriteStrategy.MERGE: + if write_method == WriteMethod.MERGE: self._ensure_compatible_table_schema( stream_name=stream_name, table_name=final_table_name, @@ -770,9 +919,9 @@ def _write_temp_table_to_final_table( return raise exc.PyAirbyteInternalError( - message="Write strategy is not supported.", + message="Write method is not supported.", context={ - "write_strategy": write_strategy, + "write_method": write_method, }, ) @@ -795,28 +944,6 @@ def _append_temp_table_to_final_table( """, ) - def _get_primary_keys( - self, - stream_name: str, - ) -> list[str]: - pks = self.catalog_provider.get_configured_stream_info(stream_name).primary_key - if not pks: - return [] - - joined_pks = [".".join(pk) for pk in pks] - for pk in joined_pks: - if "." in pk: - msg = f"Nested primary keys are not yet supported. Found: {pk}" - raise NotImplementedError(msg) - - return joined_pks - - def _get_incremental_key( - self, - stream_name: str, - ) -> str | None: - return self.catalog_provider.get_configured_stream_info(stream_name).cursor_field - def _swap_temp_table_with_final_table( self, stream_name: str, @@ -859,7 +986,9 @@ def _merge_temp_table_to_final_table( """ nl = "\n" columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)} - pk_columns = {self._quote_identifier(c) for c in self._get_primary_keys(stream_name)} + pk_columns = { + self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name) + } non_pk_columns = columns - pk_columns join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns) set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns) @@ -915,7 +1044,7 @@ def _emulated_merge_temp_table_to_final_table( """ final_table = self._get_table_by_name(final_table_name) temp_table = self._get_table_by_name(temp_table_name) - pk_columns = self._get_primary_keys(stream_name) + pk_columns = self.catalog_provider.get_primary_keys(stream_name) columns_to_update: set[str] = self._get_sql_column_definitions( stream_name=stream_name diff --git a/airbyte/_processors/sql/bigquery.py b/airbyte/_processors/sql/bigquery.py index 1a01b980..652cc49f 100644 --- a/airbyte/_processors/sql/bigquery.py +++ b/airbyte/_processors/sql/bigquery.py @@ -19,7 +19,7 @@ from airbyte import exceptions as exc from airbyte._future_cdk import SqlProcessorBase from airbyte._future_cdk.sql_processor import SqlConfig -from airbyte._processors.file.jsonl import JsonlWriter +from airbyte._writers.jsonl import JsonlWriter from airbyte.constants import DEFAULT_CACHE_SCHEMA_NAME from airbyte.secrets.base import SecretString from airbyte.types import SQLTypeConverter diff --git a/airbyte/_processors/sql/duckdb.py b/airbyte/_processors/sql/duckdb.py index f8416a9f..0fde86d8 100644 --- a/airbyte/_processors/sql/duckdb.py +++ b/airbyte/_processors/sql/duckdb.py @@ -14,7 +14,7 @@ from airbyte._future_cdk import SqlProcessorBase from airbyte._future_cdk.sql_processor import SqlConfig -from airbyte._processors.file import JsonlWriter +from airbyte._writers.jsonl import JsonlWriter from airbyte.secrets.base import SecretString diff --git a/airbyte/_processors/sql/motherduck.py b/airbyte/_processors/sql/motherduck.py index edece91e..345cf487 100644 --- a/airbyte/_processors/sql/motherduck.py +++ b/airbyte/_processors/sql/motherduck.py @@ -9,8 +9,8 @@ from duckdb_engine import DuckDBEngineWarning from overrides import overrides -from airbyte._processors.file import JsonlWriter from airbyte._processors.sql.duckdb import DuckDBSqlProcessor +from airbyte._writers.jsonl import JsonlWriter if TYPE_CHECKING: diff --git a/airbyte/_processors/sql/postgres.py b/airbyte/_processors/sql/postgres.py index 2944dd84..ab4dbb0c 100644 --- a/airbyte/_processors/sql/postgres.py +++ b/airbyte/_processors/sql/postgres.py @@ -6,7 +6,7 @@ from overrides import overrides from airbyte._future_cdk.sql_processor import SqlConfig, SqlProcessorBase -from airbyte._processors.file import JsonlWriter +from airbyte._writers.jsonl import JsonlWriter from airbyte.secrets.base import SecretString diff --git a/airbyte/_processors/sql/snowflake.py b/airbyte/_processors/sql/snowflake.py index b4b4daa3..0da6f505 100644 --- a/airbyte/_processors/sql/snowflake.py +++ b/airbyte/_processors/sql/snowflake.py @@ -16,7 +16,7 @@ from airbyte import exceptions as exc from airbyte._future_cdk import SqlProcessorBase from airbyte._future_cdk.sql_processor import SqlConfig -from airbyte._processors.file.jsonl import JsonlWriter +from airbyte._writers.jsonl import JsonlWriter from airbyte.constants import DEFAULT_CACHE_SCHEMA_NAME from airbyte.secrets.base import SecretString from airbyte.types import SQLTypeConverter diff --git a/airbyte/_processors/sql/snowflakecortex.py b/airbyte/_processors/sql/snowflakecortex.py index ca42948f..7271aef2 100644 --- a/airbyte/_processors/sql/snowflakecortex.py +++ b/airbyte/_processors/sql/snowflakecortex.py @@ -24,7 +24,7 @@ from airbyte._future_cdk.catalog_providers import CatalogProvider from airbyte._future_cdk.state_writers import StateWriterBase - from airbyte._processors.file.base import FileWriterBase + from airbyte._writers.jsonl import FileWriterBase class SnowflakeCortexTypeConverter(SnowflakeTypeConverter): diff --git a/airbyte/_util/telemetry.py b/airbyte/_util/telemetry.py index 2927742e..73376d40 100644 --- a/airbyte/_util/telemetry.py +++ b/airbyte/_util/telemetry.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: + from airbyte._writers.base import AirbyteWriterInterface from airbyte.caches.base import CacheBase from airbyte.destinations.base import Destination from airbyte.sources.base import Source @@ -226,18 +227,27 @@ class DestinationTelemetryInfo: version: str | None @classmethod - def from_destination(cls, destination: Destination | str | None) -> DestinationTelemetryInfo: + def from_destination( + cls, + destination: Destination | AirbyteWriterInterface | str | None, + ) -> DestinationTelemetryInfo: if not destination: return cls(name=UNKNOWN, executor_type=UNKNOWN, version=UNKNOWN) if isinstance(destination, str): return cls(name=destination, executor_type=UNKNOWN, version=UNKNOWN) - # Else, `destination` should be a `Destination` at this point + if hasattr(destination, "executor"): + return cls( + name=destination.name, + executor_type=type(destination.executor).__name__, + version=destination.executor.reported_version, + ) + return cls( - name=destination.name, - executor_type=type(destination.executor).__name__, - version=destination.executor.reported_version, + name=repr(destination), + executor_type=UNKNOWN, + version=UNKNOWN, ) @@ -274,7 +284,7 @@ def get_env_flags() -> dict[str, Any]: def send_telemetry( *, source: Source | str | None, - destination: Destination | str | None, + destination: Destination | AirbyteWriterInterface | str | None, cache: CacheBase | None, state: EventState, event_type: EventType, diff --git a/airbyte/_processors/file/__init__.py b/airbyte/_writers/__init__.py similarity index 68% rename from airbyte/_processors/file/__init__.py rename to airbyte/_writers/__init__.py index 2ef9b9a4..fd2c0072 100644 --- a/airbyte/_processors/file/__init__.py +++ b/airbyte/_writers/__init__.py @@ -4,8 +4,7 @@ from __future__ import annotations from airbyte._batch_handles import BatchHandle -from airbyte._processors.file.base import FileWriterBase -from airbyte._processors.file.jsonl import JsonlWriter +from airbyte._writers.jsonl import FileWriterBase, JsonlWriter __all__ = [ diff --git a/airbyte/_writers/base.py b/airbyte/_writers/base.py new file mode 100644 index 00000000..39690058 --- /dev/null +++ b/airbyte/_writers/base.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Write interfaces for PyAirbyte.""" + +from __future__ import annotations + +import abc +from typing import IO, TYPE_CHECKING + + +if TYPE_CHECKING: + from airbyte._future_cdk.catalog_providers import CatalogProvider + from airbyte._future_cdk.state_writers import StateWriterBase + from airbyte._message_iterators import AirbyteMessageIterator + from airbyte.progress import ProgressTracker + from airbyte.strategies import WriteStrategy + + +class AirbyteWriterInterface(abc.ABC): + """An interface for writing Airbyte messages.""" + + @property + def name(self) -> str: + """Return the name of the writer. + + This is used for logging and state tracking. + """ + if hasattr(self, "_name"): + return self._name + + return self.__class__.__name__ + + def _write_airbyte_io_stream( + self, + stdin: IO[str], + *, + catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, + state_writer: StateWriterBase | None = None, + progress_tracker: ProgressTracker, + ) -> None: + """Read from the connector and write to the cache. + + This is a specialized version of `_write_airbyte_message_stream` that reads from an IO + stream. Writers can override this method to provide custom behavior for reading from an IO + stream, without paying the cost of converting the stream to an AirbyteMessageIterator. + """ + self._write_airbyte_message_stream( + stdin, + catalog_provider=catalog_provider, + write_strategy=write_strategy, + state_writer=state_writer, + progress_tracker=progress_tracker, + ) + + @abc.abstractmethod + def _write_airbyte_message_stream( + self, + stdin: IO[str] | AirbyteMessageIterator, + *, + catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, + state_writer: StateWriterBase | None = None, + progress_tracker: ProgressTracker, + ) -> None: + """Write the incoming data. + + Note: Callers should use `_write_airbyte_io_stream` instead of this method if + `stdin` is always an IO stream. This ensures that the most efficient method is used for + writing the incoming stream. + """ + ... diff --git a/airbyte/_processors/file/base.py b/airbyte/_writers/file_writers.py similarity index 86% rename from airbyte/_processors/file/base.py rename to airbyte/_writers/file_writers.py index 7951fa07..744379e1 100644 --- a/airbyte/_processors/file/base.py +++ b/airbyte/_writers/file_writers.py @@ -13,6 +13,7 @@ from airbyte import exceptions as exc from airbyte import progress from airbyte._batch_handles import BatchHandle +from airbyte._writers.base import AirbyteWriterInterface from airbyte.records import StreamRecord, StreamRecordHandler @@ -21,14 +22,18 @@ AirbyteRecordMessage, ) + from airbyte._future_cdk.catalog_providers import CatalogProvider + from airbyte._future_cdk.state_writers import StateWriterBase + from airbyte._message_iterators import AirbyteMessageIterator from airbyte.progress import ProgressTracker + from airbyte.strategies import WriteStrategy DEFAULT_BATCH_SIZE = 100_000 -class FileWriterBase(abc.ABC): - """A generic base implementation for a file-based cache.""" +class FileWriterBase(AirbyteWriterInterface): + """A generic abstract implementation for a file-based writer.""" default_cache_file_suffix: str = ".batch" prune_extra_fields: bool = False @@ -186,6 +191,25 @@ def process_record_message( ) batch_handle.increment_record_count() + def _write_airbyte_message_stream( + self, + stdin: IO[str] | AirbyteMessageIterator, + *, + catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, + state_writer: StateWriterBase | None = None, + progress_tracker: ProgressTracker, + ) -> None: + """Read from the connector and write to the cache. + + This is not implemented for file writers, as they should be wrapped by another writer that + handles state tracking and other logic. + """ + _ = stdin, catalog_provider, write_strategy, state_writer, progress_tracker + raise exc.PyAirbyteInternalError from NotImplementedError( + "File writers should be wrapped by another AirbyteWriterInterface." + ) + def flush_active_batches( self, progress_tracker: ProgressTracker, diff --git a/airbyte/_processors/file/jsonl.py b/airbyte/_writers/jsonl.py similarity index 97% rename from airbyte/_processors/file/jsonl.py rename to airbyte/_writers/jsonl.py index 4f935945..39bf198b 100644 --- a/airbyte/_processors/file/jsonl.py +++ b/airbyte/_writers/jsonl.py @@ -10,7 +10,7 @@ import orjson from overrides import overrides -from airbyte._processors.file.base import ( +from airbyte._writers.file_writers import ( FileWriterBase, ) diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index ca7d5b37..96ac559d 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, final +from typing import IO, TYPE_CHECKING, Any, final import pandas as pd import pyarrow as pa @@ -19,6 +19,7 @@ SqlProcessorBase, ) from airbyte._future_cdk.state_writers import StdOutStateWriter +from airbyte._writers.base import AirbyteWriterInterface from airbyte.caches._catalog_backend import CatalogBackendBase, SqlCatalogBackend from airbyte.caches._state_backend import SqlStateBackend from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE, TEMP_FILE_CLEANUP @@ -31,11 +32,14 @@ from airbyte._future_cdk.sql_processor import SqlProcessorBase from airbyte._future_cdk.state_providers import StateProviderBase from airbyte._future_cdk.state_writers import StateWriterBase + from airbyte._message_iterators import AirbyteMessageIterator from airbyte.caches._state_backend_base import StateBackendBase from airbyte.datasets._base import DatasetBase + from airbyte.progress import ProgressTracker + from airbyte.strategies import WriteStrategy -class CacheBase(SqlConfig): +class CacheBase(SqlConfig, AirbyteWriterInterface): """Base configuration for a cache. Caches inherit from the matching `SqlConfig` class, which provides the SQL config settings @@ -52,6 +56,8 @@ class CacheBase(SqlConfig): cleanup: bool = TEMP_FILE_CLEANUP """Whether to clean up the cache after use.""" + _name: str = PrivateAttr() + _deployed_api_root: str | None = PrivateAttr(default=None) _deployed_workspace_id: str | None = PrivateAttr(default=None) _deployed_destination_id: str | None = PrivateAttr(default=None) @@ -95,14 +101,6 @@ def __init__(self, **data: Any) -> None: # noqa: ANN401 temp_file_cleanup=self.cleanup, ) - @property - def name(self) -> str: - """Return the name of the cache. - - By default, this is the class name. - """ - return type(self).__name__ - @final @property def processor(self) -> SqlProcessorBase: @@ -258,3 +256,25 @@ def __iter__( # type: ignore [override] # Overriding Pydantic model method ) -> Iterator[tuple[str, Any]]: """Iterate over the streams in the cache.""" return ((name, dataset) for name, dataset in self.streams.items()) + + def _write_airbyte_message_stream( + self, + stdin: IO[str] | AirbyteMessageIterator, + *, + catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, + state_writer: StateWriterBase | None = None, + progress_tracker: ProgressTracker, + ) -> None: + """Read from the connector and write to the cache.""" + cache_processor = self.get_record_processor( + source_name=self.name, + catalog_provider=catalog_provider, + state_writer=state_writer, + ) + cache_processor.process_airbyte_messages( + messages=stdin, + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + progress_tracker.log_cache_processing_complete() diff --git a/airbyte/caches/duckdb.py b/airbyte/caches/duckdb.py index 68b6e349..8fdf40ad 100644 --- a/airbyte/caches/duckdb.py +++ b/airbyte/caches/duckdb.py @@ -34,7 +34,6 @@ ) -# @dataclass class DuckDBCache(DuckDBConfig, CacheBase): """A DuckDB cache.""" diff --git a/airbyte/destinations/base.py b/airbyte/destinations/base.py index 22133513..8f42362a 100644 --- a/airbyte/destinations/base.py +++ b/airbyte/destinations/base.py @@ -21,9 +21,10 @@ StateProviderBase, StaticInputState, ) -from airbyte._future_cdk.state_writers import NoOpStateWriter, StateWriterBase, StdOutStateWriter +from airbyte._future_cdk.state_writers import NoOpStateWriter, StdOutStateWriter from airbyte._message_iterators import AirbyteMessageIterator from airbyte._util.temp_files import as_temp_files +from airbyte._writers.base import AirbyteWriterInterface from airbyte.caches.util import get_default_cache from airbyte.progress import ProgressTracker from airbyte.results import ReadResult, WriteResult @@ -37,7 +38,7 @@ from airbyte.caches.base import CacheBase -class Destination(ConnectorBase): +class Destination(ConnectorBase, AirbyteWriterInterface): """A class representing a destination that can be called.""" connector_type: Literal["destination"] = "destination" @@ -71,11 +72,12 @@ def write( # noqa: PLR0912, PLR0915 # Too many arguments/statements write_strategy: WriteStrategy = WriteStrategy.AUTO, force_full_refresh: bool = False, ) -> WriteResult: - """Write data to the destination. + """Write data from source connector or already cached source data. + + Caching is enabled by default, unless explicitly disabled. Args: - source_data: The source data to write to the destination. Can be a `Source`, a `Cache`, - or a `ReadResult` object. + source_data: The source data to write. Can be a `Source` or a `ReadResult` object. streams: The streams to write to the destination. If omitted or if "*" is provided, all streams will be written. If `source_data` is a source, then streams must be selected here or on the source. If both are specified, this setting will override @@ -225,8 +227,8 @@ def write( # noqa: PLR0912, PLR0915 # Too many arguments/statements self._write_airbyte_message_stream( stdin=message_iterator, catalog_provider=catalog_provider, + write_strategy=write_strategy, state_writer=destination_state_writer, - skip_validation=False, progress_tracker=progress_tracker, ) except Exception as ex: @@ -249,18 +251,18 @@ def _write_airbyte_message_stream( stdin: IO[str] | AirbyteMessageIterator, *, catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, state_writer: StateWriterBase | None = None, - skip_validation: bool = False, progress_tracker: ProgressTracker, ) -> None: """Read from the connector and write to the cache.""" # Run optional validation step - if not skip_validation: - self.validate_config() - if state_writer is None: state_writer = StdOutStateWriter() + # Apply the write strategy to the catalog provider before sending to the destination + catalog_provider = catalog_provider.with_write_strategy(write_strategy) + with as_temp_files( files_contents=[ self._config, diff --git a/airbyte/progress.py b/airbyte/progress.py index 149e9e67..c867b739 100644 --- a/airbyte/progress.py +++ b/airbyte/progress.py @@ -54,6 +54,7 @@ from types import ModuleType from airbyte._message_iterators import AirbyteMessageIterator + from airbyte._writers.base import AirbyteWriterInterface from airbyte.caches.base import CacheBase from airbyte.destinations.base import Destination from airbyte.sources.base import Source @@ -177,7 +178,7 @@ def __init__( *, source: Source | None, cache: CacheBase | None, - destination: Destination | None, + destination: AirbyteWriterInterface | Destination | None, expected_streams: list[str] | None = None, ) -> None: """Initialize the progress tracker.""" diff --git a/airbyte/results.py b/airbyte/results.py index 7035abab..c6a72bf0 100644 --- a/airbyte/results.py +++ b/airbyte/results.py @@ -22,6 +22,7 @@ from airbyte._future_cdk.catalog_providers import CatalogProvider from airbyte._future_cdk.state_providers import StateProviderBase from airbyte._future_cdk.state_writers import StateWriterBase + from airbyte._writers.base import AirbyteWriterInterface from airbyte.caches import CacheBase from airbyte.destinations.base import Destination from airbyte.progress import ProgressTracker @@ -110,7 +111,7 @@ class WriteResult: def __init__( self, *, - destination: Destination, + destination: AirbyteWriterInterface | Destination, source_data: Source | ReadResult, catalog_provider: CatalogProvider, state_writer: StateWriterBase, @@ -121,7 +122,7 @@ def __init__( This class should not be created directly. Instead, it should be returned by the `write` method of the `Destination` class. """ - self._destination: Destination = destination + self._destination: AirbyteWriterInterface | Destination = destination self._source_data: Source | ReadResult = source_data self._catalog_provider: CatalogProvider = catalog_provider self._state_writer: StateWriterBase = state_writer diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index fa2e5893..c4493d6d 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -626,9 +626,9 @@ def read( state_provider: StateProviderBase | None = None else: state_provider = cache.get_state_provider( - source_name=self.name, + source_name=self._name, ) - state_writer = cache.get_state_writer(source_name=self.name) + state_writer = cache.get_state_writer(source_name=self._name) if streams: self.select_streams(streams) @@ -717,23 +717,20 @@ def _read_to_cache( # noqa: PLR0913 # Too many arguments if incremental_streams: self._log_incremental_streams(incremental_streams=incremental_streams) - airbyte_message_iterator: Iterator[AirbyteMessage] = self._read_with_catalog( - catalog=catalog_provider.configured_catalog, - state=state_provider, - progress_tracker=progress_tracker, + airbyte_message_iterator = AirbyteMessageIterator( + self._read_with_catalog( + catalog=catalog_provider.configured_catalog, + state=state_provider, + progress_tracker=progress_tracker, + ) ) - cache_processor = cache.get_record_processor( - source_name=self.name, + cache._write_airbyte_message_stream( # noqa: SLF001 # Non-public API + stdin=airbyte_message_iterator, catalog_provider=catalog_provider, - state_writer=state_writer, - ) - cache_processor.process_airbyte_messages( - messages=airbyte_message_iterator, write_strategy=write_strategy, + state_writer=state_writer, progress_tracker=progress_tracker, ) - progress_tracker.log_cache_processing_complete() - return ReadResult( source_name=self.name, progress_tracker=progress_tracker, diff --git a/airbyte/strategies.py b/airbyte/strategies.py index 05ab3ba9..e55b4d9a 100644 --- a/airbyte/strategies.py +++ b/airbyte/strategies.py @@ -6,11 +6,26 @@ from enum import Enum +from airbyte_protocol.models import DestinationSyncMode + + +_MERGE = "merge" +_REPLACE = "replace" +_APPEND = "append" +_AUTO = "auto" + class WriteStrategy(str, Enum): - """Read strategies for PyAirbyte.""" + """Read strategies for PyAirbyte. + + Read strategies set a preferred method for writing data to a destination. The actual method used + may differ based on the capabilities of the destination. - MERGE = "merge" + If a destination does not support the preferred method, it will fall back to the next best + method. + """ + + MERGE = _MERGE """Merge new records with existing records. This requires a primary key to be set on the stream. @@ -20,13 +35,13 @@ class WriteStrategy(str, Enum): please use the `auto` strategy instead. """ - APPEND = "append" + APPEND = _APPEND """Append new records to existing records.""" - REPLACE = "replace" + REPLACE = _REPLACE """Replace existing records with new records.""" - AUTO = "auto" + AUTO = _AUTO """Automatically determine the best strategy to use. This will use the following logic: @@ -34,3 +49,44 @@ class WriteStrategy(str, Enum): - Else, if there's an incremental key, use append. - Else, use full replace (table swap). """ + + +class WriteMethod(str, Enum): + """Write methods for PyAirbyte. + + Unlike write strategies, write methods are expected to be fully resolved and do not require any + additional logic to determine the best method to use. + + If a destination does not support the declared method, it will raise an exception. + """ + + MERGE = _MERGE + """Merge new records with existing records. + + This requires a primary key to be set on the stream. + If no primary key is set, this will raise an exception. + + To apply this strategy in cases where some destination streams don't have a primary key, + please use the `auto` strategy instead. + """ + + APPEND = _APPEND + """Append new records to existing records.""" + + REPLACE = _REPLACE + """Replace existing records with new records.""" + + @property + def destination_sync_mode(self) -> DestinationSyncMode: + """Convert the write method to a destination sync mode.""" + if self == WriteMethod.MERGE: + return DestinationSyncMode.append_dedup + + if self == WriteMethod.APPEND: + return DestinationSyncMode.append + + if self == WriteMethod.REPLACE: + return DestinationSyncMode.overwrite + + msg = f"Unknown write method: {self}" # type: ignore [unreachable] + raise ValueError(msg) diff --git a/tests/integration_tests/destinations/test_source_to_destination.py b/tests/integration_tests/destinations/test_source_to_destination.py index 5a8f5ab1..7cbc56e6 100644 --- a/tests/integration_tests/destinations/test_source_to_destination.py +++ b/tests/integration_tests/destinations/test_source_to_destination.py @@ -89,6 +89,7 @@ def test_duckdb_destination_write_components( catalog_provider=CatalogProvider( configured_catalog=new_source_faker.configured_catalog ), + write_strategy=WriteStrategy.AUTO, progress_tracker=ProgressTracker( source=None, cache=None, diff --git a/tests/integration_tests/test_docker_executable.py b/tests/integration_tests/test_docker_executable.py index 53869791..1869e25c 100644 --- a/tests/integration_tests/test_docker_executable.py +++ b/tests/integration_tests/test_docker_executable.py @@ -84,8 +84,12 @@ def test_faker_pks( read_result = source_docker_faker_seed_a.read( new_duckdb_cache, write_strategy="append" ) - assert read_result.cache.processor._get_primary_keys("products") == ["id"] - assert read_result.cache.processor._get_primary_keys("purchases") == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "products" + ) == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "purchases" + ) == ["id"] @pytest.mark.slow diff --git a/tests/integration_tests/test_source_faker_integration.py b/tests/integration_tests/test_source_faker_integration.py index 3704b40c..2117cf3e 100644 --- a/tests/integration_tests/test_source_faker_integration.py +++ b/tests/integration_tests/test_source_faker_integration.py @@ -130,8 +130,12 @@ def test_faker_pks( assert catalog.streams[1].primary_key read_result = source_faker_seed_a.read(duckdb_cache, write_strategy="append") - assert read_result.cache.processor._get_primary_keys("products") == ["id"] - assert read_result.cache.processor._get_primary_keys("purchases") == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "products" + ) == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "purchases" + ) == ["id"] @pytest.mark.slow