diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index a5108b39..ea422e02 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -19,7 +19,6 @@ from dbt_common.clients.agate_helper import table_from_rows from dbt_common.contracts.constraints import ConstraintType from dbt_common.exceptions import DbtRuntimeError -from mypy_boto3_athena import AthenaClient from mypy_boto3_athena.type_defs import DataCatalogTypeDef, GetWorkGroupOutputTypeDef from mypy_boto3_glue.type_defs import ( ColumnTypeDef, @@ -218,14 +217,11 @@ def apply_lf_grants(self, relation: AthenaRelation, lf_grants_config: Dict[str, lf_permissions.process_permissions(lf_config) @lru_cache() - def _get_work_group(self, client: AthenaClient, work_group: str) -> GetWorkGroupOutputTypeDef: + def _get_work_group(self, work_group: str) -> GetWorkGroupOutputTypeDef: """ helper function to cache the result of the get_work_group to avoid APIs throttling """ - return client.get_work_group(WorkGroup=work_group) - - @available - def is_work_group_output_location_enforced(self) -> bool: + LOGGER.debug("get_work_group for %s", work_group) conn = self.connections.get_thread_connection() creds = conn.credentials client = conn.handle @@ -237,8 +233,15 @@ def is_work_group_output_location_enforced(self) -> bool: config=get_boto3_config(num_retries=creds.effective_num_retries), ) + return athena_client.get_work_group(WorkGroup=work_group) + + @available + def is_work_group_output_location_enforced(self) -> bool: + conn = self.connections.get_thread_connection() + creds = conn.credentials + if creds.work_group: - work_group = self._get_work_group(athena_client, creds.work_group) + work_group = self._get_work_group(creds.work_group) output_location = ( work_group.get("WorkGroup", {}) .get("Configuration", {})