From 10bb875cee4d8ab14a77b5e1fcfeb65d5625712f Mon Sep 17 00:00:00 2001 From: Mike Gouline <1960272+gouline@users.noreply.github.com> Date: Wed, 16 Oct 2024 22:24:18 +1100 Subject: [PATCH] Exposure dependency resolution by fully-qualified names --- dbtmetabase/_exposures.py | 36 ++++++++++++++++++++++++++++++------ dbtmetabase/manifest.py | 4 ++++ dbtmetabase/metabase.py | 6 +++++- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/dbtmetabase/_exposures.py b/dbtmetabase/_exposures.py index 3c5d9aa..08251d4 100644 --- a/dbtmetabase/_exposures.py +++ b/dbtmetabase/_exposures.py @@ -20,7 +20,7 @@ from .errors import ArgumentError from .format import Filter, dump_yaml, safe_description, safe_name -from .manifest import Manifest +from .manifest import DEFAULT_SCHEMA, Manifest _RESOURCE_VERSION = 2 @@ -79,8 +79,18 @@ def extract_exposures( models = self.manifest.read_models() ctx = self.__Context( - model_refs={m.alias.lower(): m.ref for m in models if m.ref}, - table_names={t["id"]: t["name"] for t in self.metabase.get_tables()}, + model_refs={m.alias_path.lower(): m.ref for m in models if m.ref}, + database_names={d["id"]: d["name"] for d in self.metabase.get_databases()}, + table_names={ + t["id"]: ".".join( + [ + t.get("db", {}).get("name", ""), + t.get("schema", DEFAULT_SCHEMA), + t["name"], + ] + ).lower() + for t in self.metabase.get_tables() + }, ) exposures = [] @@ -288,13 +298,26 @@ def __extract_card_exposures( # Parse SQL for exposures through FROM or JOIN clauses for sql_ref in re.findall(_EXPOSURE_PARSER, native_query): - # Grab just the table / model name - parsed_model = sql_ref.split(".")[-1].strip('"').lower() + # DATABASE.schema.table -> [database, schema, table] + parsed_model_path = [ + s.strip('"').lower() for s in sql_ref.split(".") + ] # Scrub CTEs (qualified sql_refs can not reference CTEs) - if parsed_model in ctes and "." not in sql_ref: + if parsed_model_path[-1] in ctes and "." not in sql_ref: continue + # Missing schema -> use default schema + if len(parsed_model_path) < 2: + parsed_model_path.insert(0, DEFAULT_SCHEMA.lower()) + # Missing database -> use query's database + if len(parsed_model_path) < 3: + database_name = ctx.database_names.get(query["database"], "") + parsed_model_path.insert(0, database_name.lower()) + + # Fully-qualified database.schema.table + parsed_model = ".".join(parsed_model_path) + # Verify this is one of our parsed refable models so exposures dont break the DAG if not ctx.model_refs.get(parsed_model): continue @@ -429,4 +452,5 @@ def __write_exposures( @dc.dataclass class __Context: model_refs: Mapping[str, str] + database_names: Mapping[str, str] table_names: Mapping[str, str] diff --git a/dbtmetabase/manifest.py b/dbtmetabase/manifest.py index 014ef03..be08663 100644 --- a/dbtmetabase/manifest.py +++ b/dbtmetabase/manifest.py @@ -387,6 +387,10 @@ def ref(self) -> Optional[str]: return f"source('{self.source}', '{self.name}')" return None + @property + def alias_path(self) -> str: + return ".".join([self.database, self.schema or DEFAULT_SCHEMA, self.alias]) + def format_description( self, append_tags: bool = False, diff --git a/dbtmetabase/metabase.py b/dbtmetabase/metabase.py index c50917b..61c9620 100644 --- a/dbtmetabase/metabase.py +++ b/dbtmetabase/metabase.py @@ -95,9 +95,13 @@ def _api( return response_json + def get_databases(self) -> Sequence[Mapping]: + """Retrieves all databases.""" + return list(self._api("get", "/api/database")) + def find_database(self, name: str) -> Optional[Mapping]: """Finds database by name attribute or returns none.""" - for api_database in list(self._api("get", "/api/database")): + for api_database in self.get_databases(): if api_database["name"].upper() == name.upper(): return api_database return None