Skip to content

Commit

Permalink
Chore: Clean up missing column additions logic (#243)
Browse files Browse the repository at this point in the history
Co-authored-by: octavia-squidington-iii <[email protected]>
  • Loading branch information
aaronsteers and octavia-squidington-iii authored May 17, 2024
1 parent e82d37c commit 3f161ee
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 76 deletions.
83 changes: 41 additions & 42 deletions airbyte/_future_cdk/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,40 +410,20 @@ def _ensure_final_table_exists(
def _ensure_compatible_table_schema(
self,
stream_name: str,
*,
raise_on_error: bool = False,
) -> bool:
table_name: str,
) -> None:
"""Return true if the given table is compatible with the stream's schema.
If raise_on_error is true, raise an exception if the table is not compatible.
TODO: Expand this to check for column types and sizes, and to add missing columns.
Raises an exception if the table schema is not compatible with the schema of the
input stream.
Returns true if the table is compatible, false if it is not.
TODO:
- Expand this to check for column types and sizes.
"""
json_schema = self.catalog_provider.get_stream_json_schema(stream_name)
stream_column_names: list[str] = json_schema["properties"].keys()
table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys()

lower_case_table_column_names = self.normalizer.normalize_set(table_column_names)
missing_columns = [
stream_col
for stream_col in stream_column_names
if self.normalizer.normalize(stream_col) not in lower_case_table_column_names
]
if missing_columns:
if raise_on_error:
raise exc.PyAirbyteCacheTableValidationError(
violation="Cache table is missing expected columns.",
context={
"stream_column_names": stream_column_names,
"table_column_names": table_column_names,
"missing_columns": missing_columns,
},
)
return False # Some columns are missing.

return True # All columns exist.
self._add_missing_columns_to_table(
stream_name=stream_name,
table_name=table_name,
)

@final
def _create_table(
Expand Down Expand Up @@ -509,10 +489,6 @@ def write_stream_data(
stream_name,
create_if_missing=True,
)
self._ensure_compatible_table_schema(
stream_name=stream_name,
raise_on_error=True,
)

if not batches_to_finalize:
# If there are no batches to finalize, return after ensuring the table exists.
Expand Down Expand Up @@ -675,14 +651,35 @@ def _add_missing_columns_to_table(
stream_name: str,
table_name: str,
) -> None:
"""Add missing columns to the table."""
"""Add missing columns to the table.
This is a no-op if all columns are already present.
"""
columns = self._get_sql_column_definitions(stream_name)
table = self._get_table_by_name(table_name, force_refresh=True)
for column_name, column_type in columns.items():
if column_name not in table.columns:
self._add_column_to_table(table, column_name, column_type)
# First check without forcing a refresh of the cache (faster). If nothing is missing,
# then we're done.
table = self._get_table_by_name(
table_name,
force_refresh=False,
)
missing_columns: bool = any(column_name not in table.columns for column_name in columns)

self._invalidate_table_cache(table_name)
if missing_columns:
# If we found missing columns, refresh the cache and then take action on anything
# that's still confirmed missing.
columns_added = False
table = self._get_table_by_name(
table_name,
force_refresh=True,
)
for column_name, column_type in columns.items():
if column_name not in table.columns:
self._add_column_to_table(table, column_name, column_type)
columns_added = True

if columns_added:
# We've added columns, so invalidate the cache.
self._invalidate_table_cache(table_name)

@final
def _write_temp_table_to_final_table(
Expand Down Expand Up @@ -712,6 +709,8 @@ def _write_temp_table_to_final_table(
write_strategy = WriteStrategy.REPLACE

if write_strategy == WriteStrategy.REPLACE:
# Note: No need to check for schema compatibility
# here, because we are fully replacing the table.
self._swap_temp_table_with_final_table(
stream_name=stream_name,
temp_table_name=temp_table_name,
Expand All @@ -720,7 +719,7 @@ def _write_temp_table_to_final_table(
return

if write_strategy == WriteStrategy.APPEND:
self._add_missing_columns_to_table(
self._ensure_compatible_table_schema(
stream_name=stream_name,
table_name=final_table_name,
)
Expand All @@ -732,7 +731,7 @@ def _write_temp_table_to_final_table(
return

if write_strategy == WriteStrategy.MERGE:
self._add_missing_columns_to_table(
self._ensure_compatible_table_schema(
stream_name=stream_name,
table_name=final_table_name,
)
Expand Down
34 changes: 0 additions & 34 deletions airbyte/_processors/sql/snowflakecortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from overrides import overrides
from sqlalchemy import text

from airbyte import exceptions as exc
from airbyte._processors.sql.snowflake import (
SnowflakeConfig,
SnowflakeSqlProcessor,
Expand Down Expand Up @@ -104,39 +103,6 @@ def _get_column_list_from_table(
conn.close()
return column_names

@overrides
def _ensure_compatible_table_schema(
self,
stream_name: str,
*,
raise_on_error: bool = True,
) -> bool:
"""Read the existing table schema using Snowflake python connector"""
json_schema = self.catalog_provider.get_stream_json_schema(stream_name)
stream_column_names: list[str] = json_schema["properties"].keys()
table_column_names: list[str] = self._get_column_list_from_table(stream_name)

lower_case_table_column_names = self.normalizer.normalize_set(table_column_names)
missing_columns = [
stream_col
for stream_col in stream_column_names
if self.normalizer.normalize(stream_col) not in lower_case_table_column_names
]
# TODO: shouldn't we just return false here, so missing tables can be created ?
if missing_columns:
if raise_on_error:
raise exc.PyAirbyteCacheTableValidationError(
violation="Cache table is missing expected columns.",
context={
"stream_column_names": stream_column_names,
"table_column_names": table_column_names,
"missing_columns": missing_columns,
},
)
return False # Some columns are missing.

return True # All columns exist.

@overrides
def _write_files_to_new_table(
self,
Expand Down

0 comments on commit 3f161ee

Please sign in to comment.