diff --git a/README.md b/README.md index d6b8220..c90d656 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ Commands: ```bash # select all models in dbt_resto dbterd run -ad "samples/dbtresto" -o "target" +# select all models in dbt_resto, Select multiple dbt resources +dbterd run -ad "samples/dbtresto" -o "target" -rt "model" -rt "source" # select only models in dbt_resto excluding staging dbterd run -ad "samples/dbtresto" -o "target" -s model.dbt_resto -ns model.dbt_resto.staging # select only models in schema name "mart" excluding staging diff --git a/dbterd/adapters/targets/dbml/engine/engine.py b/dbterd/adapters/targets/dbml/engine/engine.py index 4c718a3..a7f6e7f 100644 --- a/dbterd/adapters/targets/dbml/engine/engine.py +++ b/dbterd/adapters/targets/dbml/engine/engine.py @@ -4,28 +4,27 @@ def parse(manifest, catalog, **kwargs): # Parse Table tables = get_tables(manifest, catalog) - # -- apply selection + # Apply selection select_rule = (kwargs.get("select") or "").lower().split(":") - if select_rule[0].startswith("schema"): - select_rule = select_rule[-1] - tables = [ - x - for x in tables - if x.schema.startswith(select_rule) # --select schema:analytics - or f"{x.database}.{x.schema}".startswith( - select_rule - ) # --select schema:db.analytics - ] - else: - select_rule = select_rule[-1] # only take care of name - tables = [x for x in tables if x.name.startswith(select_rule)] + resource_type_rule = kwargs.get("resource_type") or "" + + def filter_table_select(table): + if select_rule[0].startswith("schema"): + schema = f"{table.database}.{table.schema}" + return schema.startswith(select_rule[-1]) or table.schema.startswith( + select_rule[-1] + ) + else: + return table.name.startswith(select_rule[-1]) + + tables = [table for table in tables if filter_table_select(table)] + tables = [table for table in tables if table.resource_type in resource_type_rule] # -- apply exclusion (take care of name only) - tables = [ - x - for x in tables - if kwargs.get("exclude") is None or not x.name.startswith(kwargs.get("exclude")) - ] + + exclude_rule = kwargs.get("exclude") + if exclude_rule: + tables = [table for table in tables if not table.name.startswith(exclude_rule)] # Parse Ref relationships = get_relationships(manifest) @@ -71,24 +70,19 @@ def parse(manifest, catalog, **kwargs): def get_tables(manifest, catalog): """Extract tables from dbt artifacts""" - tables = [ - Table( - name=x, - raw_sql=get_compiled_sql(manifest.nodes[x]), - database=manifest.nodes[x].database.lower(), - schema=manifest.nodes[x].schema_.lower(), + + def create_table_and_columns(table_name, resource, catalog_resource=None): + table = Table( + name=table_name, + raw_sql=get_compiled_sql(resource), + database=resource.database.lower(), + schema=resource.schema_.lower(), columns=[], + resource_type=table_name.split(".")[0], ) - for x in manifest.nodes - if x.startswith("model") - ] - for table in tables: - # Pull columns from the catalog and use the data types declared there - # Catalog is our primary source of information about the target db - if table.name in catalog.nodes: # table might not live yet - cat_columns = catalog.nodes[table.name].columns - for column, metadata in cat_columns.items(): + if catalog_resource: + for column, metadata in catalog_resource.columns.items(): table.columns.append( Column( name=str(column).lower(), @@ -96,26 +90,41 @@ def get_tables(manifest, catalog): ) ) - # Handle cases where columns don't exist yet, but are in manifest - man_columns = manifest.nodes[table.name].columns - for column in man_columns: - column_name = str(column).strip( - '"' - ) # remove double quotes from column name if any - if column_name.lower() in [x.name for x in table.columns]: - # Already exists in the remote - continue - table.columns.append( - Column( - name=column_name.lower(), - data_type=str(man_columns[column].data_type or "unknown").lower(), + for column_name, column_metadata in resource.columns.items(): + column_name = column_name.strip('"') + if not any(c.name.lower() == column_name.lower() for c in table.columns): + table.columns.append( + Column( + name=column_name.lower(), + data_type=str(column_metadata.data_type or "unknown").lower(), + ) ) - ) - # Fallback: add dummy column if cannot find any info if not table.columns: table.columns.append(Column()) + return table + + tables = [] + + if hasattr(manifest, "nodes"): + for table_name, node in manifest.nodes.items(): + if ( + table_name.startswith("model.") + or table_name.startswith("seed.") + or table_name.startswith("snapshot.") + ): + catalog_node = catalog.nodes.get(table_name) + table = create_table_and_columns(table_name, node, catalog_node) + tables.append(table) + + if hasattr(manifest, "sources"): + for table_name, source in manifest.sources.items(): + if table_name.startswith("source"): + catalog_source = catalog.sources.get(table_name) + table = create_table_and_columns(table_name, source, catalog_source) + tables.append(table) + return tables diff --git a/dbterd/adapters/targets/dbml/engine/meta.py b/dbterd/adapters/targets/dbml/engine/meta.py index a13a999..4960c26 100644 --- a/dbterd/adapters/targets/dbml/engine/meta.py +++ b/dbterd/adapters/targets/dbml/engine/meta.py @@ -26,6 +26,7 @@ class Table: schema: str columns: Optional[List[Column]] = None raw_sql: Optional[str] = None + resource_type: str = "model" @dataclass diff --git a/dbterd/cli/params.py b/dbterd/cli/params.py index 6627c3a..3078f8b 100644 --- a/dbterd/cli/params.py +++ b/dbterd/cli/params.py @@ -66,6 +66,14 @@ def common_params(func): default=None, type=click.STRING, ) + @click.option( + "--resource-type", + "-rt", + help="Specified dbt resource type(seed, model, source, snapshot),default:model, use examples, -rt model -rt source", + default=["model"], + multiple=True, + type=click.STRING, + ) @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) # pragma: no cover diff --git a/tests/unit/adapters/targets/dbml/test_engine.py b/tests/unit/adapters/targets/dbml/test_engine.py index b97c04f..50733d2 100644 --- a/tests/unit/adapters/targets/dbml/test_engine.py +++ b/tests/unit/adapters/targets/dbml/test_engine.py @@ -150,7 +150,7 @@ class DummyCatalogTable: class TestDbmlEngine: @pytest.mark.parametrize( - "tables, relationships, select, expected", + "tables, relationships, select, resource_type, expected", [ ( [ @@ -164,6 +164,7 @@ class TestDbmlEngine: ], [], "", + ["model"], """//Tables (based on the selection criteria) //--configured at schema: --database--.--schema-- Table "model.dbt_resto.table1" { @@ -188,6 +189,13 @@ class TestDbmlEngine: columns=[Column(name="name2", data_type="--name2-type2--")], raw_sql="--irrelevant--", ), + Table( + name="source.dbt_resto.table3", + database="--database3--", + schema="--schema3--", + columns=[Column(name="name3", data_type="--name3-type3--")], + raw_sql="--irrelevant--", + ), ], [ Ref( @@ -202,6 +210,7 @@ class TestDbmlEngine: ), ], "", + ["model", "source"], """//Tables (based on the selection criteria) //--configured at schema: --database--.--schema-- Table "model.dbt_resto.table1" { @@ -213,6 +222,10 @@ class TestDbmlEngine: "name2" "--name2-type2--" "name-notexist2" "unknown" } + //--configured at schema: --database3--.--schema3-- + Table "source.dbt_resto.table3" { + "name3" "--name3-type3--" + } //Refs (based on the DBT Relationship Tests) Ref: "model.dbt_resto.table1"."name1" > "model.dbt_resto.table2"."name2" Ref: "model.dbt_resto.table1"."name-notexist1" > "model.dbt_resto.table2"."name-notexist2" @@ -243,6 +256,7 @@ class TestDbmlEngine: ) ], "schema:--schema--", + ["model", "source"], """//Tables (based on the selection criteria) //--configured at schema: --database--.--schema-- Table "model.dbt_resto.table1" { @@ -253,7 +267,7 @@ class TestDbmlEngine: ), ], ) - def test_parse(self, tables, relationships, select, expected): + def test_parse(self, tables, relationships, select, resource_type, expected): with mock.patch( "dbterd.adapters.targets.dbml.engine.engine.get_tables", return_value=tables ) as mock_get_tables: @@ -262,7 +276,10 @@ def test_parse(self, tables, relationships, select, expected): return_value=relationships, ) as mock_get_relationships: dbml = engine.parse( - manifest="--manifest--", catalog="--catalog--", select=select + manifest="--manifest--", + catalog="--catalog--", + select=select, + resource_type=resource_type, ) print("dbml ", dbml.replace(" ", "").replace("\n", "")) print("expected", expected.replace(" ", "").replace("\n", ""))