Skip to content
This repository has been archived by the owner on Aug 4, 2023. It is now read-only.

Commit

Permalink
Break load_from_s3 into separate tasks to fix duplicate reporting (#914)
Browse files Browse the repository at this point in the history
* Separate load_from_s3 from upsert_data step

* Fix load_local_data_to_intermediate_table

* Remove clean step from load_local_data_to_intermediate_table

As far as I can tell this method is tested but never used anywhere.
This should mirror `load_s3_data_to_intermediate_table` and only
handle the loading steps, not the cleaning steps.

If this method *is* used elsewhere, it will need to be updated to
call the cleaning steps separately.

* Fix tests

* Also separate out the clean data step

* Clarify 'load_timeout' to 'upsert_timeout'

* Extend Smithsonian upsert timeout

* Add types to the clean_intermediate_table_data function

* Add NMNHANTHRO Smithsonian subprovider
  • Loading branch information
stacimc authored Dec 19, 2022
1 parent 3979fc2 commit f654446
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 41 deletions.
25 changes: 14 additions & 11 deletions openverse_catalog/dags/common/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,22 @@ def load_s3_data(
)


def load_from_s3(
bucket,
key,
postgres_conn_id,
media_type,
tsv_version,
identifier,
def upsert_data(
postgres_conn_id: str,
media_type: str,
tsv_version: str,
identifier: str,
loaded_count: int,
duplicates_count: tuple[int, int],
) -> RecordMetrics:
loaded, missing_columns, foreign_id_dup = sql.load_s3_data_to_intermediate_table(
postgres_conn_id, bucket, key, identifier, media_type
)
"""
Upserts data into the catalog DB from the loading table, and calculates
final record metrics.
"""
missing_columns, foreign_id_dup = duplicates_count
upserted = sql.upsert_records_to_db_table(
postgres_conn_id, identifier, media_type=media_type, tsv_version=tsv_version
)
url_dup = loaded - missing_columns - foreign_id_dup - upserted

url_dup = loaded_count - missing_columns - foreign_id_dup - upserted
return RecordMetrics(upserted, missing_columns, foreign_id_dup, url_dup)
1 change: 1 addition & 0 deletions openverse_catalog/dags/common/loader/provider_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
# Smithsonian parameters
SMITHSONIAN_SUB_PROVIDERS = {
"smithsonian_national_museum_of_natural_history": {
"NMNHANTHRO", # NMNH - Anthropology Dept.
"NMNHBIRDS", # NMNH - Vertebrate Zoology - Birds Division
"NMNHBOTANY", # NMNH - Botany Dept.
"NMNHEDUCATION", # NMNH - Education & Outreach
Expand Down
24 changes: 13 additions & 11 deletions openverse_catalog/dags/common/loader/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from textwrap import dedent

from airflow.providers.postgres.hooks.postgres import PostgresHook
from common.constants import AUDIO, IMAGE
from common.constants import AUDIO, IMAGE, MediaType
from common.loader import provider_details as prov
from common.loader.paths import _extract_media_type
from common.storage import columns as col
Expand Down Expand Up @@ -125,8 +125,6 @@ def load_local_data_to_intermediate_table(
"Exceeded the maximum number of allowed defective rows"
)

_clean_intermediate_table_data(postgres, load_table)


def _handle_s3_load_result(cursor) -> int:
"""
Expand All @@ -147,7 +145,7 @@ def load_s3_data_to_intermediate_table(
s3_key,
identifier,
media_type=IMAGE,
) -> tuple[int, int, int]:
) -> int:
load_table = _get_load_table_name(identifier, media_type=media_type)
logger.info(f"Loading {s3_key} from S3 Bucket {bucket} into {load_table}")

Expand All @@ -168,13 +166,14 @@ def load_s3_data_to_intermediate_table(
handler=_handle_s3_load_result,
)
logger.info(f"Successfully loaded {loaded} records from S3")
missing_columns, foreign_id_dup = _clean_intermediate_table_data(
postgres, load_table
)
return loaded, missing_columns, foreign_id_dup
return loaded


def _clean_intermediate_table_data(postgres_hook, load_table) -> tuple[int, int]:
def clean_intermediate_table_data(
postgres_conn_id: str,
identifier: str,
media_type: MediaType = IMAGE,
) -> tuple[int, int]:
"""
Necessary for old TSV files that have not been cleaned up,
using `MediaStore` class:
Expand All @@ -183,13 +182,16 @@ def _clean_intermediate_table_data(postgres_hook, load_table) -> tuple[int, int]
Also removes any duplicate rows that have the same `provider`
and `foreign_id`.
"""
load_table = _get_load_table_name(identifier, media_type=media_type)
postgres = PostgresHook(postgres_conn_id=postgres_conn_id)

missing_columns = 0
for column in required_columns:
missing_columns += postgres_hook.run(
missing_columns += postgres.run(
f"DELETE FROM {load_table} WHERE {column.db_name} IS NULL;",
handler=RETURN_ROW_COUNT,
)
foreign_id_dup = postgres_hook.run(
foreign_id_dup = postgres.run(
dedent(
f"""
DELETE FROM {load_table} p1
Expand Down
38 changes: 33 additions & 5 deletions openverse_catalog/dags/providers/provider_dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,46 @@ def append_day_shift(id_str):
)
load_from_s3 = PythonOperator(
task_id=append_day_shift("load_from_s3"),
execution_timeout=conf.load_timeout,
retries=1,
python_callable=loader.load_from_s3,
python_callable=sql.load_s3_data_to_intermediate_table,
op_kwargs={
"postgres_conn_id": DB_CONN_ID,
"bucket": OPENVERSE_BUCKET,
"key": XCOM_PULL_TEMPLATE.format(copy_to_s3.task_id, "s3_key"),
"s3_key": XCOM_PULL_TEMPLATE.format(
copy_to_s3.task_id, "s3_key"
),
"identifier": identifier,
"media_type": media_type,
},
)
clean_data = PythonOperator(
task_id=append_day_shift("clean_data"),
retries=1,
python_callable=sql.clean_intermediate_table_data,
op_kwargs={
"postgres_conn_id": DB_CONN_ID,
"identifier": identifier,
"media_type": media_type,
},
)
upsert_data = PythonOperator(
task_id=append_day_shift("upsert_data"),
execution_timeout=conf.upsert_timeout,
retries=1,
python_callable=loader.upsert_data,
op_kwargs={
"postgres_conn_id": DB_CONN_ID,
"media_type": media_type,
"tsv_version": XCOM_PULL_TEMPLATE.format(
copy_to_s3.task_id, "tsv_version"
),
"identifier": identifier,
"loaded_count": XCOM_PULL_TEMPLATE.format(
load_from_s3.task_id, "return_value"
),
"duplicates_count": XCOM_PULL_TEMPLATE.format(
clean_data.task_id, "return_value"
),
},
)
drop_loading_table = PythonOperator(
Expand All @@ -222,10 +250,10 @@ def append_day_shift(id_str):
trigger_rule=TriggerRule.ALL_DONE,
)
[create_loading_table, copy_to_s3] >> load_from_s3
load_from_s3 >> drop_loading_table
load_from_s3 >> clean_data >> upsert_data >> drop_loading_table

record_counts_by_media_type[media_type] = XCOM_PULL_TEMPLATE.format(
load_from_s3.task_id, "return_value"
upsert_data.task_id, "return_value"
)
load_tasks.append(load_data)

Expand Down
10 changes: 5 additions & 5 deletions openverse_catalog/dags/providers/provider_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ProviderWorkflow:
which data should be ingested).
pull_timeout: datetime.timedelta giving the amount of time a given data
pull may take.
load_timeout: datetime.timedelta giving the amount of time the load_data
upsert_timeout: datetime.timedelta giving the amount of time the upsert_data
task may take.
doc_md: string which should be used for the DAG's documentation markdown
media_types: list describing the media type(s) that this provider handles
Expand All @@ -84,7 +84,7 @@ class ProviderWorkflow:
schedule_string: str = "@monthly"
dated: bool = False
pull_timeout: timedelta = timedelta(hours=24)
load_timeout: timedelta = timedelta(hours=1)
upsert_timeout: timedelta = timedelta(hours=1)
doc_md: str = ""
media_types: Sequence[str] = ()
create_preingestion_tasks: Callable | None = None
Expand Down Expand Up @@ -125,7 +125,7 @@ def __post_init__(self):
ProviderWorkflow(
ingester_class=FinnishMuseumsDataIngester,
start_date=datetime(2015, 11, 1),
load_timeout=timedelta(hours=5),
upsert_timeout=timedelta(hours=5),
schedule_string="@daily",
dated=True,
),
Expand All @@ -144,7 +144,7 @@ def __post_init__(self):
create_postingestion_tasks=INaturalistDataIngester.create_postingestion_tasks,
schedule_string="@monthly",
pull_timeout=timedelta(days=5),
load_timeout=timedelta(days=5),
upsert_timeout=timedelta(days=5),
),
ProviderWorkflow(
ingester_class=JamendoDataIngester,
Expand Down Expand Up @@ -183,7 +183,7 @@ def __post_init__(self):
ingester_class=SmithsonianDataIngester,
start_date=datetime(2020, 1, 1),
schedule_string="@weekly",
load_timeout=timedelta(hours=4),
upsert_timeout=timedelta(hours=6),
),
ProviderWorkflow(
ingester_class=SmkDataIngester,
Expand Down
16 changes: 8 additions & 8 deletions tests/dags/common/loader/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@


@pytest.mark.parametrize(
"load_value, upsert_value, expected",
"load_value, clean_data_value, upsert_value, expected",
[
((100, 10, 15), 75, RecordMetrics(75, 10, 15, 0)),
((100, 0, 15), 75, RecordMetrics(75, 0, 15, 10)),
((100, 10, 0), 75, RecordMetrics(75, 10, 0, 15)),
(100, (10, 15), 75, RecordMetrics(75, 10, 15, 0)),
(100, (0, 15), 75, RecordMetrics(75, 0, 15, 10)),
(100, (10, 0), 75, RecordMetrics(75, 10, 0, 15)),
],
)
def test_load_from_s3_calculations(load_value, upsert_value, expected):
def test_upsert_data_calculations(load_value, clean_data_value, upsert_value, expected):
with mock.patch("common.loader.loader.sql") as sql_mock:
sql_mock.load_s3_data_to_intermediate_table.return_value = load_value
sql_mock.clean_intermediate_table_data.return_value = clean_data_value
sql_mock.upsert_records_to_db_table.return_value = upsert_value

actual = loader.load_from_s3(
mock.Mock(), "fake", "fake", "fake", "fake", "fake"
actual = loader.upsert_data(
mock.Mock(), "fake", "fake", "fake", load_value, clean_data_value
)
assert actual == expected
27 changes: 26 additions & 1 deletion tests/dags/common/loader/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,20 @@ def test_delete_more_than_max_malformed_rows(

@flaky
@pytest.mark.parametrize("load_function", [_load_local_tsv, _load_s3_tsv])
def test_loaders_delete_null_url_rows(
def test_loaders_deletes_null_url_rows(
postgres_with_load_table,
tmpdir,
empty_s3_bucket,
load_function,
load_table,
identifier,
):
# Load test data with some null urls into the intermediate table
load_function(tmpdir, empty_s3_bucket, "url_missing.tsv", identifier)
# Clean data
sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier)

# Check that rows with null urls were deleted
null_url_check = f"SELECT COUNT (*) FROM {load_table} WHERE url IS NULL;"
postgres_with_load_table.cursor.execute(null_url_check)
null_url_num_rows = postgres_with_load_table.cursor.fetchone()[0]
Expand All @@ -297,7 +302,12 @@ def test_loaders_delete_null_license_rows(
load_table,
identifier,
):
# Load test data with some null licenses into the intermediate table
load_function(tmpdir, empty_s3_bucket, "license_missing.tsv", identifier)
# Clean data
sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier)

# Check that rows with null licenses were deleted
license_check = f"SELECT COUNT (*) FROM {load_table} WHERE license IS NULL;"
postgres_with_load_table.cursor.execute(license_check)
null_license_num_rows = postgres_with_load_table.cursor.fetchone()[0]
Expand All @@ -319,9 +329,14 @@ def test_loaders_delete_null_foreign_landing_url_rows(
load_table,
identifier,
):
# Load test data with null foreign landings url into the intermediate table
load_function(
tmpdir, empty_s3_bucket, "foreign_landing_url_missing.tsv", identifier
)
# Clean data
sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier)

# Check that rows with null foreign landing urls were deleted
foreign_landing_url_check = (
f"SELECT COUNT (*) FROM {load_table} " f"WHERE foreign_landing_url IS NULL;"
)
Expand All @@ -344,7 +359,12 @@ def test_data_loaders_delete_null_foreign_identifier_rows(
load_table,
identifier,
):
# Load test data with null foreign identifiers into the intermediate table
load_function(tmpdir, empty_s3_bucket, "foreign_identifier_missing.tsv", identifier)
# Clean data
sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier)

# Check that rows with null foreign identifiers were deleted
foreign_identifier_check = (
f"SELECT COUNT (*) FROM {load_table} " f"WHERE foreign_identifier IS NULL;"
)
Expand All @@ -367,9 +387,14 @@ def test_import_data_deletes_duplicate_foreign_identifier_rows(
load_table,
identifier,
):
# Load test data with duplicate foreign identifiers into the intermediate table
load_function(
tmpdir, empty_s3_bucket, "foreign_identifier_duplicate.tsv", identifier
)
# Clean data
sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier)

# Check that rows with duplicate foreign ids were deleted
foreign_id_duplicate_check = (
f"SELECT COUNT (*) FROM {load_table} " f"WHERE foreign_identifier='135257';"
)
Expand Down

0 comments on commit f654446

Please sign in to comment.