diff --git a/airbyte/_future_cdk/sql_processor.py b/airbyte/_future_cdk/sql_processor.py index a3ec16cf..0212eac4 100644 --- a/airbyte/_future_cdk/sql_processor.py +++ b/airbyte/_future_cdk/sql_processor.py @@ -410,40 +410,20 @@ def _ensure_final_table_exists( def _ensure_compatible_table_schema( self, stream_name: str, - *, - raise_on_error: bool = False, - ) -> bool: + table_name: str, + ) -> None: """Return true if the given table is compatible with the stream's schema. - If raise_on_error is true, raise an exception if the table is not compatible. - - TODO: Expand this to check for column types and sizes, and to add missing columns. + Raises an exception if the table schema is not compatible with the schema of the + input stream. - Returns true if the table is compatible, false if it is not. + TODO: + - Expand this to check for column types and sizes. """ - json_schema = self.catalog_provider.get_stream_json_schema(stream_name) - stream_column_names: list[str] = json_schema["properties"].keys() - table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys() - - lower_case_table_column_names = self.normalizer.normalize_set(table_column_names) - missing_columns = [ - stream_col - for stream_col in stream_column_names - if self.normalizer.normalize(stream_col) not in lower_case_table_column_names - ] - if missing_columns: - if raise_on_error: - raise exc.PyAirbyteCacheTableValidationError( - violation="Cache table is missing expected columns.", - context={ - "stream_column_names": stream_column_names, - "table_column_names": table_column_names, - "missing_columns": missing_columns, - }, - ) - return False # Some columns are missing. - - return True # All columns exist. + self._add_missing_columns_to_table( + stream_name=stream_name, + table_name=table_name, + ) @final def _create_table( @@ -509,10 +489,6 @@ def write_stream_data( stream_name, create_if_missing=True, ) - self._ensure_compatible_table_schema( - stream_name=stream_name, - raise_on_error=True, - ) if not batches_to_finalize: # If there are no batches to finalize, return after ensuring the table exists. @@ -675,14 +651,35 @@ def _add_missing_columns_to_table( stream_name: str, table_name: str, ) -> None: - """Add missing columns to the table.""" + """Add missing columns to the table. + + This is a no-op if all columns are already present. + """ columns = self._get_sql_column_definitions(stream_name) - table = self._get_table_by_name(table_name, force_refresh=True) - for column_name, column_type in columns.items(): - if column_name not in table.columns: - self._add_column_to_table(table, column_name, column_type) + # First check without forcing a refresh of the cache (faster). If nothing is missing, + # then we're done. + table = self._get_table_by_name( + table_name, + force_refresh=False, + ) + missing_columns: bool = any(column_name not in table.columns for column_name in columns) - self._invalidate_table_cache(table_name) + if missing_columns: + # If we found missing columns, refresh the cache and then take action on anything + # that's still confirmed missing. + columns_added = False + table = self._get_table_by_name( + table_name, + force_refresh=True, + ) + for column_name, column_type in columns.items(): + if column_name not in table.columns: + self._add_column_to_table(table, column_name, column_type) + columns_added = True + + if columns_added: + # We've added columns, so invalidate the cache. + self._invalidate_table_cache(table_name) @final def _write_temp_table_to_final_table( @@ -712,6 +709,8 @@ def _write_temp_table_to_final_table( write_strategy = WriteStrategy.REPLACE if write_strategy == WriteStrategy.REPLACE: + # Note: No need to check for schema compatibility + # here, because we are fully replacing the table. self._swap_temp_table_with_final_table( stream_name=stream_name, temp_table_name=temp_table_name, @@ -720,7 +719,7 @@ def _write_temp_table_to_final_table( return if write_strategy == WriteStrategy.APPEND: - self._add_missing_columns_to_table( + self._ensure_compatible_table_schema( stream_name=stream_name, table_name=final_table_name, ) @@ -732,7 +731,7 @@ def _write_temp_table_to_final_table( return if write_strategy == WriteStrategy.MERGE: - self._add_missing_columns_to_table( + self._ensure_compatible_table_schema( stream_name=stream_name, table_name=final_table_name, ) diff --git a/airbyte/_processors/sql/snowflakecortex.py b/airbyte/_processors/sql/snowflakecortex.py index 6e700b87..5fb0fb6a 100644 --- a/airbyte/_processors/sql/snowflakecortex.py +++ b/airbyte/_processors/sql/snowflakecortex.py @@ -10,7 +10,6 @@ from overrides import overrides from sqlalchemy import text -from airbyte import exceptions as exc from airbyte._processors.sql.snowflake import ( SnowflakeConfig, SnowflakeSqlProcessor, @@ -104,39 +103,6 @@ def _get_column_list_from_table( conn.close() return column_names - @overrides - def _ensure_compatible_table_schema( - self, - stream_name: str, - *, - raise_on_error: bool = True, - ) -> bool: - """Read the existing table schema using Snowflake python connector""" - json_schema = self.catalog_provider.get_stream_json_schema(stream_name) - stream_column_names: list[str] = json_schema["properties"].keys() - table_column_names: list[str] = self._get_column_list_from_table(stream_name) - - lower_case_table_column_names = self.normalizer.normalize_set(table_column_names) - missing_columns = [ - stream_col - for stream_col in stream_column_names - if self.normalizer.normalize(stream_col) not in lower_case_table_column_names - ] - # TODO: shouldn't we just return false here, so missing tables can be created ? - if missing_columns: - if raise_on_error: - raise exc.PyAirbyteCacheTableValidationError( - violation="Cache table is missing expected columns.", - context={ - "stream_column_names": stream_column_names, - "table_column_names": table_column_names, - "missing_columns": missing_columns, - }, - ) - return False # Some columns are missing. - - return True # All columns exist. - @overrides def _write_files_to_new_table( self,