From 51fa2b3c3186fbd366db4e891c21df928dbf900b Mon Sep 17 00:00:00 2001 From: Anton Karpets Date: Mon, 16 Oct 2023 11:15:04 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9BAirbyte=20CDK:=20wrap=20HTTP=20erro?= =?UTF-8?q?r=20with=20status=20code=20400=20in=20AirbyteTracedException=20?= =?UTF-8?q?(#31207)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../requests_native_auth/abstract_oauth.py | 31 ++++++++++++++- .../http/requests_native_auth/oauth.py | 16 ++++++-- .../test_requests_native_auth.py | 39 ++++++++++++++++++- 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index bc9b8e63883b..22e2caa6a2e8 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -4,14 +4,16 @@ import logging from abc import abstractmethod +from json import JSONDecodeError from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union import backoff import pendulum import requests -from airbyte_cdk.models import Level +from airbyte_cdk.models import FailureType, Level from airbyte_cdk.sources.http_logger import format_http_message from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository +from airbyte_cdk.utils import AirbyteTracedException from requests.auth import AuthBase from ..exceptions import DefaultBackoffException @@ -29,6 +31,20 @@ class AbstractOauth2Authenticator(AuthBase): _NO_STREAM_NAME = None + def __init__( + self, + refresh_token_error_status_codes: Tuple[int, ...] = (), + refresh_token_error_key: str = "", + refresh_token_error_values: Tuple[str, ...] = (), + ) -> None: + """ + If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, + then http errors with such params will be wrapped in AirbyteTracedException. + """ + self._refresh_token_error_status_codes = refresh_token_error_status_codes + self._refresh_token_error_key = refresh_token_error_key + self._refresh_token_error_values = refresh_token_error_values + def __call__(self, request: requests.Request) -> requests.Request: """Attach the HTTP headers required to authenticate on the HTTP request""" request.headers.update(self.get_auth_header()) @@ -75,6 +91,16 @@ def build_refresh_request_body(self) -> Mapping[str, Any]: return payload + def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestException) -> bool: + try: + exception_content = exception.response.json() + except JSONDecodeError: + return False + return ( + exception.response.status_code in self._refresh_token_error_status_codes + and exception_content.get(self._refresh_token_error_key) in self._refresh_token_error_values + ) + @backoff.on_exception( backoff.expo, DefaultBackoffException, @@ -92,6 +118,9 @@ def _get_refresh_access_token_response(self): except requests.exceptions.RequestException as e: if e.response.status_code == 429 or e.response.status_code >= 500: raise DefaultBackoffException(request=e.response.request, response=e.response) + if self._wrap_refresh_token_exception(e): + message = "Refresh token is invalid or expired. Please re-authenticate from Sources//Settings." + raise AirbyteTracedException(internal_message=message, message=message, failure_type=FailureType.config_error) raise except Exception as e: raise Exception(f"Error while refreshing access token: {e}") from e diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 6a92fad6af44..48a855fa515f 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -32,6 +32,9 @@ def __init__( refresh_request_body: Mapping[str, Any] = None, grant_type: str = "refresh_token", token_expiry_is_time_of_expiration: bool = False, + refresh_token_error_status_codes: Tuple[int, ...] = (), + refresh_token_error_key: str = "", + refresh_token_error_values: Tuple[str, ...] = (), ): self._token_refresh_endpoint = token_refresh_endpoint self._client_secret = client_secret @@ -47,6 +50,7 @@ def __init__( self._token_expiry_date_format = token_expiry_date_format self._token_expiry_is_time_of_expiration = token_expiry_is_time_of_expiration self._access_token = None + super().__init__(refresh_token_error_status_codes, refresh_token_error_key, refresh_token_error_values) def get_token_refresh_endpoint(self) -> str: return self._token_refresh_endpoint @@ -103,8 +107,9 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator): Authenticator that should be used for API implementing single use refresh tokens: when refreshing access token some API returns a new refresh token that needs to used in the next refresh flow. This authenticator updates the configuration with new refresh token by emitting Airbyte control message from an observed mutation. - By default this authenticator expects a connector config with a"credentials" field with the following nested fields: client_id, client_secret, refresh_token. - This behavior can be changed by defining custom config path (using dpath paths) in client_id_config_path, client_secret_config_path, refresh_token_config_path constructor arguments. + By default, this authenticator expects a connector config with a "credentials" field with the following nested fields: client_id, + client_secret, refresh_token. This behavior can be changed by defining custom config path (using dpath paths) in client_id_config_path, + client_secret_config_path, refresh_token_config_path constructor arguments. """ def __init__( @@ -125,9 +130,11 @@ def __init__( token_expiry_date_format: Optional[str] = None, message_repository: MessageRepository = NoopMessageRepository(), token_expiry_is_time_of_expiration: bool = False, + refresh_token_error_status_codes: Tuple[int, ...] = (), + refresh_token_error_key: str = "", + refresh_token_error_values: Tuple[str, ...] = (), ): """ - Args: connector_config (Mapping[str, Any]): The full connector configuration token_refresh_endpoint (str): Full URL to the token refresh endpoint @@ -170,6 +177,9 @@ def __init__( grant_type=grant_type, token_expiry_date_format=token_expiry_date_format, token_expiry_is_time_of_expiration=token_expiry_is_time_of_expiration, + refresh_token_error_status_codes=refresh_token_error_status_codes, + refresh_token_error_key=refresh_token_error_key, + refresh_token_error_values=refresh_token_error_values, ) def get_refresh_token_name(self) -> str: diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index e4d0c9e23297..8af1199dea7e 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -11,7 +11,7 @@ import pendulum import pytest import requests -from airbyte_cdk.models import OrchestratorType, Type +from airbyte_cdk.models import FailureType, OrchestratorType, Type from airbyte_cdk.sources.streams.http.requests_native_auth import ( BasicHttpAuthenticator, MultipleTokenAuthenticator, @@ -19,7 +19,9 @@ SingleUseRefreshTokenOauth2Authenticator, TokenAuthenticator, ) +from airbyte_cdk.utils import AirbyteTracedException from requests import Response +from requests.exceptions import RequestException LOGGER = logging.getLogger(__name__) @@ -252,6 +254,41 @@ def test_auth_call_method(self, mocker): assert {"Authorization": "Bearer access_token"} == prepared_request.headers + @pytest.mark.parametrize( + ("config_codes", "response_code", "config_key", "response_key", "config_values", "response_value", "wrapped"), + ( + ((400,), 400, "error", "error", ("invalid_grant",), "invalid_grant", True), + ((401,), 400, "error", "error", ("invalid_grant",), "invalid_grant", False), + ((400,), 400, "error_key", "error", ("invalid_grant",), "invalid_grant", False), + ((400,), 400, "error", "error", ("invalid_grant",), "valid_grant", False), + ((), 400, "", "error", (), "valid_grant", False), + ), + ) + def test_refresh_access_token_wrapped( + self, requests_mock, config_codes, response_code, config_key, response_key, config_values, response_value, wrapped + ): + oauth = Oauth2Authenticator( + f"https://{TestOauth2Authenticator.refresh_endpoint}", + TestOauth2Authenticator.client_id, + TestOauth2Authenticator.client_secret, + TestOauth2Authenticator.refresh_token, + refresh_token_error_status_codes=config_codes, + refresh_token_error_key=config_key, + refresh_token_error_values=config_values, + ) + error_content = {response_key: response_value} + requests_mock.post(f"https://{TestOauth2Authenticator.refresh_endpoint}", status_code=response_code, json=error_content) + + exception_to_raise = AirbyteTracedException if wrapped else RequestException + with pytest.raises(exception_to_raise) as exc_info: + oauth.refresh_access_token() + + if wrapped: + error_message = "Refresh token is invalid or expired. Please re-authenticate from Sources//Settings." + assert exc_info.value.internal_message == error_message + assert exc_info.value.message == error_message + assert exc_info.value.failure_type == FailureType.config_error + class TestSingleUseRefreshTokenOauth2Authenticator: @pytest.fixture