diff --git a/airbyte/_processors/sql/bigquery.py b/airbyte/_processors/sql/bigquery.py index 87116fd4..7ebfac7d 100644 --- a/airbyte/_processors/sql/bigquery.py +++ b/airbyte/_processors/sql/bigquery.py @@ -196,3 +196,36 @@ def _get_tables_list( table.replace(schema_prefix, "", 1) if table.startswith(schema_prefix) else table for table in tables ] + + def _swap_temp_table_with_final_table( + self, + stream_name: str, + temp_table_name: str, + final_table_name: str, + ) -> None: + """Swap the temp table with the main one, dropping the old version of the 'final' table. + + The BigQuery RENAME implementation requires that the table schema (dataset) is named in the + first part of the ALTER statement, but not in the second part. + + For example, BigQuery expects this format: + + ALTER TABLE my_schema.my_old_table_name RENAME TO my_new_table_name; + """ + if final_table_name is None: + raise exc.AirbyteLibInternalError(message="Arg 'final_table_name' cannot be None.") + if temp_table_name is None: + raise exc.AirbyteLibInternalError(message="Arg 'temp_table_name' cannot be None.") + + _ = stream_name + deletion_name = f"{final_table_name}_deleteme" + commands = "\n".join( + [ + f"ALTER TABLE {self._fully_qualified(final_table_name)} " + f"RENAME TO {deletion_name};", + f"ALTER TABLE {self._fully_qualified(temp_table_name)} " + f"RENAME TO {final_table_name};", + f"DROP TABLE {self._fully_qualified(deletion_name)};", + ] + ) + self._execute_sql(commands) diff --git a/airbyte/_util/google_secrets.py b/airbyte/_util/google_secrets.py new file mode 100644 index 00000000..7ff426dc --- /dev/null +++ b/airbyte/_util/google_secrets.py @@ -0,0 +1,48 @@ +"""Helpers for accessing Google secrets.""" + +from __future__ import annotations + +import json +import os + +from google.cloud import secretmanager + + +def get_gcp_secret( + project_name: str, + secret_name: str, +) -> str: + """Try to get a GCP secret from the environment, or raise an error. + + We assume that the Google service account credentials file contents are stored in the + environment variable GCP_GSM_CREDENTIALS. If this environment variable is not set, we raise an + error. Otherwise, we use the Google Secret Manager API to fetch the secret with the given name. + """ + if "GCP_GSM_CREDENTIALS" not in os.environ: + raise EnvironmentError( # noqa: TRY003, UP024 + "GCP_GSM_CREDENTIALS env variable not set, can't fetch secrets. Make sure they are set " + "up as described: " + "https://github.com/airbytehq/airbyte/blob/master/airbyte-ci/connectors/ci_credentials/" + "README.md#get-gsm-access" + ) + + # load secrets from GSM using the GCP_GSM_CREDENTIALS env variable + secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info( + json.loads(os.environ["GCP_GSM_CREDENTIALS"]) + ) + return secret_client.access_secret_version( + name=f"projects/{project_name}/secrets/{secret_name}/versions/latest" + ).payload.data.decode("UTF-8") + + +def get_gcp_secret_json( + project_name: str, + secret_name: str, +) -> dict: + """Get a JSON GCP secret and return as a dict. + + We assume that the Google service account credentials file contents are stored in the + environment variable GCP_GSM_CREDENTIALS. If this environment variable is not set, we raise an + error. Otherwise, we use the Google Secret Manager API to fetch the secret with the given name. + """ + return json.loads(get_gcp_secret(secret_name, project_name)) diff --git a/airbyte/caches/bigquery.py b/airbyte/caches/bigquery.py index 6bc6f677..993d1ae5 100644 --- a/airbyte/caches/bigquery.py +++ b/airbyte/caches/bigquery.py @@ -18,8 +18,10 @@ from __future__ import annotations import urllib +from typing import Any from overrides import overrides +from pydantic import root_validator from airbyte._processors.sql.bigquery import BigQuerySqlProcessor from airbyte.caches.base import ( @@ -41,9 +43,14 @@ class BigQueryCache(CacheBase): _sql_processor_class: type[BigQuerySqlProcessor] = BigQuerySqlProcessor - def __post_init__(self) -> None: - """Initialize the BigQuery cache.""" - self.schema_name = self.dataset_name + @root_validator(pre=True) + @classmethod + def set_schema_name(cls, values: dict[str, Any]) -> dict[str, Any]: + dataset_name = values.get("dataset_name") + if dataset_name is None: + raise ValueError("dataset_name must be defined") # noqa: TRY003 + values["schema_name"] = dataset_name + return values @overrides def get_database_name(self) -> str: diff --git a/examples/run_bigquery_faker.py b/examples/run_bigquery_faker.py index f4ac4922..eb1f7139 100644 --- a/examples/run_bigquery_faker.py +++ b/examples/run_bigquery_faker.py @@ -7,29 +7,20 @@ from __future__ import annotations -import json -import os import tempfile import warnings -from google.cloud import secretmanager - import airbyte as ab +from airbyte._util.google_secrets import get_gcp_secret_json from airbyte.caches.bigquery import BigQueryCache warnings.filterwarnings("ignore", message="Cannot create BigQuery Storage client") -# load secrets from GSM using the GCP_GSM_CREDENTIALS env variable -secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info( - json.loads(os.environ["GCP_GSM_CREDENTIALS"]) -) - -bigquery_destination_secret = json.loads( - secret_client.access_secret_version( - name="projects/dataline-integration-testing/secrets/SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS/versions/latest" - ).payload.data.decode("UTF-8") +bigquery_destination_secret = get_gcp_secret_json( + project_name="dataline-integration-testing", + secret_name="SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS", ) @@ -55,6 +46,9 @@ def main() -> None: result = source.read(cache) + # Read a second time to make sure table swaps and incremental are working. + result = source.read(cache) + for name, records in result.streams.items(): print(f"Stream {name}: {len(records)} records") diff --git a/examples/run_integ_test_source.py b/examples/run_integ_test_source.py index 4e7f472a..51fa1de7 100644 --- a/examples/run_integ_test_source.py +++ b/examples/run_integ_test_source.py @@ -11,14 +11,13 @@ """ from __future__ import annotations -import json -import os import sys -from typing import Any - -from google.cloud import secretmanager import airbyte as ab +from airbyte._util.google_secrets import get_gcp_secret_json + + +GCP_SECRETS_PROJECT_NAME = "dataline-integration-testing" def get_secret_name(connector_name: str) -> str: @@ -35,31 +34,15 @@ def get_secret_name(connector_name: str) -> str: return f"SECRET_{connector_name.upper()}_CREDS" -def get_integ_test_config(secret_name: str) -> dict[str, Any]: - if "GCP_GSM_CREDENTIALS" not in os.environ: - raise Exception( # noqa: TRY002, TRY003 - f"GCP_GSM_CREDENTIALS env var not set, can't fetch secrets for '{connector_name}'. " - "Make sure they are set up as described: " - "https://github.com/airbytehq/airbyte/blob/master/airbyte-ci/connectors/ci_credentials/" - "README.md#get-gsm-access" - ) - - secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info( - json.loads(os.environ["GCP_GSM_CREDENTIALS"]) - ) - return json.loads( - secret_client.access_secret_version( - name=f"projects/dataline-integration-testing/secrets/{secret_name}/versions/latest" - ).payload.data.decode("UTF-8") - ) - - def main( connector_name: str, secret_name: str | None, streams: list[str] | None, ) -> None: - config = get_integ_test_config(secret_name) + config = get_gcp_secret_json( + secret_name=secret_name, + project_name=GCP_SECRETS_PROJECT_NAME, + ) source = ab.get_source( connector_name, config=config, diff --git a/examples/run_snowflake_faker.py b/examples/run_snowflake_faker.py index 93a386f0..d0ea5910 100644 --- a/examples/run_snowflake_faker.py +++ b/examples/run_snowflake_faker.py @@ -1,12 +1,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. from __future__ import annotations -import json -import os - -from google.cloud import secretmanager - import airbyte as ab +from airbyte._util.google_secrets import get_gcp_secret_json from airbyte.caches import SnowflakeCache @@ -16,14 +12,9 @@ install_if_missing=True, ) -# load secrets from GSM using the GCP_GSM_CREDENTIALS env variable -secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info( - json.loads(os.environ["GCP_GSM_CREDENTIALS"]) -) -secret = json.loads( - secret_client.access_secret_version( - name="projects/dataline-integration-testing/secrets/AIRBYTE_LIB_SNOWFLAKE_CREDS/versions/latest" - ).payload.data.decode("UTF-8") +secret = get_gcp_secret_json( + project_name="dataline-integration-testing", + secret_name="AIRBYTE_LIB_SNOWFLAKE_CREDS", ) cache = SnowflakeCache( diff --git a/tests/conftest.py b/tests/conftest.py index f2fa114a..1dbad159 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ """Global pytest fixtures.""" from __future__ import annotations +from contextlib import suppress import json import logging import os @@ -13,21 +14,23 @@ from requests.exceptions import HTTPError import ulid +from airbyte._util.google_secrets import get_gcp_secret from airbyte._util.meta import is_windows from airbyte.caches.base import CacheBase +from airbyte.caches.bigquery import BigQueryCache +from airbyte.caches.duckdb import DuckDBCache from airbyte.caches.snowflake import SnowflakeCache import docker import psycopg2 as psycopg import pytest from _pytest.nodes import Item -from google.cloud import secretmanager from sqlalchemy import create_engine from airbyte.caches import PostgresCache from airbyte._executor import _get_bin_dir from airbyte.caches.util import new_local_cache -from tests.integration_tests.test_source_faker_integration import all_cache_types +from airbyte.sources.base import as_temp_files logger = logging.getLogger(__name__) @@ -38,6 +41,23 @@ LOCAL_TEST_REGISTRY_URL = "./tests/integration_tests/fixtures/registry.json" +AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing" + + +def get_ci_secret( + secret_name, + project_name: str = AIRBYTE_INTERNAL_GCP_PROJECT, +) -> str: + return get_gcp_secret(project_name=project_name, secret_name=secret_name) + + +def get_ci_secret_json( + secret_name, + project_name: str = AIRBYTE_INTERNAL_GCP_PROJECT, +) -> dict: + return json.loads(get_ci_secret(secret_name=secret_name, project_name=project_name)) + + def pytest_collection_modifyitems(items: list[Item]) -> None: """Override default pytest behavior, sorting our tests in a sensible execution order. @@ -193,15 +213,8 @@ def new_postgres_cache(): @pytest.fixture def new_snowflake_cache(): - if "GCP_GSM_CREDENTIALS" not in os.environ: - raise Exception("GCP_GSM_CREDENTIALS env variable not set, can't fetch secrets for Snowflake. Make sure they are set up as described: https://github.com/airbytehq/airbyte/blob/master/airbyte-ci/connectors/ci_credentials/README.md#get-gsm-access") - secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info( - json.loads(os.environ["GCP_GSM_CREDENTIALS"]) - ) - secret = json.loads( - secret_client.access_secret_version( - name="projects/dataline-integration-testing/secrets/AIRBYTE_LIB_SNOWFLAKE_CREDS/versions/latest" - ).payload.data.decode("UTF-8") + secret = get_ci_secret_json( + "AIRBYTE_LIB_SNOWFLAKE_CREDS", ) config = SnowflakeCache( account=secret["account"], @@ -220,6 +233,30 @@ def new_snowflake_cache(): connection.execute(f"DROP SCHEMA IF EXISTS {config.schema_name}") +@pytest.fixture +@pytest.mark.requires_creds +def new_bigquery_cache(): + dest_bigquery_config = get_ci_secret_json( + "SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS" + ) + + dataset_name = f"test_deleteme_{str(ulid.ULID()).lower()[-6:]}" + credentials_json = dest_bigquery_config["credentials_json"] + with as_temp_files([credentials_json]) as (credentials_path,): + cache = BigQueryCache( + credentials_path=credentials_path, + project_name=dest_bigquery_config["project_id"], + dataset_name=dataset_name + ) + yield cache + + url = cache.get_sql_alchemy_url() + engine = create_engine(url) + with suppress(Exception): + with engine.begin() as connection: + connection.execute(f"DROP SCHEMA IF EXISTS {cache.schema_name}") + + @pytest.fixture(autouse=True) def source_test_registry(monkeypatch): """ @@ -268,3 +305,36 @@ def source_test_installation(): yield shutil.rmtree(venv_dir) + + +@pytest.fixture(scope="function") +def new_duckdb_cache() -> DuckDBCache: + return new_local_cache() + + +@pytest.fixture(scope="function") +def new_generic_cache(request) -> CacheBase: + """This is a placeholder fixture that will be overridden by pytest_generate_tests().""" + return request.getfixturevalue(request.param) + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + """Override default pytest behavior, parameterizing our tests based on the available cache types. + + This is useful for running the same tests with different cache types, to ensure that the tests + can pass across all cache types. + """ + all_cache_type_fixtures: dict[str, str] = { + "BigQuery": "new_bigquery_cache", + "DuckDB": "new_duckdb_cache", + "Postgres": "new_pg_cache", + "Snowflake": "new_snowflake_cache", + } + if "new_generic_cache" in metafunc.fixturenames: + metafunc.parametrize( + "new_generic_cache", + all_cache_type_fixtures.values(), + ids=all_cache_type_fixtures.keys(), + indirect=True, + scope="function", + ) diff --git a/tests/integration_tests/test_snowflake_cache.py b/tests/integration_tests/test_all_cache_types.py similarity index 84% rename from tests/integration_tests/test_snowflake_cache.py rename to tests/integration_tests/test_all_cache_types.py index 962ddd69..d2462186 100644 --- a/tests/integration_tests/test_snowflake_cache.py +++ b/tests/integration_tests/test_all_cache_types.py @@ -7,7 +7,6 @@ """ from __future__ import annotations -from collections.abc import Generator import os from pathlib import Path import sys @@ -15,9 +14,9 @@ import pytest import airbyte as ab -from airbyte import caches from airbyte._executor import _get_bin_dir + # Product count is always the same, regardless of faker scale. NUM_PRODUCTS = 100 @@ -87,13 +86,13 @@ def source_faker_seed_b() -> ab.Source: #@viztracer.trace_and_save(output_dir=".pytest_cache/snowflake_trace/") @pytest.mark.requires_creds @pytest.mark.slow -def test_faker_read_to_snowflake( +def test_faker_read( source_faker_seed_a: ab.Source, - new_snowflake_cache: ab.SnowflakeCache, + new_generic_cache: ab.caches.CacheBase, ) -> None: """Test that the append strategy works as expected.""" result = source_faker_seed_a.read( - new_snowflake_cache, write_strategy="replace", force_full_refresh=True + new_generic_cache, write_strategy="replace", force_full_refresh=True ) assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A @@ -102,12 +101,12 @@ def test_faker_read_to_snowflake( @pytest.mark.slow def test_replace_strategy( source_faker_seed_a: ab.Source, - new_snowflake_cache: ab.SnowflakeCache, + new_generic_cache: ab.caches.CacheBase, ) -> None: """Test that the append strategy works as expected.""" for _ in range(2): result = source_faker_seed_a.read( - new_snowflake_cache, write_strategy="replace", force_full_refresh=True + new_generic_cache, write_strategy="replace", force_full_refresh=True ) assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A @@ -117,27 +116,31 @@ def test_replace_strategy( def test_merge_strategy( source_faker_seed_a: ab.Source, source_faker_seed_b: ab.Source, - new_snowflake_cache: ab.caches.SnowflakeCache, + new_generic_cache: ab.caches.CacheBase, ) -> None: """Test that the merge strategy works as expected. Since all streams have primary keys, we should expect the auto strategy to be identical to the merge strategy. """ + + assert new_generic_cache, "Cache should not be None." + # First run, seed A (counts should match the scale or the product count) - result = source_faker_seed_a.read(new_snowflake_cache, write_strategy="merge") - assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A + result = source_faker_seed_a.read(new_generic_cache, write_strategy="merge") + assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A, \ + f"Incorrect number of records in the cache. {new_generic_cache}" # Second run, also seed A (should have same exact data, no change in counts) - result = source_faker_seed_a.read(new_snowflake_cache, write_strategy="merge") + result = source_faker_seed_a.read(new_generic_cache, write_strategy="merge") assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A # Third run, seed B - should increase record count to the scale of B, which is greater than A. # TODO: See if we can reliably predict the exact number of records, since we use fixed seeds. - result = source_faker_seed_b.read(new_snowflake_cache, write_strategy="merge") + result = source_faker_seed_b.read(new_generic_cache, write_strategy="merge") assert len(list(result.cache.streams["users"])) == FAKER_SCALE_B # Third run, seed A again - count should stay at scale B, since A is smaller. # TODO: See if we can reliably predict the exact number of records, since we use fixed seeds. - result = source_faker_seed_a.read(new_snowflake_cache, write_strategy="merge") + result = source_faker_seed_a.read(new_generic_cache, write_strategy="merge") assert len(list(result.cache.streams["users"])) == FAKER_SCALE_B diff --git a/tests/unit_tests/test_bigquery_cache.py b/tests/unit_tests/test_bigquery_cache.py new file mode 100644 index 00000000..391263f9 --- /dev/null +++ b/tests/unit_tests/test_bigquery_cache.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""Unit tests specific to BigQuery caches.""" +from __future__ import annotations + +import airbyte as ab + + +def test_bigquery_props( + new_bigquery_cache: ab.BigQueryCache, +) -> None: + """Test that the BigQueryCache properties are set correctly.""" + # assert new_bigquery_cache.credentials_path.endswith(".json") + assert new_bigquery_cache.dataset_name == new_bigquery_cache.schema_name, \ + "Dataset name should be the same as schema name." + assert new_bigquery_cache.schema_name != "airbyte_raw" \ + "Schema name should not be the default value." + + assert new_bigquery_cache.get_database_name() == new_bigquery_cache.project_name, \ + "Database name should be the same as project name."