Skip to content

Commit

Permalink
handle pruning of extra properties in streams; make get_stream_json_s…
Browse files Browse the repository at this point in the history
…chema a public method
  • Loading branch information
aaronsteers committed Mar 20, 2024
1 parent 320b6a6 commit b0b57fb
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 14 deletions.
16 changes: 14 additions & 2 deletions airbyte/_processors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def process_input_stream(
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
) -> None:
"""Write a record to the cache.
Expand All @@ -167,11 +168,22 @@ def process_airbyte_messages(
context={"write_strategy": write_strategy},
)

stream_schemas: dict[str, dict] = {}

# Process messages, writing to batches as we go
for message in messages:
if message.type is Type.RECORD:
stream_name: str = message.stream
if stream_name not in stream_schemas:
stream_schemas[stream_name] = self.cache.processor.get_stream_json_schema(
stream_name=message.stream
)

record_msg = cast(AirbyteRecordMessage, message.record)
self.process_record_message(record_msg)
self.process_record_message(
record_msg,
stream_schema=stream_schemas[stream_name],
)

elif message.type is Type.STATE:
state_msg = cast(AirbyteStateMessage, message.state)
Expand Down Expand Up @@ -248,7 +260,7 @@ def _get_stream_config(
return self._catalog_manager.get_stream_config(stream_name)

@final
def _get_stream_json_schema(
def get_stream_json_schema(
self,
stream_name: str,
) -> dict[str, Any]:
Expand Down
8 changes: 7 additions & 1 deletion airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class FileWriterBase(abc.ABC):
"""A generic base implementation for a file-based cache."""

default_cache_file_suffix: str = ".batch"
prune_extra_fields: bool = False

MAX_BATCH_SIZE: int = DEFAULT_BATCH_SIZE

Expand Down Expand Up @@ -140,6 +141,7 @@ def cleanup_all(self) -> None:
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
) -> None:
"""Write a record to the cache.
Expand All @@ -165,7 +167,11 @@ def process_record_message(
raise exc.AirbyteLibInternalError(message="Expected open file writer.")

self._write_record_dict(
record_dict=airbyte_record_message_to_dict(record_message=record_msg),
record_dict=airbyte_record_message_to_dict(
record_message=record_msg,
stream_schema=stream_schema,
prune_extra_fields=self.prune_extra_fields,
),
open_file_writer=batch_handle.open_file_writer,
)
batch_handle.increment_record_count()
Expand Down
1 change: 1 addition & 0 deletions airbyte/_processors/file/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class JsonlWriter(FileWriterBase):
"""A Jsonl cache implementation."""

default_cache_file_suffix = ".jsonl.gz"
prune_extra_fields = True

def _open_new_file(
self,
Expand Down
19 changes: 14 additions & 5 deletions airbyte/_processors/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def get_pandas_dataframe(
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
) -> None:
"""Write a record to the cache.
Expand All @@ -249,7 +250,10 @@ def process_record_message(
In most cases, the SQL processor will not perform any action, but will pass this along to to
the file processor.
"""
self.file_writer.process_record_message(record_msg)
self.file_writer.process_record_message(
record_msg,
stream_schema=stream_schema,
)

# Protected members (non-public interface):

Expand Down Expand Up @@ -419,7 +423,7 @@ def _ensure_compatible_table_schema(
Returns true if the table is compatible, false if it is not.
"""
json_schema = self._get_stream_json_schema(stream_name)
json_schema = self.get_stream_json_schema(stream_name)
stream_column_names: list[str] = json_schema["properties"].keys()
table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys()

Expand Down Expand Up @@ -461,12 +465,12 @@ def _create_table(
"""
_ = self._execute_sql(cmd)

def _get_stream_properties(
def get_stream_properties(
self,
stream_name: str,
) -> dict[str, dict]:
"""Return the names of the top-level properties for the given stream."""
return self._get_stream_json_schema(stream_name)["properties"]
return self.get_stream_json_schema(stream_name)["properties"]

@final
def _get_sql_column_definitions(
Expand All @@ -475,7 +479,7 @@ def _get_sql_column_definitions(
) -> dict[str, sqlalchemy.types.TypeEngine]:
"""Return the column definitions for the given stream."""
columns: dict[str, sqlalchemy.types.TypeEngine] = {}
properties = self._get_stream_properties(stream_name)
properties = self.get_stream_properties(stream_name)
for property_name, json_schema_property_def in properties.items():
clean_prop_name = self.normalizer.normalize(property_name)
columns[clean_prop_name] = self.type_converter.to_sql_type(
Expand Down Expand Up @@ -630,6 +634,11 @@ def _write_files_to_new_table(
for file_path in files:
dataframe = pd.read_json(file_path, lines=True)

# Remove fields that are not in the schema
for col_name in dataframe.columns:
if col_name not in self.get_stream_properties(stream_name):
dataframe = dataframe.drop(columns=property)

# Pandas will auto-create the table if it doesn't exist, which we don't want.
if not self._table_exists(temp_table_name):
raise exc.AirbyteLibInternalError(
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_processors/sql/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _write_files_to_new_table(
stream_name=stream_name,
batch_id=batch_id,
)
properties_list = list(self._get_stream_properties(stream_name).keys())
properties_list = list(self.get_stream_properties(stream_name).keys())
columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys())
columns_list_str = indent(
"\n, ".join([self._quote_identifier(c) for c in columns_list]),
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_processors/sql/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def path_str(path: Path) -> str:
[f"PUT 'file://{path_str(file_path)}' {internal_sf_stage_name};" for file_path in files]
)
self._execute_sql(put_files_statements)
properties_list: list[str] = list(self._get_stream_properties(stream_name).keys())
properties_list: list[str] = list(self.get_stream_properties(stream_name).keys())
columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
Expand Down
39 changes: 35 additions & 4 deletions airbyte/_util/protocol_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,72 @@

def airbyte_messages_to_record_dicts(
messages: Iterable[AirbyteMessage],
stream_schema: dict,
*,
prune_extra_fields: bool = False,
) -> Iterator[dict[str, Any]]:
"""Convert an AirbyteMessage to a dictionary."""
yield from (
cast(dict[str, Any], airbyte_message_to_record_dict(message))
cast(
dict[str, Any],
airbyte_message_to_record_dict(
message,
stream_schema=stream_schema,
prune_extra_fields=prune_extra_fields,
),
)
for message in messages
if message is not None and message.type == Type.RECORD
)


def airbyte_message_to_record_dict(message: AirbyteMessage) -> dict[str, Any] | None:
def airbyte_message_to_record_dict(
message: AirbyteMessage,
stream_schema: dict,
*,
prune_extra_fields: bool = False,
) -> dict[str, Any] | None:
"""Convert an AirbyteMessage to a dictionary.
Return None if the message is not a record message.
"""
if message.type != Type.RECORD:
return None

return airbyte_record_message_to_dict(message.record)
return airbyte_record_message_to_dict(
message.record,
stream_schema=stream_schema,
prune_extra_fields=prune_extra_fields,
)


def airbyte_record_message_to_dict(
record_message: AirbyteRecordMessage,
stream_schema: dict,
*,
prune_extra_fields: bool = False,
) -> dict[str, Any]:
"""Convert an AirbyteMessage to a dictionary.
Return None if the message is not a record message.
"""
result = record_message.data

if prune_extra_fields:
if not stream_schema or "properties" not in stream_schema:
raise exc.AirbyteLibInternalError(
message="A valid `stream_schema` is required when `prune_extra_fields` is `True`."
)
for prop_name in record_message:
if prop_name not in stream_schema["properties"]:
record_message.pop(property)

# TODO: Add the metadata columns (this breaks tests)
# result["_airbyte_extracted_at"] = datetime.datetime.fromtimestamp(
# record_message.emitted_at
# )

return result # noqa: RET504 # unnecessary assignment and then return (see TODO above)
return result


def get_primary_keys_from_stream(
Expand Down
26 changes: 26 additions & 0 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Any, cast

from airbyte_protocol.models.airbyte_protocol import AirbyteStream
import jsonschema
import pendulum
import yaml
Expand Down Expand Up @@ -292,6 +293,29 @@ def configured_catalog(self) -> ConfiguredAirbyteCatalog:
],
)

def get_stream_json_schema(self, stream_name: str) -> dict[str, Any]:
"""Return the JSON Schema spec for the specified stream name."""
catalog: AirbyteCatalog = self.discovered_catalog
found: list[AirbyteStream] = [
stream for stream in catalog.streams if stream.name == stream_name
]

if len(found) == 0:
raise exc.AirbyteLibInputError(
message="Stream name does not exist in catalog.",
input_value=stream_name,
)

if len(found) > 1:
raise exc.AirbyteLibInternalError(
message="Duplicate streams found with the same name.",
context={
"found_streams": found,
},
)

return found[0].json_schema

def get_records(self, stream: str) -> LazyDataset:
"""Read a stream from the connector.
Expand Down Expand Up @@ -339,6 +363,8 @@ def _with_logging(records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]
normalize_records(
records=protocol_util.airbyte_messages_to_record_dicts(
self._read_with_catalog(configured_catalog),
stream_schema=self.get_stream_json_schema(stream),
prune_extra_fields=True,
),
expected_keys=all_properties,
)
Expand Down

0 comments on commit b0b57fb

Please sign in to comment.