diff --git a/.pylintrc b/.pylintrc index e11376fc1053b..46c5cf207176c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -89,7 +89,7 @@ disable= output-format=text # Tells whether to display a full report or only the messages -reports=yes +reports=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which diff --git a/docker/pythonpath_dev/superset_config.py b/docker/pythonpath_dev/superset_config.py index 69477d56e706d..fdfd1e1262350 100644 --- a/docker/pythonpath_dev/superset_config.py +++ b/docker/pythonpath_dev/superset_config.py @@ -100,6 +100,15 @@ class CeleryConfig: SQLLAB_CTAS_NO_LIMIT = True +DB_NAME = "test" +DB_USER = "superset" +DB_PASSWORD = "superset" +SQLALCHEMY_DATABASE_URI = ( + f"{DATABASE_DIALECT}://" + f"{DATABASE_USER}:{DATABASE_PASSWORD}@" + f"{DATABASE_HOST}:{DATABASE_PORT}/{DB_NAME}" +) + # # Optionally import superset_config_docker.py (which will have been included on # the PYTHONPATH) in order to allow for local settings to be overridden diff --git a/superset/commands/report/alert.py b/superset/commands/report/alert.py index 30861bddaa2d9..857a0fa42241a 100644 --- a/superset/commands/report/alert.py +++ b/superset/commands/report/alert.py @@ -26,7 +26,7 @@ from celery.exceptions import SoftTimeLimitExceeded from flask_babel import lazy_gettext as _ -from superset import app, jinja_context, security_manager +from superset import app, security_manager from superset.commands.base import BaseCommand from superset.commands.report.exceptions import ( AlertQueryError, @@ -143,15 +143,7 @@ def _execute_query(self) -> pd.DataFrame: :raises AlertQueryError: SQL query is not valid :raises AlertQueryTimeout: The SQL query received a celery soft timeout """ - sql_template = jinja_context.get_template_processor( - database=self._report_schedule.database - ) - rendered_sql = sql_template.process_template(self._report_schedule.sql) try: - limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql( - rendered_sql, ALERT_SQL_LIMIT - ) - executor, username = get_executor( # pylint: disable=unused-variable executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"], model=self._report_schedule, @@ -159,7 +151,11 @@ def _execute_query(self) -> pd.DataFrame: user = security_manager.find_user(username) with override_user(user): start = default_timer() - df = self._report_schedule.database.get_df(sql=limited_rendered_sql) + df = self._report_schedule.database.get_df( + sql=self._report_schedule.sql, + limit=ALERT_SQL_LIMIT, + render_template=True, + ) stop = default_timer() logger.info( "Query for %s took %.2f ms", diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 518cc9c705bbd..e499b6b74d7a6 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -305,15 +305,13 @@ def type_generic(self) -> utils.GenericDataType | None: def get_sqla_col( self, label: str | None = None, - template_processor: BaseTemplateProcessor | None = None, ) -> Column: label = label or self.column_name db_engine_spec = self.db_engine_spec column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra) type_ = column_spec.sqla_type if column_spec else None if expression := self.expression: - if template_processor: - expression = template_processor.process_template(expression) + expression = self.table.render_sql(expression) col = literal_column(expression, type_=type_) else: col = column(self.column_name, type_=type_) @@ -328,14 +326,12 @@ def get_timestamp_expression( self, time_grain: str | None, label: str | None = None, - template_processor: BaseTemplateProcessor | None = None, ) -> TimestampExpression | Label: """ Return a SQLAlchemy Core element representation of self to be used in a query. :param time_grain: Optional time grain, e.g. P1Y :param label: alias/label that column is expected to have - :param template_processor: template processor :return: A TimeExpression object wrapped in a Label if supported by db """ label = label or utils.DTTM_ALIAS @@ -350,8 +346,7 @@ def get_timestamp_expression( sqla_col = column(self.column_name, type_=type_) return self.database.make_sqla_column_compatible(sqla_col, label) if expression := self.expression: - if template_processor: - expression = template_processor.process_template(expression) + expression = self.table.render_sql(expression) col = literal_column(expression, type_=type_) else: col = column(self.column_name, type_=type_) @@ -426,12 +421,10 @@ def __repr__(self) -> str: def get_sqla_col( self, label: str | None = None, - template_processor: BaseTemplateProcessor | None = None, ) -> Column: label = label or self.metric_name expression = self.expression - if template_processor: - expression = template_processor.process_template(expression) + expression = self.table.render_sql(expression) sqla_col: ColumnClause = literal_column(expression) return self.table.database.make_sqla_column_compatible(sqla_col, label) @@ -1233,13 +1226,8 @@ def extra_dict(self) -> dict[str, Any]: def get_fetch_values_predicate( self, - template_processor: BaseTemplateProcessor | None = None, ) -> TextClause: - fetch_values_predicate = self.fetch_values_predicate - if template_processor: - fetch_values_predicate = template_processor.process_template( - fetch_values_predicate - ) + fetch_values_predicate = self.render_sql(self.fetch_values_predicate) try: return self.text(fetch_values_predicate) except TemplateError as ex: @@ -1250,8 +1238,12 @@ def get_fetch_values_predicate( ) ) from ex - def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: - return get_template_processor(table=self, database=self.database, **kwargs) + def render_sql(self, sql: str, **kwargs: Any) -> str: + return self.database.render_sql(sql, **self.template_params_dict) + + @property + def template_params_dict(self) -> dict[Any, Any]: + return json.json_to_dict(self.template_params) def get_query_str(self, query_obj: QueryObjectDict) -> str: query_str_ext = self.get_query_str_extended(query_obj) @@ -1264,9 +1256,7 @@ def get_sqla_table(self) -> TableClause: tbl.schema = self.schema return tbl - def get_from_clause( - self, template_processor: BaseTemplateProcessor | None = None - ) -> tuple[TableClause | Alias, str | None]: + def get_from_clause(self) -> tuple[TableClause | Alias, str | None]: """ Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. If the FROM is referencing a @@ -1275,7 +1265,7 @@ def get_from_clause( if not self.is_virtual: return self.get_sqla_table(), None - from_sql = self.get_rendered_sql(template_processor) + from_sql = self.get_rendered_sql() parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) if not ( parsed_query.is_unknown() @@ -1298,14 +1288,12 @@ def adhoc_metric_to_sqla( self, metric: AdhocMetric, columns_by_name: dict[str, TableColumn], - template_processor: BaseTemplateProcessor | None = None, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict columns_by_name: Columns for the current table - :param template_processor: template_processor instance :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -1317,9 +1305,7 @@ def adhoc_metric_to_sqla( column_name = cast(str, metric_column.get("column_name")) table_column: TableColumn | None = columns_by_name.get(column_name) if table_column: - sqla_column = table_column.get_sqla_col( - template_processor=template_processor - ) + sqla_column = table_column.get_sqla_col() else: sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) @@ -1328,7 +1314,9 @@ def adhoc_metric_to_sqla( expression=metric["sqlExpression"], database_id=self.database_id, schema=self.schema, - template_processor=template_processor, + template_processor=get_template_processor( + table=self, database=self.database + ), ) sqla_metric = literal_column(expression) else: @@ -1340,7 +1328,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals self, col: AdhocColumn, force_type_check: bool = False, - template_processor: BaseTemplateProcessor | None = None, ) -> ColumnElement: """ Turn an adhoc column into a sqlalchemy column. @@ -1349,7 +1336,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals :param force_type_check: Should the column type be checked in the db. This is needed to validate if a filter with an adhoc column is applicable. - :param template_processor: template_processor instance :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -1358,16 +1344,16 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals expression=col["sqlExpression"], database_id=self.database_id, schema=self.schema, - template_processor=template_processor, + template_processor=get_template_processor( + table=self, database=self.database + ), ) time_grain = col.get("timeGrain") has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain is_dttm = False pdf = None if col_in_metadata := self.get_column(expression): - sqla_column = col_in_metadata.get_sqla_col( - template_processor=template_processor - ) + sqla_column = col_in_metadata.get_sqla_col() is_dttm = col_in_metadata.is_temporal pdf = col_in_metadata.python_date_format else: @@ -1375,7 +1361,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals if has_timegrain or force_type_check: try: # probe adhoc column type - tbl, _ = self.get_from_clause(template_processor) + tbl, _ = self.get_from_clause() qry = sa.select([sqla_column]).limit(1).select_from(tbl) sql = self.database.compile_sqla_query(qry) col_desc = get_columns_description( @@ -1435,7 +1421,6 @@ def _get_series_orderby( series_limit_metric: Metric, metrics_by_name: dict[str, SqlMetric], columns_by_name: dict[str, TableColumn], - template_processor: BaseTemplateProcessor | None = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) @@ -1444,9 +1429,7 @@ def _get_series_orderby( isinstance(series_limit_metric, str) and series_limit_metric in metrics_by_name ): - ob = metrics_by_name[series_limit_metric].get_sqla_col( - template_processor=template_processor - ) + ob = metrics_by_name[series_limit_metric].get_sqla_col() else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=series_limit_metric) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index bbd3d436f5362..a149643ce08ad 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -37,7 +37,6 @@ SupersetSecurityException, ) from superset.models.core import Database -from superset.result_set import SupersetResultSet from superset.sql_parse import ParsedQuery, Table from superset.superset_typing import ResultSetColumnType @@ -102,9 +101,7 @@ def get_virtual_table_metadata(dataset: Dataset) -> list[ResultSetColumnType]: ) db_engine_spec = dataset.database.db_engine_spec - sql = dataset.get_template_processor().process_template( - dataset.sql, **dataset.template_params_dict - ) + sql = dataset.render_sql(dataset.sql, **dataset.template_params_dict) parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine) if not db_engine_spec.is_readonly_query(parsed_query): raise SupersetSecurityException( @@ -137,18 +134,9 @@ def get_columns_description( schema: str | None, query: str, ) -> list[ResultSetColumnType]: - # TODO(villebro): refactor to use same code that's used by - # sql_lab.py:execute_sql_statements - db_engine_spec = database.db_engine_spec try: - with database.get_raw_connection(catalog=catalog, schema=schema) as conn: - cursor = conn.cursor() - query = database.apply_limit_to_sql(query, limit=1) - cursor.execute(query) - db_engine_spec.execute(cursor, query, database) - result = db_engine_spec.fetch_data(cursor, limit=1) - result_set = SupersetResultSet(result, cursor.description, db_engine_spec) - return result_set.columns + result_set = database.get_result_set(query, catalog, schema, limit=1) + return result_set.columns except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex diff --git a/superset/models/core.py b/superset/models/core.py index be2963a91127c..f5c0c842a83ad 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -29,7 +29,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING import numpy import pandas as pd @@ -59,7 +59,7 @@ from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import ColumnElement, expression, Select -from superset import app, db_engine_specs, is_feature_enabled +from superset import app, db_engine_specs, is_feature_enabled, jinja_context from superset.commands.database.exceptions import DatabaseInvalidError from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK from superset.databases.utils import make_url_safe @@ -636,14 +636,24 @@ def mutate_sql_based_on_config(self, sql_: str, is_split: bool = False) -> str: ) return sql_ - def get_df( # pylint: disable=too-many-locals + def render_sql(self, sql: str, **kwargs: Any) -> str: + tp = jinja_context.get_template_processor(self) + return tp.process_template(sql, **kwargs) + + def get_result_set( self, sql: str, catalog: str | None = None, schema: str | None = None, - mutator: Callable[[pd.DataFrame], None] | None = None, + limit: Optional[int] = None, + render_template: Optional[bool] = False, ) -> pd.DataFrame: - sqls = self.db_engine_spec.parse_sql(sql) + if render_template: + sql = self.render_sql(sql) + if limit: + sql = self.apply_limit_to_sql(sql, limit) + + statements = self.db_engine_spec.parse_sql(sql) with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url @@ -659,8 +669,7 @@ def _log_query(sql: str) -> None: with self.get_raw_connection(catalog=catalog, schema=schema) as conn: cursor = conn.cursor() - df = None - for i, sql_ in enumerate(sqls): + for i, sql_ in enumerate(statements): sql_ = self.mutate_sql_based_on_config(sql_, is_split=True) _log_query(sql_) with event_logger.log_context( @@ -669,20 +678,31 @@ def _log_query(sql: str) -> None: object_ref=__name__, ): self.db_engine_spec.execute(cursor, sql_, self) - if i < len(sqls) - 1: + if i < len(statements) - 1: # If it's not the last, we don't keep the results cursor.fetchall() else: # Last query, fetch and process the results data = self.db_engine_spec.fetch_data(cursor) - result_set = SupersetResultSet( + return SupersetResultSet( data, cursor.description, self.db_engine_spec ) - df = result_set.to_pandas_df() - if mutator: - df = mutator(df) + return None - return self.post_process_df(df) + def get_df( + self, + sql: str, + catalog: str | None = None, + schema: str | None = None, + mutator: Callable[[pd.DataFrame], None] | None = None, + limit: Optional[int] = None, + render_template: Optional[bool] = False, + ) -> pd.DataFrame: + result_set = self.get_result_set(sql, catalog, schema, limit, render_template) + df = result_set.to_pandas_df() + if mutator: + df = mutator(df) + return self.post_process_df(df) def compile_sqla_query( self, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 48b95566af9bd..52df59b330449 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -68,7 +68,6 @@ SupersetSecurityException, ) from superset.extensions import feature_flag_manager -from superset.jinja_context import BaseTemplateProcessor from superset.sql_parse import ( has_table_query, insert_rls_in_predicate, @@ -144,15 +143,6 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) -def json_to_dict(json_str: str) -> dict[Any, Any]: - if json_str: - val = re.sub(",[ \t\r\n]+}", "}", json_str) - val = re.sub(",[ \t\r\n]+\\]", "]", val) - return json.loads(val) - - return {} - - def convert_uuids(obj: Any) -> Any: """ Convert UUID objects to str so we can use yaml.safe_dump @@ -459,11 +449,7 @@ def reset_ownership(self) -> None: @property def params_dict(self) -> dict[Any, Any]: - return json_to_dict(self.params) - - @property - def template_params_dict(self) -> dict[Any, Any]: - return json_to_dict(self.template_params) # type: ignore + return json.json_to_dict(self.params) def _user(user: User) -> str: @@ -793,38 +779,26 @@ def columns(self) -> list[Any]: def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]: raise NotImplementedError() - def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: + def render_sql(self, sql: str, **kwargs: Any) -> str: raise NotImplementedError() - def get_fetch_values_predicate( - self, - template_processor: Optional[ # pylint: disable=unused-argument - BaseTemplateProcessor - ] = None, - ) -> TextClause: + def get_fetch_values_predicate(self) -> TextClause: return self.fetch_values_predicate - def get_sqla_row_level_filters( - self, - template_processor: Optional[BaseTemplateProcessor] = None, - ) -> list[TextClause]: + def get_sqla_row_level_filters(self) -> list[TextClause]: """ Return the appropriate row level security filters for this table and the current user. A custom username can be passed when the user is not present in the Flask global namespace. - :param template_processor: The template processor to apply to the filters. :returns: A list of SQL clauses to be ANDed together. """ - template_processor = template_processor or self.get_template_processor() - all_filters: list[TextClause] = [] filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list) try: for filter_ in security_manager.get_rls_filters(self): - clause = self.text( - f"({template_processor.process_template(filter_.clause)})" - ) + snippet = self.render_sql(filter_.clause) + clause = self.text(f"({snippet})") if filter_.group_key: filter_groups[filter_.group_key].append(clause) else: @@ -832,9 +806,8 @@ def get_sqla_row_level_filters( if is_feature_enabled("EMBEDDED_SUPERSET"): for rule in security_manager.get_guest_rls_filters(self): - clause = self.text( - f"({template_processor.process_template(rule['clause'])})" - ) + snippet = self.render_sql(rule["clause"]) + clause = self.text(f"({snippet})") all_filters.append(clause) grouped_filters = [or_(*clauses) for clauses in filter_groups.values()] @@ -853,11 +826,9 @@ def _process_sql_expression( expression: Optional[str], database_id: int, schema: str, - template_processor: Optional[BaseTemplateProcessor], ) -> Optional[str]: - if template_processor and expression: - expression = template_processor.process_template(expression) if expression: + expression = self.render_sql(expression) expression = validate_adhoc_subquery( expression, database_id, @@ -1065,23 +1036,20 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: def get_rendered_sql( self, - template_processor: Optional[BaseTemplateProcessor] = None, ) -> str: """ Render sql with template engine (Jinja). """ - sql = self.sql - if template_processor: - try: - sql = template_processor.process_template(sql) - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error while rendering virtual dataset query: %(msg)s", - msg=ex.message, - ) - ) from ex + try: + sql = self.render_sql(self.sql) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error while rendering virtual dataset query: %(msg)s", + msg=ex.message, + ) + ) from ex script = SQLScript(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine) if len(script.statements) > 1: @@ -1097,16 +1065,14 @@ def get_rendered_sql( def text(self, clause: str) -> TextClause: return self.db_engine_spec.get_text_clause(clause) - def get_from_clause( - self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> tuple[Union[TableClause, Alias], Optional[str]]: + def get_from_clause(self) -> tuple[Union[TableClause, Alias], Optional[str]]: """ Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. If the FROM is referencing a CTE, the CTE is returned as the second value in the return tuple. """ - from_sql = self.get_rendered_sql(template_processor) + from_sql = self.get_rendered_sql() parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) if not ( parsed_query.is_unknown() @@ -1129,14 +1095,12 @@ def adhoc_metric_to_sqla( self, metric: AdhocMetric, columns_by_name: dict[str, "TableColumn"], # pylint: disable=unused-argument - template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict columns_by_name: Columns for the current table - :param template_processor: template_processor instance :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -1153,7 +1117,6 @@ def adhoc_metric_to_sqla( expression=metric["sqlExpression"], database_id=self.database_id, schema=self.schema, - template_processor=template_processor, ) sqla_metric = literal_column(expression) else: @@ -1237,7 +1200,6 @@ def _get_series_orderby( series_limit_metric: Metric, metrics_by_name: dict[str, "SqlMetric"], columns_by_name: dict[str, "TableColumn"], - template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) @@ -1246,9 +1208,7 @@ def _get_series_orderby( isinstance(series_limit_metric, str) and series_limit_metric in metrics_by_name ): - ob = metrics_by_name[series_limit_metric].get_sqla_col( - template_processor=template_processor - ) + ob = metrics_by_name[series_limit_metric].get_sqla_col() else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=series_limit_metric) @@ -1259,7 +1219,6 @@ def adhoc_column_to_sqla( self, col: "AdhocColumn", # type: ignore # noqa: F821 force_type_check: bool = False, - template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: raise NotImplementedError() @@ -1322,18 +1281,14 @@ def get_time_filter( # pylint: disable=too-many-arguments end_dttm: Optional[sa.DateTime], time_grain: Optional[str] = None, label: Optional[str] = "__time", - template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: col = ( time_col.get_timestamp_expression( time_grain=time_grain, label=label, - template_processor=template_processor, ) if time_grain - else self.convert_tbl_column_to_sqla_col( - time_col, label=label, template_processor=template_processor - ) + else self.convert_tbl_column_to_sqla_col(time_col, label=label) ) l = [] # noqa: E741 @@ -1369,8 +1324,7 @@ def values_for_column( ) cols = {col.column_name: col for col in self.columns} target_col = cols[column_name_] - tp = self.get_template_processor() - tbl, cte = self.get_from_clause(tp) + tbl, cte = self.get_from_clause() qry = ( sa.select( @@ -1378,7 +1332,7 @@ def values_for_column( # automatically add a random alias to the projection because of the # call to DISTINCT; others will uppercase the column names. This # gives us a deterministic column name in the dataframe. - [target_col.get_sqla_col(template_processor=tp).label("column_values")] + [target_col.get_sqla_col().label("column_values")] ) .select_from(tbl) .distinct() @@ -1387,7 +1341,7 @@ def values_for_column( qry = qry.limit(limit) if self.fetch_values_predicate: - qry = qry.where(self.get_fetch_values_predicate(template_processor=tp)) + qry = qry.where(self.get_fetch_values_predicate()) with self.database.get_sqla_engine() as engine: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) @@ -1408,7 +1362,6 @@ def get_timestamp_expression( column: dict[str, Any], time_grain: Optional[str], label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, ) -> Union[TimestampExpression, Label]: """ Return a SQLAlchemy Core element representation of self to be used in a query. @@ -1416,7 +1369,6 @@ def get_timestamp_expression( :param column: column object :param time_grain: Optional time grain, e.g. P1Y :param label: alias/label that column is expected to have - :param template_processor: template processor :return: A TimeExpression object wrapped in a Label if supported by db """ label = label or utils.DTTM_ALIAS @@ -1424,9 +1376,8 @@ def get_timestamp_expression( type_ = column_spec.sqla_type if column_spec else sa.DateTime col = sa.column(column.get("column_name"), type_=type_) - if template_processor: - expression = template_processor.process_template(column["column_name"]) - col = sa.literal_column(expression, type_=type_) + expression = self.render_sql(column["column_name"]) + col = sa.literal_column(expression, type_=type_) time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain) return self.make_sqla_column_compatible(time_expr, label) @@ -1435,15 +1386,13 @@ def convert_tbl_column_to_sqla_col( self, tbl_column: "TableColumn", label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: label = label or tbl_column.column_name db_engine_spec = self.db_engine_spec column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra) type_ = column_spec.sqla_type if column_spec else None if expression := tbl_column.expression: - if template_processor: - expression = template_processor.process_template(expression) + expression = self.render_sql(expression) col = literal_column(expression, type_=type_) else: col = sa.column(tbl_column.column_name, type_=type_) @@ -1520,7 +1469,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma applied_template_filters: list[str] = [] template_kwargs["removed_filters"] = removed_filters template_kwargs["applied_filters"] = applied_template_filters - template_processor = self.get_template_processor(**template_kwargs) prequeries: list[str] = [] orderby = orderby or [] need_groupby = bool(metrics is not None or groupby) @@ -1556,15 +1504,10 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma self.adhoc_metric_to_sqla( metric=metric, columns_by_name=columns_by_name, - template_processor=template_processor, ) ) elif isinstance(metric, str) and metric in metrics_by_name: - metrics_exprs.append( - metrics_by_name[metric].get_sqla_col( - template_processor=template_processor - ) - ) + metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=metric) @@ -1593,7 +1536,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma expression=col["sqlExpression"], database_id=self.database_id, schema=self.schema, - template_processor=template_processor, ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists @@ -1603,16 +1545,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: - col = self.convert_tbl_column_to_sqla_col( - columns_by_name[col], template_processor=template_processor - ) + col = self.convert_tbl_column_to_sqla_col(columns_by_name[col]) elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True elif col in metrics_by_name: - col = metrics_by_name[col].get_sqla_col( - template_processor=template_processor - ) + col = metrics_by_name[col].get_sqla_col() need_groupby = True if isinstance(col, ColumnElement): @@ -1642,13 +1580,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma outer = table_col.get_timestamp_expression( time_grain=time_grain, label=selected, - template_processor=template_processor, ) # if groupby field equals a selected column elif selected in columns_by_name: outer = self.convert_tbl_column_to_sqla_col( columns_by_name[selected], - template_processor=template_processor, ) else: selected = validate_adhoc_subquery( @@ -1659,9 +1595,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: - outer = self.adhoc_column_to_sqla( - col=selected, template_processor=template_processor - ) + outer = self.adhoc_column_to_sqla(col=selected) groupby_all_columns[outer.name] = outer if ( is_timeseries and not series_column_labels @@ -1684,9 +1618,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma ) select_exprs.append( - self.convert_tbl_column_to_sqla_col( - columns_by_name[selected], template_processor=template_processor - ) + self.convert_tbl_column_to_sqla_col(columns_by_name[selected]) if isinstance(selected, str) and selected in columns_by_name else self.make_sqla_column_compatible( literal_column(selected), _column_label @@ -1705,9 +1637,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_filters = [] if is_timeseries: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) + timestamp = dttm_col.get_timestamp_expression(time_grain=time_grain) # always put timestamp as the first column select_exprs.insert(0, timestamp) groupby_all_columns[timestamp.name] = timestamp @@ -1723,7 +1653,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_col=columns_by_name[self.main_dttm_col], start_dttm=from_dttm, end_dttm=to_dttm, - template_processor=template_processor, ) ) @@ -1731,7 +1660,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_col=dttm_col, start_dttm=from_dttm, end_dttm=to_dttm, - template_processor=template_processor, ) time_filters.append(time_filter_column) @@ -1752,7 +1680,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma qry = sa.select(select_exprs) - tbl, cte = self.get_from_clause(template_processor) + tbl, cte = self.get_from_clause() if groupby_all_columns: qry = qry.group_by(*groupby_all_columns.values()) @@ -1790,12 +1718,10 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if sqla_col is not None: pass elif col_obj and filter_grain: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, template_processor=template_processor - ) + sqla_col = col_obj.get_timestamp_expression(time_grain=filter_grain) elif col_obj: sqla_col = self.convert_tbl_column_to_sqla_col( - tbl_column=col_obj, template_processor=template_processor + tbl_column=col_obj, ) col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec(native_type=col_type) @@ -1926,19 +1852,18 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma end_dttm=_until, time_grain=flt_grain, label=sqla_col.key, - template_processor=template_processor, ) ) else: raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) - where_clause_and += self.get_sqla_row_level_filters(template_processor) + where_clause_and += self.get_sqla_row_level_filters() if extras: where = extras.get("where") if where: try: - where = template_processor.process_template(f"({where})") + where = self.render_sql(f"({where})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1950,13 +1875,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma expression=where, database_id=self.database_id, schema=self.schema, - template_processor=template_processor, ) where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(f"({having})") + having = self.render_sql(f"({having})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1968,14 +1892,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma expression=having, database_id=self.database_id, schema=self.schema, - template_processor=template_processor, ) having_clause_and += [self.text(having)] if apply_fetch_values_predicate and self.fetch_values_predicate: - qry = qry.where( - self.get_fetch_values_predicate(template_processor=template_processor) - ) + qry = qry.where(self.get_fetch_values_predicate()) if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: @@ -2031,7 +1952,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_col=dttm_col, start_dttm=inner_from_dttm or from_dttm, end_dttm=inner_to_dttm or to_dttm, - template_processor=template_processor, ) ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) @@ -2043,7 +1963,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma series_limit_metric=series_limit_metric, metrics_by_name=metrics_by_name, columns_by_name=columns_by_name, - template_processor=template_processor, ) direction = sa.desc if order_desc else sa.asc subq = subq.order_by(direction(ob)) @@ -2066,7 +1985,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma series_limit_metric=series_limit_metric, metrics_by_name=metrics_by_name, columns_by_name=columns_by_name, - template_processor=template_processor, ), not order_desc, ) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 3c26c6c12fb81..e0743fbb34921 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -49,7 +49,6 @@ from superset import security_manager from superset.exceptions import SupersetSecurityException -from superset.jinja_context import BaseTemplateProcessor, get_template_processor from superset.models.helpers import ( AuditMixinNullable, ExploreMixin, @@ -158,9 +157,6 @@ class Query( __table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, changed_on),) - def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: - return get_template_processor(query=self, database=self.database, **kwargs) - def to_dict(self) -> dict[str, Any]: return { "changed_on": self.changed_on.isoformat(), @@ -361,12 +357,10 @@ def adhoc_column_to_sqla( self, col: "AdhocColumn", # type: ignore # noqa: F821 force_type_check: bool = False, - template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ Turn an adhoc column into a sqlalchemy column. :param col: Adhoc column definition - :param template_processor: template_processor instance :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -375,7 +369,6 @@ def adhoc_column_to_sqla( expression=col["sqlExpression"], database_id=self.database_id, schema=self.schema, - template_processor=template_processor, ) sqla_column = literal_column(expression) return self.make_sqla_column_compatible(sqla_column, label) diff --git a/superset/utils/json.py b/superset/utils/json.py index 0d7e31b9cd8af..3eb10b3db623d 100644 --- a/superset/utils/json.py +++ b/superset/utils/json.py @@ -16,6 +16,7 @@ # under the License. import decimal import logging +import re import uuid from datetime import date, datetime, time, timedelta from typing import Any, Callable, Optional, Union @@ -250,3 +251,11 @@ def loads( except JSONDecodeError as ex: logger.error("JSON is not valid %s", str(ex), exc_info=True) raise ex + + +def json_to_dict(json_str: str) -> dict[Any, Any]: + if json_str: + val = re.sub(",[ \t\r\n]+}", "}", json_str) + val = re.sub(",[ \t\r\n]+\\]", "]", val) + return loads(val) + return {} diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index 0107c142409bf..ef9b2446016a9 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging from collections import Counter from typing import Any @@ -57,6 +58,8 @@ from superset.views.datasource.utils import get_samples from superset.views.utils import sanitize_datasource_data +logger = logging.getLogger(__name__) + class Datasource(BaseSupersetView): """Datasource-related views""" @@ -146,6 +149,7 @@ def external_metadata( try: external_metadata = datasource.external_metadata() except SupersetException as ex: + logger.error(ex) return json_error_response(str(ex), status=400) return self.json_response(external_metadata) diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index f4b2bf688f3c9..81c1ab674bf74 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -20,19 +20,22 @@ from datetime import datetime import imp from contextlib import contextmanager -from typing import Any, Union, Optional +from typing import Any, Union, Optional, List from unittest.mock import Mock, patch, MagicMock +import logging import pandas as pd from flask import Response, g from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase from sqlalchemy.engine.interfaces import Dialect -from sqlalchemy.ext.declarative import DeclarativeMeta from sqlalchemy.orm import Session # noqa: F401 from sqlalchemy.sql import func from sqlalchemy.dialects.mysql import dialect +from sqlalchemy.exc import InvalidRequestError, IntegrityError +from sqlalchemy.ext.declarative import DeclarativeMeta + from tests.integration_tests.test_app import app, login from superset.sql_parse import CtasMethod from superset import db, security_manager @@ -203,8 +206,7 @@ def temporary_user( previous_g_user = g.user if hasattr(g, "user") else None try: if login: - resp = self.login(username=temp_user.username) - print(resp) + self.login(username=temp_user.username) else: g.user = temp_user yield temp_user @@ -592,12 +594,42 @@ def insert_dashboard( @contextmanager -def db_insert_temp_object(obj: DeclarativeMeta): - """Insert a temporary object in database; delete when done.""" +def db_insert_temp_object( + obj: DeclarativeMeta, unique_attrs: Optional[List[str]] = None +): + """ + Insert a temporary object in the database; delete when done. + Optionally will look at a combination of unique keys, and pre-delete if the object exists already. + """ + session = db.session try: - db.session.add(obj) - db.session.commit() + # Ensure the session is clean before starting + session.expire_all() + + if unique_attrs: + filter_by_kwargs = { + attr: getattr(obj, attr) for attr in unique_attrs if hasattr(obj, attr) + } + if filter_by_kwargs: + logging.debug(f"Deleting with filter: {filter_by_kwargs}") + with session.no_autoflush: + session.query(obj.__class__).filter_by(**filter_by_kwargs).delete() + session.commit() + + session.add(obj) + session.commit() yield obj + + except (IntegrityError, InvalidRequestError) as e: + session.rollback() + logging.error(f"Error: {e}") + raise e + finally: - db.session.delete(obj) - db.session.commit() + try: + if False and session.object_session(obj): + session.delete(obj) + session.commit() + except InvalidRequestError as e: + session.rollback() + logging.error(f"Error during cleanup: {e}") diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 69d15833cf44b..acbb9dea270a0 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -49,6 +49,8 @@ ) from tests.integration_tests.fixtures.datasource import get_datasource_post +dataset_unique_attrs = ["database_id", "schema", "table_name"] + @contextmanager def create_test_table_context(database: Database): @@ -270,24 +272,25 @@ def test_external_metadata_by_name_from_sqla_inspector(self): resp = self.get_json_resp(url) self.assertIn("error", resp) + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"ENABLE_TEMPLATE_PROCESSING": True}, + clear=True, + ) def test_external_metadata_for_virtual_table_template_params(self): self.login(ADMIN_USERNAME) table = Dataset( - table_name="dummy_sql_table_with_template_params", + table_name="dummy_sql_table_with_template_params_2", database=get_example_database(), schema=get_example_default_schema(), sql="select {{ foo }} as intcol", template_params=json.dumps({"foo": "123"}), ) - db.session.add(table) - db.session.commit() - - table = self.get_table(name="dummy_sql_table_with_template_params") - url = f"/datasource/external_metadata/table/{table.id}/" - resp = self.get_json_resp(url) - assert {o.get("column_name") for o in resp} == {"intcol"} - db.session.delete(table) - db.session.commit() + with db_insert_temp_object(table, dataset_unique_attrs): + url = f"/datasource/external_metadata/table/{table.id}/" + resp = self.get_json_resp(url) + print(resp) + assert {o.get("column_name") for o in resp} == {"intcol"} def test_external_metadata_for_malicious_virtual_table(self): self.login(ADMIN_USERNAME) @@ -297,7 +300,7 @@ def test_external_metadata_for_malicious_virtual_table(self): schema=get_example_default_schema(), sql="delete table birth_names", ) - with db_insert_temp_object(table): + with db_insert_temp_object(table, dataset_unique_attrs): url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) self.assertEqual(resp["error"], "Only `SELECT` statements are allowed") @@ -311,7 +314,7 @@ def test_external_metadata_for_multistatement_virtual_table(self): sql="select 123 as intcol, 'abc' as strcol;" "select 123 as intcol, 'abc' as strcol", ) - with db_insert_temp_object(table): + with db_insert_temp_object(table, dataset_unique_attrs): url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) self.assertEqual(resp["error"], "Only single queries supported") diff --git a/tests/integration_tests/superset_test_config.py b/tests/integration_tests/superset_test_config.py index 5ef3e2aa0a9c2..721e1aaabb1ea 100644 --- a/tests/integration_tests/superset_test_config.py +++ b/tests/integration_tests/superset_test_config.py @@ -32,7 +32,7 @@ logging.getLogger("flask_appbuilder.security.sqla.manager").setLevel(logging.WARNING) logging.getLogger("sqlalchemy.engine.Engine").setLevel(logging.WARNING) -SECRET_KEY = "dummy_secret_key_for_test_to_silence_warnings" +SECRET_KEY = "TEST_NON_DEV_SECRET" AUTH_USER_REGISTRATION_ROLE = "alpha" SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join( # noqa: F405 DATA_DIR,