Skip to content

Commit

Permalink
feat: Allow custom schema def for tmp tables generated by incremental (
Browse files Browse the repository at this point in the history
…#659)

Co-authored-by: nicor88 <[email protected]>
  • Loading branch information
pierrebzl and nicor88 authored May 28, 2024
1 parent 8faa921 commit 97430f9
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 39 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ athena:
- For incremental models using insert overwrite strategy on hive table
- Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid
- Useful if you are looking to run multiple dbt build inserting in the same table in parallel
- `temp_schema` (`default=none`)
- For incremental models, it allows to define a schema to hold temporary create statements
used in incremental model runs
- Schema will be created in the model target database if does not exist
- `lf_tags_config` (`default=none`)
- [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns
- `enabled` (`default=False`) whether LF tags management is enabled for a model
Expand Down
2 changes: 2 additions & 0 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class AthenaConfig(AdapterConfig):
partitions_limit: Maximum numbers of partitions when batching.
force_batch: Skip creating the table as ctas and run the operation directly in batch insert mode.
unique_tmp_table_suffix: Enforce the use of a unique id as tmp table suffix instead of __dbt_tmp.
temp_schema: Define in which schema to create temporary tables used in incremental runs.
"""

work_group: Optional[str] = None
Expand All @@ -120,6 +121,7 @@ class AthenaConfig(AdapterConfig):
partitions_limit: Optional[int] = None
force_batch: bool = False
unique_tmp_table_suffix: bool = False
temp_schema: Optional[str] = None


class AthenaAdapter(SQLAdapter):
Expand Down
15 changes: 15 additions & 0 deletions dbt/include/athena/macros/adapters/relation.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@
{%- endcall %}
{%- endmacro %}

{% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %}
{%- set temp_identifier = base_relation.identifier ~ suffix -%}
{%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%}

{%- if temp_schema is not none -%}
{%- set temp_relation = temp_relation.incorporate(path={
"identifier": temp_identifier,
"schema": temp_schema
}) -%}
{%- do create_schema(temp_relation) -%}
{% endif %}

{{ return(temp_relation) }}
{% endmacro %}

{% macro athena__rename_relation(from_relation, to_relation) %}
{% call statement('rename_relation') -%}
alter table {{ from_relation.render_hive() }} rename to `{{ to_relation.schema }}`.`{{ to_relation.identifier }}`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
{% set partitioned_by = config.get('partitioned_by') %}
{% set force_batch = config.get('force_batch', False) | as_bool -%}
{% set unique_tmp_table_suffix = config.get('unique_tmp_table_suffix', False) | as_bool -%}
{% set temp_schema = config.get('temp_schema') %}
{% set target_relation = this.incorporate(type='table') %}
{% set existing_relation = load_relation(this) %}
-- If using insert_overwrite on Hive table, allow to set a unique tmp table suffix
Expand All @@ -22,7 +23,7 @@
{% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ tmp_table_suffix,
schema=schema,
database=database) %}
{% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix) %}
{% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix, temp_schema=temp_schema) %}

-- If no partitions are used with insert_overwrite, we fall back to append mode.
{% if partitioned_by is none and strategy == 'insert_overwrite' %}
Expand Down
108 changes: 108 additions & 0 deletions tests/functional/adapter/test_incremental_tmp_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
import yaml
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_running_create_statements,
)

from dbt.contracts.results import RunStatus
from dbt.tests.util import run_dbt

models__schema_tmp_sql = """
{{ config(
materialized='incremental',
incremental_strategy='insert_overwrite',
partitioned_by=['date_column'],
temp_schema=var('temp_schema_name')
)
}}
select
random() as rnd,
cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column
"""


class TestIncrementalTmpSchema:
@pytest.fixture(scope="class")
def models(self):
return {"schema_tmp.sql": models__schema_tmp_sql}

def test__schema_tmp(self, project, capsys):
relation_name = "schema_tmp"
temp_schema_name = f"{project.test_schema}_tmp"
drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"

vars_dict = {
"temp_schema_name": temp_schema_name,
"logical_date": "2024-01-01",
}

first_model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
yaml.safe_dump(vars_dict),
"--log-level",
"debug",
"--log-format",
"json",
]
)

first_model_run_result = first_model_run.results[0]

assert first_model_run_result.status == RunStatus.Success

records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 1

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)

assert len(athena_running_create_statements) == 1

incremental_model_run_result_table_name = extract_create_statement_table_names(
athena_running_create_statements[0]
)[0]

assert temp_schema_name not in incremental_model_run_result_table_name

vars_dict["logical_date"] = "2024-01-02"
incremental_model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
yaml.safe_dump(vars_dict),
"--log-level",
"debug",
"--log-format",
"json",
]
)

incremental_model_run_result = incremental_model_run.results[0]

assert incremental_model_run_result.status == RunStatus.Success

records_count_incremental_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_incremental_run == 2

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)

assert len(athena_running_create_statements) == 1

incremental_model_run_result_table_name = extract_create_statement_table_names(
athena_running_create_statements[0]
)[0]

assert temp_schema_name == incremental_model_run_result_table_name.split(".")[1].strip('"')

project.run_sql(drop_temp_schema)
45 changes: 7 additions & 38 deletions tests/functional/adapter/test_unique_tmp_table_suffix.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import re
from typing import List

import pytest
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_running_create_statements,
)

from dbt.contracts.results import RunStatus
from dbt.tests.util import run_dbt
Expand All @@ -21,39 +23,6 @@
"""


def extract_running_create_statements(dbt_run_capsys_output: str) -> List[str]:
sql_create_statements = []
# Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..."
for events_msg in dbt_run_capsys_output.split("\n")[1:]:
base_msg_data = None
# Best effort solution to avoid invalid records and blank lines
try:
base_msg_data = json.loads(events_msg).get("data")
except json.JSONDecodeError:
pass
"""First run will not produce data.sql object in the execution logs, only data.base_msg
containing the "Running Athena query:" initial create statement.
Subsequent incremental runs will only contain the insert from the tmp table into the model
table destination.
Since we want to compare both run create statements, we need to handle both cases"""
if base_msg_data:
base_msg = base_msg_data.get("base_msg")
if "Running Athena query:" in str(base_msg):
if "create table" in base_msg:
sql_create_statements.append(base_msg)

if base_msg_data.get("conn_name") == "model.test.unique_tmp_table_suffix" and "sql" in base_msg_data:
if "create table" in base_msg_data.get("sql"):
sql_create_statements.append(base_msg_data.get("sql"))

return sql_create_statements


def extract_create_statement_table_names(sql_create_statement: str) -> List[str]:
table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement)
return [table_name.rstrip() for table_name in table_names]


class TestUniqueTmpTableSuffix:
@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -86,7 +55,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert first_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out)
athena_running_create_statements = extract_running_create_statements(out, relation_name)

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -118,7 +87,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out)
athena_running_create_statements = extract_running_create_statements(out, relation_name)

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -150,7 +119,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out)
athena_running_create_statements = extract_running_create_statements(out, relation_name)

incremental_model_run_result_table_name_2 = extract_create_statement_table_names(
athena_running_create_statements[0]
Expand Down
36 changes: 36 additions & 0 deletions tests/functional/adapter/utils/parse_dbt_run_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
import re
from typing import List


def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]:
sql_create_statements = []
# Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..."
for events_msg in dbt_run_capsys_output.split("\n")[1:]:
base_msg_data = None
# Best effort solution to avoid invalid records and blank lines
try:
base_msg_data = json.loads(events_msg).get("data")
except json.JSONDecodeError:
pass
"""First run will not produce data.sql object in the execution logs, only data.base_msg
containing the "Running Athena query:" initial create statement.
Subsequent incremental runs will only contain the insert from the tmp table into the model
table destination.
Since we want to compare both run create statements, we need to handle both cases"""
if base_msg_data:
base_msg = base_msg_data.get("base_msg")
if "Running Athena query:" in str(base_msg):
if "create table" in base_msg:
sql_create_statements.append(base_msg)

if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data:
if "create table" in base_msg_data.get("sql"):
sql_create_statements.append(base_msg_data.get("sql"))

return sql_create_statements


def extract_create_statement_table_names(sql_create_statement: str) -> List[str]:
table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement)
return [table_name.rstrip() for table_name in table_names]

0 comments on commit 97430f9

Please sign in to comment.