Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add current_user variable and slugify filter to Jinja #1436

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion querybook/server/datasources/query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,10 @@ def get_templated_query(
):
try:
return render_templated_query(
query, var_config_to_var_dict(var_config), engine_id
query,
var_config_to_var_dict(var_config),
engine_id,
current_user.id,
)
except QueryTemplatingError as e:
raise RequestException(e, status_code=INVALID_SEMANTIC_STATUS_CODE)
Expand Down
23 changes: 19 additions & 4 deletions querybook/server/lib/query_analysis/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import meta

from slugify import slugify

from app.db import with_session
from lib import metastore
from logic import admin as admin_logic
from logic import user as user_logic
from models.user import User

_DAG = Dict[str, Set[str]]

Expand Down Expand Up @@ -164,12 +168,19 @@ def get_latest_partition(
return get_latest_partition


def get_templated_query_env(engine_id: int, session=None):
def get_templated_query_env(engine_id: int, user: User, session=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_templated_query_env(engine_id: int, user: User, session=None):
def get_templated_query_env(engine_id: int, user: Optional[User], session=None):

jinja_env = SandboxedEnvironment()

# Inject helper functions
jinja_env.globals.update(
latest_partition=create_get_latest_partition(engine_id, session=session)
latest_partition=create_get_latest_partition(engine_id, session=session),
current_user=user.username if user else None,
current_user_email=user.email if user else None,
)

# Inject filters
jinja_env.filters.update(
slugify=lambda x: slugify(x, separator="_"),
)

# Template rendering config
Expand Down Expand Up @@ -315,7 +326,7 @@ def get_templated_query_variables(variables_provided, jinja_env):


def render_templated_query(
query: str, variables: Dict[str, str], engine_id: int, session=None
query: str, variables: Dict[str, str], engine_id: int, uid: int, session=None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query: str, variables: Dict[str, str], engine_id: int, uid: int, session=None
query: str, variables: Dict[str, str], engine_id: int, uid: Optional[int] = None, session=None

) -> str:
"""Renders the templated query, with global variables such as today/yesterday
and functions such as `latest_partition`.
Expand All @@ -326,6 +337,8 @@ def render_templated_query(
Arguments:
query {str} -- The query string that would get rendered
raw_variables {Dict[str, str]} -- The variable name, variable value string pair
engine_id {int} -- The engine id that the query is running on
uid {int} -- The id of the user running the query

Raises:
UndefinedVariableException: If the variable refers to a variable that does not exist
Expand All @@ -334,7 +347,9 @@ def render_templated_query(
Returns:
str -- The rendered string
"""
jinja_env = get_templated_query_env(engine_id, session=session)
user = user_logic.get_user_by_id(uid, session=session) if uid else None

jinja_env = get_templated_query_env(engine_id, user, session=session)
try:
escaped_query = _escape_sql_comments(query)
variables_in_query = get_templated_variables_in_string(escaped_query, jinja_env)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def validate_with_templated_vars(
self, query: str, uid: int, engine_id: int, templated_vars: Dict[str, Any]
):
try:
templated_query = render_templated_query(query, templated_vars, engine_id)
templated_query = render_templated_query(
query, templated_vars, engine_id, uid
)
except QueryTemplatingError as e:
return [QueryValidationResult(0, 0, QueryValidationSeverity.ERROR, str(e))]

Expand Down
1 change: 1 addition & 0 deletions querybook/server/tasks/run_datadoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def run_datadoc_with_config(
raw_query,
data_doc.meta_variables,
engine_id,
user_id,
session=session,
)
except Exception as e:
Expand Down
131 changes: 126 additions & 5 deletions querybook/tests/test_lib/test_query_analysis/test_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class TemplatingTestCase(TestCase):
DEFAULT_ENGINE_ID = 1
DEFAULT_USER_ID = 1

def setUp(self):
self.engine_mock = mock.Mock()
Expand All @@ -35,6 +36,14 @@ def setUp(self):
self.addCleanup(get_metastore_loader_patch.stop)
self.get_metastore_loader_mock.return_value = self.metastore_loader_mock

self.user_mock = mock.Mock()
self.user_mock.username = "test_user"
self.user_mock.email = "[email protected]"
get_user_by_id_patch = mock.patch("logic.user.get_user_by_id")
self.get_user_by_id_mock = get_user_by_id_patch.start()
self.addCleanup(get_user_by_id_patch.stop)
self.get_user_by_id_mock.return_value = self.user_mock


class DetectCycleTestCase(TemplatingTestCase):
def test_simple_no_cycle(self):
Expand Down Expand Up @@ -212,7 +221,9 @@ def test_basic(self):
query = 'select * from table where dt="{{ date }}"'
variable = {"date": "1970-01-01"}
self.assertEqual(
render_templated_query(query, variable, self.DEFAULT_ENGINE_ID),
render_templated_query(
query, variable, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select * from table where dt="1970-01-01"',
)

Expand All @@ -224,7 +235,9 @@ def test_recursion(self):
"date3": "01",
}
self.assertEqual(
render_templated_query(query, variable, self.DEFAULT_ENGINE_ID),
render_templated_query(
query, variable, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select * from table where dt="1970-01-01"',
)

Expand All @@ -235,7 +248,10 @@ def test_global_vars(self):
query = 'select * from table where dt="{{ date }}"'
self.assertEqual(
render_templated_query(
query, {"date": "{{ today }}"}, self.DEFAULT_ENGINE_ID
query,
{"date": "{{ today }}"},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
),
'select * from table where dt="1970-01-01"',
)
Expand All @@ -248,6 +264,7 @@ def test_exception(self):
'select * from {{ table }} where dt="{{ date }}"',
{"table": "foo"},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

# Missing variable but recursive
Expand All @@ -257,6 +274,7 @@ def test_exception(self):
'select * from {{ table }} where dt="{{ date }}"',
{"table": "foo", "date": "{{ bar }}"},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

# Circular dependency
Expand All @@ -266,6 +284,7 @@ def test_exception(self):
'select * from {{ table }} where dt="{{ date }}"',
{"date": "{{ date2 }}", "date2": "{{ date }}"},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

# Invalid template usage
Expand All @@ -275,6 +294,7 @@ def test_exception(self):
'select * from {{ table where dt="{{ date }}"',
{"table": "foo", "date": "{{ bar }}"},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

def test_escape_comments(self):
Expand All @@ -291,7 +311,10 @@ def test_escape_comments(self):
{{ end_date}}*/
-- {{ end_date }}"""
self.assertEqual(
render_templated_query(query, {}, self.DEFAULT_ENGINE_ID), query
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
query,
)

def test_escape_comments_non_greedy(self):
Expand All @@ -306,7 +329,10 @@ def test_escape_comments_non_greedy(self):
"""
self.assertEqual(
render_templated_query(
query_non_greedy, {"test": "render"}, self.DEFAULT_ENGINE_ID
query_non_greedy,
{"test": "render"},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
),
"""select * from
/*
Expand Down Expand Up @@ -390,6 +416,7 @@ def test_invalid_engine_id(self):
'select * from table where dt="{{ latest_partition("default.table", "dt") }}"',
{},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

def test_invalid_table_name(self):
Expand All @@ -399,6 +426,7 @@ def test_invalid_table_name(self):
'select * from table where dt="{{ latest_partition("table", "dt") }}"',
{},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

def test_invalid_partition_name(self):
Expand All @@ -411,6 +439,7 @@ def test_invalid_partition_name(self):
'select * from table where dt="{{ latest_partition("default.table", "date") }}"',
{},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

def test_no_latest_partition(self):
Expand All @@ -421,6 +450,7 @@ def test_no_latest_partition(self):
'select * from table where dt="{{ latest_partition("default.table", "dt") }}"',
{},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)

def test_multiple_partition_columns(self):
Expand All @@ -444,6 +474,7 @@ def test_render_templated_query(self):
'select * from table where dt="{{ latest_partition("default.table", "dt") }}"',
{},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)
self.assertEqual(templated_query, 'select * from table where dt="2021-01-01"')

Expand All @@ -452,6 +483,7 @@ def test_recursive_get_latest_partition_variable(self):
'select * from table where dt="{{ latest_part }}"',
{"latest_part": '{{latest_partition("default.table", "dt")}}'},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)
self.assertEqual(templated_query, 'select * from table where dt="2021-01-01"')

Expand All @@ -465,4 +497,93 @@ def test_multiple_partition_columns_partition_not_provided(self):
'select * from table where dt="{{ latest_partition("default.table") }}"',
{},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
)


class CurrentUserTestCase(TemplatingTestCase):
def test_current_user(self):
query = 'select * from table where user="{{ current_user }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
f'select * from table where user="{self.user_mock.username}"',
)

def test_current_user_email(self):
query = 'select * from table where user="{{ current_user_email }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
f'select * from table where user="{self.user_mock.email}"',
)


class SlugifyTestCase(TemplatingTestCase):
def test_simple(self):
query = 'select "{{ "Hello World" | slugify }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select "hello_world"',
)

def test_remove_special_characters(self):
query = 'select "{{ "Hello World #2024" | slugify }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select "hello_world_2024"',
)

def test_trim_leading_and_trailing_spaces(self):
query = 'select "{{ " Hello World " | slugify }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select "hello_world"',
)

def test_chinese(self):
query = 'select "{{ "你好世界" | slugify }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select "ni_hao_shi_jie"',
)

def test_empty_string(self):
query = 'select "{{ "" | slugify }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select ""',
)

def test_non_alphanumeric_characters(self):
query = 'select "{{ "!@#$%^&*()" | slugify }}"'
self.assertEqual(
render_templated_query(
query, {}, self.DEFAULT_ENGINE_ID, self.DEFAULT_USER_ID
),
'select ""',
)

def test_today(self):
query = "select * from report_{{ today | slugify }}"
self.assertEqual(
render_templated_query(
query,
{"today": "2024-01-01"},
self.DEFAULT_ENGINE_ID,
self.DEFAULT_USER_ID,
),
"select * from report_2024_01_01",
)
Loading
Loading