Skip to content

Commit

Permalink
ORCID fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Sep 30, 2024
1 parent f5bdbf1 commit ce14994
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from concurrent.futures import as_completed, ProcessPoolExecutor, ThreadPoolExecutor
import datetime
from os import PathLike
import json
from typing import Dict, Optional, Tuple, Union

import pendulum
Expand All @@ -33,7 +34,9 @@
from airflow.exceptions import AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.models import DagRun
from google.auth import default
from google import auth
from google.auth.compute_engine import Credentials as ComputeEngineCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from google.cloud import bigquery
from google.cloud.bigquery import SourceFormat

Expand All @@ -46,6 +49,7 @@
from observatory_platform.date_utils import datetime_normalise
from observatory_platform.files import change_keys, save_jsonl_gz
from observatory_platform.google import bigquery as bq
from observatory_platform.google.gke import gke_service_account_email
from observatory_platform.google.gcs import (
gcs_blob_uri,
gcs_create_aws_transfer,
Expand Down Expand Up @@ -106,8 +110,10 @@ def fetch_release(
"fetch_releases: there should be at least 1 DatasetRelease in the Observatory API after the first DAG run"
)
prev_release = api.get_latest_dataset_release(dag_id=dag_id, entity_id="orcid", date_key="changefile_end_date")
logging.info(f"Extra: {prev_release.extra}")
logging.info(f"Type: {type(prev_release.extra)}")
prev_latest_modified_record = pendulum.parse(json.loads(prev_release.extra["latest_modified_record_date"]))
prev_release_end = prev_release.changefile_end_date
prev_latest_modified_record = pendulum.parse(prev_release.extra["latest_modified_record_date"])

return OrcidRelease(
dag_id=dag_id,
Expand Down Expand Up @@ -246,8 +252,21 @@ def download(release: dict):
"""Reads each batch's manifest and downloads the files from the gcs bucket."""

release = OrcidRelease.from_dict(release)
gcs_creds, project_id = default()
with gcs_hmac_key(project_id, gcs_creds.service_account_email) as (key, secret):

# Check the type of credentials
credentials, _ = auth.default()
if isinstance(credentials, ServiceAccountCredentials):
email = credentials.service_account_email
logging.info("Using service account credentials")
elif isinstance(credentials, ComputeEngineCredentials):
logging.info("Using compute engine credentials")
email = gke_service_account_email()
if not email:
raise AirflowException("Email could not be retrieved")
else:
raise AirflowException(f"Unknown credentials type: {type(credentials)}")

with gcs_hmac_key(release.cloud_workspace.project_id, email) as (key, secret):
total_files = 0
start_time = time.time()
for orcid_batch in release.orcid_batches():
Expand All @@ -265,6 +284,9 @@ def download(release: dict):

# Check for errors
if returncode != 0:
with open(orcid_batch.download_log_file) as f:
output = f.read()
logging.error(output)
raise RuntimeError(
f"Download attempt '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}. See log file: {orcid_batch.download_log_file}"
)
Expand Down Expand Up @@ -471,6 +493,11 @@ def add_dataset_release(release: dict, *, api_bq_dataset_id: str, latest_modifie
:param last_modified_release_date: The modification datetime of the last modified record of this release"""

release = OrcidRelease.from_dict(release)
try:
pendulum.parse(latest_modified_date)
except pendulum.parsing.Exceptions.ParserError:
raise AirflowException("Latest modified record date not valid: {latest_modified_date}")

api = DatasetAPI(bq_project_id=release.cloud_workspace.project_id, bq_dataset_id=api_bq_dataset_id)
api.seed_db()
now = pendulum.now()
Expand Down Expand Up @@ -547,7 +574,9 @@ def latest_modified_record_date(release: dict) -> str:
with open(release.master_manifest_file, "r") as f:
reader = csv.DictReader(f)
modified_dates = sorted([pendulum.parse(row["updated"]) for row in reader])
return datetime_normalise(modified_dates[-1])
latest_modified_record_date = datetime_normalise(modified_dates[-1])
logging.info(f"Latest modified record date: {latest_modified_record_date}")
return latest_modified_record_date


def transform_orcid_record(record_path: str) -> Union[Dict, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.utils.trigger_rule import TriggerRule

from academic_observatory_workflows.config import project_path
from academic_observatory_workflows.orcid_telescope import tasks
from observatory_platform.airflow.airflow import on_failure_callback
from observatory_platform.airflow.sensors import PreviousDagRunSensor
from observatory_platform.airflow.tasks import check_dependencies, gke_create_storage, gke_delete_storage
Expand Down Expand Up @@ -143,8 +144,6 @@ def orcid():
def fetch_release(**context) -> dict:
"""Generates the OrcidRelease object."""

from academic_observatory_workflows.orcid_telescope import tasks

return tasks.fetch_release(
dag_id=dag_params.dag_id,
run_id=context["run_id"],
Expand All @@ -163,16 +162,12 @@ def fetch_release(**context) -> dict:
def create_dataset(release: dict, dag_params, **context) -> None:
"""Create datasets"""

from academic_observatory_workflows.orcid_telescope import tasks

tasks.create_dataset(release, dataset_description=dag_params.dataset_description)

@task
def transfer_orcid(release: dict, dag_params, **context):
"""Sync files from AWS bucket to Google Cloud bucket."""

from academic_observatory_workflows.orcid_telescope import tasks

tasks.transfer_orcid(
release,
aws_orcid_conn_id=dag_params.aws_orcid_conn_id,
Expand All @@ -181,17 +176,16 @@ def transfer_orcid(release: dict, dag_params, **context):
orcid_summaries_prefix=dag_params.orcid_summaries_prefix,
)

@task
@task(trigger_rule=TriggerRule.NONE_FAILED)
def bq_create_main_table_snapshot(release: dict, dag_params, **context):
"""Create a snapshot of each main table. The purpose of this table is to be able to rollback the table
if something goes wrong. The snapshot expires after snapshot_expiry_days."""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.bq_create_main_table_snapshot(release, snapshot_expiry_days=dag_params.snapshot_expiry_days)

@task.kubernetes(
trigger_rule=TriggerRule.NONE_FAILED,
name="create_manifests",
trigger_rule=TriggerRule.ALL_DONE,
container_resources=gke_make_container_resources(
{"memory": "16G", "cpu": "16"}, dag_params.gke_params.gke_resource_overrides.get("create_manifests")
),
Expand All @@ -210,9 +204,8 @@ def create_manifests(release: dict, dag_params, **context):

@task.kubernetes(
name="latest_modified_record_date",
trigger_rule=TriggerRule.ALL_DONE,
container_resources=gke_make_container_resources(
{"memory": "2G", "cpu": "2"},
{"memory": "4G", "cpu": "2"},
dag_params.gke_params.gke_resource_overrides.get("latest_modified_record_date"),
),
**kubernetes_task_params,
Expand All @@ -225,7 +218,6 @@ def latest_modified_record_date(release: dict, **context):

@task.kubernetes(
name="download",
trigger_rule=TriggerRule.ALL_DONE,
container_resources=gke_make_container_resources(
{"memory": "8G", "cpu": "8"}, dag_params.gke_params.gke_resource_overrides.get("download")
),
Expand All @@ -239,9 +231,8 @@ def download(release: dict, **context):

@task.kubernetes(
name="transform",
trigger_rule=TriggerRule.ALL_DONE,
container_resources=gke_make_container_resources(
{"memory": "16G", "cpu": "16"}, dag_params.gke_params.gke_resource_overrides.get("transform")
{"memory": "32G", "cpu": "16"}, dag_params.gke_params.gke_resource_overrides.get("transform")
),
**kubernetes_task_params,
)
Expand All @@ -253,7 +244,6 @@ def transform(release: dict, dag_params, **context):

@task.kubernetes(
name="upload_transformed",
trigger_rule=TriggerRule.ALL_DONE,
container_resources=gke_make_container_resources(
{"memory": "8G", "cpu": "8"}, dag_params.gke_params.gke_resource_overrides.get("upload_transformed")
),
Expand All @@ -265,54 +255,47 @@ def upload_transformed(release: dict, **context):

tasks.upload_transformed(release)

@task(trigger_rule=TriggerRule.ALL_DONE)
@task()
def bq_load_main_table(release: dict, dag_params, **context):
"""Load the main table."""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.bq_load_main_table(release, schema_file_path=dag_params.schema_file_path)

@task(trigger_rule=TriggerRule.ALL_DONE)
@task(trigger_rule=TriggerRule.NONE_FAILED)
def bq_load_upsert_table(release: dict, dag_params, **context):
"""Load the upsert table into bigquery"""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.bq_load_upsert_table(release, schema_file_path=dag_params.schema_file_path)

@task(trigger_rule=TriggerRule.ALL_DONE)
@task(trigger_rule=TriggerRule.NONE_FAILED)
def bq_load_delete_table(release: dict, dag_params, **context):
"""Load the delete table into bigquery"""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.bq_load_delete_table(release, delete_schema_file_path=dag_params.delete_schema_file_path)

@task(trigger_rule=TriggerRule.ALL_DONE)
@task(trigger_rule=TriggerRule.NONE_FAILED)
def bq_upsert_records(release: dict, **context):
"""Upsert the records from the upserts table into the main table."""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.bq_upsert_records(release)

@task(trigger_rule=TriggerRule.ALL_DONE)
@task(trigger_rule=TriggerRule.NONE_FAILED)
def bq_delete_records(release: dict, **context):
"""Delete the records in the delete table from the main table."""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.bq_delete_records(release)

@task(trigger_rule=TriggerRule.ALL_DONE)
@task(trigger_rule=TriggerRule.NONE_FAILED)
def add_dataset_release(release: dict, latest_modified_date: str, dag_params, **context) -> None:
"""Adds release information to API."""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.add_dataset_release(
release, api_bq_dataset_id=dag_params.api_bq_dataset_id, latest_modified_date=latest_modified_date
)

@task(trigger_rule=TriggerRule.ALL_DONE)
@task()
def cleanup_workflow(release: dict, **context) -> None:
"""Delete all files, folders and XComs associated with this release."""
from academic_observatory_workflows.orcid_telescope import tasks

tasks.cleanup_workflow(release)

Expand Down Expand Up @@ -343,15 +326,15 @@ def cleanup_workflow(release: dict, **context) -> None:
task_id=external_task_id,
)
if dag_params.test_run:
task_create_storage = EmptyOperator(task_id="gke_create_storage")
task_delete_storage = EmptyOperator(task_id="gke_delete_storage")
task_create_storage = EmptyOperator(task_id="gke_create_storage", trigger_rule=TriggerRule.NONE_FAILED)
task_delete_storage = EmptyOperator(task_id="gke_delete_storage", trigger_rule=TriggerRule.NONE_FAILED)
else:
task_create_storage = gke_create_storage(
task_create_storage = gke_create_storage.override(trigger_rule=TriggerRule.NONE_FAILED)(
volume_name=dag_params.gke_params.gke_volume_name,
volume_size=dag_params.gke_params.gke_volume_size,
kubernetes_conn_id=dag_params.gke_params.gke_conn_id,
)
task_delete_storage = gke_delete_storage(
task_delete_storage = gke_delete_storage.override(trigger_rule=TriggerRule.NONE_FAILED)(
volume_name=dag_params.gke_params.gke_volume_name,
kubernetes_conn_id=dag_params.gke_params.gke_conn_id,
)
Expand Down

0 comments on commit ce14994

Please sign in to comment.