Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: Bump to Sqlalchemy 2.0 #396

Merged
merged 12 commits into from
Sep 23, 2024
9 changes: 5 additions & 4 deletions airbyte/_processors/sql/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import warnings
from pathlib import Path
from typing import TYPE_CHECKING, final
from typing import TYPE_CHECKING, cast, final

import google.oauth2
import sqlalchemy
Expand All @@ -14,6 +14,7 @@
from google.oauth2 import service_account
from overrides import overrides
from pydantic import Field
from sqlalchemy import types as sqlalchemy_types
from sqlalchemy.engine import make_url

from airbyte import exceptions as exc
Expand Down Expand Up @@ -98,7 +99,7 @@ class BigQueryTypeConverter(SQLTypeConverter):
@classmethod
def get_string_type(cls) -> sqlalchemy.types.TypeEngine:
"""Return the string type for BigQuery."""
return "String"
return cast(sqlalchemy.types.TypeEngine, "String") # BigQuery uses STRING for all strings

@overrides
def to_sql_type(
Expand All @@ -115,9 +116,9 @@ def to_sql_type(
if isinstance(sql_type, sqlalchemy.types.VARCHAR):
return self.get_string_type()
if isinstance(sql_type, sqlalchemy.types.BIGINT):
return "INT64"
return sqlalchemy_types.Integer() # All integers are 64-bit in BigQuery

return sql_type.__class__.__name__
return sql_type


class BigQuerySqlProcessor(SqlProcessorBase):
Expand Down
21 changes: 11 additions & 10 deletions airbyte/_processors/sql/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from textwrap import dedent, indent
from textwrap import indent
from typing import TYPE_CHECKING

import sqlalchemy
from overrides import overrides
from pydantic import Field
from snowflake import connector
from snowflake.sqlalchemy import URL, VARIANT
from sqlalchemy import text

from airbyte import exceptions as exc
from airbyte._writers.jsonl import JsonlWriter
Expand Down Expand Up @@ -146,8 +147,7 @@ def upload_file(file_path: Path) -> None:
files_list = ", ".join([f"'{f.name}'" for f in files])
columns_list_str: str = indent("\n, ".join(columns_list), " " * 12)
variant_cols_str: str = ("\n" + " " * 21 + ", ").join([f"$1:{col}" for col in columns_list])
copy_statement = dedent(
f"""
copy_statement = f"""
COPY INTO {temp_table_name}
(
{columns_list_str}
Expand All @@ -160,8 +160,7 @@ def upload_file(file_path: Path) -> None:
FILE_FORMAT = ( TYPE = JSON, COMPRESSION = GZIP )
;
"""
)
self._execute_sql(copy_statement)
self._execute_sql(text(copy_statement))
return temp_table_name

@overrides
Expand All @@ -175,9 +174,11 @@ def _init_connection_settings(self, connection: Connection) -> None:
This also sets MULTI_STATEMENT_COUNT to 0, which allows multi-statement commands.
"""
connection.execute(
"""
ALTER SESSION SET
QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE
MULTI_STATEMENT_COUNT = 0
"""
text(
"""
ALTER SESSION SET
QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE
MULTI_STATEMENT_COUNT = 0
"""
)
)
7 changes: 3 additions & 4 deletions airbyte/caches/_catalog_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from typing import TYPE_CHECKING

from sqlalchemy import Column, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, declarative_base

from airbyte_protocol.models import (
AirbyteStream,
Expand Down Expand Up @@ -127,7 +126,7 @@ def __init__(

def _ensure_internal_tables(self) -> None:
engine = self._engine
SqlAlchemyModel.metadata.create_all(engine)
SqlAlchemyModel.metadata.create_all(engine) # type: ignore[attr-defined]

def register_source(
self,
Expand Down Expand Up @@ -232,7 +231,7 @@ def _fetch_streams_info(
ConfiguredAirbyteStream(
stream=AirbyteStream(
name=stream.stream_name,
json_schema=json.loads(stream.catalog_metadata),
json_schema=json.loads(stream.catalog_metadata), # type: ignore[arg-type]
supported_sync_modes=[SyncMode.full_refresh],
),
sync_mode=SyncMode.full_refresh,
Expand Down
13 changes: 6 additions & 7 deletions airbyte/caches/_state_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

from pytz import utc
from sqlalchemy import Column, DateTime, PrimaryKeyConstraint, String, and_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, declarative_base

from airbyte_protocol.models import (
AirbyteStateMessage,
Expand Down Expand Up @@ -39,11 +38,11 @@
GLOBAL_STATE_STREAM_NAMES = [GLOBAL_STATE_STREAM_NAME, LEGACY_STATE_STREAM_NAME]


SqlAlchemyModel = declarative_base()
SqlAlchemyModel: type = declarative_base()
"""A base class to use for SQLAlchemy ORM models."""


class CacheStreamStateModel(SqlAlchemyModel): # type: ignore[valid-type,misc]
class CacheStreamStateModel(SqlAlchemyModel): # type: ignore[misc]
"""A SQLAlchemy ORM model to store state metadata for internal caches."""

__tablename__ = CACHE_STATE_TABLE_NAME
Expand All @@ -66,7 +65,7 @@ class CacheStreamStateModel(SqlAlchemyModel): # type: ignore[valid-type,misc]
"""The last time the state was updated."""


class DestinationStreamStateModel(SqlAlchemyModel): # type: ignore[valid-type,misc]
class DestinationStreamStateModel(SqlAlchemyModel): # type: ignore[misc]
"""A SQLAlchemy ORM model to store state metadata for destinations.

This is a separate table from the cache state table. The destination state table
Expand Down Expand Up @@ -127,7 +126,7 @@ def _write_state(
stream_name = GLOBAL_STATE_STREAM_NAME
if state_message.type == AirbyteStateType.LEGACY:
stream_name = LEGACY_STATE_STREAM_NAME
elif state_message.type == AirbyteStateType.STREAM:
elif state_message.type == AirbyteStateType.STREAM and state_message.stream:
stream_name = state_message.stream.stream_descriptor.name
else:
raise PyAirbyteInternalError(
Expand Down Expand Up @@ -199,7 +198,7 @@ def __init__(
def _ensure_internal_tables(self) -> None:
"""Ensure the internal tables exist in the SQL database."""
engine = self._engine
SqlAlchemyModel.metadata.create_all(engine)
SqlAlchemyModel.metadata.create_all(engine) # type: ignore[attr-defined]

def get_state_provider(
self,
Expand Down
4 changes: 2 additions & 2 deletions airbyte/cloud/sync_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def get_sql_cache(self) -> CacheBase:

def get_sql_engine(self) -> sqlalchemy.engine.Engine:
"""Return a SQL Engine for querying a SQL-based destination."""
self.get_sql_cache().get_sql_engine()
return self.get_sql_cache().get_sql_engine()

def get_sql_table_name(self, stream_name: str) -> str:
"""Return the SQL table name of the named stream."""
Expand All @@ -288,7 +288,7 @@ def get_sql_table(
stream_name: str,
) -> sqlalchemy.Table:
"""Return a SQLAlchemy table object for the named stream."""
self.get_sql_cache().processor.get_sql_table(stream_name)
return self.get_sql_cache().processor.get_sql_table(stream_name)

def get_dataset(self, stream_name: str) -> CachedDataset:
"""Retrieve an `airbyte.datasets.CachedDataset` object for a given stream name.
Expand Down
14 changes: 7 additions & 7 deletions airbyte/datasets/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pyarrow.dataset import Dataset
from sqlalchemy import Table
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.selectable import Selectable
from sqlalchemy.sql.expression import Select

from airbyte_protocol.models import ConfiguredAirbyteStream

Expand All @@ -40,7 +40,7 @@ def __init__(
self,
cache: CacheBase,
stream_name: str,
query_statement: Selectable,
query_statement: Select,
stream_configuration: ConfiguredAirbyteStream | None | Literal[False] = None,
) -> None:
"""Initialize the dataset with a cache, stream name, and query statement.
Expand All @@ -60,7 +60,7 @@ def __init__(
self._length: int | None = None
self._cache: CacheBase = cache
self._stream_name: str = stream_name
self._query_statement: Selectable = query_statement
self._query_statement: Select = query_statement
if stream_configuration is None:
try:
stream_configuration = cache.processor.catalog_provider.get_configured_stream_info(
Expand All @@ -72,8 +72,8 @@ def __init__(
stacklevel=2,
)

# Coalesce False to None
stream_configuration = stream_configuration or None
# Coalesce False to None
stream_configuration = stream_configuration or None

super().__init__(stream_metadata=stream_configuration)

Expand All @@ -94,11 +94,11 @@ def __len__(self) -> int:
This method caches the length of the dataset after the first call.
"""
if self._length is None:
count_query = select([func.count()]).select_from(self._query_statement.alias())
count_query = select(func.count()).select_from(self._query_statement.subquery())
with self._cache.processor.get_sql_connection() as conn:
self._length = conn.execute(count_query).scalar()

return self._length
return cast(int, self._length)

def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)
Expand Down
11 changes: 4 additions & 7 deletions airbyte/shared/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
text,
update,
)
from sqlalchemy.sql.elements import TextClause

from airbyte_protocol.models import (
AirbyteMessage,
Expand Down Expand Up @@ -63,6 +62,7 @@
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.type_api import TypeEngine

from airbyte._batch_handles import BatchHandle
Expand Down Expand Up @@ -132,6 +132,7 @@ def get_sql_engine(self) -> Engine:
execution_options={
"schema_translate_map": {None: self.schema_name},
},
future=True,
)

def get_vendor_client(self) -> object:
Expand Down Expand Up @@ -782,10 +783,6 @@ def _execute_sql(self, sql: str | TextClause | Executable) -> CursorResult:
"""Execute the given SQL statement."""
if isinstance(sql, str):
sql = text(sql)
if isinstance(sql, TextClause):
sql = sql.execution_options(
autocommit=True,
)

with self.get_sql_connection() as conn:
try:
Expand Down Expand Up @@ -852,7 +849,7 @@ def _write_files_to_new_table(
schema=self.sql_config.schema_name,
if_exists="append",
index=False,
dtype=sql_column_definitions,
dtype=sql_column_definitions, # type: ignore[arg-type]
)
return temp_table_name

Expand Down Expand Up @@ -1117,7 +1114,7 @@ def _emulated_merge_temp_table_to_final_table(

# Select records from temp_table that are not in final_table
select_new_records_stmt = (
select([temp_table]).select_from(joined_table).where(where_not_exists_clause)
select(temp_table).select_from(joined_table).where(where_not_exists_clause)
)

# Craft the INSERT statement using the select statement
Expand Down
Loading
Loading