Skip to content

Commit

Permalink
Refreshing short-live access token (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kosta91 authored Dec 28, 2023
1 parent 10e8f1c commit b566a94
Showing 1 changed file with 62 additions and 6 deletions.
68 changes: 62 additions & 6 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import logging
import re
import time
import jwt
import os
import requests
from abc import ABCMeta
from collections import defaultdict, deque
from datetime import datetime
from re import Pattern
from textwrap import dedent
from typing import Any, cast, Optional, TYPE_CHECKING
from typing import Any, cast, Optional, TYPE_CHECKING, Dict, List
from urllib import parse

import pandas as pd
Expand Down Expand Up @@ -63,6 +66,8 @@
from superset.utils import core as utils
from superset.utils.core import GenericDataType

from logging import getLogger

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database
Expand Down Expand Up @@ -1373,22 +1378,73 @@ def get_extra_params(database: Database) -> Dict[str, Any]:
:param database: database instance from which to extract extras
"""
from flask import request as r
import requests

session = requests.Session()

AUTH_TOKEN_NAME = "Authorization"
REFRESH_TOKEN_NAME = "X-Refresh-Token"

extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
connect_args: dict = engine_params.get("connect_args", {})
connect_args.update({
"protocol": "https",
"requests_kwargs": {"verify": False}
})

auth_header = r.headers.get("Authorization")
if auth_header:
session.headers["Authorization"] = auth_header
auth_token = r.headers.get(AUTH_TOKEN_NAME)
if auth_token:
refresh_token = r.headers.get(REFRESH_TOKEN_NAME)
if refresh_token:
auth_token = PrestoEngineSpec.__update_auth_header(auth_token, refresh_token)
else:
getLogger().info(f'{REFRESH_TOKEN_NAME} not found in headers')

session.headers[AUTH_TOKEN_NAME] = auth_token
connect_args["requests_session"] = session

engine_params["connect_args"] = connect_args
extra["engine_params"] = engine_params
return extra

@staticmethod
def __update_auth_header(auth_token, refresh_token):
BEARER = "Bearer"
refresh_timeout = int(os.getenv('REFRESH_TIMEOUT', 600))

token = auth_token.split(None, 1)

if len(token) == 2 and token[0] == BEARER:
decoded_token = jwt.decode(token[1], options={"verify_signature": False})
current_time = int(time.time())
expiration_time = decoded_token.get('exp', 0)
if expiration_time - current_time < refresh_timeout:
new_access_token = PrestoEngineSpec.__refresh_access_token(refresh_token)
if new_access_token:
auth_token = f'{BEARER} {new_access_token}'
return auth_token

@staticmethod
def __refresh_access_token(refresh_token):
access_token_url = os.getenv('ACCESS_TOKEN_URL')
client_id = os.getenv('OIDC_CLIENT_ID')
client_secret = os.getenv('OIDC_CLIENT_SECRET')

payload = {
'grant_type': 'refresh_token',
'client_id': client_id,
'client_secret': client_secret,
'refresh_token': refresh_token
}

try:
response = requests.post(access_token_url, data=payload)
if response.status_code == 200:
new_access_token = response.json().get('access_token')
return new_access_token
else:
getLogger().info(f"Error on processing token: {response.text}")
return None
except requests.RequestException as e:
getLogger().info(f"Error on sending request: {e}")
return None

0 comments on commit b566a94

Please sign in to comment.