diff --git a/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py b/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py index fbaf560..b3b5677 100644 --- a/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py +++ b/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py @@ -1,3 +1,6 @@ +import re +from typing import Optional + from dbterd.adapters.algos import test_relationship @@ -14,6 +17,56 @@ def run(manifest, catalog, **kwargs): return ("output.md", parse(manifest, catalog, **kwargs)) +def replace_column_name(column_name: str) -> str: + """Replace column names containing special characters. + To prevent mermaid from not being able to render column names that may contain special characters. + + Args: + column_name (str): column name + + Returns: + str: Column name with special characters substituted + """ + return column_name.replace(" ", "-").replace(".", "__") + + +def match_complex_column_type(column_type: str) -> Optional[str]: + """Returns the root type from nested complex types. + As an example, if the input is `Struct`, return `Struct`. + + Args: + column_type (str): column type + + Returns: + Optional[str]: Returns root type if input type is nested complex type, otherwise returns `None` for primitive types + """ + pattern = r"(\w+)<(\w+\s+\w+(\s*,\s*\w+\s+\w+)*)>" + match = re.match(pattern, column_type) + if match: + return match.group(1) + else: + return None + + +def replace_column_type(column_type: str) -> str: + """If type of column contains special characters that cannot be drawn by mermaid, replace them with strings that can be drawn. + If the type string contains a nested complex type, omit it to make it easier to read. + + Args: + column_type (str): column type + + Returns: + str: Type of column with special characters are substituted or omitted + """ + # Some specific DWHs may have types that cannot be drawn in mermaid, such as `Struct`. + # These types may be nested and can be very long, so omit them + complex_column_type = match_complex_column_type(column_type) + if complex_column_type: + return f"{complex_column_type}[OMITTED]" + else: + return column_type.replace(" ", "-") + + def parse(manifest, catalog, **kwargs): """Get the Mermaid content from dbt artifacts @@ -35,7 +88,7 @@ def parse(manifest, catalog, **kwargs): table_name = table.name.upper() columns = "\n".join( [ - f' {x.data_type.replace(" ","-")} {x.name.replace(" ","-")}' + f" {replace_column_type(x.data_type)} {replace_column_name(x.name)}" for x in table.columns ] ) @@ -49,9 +102,9 @@ def parse(manifest, catalog, **kwargs): for rel in relationships: key_from = f'"{rel.table_map[1]}"' key_to = f'"{rel.table_map[0]}"' - reference_text = rel.column_map[0].replace(" ", "-") + reference_text = replace_column_name(rel.column_map[0]) if rel.column_map[0] != rel.column_map[1]: - reference_text += f"--{ rel.column_map[1].replace(' ','-')}" + reference_text += f"--{ replace_column_name(rel.column_map[1])}" mermaid += f" {key_from.upper()} {get_rel_symbol(rel.type)} {key_to.upper()}: {reference_text}\n" return mermaid diff --git a/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py b/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py index 7aa9801..e7c19ee 100644 --- a/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py +++ b/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py @@ -249,6 +249,55 @@ class TestMermaidTestRelationship: } """, ), + ( + [ + Table( + name="model.dbt_resto.table1", + node_name="model.dbt_resto.table1", + database="--database--", + schema="--schema--", + columns=[ + Column(name="name1.first_name", data_type="name1-type") + ], + raw_sql="--irrelevant--", + ), + Table( + name="model.dbt_resto.table2", + node_name="model.dbt_resto.table2", + database="--database2--", + schema="--schema2--", + columns=[ + Column(name="name2.first_name", data_type="name2-type2"), + Column( + name="complex_struct", + data_type="Struct", + ), + ], + raw_sql="--irrelevant--", + ), + ], + [ + Ref( + name="test.dbt_resto.relationships_table1", + table_map=["model.dbt_resto.table2", "model.dbt_resto.table1"], + column_map=["name2.first_name", "name1.first_name"], + ), + ], + [], + [], + ["model", "source"], + False, + """erDiagram + "MODEL.DBT_RESTO.TABLE1" { + name1-type name1__first_name + } + "MODEL.DBT_RESTO.TABLE2" { + name2-type2 name2__first_name + Struct[OMITTED] complex_struct + } + "MODEL.DBT_RESTO.TABLE1" }|--|| "MODEL.DBT_RESTO.TABLE2": name2__first_name--name1__first_name + """, + ), ], ) def test_parse(