Skip to content

Commit

Permalink
refactor BatchHandle to a single class, fix mypy lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Mar 5, 2024
1 parent 38d5754 commit d9ab8f0
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 71 deletions.
47 changes: 47 additions & 0 deletions airbyte/_batch_handles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
"""Batch handle class."""

from __future__ import annotations

from contextlib import suppress
from pathlib import Path # noqa: TCH003 # Pydantic needs this import
from typing import IO, Any, Optional

from pydantic import BaseModel, Field, PrivateAttr


class BatchHandle(BaseModel):
"""A handle for a batch of records."""

stream_name: str
batch_id: str

files: list[Path] = Field(default_factory=list)
_open_file_writer: Optional[Any] = PrivateAttr(default=None)
_record_count: int = PrivateAttr(default=0)

@property
def record_count(self) -> int:
"""Return the record count."""
return self._record_count

def increment_record_count(self) -> None:
"""Increment the record count."""
self._record_count += 1

@property
def open_file_writer(self) -> IO[bytes] | None:
"""Return the open file writer, if any, or None."""
return self._open_file_writer

def close_files(self) -> None:
"""Close the file writer."""
if self._open_file_writer is None:
return

with suppress(Exception):
self._open_file_writer.close()

def __del__(self) -> None:
"""Upon deletion, close the file writer."""
self.close_files()
9 changes: 1 addition & 8 deletions airbyte/_processors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import pyarrow as pa
import ulid
from pydantic import BaseModel

from airbyte_protocol.models import (
AirbyteMessage,
Expand All @@ -32,6 +31,7 @@
)

from airbyte import exceptions as exc
from airbyte._batch_handles import BatchHandle
from airbyte.caches.base import CacheBase
from airbyte.progress import progress
from airbyte.strategies import WriteStrategy
Expand All @@ -48,13 +48,6 @@
DEBUG_MODE = False # Set to True to enable additional debug logging.


class BatchHandle(BaseModel):
"""A handle for a batch of records."""

stream_name: str
batch_id: str


class AirbyteMessageParsingError(Exception):
"""Raised when an Airbyte message is invalid or cannot be parsed."""

Expand Down
9 changes: 5 additions & 4 deletions airbyte/_processors/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

from __future__ import annotations

from .base import FileWriterBase, FileWriterBatchHandle
from .jsonl import JsonlWriter
from .parquet import ParquetWriter
from airbyte._batch_handles import BatchHandle
from airbyte._processors.file.base import FileWriterBase
from airbyte._processors.file.jsonl import JsonlWriter
from airbyte._processors.file.parquet import ParquetWriter


__all__ = [
"FileWriterBatchHandle",
"BatchHandle",
"FileWriterBase",
"JsonlWriter",
"ParquetWriter",
Expand Down
56 changes: 15 additions & 41 deletions airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@
from __future__ import annotations

import abc
from contextlib import suppress
from dataclasses import field
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, cast, final
from typing import IO, TYPE_CHECKING, final

import ulid
from overrides import overrides
from pydantic import Field, PrivateAttr

from airbyte import exceptions as exc
from airbyte._processors.base import BatchHandle, RecordProcessor
from airbyte._batch_handles import BatchHandle
from airbyte._processors.base import RecordProcessor
from airbyte._util.protocol_util import airbyte_record_message_to_dict
from airbyte.progress import progress

Expand All @@ -31,32 +29,12 @@
DEFAULT_BATCH_SIZE = 10000


# The batch handle for file writers is a list of Path objects.
class FileWriterBatchHandle(BatchHandle):
"""The file writer batch handle is a list of Path objects."""

files: list[Path] = Field(default_factory=list)
_open_file_writer: Any | None = PrivateAttr(default=None)
_record_count: int = PrivateAttr(default=0)

# TODO: Handle pydantic error: Fields of type "<class 'typing.IO'>" are not supported.
# open_file_writer: IO[bytes] | None = PrivateAttr(default=None)

@property
def record_count(self) -> int:
return self._record_count

@property
def open_file_writer(self) -> IO[bytes] | None:
return self._open_file_writer


class FileWriterBase(RecordProcessor, abc.ABC):
"""A generic base implementation for a file-based cache."""

default_cache_file_suffix: str = ".batch"

_active_batches: dict[str, FileWriterBatchHandle]
_active_batches: dict[str, BatchHandle]

def _get_new_cache_file_path(
self,
Expand Down Expand Up @@ -97,7 +75,7 @@ def _flush_active_batch(
if stream_name not in self._active_batches:
return

batch_handle: FileWriterBatchHandle = self._active_batches[stream_name]
batch_handle: BatchHandle = self._active_batches[stream_name]
del self._active_batches[stream_name]

if self.skip_finalize_step:
Expand All @@ -112,27 +90,26 @@ def _flush_active_batch(
def _new_batch(
self,
stream_name: str,
) -> FileWriterBatchHandle:
) -> BatchHandle:
"""Create and return a new batch handle.
The base implementation creates and opens a new file for writing so it is ready to receive
records.
"""
batch_id = self._new_batch_id()
new_file_path, new_file_handle = self._open_new_file(stream_name=stream_name)
batch_handle = FileWriterBatchHandle(
batch_handle = BatchHandle(
stream_name=stream_name,
batch_id=batch_id,
files=[new_file_path],
open_file_writer=new_file_handle,
)
self._active_batches[stream_name] = batch_handle
return batch_handle

@overrides
def _cleanup_batch(
self,
batch_handle: FileWriterBatchHandle,
batch_handle: BatchHandle,
) -> None:
"""Clean up the cache.
Expand All @@ -148,21 +125,18 @@ def _cleanup_batch(

def _close_batch_files(
self,
batch_handle: FileWriterBatchHandle,
batch_handle: BatchHandle,
) -> None:
"""Close the current batch."""
if not batch_handle.open_file_writer:
return

with suppress(Exception):
batch_handle.open_file_writer.close()

batch_handle.open_file_writer = None
batch_handle.close_files()

@final
def cleanup_batch(
self,
batch_handle: FileWriterBatchHandle,
batch_handle: BatchHandle,
) -> None:
"""Clean up the cache.
Expand All @@ -189,7 +163,7 @@ def _finalize_state_messages(
def _process_record_message(
self,
record_msg: AirbyteRecordMessage,
) -> tuple[str, FileWriterBatchHandle]:
) -> tuple[str, BatchHandle]:
"""Write a record to the cache.
This method is called for each record message, before the batch is written.
Expand All @@ -199,12 +173,12 @@ def _process_record_message(
"""
stream_name = record_msg.stream

batch_handle: FileWriterBatchHandle
batch_handle: BatchHandle
if not self._pending_batches[stream_name]:
batch_handle = self._new_batch(stream_name=stream_name)

else:
batch_handle = cast(FileWriterBatchHandle, self._pending_batches[stream_name][-1])
batch_handle = self._pending_batches[stream_name][-1]

if batch_handle.record_count + 1 > self.MAX_BATCH_SIZE:
# Already at max batch size, so write the batch and start a new one
Expand All @@ -222,7 +196,7 @@ def _process_record_message(
record_dict=airbyte_record_message_to_dict(record_message=record_msg),
open_file_writer=batch_handle.open_file_writer,
)
batch_handle.record_count += 1
batch_handle.increment_record_count()
return stream_name, batch_handle

def _write_record_dict(
Expand Down
21 changes: 8 additions & 13 deletions airbyte/_processors/file/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@
"""
from __future__ import annotations

from pathlib import Path
from typing import cast

import pyarrow as pa
import ulid
from overrides import overrides
from pyarrow import parquet
from typing import TYPE_CHECKING

from airbyte import exceptions as exc
from airbyte._processors.file.base import (
FileWriterBase,
FileWriterBatchHandle,
)
from airbyte._processors.file.base import FileWriterBase
from airbyte._util.text_util import lower_case_set


if TYPE_CHECKING:
import pyarrow as pa


class ParquetWriter(FileWriterBase):
"""A Parquet cache implementation."""

Expand Down Expand Up @@ -55,7 +50,7 @@ def _get_missing_columns(
# stream_name: str,
# batch_id: str,
# record_batch: pa.Table, # TODO: Refactor to remove dependency on pyarrow
# ) -> FileWriterBatchHandle:
# ) -> BatchHandle:
# """Process a record batch.

# Return the path to the cache file.
Expand Down Expand Up @@ -85,6 +80,6 @@ def _get_missing_columns(
# },
# ) from e

# batch_handle = FileWriterBatchHandle()
# batch_handle = BatchHandle()
# batch_handle.files.append(output_file_path)
# return batch_handle
10 changes: 5 additions & 5 deletions airbyte/_processors/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import enum
from contextlib import contextmanager
from functools import cached_property
from typing import TYPE_CHECKING, cast, final
from typing import TYPE_CHECKING, final

import pandas as pd
import sqlalchemy
Expand All @@ -27,8 +27,7 @@
from sqlalchemy.sql.elements import TextClause

from airbyte import exceptions as exc
from airbyte._processors.base import BatchHandle, RecordProcessor
from airbyte._processors.file.base import FileWriterBase, FileWriterBatchHandle
from airbyte._processors.base import RecordProcessor
from airbyte._util.text_util import lower_case_set
from airbyte.caches._catalog_manager import CatalogManager
from airbyte.datasets._sql import CachedDataset
Expand All @@ -50,6 +49,8 @@
ConfiguredAirbyteCatalog,
)

from airbyte._batch_handles import BatchHandle
from airbyte._processors.file.base import FileWriterBase
from airbyte.caches.base import CacheBase


Expand Down Expand Up @@ -480,7 +481,7 @@ def _get_sql_column_definitions(
# stream_name: str,
# batch_id: str,
# record_batch: pa.Table, # TODO: Refactor to remove dependency on pyarrow
# ) -> FileWriterBatchHandle:
# ) -> BatchHandle:
# """Process a record batch.

# Return the path to the cache file.
Expand Down Expand Up @@ -535,7 +536,6 @@ def _finalize_batches(
files: list[Path] = []
# Get a list of all files to finalize from all pending batches.
for batch_handle in batches_to_finalize:
batch_handle = cast(FileWriterBatchHandle, batch_handle)
files += batch_handle.files
# Use the max batch ID as the batch ID for table names.
max_batch_id = max([batch.batch_id for batch in batches_to_finalize])
Expand Down

0 comments on commit d9ab8f0

Please sign in to comment.