diff --git a/README.md b/README.md index ff811492..f83df9a1 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ A dbt profile can be configured to run against AWS Athena using the following co | aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` | | aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` | | work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | +| skip_workgroup_check | Indicates if the WorkGroup check (additional AWS call) can be skipped | Optional | `true` | | num_retries | Number of times to retry a failing query | Optional | `3` | | num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` | | num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `3` | diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index 8536ae3c..39e4c59e 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -56,6 +56,7 @@ class AthenaCredentials(Credentials): region_name: str endpoint_url: Optional[str] = None work_group: Optional[str] = None + skip_workgroup_check: bool = False aws_profile_name: Optional[str] = None aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None @@ -91,6 +92,7 @@ def _connection_keys(self) -> Tuple[str, ...]: return ( "s3_staging_dir", "work_group", + "skip_workgroup_check", "region_name", "database", "schema", diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 65a3c1ba..2ce693ed 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -79,6 +79,7 @@ class AthenaConfig(AdapterConfig): Args: work_group: Identifier of Athena workgroup. + skip_workgroup_check: Indicates if the WorkGroup check (additional AWS call) can be skipped. s3_staging_dir: S3 location to store Athena query results and metadata. external_location: If set, the full S3 path in which the table will be saved. partitioned_by: An array list of columns by which the table will be partitioned. @@ -102,6 +103,7 @@ class AthenaConfig(AdapterConfig): """ work_group: Optional[str] = None + skip_workgroup_check: bool = False s3_staging_dir: Optional[str] = None external_location: Optional[str] = None partitioned_by: Optional[str] = None @@ -240,7 +242,7 @@ def is_work_group_output_location_enforced(self) -> bool: conn = self.connections.get_thread_connection() creds = conn.credentials - if creds.work_group: + if creds.work_group and not creds.skip_workgroup_check: work_group = self._get_work_group(creds.work_group) output_location = ( work_group.get("WorkGroup", {}) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 3ae756bc..e3c3b923 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -39,28 +39,7 @@ class TestAthenaAdapter: def setup_method(self, _): - project_cfg = { - "name": "X", - "version": "0.1", - "profile": "test", - "project-root": "/tmp/dbt/does-not-exist", - "config-version": 2, - } - profile_cfg = { - "outputs": { - "test": { - "type": "athena", - "s3_staging_dir": S3_STAGING_DIR, - "region_name": AWS_REGION, - "database": DATA_CATALOG_NAME, - "work_group": ATHENA_WORKGROUP, - "schema": DATABASE_NAME, - } - }, - "target": "test", - } - - self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + self.config = TestAthenaAdapter._config_from_settings() self._adapter = None self.used_schemas = frozenset( { @@ -79,6 +58,35 @@ def adapter(self): inject_adapter(self._adapter, AthenaPlugin) return self._adapter + @staticmethod + def _config_from_settings(settings={}): + project_cfg = { + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "config-version": 2, + } + + profile_cfg = { + "outputs": { + "test": { + **{ + "type": "athena", + "s3_staging_dir": S3_STAGING_DIR, + "region_name": AWS_REGION, + "database": DATA_CATALOG_NAME, + "work_group": ATHENA_WORKGROUP, + "schema": DATABASE_NAME, + }, + **settings, + } + }, + "target": "test", + } + + return config_from_parts_or_dicts(project_cfg, profile_cfg) + @mock.patch("dbt.adapters.athena.connections.AthenaConnection") def test_acquire_connection_validations(self, connection_cls): try: @@ -931,6 +939,17 @@ def test_get_work_group_output_location(self, mock_aws_service): work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() assert work_group_location_enforced + def test_get_work_group_output_location_if_workgroup_check_is_skipepd(self): + settings = { + "skip_workgroup_check": True, + } + + self.config = TestAthenaAdapter._config_from_settings(settings) + self.adapter.acquire_connection("dummy") + + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + @mock_aws def test_get_work_group_output_location_no_location(self, mock_aws_service): self.adapter.acquire_connection("dummy")