Skip to content

Commit

Permalink
Fix: OpenAlex Schema Generator (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmassen-hane authored Dec 6, 2023
1 parent f197c7a commit a00ce0e
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 20 deletions.
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
66 changes: 49 additions & 17 deletions academic_observatory_workflows/workflows/openalex_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from datetime import timedelta
from collections import OrderedDict
from json.encoder import JSONEncoder
from mergedeep import merge, Strategy
from typing import List, Dict, Tuple, Optional, Any
from concurrent.futures import ProcessPoolExecutor, as_completed
from bigquery_schema_generator.generate_schema import SchemaGenerator, flatten_schema_map
Expand Down Expand Up @@ -65,7 +64,7 @@
)
from observatory.platform.observatory_environment import log_diff
from observatory.platform.config import AirflowConns
from observatory.platform.files import clean_dir, load_jsonl
from observatory.platform.files import clean_dir
from observatory.platform.gcs import (
gcs_create_aws_transfer,
gcs_upload_transfer_manifest,
Expand Down Expand Up @@ -678,27 +677,30 @@ def transform(self, release: OpenAlexRelease, entity_name: str = None, **kwargs)
max_processes = self.max_processes
logging.info(f"{task_id}: transforming files for OpenAlexEntity({entity_name}), no. workers: {max_processes}")

# Merge function only expects dicts, not lists.
merged_schema = {"schema": []}
# Initialise schema generator
merged_schema_map = OrderedDict()

with ProcessPoolExecutor(max_workers=max_processes) as executor:
futures = []
for entry in entity.current_entries:
input_path = os.path.join(release.download_folder, entry.object_key)
output_path = os.path.join(release.transform_folder, entry.object_key)
futures.append(executor.submit(transform_file, input_path, output_path))
for future in as_completed(futures):
result, schema_error = future.result()
generated_schema = {"schema": result}
schema_map, schema_error = future.result()

if schema_error:
logging.info(f"Error generating schema from transformed data, please investigate: {schema_error}")

# Merge the schemas together (some part files may contain more keys/fields than others).
merged_schema = merge(merged_schema, generated_schema, strategy=Strategy.ADDITIVE)
# Merge the schemas from each process. Each data file could have more fields than others.
merged_schema_map = merge_schema_maps(to_add=schema_map, old=merged_schema_map)

# Flatten schema from nested OrderedDicts to a regular Bigquery schema.
merged_schema = flatten_schema(schema_map=merged_schema_map)

# Save schema to file
with open(entity.generated_schema_path, mode="w") as f_out:
json.dump(merged_schema["schema"], f_out, indent=2)
json.dump(merged_schema, f_out, indent=2)

def upload_schema(self, release: OpenAlexRelease, entity_name: str = None, **kwargs):
"""Upload the generated schema from the transform step to GCS."""
Expand Down Expand Up @@ -1092,7 +1094,7 @@ def fetch_merged_ids(
return results


def transform_file(download_path: str, transform_path: str) -> Tuple[dict, list]:
def transform_file(download_path: str, transform_path: str) -> Tuple[OrderedDict, list]:
"""Transforms a single file.
Each entry/object in the gzip input file is transformed and the transformed object is immediately written out to
a gzip file. For each entity only one field has to be transformed.
Expand All @@ -1101,14 +1103,16 @@ def transform_file(download_path: str, transform_path: str) -> Tuple[dict, list]
using the ScehmaGenerator from the 'bigquery_schema_generator' package.
:param download_path: The path to the file with the OpenAlex entries.
:param transform_path: The path where transformed data will be saved
:return: schema. A BQ style schema generated from the transformed records.
:param transform_path: The path where transformed data will be saved.
:return: schema_map. A nested OrderedDict object produced by the SchemaGenertaor.
:return: schema_generator.error_logs: Possible error logs produced by the SchemaGenerator.
"""

# Make base folder, e.g. authors/updated_date=2023-09-17
base_folder = os.path.dirname(transform_path)
os.makedirs(base_folder, exist_ok=True)

# Initialise the schema generator.
schema_map = OrderedDict()
schema_generator = SchemaGenerator(input_format="dict")

Expand All @@ -1118,6 +1122,8 @@ def transform_file(download_path: str, transform_path: str) -> Tuple[dict, list]
for obj in reader.iter(skip_empty=True):
transform_object(obj)

# Wrap this in a try and pass so that it doesn't
# cause the transform step to fail unexpectedly.
try:
schema_generator.deduce_schema_for_record(obj, schema_map)
except Exception:
Expand All @@ -1128,10 +1134,7 @@ def transform_file(download_path: str, transform_path: str) -> Tuple[dict, list]

logging.info(f"Finished transform, saved to {transform_path}")

# Convert schema from nested OrderedDicts to regular dictionaries
schema = flatten_schema(schema_map)

return schema, schema_generator.error_logs
return schema_map, schema_generator.error_logs


def transform_object(obj: dict):
Expand Down Expand Up @@ -1222,14 +1225,43 @@ def bq_compare_schemas(expected: List[dict], actual: List[dict], check_types_mat
for exp_field, act_field in zip(expected, actual):
# Ignore the "mode" and "description" definitions in fields as they are not required for check.
diff = DeepDiff(exp_field, act_field, ignore_order=True, exclude_regex_paths=r"\s*(description|mode)")
logging.info(f"Differeneces in the fields: {exp_field}")
for diff_type, changes in diff.items():
all_matched = False
log_diff(diff_type, changes)

if "fields" in exp_field and not "fields" in act_field:
logging.info(f"Fields are present under expected but not in actual! Field name: {exp_field['name']}")
all_mathced = False
elif not "fields" in exp_field and "fields" in act_field:
logging.info(f"Fields are present under actual but not in expected! Field name: {act_field['name']}")
all_matched = False
elif "fields" in exp_field and "fields" in act_field:
all_matched = bq_compare_schemas(exp_field["fields"], act_field["fields"], check_types_match)

return all_matched


def merge_schema_maps(to_add: OrderedDict, old: OrderedDict) -> OrderedDict:
"""Using the SchemaGenerator from the bigquery_schema_generator library, merge the schemas found
when from scanning through files into one large nested OrderedDict.
:param to_add: The incoming schema to add to the existing "old" schema.
:param old: The existing old schema with previously populated values.
:return: The old schema with newly added fields.
"""

schema_generator = SchemaGenerator()

if old:
for key, value in to_add.items():
old[key] = schema_generator.merge_schema_entry(old_schema_entry=old[key], new_schema_entry=value)
else:
# Initialise it with first result if it is empty
old = to_add.copy()

return old


def flatten_schema(schema_map: OrderedDict) -> dict:
"""A quick trick using the JSON encoder and load string function to convert from a nested
OrderedDict object to a regular dictionary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

import boto3
import pendulum
from collections import OrderedDict
from airflow.models import Connection
from airflow.utils.state import State
from click.testing import CliRunner
from google.cloud.exceptions import NotFound
from bigquery_schema_generator.generate_schema import SchemaGenerator

from academic_observatory_workflows.config import test_fixtures_folder
from academic_observatory_workflows.workflows.openalex_telescope import (
Expand All @@ -39,20 +41,23 @@
Manifest,
Meta,
MergedId,
load_json,
transform_object,
s3_uri_parts,
OpenAlexEntity,
fetch_manifest,
fetch_merged_ids,
bq_compare_schemas,
merge_schema_maps,
flatten_schema,
)
from academic_observatory_workflows.workflows.openalex_telescope import (
parse_release_msg,
)
from observatory.platform.config import AirflowConns
from observatory.platform.api import get_dataset_releases
from observatory.platform.bigquery import bq_table_id, bq_sharded_table_id
from observatory.platform.files import save_jsonl_gz, load_file
from observatory.platform.files import save_jsonl_gz, load_file, load_jsonl
from observatory.platform.gcs import gcs_blob_name_from_path
from observatory.platform.observatory_config import Workflow, CloudWorkspace
from observatory.platform.observatory_environment import (
Expand Down Expand Up @@ -540,6 +545,34 @@ def test_bq_compare_schemas(self):

self.assertFalse(bq_compare_schemas(expected, actual, True))

def test_merge_schema_maps(self):
test1 = load_jsonl(os.path.join(test_fixtures_folder(), "openalex", "schema_generator", "part_000.jsonl"))
test2 = load_jsonl(os.path.join(test_fixtures_folder(), "openalex", "schema_generator", "part_001.jsonl"))

expected_schema_path = os.path.join(test_fixtures_folder(), "openalex", "schema_generator", "expected.json")
expected = load_and_parse_json(file_path=expected_schema_path)

# Create schema maps using both the test files
schema_map1 = OrderedDict()
schema_map2 = OrderedDict()
schema_generator = SchemaGenerator(input_format="dict")

# Both schema_maps need to be independent of each other here.
for record1 in test1:
schema_generator.deduce_schema_for_record(record1, schema_map1)

for record2 in test2:
schema_generator.deduce_schema_for_record(record2, schema_map2)

# Merge the two schemas together - this is similar to how it will merge when each process from a ProcessPool
# gives a new schema map from each data file that's been transformed.
merged_schema_map = OrderedDict()
for incoming in [schema_map1, schema_map2]:
merged_schema_map = merge_schema_maps(to_add=incoming, old=merged_schema_map)
merged_schema = flatten_schema(merged_schema_map)

self.assertTrue(bq_compare_schemas(actual=merged_schema, expected=expected))


def upload_folder_to_s3(bucket_name: str, folder_path: str, s3_prefix=None):
s3 = boto3.client("s3")
Expand Down Expand Up @@ -709,7 +742,7 @@ def test_dag_load(self):
def test_telescope(self, m_send_slack_msg):
"""Test the OpenAlex telescope end to end."""

env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port(), age_to_delete=0.05)
env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port())
bq_dataset_id = env.add_dataset()

# Create the Observatory environment and run tests
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ Deprecated>=1,<2
limits>=3,<4
biopython>=1.81,<2
glom>=23.0.0,<24
mergedeep>=1.3.4
bigquery-schema-generator>=1.5.1

0 comments on commit a00ce0e

Please sign in to comment.