diff --git a/academic_observatory_workflows/config.py b/academic_observatory_workflows/config.py index 9b4e18fed..4e99f1e43 100644 --- a/academic_observatory_workflows/config.py +++ b/academic_observatory_workflows/config.py @@ -24,6 +24,7 @@ class Tag: """DAG tag.""" academic_observatory = "academic-observatory" + data_quality = "data-quality" def test_fixtures_folder(*subdirs: str, workflow_module: Optional[str] = None) -> str: diff --git a/academic_observatory_workflows/database/schema/data_quality/data_quality.json b/academic_observatory_workflows/database/schema/data_quality/data_quality.json new file mode 100644 index 000000000..9dfdc3f47 --- /dev/null +++ b/academic_observatory_workflows/database/schema/data_quality/data_quality.json @@ -0,0 +1,116 @@ +[ + { + "name": "full_table_id", + "description": "The fully qualified table name: project_id.dataset_id.table_id", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "hash_id", + "description": "A unique table identifier based off of the full_table_id, number of bytes, number of rows and numbre of columms.", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "project_id", + "description": "Name of the project.", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "dataset_id", + "description": "Dataset that holds the table.", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "table_id", + "description": "Name of the table.", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "is_sharded", + "description": "If the table is sharded or not.", + "type": "BOOL", + "mode": "REQUIRED" + }, + { + "name": "shard_date", + "description": "Date from the table shard (if null it is not a sharded table).", + "type": "DATE", + "mode": "NULLABLE" + }, + { + "name": "date_created", + "description": "Date of when the table was created.", + "type": "TIMESTAMP", + "mode": "REQUIRED" + }, + { + "name": "expires", + "description": "If the table is set to expire or not.", + "type": "BOOLEAN", + "mode": "REQUIRED" + }, + { + "name": "date_expires", + "description": "Date of when the table expires.", + "type": "TIMESTAMP", + "mode": "NULLABLE" + }, + { + "name": "date_last_modified", + "description": "Date of when the table was modified.", + "type": "TIMESTAMP", + "mode": "REQUIRED" + }, + { + "name": "date_checked", + "description": "Date of when this table was checked by the QA workflow.", + "type": "TIMESTAMP", + "mode": "REQUIRED" + }, + { + "name": "size_gb", + "description": "Size of the table in Gigabytes.", + "type": "FLOAT", + "mode": "REQUIRED" + }, + { + "name": "primary_key", + "description": "Array of primary keys that are used to identify records.", + "type": "STRING", + "mode": "REPEATED" + }, + { + "name": "num_rows", + "description": "Number of rows / records in the table.", + "type": "INTEGER", + "mode": "REQUIRED" + }, + { + "name": "num_distinct_records", + "description": "Number of records in the table that have a distinct 'primary_key'.", + "type": "INTEGER", + "mode": "REQUIRED" + }, + { + "name": "num_null_records", + "description": "Number of records that do not have an entry under 'primary_key' (None or nulls).", + "type": "INTEGER", + "mode": "REQUIRED" + }, + { + "name": "num_duplicates", + "description": "Number of duplicate records under the 'primary_key' (None or nulls).", + "type": "INTEGER", + "mode": "REQUIRED" + }, + { + "name": "num_all_fields", + "description": "Number fields that the table has in total, included nested fields.", + "type": "INTEGER", + "mode": "REQUIRED" + } +] diff --git a/academic_observatory_workflows/fixtures/data_quality/people20230101.jsonl b/academic_observatory_workflows/fixtures/data_quality/people20230101.jsonl new file mode 100644 index 000000000..9bd3541b2 --- /dev/null +++ b/academic_observatory_workflows/fixtures/data_quality/people20230101.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90eb993257b8b5dfac45998cdb83935edc02743e75d49a4e9ea2c5a56eedbefe +size 770 diff --git a/academic_observatory_workflows/fixtures/data_quality/people20230108.jsonl b/academic_observatory_workflows/fixtures/data_quality/people20230108.jsonl new file mode 100644 index 000000000..ebd258af6 --- /dev/null +++ b/academic_observatory_workflows/fixtures/data_quality/people20230108.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9065609264312d188c41ca5d0e8d91d3a7448ea97c644dd6cc683c284274ce85 +size 465 diff --git a/academic_observatory_workflows/fixtures/data_quality/people_schema.json b/academic_observatory_workflows/fixtures/data_quality/people_schema.json new file mode 100644 index 000000000..7737bdc64 --- /dev/null +++ b/academic_observatory_workflows/fixtures/data_quality/people_schema.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9bfba90202d6aee71fb7c9b31e436404908926c0bde1596c2624ad951bfccc1 +size 305 diff --git a/academic_observatory_workflows/workflows/data_quality_workflow.py b/academic_observatory_workflows/workflows/data_quality_workflow.py new file mode 100644 index 000000000..e0295821c --- /dev/null +++ b/academic_observatory_workflows/workflows/data_quality_workflow.py @@ -0,0 +1,441 @@ +# Copyright 2023 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: Alex Massen-Hane + + +from __future__ import annotations + +import os +import hashlib +import logging +import pendulum +from google.cloud import bigquery +from dataclasses import dataclass +from typing import Dict, List, Optional, Union +from google.cloud.bigquery import Table as BQTable + +from airflow import DAG +from airflow.operators.python import PythonOperator +from airflow.sensors.external_task import ExternalTaskSensor +from airflow.utils.task_group import TaskGroup + +from academic_observatory_workflows.config import schema_folder as default_schema_folder, Tag + +from observatory.platform.observatory_config import CloudWorkspace +from observatory.platform.workflows.workflow import Workflow, set_task_state, Release +from observatory.platform.bigquery import ( + bq_table_id_parts, + bq_load_from_memory, + bq_create_dataset, + bq_run_query, + bq_table_exists, + bq_select_columns, + bq_get_table, + bq_select_table_shard_dates, + bq_table_id as make_bq_table_id, +) + + +@dataclass +class Table: + project_id: str + dataset_id: str + table_id: str + primary_key: List[str] + is_sharded: bool + shard_limit: Optional[Union[int, bool]] = None + + """Create a metadata class for tables to be processed by this Workflow. + + :param project_id: The Google project_id of where the tables are located. + :param dataset_id: The dataset that the table is under. + :param table_id: The name of the table (not the full qualifed table name). + :param primary_key: Location of where the primary key is located in the table e.g. ["doi"], + could be multiple different identifiers like for Pubmed: ["MedlineCiation.PMID.value", "MedlineCiation.PMID.Version"] + :param is_sharded: True if the table is shared or not. + :param shard_limit: The number of shards to process for this series of table. + """ + + # Set the default value for the shard limit to 5 just incase it's forgotten in the config file. + def __post_init__(self): + self.shard_limit = 5 if self.shard_limit is None and self.is_sharded else self.shard_limit + + @property + def full_table_id(self): + return f"{self.project_id}.{self.dataset_id}.{self.table_id}" + + +class DataQualityWorkflow(Workflow): + def __init__( + self, + *, + dag_id: str, + cloud_workspace: CloudWorkspace, + datasets: Dict, + sensor_dag_ids: Optional[List[str]] = None, + bq_dataset_id: str = "data_quality_checks", + bq_dataset_description: str = "This dataset holds metadata about the tables that the Academic Observatory Worflows produce. If there are multiple shards tables, it will go back on the table and check if it hasn't done that table previously.", + bq_table_id: str = "data_quality", + bq_table_description: str = "Data quality check for all tables produced by the Academic Observatory workflows.", + schema_path: str = os.path.join(default_schema_folder(), "data_quality", "data_quality.json"), + start_date: Optional[pendulum.DateTime] = pendulum.datetime(2020, 1, 1), + schedule: str = "@weekly", + queue: str = "default", + ): + """Create the DataQualityCheck Workflow. + + This workflow creates metadata for all the tables defined in the "datasets" dictionary. + If a table has already been checked before, it will no do it again. This based on if the + number of columns, rows or bytesof data stored in the table changes. We cannot use the "date modified" + from the Biguqery table object because it changes if any other metadata is modified, i.e. description, etc. + + :param dag_id: the DAG ID. + :param cloud_workspace: the cloud workspace settings. + :param datasets: A dictionary of datasets holding tables that will processed by this workflow. + :param sensor_dag_ids: List of dags that this workflow will wait to finish before running. + :param bq_dataset_id: The dataset_id of where the data quality records will be stored. + :param bq_dataset_description: Description of the data quality check dataset. + :param bq_table_id: The name of the table in Bigquery. + :param bq_table_description: The description of the table in Biguqery. + :param schema_path: The path to the schema file for the records produced by this workflow. + :param start_date: The start date of the workflow. + :param schedule: Schedule of how often the workflow runs. + :param queue: Which queue this workflow will run. + """ + + super().__init__( + dag_id=dag_id, + start_date=start_date, + schedule=schedule, + tags=[Tag.data_quality], + queue=queue, + ) + + self.cloud_workspace = cloud_workspace + self.project_id = cloud_workspace.project_id + self.bq_dataset_id = bq_dataset_id + self.bq_table_id = bq_table_id + self.bq_dataset_description = bq_dataset_description + self.bq_table_description = bq_table_description + self.data_location = cloud_workspace.data_location + self.schema_path = schema_path + self.datasets = datasets + + # Full table id for the data quality records. + self.dqc_full_table_id = make_bq_table_id(self.project_id, self.bq_dataset_id, bq_table_id) + + # If no sensor workflow is given, then it will run on a regular scheduled basis. + self.sensor_dag_ids = sensor_dag_ids if sensor_dag_ids is not None else [] + + assert datasets, "No dataset or tables given for this DQC Workflow! Please revise the config file." + + def make_dag(self) -> DAG: + """Create a DAG object for the workflow explicitly defining the tasks and task groups. + + :return: The DAG object for the workflow.""" + + with self.dag: + # Create a group of sensors for the workflow. + task_sensor_group = [] + if self.sensor_dag_ids: + with TaskGroup(group_id="dag_sensors") as tg_sensors: + for ext_dag_id in self.sensor_dag_ids: + ExternalTaskSensor( + task_id=f"{ext_dag_id}_sensor", + external_dag_id=ext_dag_id, + mode="reschedule", + ) + + task_sensor_group = tg_sensors + + # Add the standard tasks for the workflow + # fmt: off + task_check_dependencies = PythonOperator(python_callable=self.check_dependencies, task_id="check_dependencies") + task_create_dataset = self.make_python_operator(self.create_dataset, "create_dataset") + # fmt: on + + # Add each dataset as a task group and perform the checks on the tables in parallel. + task_datasets_group = [] + for dataset_id in list(self.datasets.keys()): + with TaskGroup(group_id=dataset_id) as tg_dataset: + table_objects = [ + Table( + project_id=self.project_id, + dataset_id=dataset_id, + table_id=table["table_id"], + is_sharded=table["is_sharded"], + primary_key=table["primary_key"], + shard_limit=table.get("shard_limit"), + ) + for table in self.datasets[dataset_id]["tables"] + ] + + for table in table_objects: + self.make_python_operator( + self.perform_data_quality_check, + task_id=f"{table.table_id}", + op_kwargs={f"task_id": f"{table.table_id}", "table": table}, + ) + + task_datasets_group.append(tg_dataset) + + # Link all tasks and task groups together. + (task_sensor_group >> task_check_dependencies >> task_create_dataset >> task_datasets_group) + + return self.dag + + def make_release(self, **kwargs) -> Release: + """Make a release instance. + + :param kwargs: the context passed from the PythonOperator. See + https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed + to this argument. + :return: A release instance. + """ + return Release(dag_id=self.dag_id, run_id=kwargs["run_id"]) + + def create_dataset(self, release: Release, **kwargs): + """Create a dataset for the table.""" + + bq_create_dataset( + project_id=self.project_id, + dataset_id=self.bq_dataset_id, + location=self.data_location, + description=self.bq_dataset_description, + ) + + def perform_data_quality_check(self, release: Release, **kwargs): + """For each dataset, create a table that holds metadata about the dataset. + + Please refer to the output of the create_data_quality_record function for all the metadata included in this workflow. + """ + + task_id = kwargs["task_id"] + table_to_check: Table = kwargs["table"] + + # Grab tables only if they are + if table_to_check.is_sharded: + dates = bq_select_table_shard_dates( + table_id=table_to_check.full_table_id, end_date=pendulum.now(tz="UTC"), limit=1000 + ) + + # Use limited number of table shards checked to reduce querying costs. + if table_to_check.shard_limit and len(dates) > table_to_check.shard_limit: + shard_limit = table_to_check.shard_limit + else: + shard_limit = len(dates) + + tables = [] + for shard_date in dates[:shard_limit]: + table_id = f'{table_to_check.full_table_id}{shard_date.strftime("%Y%m%d")}' + assert bq_table_exists(table_id), f"Sharded table {table_id} does not exist!" + table = bq_get_table(table_id) + tables.append(table) + else: + tables = [bq_get_table(table_to_check.full_table_id)] + + assert ( + len(tables) > 0 and tables[0] is not None + ), f"No table or sharded tables found in Bigquery for: {table_to_check.dataset_id}.{table_to_check.table_id}" + + records = [] + table: BQTable + for table in tables: + full_table_id = str(table.reference) + + hash_id = create_table_hash_id( + full_table_id=full_table_id, + num_bytes=table.num_bytes, + nrows=table.num_rows, + ncols=len(bq_select_columns(table_id=full_table_id)), + ) + + if not is_in_dqc_table(hash_id, self.dqc_full_table_id): + logging.info(f"Performing check on table {full_table_id} with hash {hash_id}") + check = create_data_quality_record( + hash_id=hash_id, + full_table_id=full_table_id, + primary_key=table_to_check.primary_key, + is_sharded=table_to_check.is_sharded, + table_in_bq=table, + ) + records.append(check) + else: + logging.info( + f"Table {table_to_check.full_table_id} with hash {hash_id} has already been checked before. Not performing check again." + ) + + if records: + logging.info(f"Uploading DQC records for table: {task_id}: {self.dqc_full_table_id}") + success = bq_load_from_memory( + table_id=self.dqc_full_table_id, + records=records, + schema_file_path=self.schema_path, + write_disposition=bigquery.WriteDisposition.WRITE_APPEND, + ) + + assert success, f"Error uploading data quality check to Bigquery." + else: + success = True + + set_task_state(success, task_id, release) + + +def create_data_quality_record( + hash_id: str, + full_table_id: str, + primary_key: List[str], + is_sharded: bool, + table_in_bq: BQTable, +) -> Dict[str, Union[str, List[str], float, bool, int]]: + """ + Perform novel data quality checks on a given table in Bigquery. + + :param hash_id: Unique md5 style identifier of the table. + :param full_table_id: The fully qualified table id, including the shard date suffix. + :param primary_key: The key identifier columns for the table. + :param is_sharded: If the table is supposed to be sharded or not. + :param table_in_bq: Table metadata object retrieved from the Bigquery API. + :return: Dictionary of values from the data quality check. + """ + + project_id, dataset_id, table_id, shard_date = bq_table_id_parts(full_table_id) + assert is_sharded == (shard_date is not None), f"Workflow config of table {full_table_id} do not match." + + # Retrieve metadata on the table. + date_created = table_in_bq.created.isoformat() + date_checked = pendulum.now(tz="UTC").isoformat() + date_last_modified = table_in_bq.modified.isoformat() + + expires = bool(table_in_bq.expires) + date_expires = table_in_bq.expires.isoformat() if expires else None + + num_distinct_records = bq_count_distinct_records(full_table_id, fields=primary_key) + num_null_records = bq_count_nulls(full_table_id, fields=primary_key) + num_duplicates = bq_count_duplicate_records(full_table_id, fields=primary_key) + + num_all_fields = len(bq_select_columns(table_id=full_table_id)) + + return dict( + full_table_id=full_table_id, + hash_id=hash_id, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + shard_date=shard_date.to_date_string() if shard_date is not None else shard_date, + is_sharded=is_sharded, + date_created=date_created, + expires=expires, + date_expires=date_expires, + date_checked=date_checked, + date_last_modified=date_last_modified, + size_gb=float(table_in_bq.num_bytes) / (1024.0) ** 3, + primary_key=primary_key, + num_rows=table_in_bq.num_rows, + num_distinct_records=num_distinct_records, + num_null_records=num_null_records, + num_duplicates=num_duplicates, + num_all_fields=num_all_fields, + ) + + +def create_table_hash_id(full_table_id: str, num_bytes: int, nrows: int, ncols: int) -> str: + """Create a unique table identifier based off of the the input parameters for a table in Biguqery. + + :param full_table_id: The fully qualified table name. + :param num_bytes: Number of bytes stored in the table. + :param nrows: Number of rows/records in the table. + :param ncols: Number of columns/fields in the table. + :return: A md5 hash based off of the given input parameters.""" + + return hashlib.md5(f"{full_table_id}{num_bytes}{nrows}{ncols}".encode("utf-8")).hexdigest() + + +def is_in_dqc_table(hash_to_check: str, dqc_full_table_id: str) -> bool: + """Checks if a table has already been processed before checking if the table's hash_id is in the main DQC table. + + :param hash_to_check: The hash of the table to check if the data quality checks have been performed before. + :param dqc_full_table_id: The fully qualified name of the table that holds all of the DQC records in Bigquery. + :return: True if the check has been done before, otherwise false. + """ + + if bq_table_exists(dqc_full_table_id): + sql = f""" + SELECT True + FROM {dqc_full_table_id} + WHERE hash_id = "{hash_to_check}" + """ + return bool([dict(row) for row in bq_run_query(sql)]) + + else: + logging.info(f"DQC record table: {dqc_full_table_id} does not exist!") + return False + + +def bq_count_distinct_records(full_table_id: str, fields: Union[str, List[str]]) -> int: + """ + Finds the distinct number of records that have these matching fields. + + :param table_id: The fully qualified table id. + :param fields: Singular or list of fields to determine the distinct records. + """ + + fields_to_check = ", ".join(fields) if isinstance(fields, list) else fields + sql = f""" + SELECT COUNT(*) as count + FROM ( SELECT distinct {fields_to_check} FROM {full_table_id} ) + """ + return int([dict(row) for row in bq_run_query(sql)][0]["count"]) + + +def bq_count_nulls(full_table_id: str, fields: Union[str, List[str]]) -> int: + """Return the number of nulls for a singular field or number of fields. + This is separated by an OR condition, thus will be counts if any of the fields listed are nulls/empty. + + :param full_table_id: The fully qualified table id. + :param fields: A single string or list of strings to have the number of nulls checked. + :return: The integer number of nulls present in the given fields.""" + + fields_to_check = " IS NULL OR ".join(fields) if isinstance(fields, list) else fields + sql = f""" + SELECT COUNTIF( {fields_to_check} IS NULL) AS nullCount + FROM `{full_table_id}` + """ + return int([dict(row) for row in bq_run_query(sql)][0]["nullCount"]) + + +def bq_count_duplicate_records(full_table_id: str, fields: Union[str, List[str]]) -> int: + """Query a table in Bigquery and return a dictionary of values holding the field/s + and the number of duplicates for said key/s. + + :param fields: String or list of strings of the keys to query the table. + :param full_table_id: Fully qualified table name. + """ + + fields_to_check = ", ".join(fields) if isinstance(fields, list) else fields + sql = f""" + SELECT + SUM(duplicate_count) AS total_duplicate_sum + FROM ( + SELECT {fields_to_check}, COUNT(*) AS duplicate_count + FROM `{full_table_id}` + GROUP BY {fields_to_check} + HAVING COUNT(*) > 1 + ) + """ + result = [dict(row) for row in bq_run_query(sql)][0]["total_duplicate_sum"] + num_duplicates = 0 if result is None else int(result) + + return num_duplicates diff --git a/academic_observatory_workflows/workflows/tests/test_data_quality_workflow.py b/academic_observatory_workflows/workflows/tests/test_data_quality_workflow.py new file mode 100644 index 000000000..ac9732605 --- /dev/null +++ b/academic_observatory_workflows/workflows/tests/test_data_quality_workflow.py @@ -0,0 +1,553 @@ +# Copyright 2023 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: Alex Massen-Hane + +import os +import pendulum +from google.cloud import bigquery +from google.cloud.bigquery import Table as BQTable + +from academic_observatory_workflows.config import test_fixtures_folder, schema_folder as default_schema_folder +from academic_observatory_workflows.workflows.data_quality_workflow import ( + DataQualityWorkflow, + Table, + create_data_quality_record, + bq_count_distinct_records, + bq_count_nulls, + bq_get_table, + bq_count_duplicate_records, + create_table_hash_id, + is_in_dqc_table, +) +from observatory.platform.bigquery import ( + bq_table_id, + bq_load_from_memory, + bq_select_columns, + bq_upsert_records, +) +from observatory.platform.files import load_jsonl +from observatory.platform.observatory_config import Workflow +from observatory.platform.observatory_environment import ( + ObservatoryEnvironment, + ObservatoryTestCase, + make_dummy_dag, + find_free_port, + random_id, +) + + +class TestDataQualityWorkflow(ObservatoryTestCase): + """Tests for the Data Quality Check Workflow""" + + def __init__(self, *args, **kwargs): + self.dag_id = "data_quality_workflow" + self.project_id = os.getenv("TEST_GCP_PROJECT_ID") + self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") + + super(TestDataQualityWorkflow, self).__init__(*args, **kwargs) + + def test_dag_load(self): + """Test that the DataQualityCheck DAG can be loaded from a DAG bag.""" + + env = ObservatoryEnvironment( + workflows=[ + Workflow( + dag_id=self.dag_id, + name="Data Quality Check Workflow", + class_name="academic_observatory_workflows.workflows.data_quality_workflow.DataQualityWorkflow", + cloud_workspace=self.fake_cloud_workspace, + kwargs=dict( + sensor_dag_ids=["doi"], + datasets={ + "observatory": {"tables": [{"table_id": "doi", "is_sharded": True, "primary_key": ["doi"]}]} + }, + ), + ) + ], + api_port=find_free_port(), + ) + + with env.create(): + self.assert_dag_load_from_config(self.dag_id) + + def test_dag_structure(self): + """Test that the DAG has the correct structure.""" + + workflow = DataQualityWorkflow( + dag_id=self.dag_id, + cloud_workspace=self.fake_cloud_workspace, + sensor_dag_ids=["dummy1", "dummy2"], + datasets={ + "observatory": { + "tables": [ + { + "table_id": "doi", + "primary_key": ["doi"], + "is_sharded": True, + "shard_limit": 5, + } + ], + }, + "pubmed": { + "tables": [ + { + "table_id": "pubmed", + "primary_key": ["MedlineCitation.PMID.value", "MedlineCitation.PMID.Version"], + "is_sharded": False, + } + ] + }, + }, + ) + dag = workflow.make_dag() + self.assert_dag_structure( + { + "dag_sensors.dummy1_sensor": ["check_dependencies"], + "dag_sensors.dummy2_sensor": ["check_dependencies"], + "check_dependencies": ["create_dataset"], + "create_dataset": ["observatory.doi", "pubmed.pubmed"], + "observatory.doi": [], + "pubmed.pubmed": [], + }, + dag, + ) + + def test_workflow(self): + """Test the Data Quality Check Workflow end to end + + Borrowing off of the doi test structure.""" + + env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port()) + + # Where the metadata generated for this workflow is going to be stored. + dq_dataset_id = env.add_dataset(prefix="data_quality_check") + fake_dataset_id = env.add_dataset() + + test_tables = [ + { + "full_table_id": bq_table_id(self.project_id, fake_dataset_id, "people"), + "schema_path": os.path.join(test_fixtures_folder(), "data_quality", "people_schema.json"), + "expected": load_jsonl(os.path.join(test_fixtures_folder(), "data_quality", "people20230101.jsonl")), + }, + { + "full_table_id": bq_table_id(self.project_id, fake_dataset_id, "people_shard20230101"), + "schema_path": os.path.join(test_fixtures_folder(), "data_quality", "people_schema.json"), + "expected": load_jsonl(os.path.join(test_fixtures_folder(), "data_quality", "people20230101.jsonl")), + }, + { + "full_table_id": bq_table_id(self.project_id, fake_dataset_id, "people_shard20230108"), + "schema_path": os.path.join(test_fixtures_folder(), "data_quality", "people_schema.json"), + "expected": load_jsonl(os.path.join(test_fixtures_folder(), "data_quality", "people20230108.jsonl")), + }, + { + "full_table_id": bq_table_id(self.project_id, fake_dataset_id, "people_shard20230115"), + "schema_path": os.path.join(test_fixtures_folder(), "data_quality", "people_schema.json"), + "expected": load_jsonl(os.path.join(test_fixtures_folder(), "data_quality", "people20230108.jsonl")), + }, + ] + + with env.create(task_logging=True): + # Upload the test tables to Bigquery + for table in test_tables: + bq_load_from_memory( + table_id=table["full_table_id"], records=table["expected"], schema_file_path=table["schema_path"] + ) + + start_date = pendulum.datetime(year=2021, month=10, day=10) + workflow = DataQualityWorkflow( + dag_id=self.dag_id, + cloud_workspace=env.cloud_workspace, + bq_dataset_id=dq_dataset_id, + start_date=start_date, + sensor_dag_ids=["doi", "pubmed"], + datasets={ + fake_dataset_id: { + "tables": [ + {"table_id": "people", "primary_key": ["id"], "is_sharded": False}, + {"table_id": "people_shard", "primary_key": ["id"], "is_sharded": True, "shard_limit": 2}, + ], + }, + }, + ) + + data_quality_dag = workflow.make_dag() + + # Run fake version of the dags that the workflow sensors are waiting for. + execution_date = pendulum.datetime(year=2023, month=1, day=1) + for dag_id in workflow.sensor_dag_ids: + dag = make_dummy_dag(dag_id, execution_date) + with env.create_dag_run(dag, execution_date): + # Running all of a DAGs tasks sets the DAG to finished + ti = env.run_task("dummy_task") + self.assertEqual("success", ti.state) + + ### FIRST RUN ### + # First run of the workflow. Will produce the data_quality table and a record for + # each of the test tables uploaded, but will miss the 20230101 shard due to the shard_limit parameter set. + + # Run end to end tests for DQC DAG + with env.create_dag_run(data_quality_dag, execution_date): + # Test that sensors go into 'success' state as the DAGs that they are waiting for have finished + for task_id in workflow.sensor_dag_ids: + ti = env.run_task(f"dag_sensors.{task_id}_sensor") + self.assertEqual("success", ti.state) + + # Check dependencies + ti = env.run_task(workflow.check_dependencies.__name__) + self.assertEqual("success", ti.state) + + # Create dataset + ti = env.run_task(workflow.create_dataset.__name__) + self.assertEqual("success", ti.state) + + # Perform data quality check + for dataset_id, tables in workflow.datasets.items(): + for table in tables["tables"]: + task_id = f"{dataset_id}.{table['table_id']}" + ti = env.run_task(task_id) + self.assertEqual("success", ti.state) + + # Check that DQC table has been created. + table_id = bq_table_id(self.project_id, workflow.bq_dataset_id, workflow.bq_table_id) + self.assert_table_integrity(table_id, expected_rows=3) # stop and look at table on BQ + + ### SECOND RUN ### + # For the sake of the test, we will change one of the tables by doing an upsert, so that the + # hash_id of the first will be different. + + bq_upsert_records( + main_table_id=test_tables[0]["full_table_id"], + upsert_table_id=test_tables[2]["full_table_id"], + primary_key="id", + ) + self.assert_table_integrity(test_tables[0]["full_table_id"], 16) + + # Run Dummy Dags + execution_date = pendulum.datetime(year=2023, month=2, day=1) + for dag_id in workflow.sensor_dag_ids: + dag = make_dummy_dag(dag_id, execution_date) + with env.create_dag_run(dag, execution_date): + # Running all of a DAGs tasks sets the DAG to finished + ti = env.run_task("dummy_task") + self.assertEqual("success", ti.state) + + # Run end to end tests for DQC DAG + with env.create_dag_run(data_quality_dag, execution_date): + # Test that sensors go into 'success' state as the DAGs that they are waiting for have finished + for task_id in workflow.sensor_dag_ids: + ti = env.run_task(f"dag_sensors.{task_id}_sensor") + self.assertEqual("success", ti.state) + + # Check dependencies + ti = env.run_task(workflow.check_dependencies.__name__) + self.assertEqual("success", ti.state) + + # Create dataset + ti = env.run_task(workflow.create_dataset.__name__) + self.assertEqual("success", ti.state) + # Perform data quality check + for dataset_id, tables in workflow.datasets.items(): + for table in tables["tables"]: + task_id = f"{dataset_id}.{table['table_id']}" + ti = env.run_task(task_id) + self.assertEqual("success", ti.state) + + # Check that DQC table has been created. + table_id = bq_table_id(self.project_id, workflow.bq_dataset_id, workflow.bq_table_id) + self.assert_table_integrity(table_id, expected_rows=4) + + ### THIRD RUN ### + # For this third run, no tables should be updated or changed meaning that there should be + # no data quality checks done. + + # Run Dummy Dags + execution_date = pendulum.datetime(year=2023, month=3, day=1) + for dag_id in workflow.sensor_dag_ids: + dag = make_dummy_dag(dag_id, execution_date) + with env.create_dag_run(dag, execution_date): + # Running all of a DAGs tasks sets the DAG to finished + ti = env.run_task("dummy_task") + self.assertEqual("success", ti.state) + + # Run end to end tests for Data Quality DAG + with env.create_dag_run(data_quality_dag, execution_date): + # Test that sensors go into 'success' state as the DAGs that they are waiting for have finished + for task_id in workflow.sensor_dag_ids: + ti = env.run_task(f"dag_sensors.{task_id}_sensor") + self.assertEqual("success", ti.state) + + # Check dependencies + ti = env.run_task(workflow.check_dependencies.__name__) + self.assertEqual("success", ti.state) + + # Create dataset + ti = env.run_task(workflow.create_dataset.__name__) + self.assertEqual("success", ti.state) + + # Perform data quality check + for dataset_id, tables in workflow.datasets.items(): + for table in tables["tables"]: + task_id = f"{dataset_id}.{table['table_id']}" + ti = env.run_task(task_id) + self.assertEqual("success", ti.state) + + # Check that the DQC table has no new records added for this third run. + # Check that DQC table has been created. + table_id = bq_table_id(self.project_id, workflow.bq_dataset_id, workflow.bq_table_id) + self.assert_table_integrity(table_id, expected_rows=4) + + +class TestDataQualityUtils(ObservatoryTestCase): + def __init__(self, *args, **kwargs): + super(TestDataQualityUtils, self).__init__(*args, **kwargs) + + self.dag_id = "data_quality_checks" + self.project_id = os.getenv("TEST_GCP_PROJECT_ID") + self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") + + self.schema_path = os.path.join(default_schema_folder(), "data_quality", "data_quality.json") + + # Can't use faker here because the number of bytes in a table is needed to be the same for each test run. + self.test_table_hash = "771c9176e77c1b03f64b1b5fa4a39cdb" + self.test_table = [ + dict(id="something", count="1", abstract_text="Hello"), + dict(id="something", count="2", abstract_text="World"), + dict(id="somethingelse", count="3", abstract_text="Science"), + dict(id="somethingelse", count="4", abstract_text="Science"), + dict(id=None, count="5", abstract_text="Maths"), + ] + + self.expected_dqc_record = dict( + table_id="create_dqc_record", + project_id=self.project_id, + is_sharded=False, + shard_date=None, + expires=False, + date_expires=None, + size_gb=1.2200325727462769e-07, + primary_key=["id"], + num_rows=5, + num_distinct_records=3, + num_null_records=1, + num_duplicates=4, + num_all_fields=3, + ) + + def test_create_table_hash_id(self): + """Test if hash can be reliably created.""" + + bq_table_id = "create_table_hash_id" + result = create_table_hash_id(full_table_id=bq_table_id, num_bytes=131, nrows=5, ncols=3) + self.assertEqual(result, self.test_table_hash) + + def test_create_dq_record(self): + """Test if a data quality check record can be reliably created.""" + + env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port()) + dataset_id = env.add_dataset() + table_id = "create_dqc_record" + + table_to_check = Table( + project_id=self.project_id, + dataset_id=dataset_id, + table_id=table_id, + primary_key=["id"], + is_sharded=False, + ) + + with env.create(task_logging=True): + full_table_id = bq_table_id(self.project_id, dataset_id, table_id) + + # Load the test table from memory to Bigquery. + success = bq_load_from_memory(table_id=full_table_id, records=self.test_table) + self.assertTrue(success) + + # Grab the table from the Bigquery API + table: BQTable = bq_get_table(full_table_id) + + # Need to add a DQC record into a temp table so that we can check if it's in there. + hash_id = create_table_hash_id( + full_table_id=full_table_id, + num_bytes=table.num_bytes, + nrows=table.num_rows, + ncols=len(bq_select_columns(table_id=full_table_id)), + ) + + dqc_record = create_data_quality_record( + hash_id=hash_id, + full_table_id=full_table_id, + primary_key=table_to_check.primary_key, + is_sharded=table_to_check.is_sharded, + table_in_bq=table, + ) + + # Loop through checking all of the values that do not change for each unittest run. + keys_to_check = list(self.expected_dqc_record.keys()) + for key in keys_to_check: + self.assertEqual(dqc_record[key], self.expected_dqc_record[key]) + + def test_is_in_dqc_table(self): + """Test if a data quality check has already been previously performed by checking the table hash that it creates.""" + + env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port()) + dqc_dataset_id = env.add_dataset(prefix="data_quality_check") + dataset_id = env.add_dataset() + table_id = "is_in_dqc_table" + dag_id = "test_dag" + + table_to_check = Table( + project_id=self.project_id, + dataset_id=dataset_id, + table_id=table_id, + primary_key=["id"], + is_sharded=False, + ) + + with env.create(task_logging=True): + full_table_id = bq_table_id(self.project_id, dataset_id, table_id) + dqc_full_table_id = bq_table_id(self.project_id, dqc_dataset_id, dag_id) + + # Load the test table from memory to Bigquery. + success = bq_load_from_memory(table_id=full_table_id, records=self.test_table) + self.assertTrue(success) + + # Grab the table from the Bigquery API + table: BQTable = bq_get_table(full_table_id) + + # Need to add a DQC record into a temp table so that we can check if it's in there. + hash_id = create_table_hash_id( + full_table_id=full_table_id, + num_bytes=table.num_bytes, + nrows=table.num_rows, + ncols=len(bq_select_columns(table_id=full_table_id)), + ) + + dqc_record = [ + create_data_quality_record( + hash_id=hash_id, + full_table_id=full_table_id, + primary_key=table_to_check.primary_key, + is_sharded=table_to_check.is_sharded, + table_in_bq=table, + ) + ] + success = bq_load_from_memory( + table_id=dqc_full_table_id, + records=dqc_record, + schema_file_path=self.schema_path, + write_disposition=bigquery.WriteDisposition.WRITE_APPEND, + table_description=f"{dag_id}", + ) + self.assertTrue(success) + + # Ensure that hash is in the data quality table. + result = is_in_dqc_table(hash_to_check=hash_id, dqc_full_table_id=dqc_full_table_id) + self.assertTrue(result) + + # A random hash that we know shouldn't be in the data quality table. + random_hash = random_id() + result = is_in_dqc_table(hash_to_check=random_hash, dqc_full_table_id=dqc_full_table_id) + self.assertFalse(result) + + def test_bq_count_duplicate_records(self): + """Test if duplicate records can be reliably found in a table.""" + + env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port()) + bq_dataset_id = env.add_dataset() + bq_table_id = "count_duplicate_records" + + with env.create(task_logging=True): + full_table_id = f"{self.project_id}.{bq_dataset_id}.{bq_table_id}" + + success = bq_load_from_memory(table_id=full_table_id, records=self.test_table) + self.assertTrue(success) + + num_distinct = bq_count_duplicate_records(full_table_id, "id") + self.assertEqual(num_distinct, 4) + + num_distinct = bq_count_duplicate_records(full_table_id, ["id", "abstract_text"]) + self.assertEqual(num_distinct, 2) + + def test_bq_count_nulls(self): + """Test if the number of nulls under a field can be correctly determined.""" + + env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port()) + bq_dataset_id = env.add_dataset() + bq_table_id = "count_num_nulls_for_field" + + with env.create(task_logging=True): + full_table_id = f"{self.project_id}.{bq_dataset_id}.{bq_table_id}" + + success = bq_load_from_memory(table_id=full_table_id, records=self.test_table) + self.assertTrue(success) + + num_distinct = bq_count_nulls(full_table_id, "id") + self.assertEqual(num_distinct, 1) + + num_distinct = bq_count_nulls(full_table_id, ["id", "abstract_text"]) + self.assertEqual(num_distinct, 1) + + def test_bq_count_distinct_records(self): + """Test that the number of distinct records can be reliably detmerined.""" + + env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port()) + bq_dataset_id = env.add_dataset() + bq_table_id = "distinct_records" + + with env.create(task_logging=True): + full_table_id = f"{self.project_id}.{bq_dataset_id}.{bq_table_id}" + + success = bq_load_from_memory(table_id=full_table_id, records=self.test_table) + self.assertTrue(success) + + num_distinct = bq_count_distinct_records(full_table_id, "id") + self.assertEqual(num_distinct, 3) + + num_distinct = bq_count_distinct_records(full_table_id, ["id", "abstract_text"]) + self.assertEqual(num_distinct, 4) + + def test_table_object(self): + """Test that a table's shard limit can be set properly.""" + + table = Table( + project_id=self.project_id, + dataset_id="dataset_id", + table_id="table_id", + primary_key=["id"], + is_sharded=False, + ) + + self.assertEqual(table.shard_limit, None) + + table = Table( + project_id=self.project_id, + dataset_id="dataset_id", + table_id="table_id", + primary_key=["id"], + is_sharded=True, + ) + + self.assertEqual(table.shard_limit, 5) + + table = Table( + project_id=self.project_id, + dataset_id="dataset_id", + table_id="table_id", + primary_key=["id"], + is_sharded=False, + shard_limit=False, + ) + + self.assertEqual(table.shard_limit, False)