diff --git a/openverse_catalog/dags/common/operators/postgres_result.py b/openverse_catalog/dags/common/operators/postgres_result.py deleted file mode 100644 index 66473e974..000000000 --- a/openverse_catalog/dags/common/operators/postgres_result.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections.abc import Callable, Iterable, Mapping -from typing import TYPE_CHECKING - -from airflow.providers.postgres.hooks.postgres import PostgresHook -from airflow.providers.postgres.operators.postgres import PostgresOperator -from psycopg2.sql import SQL, Identifier - - -if TYPE_CHECKING: - from airflow.utils.context import Context - - -class PostgresResultOperator(PostgresOperator): - """ - Override for the PostgresOperator which is functionally identical except that a - handler can be specified for the PostgresHook. The value(s) accumulated from the - run function will then be pushed as an XCom. - """ - - def __init__( - self, - *, - sql: str | list[str], - handler: Callable, - postgres_conn_id: str = "postgres_default", - autocommit: bool = False, - parameters: Mapping | Iterable | None = None, - database: str | None = None, - runtime_parameters: Mapping | None = None, - **kwargs, - ) -> None: - super().__init__( - sql=sql, - postgres_conn_id=postgres_conn_id, - autocommit=autocommit, - parameters=parameters, - database=database, - runtime_parameters=runtime_parameters, - **kwargs, - ) - self.handler = handler - - def execute(self, context: "Context"): - """ - This almost exactly mirrors PostgresOperator::execute, except that it allows - passing a handler into the hook. - """ - self.hook = PostgresHook( - postgres_conn_id=self.postgres_conn_id, schema=self.database - ) - if self.runtime_parameters: - final_sql = [] - sql_param = {} - for param in self.runtime_parameters: - set_param_sql = f"SET {{}} TO %({param})s;" - dynamic_sql = SQL(set_param_sql).format(Identifier(f"{param}")) - final_sql.append(dynamic_sql) - for param, val in self.runtime_parameters.items(): - sql_param.update({f"{param}": f"{val}"}) - if self.parameters: - sql_param.update(self.parameters) - if isinstance(self.sql, str): - final_sql.append(SQL(self.sql)) - else: - final_sql.extend(list(map(SQL, self.sql))) - results = self.hook.run( - final_sql, - self.autocommit, - parameters=sql_param, - handler=self.handler, - ) - else: - results = self.hook.run( - self.sql, - self.autocommit, - parameters=self.parameters, - handler=self.handler, - ) - for output in self.hook.conn.notices: - self.log.info(output) - - return results diff --git a/openverse_catalog/dags/data_refresh/dag_factory.py b/openverse_catalog/dags/data_refresh/dag_factory.py index ee51c5409..0815ae44d 100644 --- a/openverse_catalog/dags/data_refresh/dag_factory.py +++ b/openverse_catalog/dags/data_refresh/dag_factory.py @@ -37,6 +37,7 @@ from airflow import DAG from airflow.models.dagrun import DagRun from airflow.operators.python import BranchPythonOperator, PythonOperator +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.settings import SASession from airflow.utils.session import provide_session from airflow.utils.state import State @@ -45,7 +46,6 @@ OPENLEDGER_API_CONN_ID, XCOM_PULL_TEMPLATE, ) -from common.operators.postgres_result import PostgresResultOperator from data_refresh.data_refresh_task_factory import create_data_refresh_task_group from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh from data_refresh.refresh_popularity_metrics_task_factory import ( @@ -206,11 +206,12 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc ) # Get the current number of records in the target API table - before_record_count = PostgresResultOperator( + before_record_count = SQLExecuteQueryOperator( task_id="get_before_record_count", - postgres_conn_id=OPENLEDGER_API_CONN_ID, + conn_id=OPENLEDGER_API_CONN_ID, sql=count_sql, handler=_single_value, + return_last=True, ) # Refresh underlying popularity tables. This is required infrequently in order @@ -231,11 +232,12 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc ) # Get the final number of records in the API table after the refresh - after_record_count = PostgresResultOperator( + after_record_count = SQLExecuteQueryOperator( task_id="get_after_record_count", - postgres_conn_id=OPENLEDGER_API_CONN_ID, + conn_id=OPENLEDGER_API_CONN_ID, sql=count_sql, handler=_single_value, + return_last=True, ) # Report the count difference to Slack diff --git a/openverse_catalog/dags/providers/provider_api_scripts/inaturalist.py b/openverse_catalog/dags/providers/provider_api_scripts/inaturalist.py index 284da17dc..4952d46aa 100644 --- a/openverse_catalog/dags/providers/provider_api_scripts/inaturalist.py +++ b/openverse_catalog/dags/providers/provider_api_scripts/inaturalist.py @@ -26,8 +26,8 @@ from airflow.exceptions import AirflowSkipException from airflow.operators.python import PythonOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.postgres.hooks.postgres import PostgresHook -from airflow.providers.postgres.operators.postgres import PostgresOperator from airflow.utils.task_group import TaskGroup from common.constants import POSTGRES_CONN_ID from common.licenses import NO_LICENSE_FOUND, get_license_info @@ -137,17 +137,17 @@ def create_preingestion_tasks(): }, ) - create_inaturalist_schema = PostgresOperator( + create_inaturalist_schema = SQLExecuteQueryOperator( task_id="create_inaturalist_schema", - postgres_conn_id=POSTGRES_CONN_ID, + conn_id=POSTGRES_CONN_ID, sql=(SCRIPT_DIR / "create_schema.sql").read_text(), ) with TaskGroup(group_id="load_source_files") as load_source_files: for source_name in SOURCE_FILE_NAMES: - PostgresOperator( + SQLExecuteQueryOperator( task_id=f"load_{source_name}", - postgres_conn_id=POSTGRES_CONN_ID, + conn_id=POSTGRES_CONN_ID, sql=(SCRIPT_DIR / f"{source_name}.sql").read_text(), ), @@ -157,9 +157,9 @@ def create_preingestion_tasks(): @staticmethod def create_postingestion_tasks(): - drop_inaturalist_schema = PostgresOperator( + drop_inaturalist_schema = SQLExecuteQueryOperator( task_id="drop_inaturalist_schema", - postgres_conn_id=POSTGRES_CONN_ID, + conn_id=POSTGRES_CONN_ID, sql="DROP SCHEMA IF EXISTS inaturalist CASCADE", ) return drop_inaturalist_schema diff --git a/pytest.ini b/pytest.ini index 7263d4702..b3525fece 100644 --- a/pytest.ini +++ b/pytest.ini @@ -21,10 +21,6 @@ addopts = # This occurs because the default config is loaded when running `just test --extended` # which happens to still have SMTP credential defaults assigned in it. We do not set # these anywhere in the dev stack so it can be safely ignored. -# Subdag -# This appears to be coming from Airflow internals during testing as a result of -# loading the example DAGs: -# /opt/airflow/.local/lib/python3.10/site-packages/airflow/example_dags/example_subdag_operator.py:43: RemovedInAirflow3Warning # distutils # Warning in dependency, nothing we can do # flask @@ -32,6 +28,5 @@ addopts = # "removed"/"remoevd" is due to https://github.com/pallets/flask/pull/4757 filterwarnings= ignore:Fetching SMTP credentials from configuration variables will be deprecated in a future release. Please set credentials using a connection instead:airflow.exceptions.RemovedInAirflow3Warning - ignore:This class is deprecated. Please use `airflow.utils.task_group.TaskGroup`.:airflow.exceptions.RemovedInAirflow3Warning ignore:distutils Version classes are deprecated. Use packaging.version instead:DeprecationWarning ignore:.*is deprecated and will be (remoevd|removed) in Flask 2.3.:DeprecationWarning diff --git a/requirements_prod.txt b/requirements_prod.txt index 3093bc6cc..eb5a91d15 100644 --- a/requirements_prod.txt +++ b/requirements_prod.txt @@ -2,7 +2,7 @@ # Note: Unpinned packages have their versions determined by the Airflow constraints file -apache-airflow[amazon,postgres,http]==2.4.2 +apache-airflow[amazon,postgres,http]==2.5.0 lxml psycopg2-binary requests-file==1.5.1 diff --git a/tests/dags/common/operators/test_postgres_result.py b/tests/dags/common/operators/test_postgres_result.py deleted file mode 100644 index 79bf79c93..000000000 --- a/tests/dags/common/operators/test_postgres_result.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Based on: -https://airflow.apache.org/docs/apache-airflow/stable/best-practices.html#unit-tests -""" -import datetime -import os - -import pytest -from airflow import DAG -from airflow.models import DagRun, TaskInstance -from airflow.utils.session import create_session -from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.types import DagRunType -from common.operators.postgres_result import PostgresResultOperator - - -DATA_INTERVAL_START = datetime.datetime(2021, 9, 13, tzinfo=datetime.timezone.utc) -DATA_INTERVAL_END = DATA_INTERVAL_START + datetime.timedelta(days=1) - -TEST_DAG_ID = "test_postgres_result_dag" -TEST_TASK_ID = "test_postgres_result_task" -DB_CONN_ID = os.getenv("OPENLEDGER_CONN_ID", "postgres_openledger_testing") - - -@pytest.fixture(autouse=True) -def clean_db(): - with create_session() as session: - session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete() - session.query(TaskInstance).filter(TaskInstance.dag_id == TEST_DAG_ID).delete() - yield - session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete() - session.query(TaskInstance).filter(TaskInstance.dag_id == TEST_DAG_ID).delete() - - -def get_dag(sql, handler): - with DAG( - dag_id=TEST_DAG_ID, - schedule="@daily", - start_date=DATA_INTERVAL_START, - ) as dag: - PostgresResultOperator( - task_id=TEST_TASK_ID, postgres_conn_id=DB_CONN_ID, sql=sql, handler=handler - ) - return dag - - -@pytest.mark.parametrize( - "sql, handler, expected", - [ - ["SELECT 1", lambda c: c.fetchone(), (1,)], - ["SELECT 1", lambda c: c.fetchone()[0], 1], - ["SELECT UNNEST(ARRAY[1, 2, 3])", lambda c: c.fetchone(), (1,)], - ["SELECT UNNEST(ARRAY[1, 2, 3])", lambda c: c.fetchall(), [(1,), (2,), (3,)]], - ["SELECT UNNEST(ARRAY[1, 2, 3])", lambda c: c.rowcount, 3], - ], -) -def test_postgres_result_operator(sql, handler, expected): - dag = get_dag(sql, handler) - dagrun = dag.create_dagrun( - state=DagRunState.RUNNING, - execution_date=DATA_INTERVAL_START, - data_interval=(DATA_INTERVAL_START, DATA_INTERVAL_END), - start_date=DATA_INTERVAL_END, - run_type=DagRunType.MANUAL, - ) - ti: TaskInstance = dagrun.get_task_instance(task_id=TEST_TASK_ID) - ti.task = dag.get_task(task_id=TEST_TASK_ID) - ti.run(ignore_ti_state=True) - assert ti.state == TaskInstanceState.SUCCESS - value = ti.xcom_pull(TEST_TASK_ID) - assert value == expected diff --git a/tests/dags/common/sensors/test_single_run_external_dags_sensor.py b/tests/dags/common/sensors/test_single_run_external_dags_sensor.py index f1d801f1c..142aee650 100644 --- a/tests/dags/common/sensors/test_single_run_external_dags_sensor.py +++ b/tests/dags/common/sensors/test_single_run_external_dags_sensor.py @@ -74,6 +74,20 @@ def create_dagrun(dag, dag_state): ) +# This appears to be coming from Airflow internals during testing as a result of +# loading the example DAGs: +# /opt/airflow/.local/lib/python3.10/site-packages/airflow/example_dags/example_subdag_operator.py:43: RemovedInAirflow3Warning # noqa: E501 +@pytest.mark.filterwarnings( + "ignore:This class is deprecated. Please use " + "`airflow.utils.task_group.TaskGroup`.:airflow.exceptions.RemovedInAirflow3Warning" +) +# This also appears to be coming from Airflow internals during testing as a result of +# loading the example bash operator DAG: +# /home/airflow/.local/lib/python3.10/site-packages/airflow/models/dag.py:3492: RemovedInAirflow3Warning # noqa: E501 +@pytest.mark.filterwarnings( + "ignore:Param `schedule_interval` is deprecated and will be removed in a future release. " + "Please use `schedule` instead.:airflow.exceptions.RemovedInAirflow3Warning" +) class TestExternalDAGsSensor(unittest.TestCase): def setUp(self): Pool.create_or_update_pool(TEST_POOL, slots=1, description="test pool")