Skip to content

Commit

Permalink
major refactor increment (some tests still failing)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Feb 21, 2024
1 parent 3039498 commit 816d807
Show file tree
Hide file tree
Showing 20 changed files with 249 additions and 275 deletions.
3 changes: 1 addition & 2 deletions airbyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from airbyte._factories.cache_factories import get_default_cache, new_local_cache
from airbyte._factories.connector_factories import get_source
from airbyte.caches import DuckDBCacheInstance, DuckDBCache
from airbyte.caches import DuckDBCache
from airbyte.datasets import CachedDataset
from airbyte.registry import get_available_connectors
from airbyte.results import ReadResult
Expand All @@ -17,7 +17,6 @@

__all__ = [
"CachedDataset",
"DuckDBCacheInstance",
"DuckDBCache",
"get_available_connectors",
"get_source",
Expand Down
13 changes: 6 additions & 7 deletions airbyte/_factories/cache_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@
import ulid

from airbyte import exceptions as exc
from airbyte.caches.duckdb import DuckDBCacheInstance, DuckDBCache
from airbyte.caches.duckdb import DuckDBCache


def get_default_cache() -> DuckDBCacheInstance:
def get_default_cache() -> DuckDBCache:
"""Get a local cache for storing data, using the default database path.
Cache files are stored in the `.cache` directory, relative to the current
working directory.
"""
config = DuckDBCache(

return DuckDBCache(
db_path="./.cache/default_cache_db.duckdb",
)
return DuckDBCacheInstance(config=config)


def new_local_cache(
cache_name: str | None = None,
cache_dir: str | Path | None = None,
*,
cleanup: bool = True,
) -> DuckDBCacheInstance:
) -> DuckDBCache:
"""Get a local cache for storing data, using a name string to seed the path.
Args:
Expand Down Expand Up @@ -55,9 +55,8 @@ def new_local_cache(
if not isinstance(cache_dir, Path):
cache_dir = Path(cache_dir)

config = DuckDBCache(
return DuckDBCache(
db_path=cache_dir / f"db_{cache_name}.duckdb",
cache_dir=cache_dir,
cleanup=cleanup,
)
return DuckDBCacheInstance(config=config)
1 change: 0 additions & 1 deletion airbyte/_file_writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class FileWriterConfigBase(CacheConfigBase):
class FileWriterBase(RecordProcessor, abc.ABC):
"""A generic base implementation for a file-based cache."""

config_class = FileWriterConfigBase
config: FileWriterConfigBase

@abc.abstractmethod
Expand Down
2 changes: 0 additions & 2 deletions airbyte/_file_writers/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class ParquetWriterConfig(FileWriterConfigBase):
class ParquetWriter(FileWriterBase):
"""A Parquet cache implementation."""

config_class = ParquetWriterConfig

def get_new_cache_file_path(
self,
stream_name: str,
Expand Down
25 changes: 11 additions & 14 deletions airbyte/_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from airbyte import exceptions as exc
from airbyte._util import protocol_util
from airbyte.config import CacheConfigBase
from airbyte.progress import progress
from airbyte.strategies import WriteStrategy

Expand All @@ -41,7 +42,6 @@
from collections.abc import Generator, Iterable, Iterator

from airbyte.caches._catalog_manager import CatalogManager
from airbyte.config import CacheConfigBase


DEFAULT_BATCH_SIZE = 10_000
Expand All @@ -59,32 +59,29 @@ class AirbyteMessageParsingError(Exception):
class RecordProcessor(abc.ABC):
"""Abstract base class for classes which can process input records."""

config_class: type[CacheConfigBase]
skip_finalize_step: bool = False
_expected_streams: set[str]

def __init__(
self,
config: CacheConfigBase | dict | None,
config: CacheConfigBase,
*,
catalog_manager: CatalogManager | None = None,
) -> None:
if isinstance(config, dict):
config = self.config_class(**config)

self.config = config or self.config_class()
if not isinstance(self.config, self.config_class):
err_msg = (
f"Expected config class of type '{self.config_class.__name__}'. "
f"Instead found '{type(self.config).__name__}'."
self.config = config
if not isinstance(self.config, CacheConfigBase):
raise exc.AirbyteLibInputError(
message=(
f"Expected config class of type 'CacheConfigBase'. "
f"Instead received type '{type(self.config).__name__}'."
),
)
raise TypeError(err_msg)

self.source_catalog: ConfiguredAirbyteCatalog | None = None
self._source_name: str | None = None

self._pending_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
self._pending_batches: dict[str, dict[str, Any]] = defaultdict(dict, {})
self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(dict, {})

self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {})
self._finalized_state_messages: dict[
Expand Down
8 changes: 3 additions & 5 deletions airbyte/caches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

from airbyte.caches.base import SQLCacheBase
from airbyte.caches.duckdb import DuckDBCache
from airbyte.caches.postgres import PostgresCache, PostgresCacheConfig
from airbyte.caches.snowflake import SnowflakeCacheConfig, SnowflakeSQLCache
from airbyte.caches.postgres import PostgresCache
from airbyte.caches.snowflake import SnowflakeCache


# We export these classes for easy access: `airbyte.caches...`
__all__ = [
"DuckDBCache",
"PostgresCache",
"PostgresCacheConfig",
"SQLCacheBase",
"SnowflakeCacheConfig",
"SnowflakeSQLCache",
"SnowflakeCache",
]
8 changes: 4 additions & 4 deletions airbyte/caches/_catalog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
if TYPE_CHECKING:
from sqlalchemy.engine import Engine

STREAMS_TABLE_NAME = "_airbytelib_streams"
STATE_TABLE_NAME = "_airbytelib_state"
STREAMS_TABLE_NAME = "_airbyte_streams"
STATE_TABLE_NAME = "_airbyte_state"

GLOBAL_STATE_STREAM_NAMES = ["_GLOBAL", "_LEGACY"]

Expand Down Expand Up @@ -90,7 +90,7 @@ def _ensure_internal_tables(self) -> None:
engine = self._engine
Base.metadata.create_all(engine)

def _save_state(
def save_state(
self,
source_name: str,
state: AirbyteStateMessage,
Expand All @@ -113,7 +113,7 @@ def _save_state(
)
session.commit()

def _get_state(
def get_state(
self,
source_name: str,
streams: list[str],
Expand Down
Loading

0 comments on commit 816d807

Please sign in to comment.