From ce14994a29c966b5f1ed97f5181f56aca1f80c68 Mon Sep 17 00:00:00 2001 From: keegansmith21 Date: Fri, 27 Sep 2024 08:43:29 +0000 Subject: [PATCH] ORCID fixes --- .../orcid_telescope/tasks.py | 39 +++++++++++++-- .../orcid_telescope/telescope.py | 49 ++++++------------- 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/tasks.py b/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/tasks.py index e275da4a1..696411628 100644 --- a/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/tasks.py +++ b/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/tasks.py @@ -25,6 +25,7 @@ from concurrent.futures import as_completed, ProcessPoolExecutor, ThreadPoolExecutor import datetime from os import PathLike +import json from typing import Dict, Optional, Tuple, Union import pendulum @@ -33,7 +34,9 @@ from airflow.exceptions import AirflowSkipException from airflow.hooks.base import BaseHook from airflow.models import DagRun -from google.auth import default +from google import auth +from google.auth.compute_engine import Credentials as ComputeEngineCredentials +from google.oauth2.service_account import Credentials as ServiceAccountCredentials from google.cloud import bigquery from google.cloud.bigquery import SourceFormat @@ -46,6 +49,7 @@ from observatory_platform.date_utils import datetime_normalise from observatory_platform.files import change_keys, save_jsonl_gz from observatory_platform.google import bigquery as bq +from observatory_platform.google.gke import gke_service_account_email from observatory_platform.google.gcs import ( gcs_blob_uri, gcs_create_aws_transfer, @@ -106,8 +110,10 @@ def fetch_release( "fetch_releases: there should be at least 1 DatasetRelease in the Observatory API after the first DAG run" ) prev_release = api.get_latest_dataset_release(dag_id=dag_id, entity_id="orcid", date_key="changefile_end_date") + logging.info(f"Extra: {prev_release.extra}") + logging.info(f"Type: {type(prev_release.extra)}") + prev_latest_modified_record = pendulum.parse(json.loads(prev_release.extra["latest_modified_record_date"])) prev_release_end = prev_release.changefile_end_date - prev_latest_modified_record = pendulum.parse(prev_release.extra["latest_modified_record_date"]) return OrcidRelease( dag_id=dag_id, @@ -246,8 +252,21 @@ def download(release: dict): """Reads each batch's manifest and downloads the files from the gcs bucket.""" release = OrcidRelease.from_dict(release) - gcs_creds, project_id = default() - with gcs_hmac_key(project_id, gcs_creds.service_account_email) as (key, secret): + + # Check the type of credentials + credentials, _ = auth.default() + if isinstance(credentials, ServiceAccountCredentials): + email = credentials.service_account_email + logging.info("Using service account credentials") + elif isinstance(credentials, ComputeEngineCredentials): + logging.info("Using compute engine credentials") + email = gke_service_account_email() + if not email: + raise AirflowException("Email could not be retrieved") + else: + raise AirflowException(f"Unknown credentials type: {type(credentials)}") + + with gcs_hmac_key(release.cloud_workspace.project_id, email) as (key, secret): total_files = 0 start_time = time.time() for orcid_batch in release.orcid_batches(): @@ -265,6 +284,9 @@ def download(release: dict): # Check for errors if returncode != 0: + with open(orcid_batch.download_log_file) as f: + output = f.read() + logging.error(output) raise RuntimeError( f"Download attempt '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}. See log file: {orcid_batch.download_log_file}" ) @@ -471,6 +493,11 @@ def add_dataset_release(release: dict, *, api_bq_dataset_id: str, latest_modifie :param last_modified_release_date: The modification datetime of the last modified record of this release""" release = OrcidRelease.from_dict(release) + try: + pendulum.parse(latest_modified_date) + except pendulum.parsing.Exceptions.ParserError: + raise AirflowException("Latest modified record date not valid: {latest_modified_date}") + api = DatasetAPI(bq_project_id=release.cloud_workspace.project_id, bq_dataset_id=api_bq_dataset_id) api.seed_db() now = pendulum.now() @@ -547,7 +574,9 @@ def latest_modified_record_date(release: dict) -> str: with open(release.master_manifest_file, "r") as f: reader = csv.DictReader(f) modified_dates = sorted([pendulum.parse(row["updated"]) for row in reader]) - return datetime_normalise(modified_dates[-1]) + latest_modified_record_date = datetime_normalise(modified_dates[-1]) + logging.info(f"Latest modified record date: {latest_modified_record_date}") + return latest_modified_record_date def transform_orcid_record(record_path: str) -> Union[Dict, str]: diff --git a/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/telescope.py b/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/telescope.py index d0cb660a5..746781cb4 100644 --- a/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/telescope.py +++ b/academic-observatory-workflows/academic_observatory_workflows/orcid_telescope/telescope.py @@ -27,6 +27,7 @@ from airflow.utils.trigger_rule import TriggerRule from academic_observatory_workflows.config import project_path +from academic_observatory_workflows.orcid_telescope import tasks from observatory_platform.airflow.airflow import on_failure_callback from observatory_platform.airflow.sensors import PreviousDagRunSensor from observatory_platform.airflow.tasks import check_dependencies, gke_create_storage, gke_delete_storage @@ -143,8 +144,6 @@ def orcid(): def fetch_release(**context) -> dict: """Generates the OrcidRelease object.""" - from academic_observatory_workflows.orcid_telescope import tasks - return tasks.fetch_release( dag_id=dag_params.dag_id, run_id=context["run_id"], @@ -163,16 +162,12 @@ def fetch_release(**context) -> dict: def create_dataset(release: dict, dag_params, **context) -> None: """Create datasets""" - from academic_observatory_workflows.orcid_telescope import tasks - tasks.create_dataset(release, dataset_description=dag_params.dataset_description) @task def transfer_orcid(release: dict, dag_params, **context): """Sync files from AWS bucket to Google Cloud bucket.""" - from academic_observatory_workflows.orcid_telescope import tasks - tasks.transfer_orcid( release, aws_orcid_conn_id=dag_params.aws_orcid_conn_id, @@ -181,17 +176,16 @@ def transfer_orcid(release: dict, dag_params, **context): orcid_summaries_prefix=dag_params.orcid_summaries_prefix, ) - @task + @task(trigger_rule=TriggerRule.NONE_FAILED) def bq_create_main_table_snapshot(release: dict, dag_params, **context): """Create a snapshot of each main table. The purpose of this table is to be able to rollback the table if something goes wrong. The snapshot expires after snapshot_expiry_days.""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.bq_create_main_table_snapshot(release, snapshot_expiry_days=dag_params.snapshot_expiry_days) @task.kubernetes( + trigger_rule=TriggerRule.NONE_FAILED, name="create_manifests", - trigger_rule=TriggerRule.ALL_DONE, container_resources=gke_make_container_resources( {"memory": "16G", "cpu": "16"}, dag_params.gke_params.gke_resource_overrides.get("create_manifests") ), @@ -210,9 +204,8 @@ def create_manifests(release: dict, dag_params, **context): @task.kubernetes( name="latest_modified_record_date", - trigger_rule=TriggerRule.ALL_DONE, container_resources=gke_make_container_resources( - {"memory": "2G", "cpu": "2"}, + {"memory": "4G", "cpu": "2"}, dag_params.gke_params.gke_resource_overrides.get("latest_modified_record_date"), ), **kubernetes_task_params, @@ -225,7 +218,6 @@ def latest_modified_record_date(release: dict, **context): @task.kubernetes( name="download", - trigger_rule=TriggerRule.ALL_DONE, container_resources=gke_make_container_resources( {"memory": "8G", "cpu": "8"}, dag_params.gke_params.gke_resource_overrides.get("download") ), @@ -239,9 +231,8 @@ def download(release: dict, **context): @task.kubernetes( name="transform", - trigger_rule=TriggerRule.ALL_DONE, container_resources=gke_make_container_resources( - {"memory": "16G", "cpu": "16"}, dag_params.gke_params.gke_resource_overrides.get("transform") + {"memory": "32G", "cpu": "16"}, dag_params.gke_params.gke_resource_overrides.get("transform") ), **kubernetes_task_params, ) @@ -253,7 +244,6 @@ def transform(release: dict, dag_params, **context): @task.kubernetes( name="upload_transformed", - trigger_rule=TriggerRule.ALL_DONE, container_resources=gke_make_container_resources( {"memory": "8G", "cpu": "8"}, dag_params.gke_params.gke_resource_overrides.get("upload_transformed") ), @@ -265,54 +255,47 @@ def upload_transformed(release: dict, **context): tasks.upload_transformed(release) - @task(trigger_rule=TriggerRule.ALL_DONE) + @task() def bq_load_main_table(release: dict, dag_params, **context): """Load the main table.""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.bq_load_main_table(release, schema_file_path=dag_params.schema_file_path) - @task(trigger_rule=TriggerRule.ALL_DONE) + @task(trigger_rule=TriggerRule.NONE_FAILED) def bq_load_upsert_table(release: dict, dag_params, **context): """Load the upsert table into bigquery""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.bq_load_upsert_table(release, schema_file_path=dag_params.schema_file_path) - @task(trigger_rule=TriggerRule.ALL_DONE) + @task(trigger_rule=TriggerRule.NONE_FAILED) def bq_load_delete_table(release: dict, dag_params, **context): """Load the delete table into bigquery""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.bq_load_delete_table(release, delete_schema_file_path=dag_params.delete_schema_file_path) - @task(trigger_rule=TriggerRule.ALL_DONE) + @task(trigger_rule=TriggerRule.NONE_FAILED) def bq_upsert_records(release: dict, **context): """Upsert the records from the upserts table into the main table.""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.bq_upsert_records(release) - @task(trigger_rule=TriggerRule.ALL_DONE) + @task(trigger_rule=TriggerRule.NONE_FAILED) def bq_delete_records(release: dict, **context): """Delete the records in the delete table from the main table.""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.bq_delete_records(release) - @task(trigger_rule=TriggerRule.ALL_DONE) + @task(trigger_rule=TriggerRule.NONE_FAILED) def add_dataset_release(release: dict, latest_modified_date: str, dag_params, **context) -> None: """Adds release information to API.""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.add_dataset_release( release, api_bq_dataset_id=dag_params.api_bq_dataset_id, latest_modified_date=latest_modified_date ) - @task(trigger_rule=TriggerRule.ALL_DONE) + @task() def cleanup_workflow(release: dict, **context) -> None: """Delete all files, folders and XComs associated with this release.""" - from academic_observatory_workflows.orcid_telescope import tasks tasks.cleanup_workflow(release) @@ -343,15 +326,15 @@ def cleanup_workflow(release: dict, **context) -> None: task_id=external_task_id, ) if dag_params.test_run: - task_create_storage = EmptyOperator(task_id="gke_create_storage") - task_delete_storage = EmptyOperator(task_id="gke_delete_storage") + task_create_storage = EmptyOperator(task_id="gke_create_storage", trigger_rule=TriggerRule.NONE_FAILED) + task_delete_storage = EmptyOperator(task_id="gke_delete_storage", trigger_rule=TriggerRule.NONE_FAILED) else: - task_create_storage = gke_create_storage( + task_create_storage = gke_create_storage.override(trigger_rule=TriggerRule.NONE_FAILED)( volume_name=dag_params.gke_params.gke_volume_name, volume_size=dag_params.gke_params.gke_volume_size, kubernetes_conn_id=dag_params.gke_params.gke_conn_id, ) - task_delete_storage = gke_delete_storage( + task_delete_storage = gke_delete_storage.override(trigger_rule=TriggerRule.NONE_FAILED)( volume_name=dag_params.gke_params.gke_volume_name, kubernetes_conn_id=dag_params.gke_params.gke_conn_id, )