diff --git a/airbyte/_batch_handles.py b/airbyte/_batch_handles.py index 220e9d52..e405dfca 100644 --- a/airbyte/_batch_handles.py +++ b/airbyte/_batch_handles.py @@ -3,22 +3,45 @@ from __future__ import annotations -from contextlib import suppress -from pathlib import Path # noqa: TCH003 # Pydantic needs this import -from typing import IO, Any, Optional +from typing import IO, TYPE_CHECKING, Callable -from pydantic import BaseModel, Field, PrivateAttr +if TYPE_CHECKING: + from pathlib import Path -class BatchHandle(BaseModel): + +class BatchHandle: """A handle for a batch of records.""" - stream_name: str - batch_id: str + def __init__( + self, + stream_name: str, + batch_id: str, + files: list[Path], + file_opener: Callable[[Path], IO[bytes]], + ) -> None: + """Initialize the batch handle.""" + self._stream_name = stream_name + self._batch_id = batch_id + self._files = files + self._record_count = 0 + assert self._files, "A batch must have at least one file." + self._open_file_writer: IO[bytes] = file_opener(self._files[0]) + + @property + def files(self) -> list[Path]: + """Return the files.""" + return self._files + + @property + def batch_id(self) -> str: + """Return the batch ID.""" + return self._batch_id - files: list[Path] = Field(default_factory=list) - _open_file_writer: Optional[Any] = PrivateAttr(default=None) - _record_count: int = PrivateAttr(default=0) + @property + def stream_name(self) -> str: + """Return the stream name.""" + return self._stream_name @property def record_count(self) -> int: @@ -36,11 +59,11 @@ def open_file_writer(self) -> IO[bytes] | None: def close_files(self) -> None: """Close the file writer.""" - if self._open_file_writer is None: + if self.open_file_writer is None: return - with suppress(Exception): - self._open_file_writer.close() + # with suppress(Exception): + self.open_file_writer.close() def __del__(self) -> None: """Upon deletion, close the file writer.""" diff --git a/airbyte/_processors/base.py b/airbyte/_processors/base.py index dcde9c02..49acb1e0 100644 --- a/airbyte/_processors/base.py +++ b/airbyte/_processors/base.py @@ -10,14 +10,12 @@ from __future__ import annotations import abc -import contextlib import io import sys from collections import defaultdict from typing import TYPE_CHECKING, Any, cast, final import pyarrow as pa -import ulid from airbyte_protocol.models import ( AirbyteMessage, @@ -31,20 +29,17 @@ ) from airbyte import exceptions as exc -from airbyte._batch_handles import BatchHandle from airbyte.caches.base import CacheBase -from airbyte.progress import progress from airbyte.strategies import WriteStrategy from airbyte.types import _get_pyarrow_type if TYPE_CHECKING: - from collections.abc import Generator, Iterable, Iterator + from collections.abc import Iterable, Iterator from airbyte.caches._catalog_manager import CatalogManager -DEFAULT_BATCH_SIZE = 10_000 DEBUG_MODE = False # Set to True to enable additional debug logging. @@ -55,10 +50,6 @@ class AirbyteMessageParsingError(Exception): class RecordProcessor(abc.ABC): """Abstract base class for classes which can process input records.""" - MAX_BATCH_SIZE: int = DEFAULT_BATCH_SIZE - - skip_finalize_step: bool = False - def __init__( self, cache: CacheBase, @@ -78,10 +69,6 @@ def __init__( self.source_catalog: ConfiguredAirbyteCatalog | None = None self._source_name: str | None = None - self._active_batches: dict[str, BatchHandle] = {} - self._pending_batches: dict[str, list[BatchHandle]] = defaultdict(list, {}) - self._finalized_batches: dict[str, list[BatchHandle]] = defaultdict(list, {}) - self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {}) self._finalized_state_messages: dict[ str, @@ -150,19 +137,18 @@ def process_input_stream( write_strategy=write_strategy, ) - def _process_record_message( + @abc.abstractmethod + def process_record_message( self, record_msg: AirbyteRecordMessage, ) -> None: """Write a record to the cache. - This method is called for each record message, before the batch is written. + This method is called for each record message. - By default this is a no-op but file writers can override this method to write the record to - files. + In most cases, the SQL processor will not perform any action, but will pass this along to to + the file processor. """ - _ = record_msg # Unused - pass @final def process_airbyte_messages( @@ -181,7 +167,7 @@ def process_airbyte_messages( for message in messages: if message.type is Type.RECORD: record_msg = cast(AirbyteRecordMessage, message.record) - self._process_record_message(record_msg) + self.process_record_message(record_msg) elif message.type is Type.STATE: state_msg = cast(AirbyteStateMessage, message.state) @@ -197,129 +183,34 @@ def process_airbyte_messages( # Type.LOG, Type.TRACE, Type.CONTROL, etc. pass - # We are at the end of the stream. Process whatever else is queued. - self._flush_active_batches() - - all_streams = list(set(self._pending_batches.keys()) | set(self._finalized_batches.keys())) - # Add empty streams to the streams list, so we create a destination table for it - for stream_name in self.expected_streams: - if stream_name not in all_streams: - if DEBUG_MODE: - print(f"Stream {stream_name} has no data") - all_streams.append(stream_name) - - # Finalize any pending batches - for stream_name in all_streams: - self._finalize_batches(stream_name, write_strategy=write_strategy) - progress.log_stream_finalized(stream_name) - - def _flush_active_batches( - self, - ) -> None: - """Flush active batches for all streams.""" - for stream_name in self._active_batches: - self._flush_active_batch(stream_name) - - def _flush_active_batch( - self, - stream_name: str, - ) -> None: - """Flush the active batch for the given stream. - - This entails moving the active batch to the pending batches, closing any open files, and - logging the batch as written. - """ - raise NotImplementedError( - "Subclasses must implement the _flush_active_batch() method.", + self.flush_all( + write_strategy=write_strategy, ) - def _cleanup_batch( # noqa: B027 # Intentionally empty, not abstract - self, - batch_handle: BatchHandle, - ) -> None: - """Clean up the cache. - - This method is called after the given batch has been finalized. - - For instance, file writers can override this method to delete the files created. Caches, - similarly, can override this method to delete any other temporary artifacts. - """ - pass - - def _new_batch_id(self) -> str: - """Return a new batch handle.""" - return str(ulid.ULID()) - - def _new_batch( - self, - stream_name: str, - ) -> BatchHandle: - """Create and return a new batch handle. - - By default this is a concatenation of the stream name and batch ID. - However, any Python object can be returned, such as a Path object. - """ - batch_id = self._new_batch_id() - return BatchHandle(stream_name=stream_name, batch_id=batch_id) - - def _finalize_batches( - self, - stream_name: str, - write_strategy: WriteStrategy, - ) -> list[BatchHandle]: - """Finalize all uncommitted batches. - - Returns a mapping of batch IDs to batch handles, for processed batches. - - This is a generic implementation, which can be overridden. - """ - _ = write_strategy # Unused - with self._finalizing_batches(stream_name) as batches_to_finalize: - if batches_to_finalize and not self.skip_finalize_step: - raise NotImplementedError( - "Caches need to be finalized but no _finalize_batch() method " - f"exists for class {self.__class__.__name__}", - ) - - return batches_to_finalize + # Clean up files, if requested. + if self.cache.cleanup: + self.cleanup_all() @abc.abstractmethod + def flush_all(self, write_strategy: WriteStrategy) -> None: + """Finalize any pending writes.""" + def _finalize_state_messages( self, stream_name: str, state_messages: list[AirbyteStateMessage], ) -> None: - """Handle state messages. - Might be a no-op if the processor doesn't handle incremental state.""" - pass - - @final - @contextlib.contextmanager - def _finalizing_batches( - self, - stream_name: str, - ) -> Generator[list[BatchHandle], str, None]: - """Context manager to use for finalizing batches, if applicable. - - Returns a mapping of batch IDs to batch handles, for those processed batches. - """ - batches_to_finalize: list[BatchHandle] = self._pending_batches[stream_name].copy() - state_messages_to_finalize: list[AirbyteStateMessage] = self._pending_state_messages[ - stream_name - ].copy() - self._pending_batches[stream_name].clear() - self._pending_state_messages[stream_name].clear() - - progress.log_batches_finalizing(stream_name, len(batches_to_finalize)) - yield batches_to_finalize - self._finalize_state_messages(stream_name, state_messages_to_finalize) - progress.log_batches_finalized(stream_name, len(batches_to_finalize)) - - self._finalized_batches[stream_name] += batches_to_finalize - self._finalized_state_messages[stream_name] += state_messages_to_finalize - - for batch_handle in batches_to_finalize: - self._cleanup_batch(batch_handle) + """Handle state messages by passing them to the catalog manager.""" + if not self._catalog_manager: + raise exc.AirbyteLibInternalError( + message="Catalog manager should exist but does not.", + ) + if state_messages and self._source_name: + self._catalog_manager.save_state( + source_name=self._source_name, + stream_name=stream_name, + state=state_messages[-1], + ) def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract """Create the database. @@ -329,27 +220,6 @@ def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract """ pass - def _teardown(self) -> None: - """Teardown the processor resources. - - By default, the base implementation simply calls _cleanup_batch() for all pending batches. - """ - batch_lists: list[list[BatchHandle]] = list(self._pending_batches.values()) + list( - self._finalized_batches.values() - ) - - # TODO: flatten lists and remove nested 'for' - for batch_list in batch_lists: - for batch_handle in batch_list: - self._cleanup_batch( - batch_handle=batch_handle, - ) - - @final - def __del__(self) -> None: - """Teardown temporary resources when instance is unloaded from memory.""" - self._teardown() - @final def _get_stream_config( self, @@ -384,3 +254,10 @@ def _get_stream_pyarrow_schema( ].items() ] ) + + def cleanup_all(self) -> None: # noqa: B027 # Intentionally empty, not abstract + """Clean up all resources. + + The default implementation is a no-op. + """ + pass diff --git a/airbyte/_processors/file/base.py b/airbyte/_processors/file/base.py index 50e12a14..3339b61b 100644 --- a/airbyte/_processors/file/base.py +++ b/airbyte/_processors/file/base.py @@ -4,37 +4,48 @@ from __future__ import annotations import abc +from collections import defaultdict from pathlib import Path from typing import IO, TYPE_CHECKING, final import ulid -from overrides import overrides from airbyte import exceptions as exc from airbyte._batch_handles import BatchHandle -from airbyte._processors.base import RecordProcessor from airbyte._util.protocol_util import airbyte_record_message_to_dict from airbyte.progress import progress if TYPE_CHECKING: - from io import BufferedWriter from airbyte_protocol.models import ( AirbyteRecordMessage, - AirbyteStateMessage, ) + from airbyte.caches.base import CacheBase + from airbyte.strategies import WriteStrategy + DEFAULT_BATCH_SIZE = 10000 -class FileWriterBase(RecordProcessor, abc.ABC): +class FileWriterBase(abc.ABC): """A generic base implementation for a file-based cache.""" default_cache_file_suffix: str = ".batch" - _active_batches: dict[str, BatchHandle] + MAX_BATCH_SIZE: int = DEFAULT_BATCH_SIZE + + def __init__( + self, + cache: CacheBase, + ) -> None: + """Initialize the file writer.""" + self.cache = cache + + self._active_batches: dict[str, BatchHandle] = {} + self._pending_batches: dict[str, list[BatchHandle]] = defaultdict(list, {}) + self._finalized_batches: dict[str, list[BatchHandle]] = defaultdict(list, {}) def _get_new_cache_file_path( self, @@ -45,16 +56,14 @@ def _get_new_cache_file_path( batch_id = batch_id or str(ulid.ULID()) target_dir = Path(self.cache.cache_dir) target_dir.mkdir(parents=True, exist_ok=True) - return target_dir / f"{stream_name}_{batch_id}.{self.default_cache_file_suffix}" + return target_dir / f"{stream_name}_{batch_id}{self.default_cache_file_suffix}" def _open_new_file( self, - stream_name: str, - ) -> tuple[Path, IO[bytes]]: + file_path: Path, + ) -> IO[bytes]: """Open a new file for writing.""" - file_path: Path = self._get_new_cache_file_path(stream_name) - file_handle: BufferedWriter = file_path.open("wb") - return file_path, file_handle + return file_path.open("wb") def _flush_active_batch( self, @@ -69,12 +78,10 @@ def _flush_active_batch( return batch_handle: BatchHandle = self._active_batches[stream_name] + batch_handle.close_files() del self._active_batches[stream_name] - if self.skip_finalize_step: - self._finalized_batches[stream_name].append(batch_handle) - else: - self._pending_batches[stream_name].append(batch_handle) + self._pending_batches[stream_name].append(batch_handle) progress.log_batch_written( stream_name=stream_name, batch_size=batch_handle.record_count, @@ -88,35 +95,25 @@ def _new_batch( The base implementation creates and opens a new file for writing so it is ready to receive records. + + This also flushes the active batch if one already exists for the given stream. """ + if stream_name in self._active_batches: + self._flush_active_batch(stream_name) + batch_id = self._new_batch_id() - new_file_path, new_file_handle = self._open_new_file(stream_name=stream_name) + new_file_path = self._get_new_cache_file_path(stream_name) + batch_handle = BatchHandle( stream_name=stream_name, batch_id=batch_id, files=[new_file_path], + file_opener=self._open_new_file, ) self._active_batches[stream_name] = batch_handle return batch_handle - @overrides - def _cleanup_batch( - self, - batch_handle: BatchHandle, - ) -> None: - """Clean up the cache. - - For file writers, this means deleting the files created and declared in the batch. - - This method is a no-op if the `cleanup` config option is set to False. - """ - self._close_batch_files(batch_handle) - - if self.cache.cleanup: - for file_path in batch_handle.files: - file_path.unlink() - - def _close_batch_files( + def _close_batch( self, batch_handle: BatchHandle, ) -> None: @@ -127,10 +124,7 @@ def _close_batch_files( batch_handle.close_files() @final - def cleanup_batch( - self, - batch_handle: BatchHandle, - ) -> None: + def cleanup_all(self) -> None: """Clean up the cache. For file writers, this means deleting the files created and declared in the batch. @@ -139,21 +133,15 @@ def cleanup_batch( Subclasses should override `_cleanup_batch` instead. """ - self._cleanup_batch(batch_handle) + for batch_list in self._pending_batches.values(): + for batch_handle in batch_list: + self._cleanup_batch(batch_handle) - @overrides - def _finalize_state_messages( - self, - stream_name: str, - state_messages: list[AirbyteStateMessage], - ) -> None: - """ - State messages are not used in file writers, so this method is a no-op. - """ - pass + for batch_list in self._finalized_batches.values(): + for batch_handle in batch_list: + self._cleanup_batch(batch_handle) - @overrides - def _process_record_message( + def process_record_message( self, record_msg: AirbyteRecordMessage, ) -> None: @@ -167,22 +155,17 @@ def _process_record_message( stream_name = record_msg.stream batch_handle: BatchHandle - if not self._pending_batches[stream_name]: + if stream_name not in self._active_batches: batch_handle = self._new_batch(stream_name=stream_name) else: - batch_handle = self._pending_batches[stream_name][-1] + batch_handle = self._active_batches[stream_name] if batch_handle.record_count + 1 > self.MAX_BATCH_SIZE: - # Already at max batch size, so write the batch and start a new one - self._close_batch_files(batch_handle) - progress.log_batch_written( - stream_name=batch_handle.stream_name, - batch_size=batch_handle.record_count, - ) + # Already at max batch size, so start a new batch. batch_handle = self._new_batch(stream_name=stream_name) - if not batch_handle.open_file_writer: + if batch_handle.open_file_writer is None: raise exc.AirbyteLibInternalError(message="Expected open file writer.") self._write_record_dict( @@ -191,6 +174,51 @@ def _process_record_message( ) batch_handle.increment_record_count() + def _flush_active_batches( + self, + ) -> None: + """Flush active batches for all streams.""" + streams = list(self._active_batches.keys()) + for stream_name in streams: + self._flush_active_batch(stream_name) + + def _cleanup_batch( + self, + batch_handle: BatchHandle, + ) -> None: + """Clean up the cache. + + For file writers, this means deleting the files created and declared in the batch. + + This method is a no-op if the `cleanup` config option is set to False. + """ + self._close_batch(batch_handle) + + if self.cache.cleanup: + for file_path in batch_handle.files: + if file_path.exists(): + file_path.unlink() + + def _new_batch_id(self) -> str: + """Return a new batch handle.""" + return str(ulid.ULID()) + + def flush_all(self, write_strategy: WriteStrategy) -> None: + """Finalize any pending writes.""" + # We are at the end of the stream. Process whatever else is queued. + self._flush_active_batches() + + # Destructor + + @final + def __del__(self) -> None: + """Teardown temporary resources when instance is unloaded from memory.""" + if self.cache.cleanup: + self.cleanup_all() + + # Abstract methods + + @abc.abstractmethod def _write_record_dict( self, record_dict: dict, diff --git a/airbyte/_processors/file/jsonl.py b/airbyte/_processors/file/jsonl.py index 8a238abb..0578fded 100644 --- a/airbyte/_processors/file/jsonl.py +++ b/airbyte/_processors/file/jsonl.py @@ -26,11 +26,10 @@ class JsonlWriter(FileWriterBase): def _open_new_file( self, - stream_name: str, - ) -> tuple[Path, IO[bytes]]: + file_path: Path, + ) -> IO[bytes]: """Open a new file for writing.""" - file_path = self._get_new_cache_file_path(stream_name) - return file_path, cast(IO[bytes], gzip.open(file_path, "w")) + return cast(IO[bytes], gzip.open(file_path, "w")) def _write_record_dict( self, diff --git a/airbyte/_processors/file/parquet.py b/airbyte/_processors/file/parquet.py index 8a7abf66..34498fe2 100644 --- a/airbyte/_processors/file/parquet.py +++ b/airbyte/_processors/file/parquet.py @@ -33,9 +33,9 @@ def _get_missing_columns( The comparison is based on a case-insensitive comparison of the column names. """ - if not self._catalog_manager: + if not self.cache.processor._catalog_manager: raise exc.AirbyteLibInternalError(message="Catalog manager should exist but does not.") - stream = self._catalog_manager.get_stream_config(stream_name) + stream = self.cache.processor._catalog_manager.get_stream_config(stream_name) stream_property_names = stream.stream.json_schema["properties"].keys() return [ col diff --git a/airbyte/_processors/sql/base.py b/airbyte/_processors/sql/base.py index 7bc26a07..4a8ea9b7 100644 --- a/airbyte/_processors/sql/base.py +++ b/airbyte/_processors/sql/base.py @@ -3,6 +3,7 @@ from __future__ import annotations +import contextlib import enum from contextlib import contextmanager from functools import cached_property @@ -31,6 +32,7 @@ from airbyte._util.text_util import lower_case_set from airbyte.caches._catalog_manager import CatalogManager from airbyte.datasets._sql import CachedDataset +from airbyte.progress import progress from airbyte.strategies import WriteStrategy from airbyte.types import SQLTypeConverter @@ -45,6 +47,7 @@ from sqlalchemy.sql.base import Executable from airbyte_protocol.models import ( + AirbyteRecordMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ) @@ -90,9 +93,7 @@ def __init__( engine=self.get_sql_engine(), table_name_resolver=lambda stream_name: self.get_sql_table_name(stream_name), ) - self.file_writer = file_writer or self.file_writer_class( - cache, catalog_manager=self._catalog_manager - ) + self.file_writer = file_writer or self.file_writer_class(cache) self.type_converter = self.type_converter_class() self._cached_table_definitions: dict[str, sqlalchemy.Table] = {} @@ -160,11 +161,11 @@ def register_source( This method is called by the source when it is initialized. """ self._source_name = source_name - self.file_writer.register_source( - source_name, - incoming_source_catalog, - stream_names=stream_names, - ) + # self.register_source( + # source_name, + # incoming_source_catalog, + # stream_names=stream_names, + # ) self._ensure_schema_exists() super().register_source( source_name, @@ -233,6 +234,19 @@ def get_pandas_dataframe( engine = self.get_sql_engine() return pd.read_sql_table(table_name, engine) + def process_record_message( + self, + record_msg: AirbyteRecordMessage, + ) -> None: + """Write a record to the cache. + + This method is called for each record message, before the batch is written. + + In most cases, the SQL processor will not perform any action, but will pass this along to to + the file processor. + """ + self.file_writer.process_record_message(record_msg) + # Protected members (non-public interface): def _init_connection_settings(self, connection: Connection) -> None: @@ -474,28 +488,20 @@ def _get_sql_column_definitions( # columns["_airbyte_loaded_at"] = sqlalchemy.TIMESTAMP() return columns - def _cleanup_batch( - self, - batch_handle: BatchHandle, - ) -> None: - """Clean up the cache. - - For SQL caches, we only need to call the cleanup operation on the file writer. - - Subclasses should call super() if they override this method. - """ - self.file_writer.cleanup_batch(batch_handle) + def flush_all(self, write_strategy: WriteStrategy) -> None: + """Finalize any pending writes.""" + for stream_name in self.expected_streams: + self.flush_stream(stream_name, write_strategy=write_strategy) @final - @overrides - def _finalize_batches( + def flush_stream( self, stream_name: str, write_strategy: WriteStrategy, ) -> list[BatchHandle]: """Finalize all uncommitted batches. - This is a generic 'final' implementation, which should not be overridden. + This is a generic 'final' SQL implementation, which should not be overridden. Returns a mapping of batch IDs to batch handles, for those processed batches. @@ -503,7 +509,10 @@ def _finalize_batches( Some sources will send us duplicate records within the same stream, although this is a fairly rare edge case we can ignore in V1. """ - with self._finalizing_batches(stream_name) as batches_to_finalize: + # Flush any pending writes + self.file_writer.flush_all(write_strategy=write_strategy) + + with self.finalizing_batches(stream_name) as batches_to_finalize: # Make sure the target schema and target table exist. self._ensure_schema_exists() final_table_name = self._ensure_final_table_exists( @@ -541,26 +550,45 @@ def _finalize_batches( finally: self._drop_temp_table(temp_table_name, if_exists=True) - # Return the batch handles as measure of work completed. - return batches_to_finalize + # Return the batch handles as measure of work completed. + return batches_to_finalize - @overrides - def _finalize_state_messages( + @final + def cleanup_all(self) -> None: + """Clean resources.""" + self.file_writer.cleanup_all() + + # Finalizing context manager + + @final + @contextlib.contextmanager + def finalizing_batches( self, stream_name: str, - state_messages: list[AirbyteStateMessage], - ) -> None: - """Handle state messages by passing them to the catalog manager.""" - if not self._catalog_manager: - raise exc.AirbyteLibInternalError( - message="Catalog manager should exist but does not.", - ) - if state_messages and self._source_name: - self._catalog_manager.save_state( - source_name=self._source_name, - stream_name=stream_name, - state=state_messages[-1], - ) + ) -> Generator[list[BatchHandle], str, None]: + """Context manager to use for finalizing batches, if applicable. + + Returns a mapping of batch IDs to batch handles, for those processed batches. + """ + batches_to_finalize: list[BatchHandle] = self.file_writer._pending_batches[ + stream_name + ].copy() + state_messages_to_finalize: list[AirbyteStateMessage] = self._pending_state_messages[ + stream_name + ].copy() + self.file_writer._pending_batches[stream_name].clear() + self._pending_state_messages[stream_name].clear() + + progress.log_batches_finalizing(stream_name, len(batches_to_finalize)) + yield batches_to_finalize + self._finalize_state_messages(stream_name, state_messages_to_finalize) + progress.log_batches_finalized(stream_name, len(batches_to_finalize)) + + self.file_writer._finalized_batches[stream_name] += batches_to_finalize + self._finalized_state_messages[stream_name] += state_messages_to_finalize + + for batch_handle in batches_to_finalize: + self.file_writer._cleanup_batch(batch_handle) def _execute_sql(self, sql: str | TextClause | Executable) -> CursorResult: """Execute the given SQL statement."""