From d9ab8f00cd4dec02b7418dc441d77ac07fe0a0f2 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Tue, 5 Mar 2024 15:34:10 -0800 Subject: [PATCH] refactor BatchHandle to a single class, fix mypy lint errors --- airbyte/_batch_handles.py | 47 +++++++++++++++++++++++ airbyte/_processors/base.py | 9 +---- airbyte/_processors/file/__init__.py | 9 +++-- airbyte/_processors/file/base.py | 56 ++++++++-------------------- airbyte/_processors/file/parquet.py | 21 ++++------- airbyte/_processors/sql/base.py | 10 ++--- 6 files changed, 81 insertions(+), 71 deletions(-) create mode 100644 airbyte/_batch_handles.py diff --git a/airbyte/_batch_handles.py b/airbyte/_batch_handles.py new file mode 100644 index 00000000..220e9d52 --- /dev/null +++ b/airbyte/_batch_handles.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Batch handle class.""" + +from __future__ import annotations + +from contextlib import suppress +from pathlib import Path # noqa: TCH003 # Pydantic needs this import +from typing import IO, Any, Optional + +from pydantic import BaseModel, Field, PrivateAttr + + +class BatchHandle(BaseModel): + """A handle for a batch of records.""" + + stream_name: str + batch_id: str + + files: list[Path] = Field(default_factory=list) + _open_file_writer: Optional[Any] = PrivateAttr(default=None) + _record_count: int = PrivateAttr(default=0) + + @property + def record_count(self) -> int: + """Return the record count.""" + return self._record_count + + def increment_record_count(self) -> None: + """Increment the record count.""" + self._record_count += 1 + + @property + def open_file_writer(self) -> IO[bytes] | None: + """Return the open file writer, if any, or None.""" + return self._open_file_writer + + def close_files(self) -> None: + """Close the file writer.""" + if self._open_file_writer is None: + return + + with suppress(Exception): + self._open_file_writer.close() + + def __del__(self) -> None: + """Upon deletion, close the file writer.""" + self.close_files() diff --git a/airbyte/_processors/base.py b/airbyte/_processors/base.py index 3031ea36..9111a7f9 100644 --- a/airbyte/_processors/base.py +++ b/airbyte/_processors/base.py @@ -18,7 +18,6 @@ import pyarrow as pa import ulid -from pydantic import BaseModel from airbyte_protocol.models import ( AirbyteMessage, @@ -32,6 +31,7 @@ ) from airbyte import exceptions as exc +from airbyte._batch_handles import BatchHandle from airbyte.caches.base import CacheBase from airbyte.progress import progress from airbyte.strategies import WriteStrategy @@ -48,13 +48,6 @@ DEBUG_MODE = False # Set to True to enable additional debug logging. -class BatchHandle(BaseModel): - """A handle for a batch of records.""" - - stream_name: str - batch_id: str - - class AirbyteMessageParsingError(Exception): """Raised when an Airbyte message is invalid or cannot be parsed.""" diff --git a/airbyte/_processors/file/__init__.py b/airbyte/_processors/file/__init__.py index 26c25484..0f83652b 100644 --- a/airbyte/_processors/file/__init__.py +++ b/airbyte/_processors/file/__init__.py @@ -3,13 +3,14 @@ from __future__ import annotations -from .base import FileWriterBase, FileWriterBatchHandle -from .jsonl import JsonlWriter -from .parquet import ParquetWriter +from airbyte._batch_handles import BatchHandle +from airbyte._processors.file.base import FileWriterBase +from airbyte._processors.file.jsonl import JsonlWriter +from airbyte._processors.file.parquet import ParquetWriter __all__ = [ - "FileWriterBatchHandle", + "BatchHandle", "FileWriterBase", "JsonlWriter", "ParquetWriter", diff --git a/airbyte/_processors/file/base.py b/airbyte/_processors/file/base.py index 009b7509..30a6ced3 100644 --- a/airbyte/_processors/file/base.py +++ b/airbyte/_processors/file/base.py @@ -4,17 +4,15 @@ from __future__ import annotations import abc -from contextlib import suppress -from dataclasses import field from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, cast, final +from typing import IO, TYPE_CHECKING, final import ulid from overrides import overrides -from pydantic import Field, PrivateAttr from airbyte import exceptions as exc -from airbyte._processors.base import BatchHandle, RecordProcessor +from airbyte._batch_handles import BatchHandle +from airbyte._processors.base import RecordProcessor from airbyte._util.protocol_util import airbyte_record_message_to_dict from airbyte.progress import progress @@ -31,32 +29,12 @@ DEFAULT_BATCH_SIZE = 10000 -# The batch handle for file writers is a list of Path objects. -class FileWriterBatchHandle(BatchHandle): - """The file writer batch handle is a list of Path objects.""" - - files: list[Path] = Field(default_factory=list) - _open_file_writer: Any | None = PrivateAttr(default=None) - _record_count: int = PrivateAttr(default=0) - - # TODO: Handle pydantic error: Fields of type "" are not supported. - # open_file_writer: IO[bytes] | None = PrivateAttr(default=None) - - @property - def record_count(self) -> int: - return self._record_count - - @property - def open_file_writer(self) -> IO[bytes] | None: - return self._open_file_writer - - class FileWriterBase(RecordProcessor, abc.ABC): """A generic base implementation for a file-based cache.""" default_cache_file_suffix: str = ".batch" - _active_batches: dict[str, FileWriterBatchHandle] + _active_batches: dict[str, BatchHandle] def _get_new_cache_file_path( self, @@ -97,7 +75,7 @@ def _flush_active_batch( if stream_name not in self._active_batches: return - batch_handle: FileWriterBatchHandle = self._active_batches[stream_name] + batch_handle: BatchHandle = self._active_batches[stream_name] del self._active_batches[stream_name] if self.skip_finalize_step: @@ -112,7 +90,7 @@ def _flush_active_batch( def _new_batch( self, stream_name: str, - ) -> FileWriterBatchHandle: + ) -> BatchHandle: """Create and return a new batch handle. The base implementation creates and opens a new file for writing so it is ready to receive @@ -120,11 +98,10 @@ def _new_batch( """ batch_id = self._new_batch_id() new_file_path, new_file_handle = self._open_new_file(stream_name=stream_name) - batch_handle = FileWriterBatchHandle( + batch_handle = BatchHandle( stream_name=stream_name, batch_id=batch_id, files=[new_file_path], - open_file_writer=new_file_handle, ) self._active_batches[stream_name] = batch_handle return batch_handle @@ -132,7 +109,7 @@ def _new_batch( @overrides def _cleanup_batch( self, - batch_handle: FileWriterBatchHandle, + batch_handle: BatchHandle, ) -> None: """Clean up the cache. @@ -148,21 +125,18 @@ def _cleanup_batch( def _close_batch_files( self, - batch_handle: FileWriterBatchHandle, + batch_handle: BatchHandle, ) -> None: """Close the current batch.""" if not batch_handle.open_file_writer: return - with suppress(Exception): - batch_handle.open_file_writer.close() - - batch_handle.open_file_writer = None + batch_handle.close_files() @final def cleanup_batch( self, - batch_handle: FileWriterBatchHandle, + batch_handle: BatchHandle, ) -> None: """Clean up the cache. @@ -189,7 +163,7 @@ def _finalize_state_messages( def _process_record_message( self, record_msg: AirbyteRecordMessage, - ) -> tuple[str, FileWriterBatchHandle]: + ) -> tuple[str, BatchHandle]: """Write a record to the cache. This method is called for each record message, before the batch is written. @@ -199,12 +173,12 @@ def _process_record_message( """ stream_name = record_msg.stream - batch_handle: FileWriterBatchHandle + batch_handle: BatchHandle if not self._pending_batches[stream_name]: batch_handle = self._new_batch(stream_name=stream_name) else: - batch_handle = cast(FileWriterBatchHandle, self._pending_batches[stream_name][-1]) + batch_handle = self._pending_batches[stream_name][-1] if batch_handle.record_count + 1 > self.MAX_BATCH_SIZE: # Already at max batch size, so write the batch and start a new one @@ -222,7 +196,7 @@ def _process_record_message( record_dict=airbyte_record_message_to_dict(record_message=record_msg), open_file_writer=batch_handle.open_file_writer, ) - batch_handle.record_count += 1 + batch_handle.increment_record_count() return stream_name, batch_handle def _write_record_dict( diff --git a/airbyte/_processors/file/parquet.py b/airbyte/_processors/file/parquet.py index ec99dc90..8a7abf66 100644 --- a/airbyte/_processors/file/parquet.py +++ b/airbyte/_processors/file/parquet.py @@ -8,22 +8,17 @@ """ from __future__ import annotations -from pathlib import Path -from typing import cast - -import pyarrow as pa -import ulid -from overrides import overrides -from pyarrow import parquet +from typing import TYPE_CHECKING from airbyte import exceptions as exc -from airbyte._processors.file.base import ( - FileWriterBase, - FileWriterBatchHandle, -) +from airbyte._processors.file.base import FileWriterBase from airbyte._util.text_util import lower_case_set +if TYPE_CHECKING: + import pyarrow as pa + + class ParquetWriter(FileWriterBase): """A Parquet cache implementation.""" @@ -55,7 +50,7 @@ def _get_missing_columns( # stream_name: str, # batch_id: str, # record_batch: pa.Table, # TODO: Refactor to remove dependency on pyarrow - # ) -> FileWriterBatchHandle: + # ) -> BatchHandle: # """Process a record batch. # Return the path to the cache file. @@ -85,6 +80,6 @@ def _get_missing_columns( # }, # ) from e - # batch_handle = FileWriterBatchHandle() + # batch_handle = BatchHandle() # batch_handle.files.append(output_file_path) # return batch_handle diff --git a/airbyte/_processors/sql/base.py b/airbyte/_processors/sql/base.py index 345f1109..6cf3bc2d 100644 --- a/airbyte/_processors/sql/base.py +++ b/airbyte/_processors/sql/base.py @@ -6,7 +6,7 @@ import enum from contextlib import contextmanager from functools import cached_property -from typing import TYPE_CHECKING, cast, final +from typing import TYPE_CHECKING, final import pandas as pd import sqlalchemy @@ -27,8 +27,7 @@ from sqlalchemy.sql.elements import TextClause from airbyte import exceptions as exc -from airbyte._processors.base import BatchHandle, RecordProcessor -from airbyte._processors.file.base import FileWriterBase, FileWriterBatchHandle +from airbyte._processors.base import RecordProcessor from airbyte._util.text_util import lower_case_set from airbyte.caches._catalog_manager import CatalogManager from airbyte.datasets._sql import CachedDataset @@ -50,6 +49,8 @@ ConfiguredAirbyteCatalog, ) + from airbyte._batch_handles import BatchHandle + from airbyte._processors.file.base import FileWriterBase from airbyte.caches.base import CacheBase @@ -480,7 +481,7 @@ def _get_sql_column_definitions( # stream_name: str, # batch_id: str, # record_batch: pa.Table, # TODO: Refactor to remove dependency on pyarrow - # ) -> FileWriterBatchHandle: + # ) -> BatchHandle: # """Process a record batch. # Return the path to the cache file. @@ -535,7 +536,6 @@ def _finalize_batches( files: list[Path] = [] # Get a list of all files to finalize from all pending batches. for batch_handle in batches_to_finalize: - batch_handle = cast(FileWriterBatchHandle, batch_handle) files += batch_handle.files # Use the max batch ID as the batch ID for table names. max_batch_id = max([batch.batch_id for batch in batches_to_finalize])