diff --git a/api/src/data_inclusion/api/inclusion_data/commands.py b/api/src/data_inclusion/api/inclusion_data/commands.py index b0c047b0..608e8b5f 100644 --- a/api/src/data_inclusion/api/inclusion_data/commands.py +++ b/api/src/data_inclusion/api/inclusion_data/commands.py @@ -123,9 +123,6 @@ def load_inclusion_data(): structures_df = structures_df.replace({np.nan: None}) services_df = services_df.replace({np.nan: None}) - structures_df = structures_df.drop(columns=["_di_geocodage_score"]) - services_df = services_df.drop(columns=["_di_geocodage_score"]) - structure_errors_df = validate_df(structures_df, model_schema=schema.Structure) service_errors_df = validate_df(services_df, model_schema=schema.Service) diff --git a/pipeline/dags/compute_hourly.py b/pipeline/dags/compute_hourly.py index 5d3ca3e7..34e93d14 100644 --- a/pipeline/dags/compute_hourly.py +++ b/pipeline/dags/compute_hourly.py @@ -5,8 +5,7 @@ from dag_utils import date, marts, notifications from dag_utils.dbt import ( - get_after_geocoding_tasks, - get_before_geocoding_tasks, + get_intermediate_tasks, get_staging_tasks, ) @@ -24,8 +23,7 @@ ( start >> get_staging_tasks(schedule="@hourly") - >> get_before_geocoding_tasks() - >> get_after_geocoding_tasks() + >> get_intermediate_tasks() >> marts.export_di_dataset_to_s3() >> end ) diff --git a/pipeline/dags/dag_utils/dbt.py b/pipeline/dags/dag_utils/dbt.py index a5b3a21c..61c0787f 100644 --- a/pipeline/dags/dag_utils/dbt.py +++ b/pipeline/dags/dag_utils/dbt.py @@ -98,9 +98,9 @@ def get_staging_tasks(schedule=None): return task_list -def get_before_geocoding_tasks(): +def get_intermediate_tasks(): return dbt_operator_factory( - task_id="dbt_build_before_geocoding", + task_id="dbt_build_intermediate", command="build", select=" ".join( [ @@ -111,22 +111,11 @@ def get_before_geocoding_tasks(): # into a single DAG. Another way to see it is that it depended on # main since the beginning as it required intermediate data to be # present ? + "path:models/intermediate/int__geocodages.sql", "path:models/intermediate/int__union_contacts.sql", "path:models/intermediate/int__union_adresses.sql", "path:models/intermediate/int__union_services.sql", "path:models/intermediate/int__union_structures.sql", - ] - ), - trigger_rule=TriggerRule.ALL_DONE, - ) - - -def get_after_geocoding_tasks(): - return dbt_operator_factory( - task_id="dbt_build_after_geocoding", - command="build", - select=" ".join( - [ "path:models/intermediate/extra", "path:models/intermediate/int__plausible_personal_emails.sql", "path:models/intermediate/int__union_adresses__enhanced.sql+", @@ -140,4 +129,5 @@ def get_after_geocoding_tasks(): "path:models/intermediate/quality/int_quality__stats.sql+", ] ), + trigger_rule=TriggerRule.ALL_DONE, ) diff --git a/pipeline/dags/dag_utils/geocoding.py b/pipeline/dags/dag_utils/geocoding.py deleted file mode 100644 index 9ced0c17..00000000 --- a/pipeline/dags/dag_utils/geocoding.py +++ /dev/null @@ -1,125 +0,0 @@ -import csv -import io -import logging - -import numpy as np -import pandas as pd -import requests -import tenacity -from tenacity import before, stop, wait - -logger = logging.getLogger(__name__) - - -class GeocodingBackend: - def geocode(self, df: pd.DataFrame) -> pd.DataFrame: - raise NotImplementedError - - -class BaseAdresseNationaleBackend(GeocodingBackend): - def __init__(self, base_url: str): - self.base_url = base_url.strip("/") - - def _geocode(self, df: pd.DataFrame) -> pd.DataFrame: - logger.info("Will send address batch, dimensions=%s", df.shape) - with io.BytesIO() as buf: - df.to_csv(buf, index=False, quoting=csv.QUOTE_ALL, sep="|") - - try: - response = requests.post( - self.base_url + "/search/csv/", - files={"data": ("data.csv", buf.getvalue(), "text/csv")}, - data={ - "columns": ["adresse", "code_postal", "commune"], - # Post-filter on the INSEE code and not the zipcode. - # Explanations from the BAN API creators: - # The postcode is problematic for cities with multiple zipcodes - # if the supplied zipcode is wrong, or the one in the BAN is. - # The INSEE code is more stable, unique and reliable. - # Also this post-filter does not return "possible" results, - # it blindly filters-out. - "citycode": "code_insee", - }, - timeout=180, # we upload 2MB of data, so we need a high timeout - ) - response.raise_for_status() - except requests.RequestException as e: - logger.info("Error while fetching `%s`: %s", e.request.url, e) - return pd.DataFrame() - - with io.StringIO() as f: - f.write(response.text) - f.seek(0) - results_df = pd.read_csv( - f, - encoding_errors="replace", - on_bad_lines="warn", - dtype=str, - sep="|", - ) - results_df = results_df.replace({np.nan: None}) - # In some cases (ex: address='20' and city='Paris'), the BAN API will return - # a municipality as a result with a very high score. This is be discarded - # as this will not be valuable information to localize a structure. - results_df = results_df[results_df.result_type != "municipality"] - - logger.info("Got result for address batch, dimensions=%s", results_df.shape) - return results_df - - def geocode(self, df: pd.DataFrame) -> pd.DataFrame: - # BAN api limits the batch geocoding to 50MB of data - # In our tests, 10_000 rows is about 1MB; but we'll be conservative - # since we also want to avoid upload timeouts. - BATCH_SIZE = 20_000 - - # drop rows that have not at least one commune, code_insee or code_postal - # as the result can't make sense. - # Note that we keep the rows without an address, as they can be used to - # at least resolve the city. - df = df.dropna(subset=["code_postal", "code_insee", "commune"], thresh=2) - df = df.sort_values(by="_di_surrogate_id") - # Cleanup the values a bit to help the BAN's scoring. After experimentation, - # looking for "Ville-Nouvelle" returns worse results than "Ville Nouvelle", - # probably due to a tokenization in the BAN that favors spaces. - # In the same fashion, looking for "U.R.S.S." returns worse results than using - # "URSS" for "Avenue de l'U.R.S.S.". With the dots, it does not find the - # street at all ¯\_(ツ)_/¯ - df = df.assign( - adresse=(df.adresse.str.strip().replace("-", " ").replace(".", "")), - commune=df.commune.str.strip(), - ) - - logger.info(f"Only {len(df)} rows can be geocoded.") - - def _geocode_with_retry(df: pd.DataFrame) -> pd.DataFrame: - try: - for attempt in tenacity.Retrying( - stop=stop.stop_after_attempt(6), - wait=wait.wait_random_exponential(multiplier=10), - before=before.before_log(logger, logging.INFO), - ): - with attempt: - results_df = self._geocode(df) - - if ( - len(df) > 0 - and len(results_df.dropna(subset=["result_citycode"])) - / len(df) - # arbitrary threshold, if less than this percentage of - # the rows have been geocoded, retry. - < 0.3 - ): - # the BAN api often fails to properly complete a - # geocoding batch. If that happens, raise for retry. - raise tenacity.TryAgain - - return results_df - - except tenacity.RetryError: - logger.error("Failed to geocode batch") - return results_df - - return df.groupby( - np.arange(len(df)) // BATCH_SIZE, - group_keys=False, - ).apply(_geocode_with_retry) diff --git a/pipeline/dags/main.py b/pipeline/dags/main.py index 7b39b58e..adde7e64 100644 --- a/pipeline/dags/main.py +++ b/pipeline/dags/main.py @@ -1,76 +1,15 @@ import pendulum import airflow -from airflow.operators import empty, python +from airflow.operators import empty from dag_utils import date, marts from dag_utils.dbt import ( dbt_operator_factory, - get_after_geocoding_tasks, - get_before_geocoding_tasks, + get_intermediate_tasks, get_staging_tasks, ) from dag_utils.notifications import notify_failure_args -from dag_utils.virtualenvs import PYTHON_BIN_PATH - - -def _geocode(): - import logging - - import sqlalchemy as sqla - - from airflow.models import Variable - from airflow.providers.postgres.hooks.postgres import PostgresHook - - from dag_utils import geocoding - from dag_utils.sources import utils - - logger = logging.getLogger(__name__) - - pg_hook = PostgresHook(postgres_conn_id="pg") - - # 1. Retrieve input data - input_df = pg_hook.get_pandas_df( - sql=""" - SELECT - _di_surrogate_id, - adresse, - code_postal, - code_insee, - commune - FROM public_intermediate.int__union_adresses; - """ - ) - - utils.log_df_info(input_df, logger=logger) - - geocoding_backend = geocoding.BaseAdresseNationaleBackend( - base_url=Variable.get("BAN_API_URL") - ) - - # 2. Geocode - output_df = geocoding_backend.geocode(input_df) - - utils.log_df_info(output_df, logger=logger) - - # 3. Write result back - engine = pg_hook.get_sqlalchemy_engine() - - with engine.connect() as conn: - with conn.begin(): - output_df.to_sql( - "extra__geocoded_results", - schema="public", - con=conn, - if_exists="replace", - index=False, - dtype={ - "latitude": sqla.Float, - "longitude": sqla.Float, - "result_score": sqla.Float, - }, - ) - with airflow.DAG( dag_id="main", @@ -93,20 +32,12 @@ def _geocode(): command="run-operation create_udfs", ) - python_geocode = python.ExternalPythonOperator( - task_id="python_geocode", - python=str(PYTHON_BIN_PATH), - python_callable=_geocode, - ) - ( start >> dbt_seed >> dbt_create_udfs >> get_staging_tasks() - >> get_before_geocoding_tasks() - >> python_geocode - >> get_after_geocoding_tasks() + >> get_intermediate_tasks() >> marts.export_di_dataset_to_s3() >> end ) diff --git a/pipeline/dbt/macros/create_udfs.sql b/pipeline/dbt/macros/create_udfs.sql index e8263217..4276c0f9 100644 --- a/pipeline/dbt/macros/create_udfs.sql +++ b/pipeline/dbt/macros/create_udfs.sql @@ -23,4 +23,4 @@ CREATE SCHEMA IF NOT EXISTS processings; {% do run_query(sql) %} -{% endmacro %}s \ No newline at end of file +{% endmacro %} diff --git a/pipeline/dbt/models/intermediate/extra/_extra__models.yml b/pipeline/dbt/models/intermediate/extra/_extra__models.yml index 20020c1f..26af30df 100644 --- a/pipeline/dbt/models/intermediate/extra/_extra__models.yml +++ b/pipeline/dbt/models/intermediate/extra/_extra__models.yml @@ -1,7 +1,4 @@ version: 2 models: - # TODO: cleanup these models, add staging models - - name: int_extra__insee_prenoms_filtered - - name: int_extra__geocoded_results diff --git a/pipeline/dbt/models/intermediate/extra/int_extra__geocoded_results.sql b/pipeline/dbt/models/intermediate/extra/int_extra__geocoded_results.sql deleted file mode 100644 index d03e189b..00000000 --- a/pipeline/dbt/models/intermediate/extra/int_extra__geocoded_results.sql +++ /dev/null @@ -1,48 +0,0 @@ -{% set source_model = source('internal', 'extra__geocoded_results') %} - -{% set table_exists = adapter.get_relation(database=source_model.database, schema=source_model.schema, identifier=source_model.name) is not none %} - -{% if table_exists %} - - WITH source AS ( - SELECT * FROM {{ source_model }} - ), - -{% else %} - -WITH source AS ( - SELECT - NULL AS "_di_surrogate_id", - NULL AS "adresse", - NULL AS "code_postal", - NULL AS "commune", - CAST(NULL AS FLOAT) AS "latitude", - CAST(NULL AS FLOAT) AS "longitude", - NULL AS "result_label", - CAST(NULL AS FLOAT) AS "result_score", - NULL AS "result_score_next", - NULL AS "result_type", - NULL AS "result_id", - NULL AS "result_housenumber", - NULL AS "result_name", - NULL AS "result_street", - NULL AS "result_postcode", - NULL AS "result_city", - NULL AS "result_context", - NULL AS "result_citycode", - NULL AS "result_oldcitycode", - NULL AS "result_oldcity", - NULL AS "result_district", - NULL AS "result_status" - WHERE FALSE -), - -{% endif %} - -final AS ( - SELECT * - FROM source - WHERE result_id IS NOT NULL -) - -SELECT * FROM final diff --git a/pipeline/dbt/models/intermediate/int__union_adresses__enhanced.sql b/pipeline/dbt/models/intermediate/int__union_adresses__enhanced.sql index 11a14d77..ec824feb 100644 --- a/pipeline/dbt/models/intermediate/int__union_adresses__enhanced.sql +++ b/pipeline/dbt/models/intermediate/int__union_adresses__enhanced.sql @@ -6,7 +6,7 @@ geocodages AS ( SELECT * FROM {{ ref('int__geocodages') }} ), -final AS ( +overriden_adresses AS ( SELECT adresses._di_surrogate_id AS "_di_surrogate_id", adresses.id AS "id", @@ -28,6 +28,25 @@ final AS ( ON adresses._di_surrogate_id = geocodages.adresse_id AND geocodages.score >= 0.8 +), + +final AS ( + SELECT overriden_adresses.* + FROM overriden_adresses + LEFT JOIN + LATERAL + LIST_ADRESSE_ERRORS( + overriden_adresses.adresse, + overriden_adresses.code_insee, + overriden_adresses.code_postal, + overriden_adresses.commune, + overriden_adresses.complement_adresse, + overriden_adresses.id, + overriden_adresses.latitude, + overriden_adresses.longitude, + overriden_adresses.source + ) AS errors ON TRUE + WHERE errors.field IS NULL ) SELECT * FROM final diff --git a/pipeline/dbt/models/intermediate/int__union_services__enhanced.sql b/pipeline/dbt/models/intermediate/int__union_services__enhanced.sql index a02857e2..461c1a6c 100644 --- a/pipeline/dbt/models/intermediate/int__union_services__enhanced.sql +++ b/pipeline/dbt/models/intermediate/int__union_services__enhanced.sql @@ -92,8 +92,7 @@ final AS ( adresses.commune AS "commune", adresses.adresse AS "adresse", adresses.code_postal AS "code_postal", - adresses.code_insee AS "code_insee", - adresses.result_score AS "_di_geocodage_score" + adresses.code_insee AS "code_insee" FROM valid_services LEFT JOIN adresses ON valid_services._di_adresse_surrogate_id = adresses._di_surrogate_id diff --git a/pipeline/dbt/models/intermediate/int__union_structures__enhanced.sql b/pipeline/dbt/models/intermediate/int__union_structures__enhanced.sql index ddc703a8..d8d62913 100644 --- a/pipeline/dbt/models/intermediate/int__union_structures__enhanced.sql +++ b/pipeline/dbt/models/intermediate/int__union_structures__enhanced.sql @@ -49,7 +49,6 @@ final AS ( adresses.adresse AS "adresse", adresses.code_postal AS "code_postal", adresses.code_insee AS "code_insee", - adresses.result_score AS "_di_geocodage_score", COALESCE(plausible_personal_emails._di_surrogate_id IS NOT NULL, FALSE) AS "_di_email_is_pii" FROM valid_structures diff --git a/pipeline/dbt/models/marts/inclusion/_inclusion_models.yml b/pipeline/dbt/models/marts/inclusion/_inclusion_models.yml index 38941147..dd1be2e0 100644 --- a/pipeline/dbt/models/marts/inclusion/_inclusion_models.yml +++ b/pipeline/dbt/models/marts/inclusion/_inclusion_models.yml @@ -10,8 +10,6 @@ models: data_type: text constraints: - type: primary_key - - name: _di_geocodage_score - data_type: float - name: id data_type: text constraints: @@ -135,8 +133,6 @@ models: - type: not_null - type: foreign_key expression: "public_marts.marts_inclusion__structures (_di_surrogate_id)" - - name: _di_geocodage_score - data_type: float - name: id data_type: text constraints: diff --git a/pipeline/tests/integration/__init__.py b/pipeline/tests/integration/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pipeline/tests/integration/test_geocoding.py b/pipeline/tests/integration/test_geocoding.py deleted file mode 100644 index b705c2f4..00000000 --- a/pipeline/tests/integration/test_geocoding.py +++ /dev/null @@ -1,73 +0,0 @@ -from unittest.mock import ANY - -import pandas as pd -import pytest - -from dags.dag_utils import geocoding - -pytestmark = pytest.mark.ban_api - - -@pytest.fixture -def ban_backend(): - return geocoding.BaseAdresseNationaleBackend( - base_url="https://api-adresse.data.gouv.fr" - ) - - -@pytest.fixture -def sample_df() -> pd.DataFrame: - return pd.DataFrame.from_records( - data=[ - { - "source": "dora", - "_di_surrogate_id": "1", - "adresse": "17 rue Malus", - "code_postal": "59000", - "code_insee": "59350", - "commune": "Lille", - }, - { - "source": "dora", - "_di_surrogate_id": "2", - "adresse": None, - "code_postal": None, - "code_insee": None, - "commune": None, - }, - ] - ) - - -def test_ban_geocode( - ban_backend: geocoding.BaseAdresseNationaleBackend, - sample_df: pd.DataFrame, -): - assert ban_backend.geocode(sample_df).to_dict(orient="records") == [ - { - "_di_surrogate_id": "1", - "source": "dora", - "adresse": "17 rue Malus", - "code_insee": "59350", - "code_postal": "59000", - "commune": "Lille", - "latitude": "50.627078", - "longitude": "3.067372", - "result_label": "17 Rue Malus 59000 Lille", - "result_score": ANY, - "result_score_next": None, - "result_type": "housenumber", - "result_id": "59350_5835_00017", - "result_housenumber": "17", - "result_name": "17 Rue Malus", - "result_street": "Rue Malus", - "result_postcode": "59000", - "result_city": "Lille", - "result_context": "59, Nord, Hauts-de-France", - "result_citycode": "59350", - "result_oldcitycode": "59350", - "result_oldcity": "Lille", - "result_district": None, - "result_status": "ok", - } - ]