From 5a8ebf575d6bee306a8693585aca9d9b33317e0e Mon Sep 17 00:00:00 2001 From: "Aaron (\"AJ\") Steers" Date: Fri, 8 Mar 2024 14:31:09 -0800 Subject: [PATCH 1/3] Docs: Add usage samples for each cache option (#112) --- airbyte/caches/bigquery.py | 21 ++++++++++++++++++++- airbyte/caches/duckdb.py | 18 +++++++++++++++--- airbyte/caches/motherduck.py | 16 +++++++++++++++- airbyte/caches/postgres.py | 18 +++++++++++++++++- airbyte/caches/snowflake.py | 20 +++++++++++++++++++- 5 files changed, 86 insertions(+), 7 deletions(-) diff --git a/airbyte/caches/bigquery.py b/airbyte/caches/bigquery.py index 3ff93257..6bc6f677 100644 --- a/airbyte/caches/bigquery.py +++ b/airbyte/caches/bigquery.py @@ -1,5 +1,19 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -"""A BigQuery implementation of the cache.""" +"""A BigQuery implementation of the cache. + +## Usage Example + +```python +import airbyte as ab +from airbyte.caches import BigQueryCache + +cache = BigQueryCache( + project_name="myproject", + dataset_name="mydataset", + credentials_path="path/to/credentials.json", +) +``` +""" from __future__ import annotations @@ -17,8 +31,13 @@ class BigQueryCache(CacheBase): """The BigQuery cache implementation.""" project_name: str + """The name of the project to use. In BigQuery, this is equivalent to the database name.""" + dataset_name: str = "airbyte_raw" + """The name of the dataset to use. In BigQuery, this is equivalent to the schema name.""" + credentials_path: str + """The path to the credentials file to use.""" _sql_processor_class: type[BigQuerySqlProcessor] = BigQuerySqlProcessor diff --git a/airbyte/caches/duckdb.py b/airbyte/caches/duckdb.py index 91693ac5..5ed2abc2 100644 --- a/airbyte/caches/duckdb.py +++ b/airbyte/caches/duckdb.py @@ -1,5 +1,17 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -"""A DuckDB implementation of the cache.""" +"""A DuckDB implementation of the PyAirbyte cache. + +## Usage Example + +```python +from airbyte as ab +from airbyte.caches import DuckDBCache + +cache = DuckDBCache( + db_path="/path/to/my/database.duckdb", + schema_name="myschema", +) +""" from __future__ import annotations @@ -27,8 +39,8 @@ class DuckDBCache(CacheBase): db_path: Union[Path, str] """Normally db_path is a Path object. - There are some cases, such as when connecting to MotherDuck, where it could be a string that - is not also a path, such as "md:" to connect the user's default MotherDuck DB. + The database name will be inferred from the file name. For example, given a `db_path` of + `/path/to/my/my_db.duckdb`, the database name is `my_db`. """ schema_name: str = "main" diff --git a/airbyte/caches/motherduck.py b/airbyte/caches/motherduck.py index 417dabfb..99b599cc 100644 --- a/airbyte/caches/motherduck.py +++ b/airbyte/caches/motherduck.py @@ -1,4 +1,18 @@ -"""A cache implementation for the MotherDuck service, built on DuckDB.""" +"""A MotherDuck implementation of the PyAirbyte cache, built on DuckDB. + +## Usage Example + +```python +from airbyte as ab +from airbyte.caches import MotherDuckCache + +cache = MotherDuckCache( + database="mydatabase", + schema_name="myschema", + api_key=ab.get_secret("MOTHERDUCK_API_KEY"), +) +""" + from __future__ import annotations from overrides import overrides diff --git a/airbyte/caches/postgres.py b/airbyte/caches/postgres.py index a9cd4727..5d4c33e2 100644 --- a/airbyte/caches/postgres.py +++ b/airbyte/caches/postgres.py @@ -1,5 +1,21 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -"""A Postgres implementation of the cache.""" +"""A Postgres implementation of the PyAirbyte cache. + +## Usage Example + +```python +from airbyte as ab +from airbyte.caches import PostgresCache + +cache = PostgresCache( + host="myhost", + port=5432, + username="myusername", + password=ab.get_secret("POSTGRES_PASSWORD"), + database="mydatabase", +) +``` +""" from __future__ import annotations diff --git a/airbyte/caches/snowflake.py b/airbyte/caches/snowflake.py index 993307e8..f0e55f3b 100644 --- a/airbyte/caches/snowflake.py +++ b/airbyte/caches/snowflake.py @@ -1,5 +1,23 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -"""A Snowflake implementation of the cache.""" +"""A Snowflake implementation of the PyAirbyte cache. + +## Usage Example + +```python +from airbyte as ab +from airbyte.caches import SnowflakeCache + +cache = SnowflakeCache( + account="myaccount", + username="myusername", + password=ab.get_secret("SNOWFLAKE_PASSWORD"), + warehouse="mywarehouse", + database="mydatabase", + role="myrole", + schema_name="myschema", +) +``` +""" from __future__ import annotations From 1e45e42bbd6af2e22cd5d144b8bbbf2637e098cf Mon Sep 17 00:00:00 2001 From: "Aaron (\"AJ\") Steers" Date: Fri, 8 Mar 2024 15:58:01 -0800 Subject: [PATCH 2/3] Fix: Resolve issue where mixed-case stream properties would result in missing data (#114) --- airbyte/_processors/sql/base.py | 44 ++-- airbyte/_processors/sql/duckdb.py | 12 +- airbyte/_processors/sql/snowflake.py | 6 +- airbyte/_util/name_normalizers.py | 205 ++++++++++++++++++ airbyte/_util/text_util.py | 15 -- airbyte/sources/base.py | 26 +-- .../fixtures/source-test/source_test/run.py | 18 +- .../test_source_test_fixture.py | 10 + tests/unit_tests/test_text_normalization.py | 103 +++++++++ 9 files changed, 376 insertions(+), 63 deletions(-) create mode 100644 airbyte/_util/name_normalizers.py delete mode 100644 airbyte/_util/text_util.py create mode 100644 tests/unit_tests/test_text_normalization.py diff --git a/airbyte/_processors/sql/base.py b/airbyte/_processors/sql/base.py index 9beb0b97..c74b9a74 100644 --- a/airbyte/_processors/sql/base.py +++ b/airbyte/_processors/sql/base.py @@ -13,6 +13,7 @@ import sqlalchemy import ulid from overrides import overrides +from pandas import Index from sqlalchemy import ( Column, Table, @@ -29,7 +30,7 @@ from airbyte import exceptions as exc from airbyte._processors.base import RecordProcessor -from airbyte._util.text_util import lower_case_set +from airbyte._util.name_normalizers import LowerCaseNormalizer from airbyte.caches._catalog_manager import CatalogManager from airbyte.datasets._sql import CachedDataset from airbyte.progress import progress @@ -73,9 +74,17 @@ class SqlProcessorBase(RecordProcessor): """A base class to be used for SQL Caches.""" type_converter_class: type[SQLTypeConverter] = SQLTypeConverter + """The type converter class to use for converting JSON schema types to SQL types.""" + + normalizer = LowerCaseNormalizer + """The name normalizer to user for table and column name normalization.""" + file_writer_class: type[FileWriterBase] + """The file writer class to use for writing files to the cache.""" supports_merge_insert = False + """True if the database supports the MERGE INTO syntax.""" + use_singleton_connection = False # If true, the same connection is used for all operations. # Constructor: @@ -197,7 +206,7 @@ def get_sql_table_name( # TODO: Add default prefix based on the source name. - return self._normalize_table_name( + return self.normalizer.normalize( f"{table_prefix}{stream_name}{self.cache.table_suffix}", ) @@ -324,7 +333,7 @@ def _get_temp_table_name( ) -> str: """Return a new (unique) temporary table name.""" batch_id = batch_id or str(ulid.ULID()) - return self._normalize_table_name(f"{stream_name}_{batch_id}") + return self.normalizer.normalize(f"{stream_name}_{batch_id}") def _fully_qualified( self, @@ -414,11 +423,11 @@ def _ensure_compatible_table_schema( stream_column_names: list[str] = json_schema["properties"].keys() table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys() - lower_case_table_column_names = lower_case_set(table_column_names) + lower_case_table_column_names = self.normalizer.normalize_set(table_column_names) missing_columns = [ stream_col for stream_col in stream_column_names - if stream_col.lower() not in lower_case_table_column_names + if self.normalizer.normalize(stream_col) not in lower_case_table_column_names ] if missing_columns: if raise_on_error: @@ -452,17 +461,12 @@ def _create_table( """ _ = self._execute_sql(cmd) - def _normalize_column_name( - self, - raw_name: str, - ) -> str: - return raw_name.lower().replace(" ", "_").replace("-", "_") - - def _normalize_table_name( + def _get_stream_properties( self, - raw_name: str, - ) -> str: - return raw_name.lower().replace(" ", "_").replace("-", "_") + stream_name: str, + ) -> dict[str, dict]: + """Return the names of the top-level properties for the given stream.""" + return self._get_stream_json_schema(stream_name)["properties"] @final def _get_sql_column_definitions( @@ -471,9 +475,9 @@ def _get_sql_column_definitions( ) -> dict[str, sqlalchemy.types.TypeEngine]: """Return the column definitions for the given stream.""" columns: dict[str, sqlalchemy.types.TypeEngine] = {} - properties = self._get_stream_json_schema(stream_name)["properties"] + properties = self._get_stream_properties(stream_name) for property_name, json_schema_property_def in properties.items(): - clean_prop_name = self._normalize_column_name(property_name) + clean_prop_name = self.normalizer.normalize(property_name) columns[clean_prop_name] = self.type_converter.to_sql_type( json_schema_property_def, ) @@ -635,6 +639,12 @@ def _write_files_to_new_table( }, ) + # Normalize all column names to lower case. + dataframe.columns = Index( + [LowerCaseNormalizer.normalize(col) for col in dataframe.columns] + ) + + # Write the data to the table. dataframe.to_sql( temp_table_name, self.get_sql_alchemy_url(), diff --git a/airbyte/_processors/sql/duckdb.py b/airbyte/_processors/sql/duckdb.py index a31f39a3..0269323a 100644 --- a/airbyte/_processors/sql/duckdb.py +++ b/airbyte/_processors/sql/duckdb.py @@ -84,6 +84,7 @@ def _write_files_to_new_table( stream_name=stream_name, batch_id=batch_id, ) + properties_list = list(self._get_stream_properties(stream_name).keys()) columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys()) columns_list_str = indent( "\n, ".join([self._quote_identifier(c) for c in columns_list]), @@ -93,9 +94,14 @@ def _write_files_to_new_table( columns_type_map = indent( "\n, ".join( [ - f"{self._quote_identifier(c)}: " - f"{self._get_sql_column_definitions(stream_name)[c]!s}" - for c in columns_list + self._quote_identifier(prop_name) + + ": " + + str( + self._get_sql_column_definitions(stream_name)[ + self.normalizer.normalize(prop_name) + ] + ) + for prop_name in properties_list ] ), " ", diff --git a/airbyte/_processors/sql/snowflake.py b/airbyte/_processors/sql/snowflake.py index 12c04410..7806d91f 100644 --- a/airbyte/_processors/sql/snowflake.py +++ b/airbyte/_processors/sql/snowflake.py @@ -67,14 +67,16 @@ def _write_files_to_new_table( ] ) self._execute_sql(put_files_statements) - + properties_list: list[str] = list(self._get_stream_properties(stream_name).keys()) columns_list = [ self._quote_identifier(c) for c in list(self._get_sql_column_definitions(stream_name).keys()) ] files_list = ", ".join([f"'{f.name}'" for f in files]) columns_list_str: str = indent("\n, ".join(columns_list), " " * 12) - variant_cols_str: str = ("\n" + " " * 21 + ", ").join([f"$1:{col}" for col in columns_list]) + variant_cols_str: str = ("\n" + " " * 21 + ", ").join( + [f"$1:{col}" for col in properties_list] + ) copy_statement = dedent( f""" COPY INTO {temp_table_name} diff --git a/airbyte/_util/name_normalizers.py b/airbyte/_util/name_normalizers.py new file mode 100644 index 00000000..a771ba31 --- /dev/null +++ b/airbyte/_util/name_normalizers.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""Name normalizer classes.""" + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + +class NameNormalizerBase(abc.ABC): + """Abstract base class for name normalizers.""" + + @staticmethod + @abc.abstractmethod + def normalize(name: str) -> str: + """Return the normalized name.""" + ... + + @classmethod + def normalize_set(cls, str_iter: Iterable[str]) -> set[str]: + """Converts string iterable to a set of lower case strings.""" + return {cls.normalize(s) for s in str_iter} + + @classmethod + def normalize_list(cls, str_iter: Iterable[str]) -> list[str]: + """Converts string iterable to a list of lower case strings.""" + return [cls.normalize(s) for s in str_iter] + + @classmethod + def check_matched(cls, name1: str, name2: str) -> bool: + """Return True if the two names match after each is normalized.""" + return cls.normalize(name1) == cls.normalize(name2) + + @classmethod + def check_normalized(cls, name: str) -> bool: + """Return True if the name is already normalized.""" + return cls.normalize(name) == name + + +class LowerCaseNormalizer(NameNormalizerBase): + """A name normalizer that converts names to lower case.""" + + @staticmethod + def normalize(name: str) -> str: + """Return the normalized name.""" + return name.lower().replace(" ", "_").replace("-", "_") + + +class CaseInsensitiveDict(dict[str, Any]): + """A case-aware, case-insensitive dictionary implementation. + + It has these behaviors: + - When a key is retrieved, deleted, or checked for existence, it is always checked in a + case-insensitive manner. + - The original case is stored in a separate dictionary, so that the original case can be + retrieved when needed. + + There are two ways to store keys internally: + - If normalize_keys is True, the keys are normalized using the given normalizer. + - If normalize_keys is False, the original case of the keys is stored. + + In regards to missing values, the dictionary accepts an 'expected_keys' input. When set, the + dictionary will be initialized with the given keys. If a key is not found in the input data, it + will be initialized with a value of None. When provided, the 'expected_keys' input will also + determine the original case of the keys. + """ + + def _display_case(self, key: str) -> str: + """Return the original case of the key.""" + return self._pretty_case_keys[self._normalizer.normalize(key)] + + def _index_case(self, key: str) -> str: + """Return the internal case of the key. + + If normalize_keys is True, return the normalized key. + Otherwise, return the original case of the key. + """ + if self._normalize_keys: + return self._normalizer.normalize(key) + + return self._display_case(key) + + def __init__( + self, + from_dict: dict, + *, + normalize_keys: bool = True, + normalizer: type[NameNormalizerBase] | None = None, + expected_keys: list[str] | None = None, + ) -> None: + """Initialize the dictionary with the given data. + + If normalize_keys is True, the keys will be normalized using the given normalizer. + If expected_keys is provided, the dictionary will be initialized with the given keys. + """ + # If no normalizer is provided, use LowerCaseNormalizer. + self._normalize_keys = normalize_keys + self._normalizer: type[NameNormalizerBase] = normalizer or LowerCaseNormalizer + + # If no expected keys are provided, use all keys from the input dictionary. + if not expected_keys: + expected_keys = list(from_dict.keys()) + + # Store a lookup from normalized keys to pretty cased (originally cased) keys. + self._pretty_case_keys: dict[str, str] = { + self._normalizer.normalize(pretty_case.lower()): pretty_case + for pretty_case in expected_keys + } + + if normalize_keys: + index_keys = [self._normalizer.normalize(key) for key in expected_keys] + else: + index_keys = expected_keys + + self.update({k: None for k in index_keys}) # Start by initializing all values to None + for k, v in from_dict.items(): + self[self._index_case(k)] = v + + def __getitem__(self, key: str) -> Any: # noqa: ANN401 + if super().__contains__(key): + return super().__getitem__(key) + + if super().__contains__(self._index_case(key)): + return super().__getitem__(self._index_case(key)) + + raise KeyError(key) + + def __setitem__(self, key: str, value: Any) -> None: # noqa: ANN401 + if super().__contains__(key): + super().__setitem__(key, value) + return + + if super().__contains__(self._index_case(key)): + super().__setitem__(self._index_case(key), value) + return + + # Store the pretty cased (originally cased) key: + self._pretty_case_keys[self._normalizer.normalize(key)] = key + + # Store the data with the normalized key: + super().__setitem__(self._index_case(key), value) + + def __delitem__(self, key: str) -> None: + if super().__contains__(key): + super().__delitem__(key) + return + + if super().__contains__(self._index_case(key)): + super().__delitem__(self._index_case(key)) + return + + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + assert isinstance(key, str), "Key must be a string." + return super().__contains__(key) or super().__contains__(self._index_case(key)) + + def __iter__(self) -> Any: # noqa: ANN401 + return iter(super().__iter__()) + + def __len__(self) -> int: + return super().__len__() + + def __eq__(self, other: object) -> bool: + if isinstance(other, CaseInsensitiveDict): + return dict(self) == dict(other) + + if isinstance(other, dict): + return {k.lower(): v for k, v in self.items()} == { + k.lower(): v for k, v in other.items() + } + return False + + +def normalize_records( + records: Iterable[dict[str, Any]], + expected_keys: list[str], +) -> Iterator[CaseInsensitiveDict]: + """Add missing columns to the record with null values. + + Also conform the column names to the case in the catalog. + + This is a generator that yields CaseInsensitiveDicts, which allows for case-insensitive + lookups of columns. This is useful because the case of the columns in the records may + not match the case of the columns in the catalog. + """ + yield from ( + CaseInsensitiveDict( + from_dict=record, + expected_keys=expected_keys, + ) + for record in records + ) + + +__all__ = [ + "NameNormalizerBase", + "LowerCaseNormalizer", + "CaseInsensitiveDict", + "normalize_records", +] diff --git a/airbyte/_util/text_util.py b/airbyte/_util/text_util.py deleted file mode 100644 index d5f89099..00000000 --- a/airbyte/_util/text_util.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. - -"""Internal utility functions for dealing with text.""" -from __future__ import annotations - -from typing import TYPE_CHECKING - - -if TYPE_CHECKING: - from collections.abc import Iterable - - -def lower_case_set(str_iter: Iterable[str]) -> set[str]: - """Converts a list of strings to a set of lower case strings.""" - return {s.lower() for s in str_iter} diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index 0f8ac4b0..a8d47f30 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -5,7 +5,7 @@ import tempfile import warnings from contextlib import contextmanager, suppress -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import jsonschema import pendulum @@ -28,11 +28,11 @@ from airbyte import exceptions as exc from airbyte._util import protocol_util +from airbyte._util.name_normalizers import normalize_records from airbyte._util.telemetry import ( SyncState, send_telemetry, ) -from airbyte._util.text_util import lower_case_set # Internal utility functions from airbyte.caches.util import get_default_cache from airbyte.datasets._lazy import LazyDataset from airbyte.progress import progress @@ -329,29 +329,21 @@ def get_records(self, stream: str) -> LazyDataset: ) from KeyError(stream) configured_stream = configured_catalog.streams[0] - all_properties = set(configured_stream.stream.json_schema["properties"].keys()) + all_properties = cast( + list[str], list(configured_stream.stream.json_schema["properties"].keys()) + ) def _with_logging(records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]: self._log_sync_start(cache=None) yield from records self._log_sync_success(cache=None) - def _with_missing_columns(records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]: - """Add missing columns to the record with null values.""" - for record in records: - existing_properties_lower = lower_case_set(record.keys()) - appended_dict = { - prop: None - for prop in all_properties - if prop.lower() not in existing_properties_lower - } - yield {**record, **appended_dict} - iterator: Iterator[dict[str, Any]] = _with_logging( - _with_missing_columns( - protocol_util.airbyte_messages_to_record_dicts( + normalize_records( + records=protocol_util.airbyte_messages_to_record_dicts( self._read_with_catalog(configured_catalog), - ) + ), + expected_keys=all_properties, ) ) return LazyDataset( diff --git a/tests/integration_tests/fixtures/source-test/source_test/run.py b/tests/integration_tests/fixtures/source-test/source_test/run.py index 1d6c4820..a843fd41 100644 --- a/tests/integration_tests/fixtures/source-test/source_test/run.py +++ b/tests/integration_tests/fixtures/source-test/source_test/run.py @@ -17,8 +17,8 @@ "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { - "column1": {"type": "string"}, - "column2": {"type": "number"}, + "Column1": {"type": "string"}, + "Column2": {"type": "number"}, }, }, }, @@ -30,8 +30,8 @@ "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { - "column1": {"type": "string"}, - "column2": {"type": "number"}, + "Column1": {"type": "string"}, + "Column2": {"type": "number"}, "empty_column": {"type": "string"}, }, }, @@ -45,8 +45,8 @@ "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { - "column1": {"type": "string"}, - "column2": {"type": "number"}, + "Column1": {"type": "string"}, + "Column2": {"type": "number"}, "empty_column": {"type": "string"}, }, }, @@ -86,7 +86,7 @@ sample_record1_stream1 = { "type": "RECORD", "record": { - "data": {"column1": "value1", "column2": 1}, + "data": {"Column1": "value1", "Column2": 1}, "stream": "stream1", "emitted_at": 123456789, }, @@ -94,7 +94,7 @@ sample_record2_stream1 = { "type": "RECORD", "record": { - "data": {"column1": "value2", "column2": 2}, + "data": {"Column1": "value2", "Column2": 2}, "stream": "stream1", "emitted_at": 123456789, }, @@ -102,7 +102,7 @@ sample_record_stream2 = { "type": "RECORD", "record": { - "data": {"column1": "value1", "column2": 1}, + "data": {"Column1": "value1", "Column2": 1}, "stream": "stream2", "emitted_at": 123456789, }, diff --git a/tests/integration_tests/test_source_test_fixture.py b/tests/integration_tests/test_source_test_fixture.py index 33e1f30e..116ce871 100644 --- a/tests/integration_tests/test_source_test_fixture.py +++ b/tests/integration_tests/test_source_test_fixture.py @@ -249,6 +249,16 @@ def test_dataset_list_and_len(expected_test_stream_data): source = ab.get_source("source-test", config={"apiKey": "test"}) source.select_all_streams() + # Test the lazy dataset implementation + lazy_dataset = source.get_records("stream1") + # assert len(stream_1) == 2 # This is not supported by the lazy dataset + lazy_dataset_list = list(lazy_dataset) + # Make sure counts are correct + assert len(list(lazy_dataset_list)) == 2 + # Make sure records are correct + assert list(lazy_dataset_list) == [{"column1": "value1", "column2": 1}, {"column1": "value2", "column2": 2}] + + # Test the cached dataset implementation result: ReadResult = source.read(ab.new_local_cache()) stream_1 = result["stream1"] assert len(stream_1) == 2 diff --git a/tests/unit_tests/test_text_normalization.py b/tests/unit_tests/test_text_normalization.py new file mode 100644 index 00000000..0791552a --- /dev/null +++ b/tests/unit_tests/test_text_normalization.py @@ -0,0 +1,103 @@ +from math import exp +import pytest +from airbyte._util.name_normalizers import CaseInsensitiveDict, LowerCaseNormalizer + +def test_case_insensitive_dict(): + # Initialize a CaseInsensitiveDict + cid = CaseInsensitiveDict({"Upper": 1, "lower": 2}) + + # Test __getitem__ + assert cid["Upper"] == 1 + assert cid["lower"] == 2 + + # Test __len__ + assert len(cid) == 2 + + # Test __setitem__ and __getitem__ with case mismatch + cid["upper"] = 3 + cid["Lower"] = 4 + assert len(cid) == 2 + + assert cid["upper"] == 3 + assert cid["Lower"] == 4 + + # Test __contains__ + assert "Upper" in cid + assert "lower" in cid + assert "Upper" in cid + assert "lower" in cid + + # Test __contains__ with case-insensitive normalizer + assert "Upper" in cid + assert "lower" in cid + assert "upper" in cid + assert "Lower" in cid + + # Assert __eq__ + # When converting to dict, the keys should be normalized to the original case. + assert dict(cid) != {"Upper": 3, "lower": 4} + assert dict(cid) == {"upper": 3, "lower": 4} + # When comparing directly to a dict, should use case insensitive comparison. + assert cid == {"Upper": 3, "lower": 4} + assert cid == {"upper": 3, "Lower": 4} + + # Test __iter__ + assert set(key for key in cid) == {"upper", "lower"} + + # Test __delitem__ + del cid["lower"] + with pytest.raises(KeyError): + _ = cid["lower"] + assert len(cid) == 1 + + del cid["upper"] + with pytest.raises(KeyError): + _ = cid["upper"] + assert len(cid) == 0 + + + + +def test_case_insensitive_dict_w() -> None: + # Initialize a CaseInsensitiveDict + cid = CaseInsensitiveDict({"Upper": 1, "lower": 2}, expected_keys=["Upper", "lower", "other"]) + + # Test __len__ + assert len(cid) == 3 + + # Test __getitem__ + assert cid["Upper"] == 1 + assert cid["lower"] == 2 + assert cid["other"] is None + + # Use dict() to test exact contents + assert dict(cid) == {"upper": 1, "lower": 2, "other": None} + + # Assert case insensitivity when comparing natively to a dict + assert cid == {"UPPER": 1, "lower": 2, "other": None} + assert cid == {"upper": 1, "lower": 2, "other": None} + + + +def test_case_insensitive_w_pretty_keys() -> None: + # Initialize a CaseInsensitiveDict + cid = CaseInsensitiveDict( + {"Upper": 1, "lower": 2}, + expected_keys=["Upper", "lower", "other"], + normalize_keys=False, + ) + + # Test __len__ + assert len(cid) == 3 + + # Test __getitem__ + assert cid["Upper"] == 1 + assert cid["lower"] == 2 + assert cid["other"] is None + + # Because we're not normalizing keys, the keys should be stored as-is + assert dict(cid) == {"Upper": 1, "lower": 2, "other": None} + + # Assert case insensitivity when comparing natively to a dict + assert cid == {"UPPER": 1, "lower": 2, "other": None} + assert cid == {"upper": 1, "lower": 2, "other": None} From 2f483ec1af451bbd8cc55b45893274ccceba3bcb Mon Sep 17 00:00:00 2001 From: "Aaron (\"AJ\") Steers" Date: Tue, 12 Mar 2024 16:55:39 -0700 Subject: [PATCH 3/3] Chore: Add anonymous user ID in tracking events (#124) --- .github/workflows/autofix.yml | 3 + .github/workflows/pydoc_preview.yml | 2 + .github/workflows/pydoc_publish.yml | 3 + .github/workflows/pypi_publish.yml | 3 + .github/workflows/python_lint.yml | 3 + .github/workflows/python_pytest.yml | 3 + .github/workflows/release_drafter.yml | 3 + .github/workflows/semantic_pr_check.yml | 3 + .github/workflows/slash_command_dispatch.yml | 3 + airbyte/_util/telemetry.py | 96 ++++++++++++++++++- .../unit_tests/test_anonymous_usage_stats.py | 78 ++++++++++++++- 11 files changed, 195 insertions(+), 5 deletions(-) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index f513c27b..4e9ecd51 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -5,6 +5,9 @@ on: repository_dispatch: types: [autofix-command] +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + jobs: python-autofix: runs-on: ubuntu-latest diff --git a/.github/workflows/pydoc_preview.yml b/.github/workflows/pydoc_preview.yml index 7fac9a63..8284dfde 100644 --- a/.github/workflows/pydoc_preview.yml +++ b/.github/workflows/pydoc_preview.yml @@ -6,6 +6,8 @@ on: - main pull_request: {} +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} jobs: preview_docs: diff --git a/.github/workflows/pydoc_publish.yml b/.github/workflows/pydoc_publish.yml index bbe9aeaa..0d719dbb 100644 --- a/.github/workflows/pydoc_publish.yml +++ b/.github/workflows/pydoc_publish.yml @@ -8,6 +8,9 @@ on: # Allows you to run this workflow manually from the Actions tab workflow_dispatch: +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages permissions: contents: read diff --git a/.github/workflows/pypi_publish.yml b/.github/workflows/pypi_publish.yml index efe374a6..f31f98b8 100644 --- a/.github/workflows/pypi_publish.yml +++ b/.github/workflows/pypi_publish.yml @@ -5,6 +5,9 @@ on: workflow_dispatch: +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + jobs: build: runs-on: ubuntu-latest diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml index 66e92d0d..7630d76b 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/python_lint.yml @@ -6,6 +6,9 @@ on: - main pull_request: {} +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + jobs: ruff-lint-check: name: Ruff Lint Check diff --git a/.github/workflows/python_pytest.yml b/.github/workflows/python_pytest.yml index 79c0fa90..12c126e5 100644 --- a/.github/workflows/python_pytest.yml +++ b/.github/workflows/python_pytest.yml @@ -13,6 +13,9 @@ on: - main pull_request: {} +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + jobs: pytest-fast: name: Pytest (Fast) diff --git a/.github/workflows/release_drafter.yml b/.github/workflows/release_drafter.yml index feadbd4a..e35a2372 100644 --- a/.github/workflows/release_drafter.yml +++ b/.github/workflows/release_drafter.yml @@ -5,6 +5,9 @@ on: branches: - main +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + permissions: contents: read diff --git a/.github/workflows/semantic_pr_check.yml b/.github/workflows/semantic_pr_check.yml index 7e311577..bd01b877 100644 --- a/.github/workflows/semantic_pr_check.yml +++ b/.github/workflows/semantic_pr_check.yml @@ -7,6 +7,9 @@ on: - edited - synchronize +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + permissions: pull-requests: read diff --git a/.github/workflows/slash_command_dispatch.yml b/.github/workflows/slash_command_dispatch.yml index 1da8ec92..14207ce6 100644 --- a/.github/workflows/slash_command_dispatch.yml +++ b/.github/workflows/slash_command_dispatch.yml @@ -4,6 +4,9 @@ on: issue_comment: types: [created] +env: + AIRBYTE_ANALYTICS_ID: ${{ vars.AIRBYTE_ANALYTICS_ID }} + jobs: slashCommandDispatch: runs-on: ubuntu-latest diff --git a/airbyte/_util/telemetry.py b/airbyte/_util/telemetry.py index 8eca682b..88d1bc91 100644 --- a/airbyte/_util/telemetry.py +++ b/airbyte/_util/telemetry.py @@ -37,10 +37,12 @@ from dataclasses import asdict, dataclass from enum import Enum from functools import lru_cache -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast import requests import ulid +import yaml from airbyte import exceptions as exc from airbyte._util import meta @@ -52,6 +54,10 @@ from airbyte.sources.base import Source +DEBUG = True +"""Enable debug mode for telemetry code.""" + + HASH_SEED = "PyAirbyte:" """Additional seed for randomizing one-way hashed strings.""" @@ -73,6 +79,92 @@ DO_NOT_TRACK = "DO_NOT_TRACK" """Environment variable to opt-out of telemetry.""" +_ENV_ANALYTICS_ID = "AIRBYTE_ANALYTICS_ID" # Allows user to override the anonymous user ID +_ANALYTICS_FILE = Path.home() / ".airbyte" / "analytics.yml" +_ANALYTICS_ID: str | bool | None = None + + +def _setup_analytics() -> str | bool: + """Set up the analytics file if it doesn't exist. + + Return the anonymous user ID or False if the user has opted out. + """ + anonymous_user_id: str | None = None + issues: list[str] = [] + + if os.environ.get(DO_NOT_TRACK): + # User has opted out of tracking. + return False + + if _ENV_ANALYTICS_ID in os.environ: + # If the user has chosen to override their analytics ID, use that value and + # remember it for future invocations. + anonymous_user_id = os.environ[_ENV_ANALYTICS_ID] + + if not _ANALYTICS_FILE.exists(): + # This is a one-time message to inform the user that we are tracking anonymous usage stats. + print( + "Anonymous usage reporting is enabled. For more information or to opt out, please" + " see https://docs.airbyte.io/pyairbyte/anonymized-usage-statistics" + ) + + if _ANALYTICS_FILE.exists(): + analytics_text = _ANALYTICS_FILE.read_text() + try: + analytics: dict = yaml.safe_load(analytics_text) + except Exception as ex: + issues += f"File appears corrupted. Error was: {ex!s}" + + if analytics and "anonymous_user_id" in analytics: + # The analytics ID was successfully located. + if not anonymous_user_id: + return analytics["anonymous_user_id"] + + if anonymous_user_id == analytics["anonymous_user_id"]: + # Values match, no need to update the file. + return analytics["anonymous_user_id"] + + issues.append("Provided analytics ID did not match the file. Rewriting the file.") + print( + f"Received a user-provided analytics ID override in the '{_ENV_ANALYTICS_ID}' " + "environment variable." + ) + + # File is missing, incomplete, or stale. Create a new one. + anonymous_user_id = anonymous_user_id or str(ulid.ULID()) + try: + _ANALYTICS_FILE.parent.mkdir(exist_ok=True, parents=True) + _ANALYTICS_FILE.write_text( + "# This file is used by PyAirbyte to track anonymous usage statistics.\n" + "# For more information or to opt out, please see\n" + "# - https://docs.airbyte.com/operator-guides/telemetry\n" + f"anonymous_user_id: {anonymous_user_id}\n" + ) + except Exception: + # Failed to create the analytics file. Likely due to a read-only filesystem. + issues.append("Failed to write the analytics file. Check filesystem permissions.") + pass + + if DEBUG and issues: + nl = "\n" + print(f"One or more issues occurred when configuring usage tracking:\n{nl.join(issues)}") + + return anonymous_user_id + + +def _get_analytics_id() -> str | None: + result: str | bool | None = _ANALYTICS_ID + if result is None: + result = _setup_analytics() + + if result is False: + return None + + return cast(str, result) + + +_ANALYTICS_ID = _get_analytics_id() + class SyncState(str, Enum): STARTED = "started" @@ -174,7 +266,7 @@ def send_telemetry( "https://api.segment.io/v1/track", auth=(PYAIRBYTE_APP_TRACKING_KEY, ""), json={ - "anonymousId": "airbyte-lib-user", + "anonymousId": _get_analytics_id(), "event": "sync", "properties": payload_props, "timestamp": datetime.datetime.utcnow().isoformat(), # noqa: DTZ003 diff --git a/tests/unit_tests/test_anonymous_usage_stats.py b/tests/unit_tests/test_anonymous_usage_stats.py index b1b953e8..8e65ea78 100644 --- a/tests/unit_tests/test_anonymous_usage_stats.py +++ b/tests/unit_tests/test_anonymous_usage_stats.py @@ -4,8 +4,10 @@ import itertools from contextlib import nullcontext as does_not_raise import json +import os +from pathlib import Path import re -from unittest.mock import Mock, call, patch +from unittest.mock import MagicMock, call, patch from freezegun import freeze_time import responses @@ -16,8 +18,6 @@ from airbyte.version import get_version import airbyte as ab from airbyte._util import telemetry -import requests -import datetime @responses.activate @@ -174,3 +174,75 @@ def test_tracking( } ) ]) + + +def test_setup_analytics_existing_file(monkeypatch): + # Mock the environment variable and the analytics file + monkeypatch.delenv(telemetry._ENV_ANALYTICS_ID, raising=False) + monkeypatch.delenv(telemetry.DO_NOT_TRACK, raising=False) + + monkeypatch.setattr(Path, 'exists', lambda x: True) + monkeypatch.setattr(Path, 'read_text', lambda x: "anonymous_user_id: test_id\n") + assert telemetry._setup_analytics() == 'test_id' + + +def test_setup_analytics_missing_file(monkeypatch): + """Mock the environment variable and the missing analytics file.""" + monkeypatch.setenv(telemetry._ENV_ANALYTICS_ID, 'test_id') + monkeypatch.delenv(telemetry.DO_NOT_TRACK, raising=False) + monkeypatch.setattr(Path, 'exists', lambda x: False) + + mock_path = MagicMock() + monkeypatch.setattr(Path, 'write_text', mock_path) + + assert telemetry._setup_analytics() == 'test_id' + + assert mock_path.call_count == 1 + + +def test_setup_analytics_read_only_filesystem(monkeypatch, capfd): + """Mock the environment variable and simulate a read-only filesystem.""" + monkeypatch.setenv(telemetry._ENV_ANALYTICS_ID, 'test_id') + monkeypatch.delenv(telemetry.DO_NOT_TRACK, raising=False) + monkeypatch.setattr(Path, 'exists', lambda x: False) + + mock_write_text = MagicMock(side_effect=PermissionError("Read-only filesystem")) + monkeypatch.setattr(Path, 'write_text', mock_write_text) + + # We should not raise an exception + assert telemetry._setup_analytics() == "test_id" + + assert mock_write_text.call_count == 1 + + # Capture print outputs + captured = capfd.readouterr() + + # Validate print message + assert "Read-only filesystem" not in captured.out + + +def test_setup_analytics_corrupt_file(monkeypatch): + """Mock the environment variable and the missing analytics file.""" + monkeypatch.delenv(telemetry._ENV_ANALYTICS_ID, raising=False) + monkeypatch.delenv(telemetry.DO_NOT_TRACK, raising=False) + monkeypatch.setattr(Path, 'exists', lambda x: True) + monkeypatch.setattr(Path, 'read_text', lambda x: "not-a-valid ::: yaml file\n") + + mock = MagicMock() + monkeypatch.setattr(Path, 'write_text', mock) + + assert telemetry._setup_analytics() + + assert mock.call_count == 1 + + +def test_get_analytics_id(monkeypatch): + # Mock the _ANALYTICS_ID variable + monkeypatch.delenv(telemetry._ENV_ANALYTICS_ID, raising=False) + monkeypatch.delenv(telemetry.DO_NOT_TRACK, raising=False) + monkeypatch.setattr(telemetry, '_ANALYTICS_ID', 'test_id') + + mock = MagicMock() + monkeypatch.setattr(Path, 'write_text', mock) + + assert telemetry._get_analytics_id() == 'test_id'