From 46e0c7fc8de9bf66dd401786cfd83fc7468f6584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Fri, 29 Dec 2023 13:41:05 +0100 Subject: [PATCH] feat: add price.origins_tags field (#110) --- ...023_24d71d56d493_add_origins_tags_field.py | 44 +++++++ app/api.py | 22 ---- app/models.py | 1 + app/schemas.py | 65 +++++++++- tests/test_api.py | 114 ++++++++++++++---- tests/unit/test_schema.py | 87 +++++++++++++ 6 files changed, 285 insertions(+), 48 deletions(-) create mode 100644 alembic/versions/20231229_1023_24d71d56d493_add_origins_tags_field.py create mode 100644 tests/unit/test_schema.py diff --git a/alembic/versions/20231229_1023_24d71d56d493_add_origins_tags_field.py b/alembic/versions/20231229_1023_24d71d56d493_add_origins_tags_field.py new file mode 100644 index 00000000..a9276509 --- /dev/null +++ b/alembic/versions/20231229_1023_24d71d56d493_add_origins_tags_field.py @@ -0,0 +1,44 @@ +"""add origins_tags field + +Revision ID: 24d71d56d493 +Revises: 1e60d73e79cd +Create Date: 2023-12-29 10:23:22.430506 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "24d71d56d493" +down_revision: Union[str, None] = "1e60d73e79cd" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "prices", + sa.Column( + "origins_tags", + sa.JSON().with_variant( + postgresql.JSONB(astext_type=sa.Text()), "postgresql" + ), + nullable=True, + ), + ) + op.create_index( + op.f("ix_prices_origins_tags"), "prices", ["origins_tags"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_prices_origins_tags"), table_name="prices") + op.drop_column("prices", "origins_tags") + # ### end Alembic commands ### diff --git a/app/api.py b/app/api.py index 52b446ec..8c5b5e4d 100644 --- a/app/api.py +++ b/app/api.py @@ -21,7 +21,6 @@ from fastapi_filter import FilterDepends from fastapi_pagination import Page, add_pagination from fastapi_pagination.ext.sqlalchemy import paginate -from openfoodfacts.taxonomy import get_taxonomy from openfoodfacts.utils import get_logger from sqlalchemy.orm import Session @@ -209,27 +208,6 @@ def create_price( detail="Proof does not belong to current user", ) - if price.category_tag is not None: - # lowercase the category tag to perform the match - price.category_tag = price.category_tag.lower() - category_taxonomy = get_taxonomy("category") - if price.category_tag not in category_taxonomy: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid category tag: category '{price.category_tag}' does not exist in the taxonomy", - ) - - if price.labels_tags is not None: - # lowercase the labels tags to perform the match - price.labels_tags = [label_tag.lower() for label_tag in price.labels_tags] - labels_taxonomy = get_taxonomy("label") - for label_tag in price.labels_tags: - if label_tag not in labels_taxonomy: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid label tag: label '{label_tag}' does not exist in the taxonomy", - ) - db_price = crud.create_price(db, price=price, user=current_user) background_tasks.add_task(tasks.create_price_product, db, db_price) background_tasks.add_task(tasks.create_price_location, db, db_price) diff --git a/app/models.py b/app/models.py index 0ed51417..0a9274c2 100644 --- a/app/models.py +++ b/app/models.py @@ -98,6 +98,7 @@ class Price(Base): product_name = Column(String, nullable=True) category_tag = Column(String, nullable=True, index=True) labels_tags = Column(JSONVariant, nullable=True, index=True) + origins_tags = Column(JSONVariant, nullable=True, index=True) product_id: Mapped[int] = mapped_column(ForeignKey("products.id"), nullable=True) product: Mapped[Product] = relationship(back_populates="prices") diff --git a/app/schemas.py b/app/schemas.py index 177b9d65..e8de96c4 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -3,6 +3,7 @@ from fastapi_filter.contrib.sqlalchemy import Filter from openfoodfacts import Flavor +from openfoodfacts.taxonomy import get_taxonomy from pydantic import ( AnyHttpUrl, BaseModel, @@ -123,10 +124,23 @@ class PriceCreate(BaseModel): The most common labels are: - `en:organic`: the product is organic + - `fr:ab-agriculture-biologique`: the product is organic, in France - `en:fair-trade`: the product is fair-trade Other labels can be provided if relevant. """, + examples=["en:organic", "fr:ab-agriculture-biologique", "en:fair-trade"], + ) + origins_tags: list[str] | None = Field( + default=None, + description="""origins of the product, only for products without barcode. + + This field is a list as some products may be a mix of several origins, + but most products have only one origin. + + The origins must be valid origins in the Open Food Facts taxonomy. + If one of the origins is not valid, the price will be rejected.""", + examples=["en:california", "en:france", "en:italy", "en:spain"], ) price: float = Field( gt=0, @@ -162,20 +176,61 @@ class PriceCreate(BaseModel): ) @field_validator("labels_tags") - def labels_tags_is_valid(cls, v): + def labels_tags_is_valid(cls, v: list[str] | None): if v is not None: if len(v) == 0: raise ValueError("`labels_tags` cannot be empty") + v = [label_tag.lower() for label_tag in v] + labels_taxonomy = get_taxonomy("label") + for label_tag in v: + if label_tag not in labels_taxonomy: + raise ValueError( + f"Invalid label tag: label '{label_tag}' does not exist in the taxonomy", + ) + return v + + @field_validator("origins_tags") + def origins_tags_is_valid(cls, v: list[str] | None): + if v is not None: + if len(v) == 0: + raise ValueError("`origins_tags` cannot be empty") + v = [origin_tag.lower() for origin_tag in v] + origins_taxonomy = get_taxonomy("origin") + for origin_tag in v: + if origin_tag not in origins_taxonomy: + raise ValueError( + f"Invalid origin tag: origin '{origin_tag}' does not exist in the taxonomy", + ) + return v + + @field_validator("category_tag") + def category_tag_is_valid(cls, v: str | None): + if v is not None: + v = v.lower() + category_taxonomy = get_taxonomy("category") + if v not in category_taxonomy: + raise ValueError( + f"Invalid category tag: category '{v}' does not exist in the taxonomy" + ) return v @model_validator(mode="after") def product_code_and_category_tag_are_exclusive(self): """Validator that checks that `product_code` and `category_tag` are exclusive, and that at least one of them is set.""" - if self.product_code is not None and self.category_tag is not None: - raise ValueError( - "`product_code` and `category_tag` are exclusive, you can't set both" - ) + if self.product_code is not None: + if self.category_tag is not None: + raise ValueError( + "`product_code` and `category_tag` are exclusive, you can't set both" + ) + if self.labels_tags is not None: + raise ValueError( + "`labels_tags` can only be set for products without barcode" + ) + if self.origins_tags is not None: + raise ValueError( + "`origins_tags` can only be set for products without barcode" + ) if self.product_code is None and self.category_tag is None: raise ValueError("either `product_code` or `category_tag` must be set") return self diff --git a/tests/test_api.py b/tests/test_api.py index dcd32ddc..810dc27f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,3 +1,4 @@ +import datetime import io import pytest @@ -36,6 +37,7 @@ def override_get_db(): app.dependency_overrides[get_db] = override_get_db +db_session = pytest.fixture(override_get_db, scope="module") # client setup & fixtures # ------------------------------------------------------------------------------ @@ -57,26 +59,32 @@ def override_get_db(): @pytest.fixture(scope="module") -def user(db=override_get_db()): - db_user = crud.create_user(next(db), USER) +def user(db_session): + db_user = crud.create_user(db_session, USER) return db_user @pytest.fixture(scope="module") -def product(db=override_get_db()): - db_product = crud.create_product(next(db), PRODUCT) +def product(db_session): + db_product = crud.create_product(db_session, PRODUCT) return db_product @pytest.fixture(scope="module") -def location(db=override_get_db()): - db_location = crud.create_location(next(db), LOCATION) +def location(db_session): + db_location = crud.create_location(db_session, LOCATION) return db_location +@pytest.fixture(scope="function") +def clean_prices(db_session): + db_session.query(crud.Price).delete() + db_session.commit() + + # Tests # ------------------------------------------------------------------------------ -def test_create_price(user, db=override_get_db()): +def test_create_price(db_session, user, clean_prices): # without authentication response = client.post( "/api/v1/prices", @@ -92,12 +100,37 @@ def test_create_price(user, db=override_get_db()): assert response.status_code == 201 assert response.json()["product_code"] == PRICE_1.product_code assert "id" not in response.json() - db_prices = crud.get_prices(next(db)) - assert len(db_prices) == 1 + assert len(crud.get_prices(db_session)) == 1 # assert db_prices[0]["owner"] == user.user_id -def test_create_price_required_fields_validation(user): +def test_create_price_with_category_tag(db_session, user, clean_prices): + PRICE_WITH_CATEGORY_TAG = PRICE_1.model_copy( + update={ + "product_code": None, + "category_tag": "en:tomatoes", + "labels_tags": ["en:Organic"], + "origins_tags": ["en:France"], + "date": "2023-12-01", + } + ) + response = client.post( + "/api/v1/prices", + json=jsonable_encoder(PRICE_WITH_CATEGORY_TAG), + headers={"Authorization": f"Bearer {user.token}"}, + ) + json_response = response.json() + assert response.status_code == 201 + assert json_response.get("category_tag") == "en:tomatoes" + assert json_response.get("labels_tags") == ["en:organic"] + assert json_response.get("origins_tags") == ["en:france"] + assert json_response.get("date") == "2023-12-01" + assert "id" not in response.json() + db_prices = crud.get_prices(db_session) + assert len(db_prices) == 1 + + +def test_create_price_required_fields_validation(db_session, user, clean_prices): REQUIRED_FIELDS = [ "price", "location_osm_id", @@ -112,9 +145,10 @@ def test_create_price_required_fields_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 0 -def test_create_price_product_code_pattern_validation(user): +def test_create_price_product_code_pattern_validation(db_session, user, clean_prices): # product_code cannot be an empty string, nor contain letters WRONG_PRICE_PRODUCT_CODES = ["", "en:tomates", "8001505005707XYZ"] for wrong_price_product_code in WRONG_PRICE_PRODUCT_CODES: @@ -127,9 +161,10 @@ def test_create_price_product_code_pattern_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 0 -def test_create_price_category_tag_pattern_validation(user): +def test_create_price_category_tag_pattern_validation(db_session, user, clean_prices): # category_tag must follow a certain pattern (ex: "en:tomatoes") WRONG_PRICE_CATEGORY_TAGS = ["", ":", "en", ":tomatoes"] for wrong_price_category_tag in WRONG_PRICE_CATEGORY_TAGS: @@ -142,9 +177,10 @@ def test_create_price_category_tag_pattern_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 0 -def test_create_price_currency_validation(user): +def test_create_price_currency_validation(db_session, user, clean_prices): # currency must have a specific format (ex: "EUR") WRONG_PRICE_CURRENCIES = ["", "€", "euro"] for wrong_price_currency in WRONG_PRICE_CURRENCIES: @@ -157,9 +193,10 @@ def test_create_price_currency_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 0 -def test_create_price_location_osm_type_validation(user): +def test_create_price_location_osm_type_validation(db_session, user, clean_prices): WRONG_PRICE_LOCATION_OSM_TYPES = ["", "node"] for wrong_price_location_osm_type in WRONG_PRICE_LOCATION_OSM_TYPES: PRICE_WITH_LOCATION_OSM_TYPE_ERROR = PRICE_1.model_copy( @@ -171,9 +208,12 @@ def test_create_price_location_osm_type_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 0 -def test_create_price_code_category_exclusive_validation(user): +def test_create_price_code_category_exclusive_validation( + db_session, user, clean_prices +): # both product_code & category_tag missing: error PRICE_WITH_CODE_AND_CATEGORY_MISSING = PRICE_1.model_copy( update={"product_code": None} @@ -184,6 +224,7 @@ def test_create_price_code_category_exclusive_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 0 # only product_code: ok PRICE_WITH_ONLY_PRODUCT_CODE = PRICE_1.model_copy() response = client.post( @@ -192,6 +233,7 @@ def test_create_price_code_category_exclusive_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 201 + assert len(crud.get_prices(db_session)) == 1 # only category_tag: ok PRICE_WITH_ONLY_CATEGORY = PRICE_1.model_copy( update={ @@ -206,6 +248,7 @@ def test_create_price_code_category_exclusive_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 201 + assert len(crud.get_prices(db_session)) == 2 # both product_code & category_tag present: error PRICE_WITH_BOTH_CODE_AND_CATEGORY = PRICE_1.model_copy( update={"category_tag": "en:tomatoes"} @@ -216,9 +259,10 @@ def test_create_price_code_category_exclusive_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 2 -def test_create_price_labels_tags_pattern_validation(user): +def test_create_price_labels_tags_pattern_validation(db_session, user, clean_prices): # product_code cannot be an empty string, nor contain letters WRONG_PRICE_LABELS_TAGS = [[]] for wrong_price_labels_tags in WRONG_PRICE_LABELS_TAGS: @@ -231,9 +275,14 @@ def test_create_price_labels_tags_pattern_validation(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 422 + assert len(crud.get_prices(db_session)) == 0 + +def test_get_prices(db_session, user, clean_prices): + for _ in range(3): + crud.create_price(db_session, PRICE_1, user) -def test_get_prices(): + assert len(crud.get_prices(db_session)) == 3 response = client.get("/api/v1/prices") assert response.status_code == 200 assert len(response.json()["items"]) == 3 @@ -250,19 +299,42 @@ def test_get_prices_pagination(): assert key in response.json() -def test_get_prices_filters(): +def test_get_prices_filters(db_session, user, clean_prices): + crud.create_price(db_session, PRICE_1, user) + crud.create_price( + db_session, + PRICE_1.model_copy( + update={"price": 3.99, "date": datetime.date.fromisoformat("2023-11-01")} + ), + user, + ) + crud.create_price(db_session, PRICE_1.model_copy(update={"price": 5.10}), user) + + assert len(crud.get_prices(db_session)) == 3 + response = client.get(f"/api/v1/prices?product_code={PRICE_1.product_code}") assert response.status_code == 200 - assert len(response.json()["items"]) == 2 + # 3 prices with the same product_code + assert len(response.json()["items"]) == 3 response = client.get("/api/v1/prices?price__gt=5") assert response.status_code == 200 - assert len(response.json()["items"]) == 0 + # 1 price with price > 5 + assert len(response.json()["items"]) == 1 response = client.get("/api/v1/prices?date=2023-10-31") assert response.status_code == 200 + # 2 prices with date = 2023-10-31 assert len(response.json()["items"]) == 2 -def test_get_prices_orders(): +def test_get_prices_orders(db_session, user, clean_prices): + crud.create_price(db_session, PRICE_1, user) + crud.create_price( + db_session, + PRICE_1.model_copy( + update={"price": 3.99, "date": datetime.date.fromisoformat("2023-10-01")} + ), + user, + ) response = client.get("/api/v1/prices") assert response.status_code == 200 assert (response.json()["items"][0]["date"]) == "2023-10-31" diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py new file mode 100644 index 00000000..aba979ce --- /dev/null +++ b/tests/unit/test_schema.py @@ -0,0 +1,87 @@ +import datetime + +import pydantic +import pytest + +from app.schemas import CurrencyEnum, LocationOSMEnum, PriceCreate + + +class TestPriceCreate: + def test_simple_price_with_barcode(self): + price = PriceCreate( + product_code="5414661000456", + location_osm_id=123, + location_osm_type=LocationOSMEnum.NODE, + price=1.99, + currency="EUR", + date="2021-01-01", + ) + assert price.product_code == "5414661000456" + assert price.location_osm_id == 123 + assert price.location_osm_type == LocationOSMEnum.NODE + assert price.price == 1.99 + assert price.currency == CurrencyEnum.EUR + assert price.date == datetime.date.fromisoformat("2021-01-01") + + def test_simple_price_with_category(self): + price = PriceCreate( + category_tag="en:Fresh-apricots", + labels_tags=["en:Organic", "fr:AB-agriculture-biologique"], + origins_tags=["en:California", "en:Sweden"], + location_osm_id=123, + location_osm_type=LocationOSMEnum.NODE, + price=1.99, + currency="EUR", + date="2021-01-01", + ) + assert price.category_tag == "en:fresh-apricots" + assert price.labels_tags == ["en:organic", "fr:ab-agriculture-biologique"] + assert price.origins_tags == ["en:california", "en:sweden"] + + def test_simple_price_with_invalid_taxonomized_values(self): + with pytest.raises(pydantic.ValidationError, match="Invalid category tag"): + PriceCreate( + category_tag="en:unknown-category", + location_osm_id=123, + location_osm_type=LocationOSMEnum.NODE, + price=1.99, + currency="EUR", + date="2021-01-01", + ) + + with pytest.raises(pydantic.ValidationError, match="Invalid label tag"): + PriceCreate( + category_tag="en:carrots", + labels_tags=["en:invalid"], + location_osm_id=123, + location_osm_type=LocationOSMEnum.NODE, + price=1.99, + currency="EUR", + date="2021-01-01", + ) + + with pytest.raises(pydantic.ValidationError, match="Invalid origin tag"): + PriceCreate( + category_tag="en:carrots", + origins_tags=["en:invalid"], + location_osm_id=123, + location_osm_type=LocationOSMEnum.NODE, + price=1.99, + currency="EUR", + date="2021-01-01", + ) + + def test_simple_price_with_product_code_and_labels_tags_raise(self): + with pytest.raises( + pydantic.ValidationError, + match="`labels_tags` can only be set for products without barcode", + ): + PriceCreate( + product_code="5414661000456", + labels_tags=["en:Organic", "fr:AB-agriculture-biologique"], + location_osm_id=123, + location_osm_type=LocationOSMEnum.NODE, + price=1.99, + currency="EUR", + date="2021-01-01", + )