Skip to content

Commit

Permalink
Feature/add optional dbt resource type (#14)
Browse files Browse the repository at this point in the history
* Support for different dbt resources

* test: add parse test

* test: Clean up excess input

* style: Modify drt name and formatting code

---------

Co-authored-by: luoweiying <[email protected]>
  • Loading branch information
yingyingqiqi and luoweiying authored Apr 2, 2023
1 parent edacd69 commit 5098b7b
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 52 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Commands:
```bash
# select all models in dbt_resto
dbterd run -ad "samples/dbtresto" -o "target"
# select all models in dbt_resto, Select multiple dbt resources
dbterd run -ad "samples/dbtresto" -o "target" -rt "model" -rt "source"
# select only models in dbt_resto excluding staging
dbterd run -ad "samples/dbtresto" -o "target" -s model.dbt_resto -ns model.dbt_resto.staging
# select only models in schema name "mart" excluding staging
Expand Down
107 changes: 58 additions & 49 deletions dbterd/adapters/targets/dbml/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,27 @@
def parse(manifest, catalog, **kwargs):
# Parse Table
tables = get_tables(manifest, catalog)
# -- apply selection
# Apply selection
select_rule = (kwargs.get("select") or "").lower().split(":")
if select_rule[0].startswith("schema"):
select_rule = select_rule[-1]
tables = [
x
for x in tables
if x.schema.startswith(select_rule) # --select schema:analytics
or f"{x.database}.{x.schema}".startswith(
select_rule
) # --select schema:db.analytics
]
else:
select_rule = select_rule[-1] # only take care of name
tables = [x for x in tables if x.name.startswith(select_rule)]
resource_type_rule = kwargs.get("resource_type") or ""

def filter_table_select(table):
if select_rule[0].startswith("schema"):
schema = f"{table.database}.{table.schema}"
return schema.startswith(select_rule[-1]) or table.schema.startswith(
select_rule[-1]
)
else:
return table.name.startswith(select_rule[-1])

tables = [table for table in tables if filter_table_select(table)]
tables = [table for table in tables if table.resource_type in resource_type_rule]

# -- apply exclusion (take care of name only)
tables = [
x
for x in tables
if kwargs.get("exclude") is None or not x.name.startswith(kwargs.get("exclude"))
]

exclude_rule = kwargs.get("exclude")
if exclude_rule:
tables = [table for table in tables if not table.name.startswith(exclude_rule)]

# Parse Ref
relationships = get_relationships(manifest)
Expand Down Expand Up @@ -71,51 +70,61 @@ def parse(manifest, catalog, **kwargs):

def get_tables(manifest, catalog):
"""Extract tables from dbt artifacts"""
tables = [
Table(
name=x,
raw_sql=get_compiled_sql(manifest.nodes[x]),
database=manifest.nodes[x].database.lower(),
schema=manifest.nodes[x].schema_.lower(),

def create_table_and_columns(table_name, resource, catalog_resource=None):
table = Table(
name=table_name,
raw_sql=get_compiled_sql(resource),
database=resource.database.lower(),
schema=resource.schema_.lower(),
columns=[],
resource_type=table_name.split(".")[0],
)
for x in manifest.nodes
if x.startswith("model")
]

for table in tables:
# Pull columns from the catalog and use the data types declared there
# Catalog is our primary source of information about the target db
if table.name in catalog.nodes: # table might not live yet
cat_columns = catalog.nodes[table.name].columns
for column, metadata in cat_columns.items():
if catalog_resource:
for column, metadata in catalog_resource.columns.items():
table.columns.append(
Column(
name=str(column).lower(),
data_type=str(metadata.type).lower(),
)
)

# Handle cases where columns don't exist yet, but are in manifest
man_columns = manifest.nodes[table.name].columns
for column in man_columns:
column_name = str(column).strip(
'"'
) # remove double quotes from column name if any
if column_name.lower() in [x.name for x in table.columns]:
# Already exists in the remote
continue
table.columns.append(
Column(
name=column_name.lower(),
data_type=str(man_columns[column].data_type or "unknown").lower(),
for column_name, column_metadata in resource.columns.items():
column_name = column_name.strip('"')
if not any(c.name.lower() == column_name.lower() for c in table.columns):
table.columns.append(
Column(
name=column_name.lower(),
data_type=str(column_metadata.data_type or "unknown").lower(),
)
)
)

# Fallback: add dummy column if cannot find any info
if not table.columns:
table.columns.append(Column())

return table

tables = []

if hasattr(manifest, "nodes"):
for table_name, node in manifest.nodes.items():
if (
table_name.startswith("model.")
or table_name.startswith("seed.")
or table_name.startswith("snapshot.")
):
catalog_node = catalog.nodes.get(table_name)
table = create_table_and_columns(table_name, node, catalog_node)
tables.append(table)

if hasattr(manifest, "sources"):
for table_name, source in manifest.sources.items():
if table_name.startswith("source"):
catalog_source = catalog.sources.get(table_name)
table = create_table_and_columns(table_name, source, catalog_source)
tables.append(table)

return tables


Expand Down
1 change: 1 addition & 0 deletions dbterd/adapters/targets/dbml/engine/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Table:
schema: str
columns: Optional[List[Column]] = None
raw_sql: Optional[str] = None
resource_type: str = "model"


@dataclass
Expand Down
8 changes: 8 additions & 0 deletions dbterd/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def common_params(func):
default=None,
type=click.STRING,
)
@click.option(
"--resource-type",
"-rt",
help="Specified dbt resource type(seed, model, source, snapshot),default:model, use examples, -rt model -rt source",
default=["model"],
multiple=True,
type=click.STRING,
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs) # pragma: no cover
Expand Down
23 changes: 20 additions & 3 deletions tests/unit/adapters/targets/dbml/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class DummyCatalogTable:

class TestDbmlEngine:
@pytest.mark.parametrize(
"tables, relationships, select, expected",
"tables, relationships, select, resource_type, expected",
[
(
[
Expand All @@ -164,6 +164,7 @@ class TestDbmlEngine:
],
[],
"",
["model"],
"""//Tables (based on the selection criteria)
//--configured at schema: --database--.--schema--
Table "model.dbt_resto.table1" {
Expand All @@ -188,6 +189,13 @@ class TestDbmlEngine:
columns=[Column(name="name2", data_type="--name2-type2--")],
raw_sql="--irrelevant--",
),
Table(
name="source.dbt_resto.table3",
database="--database3--",
schema="--schema3--",
columns=[Column(name="name3", data_type="--name3-type3--")],
raw_sql="--irrelevant--",
),
],
[
Ref(
Expand All @@ -202,6 +210,7 @@ class TestDbmlEngine:
),
],
"",
["model", "source"],
"""//Tables (based on the selection criteria)
//--configured at schema: --database--.--schema--
Table "model.dbt_resto.table1" {
Expand All @@ -213,6 +222,10 @@ class TestDbmlEngine:
"name2" "--name2-type2--"
"name-notexist2" "unknown"
}
//--configured at schema: --database3--.--schema3--
Table "source.dbt_resto.table3" {
"name3" "--name3-type3--"
}
//Refs (based on the DBT Relationship Tests)
Ref: "model.dbt_resto.table1"."name1" > "model.dbt_resto.table2"."name2"
Ref: "model.dbt_resto.table1"."name-notexist1" > "model.dbt_resto.table2"."name-notexist2"
Expand Down Expand Up @@ -243,6 +256,7 @@ class TestDbmlEngine:
)
],
"schema:--schema--",
["model", "source"],
"""//Tables (based on the selection criteria)
//--configured at schema: --database--.--schema--
Table "model.dbt_resto.table1" {
Expand All @@ -253,7 +267,7 @@ class TestDbmlEngine:
),
],
)
def test_parse(self, tables, relationships, select, expected):
def test_parse(self, tables, relationships, select, resource_type, expected):
with mock.patch(
"dbterd.adapters.targets.dbml.engine.engine.get_tables", return_value=tables
) as mock_get_tables:
Expand All @@ -262,7 +276,10 @@ def test_parse(self, tables, relationships, select, expected):
return_value=relationships,
) as mock_get_relationships:
dbml = engine.parse(
manifest="--manifest--", catalog="--catalog--", select=select
manifest="--manifest--",
catalog="--catalog--",
select=select,
resource_type=resource_type,
)
print("dbml ", dbml.replace(" ", "").replace("\n", ""))
print("expected", expected.replace(" ", "").replace("\n", ""))
Expand Down

0 comments on commit 5098b7b

Please sign in to comment.