From 76f6d5f501611dc98dabb177d330de370499080e Mon Sep 17 00:00:00 2001 From: Serhii Dimchenko <39801237+svdimchenko@users.noreply.github.com> Date: Tue, 28 May 2024 19:14:35 +0200 Subject: [PATCH] feat: Implement iceberg retry logic (#657) Co-authored-by: nicor88 <6278547+nicor88@users.noreply.github.com> --- README.md | 5 +- dbt/adapters/athena/connections.py | 87 ++++++----- .../adapter/test_retries_iceberg.py | 135 ++++++++++++++++++ 3 files changed, 191 insertions(+), 36 deletions(-) create mode 100644 tests/functional/adapter/test_retries_iceberg.py diff --git a/README.md b/README.md index b7b23289..b78c4eb0 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ You can either: A dbt profile can be configured to run against AWS Athena using the following configuration: | Option | Description | Required? | Example | -| --------------------- | ---------------------------------------------------------------------------------------- | --------- | ------------------------------------------ | +|-----------------------|------------------------------------------------------------------------------------------|-----------|--------------------------------------------| | s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` | | s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` | | s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` | @@ -134,8 +134,9 @@ A dbt profile can be configured to run against AWS Athena using the following co | aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` | | work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | | num_retries | Number of times to retry a failing query | Optional | `3` | -| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` | | num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` | +| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `0` | +| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` | | seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` | | lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` | diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index 5962eaa1..8536ae3c 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -25,11 +25,12 @@ from pyathena.result_set import AthenaResultSet from pyathena.util import RetryConfig from tenacity import ( - Retrying, + retry, retry_if_exception, stop_after_attempt, wait_random_exponential, ) +from typing_extensions import Self from dbt.adapters.athena.config import get_boto3_config from dbt.adapters.athena.constants import LOGGER @@ -64,8 +65,9 @@ class AthenaCredentials(Credentials): _ALIASES = {"catalog": "database"} num_retries: int = 5 num_boto3_retries: Optional[int] = None + num_iceberg_retries: int = 3 s3_data_dir: Optional[str] = None - s3_data_naming: Optional[str] = "schema_table_unique" + s3_data_naming: str = "schema_table_unique" spark_work_group: Optional[str] = None s3_tmp_table_dir: Optional[str] = None # Unfortunately we can not just use dict, must be Dict because we'll get the following error: @@ -147,7 +149,7 @@ def __poll(self, query_id: str) -> AthenaQueryExecution: LOGGER.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...") time.sleep(self._poll_interval) - def execute( # type: ignore + def execute( self, operation: str, parameters: Optional[Dict[str, Any]] = None, @@ -157,35 +159,9 @@ def execute( # type: ignore cache_size: int = 0, cache_expiration_time: int = 0, catch_partitions_limit: bool = False, - **kwargs, - ): - def inner() -> AthenaCursor: - query_id = self._execute( - operation, - parameters=parameters, - work_group=work_group, - s3_staging_dir=s3_staging_dir, - cache_size=cache_size, - cache_expiration_time=cache_expiration_time, - ) - - LOGGER.debug(f"Athena query ID {query_id}") - - query_execution = self._executor.submit(self._collect_result_set, query_id).result() - if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: - self.result_set = self._result_set_class( - self._connection, - self._converter, - query_execution, - self.arraysize, - self._retry_config, - ) - - else: - raise OperationalError(query_execution.state_change_reason) - return self - - retry = Retrying( + **kwargs: Dict[str, Any], + ) -> Self: + @retry( # No need to retry if TOO_MANY_OPEN_PARTITIONS occurs. # Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry, # because not all files are removed immediately after first try to create table @@ -200,7 +176,47 @@ def inner() -> AthenaCursor: ), reraise=True, ) - return retry(inner) + def inner() -> AthenaCursor: + num_iceberg_retries = self.connection.cursor_kwargs.get("num_iceberg_retries") + 1 + + @retry( + # Nested retry is needed to handle ICEBERG_COMMIT_ERROR for parallel inserts + retry=retry_if_exception(lambda e: "ICEBERG_COMMIT_ERROR" in str(e)), + stop=stop_after_attempt(num_iceberg_retries), + wait=wait_random_exponential( + multiplier=num_iceberg_retries, + max=self._retry_config.max_delay, + exp_base=self._retry_config.exponential_base, + ), + reraise=True, + ) + def execute_with_iceberg_retries() -> AthenaCursor: + query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + cache_expiration_time=cache_expiration_time, + ) + + LOGGER.debug(f"Athena query ID {query_id}") + + query_execution = self._executor.submit(self._collect_result_set, query_id).result() + if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: + self.result_set = self._result_set_class( + self._connection, + self._converter, + query_execution, + self.arraysize, + self._retry_config, + ) + return self + raise OperationalError(query_execution.state_change_reason) + + return execute_with_iceberg_retries() # type: ignore + + return inner() # type: ignore class AthenaConnectionManager(SQLConnectionManager): @@ -243,7 +259,10 @@ def open(cls, connection: Connection) -> Connection: schema_name=creds.schema, work_group=creds.work_group, cursor_class=AthenaCursor, - cursor_kwargs={"debug_query_state": creds.debug_query_state}, + cursor_kwargs={ + "debug_query_state": creds.debug_query_state, + "num_iceberg_retries": creds.num_iceberg_retries, + }, formatter=AthenaParameterFormatter(), poll_interval=creds.poll_interval, session=get_boto3_session(connection), diff --git a/tests/functional/adapter/test_retries_iceberg.py b/tests/functional/adapter/test_retries_iceberg.py new file mode 100644 index 00000000..adb80498 --- /dev/null +++ b/tests/functional/adapter/test_retries_iceberg.py @@ -0,0 +1,135 @@ +"""Test parallel insert into iceberg table.""" +import copy +import os + +import pytest + +from dbt.artifacts.schemas.results import RunStatus +from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture + +PARALLELISM = 10 + +base_dbt_profile = { + "type": "athena", + "s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"), + "s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"), + "schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"), + "database": os.getenv("DBT_TEST_ATHENA_DATABASE"), + "region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"), + "threads": PARALLELISM, + "poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")), + "num_retries": 0, + "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"), + "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, +} + +models__target = """ +{{ + config( + table_type='iceberg', + materialized='table' + ) +}} + +select * from ( + values + (1, -1) +) as t (id, status) +limit 0 + +""" + +models__source = { + f"model_{i}.sql": f""" +{{{{ + config( + table_type='iceberg', + materialized='table', + tags=['src'], + pre_hook='insert into target values ({i}, {i})' + ) +}}}} + +select 1 as col +""" + for i in range(PARALLELISM) +} + +seeds__expected_target_init = "id,status" +seeds__expected_target_post = "id,status\n" + "\n".join([f"{i},{i}" for i in range(PARALLELISM)]) + + +class TestIcebergRetriesDisabled: + @pytest.fixture(scope="class") + def dbt_profile_target(self): + profile = copy.deepcopy(base_dbt_profile) + profile["num_iceberg_retries"] = 0 + return profile + + @pytest.fixture(scope="class") + def models(self): + return {**{"target.sql": models__target}, **models__source} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_target_init.csv": seeds__expected_target_init, + "expected_target_post.csv": seeds__expected_target_post, + } + + def test__retries_iceberg(self, project): + """Seed should match the model after run""" + + expected__init_seed_name = "expected_target_init" + run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"]) + + relation_name = "target" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + check_relations_equal(project.adapter, [relation_name, expected__init_seed_name]) + + expected__post_seed_name = "expected_target_post" + run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"]) + + run, log = run_dbt_and_capture(["run", "--select", "tag:src"], expect_pass=False) + assert any(model_run_result.status == RunStatus.Error for model_run_result in run.results) + assert "ICEBERG_COMMIT_ERROR" in log + + +class TestIcebergRetriesEnabled: + @pytest.fixture(scope="class") + def dbt_profile_target(self): + profile = copy.deepcopy(base_dbt_profile) + profile["num_iceberg_retries"] = 1 + return profile + + @pytest.fixture(scope="class") + def models(self): + return {**{"target.sql": models__target}, **models__source} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_target_init.csv": seeds__expected_target_init, + "expected_target_post.csv": seeds__expected_target_post, + } + + def test__retries_iceberg(self, project): + """Seed should match the model after run""" + + expected__init_seed_name = "expected_target_init" + run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"]) + + relation_name = "target" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + check_relations_equal(project.adapter, [relation_name, expected__init_seed_name]) + + expected__post_seed_name = "expected_target_post" + run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"]) + + run = run_dbt(["run", "--select", "tag:src"]) + assert all([model_run_result.status == RunStatus.Success for model_run_result in run.results]) + check_relations_equal(project.adapter, [relation_name, expected__post_seed_name])