Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Resolve issue where read() would fail if it received unexpected/undeclared top-level properties in a stream #131

Merged
merged 55 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
6778dd7
add back ci tests for 3.9 and 3.11
aaronsteers Mar 5, 2024
636c70f
try with explicit pydantic Field() syntax
aaronsteers Mar 5, 2024
bb73c7c
disable lint rule
aaronsteers Mar 5, 2024
6a8d3ec
use legacy Optional in type hints
aaronsteers Mar 5, 2024
5c7e82c
update pydantic model properties
aaronsteers Mar 5, 2024
82010d3
fix duckdb type
aaronsteers Mar 5, 2024
4e3b02b
add to test files: 'from __future__ import annotations'
aaronsteers Mar 5, 2024
3a23e7c
add to integ tests: 'from __future__ import annotations'
aaronsteers Mar 5, 2024
48cd156
Merge remote-tracking branch 'origin/main' into aj/add-windows-supprt
aaronsteers Mar 5, 2024
a303ecf
ci: Add Windows to test matrix
aaronsteers Mar 5, 2024
6350c7a
support windows venv execution
aaronsteers Mar 5, 2024
cc4658d
fix test
aaronsteers Mar 5, 2024
b9a6fb9
fix other refs
aaronsteers Mar 5, 2024
3de0114
update test fixtures
aaronsteers Mar 5, 2024
f5023b6
skip postgres tests on windows
aaronsteers Mar 5, 2024
e30d752
Merge branch 'main' into aj/add-windows-supprt
aaronsteers Mar 5, 2024
8b42cb4
find executable on windows
aaronsteers Mar 5, 2024
1001e40
improve 'which' check logic
aaronsteers Mar 5, 2024
6e164c3
poetry add --dev airbyte-source-pokeapi
aaronsteers Mar 14, 2024
86722f4
remove unused import
aaronsteers Mar 14, 2024
88af72c
update from main
aaronsteers Mar 18, 2024
90d78f7
use windows-style separator if needed
aaronsteers Mar 18, 2024
4524d75
use os-specific separator for path
aaronsteers Mar 18, 2024
c581b6b
fix temp file deletion issue
aaronsteers Mar 18, 2024
a030413
remove dupe test
aaronsteers Mar 18, 2024
f65895c
simplify test
aaronsteers Mar 18, 2024
b2f014d
fix install/uninstall tests
aaronsteers Mar 18, 2024
9500220
fix connector path
aaronsteers Mar 18, 2024
6e7e5b7
fix pathsep
aaronsteers Mar 18, 2024
de46973
fix failing test
aaronsteers Mar 18, 2024
ab74e10
lint fix
aaronsteers Mar 18, 2024
bce1867
handle http error
aaronsteers Mar 18, 2024
6384b3d
fix exception import
aaronsteers Mar 18, 2024
c3acdad
skip pg tests on windows
aaronsteers Mar 19, 2024
20d5d84
escape windows paths
aaronsteers Mar 19, 2024
42d3ae6
fix interpreter_path
aaronsteers Mar 19, 2024
23d1eec
ignore cleanup errors
aaronsteers Mar 19, 2024
70080ce
fix temp dir access
aaronsteers Mar 19, 2024
c488139
Merge remote-tracking branch 'origin/main' into aj/add-windows-supprt
aaronsteers Mar 19, 2024
5b2236a
Merge remote-tracking branch 'origin/main' into aj/add-poke-api-integ…
aaronsteers Mar 19, 2024
adf9304
add pokeapi test
aaronsteers Mar 19, 2024
769b01e
fix renamed fixture
aaronsteers Mar 19, 2024
7a0275b
drop postgres fixture on windows
aaronsteers Mar 19, 2024
2acf111
Merge remote-tracking branch 'origin/aj/add-windows-supprt' into aj/a…
aaronsteers Mar 19, 2024
a0ed05d
Merge branch 'main' into aj/add-poke-api-integ-tests
aaronsteers Mar 19, 2024
14e30d9
fix test
aaronsteers Mar 19, 2024
c9885ba
update docstring
aaronsteers Mar 19, 2024
bd64605
quote bigquery identifiers
aaronsteers Mar 19, 2024
320b6a6
use json type for array and object
aaronsteers Mar 19, 2024
b0b57fb
handle pruning of extra properties in streams; make get_stream_json_s…
aaronsteers Mar 20, 2024
413f364
fix prune of extra columns
aaronsteers Mar 22, 2024
d98f6dd
Merge branch 'main' into aj/add-poke-api-integ-tests
aaronsteers Mar 22, 2024
bd5d3a6
Merge branch 'main' into aj/add-poke-api-integ-tests
aaronsteers Mar 27, 2024
735da0c
`poetry lock`
aaronsteers Mar 27, 2024
ce8a178
skip flaky test in ci
aaronsteers Mar 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions airbyte/_processors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def process_input_stream(
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
) -> None:
"""Write a record to the cache.

Expand All @@ -167,11 +168,23 @@ def process_airbyte_messages(
context={"write_strategy": write_strategy},
)

stream_schemas: dict[str, dict] = {}

# Process messages, writing to batches as we go
for message in messages:
if message.type is Type.RECORD:
record_msg = cast(AirbyteRecordMessage, message.record)
self.process_record_message(record_msg)
stream_name = record_msg.stream

if stream_name not in stream_schemas:
stream_schemas[stream_name] = self.cache.processor.get_stream_json_schema(
stream_name=stream_name
)

self.process_record_message(
record_msg,
stream_schema=stream_schemas[stream_name],
)

elif message.type is Type.STATE:
state_msg = cast(AirbyteStateMessage, message.state)
Expand Down Expand Up @@ -248,7 +261,7 @@ def _get_stream_config(
return self._catalog_manager.get_stream_config(stream_name)

@final
def _get_stream_json_schema(
def get_stream_json_schema(
self,
stream_name: str,
) -> dict[str, Any]:
Expand Down
8 changes: 7 additions & 1 deletion airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class FileWriterBase(abc.ABC):
"""A generic base implementation for a file-based cache."""

default_cache_file_suffix: str = ".batch"
prune_extra_fields: bool = False

MAX_BATCH_SIZE: int = DEFAULT_BATCH_SIZE

Expand Down Expand Up @@ -140,6 +141,7 @@ def cleanup_all(self) -> None:
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
) -> None:
"""Write a record to the cache.

Expand All @@ -165,7 +167,11 @@ def process_record_message(
raise exc.AirbyteLibInternalError(message="Expected open file writer.")

self._write_record_dict(
record_dict=airbyte_record_message_to_dict(record_message=record_msg),
record_dict=airbyte_record_message_to_dict(
record_message=record_msg,
stream_schema=stream_schema,
prune_extra_fields=self.prune_extra_fields,
),
open_file_writer=batch_handle.open_file_writer,
)
batch_handle.increment_record_count()
Expand Down
1 change: 1 addition & 0 deletions airbyte/_processors/file/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class JsonlWriter(FileWriterBase):
"""A Jsonl cache implementation."""

default_cache_file_suffix = ".jsonl.gz"
prune_extra_fields = True

def _open_new_file(
self,
Expand Down
19 changes: 14 additions & 5 deletions airbyte/_processors/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def get_pandas_dataframe(
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
) -> None:
"""Write a record to the cache.

Expand All @@ -249,7 +250,10 @@ def process_record_message(
In most cases, the SQL processor will not perform any action, but will pass this along to to
the file processor.
"""
self.file_writer.process_record_message(record_msg)
self.file_writer.process_record_message(
record_msg,
stream_schema=stream_schema,
)

# Protected members (non-public interface):

Expand Down Expand Up @@ -419,7 +423,7 @@ def _ensure_compatible_table_schema(

Returns true if the table is compatible, false if it is not.
"""
json_schema = self._get_stream_json_schema(stream_name)
json_schema = self.get_stream_json_schema(stream_name)
stream_column_names: list[str] = json_schema["properties"].keys()
table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys()

Expand Down Expand Up @@ -461,12 +465,12 @@ def _create_table(
"""
_ = self._execute_sql(cmd)

def _get_stream_properties(
def get_stream_properties(
self,
stream_name: str,
) -> dict[str, dict]:
"""Return the names of the top-level properties for the given stream."""
return self._get_stream_json_schema(stream_name)["properties"]
return self.get_stream_json_schema(stream_name)["properties"]

@final
def _get_sql_column_definitions(
Expand All @@ -475,7 +479,7 @@ def _get_sql_column_definitions(
) -> dict[str, sqlalchemy.types.TypeEngine]:
"""Return the column definitions for the given stream."""
columns: dict[str, sqlalchemy.types.TypeEngine] = {}
properties = self._get_stream_properties(stream_name)
properties = self.get_stream_properties(stream_name)
for property_name, json_schema_property_def in properties.items():
clean_prop_name = self.normalizer.normalize(property_name)
columns[clean_prop_name] = self.type_converter.to_sql_type(
Expand Down Expand Up @@ -630,6 +634,11 @@ def _write_files_to_new_table(
for file_path in files:
dataframe = pd.read_json(file_path, lines=True)

# Remove fields that are not in the schema
for col_name in dataframe.columns:
if col_name not in self.get_stream_properties(stream_name):
dataframe = dataframe.drop(columns=property)

# Pandas will auto-create the table if it doesn't exist, which we don't want.
if not self._table_exists(temp_table_name):
raise exc.AirbyteLibInternalError(
Expand Down
8 changes: 6 additions & 2 deletions airbyte/_processors/sql/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ def _fully_qualified(
@final
@overrides
def _quote_identifier(self, identifier: str) -> str:
"""Return the identifier name as is. BigQuery does not require quoting identifiers"""
return f"{identifier}"
"""Return the identifier name.

BigQuery uses backticks to quote identifiers. Because BigQuery is case-sensitive for quoted
identifiers, we convert the identifier to lowercase before quoting it.
"""
return f"`{identifier.lower()}`"

def _write_files_to_new_table(
self,
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_processors/sql/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _write_files_to_new_table(
stream_name=stream_name,
batch_id=batch_id,
)
properties_list = list(self._get_stream_properties(stream_name).keys())
properties_list = list(self.get_stream_properties(stream_name).keys())
columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys())
columns_list_str = indent(
"\n, ".join([self._quote_identifier(c) for c in columns_list]),
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_processors/sql/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def path_str(path: Path) -> str:
[f"PUT 'file://{path_str(file_path)}' {internal_sf_stage_name};" for file_path in files]
)
self._execute_sql(put_files_statements)
properties_list: list[str] = list(self._get_stream_properties(stream_name).keys())
properties_list: list[str] = list(self.get_stream_properties(stream_name).keys())
columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
Expand Down
39 changes: 35 additions & 4 deletions airbyte/_util/protocol_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,72 @@

def airbyte_messages_to_record_dicts(
messages: Iterable[AirbyteMessage],
stream_schema: dict,
*,
prune_extra_fields: bool = False,
) -> Iterator[dict[str, Any]]:
"""Convert an AirbyteMessage to a dictionary."""
yield from (
cast(dict[str, Any], airbyte_message_to_record_dict(message))
cast(
dict[str, Any],
airbyte_message_to_record_dict(
message,
stream_schema=stream_schema,
prune_extra_fields=prune_extra_fields,
),
)
for message in messages
if message is not None and message.type == Type.RECORD
)


def airbyte_message_to_record_dict(message: AirbyteMessage) -> dict[str, Any] | None:
def airbyte_message_to_record_dict(
message: AirbyteMessage,
stream_schema: dict,
*,
prune_extra_fields: bool = False,
) -> dict[str, Any] | None:
"""Convert an AirbyteMessage to a dictionary.

Return None if the message is not a record message.
"""
if message.type != Type.RECORD:
return None

return airbyte_record_message_to_dict(message.record)
return airbyte_record_message_to_dict(
message.record,
stream_schema=stream_schema,
prune_extra_fields=prune_extra_fields,
)


def airbyte_record_message_to_dict(
record_message: AirbyteRecordMessage,
stream_schema: dict,
*,
prune_extra_fields: bool = False,
) -> dict[str, Any]:
"""Convert an AirbyteMessage to a dictionary.

Return None if the message is not a record message.
"""
result = record_message.data

if prune_extra_fields:
if not stream_schema or "properties" not in stream_schema:
raise exc.AirbyteLibInternalError(
message="A valid `stream_schema` is required when `prune_extra_fields` is `True`."
)
for prop_name in list(result.keys()):
if prop_name not in stream_schema["properties"]:
result.pop(prop_name)

# TODO: Add the metadata columns (this breaks tests)
# result["_airbyte_extracted_at"] = datetime.datetime.fromtimestamp(
# record_message.emitted_at
# )

return result # noqa: RET504 # unnecessary assignment and then return (see TODO above)
return result


def get_primary_keys_from_stream(
Expand Down
27 changes: 27 additions & 0 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
if TYPE_CHECKING:
from collections.abc import Generator, Iterable, Iterator

from airbyte_protocol.models.airbyte_protocol import AirbyteStream

from airbyte._executor import Executor
from airbyte.caches import CacheBase
from airbyte.documents import Document
Expand Down Expand Up @@ -368,6 +370,29 @@ def configured_catalog(self) -> ConfiguredAirbyteCatalog:
],
)

def get_stream_json_schema(self, stream_name: str) -> dict[str, Any]:
"""Return the JSON Schema spec for the specified stream name."""
catalog: AirbyteCatalog = self.discovered_catalog
found: list[AirbyteStream] = [
stream for stream in catalog.streams if stream.name == stream_name
]

if len(found) == 0:
raise exc.AirbyteLibInputError(
message="Stream name does not exist in catalog.",
input_value=stream_name,
)

if len(found) > 1:
raise exc.AirbyteLibInternalError(
message="Duplicate streams found with the same name.",
context={
"found_streams": found,
},
)

return found[0].json_schema

def get_records(self, stream: str) -> LazyDataset:
"""Read a stream from the connector.

Expand Down Expand Up @@ -415,6 +440,8 @@ def _with_logging(records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]
normalize_records(
records=protocol_util.airbyte_messages_to_record_dicts(
self._read_with_catalog(configured_catalog),
stream_schema=self.get_stream_json_schema(stream),
prune_extra_fields=True,
),
expected_keys=all_properties,
)
Expand Down
6 changes: 2 additions & 4 deletions airbyte/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,9 @@ def to_sql_type(
return sqlalchemy.types.TIMESTAMP()

if json_schema_type == "array":
# TODO: Implement array type conversion.
return self.get_failover_type()
return sqlalchemy.types.JSON()

if json_schema_type == "object":
# TODO: Implement object type handling.
return self.get_failover_type()
return sqlalchemy.types.JSON()

return self.get_failover_type()
Loading
Loading