Skip to content

Commit

Permalink
Feat: Remove need to import CacheConfig classes in addition to `Cac…
Browse files Browse the repository at this point in the history
…he` classes (major refactor) (#59)
  • Loading branch information
aaronsteers authored Feb 22, 2024
1 parent 81d1b9c commit 7b527d0
Show file tree
Hide file tree
Showing 35 changed files with 1,638 additions and 1,567 deletions.
19 changes: 13 additions & 6 deletions airbyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
"""
from __future__ import annotations

from airbyte._factories.cache_factories import get_default_cache, new_local_cache
from airbyte import caches, datasets, registry, secrets
from airbyte._factories.connector_factories import get_source
from airbyte.caches import DuckDBCache, DuckDBCacheConfig
from airbyte.caches.duckdb import DuckDBCache
from airbyte.caches.factories import get_default_cache, new_local_cache
from airbyte.datasets import CachedDataset
from airbyte.registry import get_available_connectors
from airbyte.results import ReadResult
Expand All @@ -16,14 +17,20 @@


__all__ = [
"CachedDataset",
"DuckDBCache",
"DuckDBCacheConfig",
# Modules
"caches",
"datasets",
"registry",
"secrets",
# Factories
"get_available_connectors",
"get_source",
"get_default_cache",
"get_secret",
"get_source",
"new_local_cache",
# Classes
"CachedDataset",
"DuckDBCache",
"ReadResult",
"SecretSource",
"Source",
Expand Down
6 changes: 3 additions & 3 deletions airbyte/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def install(self) -> None:
pass

@abstractmethod
def get_telemetry_info(self) -> SourceTelemetryInfo:
def _get_telemetry_info(self) -> SourceTelemetryInfo:
pass

@abstractmethod
Expand Down Expand Up @@ -388,7 +388,7 @@ def execute(self, args: list[str]) -> Iterator[str]:
with _stream_from_subprocess([str(connector_path), *args]) as stream:
yield from stream

def get_telemetry_info(self) -> SourceTelemetryInfo:
def _get_telemetry_info(self) -> SourceTelemetryInfo:
return SourceTelemetryInfo(
name=self.name,
type=SourceType.VENV,
Expand Down Expand Up @@ -449,7 +449,7 @@ def execute(self, args: list[str]) -> Iterator[str]:
with _stream_from_subprocess([str(self.path), *args]) as stream:
yield from stream

def get_telemetry_info(self) -> SourceTelemetryInfo:
def _get_telemetry_info(self) -> SourceTelemetryInfo:
return SourceTelemetryInfo(
str(self.name),
SourceType.LOCAL_INSTALL,
Expand Down
16 changes: 0 additions & 16 deletions airbyte/_file_writers/__init__.py

This file was deleted.

Empty file added airbyte/_processors/__init__.py
Empty file.
45 changes: 21 additions & 24 deletions airbyte/_processors.py → airbyte/_processors/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""Abstract base class for Processors, including SQL and File writers.
"""Define abstract base class for Processors, including Caches and File writers.
Processors can all take input from STDIN or a stream of Airbyte messages.
Processors can take input from STDIN or a stream of Airbyte messages.
Caches will pass their input to the File Writer. They share a common base class so certain
abstractions like "write" and "finalize" can be handled in either layer, or both.
Expand Down Expand Up @@ -34,6 +33,7 @@

from airbyte import exceptions as exc
from airbyte._util import protocol_util
from airbyte.caches.base import CacheBase
from airbyte.progress import progress
from airbyte.strategies import WriteStrategy
from airbyte.types import _get_pyarrow_type
Expand All @@ -43,7 +43,6 @@
from collections.abc import Generator, Iterable, Iterator

from airbyte.caches._catalog_manager import CatalogManager
from airbyte.config import CacheConfigBase


DEFAULT_BATCH_SIZE = 10_000
Expand All @@ -61,32 +60,29 @@ class AirbyteMessageParsingError(Exception):
class RecordProcessor(abc.ABC):
"""Abstract base class for classes which can process input records."""

config_class: type[CacheConfigBase]
skip_finalize_step: bool = False
_expected_streams: set[str]

def __init__(
self,
config: CacheConfigBase | dict | None,
cache: CacheBase,
*,
catalog_manager: CatalogManager | None = None,
) -> None:
if isinstance(config, dict):
config = self.config_class(**config)

self.config = config or self.config_class()
if not isinstance(self.config, self.config_class):
err_msg = (
f"Expected config class of type '{self.config_class.__name__}'. "
f"Instead found '{type(self.config).__name__}'."
self._expected_streams: set[str] | None = None
self.cache: CacheBase = cache
if not isinstance(self.cache, CacheBase):
raise exc.AirbyteLibInputError(
message=(
f"Expected config class of type 'CacheBase'. "
f"Instead received type '{type(self.cache).__name__}'."
),
)
raise TypeError(err_msg)

self.source_catalog: ConfiguredAirbyteCatalog | None = None
self._source_name: str | None = None

self._pending_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
self._pending_batches: dict[str, dict[str, Any]] = defaultdict(dict, {})
self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(dict, {})

self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {})
self._finalized_state_messages: dict[
Expand All @@ -97,6 +93,11 @@ def __init__(
self._catalog_manager: CatalogManager | None = catalog_manager
self._setup()

@property
def expected_streams(self) -> set[str]:
"""Return the expected stream names."""
return self._expected_streams or set()

def register_source(
self,
source_name: str,
Expand All @@ -115,11 +116,6 @@ def register_source(
)
self._expected_streams = stream_names

@property
def _streams_with_data(self) -> set[str]:
"""Return a list of known streams."""
return self._pending_batches.keys() | self._finalized_batches.keys()

@final
def process_stdin(
self,
Expand Down Expand Up @@ -216,7 +212,7 @@ def process_airbyte_messages(

all_streams = list(self._pending_batches.keys())
# Add empty streams to the streams list, so we create a destination table for it
for stream_name in self._expected_streams:
for stream_name in self.expected_streams:
if stream_name not in all_streams:
if DEBUG_MODE:
print(f"Stream {stream_name} has no data")
Expand Down Expand Up @@ -358,6 +354,7 @@ def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract
By default this is a no-op but subclasses can override this method to prepare
any necessary resources.
"""
pass

def _teardown(self) -> None:
"""Teardown the processor resources.
Expand Down
16 changes: 16 additions & 0 deletions airbyte/_processors/file/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""File processors."""

from __future__ import annotations

from .base import FileWriterBase, FileWriterBatchHandle
from .jsonl import JsonlWriter
from .parquet import ParquetWriter


__all__ = [
"FileWriterBatchHandle",
"FileWriterBase",
"JsonlWriter",
"ParquetWriter",
]
23 changes: 5 additions & 18 deletions airbyte/_file_writers/base.py → airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""Define abstract base class for File Writers, which write and read from file storage."""
"""Abstract base class for File Writers, which write and read from file storage."""

from __future__ import annotations

import abc
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, cast, final

from overrides import overrides

from airbyte._processors import BatchHandle, RecordProcessor
from airbyte.config import CacheConfigBase
from airbyte._processors.base import BatchHandle, RecordProcessor


if TYPE_CHECKING:
from pathlib import Path

import pyarrow as pa

from airbyte_protocol.models import (
Expand All @@ -34,21 +33,9 @@ class FileWriterBatchHandle(BatchHandle):
files: list[Path] = field(default_factory=list)


class FileWriterConfigBase(CacheConfigBase):
"""Configuration for the Snowflake cache."""

cache_dir: Path = Path("./.cache/files/")
"""The directory to store cache files in."""
cleanup: bool = True
"""Whether to clean up temporary files after processing a batch."""


class FileWriterBase(RecordProcessor, abc.ABC):
"""A generic base implementation for a file-based cache."""

config_class = FileWriterConfigBase
config: FileWriterConfigBase

@abc.abstractmethod
@overrides
def _write_batch(
Expand Down Expand Up @@ -91,7 +78,7 @@ def _cleanup_batch(
This method is a no-op if the `cleanup` config option is set to False.
"""
if self.config.cleanup:
if self.cache.cleanup:
batch_handle = cast(FileWriterBatchHandle, batch_handle)
_ = stream_name, batch_id
for file_path in batch_handle.files:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,47 +1,37 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""A Parquet cache implementation."""

from __future__ import annotations

import gzip
from pathlib import Path
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

import orjson
import ulid
from overrides import overrides

from airbyte._file_writers.base import (
from airbyte._processors.file.base import (
FileWriterBase,
FileWriterBatchHandle,
FileWriterConfigBase,
)


if TYPE_CHECKING:
import pyarrow as pa


class JsonlWriterConfig(FileWriterConfigBase):
"""Configuration for the Snowflake cache."""

# Inherits `cache_dir` from base class


class JsonlWriter(FileWriterBase):
"""A Jsonl cache implementation."""

config_class = JsonlWriterConfig

def get_new_cache_file_path(
self,
stream_name: str,
batch_id: str | None = None, # ULID of the batch
) -> Path:
"""Return a new cache file path for the given stream."""
batch_id = batch_id or str(ulid.ULID())
config: JsonlWriterConfig = cast(JsonlWriterConfig, self.config)
target_dir = Path(config.cache_dir)
target_dir = Path(self.cache.cache_dir)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / f"{stream_name}_{batch_id}.jsonl.gz"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""A Parquet cache implementation.
# Copyright (c) 2023 Airbyte, Inc., all rights reserved
"""A Parquet file writer implementation.
NOTE: Parquet is a strongly typed columnar storage format, which has known issues when applied to
variable schemas, schemas with indeterminate types, and schemas that have empty data nodes.
Expand All @@ -18,34 +17,24 @@
from pyarrow import parquet

from airbyte import exceptions as exc
from airbyte._file_writers.base import (
from airbyte._processors.file.base import (
FileWriterBase,
FileWriterBatchHandle,
FileWriterConfigBase,
)
from airbyte._util.text_util import lower_case_set


class ParquetWriterConfig(FileWriterConfigBase):
"""Configuration for the Snowflake cache."""

# Inherits `cache_dir` from base class


class ParquetWriter(FileWriterBase):
"""A Parquet cache implementation."""

config_class = ParquetWriterConfig

def get_new_cache_file_path(
self,
stream_name: str,
batch_id: str | None = None, # ULID of the batch
) -> Path:
"""Return a new cache file path for the given stream."""
batch_id = batch_id or str(ulid.ULID())
config: ParquetWriterConfig = cast(ParquetWriterConfig, self.config)
target_dir = Path(config.cache_dir)
target_dir = Path(self.cache.cache_dir)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / f"{stream_name}_{batch_id}.parquet"

Expand Down
2 changes: 2 additions & 0 deletions airbyte/_processors/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""SQL processors."""
Loading

0 comments on commit 7b527d0

Please sign in to comment.