From 0a9a49d37c9c19ddb189ba56992eca79e639fda0 Mon Sep 17 00:00:00 2001 From: Kostiantyn Kovalenko Date: Thu, 28 Dec 2023 23:24:37 +0200 Subject: [PATCH] Refreshing short-live access token (#11) --- superset/db_engine_specs/presto.py | 68 +++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index c1f9e2d671692..6bdf2e8d3fcda 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -21,12 +21,15 @@ import logging import re import time +import jwt +import os +import requests from abc import ABCMeta from collections import defaultdict, deque from datetime import datetime from re import Pattern from textwrap import dedent -from typing import Any, cast, Optional, TYPE_CHECKING +from typing import Any, cast, Optional, TYPE_CHECKING, Dict, List from urllib import parse import pandas as pd @@ -64,6 +67,8 @@ from superset.utils import core as utils from superset.utils.core import GenericDataType +from logging import getLogger + if TYPE_CHECKING: # prevent circular imports from superset.models.core import Database @@ -1378,22 +1383,73 @@ def get_extra_params(database: Database) -> Dict[str, Any]: :param database: database instance from which to extract extras """ from flask import request as r - import requests + session = requests.Session() + + AUTH_TOKEN_NAME = "Authorization" + REFRESH_TOKEN_NAME = "X-Refresh-Token" extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) engine_params = extra.get("engine_params", {}) - connect_args = engine_params.get("connect_args", {}) + connect_args: dict = engine_params.get("connect_args", {}) connect_args.update({ "protocol": "https", "requests_kwargs": {"verify": False} }) - auth_header = r.headers.get("Authorization") - if auth_header: - session.headers["Authorization"] = auth_header + auth_token = r.headers.get(AUTH_TOKEN_NAME) + if auth_token: + refresh_token = r.headers.get(REFRESH_TOKEN_NAME) + if refresh_token: + auth_token = PrestoEngineSpec.__update_auth_header(auth_token, refresh_token) + else: + getLogger().info(f'{REFRESH_TOKEN_NAME} not found in headers') + + session.headers[AUTH_TOKEN_NAME] = auth_token connect_args["requests_session"] = session engine_params["connect_args"] = connect_args extra["engine_params"] = engine_params return extra + + @staticmethod + def __update_auth_header(auth_token, refresh_token): + BEARER = "Bearer" + refresh_timeout = int(os.getenv('REFRESH_TIMEOUT', 600)) + + token = auth_token.split(None, 1) + + if len(token) == 2 and token[0] == BEARER: + decoded_token = jwt.decode(token[1], options={"verify_signature": False}) + current_time = int(time.time()) + expiration_time = decoded_token.get('exp', 0) + if expiration_time - current_time < refresh_timeout: + new_access_token = PrestoEngineSpec.__refresh_access_token(refresh_token) + if new_access_token: + auth_token = f'{BEARER} {new_access_token}' + return auth_token + + @staticmethod + def __refresh_access_token(refresh_token): + access_token_url = os.getenv('ACCESS_TOKEN_URL') + client_id = os.getenv('OIDC_CLIENT_ID') + client_secret = os.getenv('OIDC_CLIENT_SECRET') + + payload = { + 'grant_type': 'refresh_token', + 'client_id': client_id, + 'client_secret': client_secret, + 'refresh_token': refresh_token + } + + try: + response = requests.post(access_token_url, data=payload) + if response.status_code == 200: + new_access_token = response.json().get('access_token') + return new_access_token + else: + getLogger().info(f"Error on processing token: {response.text}") + return None + except requests.RequestException as e: + getLogger().info(f"Error on sending request: {e}") + return None