Skip to content

Commit

Permalink
🐛Airbyte CDK: wrap HTTP error with status code 400 in AirbyteTracedEx…
Browse files Browse the repository at this point in the history
…ception (#31207)
  • Loading branch information
Anton Karpets authored Oct 16, 2023
1 parent 99b9fc9 commit 51fa2b3
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

import logging
from abc import abstractmethod
from json import JSONDecodeError
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union

import backoff
import pendulum
import requests
from airbyte_cdk.models import Level
from airbyte_cdk.models import FailureType, Level
from airbyte_cdk.sources.http_logger import format_http_message
from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
from airbyte_cdk.utils import AirbyteTracedException
from requests.auth import AuthBase

from ..exceptions import DefaultBackoffException
Expand All @@ -29,6 +31,20 @@ class AbstractOauth2Authenticator(AuthBase):

_NO_STREAM_NAME = None

def __init__(
self,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
) -> None:
"""
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
then http errors with such params will be wrapped in AirbyteTracedException.
"""
self._refresh_token_error_status_codes = refresh_token_error_status_codes
self._refresh_token_error_key = refresh_token_error_key
self._refresh_token_error_values = refresh_token_error_values

def __call__(self, request: requests.Request) -> requests.Request:
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
Expand Down Expand Up @@ -75,6 +91,16 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestException) -> bool:
try:
exception_content = exception.response.json()
except JSONDecodeError:
return False
return (
exception.response.status_code in self._refresh_token_error_status_codes
and exception_content.get(self._refresh_token_error_key) in self._refresh_token_error_values
)

@backoff.on_exception(
backoff.expo,
DefaultBackoffException,
Expand All @@ -92,6 +118,9 @@ def _get_refresh_access_token_response(self):
except requests.exceptions.RequestException as e:
if e.response.status_code == 429 or e.response.status_code >= 500:
raise DefaultBackoffException(request=e.response.request, response=e.response)
if self._wrap_refresh_token_exception(e):
message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
raise AirbyteTracedException(internal_message=message, message=message, failure_type=FailureType.config_error)
raise
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
token_expiry_is_time_of_expiration: bool = False,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
):
self._token_refresh_endpoint = token_refresh_endpoint
self._client_secret = client_secret
Expand All @@ -47,6 +50,7 @@ def __init__(
self._token_expiry_date_format = token_expiry_date_format
self._token_expiry_is_time_of_expiration = token_expiry_is_time_of_expiration
self._access_token = None
super().__init__(refresh_token_error_status_codes, refresh_token_error_key, refresh_token_error_values)

def get_token_refresh_endpoint(self) -> str:
return self._token_refresh_endpoint
Expand Down Expand Up @@ -103,8 +107,9 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
Authenticator that should be used for API implementing single use refresh tokens:
when refreshing access token some API returns a new refresh token that needs to used in the next refresh flow.
This authenticator updates the configuration with new refresh token by emitting Airbyte control message from an observed mutation.
By default this authenticator expects a connector config with a"credentials" field with the following nested fields: client_id, client_secret, refresh_token.
This behavior can be changed by defining custom config path (using dpath paths) in client_id_config_path, client_secret_config_path, refresh_token_config_path constructor arguments.
By default, this authenticator expects a connector config with a "credentials" field with the following nested fields: client_id,
client_secret, refresh_token. This behavior can be changed by defining custom config path (using dpath paths) in client_id_config_path,
client_secret_config_path, refresh_token_config_path constructor arguments.
"""

def __init__(
Expand All @@ -125,9 +130,11 @@ def __init__(
token_expiry_date_format: Optional[str] = None,
message_repository: MessageRepository = NoopMessageRepository(),
token_expiry_is_time_of_expiration: bool = False,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
):
"""
Args:
connector_config (Mapping[str, Any]): The full connector configuration
token_refresh_endpoint (str): Full URL to the token refresh endpoint
Expand Down Expand Up @@ -170,6 +177,9 @@ def __init__(
grant_type=grant_type,
token_expiry_date_format=token_expiry_date_format,
token_expiry_is_time_of_expiration=token_expiry_is_time_of_expiration,
refresh_token_error_status_codes=refresh_token_error_status_codes,
refresh_token_error_key=refresh_token_error_key,
refresh_token_error_values=refresh_token_error_values,
)

def get_refresh_token_name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
import pendulum
import pytest
import requests
from airbyte_cdk.models import OrchestratorType, Type
from airbyte_cdk.models import FailureType, OrchestratorType, Type
from airbyte_cdk.sources.streams.http.requests_native_auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Oauth2Authenticator,
SingleUseRefreshTokenOauth2Authenticator,
TokenAuthenticator,
)
from airbyte_cdk.utils import AirbyteTracedException
from requests import Response
from requests.exceptions import RequestException

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -252,6 +254,41 @@ def test_auth_call_method(self, mocker):

assert {"Authorization": "Bearer access_token"} == prepared_request.headers

@pytest.mark.parametrize(
("config_codes", "response_code", "config_key", "response_key", "config_values", "response_value", "wrapped"),
(
((400,), 400, "error", "error", ("invalid_grant",), "invalid_grant", True),
((401,), 400, "error", "error", ("invalid_grant",), "invalid_grant", False),
((400,), 400, "error_key", "error", ("invalid_grant",), "invalid_grant", False),
((400,), 400, "error", "error", ("invalid_grant",), "valid_grant", False),
((), 400, "", "error", (), "valid_grant", False),
),
)
def test_refresh_access_token_wrapped(
self, requests_mock, config_codes, response_code, config_key, response_key, config_values, response_value, wrapped
):
oauth = Oauth2Authenticator(
f"https://{TestOauth2Authenticator.refresh_endpoint}",
TestOauth2Authenticator.client_id,
TestOauth2Authenticator.client_secret,
TestOauth2Authenticator.refresh_token,
refresh_token_error_status_codes=config_codes,
refresh_token_error_key=config_key,
refresh_token_error_values=config_values,
)
error_content = {response_key: response_value}
requests_mock.post(f"https://{TestOauth2Authenticator.refresh_endpoint}", status_code=response_code, json=error_content)

exception_to_raise = AirbyteTracedException if wrapped else RequestException
with pytest.raises(exception_to_raise) as exc_info:
oauth.refresh_access_token()

if wrapped:
error_message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
assert exc_info.value.internal_message == error_message
assert exc_info.value.message == error_message
assert exc_info.value.failure_type == FailureType.config_error


class TestSingleUseRefreshTokenOauth2Authenticator:
@pytest.fixture
Expand Down

0 comments on commit 51fa2b3

Please sign in to comment.