Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Resolve issues in Postgres name normalization when names are >63 characters #359

Merged
merged 14 commits into from
Sep 9, 2024
22 changes: 22 additions & 0 deletions airbyte/_processors/sql/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
"""A Postgres implementation of the cache."""

from __future__ import annotations
from functools import lru_cache

from overrides import overrides

from airbyte._util.name_normalizers import LowerCaseNormalizer
from airbyte._writers.jsonl import JsonlWriter
from airbyte.secrets.base import SecretString
from airbyte.shared.sql_processor import SqlConfig, SqlProcessorBase
Expand Down Expand Up @@ -35,6 +37,23 @@ def get_database_name(self) -> str:
return self.database


class PostgresNormalizer(LowerCaseNormalizer):
"""A name normalizer for Postgres.

Postgres has specific field name length limits:
- Tables names are limited to 63 characters.
- Column names are limited to 63 characters.

The postgres normalizer inherits from the default LowerCaseNormalizer class, and
additionally truncates column and table names to 63 characters.
"""

@lru_cache
def normalize(self, name: str) -> str:
"""Normalize the name, truncating to 63 characters."""
return super().normalize(name)[:63]


class PostgresSqlProcessor(SqlProcessorBase):
"""A Postgres implementation of the cache.

Expand All @@ -49,3 +68,6 @@ class PostgresSqlProcessor(SqlProcessorBase):
supports_merge_insert = False
file_writer_class = JsonlWriter
sql_config: PostgresConfig

normalizer = PostgresNormalizer
"""A Postgres-specific name normalizer for table and column name normalization."""
15 changes: 13 additions & 2 deletions airbyte/shared/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,19 @@ def _get_temp_table_name(
batch_id: str | None = None, # ULID of the batch
) -> str:
"""Return a new (unique) temporary table name."""
batch_id = batch_id or str(ulid.ULID())
return self.normalizer.normalize(f"{stream_name}_{batch_id}")
if not batch_id:
batch_id = str(ulid.ULID())

if len(batch_id) > 9:
# Use the first 6 and last 3 characters of the ULID. This gives great uniqueness while
# limiting the table name suffix to 10 characters, including the underscore.
suffix = f"{batch_id[:6]}{batch_id[-3:]}"
else:
suffix = batch_id

# Note: The normalizer may truncate the table name if the database has a name length limit.
# For instance, the Postgres normalizer will enforce a 63-character limit on table names.
return self.normalizer.normalize(f"{stream_name}_{suffix}")

def _fully_qualified(
self,
Expand Down
22 changes: 19 additions & 3 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,25 @@ def get_stream_json_schema(self, stream_name: str) -> dict[str, Any]:

return found[0].json_schema

def get_records(self, stream: str) -> LazyDataset:
def get_records(
self,
stream: str,
*,
normalize_field_names: bool = False,
prune_undeclared_fields: bool = True,
) -> LazyDataset:
"""Read a stream from the connector.

Args:
stream: The name of the stream to read.
normalize_field_names: When `True`, field names will be normalized to lower case, with
special characters removed. This matches the behavior of PyAirbyte caches and most
Airbyte destinations.
prune_undeclared_fields: When `True`, undeclared fields will be pruned from the records,
which generally matches the behavior of PyAirbyte caches and most Airbyte
destinations, specifically when you expect the catalog may be stale. You can disable
this to keep all fields in the records.

This involves the following steps:
* Call discover to get the catalog
* Generate a configured catalog that syncs the given stream in full_refresh mode
Expand Down Expand Up @@ -445,8 +461,8 @@ def _with_logging(records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]

stream_record_handler = StreamRecordHandler(
json_schema=self.get_stream_json_schema(stream),
prune_extra_fields=True,
normalize_keys=False,
prune_extra_fields=prune_undeclared_fields,
normalize_keys=normalize_field_names,
)

# This method is non-blocking, so we use "PLAIN" to avoid a live progress display
Expand Down
42 changes: 32 additions & 10 deletions tests/integration_tests/test_source_test_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ def expected_test_stream_data() -> dict[str, list[dict[str, str | int]]]:
},
],
"always-empty-stream": [],
"primary-key-with-dot": [
# Expect field names lowercase, with '.' replaced by '_':
{
"table1_column1": "value1",
"table1_column2": 1,
"table1_empty_column": None,
"table1_big_number": 1234567890123456,
}
],
}


Expand Down Expand Up @@ -325,7 +334,7 @@ def test_file_write_and_cleanup() -> None:

# There are three streams, but only two of them have data:
assert (
len(list(Path(temp_dir_2).glob("*.jsonl.gz"))) == 2
len(list(Path(temp_dir_2).glob("*.jsonl.gz"))) == 3
), "Expected files to exist"

with suppress(Exception):
Expand All @@ -342,21 +351,23 @@ def test_sync_to_duckdb(

result: ReadResult = source.read(cache)

assert result.processed_records == 3
assert result.processed_records == sum(
len(stream_data) for stream_data in expected_test_stream_data.values()
)
assert_data_matches_cache(expected_test_stream_data, cache)


def test_read_result_mapping():
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()
result: ReadResult = source.read(ab.new_local_cache())
assert len(result) == 3
assert len(result) == 4
assert isinstance(result, Mapping)
assert "stream1" in result
assert "stream2" in result
assert "always-empty-stream" in result
assert "stream3" not in result
assert result.keys() == {"stream1", "stream2", "always-empty-stream"}
assert result.keys() == {"stream1", "stream2", "always-empty-stream", "primary-key-with-dot"}


def test_dataset_list_and_len(expected_test_stream_data):
Expand All @@ -381,7 +392,7 @@ def test_dataset_list_and_len(expected_test_stream_data):
assert "stream2" in result
assert "always-empty-stream" in result
assert "stream3" not in result
assert result.keys() == {"stream1", "stream2", "always-empty-stream"}
assert result.keys() == {"stream1", "stream2", "always-empty-stream", "primary-key-with-dot"}


def test_read_from_cache(
Expand Down Expand Up @@ -456,6 +467,8 @@ def test_merge_streams_in_cache(
"""
Test that we can extend a cache with new streams
"""
expected_test_stream_data.pop("primary-key-with-dot") # Stream not needed for this test.

cache_name = str(ulid.ULID())
source = ab.get_source("source-test", config={"apiKey": "test"})
cache = ab.new_local_cache(cache_name)
Expand Down Expand Up @@ -552,7 +565,7 @@ def test_sync_with_merge_to_duckdb(
result: ReadResult = source.read(cache)
result: ReadResult = source.read(cache)

assert result.processed_records == 3
assert result.processed_records == 4
for stream_name, expected_data in expected_test_stream_data.items():
if len(cache[stream_name]) > 0:
pd.testing.assert_frame_equal(
Expand Down Expand Up @@ -713,7 +726,10 @@ def test_lazy_dataset_from_source(
for stream_name in source.get_available_streams():
assert isinstance(stream_name, str)

lazy_dataset: LazyDataset = source.get_records(stream_name)
lazy_dataset: LazyDataset = source.get_records(
stream_name,
normalize_field_names=True,
)
assert isinstance(lazy_dataset, LazyDataset)

list_data = list(lazy_dataset)
Expand Down Expand Up @@ -756,7 +772,9 @@ def test_sync_with_merge_to_postgres(
result: ReadResult = source.read(new_postgres_cache, write_strategy="merge")
result: ReadResult = source.read(new_postgres_cache, write_strategy="merge")

assert result.processed_records == 3
assert result.processed_records == sum(
len(stream_data) for stream_data in expected_test_stream_data.values()
)
assert_data_matches_cache(
expected_test_stream_data=expected_test_stream_data,
cache=new_postgres_cache,
Expand All @@ -780,7 +798,9 @@ def test_sync_to_postgres(

result: ReadResult = source.read(new_postgres_cache)

assert result.processed_records == 3
assert result.processed_records == sum(
len(stream_data) for stream_data in expected_test_stream_data.values()
)
for stream_name, expected_data in expected_test_stream_data.items():
if len(new_postgres_cache[stream_name]) > 0:
pd.testing.assert_frame_equal(
Expand All @@ -804,7 +824,9 @@ def test_sync_to_snowflake(

result: ReadResult = source.read(new_snowflake_cache)

assert result.processed_records == 3
assert result.processed_records == sum(
len(stream_data) for stream_data in expected_test_stream_data.values()
)
for stream_name, expected_data in expected_test_stream_data.items():
if len(new_snowflake_cache[stream_name]) > 0:
pd.testing.assert_frame_equal(
Expand Down
Loading