Skip to content

Commit

Permalink
Add more types for function arguments and return values
Browse files Browse the repository at this point in the history
  • Loading branch information
syou6162 committed Feb 10, 2024
1 parent e8122d9 commit 3ebc5a2
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 17 deletions.
24 changes: 14 additions & 10 deletions dbterd/adapters/algos/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from typing import Dict, List

import click

Expand All @@ -9,9 +10,10 @@
TEST_META_RELATIONSHIP_TYPE,
)
from dbterd.helpers.log import logger
from dbterd.types import Catalog, Manifest


def get_tables_from_metadata(data=[], **kwargs):
def get_tables_from_metadata(data=[], **kwargs) -> List[Table]:
"""Extract tables from dbt metadata
Args:
Expand Down Expand Up @@ -47,7 +49,7 @@ def get_tables_from_metadata(data=[], **kwargs):
return tables


def get_tables(manifest, catalog, **kwargs):
def get_tables(manifest: Manifest, catalog: Catalog, **kwargs) -> List[Table]:
"""Extract tables from dbt artifacts
Args:
Expand Down Expand Up @@ -94,7 +96,9 @@ def get_tables(manifest, catalog, **kwargs):
return tables


def enrich_tables_from_relationships(tables, relationships):
def enrich_tables_from_relationships(
tables: List[Table], relationships: List[Ref]
) -> List[Table]:
"""Fullfill columns in Table due to `select *`
Args:
Expand Down Expand Up @@ -180,7 +184,7 @@ def get_table_from_metadata(model_metadata, exposures=[], **kwargs) -> Table:


def get_table(
node_name, manifest_node, catalog_node=None, exposures=[], **kwargs
node_name: str, manifest_node, catalog_node=None, exposures=[], **kwargs
) -> Table:
"""Construct a single Table object
Expand Down Expand Up @@ -313,7 +317,7 @@ def get_node_exposures_from_metadata(data=[], **kwargs):
return exposures


def get_node_exposures(manifest):
def get_node_exposures(manifest: Manifest) -> List[Dict[str, str]]:
"""Get the mapping of table name and exposure name
Args:
Expand Down Expand Up @@ -349,7 +353,7 @@ def get_table_name(format: str, **kwargs) -> str:
return ".".join([kwargs.get(x.lower()) or "KEYNOTFOUND" for x in format.split(".")])


def get_relationships_from_metadata(data=[], **kwargs) -> list[Ref]:
def get_relationships_from_metadata(data=[], **kwargs) -> List[Ref]:
"""Extract relationships from Metadata result list on test relationship
Args:
Expand Down Expand Up @@ -409,7 +413,7 @@ def get_relationships_from_metadata(data=[], **kwargs) -> list[Ref]:
return get_unique_refs(refs=refs)


def get_relationships(manifest, **kwargs):
def get_relationships(manifest: Manifest, **kwargs) -> List[Ref]:
"""Extract relationships from dbt artifacts based on test relationship
Args:
Expand Down Expand Up @@ -482,7 +486,7 @@ def get_unique_refs(refs: list[Ref] = []) -> list[Ref]:
return distinct_list


def get_algo_rule(**kwargs):
def get_algo_rule(**kwargs) -> Dict[str, str]:
"""Extract rule from the --algo option
Args:
Expand Down Expand Up @@ -517,7 +521,7 @@ def get_algo_rule(**kwargs):
return rules


def get_table_map_from_metadata(test_node, **kwargs):
def get_table_map_from_metadata(test_node, **kwargs) -> List[str]:
"""Get the table map with order of [to, from] guaranteed
(for Metadata)
Expand Down Expand Up @@ -570,7 +574,7 @@ def get_table_map_from_metadata(test_node, **kwargs):
return list(reversed(test_parents))


def get_table_map(test_node, **kwargs):
def get_table_map(test_node, **kwargs) -> List[str]:
"""Get the table map with order of [to, from] guaranteed
Args:
Expand Down
15 changes: 10 additions & 5 deletions dbterd/adapters/algos/test_relationship.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List, Tuple, Union

from dbterd.adapters.algos import base
from dbterd.adapters.filter import is_selected_table
from dbterd.adapters.meta import Ref
from dbterd.adapters.meta import Ref, Table
from dbterd.helpers.log import logger
from dbterd.types import Catalog, Manifest


def parse_metadata(data, **kwargs):
def parse_metadata(manifest: Manifest, **kwargs) -> Tuple[List[Table], List[Ref]]:
"""Get all information (tables, relationships) needed for building diagram
(from Metadata)
Expand All @@ -18,7 +21,7 @@ def parse_metadata(data, **kwargs):
relationships = []

# Parse Table
tables = base.get_tables_from_metadata(data=data, **kwargs)
tables = base.get_tables_from_metadata(data=manifest, **kwargs)

# Apply selection
tables = [
Expand All @@ -33,7 +36,7 @@ def parse_metadata(data, **kwargs):
]

# Parse Ref
relationships = base.get_relationships_from_metadata(data=data, **kwargs)
relationships = base.get_relationships_from_metadata(data=manifest, **kwargs)
node_names = [x.node_name for x in tables]
relationships = [
x
Expand All @@ -47,7 +50,9 @@ def parse_metadata(data, **kwargs):
return (tables, relationships)


def parse(manifest, catalog, **kwargs):
def parse(
manifest: Manifest, catalog: Union[str, Catalog], **kwargs
) -> Tuple[List[Table], List[Ref]]:
"""Get all information (tables, relationships) needed for building diagram
Args:
Expand Down
5 changes: 3 additions & 2 deletions dbterd/helpers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbt_artifacts_parser import parser

from dbterd.helpers.log import logger
from dbterd.types import Catalog, Manifest


def get_sys_platform(): # pragma: no cover
Expand Down Expand Up @@ -109,7 +110,7 @@ def win_prepare_path(path: str) -> str: # pragma: no cover
return path


def read_manifest(path: str, version: int = None):
def read_manifest(path: str, version: int = None) -> Manifest:
"""Reads in the manifest.json file, with optional version specification
Args:
Expand All @@ -134,7 +135,7 @@ def read_manifest(path: str, version: int = None):
return parse_func(manifest=_dict)


def read_catalog(path: str, version: int = None):
def read_catalog(path: str, version: int = None) -> Catalog:
"""Reads in the catalog.json file, with optional version specification
Args:
Expand Down
31 changes: 31 additions & 0 deletions dbterd/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Union

from dbt_artifacts_parser.parsers.catalog.catalog_v1 import CatalogV1
from dbt_artifacts_parser.parsers.manifest.manifest_v1 import ManifestV1
from dbt_artifacts_parser.parsers.manifest.manifest_v2 import ManifestV2
from dbt_artifacts_parser.parsers.manifest.manifest_v3 import ManifestV3
from dbt_artifacts_parser.parsers.manifest.manifest_v4 import ManifestV4
from dbt_artifacts_parser.parsers.manifest.manifest_v5 import ManifestV5
from dbt_artifacts_parser.parsers.manifest.manifest_v6 import ManifestV6
from dbt_artifacts_parser.parsers.manifest.manifest_v7 import ManifestV7
from dbt_artifacts_parser.parsers.manifest.manifest_v8 import ManifestV8
from dbt_artifacts_parser.parsers.manifest.manifest_v9 import ManifestV9
from dbt_artifacts_parser.parsers.manifest.manifest_v10 import ManifestV10
from dbt_artifacts_parser.parsers.manifest.manifest_v11 import ManifestV11

Manifest = Union[
ManifestV1,
ManifestV2,
ManifestV3,
ManifestV4,
ManifestV5,
ManifestV6,
ManifestV7,
ManifestV8,
ManifestV9,
ManifestV10,
ManifestV11,
]

# If a new version of Catalog is added, replace with `Union[CatalogV1, CatalogV2, ...]`.
Catalog = CatalogV1

0 comments on commit 3ebc5a2

Please sign in to comment.