Skip to content

Commit

Permalink
Feat: Replace Parquet File Writer with Gzipped Jsonl File Writer (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers authored Feb 22, 2024
1 parent b0a98fe commit 81d1b9c
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 72 deletions.
3 changes: 3 additions & 0 deletions airbyte/_file_writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from .base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase
from .jsonl import JsonlWriter, JsonlWriterConfig
from .parquet import ParquetWriter, ParquetWriterConfig


__all__ = [
"FileWriterBatchHandle",
"FileWriterBase",
"FileWriterConfigBase",
"JsonlWriter",
"JsonlWriterConfig",
"ParquetWriter",
"ParquetWriterConfig",
]
68 changes: 68 additions & 0 deletions airbyte/_file_writers/jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""A Parquet cache implementation."""
from __future__ import annotations

import gzip
from pathlib import Path
from typing import TYPE_CHECKING, cast

import orjson
import ulid
from overrides import overrides

from airbyte._file_writers.base import (
FileWriterBase,
FileWriterBatchHandle,
FileWriterConfigBase,
)


if TYPE_CHECKING:
import pyarrow as pa


class JsonlWriterConfig(FileWriterConfigBase):
"""Configuration for the Snowflake cache."""

# Inherits `cache_dir` from base class


class JsonlWriter(FileWriterBase):
"""A Jsonl cache implementation."""

config_class = JsonlWriterConfig

def get_new_cache_file_path(
self,
stream_name: str,
batch_id: str | None = None, # ULID of the batch
) -> Path:
"""Return a new cache file path for the given stream."""
batch_id = batch_id or str(ulid.ULID())
config: JsonlWriterConfig = cast(JsonlWriterConfig, self.config)
target_dir = Path(config.cache_dir)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / f"{stream_name}_{batch_id}.jsonl.gz"

@overrides
def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Process a record batch.
Return the path to the cache file.
"""
_ = batch_id # unused
output_file_path = self.get_new_cache_file_path(stream_name)

with gzip.open(output_file_path, "w") as jsonl_file:
for record in record_batch.to_pylist():
jsonl_file.write(orjson.dumps(record) + b"\n")

batch_handle = FileWriterBatchHandle()
batch_handle.files.append(output_file_path)
return batch_handle
24 changes: 21 additions & 3 deletions airbyte/_file_writers/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""A Parquet cache implementation."""
"""A Parquet cache implementation.
NOTE: Parquet is a strongly typed columnar storage format, which has known issues when applied to
variable schemas, schemas with indeterminate types, and schemas that have empty data nodes.
This implementation is deprecated for now in favor of jsonl.gz, and may be removed or revamped in
the future.
"""
from __future__ import annotations

from pathlib import Path
Expand Down Expand Up @@ -83,8 +89,20 @@ def _write_batch(
for col in missing_columns:
record_batch = record_batch.append_column(col, null_array)

with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer:
writer.write_table(record_batch)
try:
with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer:
writer.write_table(record_batch)
except Exception as e:
raise exc.AirbyteLibInternalError(
message=f"Failed to write record batch to Parquet file: {e}",
context={
"stream_name": stream_name,
"batch_id": batch_id,
"output_file_path": output_file_path,
"schema": record_batch.schema,
"record_batch": record_batch,
},
) from e

batch_handle = FileWriterBatchHandle()
batch_handle.files.append(output_file_path)
Expand Down
41 changes: 29 additions & 12 deletions airbyte/_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, cast, final

import pandas as pd
import pyarrow as pa
import ulid

Expand All @@ -35,6 +36,7 @@
from airbyte._util import protocol_util
from airbyte.progress import progress
from airbyte.strategies import WriteStrategy
from airbyte.types import _get_pyarrow_type


if TYPE_CHECKING:
Expand Down Expand Up @@ -177,17 +179,16 @@ def process_airbyte_messages(
)

stream_batches: dict[str, list[dict]] = defaultdict(list, {})

# Process messages, writing to batches as we go
for message in messages:
if message.type is Type.RECORD:
record_msg = cast(AirbyteRecordMessage, message.record)
stream_name = record_msg.stream
stream_batch = stream_batches[stream_name]
stream_batch.append(protocol_util.airbyte_record_message_to_dict(record_msg))

if len(stream_batch) >= max_batch_size:
record_batch = pa.Table.from_pylist(stream_batch)
batch_df = pd.DataFrame(stream_batch)
record_batch = pa.Table.from_pandas(batch_df)
self._process_batch(stream_name, record_batch)
progress.log_batch_written(stream_name, len(stream_batch))
stream_batch.clear()
Expand All @@ -206,21 +207,23 @@ def process_airbyte_messages(
# Type.LOG, Type.TRACE, Type.CONTROL, etc.
pass

# Add empty streams to the dictionary, so we create a destination table for it
for stream_name in self._expected_streams:
if stream_name not in stream_batches:
if DEBUG_MODE:
print(f"Stream {stream_name} has no data")
stream_batches[stream_name] = []

# We are at the end of the stream. Process whatever else is queued.
for stream_name, stream_batch in stream_batches.items():
record_batch = pa.Table.from_pylist(stream_batch)
batch_df = pd.DataFrame(stream_batch)
record_batch = pa.Table.from_pandas(batch_df)
self._process_batch(stream_name, record_batch)
progress.log_batch_written(stream_name, len(stream_batch))

all_streams = list(self._pending_batches.keys())
# Add empty streams to the streams list, so we create a destination table for it
for stream_name in self._expected_streams:
if stream_name not in all_streams:
if DEBUG_MODE:
print(f"Stream {stream_name} has no data")
all_streams.append(stream_name)

# Finalize any pending batches
for stream_name in list(self._pending_batches.keys()):
for stream_name in all_streams:
self._finalize_batches(stream_name, write_strategy=write_strategy)
progress.log_stream_finalized(stream_name)

Expand Down Expand Up @@ -394,3 +397,17 @@ def _get_stream_json_schema(
) -> dict[str, Any]:
"""Return the column definitions for the given stream."""
return self._get_stream_config(stream_name).stream.json_schema

def _get_stream_pyarrow_schema(
self,
stream_name: str,
) -> pa.Schema:
"""Return the column definitions for the given stream."""
return pa.schema(
fields=[
pa.field(prop_name, _get_pyarrow_type(prop_def))
for prop_name, prop_def in self._get_stream_json_schema(stream_name)[
"properties"
].items()
]
)
68 changes: 36 additions & 32 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, cast, final

import pandas as pd
import pyarrow as pa
import sqlalchemy
import ulid
from overrides import overrides
Expand Down Expand Up @@ -42,6 +41,7 @@
from collections.abc import Generator, Iterator
from pathlib import Path

import pyarrow as pa
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -545,17 +545,6 @@ def _finalize_batches(
although this is a fairly rare edge case we can ignore in V1.
"""
with self._finalizing_batches(stream_name) as batches_to_finalize:
if not batches_to_finalize:
return {}

files: list[Path] = []
# Get a list of all files to finalize from all pending batches.
for batch_handle in batches_to_finalize.values():
batch_handle = cast(FileWriterBatchHandle, batch_handle)
files += batch_handle.files
# Use the max batch ID as the batch ID for table names.
max_batch_id = max(batches_to_finalize.keys())

# Make sure the target schema and target table exist.
self._ensure_schema_exists()
final_table_name = self._ensure_final_table_exists(
Expand All @@ -567,6 +556,18 @@ def _finalize_batches(
raise_on_error=True,
)

if not batches_to_finalize:
# If there are no batches to finalize, return after ensuring the table exists.
return {}

files: list[Path] = []
# Get a list of all files to finalize from all pending batches.
for batch_handle in batches_to_finalize.values():
batch_handle = cast(FileWriterBatchHandle, batch_handle)
files += batch_handle.files
# Use the max batch ID as the batch ID for table names.
max_batch_id = max(batches_to_finalize.keys())

temp_table_name = self._write_files_to_new_table(
files=files,
stream_name=stream_name,
Expand Down Expand Up @@ -659,27 +660,25 @@ def _write_files_to_new_table(
"""
temp_table_name = self._create_table_for_loading(stream_name, batch_id)
for file_path in files:
with pa.parquet.ParquetFile(file_path) as pf:
record_batch = pf.read()
dataframe = record_batch.to_pandas()

# Pandas will auto-create the table if it doesn't exist, which we don't want.
if not self._table_exists(temp_table_name):
raise exc.AirbyteLibInternalError(
message="Table does not exist after creation.",
context={
"temp_table_name": temp_table_name,
},
)

dataframe.to_sql(
temp_table_name,
self.get_sql_alchemy_url(),
schema=self.config.schema_name,
if_exists="append",
index=False,
dtype=self._get_sql_column_definitions(stream_name),
dataframe = pd.read_json(file_path, lines=True)

# Pandas will auto-create the table if it doesn't exist, which we don't want.
if not self._table_exists(temp_table_name):
raise exc.AirbyteLibInternalError(
message="Table does not exist after creation.",
context={
"temp_table_name": temp_table_name,
},
)

dataframe.to_sql(
temp_table_name,
self.get_sql_alchemy_url(),
schema=self.config.schema_name,
if_exists="append",
index=False,
dtype=self._get_sql_column_definitions(stream_name),
)
return temp_table_name

@final
Expand Down Expand Up @@ -959,6 +958,11 @@ def register_source(
This method is called by the source when it is initialized.
"""
self._source_name = source_name
self.file_writer.register_source(
source_name,
incoming_source_catalog,
stream_names=stream_names,
)
self._ensure_schema_exists()
super().register_source(
source_name,
Expand Down
34 changes: 23 additions & 11 deletions airbyte/caches/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from overrides import overrides

from airbyte._file_writers import ParquetWriter, ParquetWriterConfig
from airbyte._file_writers import JsonlWriter, JsonlWriterConfig
from airbyte.caches.base import SQLCacheBase, SQLCacheConfigBase
from airbyte.telemetry import CacheTelemetryInfo

Expand All @@ -24,10 +24,10 @@
)


class DuckDBCacheConfig(SQLCacheConfigBase, ParquetWriterConfig):
class DuckDBCacheConfig(SQLCacheConfigBase, JsonlWriterConfig):
"""Configuration for the DuckDB cache.
Also inherits config from the ParquetWriter, which is responsible for writing files to disk.
Also inherits config from the JsonlWriter, which is responsible for writing files to disk.
"""

db_path: Path | str
Expand Down Expand Up @@ -88,7 +88,7 @@ class DuckDBCache(DuckDBCacheBase):
so we insert as values instead.
"""

file_writer_class = ParquetWriter
file_writer_class = JsonlWriter

# TODO: Delete or rewrite this method after DuckDB adds support for primary key inspection.
# @overrides
Expand Down Expand Up @@ -181,12 +181,22 @@ def _write_files_to_new_table(
stream_name=stream_name,
batch_id=batch_id,
)
columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
]
columns_list_str = indent("\n, ".join(columns_list), " ")
columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys())
columns_list_str = indent(
"\n, ".join([self._quote_identifier(c) for c in columns_list]),
" ",
)
files_list = ", ".join([f"'{f!s}'" for f in files])
columns_type_map = indent(
"\n, ".join(
[
f"{self._quote_identifier(c)}: "
f"{self._get_sql_column_definitions(stream_name)[c]!s}"
for c in columns_list
]
),
" ",
)
insert_statement = dedent(
f"""
INSERT INTO {self.config.schema_name}.{temp_table_name}
Expand All @@ -195,9 +205,11 @@ def _write_files_to_new_table(
)
SELECT
{columns_list_str}
FROM read_parquet(
FROM read_json_auto(
[{files_list}],
union_by_name = true
format = 'newline_delimited',
union_by_name = true,
columns = {{ { columns_type_map } }}
)
"""
)
Expand Down
Loading

0 comments on commit 81d1b9c

Please sign in to comment.