Skip to content

Commit

Permalink
fixage
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Jun 6, 2024
1 parent b6a4033 commit f952323
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 231 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ disable=
output-format=text

# Tells whether to display a full report or only the messages
reports=yes
reports=no

# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
Expand Down
9 changes: 9 additions & 0 deletions docker/pythonpath_dev/superset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class CeleryConfig:

SQLLAB_CTAS_NO_LIMIT = True

DB_NAME = "test"
DB_USER = "superset"
DB_PASSWORD = "superset"
SQLALCHEMY_DATABASE_URI = (
f"{DATABASE_DIALECT}://"
f"{DATABASE_USER}:{DATABASE_PASSWORD}@"
f"{DATABASE_HOST}:{DATABASE_PORT}/{DB_NAME}"
)

#
# Optionally import superset_config_docker.py (which will have been included on
# the PYTHONPATH) in order to allow for local settings to be overridden
Expand Down
16 changes: 6 additions & 10 deletions superset/commands/report/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from celery.exceptions import SoftTimeLimitExceeded
from flask_babel import lazy_gettext as _

from superset import app, jinja_context, security_manager
from superset import app, security_manager
from superset.commands.base import BaseCommand
from superset.commands.report.exceptions import (
AlertQueryError,
Expand Down Expand Up @@ -143,23 +143,19 @@ def _execute_query(self) -> pd.DataFrame:
:raises AlertQueryError: SQL query is not valid
:raises AlertQueryTimeout: The SQL query received a celery soft timeout
"""
sql_template = jinja_context.get_template_processor(
database=self._report_schedule.database
)
rendered_sql = sql_template.process_template(self._report_schedule.sql)
try:
limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql(
rendered_sql, ALERT_SQL_LIMIT
)

executor, username = get_executor( # pylint: disable=unused-variable
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
model=self._report_schedule,
)
user = security_manager.find_user(username)
with override_user(user):
start = default_timer()
df = self._report_schedule.database.get_df(sql=limited_rendered_sql)
df = self._report_schedule.database.get_df(
sql=self._report_schedule.sql,
limit=ALERT_SQL_LIMIT,
render_template=True,
)
stop = default_timer()
logger.info(
"Query for %s took %.2f ms",
Expand Down
61 changes: 22 additions & 39 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,13 @@ def type_generic(self) -> utils.GenericDataType | None:
def get_sqla_col(
self,
label: str | None = None,
template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.column_name
db_engine_spec = self.db_engine_spec
column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
type_ = column_spec.sqla_type if column_spec else None
if expression := self.expression:
if template_processor:
expression = template_processor.process_template(expression)
expression = self.table.render_sql(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
Expand All @@ -328,14 +326,12 @@ def get_timestamp_expression(
self,
time_grain: str | None,
label: str | None = None,
template_processor: BaseTemplateProcessor | None = None,
) -> TimestampExpression | Label:
"""
Return a SQLAlchemy Core element representation of self to be used in a query.
:param time_grain: Optional time grain, e.g. P1Y
:param label: alias/label that column is expected to have
:param template_processor: template processor
:return: A TimeExpression object wrapped in a Label if supported by db
"""
label = label or utils.DTTM_ALIAS
Expand All @@ -350,8 +346,7 @@ def get_timestamp_expression(
sqla_col = column(self.column_name, type_=type_)
return self.database.make_sqla_column_compatible(sqla_col, label)
if expression := self.expression:
if template_processor:
expression = template_processor.process_template(expression)
expression = self.table.render_sql(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
Expand Down Expand Up @@ -426,12 +421,10 @@ def __repr__(self) -> str:
def get_sqla_col(
self,
label: str | None = None,
template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.metric_name
expression = self.expression
if template_processor:
expression = template_processor.process_template(expression)
expression = self.table.render_sql(expression)

sqla_col: ColumnClause = literal_column(expression)
return self.table.database.make_sqla_column_compatible(sqla_col, label)
Expand Down Expand Up @@ -1233,13 +1226,8 @@ def extra_dict(self) -> dict[str, Any]:

def get_fetch_values_predicate(
self,
template_processor: BaseTemplateProcessor | None = None,
) -> TextClause:
fetch_values_predicate = self.fetch_values_predicate
if template_processor:
fetch_values_predicate = template_processor.process_template(
fetch_values_predicate
)
fetch_values_predicate = self.render_sql(self.fetch_values_predicate)
try:
return self.text(fetch_values_predicate)
except TemplateError as ex:
Expand All @@ -1250,8 +1238,12 @@ def get_fetch_values_predicate(
)
) from ex

def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(table=self, database=self.database, **kwargs)
def render_sql(self, sql: str, **kwargs: Any) -> str:
return self.database.render_sql(sql, **self.template_params_dict)

@property
def template_params_dict(self) -> dict[Any, Any]:
return json.json_to_dict(self.template_params)

def get_query_str(self, query_obj: QueryObjectDict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
Expand All @@ -1264,9 +1256,7 @@ def get_sqla_table(self) -> TableClause:
tbl.schema = self.schema
return tbl

def get_from_clause(
self, template_processor: BaseTemplateProcessor | None = None
) -> tuple[TableClause | Alias, str | None]:
def get_from_clause(self) -> tuple[TableClause | Alias, str | None]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
Expand All @@ -1275,7 +1265,7 @@ def get_from_clause(
if not self.is_virtual:
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
from_sql = self.get_rendered_sql()
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
Expand All @@ -1298,14 +1288,12 @@ def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
columns_by_name: dict[str, TableColumn],
template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.
:param dict metric: Adhoc metric definition
:param dict columns_by_name: Columns for the current table
:param template_processor: template_processor instance
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
Expand All @@ -1317,9 +1305,7 @@ def adhoc_metric_to_sqla(
column_name = cast(str, metric_column.get("column_name"))
table_column: TableColumn | None = columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col(
template_processor=template_processor
)
sqla_column = table_column.get_sqla_col()
else:
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
Expand All @@ -1328,7 +1314,9 @@ def adhoc_metric_to_sqla(
expression=metric["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
template_processor=get_template_processor(
table=self, database=self.database
),
)
sqla_metric = literal_column(expression)
else:
Expand All @@ -1340,7 +1328,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
self,
col: AdhocColumn,
force_type_check: bool = False,
template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc column into a sqlalchemy column.
Expand All @@ -1349,7 +1336,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
:param force_type_check: Should the column type be checked in the db.
This is needed to validate if a filter with an adhoc column
is applicable.
:param template_processor: template_processor instance
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
Expand All @@ -1358,24 +1344,24 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
expression=col["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
template_processor=get_template_processor(
table=self, database=self.database
),
)
time_grain = col.get("timeGrain")
has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain
is_dttm = False
pdf = None
if col_in_metadata := self.get_column(expression):
sqla_column = col_in_metadata.get_sqla_col(
template_processor=template_processor
)
sqla_column = col_in_metadata.get_sqla_col()
is_dttm = col_in_metadata.is_temporal
pdf = col_in_metadata.python_date_format
else:
sqla_column = literal_column(expression)
if has_timegrain or force_type_check:
try:
# probe adhoc column type
tbl, _ = self.get_from_clause(template_processor)
tbl, _ = self.get_from_clause()
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(
Expand Down Expand Up @@ -1435,7 +1421,6 @@ def _get_series_orderby(
series_limit_metric: Metric,
metrics_by_name: dict[str, SqlMetric],
columns_by_name: dict[str, TableColumn],
template_processor: BaseTemplateProcessor | None = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
Expand All @@ -1444,9 +1429,7 @@ def _get_series_orderby(
isinstance(series_limit_metric, str)
and series_limit_metric in metrics_by_name
):
ob = metrics_by_name[series_limit_metric].get_sqla_col(
template_processor=template_processor
)
ob = metrics_by_name[series_limit_metric].get_sqla_col()
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=series_limit_metric)
Expand Down
18 changes: 3 additions & 15 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
SupersetSecurityException,
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType

Expand Down Expand Up @@ -102,9 +101,7 @@ def get_virtual_table_metadata(dataset: Dataset) -> list[ResultSetColumnType]:
)

db_engine_spec = dataset.database.db_engine_spec
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
sql = dataset.render_sql(dataset.sql, **dataset.template_params_dict)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
Expand Down Expand Up @@ -137,18 +134,9 @@ def get_columns_description(
schema: str | None,
query: str,
) -> list[ResultSetColumnType]:
# TODO(villebro): refactor to use same code that's used by
# sql_lab.py:execute_sql_statements
db_engine_spec = database.db_engine_spec
try:
with database.get_raw_connection(catalog=catalog, schema=schema) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query, database)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
result_set = database.get_result_set(query, catalog, schema, limit=1)
return result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex

Expand Down
Loading

0 comments on commit f952323

Please sign in to comment.