Skip to content

Commit

Permalink
feat: skip_workgroup_check setting to reduce AWS throttling (#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
amacal authored Sep 12, 2024
1 parent 840e847 commit d2bc485
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 23 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ A dbt profile can be configured to run against AWS Athena using the following co
| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` |
| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| skip_workgroup_check | Indicates if the WorkGroup check (additional AWS call) can be skipped | Optional | `true` |
| num_retries | Number of times to retry a failing query | Optional | `3` |
| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` |
| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `3` |
Expand Down
2 changes: 2 additions & 0 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class AthenaCredentials(Credentials):
region_name: str
endpoint_url: Optional[str] = None
work_group: Optional[str] = None
skip_workgroup_check: bool = False
aws_profile_name: Optional[str] = None
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
Expand Down Expand Up @@ -91,6 +92,7 @@ def _connection_keys(self) -> Tuple[str, ...]:
return (
"s3_staging_dir",
"work_group",
"skip_workgroup_check",
"region_name",
"database",
"schema",
Expand Down
4 changes: 3 additions & 1 deletion dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class AthenaConfig(AdapterConfig):
Args:
work_group: Identifier of Athena workgroup.
skip_workgroup_check: Indicates if the WorkGroup check (additional AWS call) can be skipped.
s3_staging_dir: S3 location to store Athena query results and metadata.
external_location: If set, the full S3 path in which the table will be saved.
partitioned_by: An array list of columns by which the table will be partitioned.
Expand All @@ -102,6 +103,7 @@ class AthenaConfig(AdapterConfig):
"""

work_group: Optional[str] = None
skip_workgroup_check: bool = False
s3_staging_dir: Optional[str] = None
external_location: Optional[str] = None
partitioned_by: Optional[str] = None
Expand Down Expand Up @@ -240,7 +242,7 @@ def is_work_group_output_location_enforced(self) -> bool:
conn = self.connections.get_thread_connection()
creds = conn.credentials

if creds.work_group:
if creds.work_group and not creds.skip_workgroup_check:
work_group = self._get_work_group(creds.work_group)
output_location = (
work_group.get("WorkGroup", {})
Expand Down
63 changes: 41 additions & 22 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,7 @@

class TestAthenaAdapter:
def setup_method(self, _):
project_cfg = {
"name": "X",
"version": "0.1",
"profile": "test",
"project-root": "/tmp/dbt/does-not-exist",
"config-version": 2,
}
profile_cfg = {
"outputs": {
"test": {
"type": "athena",
"s3_staging_dir": S3_STAGING_DIR,
"region_name": AWS_REGION,
"database": DATA_CATALOG_NAME,
"work_group": ATHENA_WORKGROUP,
"schema": DATABASE_NAME,
}
},
"target": "test",
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.config = TestAthenaAdapter._config_from_settings()
self._adapter = None
self.used_schemas = frozenset(
{
Expand All @@ -79,6 +58,35 @@ def adapter(self):
inject_adapter(self._adapter, AthenaPlugin)
return self._adapter

@staticmethod
def _config_from_settings(settings={}):
project_cfg = {
"name": "X",
"version": "0.1",
"profile": "test",
"project-root": "/tmp/dbt/does-not-exist",
"config-version": 2,
}

profile_cfg = {
"outputs": {
"test": {
**{
"type": "athena",
"s3_staging_dir": S3_STAGING_DIR,
"region_name": AWS_REGION,
"database": DATA_CATALOG_NAME,
"work_group": ATHENA_WORKGROUP,
"schema": DATABASE_NAME,
},
**settings,
}
},
"target": "test",
}

return config_from_parts_or_dicts(project_cfg, profile_cfg)

@mock.patch("dbt.adapters.athena.connections.AthenaConnection")
def test_acquire_connection_validations(self, connection_cls):
try:
Expand Down Expand Up @@ -931,6 +939,17 @@ def test_get_work_group_output_location(self, mock_aws_service):
work_group_location_enforced = self.adapter.is_work_group_output_location_enforced()
assert work_group_location_enforced

def test_get_work_group_output_location_if_workgroup_check_is_skipepd(self):
settings = {
"skip_workgroup_check": True,
}

self.config = TestAthenaAdapter._config_from_settings(settings)
self.adapter.acquire_connection("dummy")

work_group_location_enforced = self.adapter.is_work_group_output_location_enforced()
assert not work_group_location_enforced

@mock_aws
def test_get_work_group_output_location_no_location(self, mock_aws_service):
self.adapter.acquire_connection("dummy")
Expand Down

0 comments on commit d2bc485

Please sign in to comment.