Skip to content

Commit

Permalink
major refactor, passing tests again
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Mar 7, 2024
1 parent b9b0296 commit caf429a
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 276 deletions.
49 changes: 36 additions & 13 deletions airbyte/_batch_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,45 @@

from __future__ import annotations

from contextlib import suppress
from pathlib import Path # noqa: TCH003 # Pydantic needs this import
from typing import IO, Any, Optional
from typing import IO, TYPE_CHECKING, Callable

from pydantic import BaseModel, Field, PrivateAttr

if TYPE_CHECKING:
from pathlib import Path

class BatchHandle(BaseModel):

class BatchHandle:
"""A handle for a batch of records."""

stream_name: str
batch_id: str
def __init__(
self,
stream_name: str,
batch_id: str,
files: list[Path],
file_opener: Callable[[Path], IO[bytes]],
) -> None:
"""Initialize the batch handle."""
self._stream_name = stream_name
self._batch_id = batch_id
self._files = files
self._record_count = 0
assert self._files, "A batch must have at least one file."
self._open_file_writer: IO[bytes] = file_opener(self._files[0])

@property
def files(self) -> list[Path]:
"""Return the files."""
return self._files

@property
def batch_id(self) -> str:
"""Return the batch ID."""
return self._batch_id

files: list[Path] = Field(default_factory=list)
_open_file_writer: Optional[Any] = PrivateAttr(default=None)
_record_count: int = PrivateAttr(default=0)
@property
def stream_name(self) -> str:
"""Return the stream name."""
return self._stream_name

@property
def record_count(self) -> int:
Expand All @@ -36,11 +59,11 @@ def open_file_writer(self) -> IO[bytes] | None:

def close_files(self) -> None:
"""Close the file writer."""
if self._open_file_writer is None:
if self.open_file_writer is None:
return

with suppress(Exception):
self._open_file_writer.close()
# with suppress(Exception):
self.open_file_writer.close()

def __del__(self) -> None:
"""Upon deletion, close the file writer."""
Expand Down
189 changes: 33 additions & 156 deletions airbyte/_processors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
from __future__ import annotations

import abc
import contextlib
import io
import sys
from collections import defaultdict
from typing import TYPE_CHECKING, Any, cast, final

import pyarrow as pa
import ulid

from airbyte_protocol.models import (
AirbyteMessage,
Expand All @@ -31,20 +29,17 @@
)

from airbyte import exceptions as exc
from airbyte._batch_handles import BatchHandle
from airbyte.caches.base import CacheBase
from airbyte.progress import progress
from airbyte.strategies import WriteStrategy
from airbyte.types import _get_pyarrow_type


if TYPE_CHECKING:
from collections.abc import Generator, Iterable, Iterator
from collections.abc import Iterable, Iterator

from airbyte.caches._catalog_manager import CatalogManager


DEFAULT_BATCH_SIZE = 10_000
DEBUG_MODE = False # Set to True to enable additional debug logging.


Expand All @@ -55,10 +50,6 @@ class AirbyteMessageParsingError(Exception):
class RecordProcessor(abc.ABC):
"""Abstract base class for classes which can process input records."""

MAX_BATCH_SIZE: int = DEFAULT_BATCH_SIZE

skip_finalize_step: bool = False

def __init__(
self,
cache: CacheBase,
Expand All @@ -78,10 +69,6 @@ def __init__(
self.source_catalog: ConfiguredAirbyteCatalog | None = None
self._source_name: str | None = None

self._active_batches: dict[str, BatchHandle] = {}
self._pending_batches: dict[str, list[BatchHandle]] = defaultdict(list, {})
self._finalized_batches: dict[str, list[BatchHandle]] = defaultdict(list, {})

self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {})
self._finalized_state_messages: dict[
str,
Expand Down Expand Up @@ -150,19 +137,18 @@ def process_input_stream(
write_strategy=write_strategy,
)

def _process_record_message(
@abc.abstractmethod
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
) -> None:
"""Write a record to the cache.
This method is called for each record message, before the batch is written.
This method is called for each record message.
By default this is a no-op but file writers can override this method to write the record to
files.
In most cases, the SQL processor will not perform any action, but will pass this along to to
the file processor.
"""
_ = record_msg # Unused
pass

@final
def process_airbyte_messages(
Expand All @@ -181,7 +167,7 @@ def process_airbyte_messages(
for message in messages:
if message.type is Type.RECORD:
record_msg = cast(AirbyteRecordMessage, message.record)
self._process_record_message(record_msg)
self.process_record_message(record_msg)

elif message.type is Type.STATE:
state_msg = cast(AirbyteStateMessage, message.state)
Expand All @@ -197,129 +183,34 @@ def process_airbyte_messages(
# Type.LOG, Type.TRACE, Type.CONTROL, etc.
pass

# We are at the end of the stream. Process whatever else is queued.
self._flush_active_batches()

all_streams = list(set(self._pending_batches.keys()) | set(self._finalized_batches.keys()))
# Add empty streams to the streams list, so we create a destination table for it
for stream_name in self.expected_streams:
if stream_name not in all_streams:
if DEBUG_MODE:
print(f"Stream {stream_name} has no data")
all_streams.append(stream_name)

# Finalize any pending batches
for stream_name in all_streams:
self._finalize_batches(stream_name, write_strategy=write_strategy)
progress.log_stream_finalized(stream_name)

def _flush_active_batches(
self,
) -> None:
"""Flush active batches for all streams."""
for stream_name in self._active_batches:
self._flush_active_batch(stream_name)

def _flush_active_batch(
self,
stream_name: str,
) -> None:
"""Flush the active batch for the given stream.
This entails moving the active batch to the pending batches, closing any open files, and
logging the batch as written.
"""
raise NotImplementedError(
"Subclasses must implement the _flush_active_batch() method.",
self.flush_all(
write_strategy=write_strategy,
)

def _cleanup_batch( # noqa: B027 # Intentionally empty, not abstract
self,
batch_handle: BatchHandle,
) -> None:
"""Clean up the cache.
This method is called after the given batch has been finalized.
For instance, file writers can override this method to delete the files created. Caches,
similarly, can override this method to delete any other temporary artifacts.
"""
pass

def _new_batch_id(self) -> str:
"""Return a new batch handle."""
return str(ulid.ULID())

def _new_batch(
self,
stream_name: str,
) -> BatchHandle:
"""Create and return a new batch handle.
By default this is a concatenation of the stream name and batch ID.
However, any Python object can be returned, such as a Path object.
"""
batch_id = self._new_batch_id()
return BatchHandle(stream_name=stream_name, batch_id=batch_id)

def _finalize_batches(
self,
stream_name: str,
write_strategy: WriteStrategy,
) -> list[BatchHandle]:
"""Finalize all uncommitted batches.
Returns a mapping of batch IDs to batch handles, for processed batches.
This is a generic implementation, which can be overridden.
"""
_ = write_strategy # Unused
with self._finalizing_batches(stream_name) as batches_to_finalize:
if batches_to_finalize and not self.skip_finalize_step:
raise NotImplementedError(
"Caches need to be finalized but no _finalize_batch() method "
f"exists for class {self.__class__.__name__}",
)

return batches_to_finalize
# Clean up files, if requested.
if self.cache.cleanup:
self.cleanup_all()

@abc.abstractmethod
def flush_all(self, write_strategy: WriteStrategy) -> None:
"""Finalize any pending writes."""

def _finalize_state_messages(
self,
stream_name: str,
state_messages: list[AirbyteStateMessage],
) -> None:
"""Handle state messages.
Might be a no-op if the processor doesn't handle incremental state."""
pass

@final
@contextlib.contextmanager
def _finalizing_batches(
self,
stream_name: str,
) -> Generator[list[BatchHandle], str, None]:
"""Context manager to use for finalizing batches, if applicable.
Returns a mapping of batch IDs to batch handles, for those processed batches.
"""
batches_to_finalize: list[BatchHandle] = self._pending_batches[stream_name].copy()
state_messages_to_finalize: list[AirbyteStateMessage] = self._pending_state_messages[
stream_name
].copy()
self._pending_batches[stream_name].clear()
self._pending_state_messages[stream_name].clear()

progress.log_batches_finalizing(stream_name, len(batches_to_finalize))
yield batches_to_finalize
self._finalize_state_messages(stream_name, state_messages_to_finalize)
progress.log_batches_finalized(stream_name, len(batches_to_finalize))

self._finalized_batches[stream_name] += batches_to_finalize
self._finalized_state_messages[stream_name] += state_messages_to_finalize

for batch_handle in batches_to_finalize:
self._cleanup_batch(batch_handle)
"""Handle state messages by passing them to the catalog manager."""
if not self._catalog_manager:
raise exc.AirbyteLibInternalError(
message="Catalog manager should exist but does not.",
)
if state_messages and self._source_name:
self._catalog_manager.save_state(
source_name=self._source_name,
stream_name=stream_name,
state=state_messages[-1],
)

def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract
"""Create the database.
Expand All @@ -329,27 +220,6 @@ def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract
"""
pass

def _teardown(self) -> None:
"""Teardown the processor resources.
By default, the base implementation simply calls _cleanup_batch() for all pending batches.
"""
batch_lists: list[list[BatchHandle]] = list(self._pending_batches.values()) + list(
self._finalized_batches.values()
)

# TODO: flatten lists and remove nested 'for'
for batch_list in batch_lists:
for batch_handle in batch_list:
self._cleanup_batch(
batch_handle=batch_handle,
)

@final
def __del__(self) -> None:
"""Teardown temporary resources when instance is unloaded from memory."""
self._teardown()

@final
def _get_stream_config(
self,
Expand Down Expand Up @@ -384,3 +254,10 @@ def _get_stream_pyarrow_schema(
].items()
]
)

def cleanup_all(self) -> None: # noqa: B027 # Intentionally empty, not abstract
"""Clean up all resources.
The default implementation is a no-op.
"""
pass
Loading

0 comments on commit caf429a

Please sign in to comment.