diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 8aef6c1359b5e..7330446d47ed6 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -93,7 +93,7 @@ def __init__(self, data: dict[str, Any]): self._context = context self._uri = uri - def run(self) -> None: # pylint: disable=too-many-statements + def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches self.validate() ex_str = "" ssh_tunnel = self._properties.get("ssh_tunnel") @@ -225,6 +225,10 @@ def ping(engine: Engine) -> bool: # bubble up the exception to return proper status code raise except Exception as ex: + if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2( + ex + ): + database.start_oauth2_dance() event_logger.log_with_context( action=get_log_connection_action( "test_connection_error", ssh_tunnel, ex diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index dcdfff6c3f242..8622a76a50b6d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1691,10 +1691,13 @@ def select_star( # pylint: disable=too-many-arguments return sql @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. + :param database: A Database object :param statement: A single SQL statement :param cursor: Cursor instance :return: Dictionary with different costs @@ -1765,6 +1768,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments cursor = conn.cursor() return [ cls.estimate_statement_cost( + database, cls.process_statement(statement, database), cursor, ) @@ -1793,8 +1797,9 @@ def get_url_for_impersonation( return url @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -1804,6 +1809,7 @@ def update_impersonation_config( Update a configuration dictionary that can set the correct properties for impersonating users + :param connect_args: a Database object :param connect_args: config to be updated :param uri: URI :param username: Effective username diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 70373927d521b..6281c6b3b0ff3 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -351,7 +351,16 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: + """ + Run a SQL query that estimates the cost of a given statement. + :param database: A Database object + :param statement: A single SQL statement + :param cursor: Cursor instance + :return: JSON response from Trino + """ sql = f"EXPLAIN {statement}" cursor.execute(sql) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index f0664564f872c..df5e1c643fa1f 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -365,9 +365,12 @@ def get_schema_from_engine_params( return parse.unquote(database.split("/")[1]) @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. + :param database: A Database object :param statement: A single SQL statement :param cursor: Cursor instance :return: JSON response from Trino @@ -945,8 +948,9 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return version is not None and Version(version) >= Version("0.319") @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -955,6 +959,8 @@ def update_impersonation_config( """ Update a configuration dictionary that can set the correct properties for impersonating users + + :param connect_args: the Database object :param connect_args: config to be updated :param uri: URI string :param username: Effective username diff --git a/superset/models/core.py b/superset/models/core.py index 5d3a6ea74ddab..6e439fc72ef0e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -29,6 +29,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache +from inspect import signature from typing import Any, Callable, cast, TYPE_CHECKING import numpy @@ -510,12 +511,22 @@ def _get_sqla_engine( # pylint: disable=too-many-locals logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) if self.impersonate_user: - self.db_engine_spec.update_impersonation_config( - connect_args, - str(sqlalchemy_url), - effective_username, - access_token, - ) + # Checking if the function signature can accept database as a param + if "database" in signature(self.db_engine_spec.update_impersonation_config): + self.db_engine_spec.update_impersonation_config( + self, + connect_args, + str(sqlalchemy_url), + effective_username, + access_token, + ) + else: + self.db_engine_spec.update_impersonation_config( + connect_args, + str(sqlalchemy_url), + effective_username, + access_token, + ) if connect_args: params["connect_args"] = connect_args diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index b889ef83c5e75..08a081862f7d8 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -192,3 +192,4 @@ class OAuth2ClientConfigSchema(Schema): ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True) + project_id = fields.String(required=False)