From f654446ae34bb2351b5e52456404efa8c6002812 Mon Sep 17 00:00:00 2001 From: Staci Cooper <63313398+stacimc@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:15:37 -0800 Subject: [PATCH] Break load_from_s3 into separate tasks to fix duplicate reporting (#914) * Separate load_from_s3 from upsert_data step * Fix load_local_data_to_intermediate_table * Remove clean step from load_local_data_to_intermediate_table As far as I can tell this method is tested but never used anywhere. This should mirror `load_s3_data_to_intermediate_table` and only handle the loading steps, not the cleaning steps. If this method *is* used elsewhere, it will need to be updated to call the cleaning steps separately. * Fix tests * Also separate out the clean data step * Clarify 'load_timeout' to 'upsert_timeout' * Extend Smithsonian upsert timeout * Add types to the clean_intermediate_table_data function * Add NMNHANTHRO Smithsonian subprovider --- .../dags/common/loader/loader.py | 25 ++++++------ .../dags/common/loader/provider_details.py | 1 + openverse_catalog/dags/common/loader/sql.py | 24 ++++++------ .../dags/providers/provider_dag_factory.py | 38 ++++++++++++++++--- .../dags/providers/provider_workflows.py | 10 ++--- tests/dags/common/loader/test_loader.py | 16 ++++---- tests/dags/common/loader/test_sql.py | 27 ++++++++++++- 7 files changed, 100 insertions(+), 41 deletions(-) diff --git a/openverse_catalog/dags/common/loader/loader.py b/openverse_catalog/dags/common/loader/loader.py index 21988097f..76e201cf3 100644 --- a/openverse_catalog/dags/common/loader/loader.py +++ b/openverse_catalog/dags/common/loader/loader.py @@ -34,19 +34,22 @@ def load_s3_data( ) -def load_from_s3( - bucket, - key, - postgres_conn_id, - media_type, - tsv_version, - identifier, +def upsert_data( + postgres_conn_id: str, + media_type: str, + tsv_version: str, + identifier: str, + loaded_count: int, + duplicates_count: tuple[int, int], ) -> RecordMetrics: - loaded, missing_columns, foreign_id_dup = sql.load_s3_data_to_intermediate_table( - postgres_conn_id, bucket, key, identifier, media_type - ) + """ + Upserts data into the catalog DB from the loading table, and calculates + final record metrics. + """ + missing_columns, foreign_id_dup = duplicates_count upserted = sql.upsert_records_to_db_table( postgres_conn_id, identifier, media_type=media_type, tsv_version=tsv_version ) - url_dup = loaded - missing_columns - foreign_id_dup - upserted + + url_dup = loaded_count - missing_columns - foreign_id_dup - upserted return RecordMetrics(upserted, missing_columns, foreign_id_dup, url_dup) diff --git a/openverse_catalog/dags/common/loader/provider_details.py b/openverse_catalog/dags/common/loader/provider_details.py index 55922c407..92c22a0ad 100644 --- a/openverse_catalog/dags/common/loader/provider_details.py +++ b/openverse_catalog/dags/common/loader/provider_details.py @@ -66,6 +66,7 @@ # Smithsonian parameters SMITHSONIAN_SUB_PROVIDERS = { "smithsonian_national_museum_of_natural_history": { + "NMNHANTHRO", # NMNH - Anthropology Dept. "NMNHBIRDS", # NMNH - Vertebrate Zoology - Birds Division "NMNHBOTANY", # NMNH - Botany Dept. "NMNHEDUCATION", # NMNH - Education & Outreach diff --git a/openverse_catalog/dags/common/loader/sql.py b/openverse_catalog/dags/common/loader/sql.py index 826cd3c08..6659f79f5 100644 --- a/openverse_catalog/dags/common/loader/sql.py +++ b/openverse_catalog/dags/common/loader/sql.py @@ -2,7 +2,7 @@ from textwrap import dedent from airflow.providers.postgres.hooks.postgres import PostgresHook -from common.constants import AUDIO, IMAGE +from common.constants import AUDIO, IMAGE, MediaType from common.loader import provider_details as prov from common.loader.paths import _extract_media_type from common.storage import columns as col @@ -125,8 +125,6 @@ def load_local_data_to_intermediate_table( "Exceeded the maximum number of allowed defective rows" ) - _clean_intermediate_table_data(postgres, load_table) - def _handle_s3_load_result(cursor) -> int: """ @@ -147,7 +145,7 @@ def load_s3_data_to_intermediate_table( s3_key, identifier, media_type=IMAGE, -) -> tuple[int, int, int]: +) -> int: load_table = _get_load_table_name(identifier, media_type=media_type) logger.info(f"Loading {s3_key} from S3 Bucket {bucket} into {load_table}") @@ -168,13 +166,14 @@ def load_s3_data_to_intermediate_table( handler=_handle_s3_load_result, ) logger.info(f"Successfully loaded {loaded} records from S3") - missing_columns, foreign_id_dup = _clean_intermediate_table_data( - postgres, load_table - ) - return loaded, missing_columns, foreign_id_dup + return loaded -def _clean_intermediate_table_data(postgres_hook, load_table) -> tuple[int, int]: +def clean_intermediate_table_data( + postgres_conn_id: str, + identifier: str, + media_type: MediaType = IMAGE, +) -> tuple[int, int]: """ Necessary for old TSV files that have not been cleaned up, using `MediaStore` class: @@ -183,13 +182,16 @@ def _clean_intermediate_table_data(postgres_hook, load_table) -> tuple[int, int] Also removes any duplicate rows that have the same `provider` and `foreign_id`. """ + load_table = _get_load_table_name(identifier, media_type=media_type) + postgres = PostgresHook(postgres_conn_id=postgres_conn_id) + missing_columns = 0 for column in required_columns: - missing_columns += postgres_hook.run( + missing_columns += postgres.run( f"DELETE FROM {load_table} WHERE {column.db_name} IS NULL;", handler=RETURN_ROW_COUNT, ) - foreign_id_dup = postgres_hook.run( + foreign_id_dup = postgres.run( dedent( f""" DELETE FROM {load_table} p1 diff --git a/openverse_catalog/dags/providers/provider_dag_factory.py b/openverse_catalog/dags/providers/provider_dag_factory.py index b83211634..cfd63929b 100644 --- a/openverse_catalog/dags/providers/provider_dag_factory.py +++ b/openverse_catalog/dags/providers/provider_dag_factory.py @@ -197,18 +197,46 @@ def append_day_shift(id_str): ) load_from_s3 = PythonOperator( task_id=append_day_shift("load_from_s3"), - execution_timeout=conf.load_timeout, retries=1, - python_callable=loader.load_from_s3, + python_callable=sql.load_s3_data_to_intermediate_table, op_kwargs={ + "postgres_conn_id": DB_CONN_ID, "bucket": OPENVERSE_BUCKET, - "key": XCOM_PULL_TEMPLATE.format(copy_to_s3.task_id, "s3_key"), + "s3_key": XCOM_PULL_TEMPLATE.format( + copy_to_s3.task_id, "s3_key" + ), + "identifier": identifier, + "media_type": media_type, + }, + ) + clean_data = PythonOperator( + task_id=append_day_shift("clean_data"), + retries=1, + python_callable=sql.clean_intermediate_table_data, + op_kwargs={ + "postgres_conn_id": DB_CONN_ID, + "identifier": identifier, + "media_type": media_type, + }, + ) + upsert_data = PythonOperator( + task_id=append_day_shift("upsert_data"), + execution_timeout=conf.upsert_timeout, + retries=1, + python_callable=loader.upsert_data, + op_kwargs={ "postgres_conn_id": DB_CONN_ID, "media_type": media_type, "tsv_version": XCOM_PULL_TEMPLATE.format( copy_to_s3.task_id, "tsv_version" ), "identifier": identifier, + "loaded_count": XCOM_PULL_TEMPLATE.format( + load_from_s3.task_id, "return_value" + ), + "duplicates_count": XCOM_PULL_TEMPLATE.format( + clean_data.task_id, "return_value" + ), }, ) drop_loading_table = PythonOperator( @@ -222,10 +250,10 @@ def append_day_shift(id_str): trigger_rule=TriggerRule.ALL_DONE, ) [create_loading_table, copy_to_s3] >> load_from_s3 - load_from_s3 >> drop_loading_table + load_from_s3 >> clean_data >> upsert_data >> drop_loading_table record_counts_by_media_type[media_type] = XCOM_PULL_TEMPLATE.format( - load_from_s3.task_id, "return_value" + upsert_data.task_id, "return_value" ) load_tasks.append(load_data) diff --git a/openverse_catalog/dags/providers/provider_workflows.py b/openverse_catalog/dags/providers/provider_workflows.py index 3b95221e9..e668affef 100644 --- a/openverse_catalog/dags/providers/provider_workflows.py +++ b/openverse_catalog/dags/providers/provider_workflows.py @@ -59,7 +59,7 @@ class ProviderWorkflow: which data should be ingested). pull_timeout: datetime.timedelta giving the amount of time a given data pull may take. - load_timeout: datetime.timedelta giving the amount of time the load_data + upsert_timeout: datetime.timedelta giving the amount of time the upsert_data task may take. doc_md: string which should be used for the DAG's documentation markdown media_types: list describing the media type(s) that this provider handles @@ -84,7 +84,7 @@ class ProviderWorkflow: schedule_string: str = "@monthly" dated: bool = False pull_timeout: timedelta = timedelta(hours=24) - load_timeout: timedelta = timedelta(hours=1) + upsert_timeout: timedelta = timedelta(hours=1) doc_md: str = "" media_types: Sequence[str] = () create_preingestion_tasks: Callable | None = None @@ -125,7 +125,7 @@ def __post_init__(self): ProviderWorkflow( ingester_class=FinnishMuseumsDataIngester, start_date=datetime(2015, 11, 1), - load_timeout=timedelta(hours=5), + upsert_timeout=timedelta(hours=5), schedule_string="@daily", dated=True, ), @@ -144,7 +144,7 @@ def __post_init__(self): create_postingestion_tasks=INaturalistDataIngester.create_postingestion_tasks, schedule_string="@monthly", pull_timeout=timedelta(days=5), - load_timeout=timedelta(days=5), + upsert_timeout=timedelta(days=5), ), ProviderWorkflow( ingester_class=JamendoDataIngester, @@ -183,7 +183,7 @@ def __post_init__(self): ingester_class=SmithsonianDataIngester, start_date=datetime(2020, 1, 1), schedule_string="@weekly", - load_timeout=timedelta(hours=4), + upsert_timeout=timedelta(hours=6), ), ProviderWorkflow( ingester_class=SmkDataIngester, diff --git a/tests/dags/common/loader/test_loader.py b/tests/dags/common/loader/test_loader.py index 3f0b500f8..b731bd55b 100644 --- a/tests/dags/common/loader/test_loader.py +++ b/tests/dags/common/loader/test_loader.py @@ -6,19 +6,19 @@ @pytest.mark.parametrize( - "load_value, upsert_value, expected", + "load_value, clean_data_value, upsert_value, expected", [ - ((100, 10, 15), 75, RecordMetrics(75, 10, 15, 0)), - ((100, 0, 15), 75, RecordMetrics(75, 0, 15, 10)), - ((100, 10, 0), 75, RecordMetrics(75, 10, 0, 15)), + (100, (10, 15), 75, RecordMetrics(75, 10, 15, 0)), + (100, (0, 15), 75, RecordMetrics(75, 0, 15, 10)), + (100, (10, 0), 75, RecordMetrics(75, 10, 0, 15)), ], ) -def test_load_from_s3_calculations(load_value, upsert_value, expected): +def test_upsert_data_calculations(load_value, clean_data_value, upsert_value, expected): with mock.patch("common.loader.loader.sql") as sql_mock: - sql_mock.load_s3_data_to_intermediate_table.return_value = load_value + sql_mock.clean_intermediate_table_data.return_value = clean_data_value sql_mock.upsert_records_to_db_table.return_value = upsert_value - actual = loader.load_from_s3( - mock.Mock(), "fake", "fake", "fake", "fake", "fake" + actual = loader.upsert_data( + mock.Mock(), "fake", "fake", "fake", load_value, clean_data_value ) assert actual == expected diff --git a/tests/dags/common/loader/test_sql.py b/tests/dags/common/loader/test_sql.py index cd5129c69..a3bd31efd 100644 --- a/tests/dags/common/loader/test_sql.py +++ b/tests/dags/common/loader/test_sql.py @@ -267,7 +267,7 @@ def test_delete_more_than_max_malformed_rows( @flaky @pytest.mark.parametrize("load_function", [_load_local_tsv, _load_s3_tsv]) -def test_loaders_delete_null_url_rows( +def test_loaders_deletes_null_url_rows( postgres_with_load_table, tmpdir, empty_s3_bucket, @@ -275,7 +275,12 @@ def test_loaders_delete_null_url_rows( load_table, identifier, ): + # Load test data with some null urls into the intermediate table load_function(tmpdir, empty_s3_bucket, "url_missing.tsv", identifier) + # Clean data + sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier) + + # Check that rows with null urls were deleted null_url_check = f"SELECT COUNT (*) FROM {load_table} WHERE url IS NULL;" postgres_with_load_table.cursor.execute(null_url_check) null_url_num_rows = postgres_with_load_table.cursor.fetchone()[0] @@ -297,7 +302,12 @@ def test_loaders_delete_null_license_rows( load_table, identifier, ): + # Load test data with some null licenses into the intermediate table load_function(tmpdir, empty_s3_bucket, "license_missing.tsv", identifier) + # Clean data + sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier) + + # Check that rows with null licenses were deleted license_check = f"SELECT COUNT (*) FROM {load_table} WHERE license IS NULL;" postgres_with_load_table.cursor.execute(license_check) null_license_num_rows = postgres_with_load_table.cursor.fetchone()[0] @@ -319,9 +329,14 @@ def test_loaders_delete_null_foreign_landing_url_rows( load_table, identifier, ): + # Load test data with null foreign landings url into the intermediate table load_function( tmpdir, empty_s3_bucket, "foreign_landing_url_missing.tsv", identifier ) + # Clean data + sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier) + + # Check that rows with null foreign landing urls were deleted foreign_landing_url_check = ( f"SELECT COUNT (*) FROM {load_table} " f"WHERE foreign_landing_url IS NULL;" ) @@ -344,7 +359,12 @@ def test_data_loaders_delete_null_foreign_identifier_rows( load_table, identifier, ): + # Load test data with null foreign identifiers into the intermediate table load_function(tmpdir, empty_s3_bucket, "foreign_identifier_missing.tsv", identifier) + # Clean data + sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier) + + # Check that rows with null foreign identifiers were deleted foreign_identifier_check = ( f"SELECT COUNT (*) FROM {load_table} " f"WHERE foreign_identifier IS NULL;" ) @@ -367,9 +387,14 @@ def test_import_data_deletes_duplicate_foreign_identifier_rows( load_table, identifier, ): + # Load test data with duplicate foreign identifiers into the intermediate table load_function( tmpdir, empty_s3_bucket, "foreign_identifier_duplicate.tsv", identifier ) + # Clean data + sql.clean_intermediate_table_data(POSTGRES_CONN_ID, identifier) + + # Check that rows with duplicate foreign ids were deleted foreign_id_duplicate_check = ( f"SELECT COUNT (*) FROM {load_table} " f"WHERE foreign_identifier='135257';" )