diff --git a/alembic/versions/20240622_2231_cf78049d89b3_add_proof_location_fields.py b/alembic/versions/20240622_2231_cf78049d89b3_add_proof_location_fields.py new file mode 100644 index 00000000..50561041 --- /dev/null +++ b/alembic/versions/20240622_2231_cf78049d89b3_add_proof_location_fields.py @@ -0,0 +1,77 @@ +"""Add Proof location fields + +Revision ID: cf78049d89b3 +Revises: 49a828f10b05 +Create Date: 2024-06-22 22:31:47.195399 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "cf78049d89b3" +down_revision: Union[str, None] = "49a828f10b05" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "proofs", sa.Column("location_osm_id", sa.BigInteger(), nullable=True) + ) + op.add_column( + "proofs", sa.Column("location_osm_type", sa.String(length=255), nullable=True) + ) + op.add_column("proofs", sa.Column("location_id", sa.Integer(), nullable=True)) + op.create_foreign_key(None, "proofs", "locations", ["location_id"], ["id"]) + # Set the location to the location of the first price for each proof + op.execute( + """ + UPDATE proofs + SET location_osm_id = ( + SELECT location_osm_id + FROM prices + WHERE prices.proof_id = proofs.id + LIMIT 1 + ) + WHERE type IN ('PRICE_TAG', 'RECEIPT') + """ + ) + op.execute( + """ + UPDATE proofs + SET location_osm_type = ( + SELECT location_osm_type + FROM prices + WHERE prices.proof_id = proofs.id + LIMIT 1 + ) + WHERE type IN ('PRICE_TAG', 'RECEIPT') + """ + ) + op.execute( + """ + UPDATE proofs + SET location_id = ( + SELECT location_id + FROM prices + WHERE prices.proof_id = proofs.id + LIMIT 1 + ) + WHERE type IN ('PRICE_TAG', 'RECEIPT') + """ + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "proofs", type_="foreignkey") + op.drop_column("proofs", "location_id") + op.drop_column("proofs", "location_osm_type") + op.drop_column("proofs", "location_osm_id") + # ### end Alembic commands ### diff --git a/app/crud.py b/app/crud.py index d45870db..f3b7fa88 100644 --- a/app/crud.py +++ b/app/crud.py @@ -383,6 +383,8 @@ def create_proof( mimetype: str, type: ProofTypeEnum, user: UserCreate, + location_osm_id: int = None, + location_osm_type: LocationOSMEnum = None, date: str = None, currency: CurrencyEnum = None, source: str = None, @@ -399,6 +401,8 @@ def create_proof( file_path=file_path, mimetype=mimetype, type=type, + location_osm_id=location_osm_id, + location_osm_type=location_osm_type, date=date, currency=currency, owner=user.user_id, @@ -480,6 +484,13 @@ def increment_proof_price_count(db: Session, proof: Proof) -> Proof: return proof +def set_proof_location(db: Session, proof: Proof, location: Location) -> Proof: + proof.location_id = location.id + db.commit() + db.refresh(proof) + return proof + + def update_proof( db: Session, proof: Proof, new_values: ProofBasicUpdatableFields ) -> Proof: diff --git a/app/models.py b/app/models.py index 65f4013d..c3111b9d 100644 --- a/app/models.py +++ b/app/models.py @@ -105,6 +105,7 @@ class Location(Base): price_count: Mapped[int] = mapped_column( Integer, nullable=False, server_default="0", index=True ) + proofs: Mapped[list["Proof"]] = relationship(back_populates="location") created = mapped_column(DateTime(timezone=True), server_default=func.now()) updated = mapped_column(DateTime(timezone=True), onupdate=func.now()) @@ -124,10 +125,17 @@ class Proof(Base): Integer, nullable=False, server_default="0", index=True ) + location_osm_id = mapped_column(BigInteger, nullable=True) + location_osm_type: Mapped[LocationOSMEnum] = mapped_column( + ChoiceType(LocationOSMEnum), nullable=True + ) + location_id: Mapped[int] = mapped_column(ForeignKey("locations.id"), nullable=True) + location: Mapped[Location] = relationship(back_populates="proofs") + + date = mapped_column(Date, nullable=True) currency: Mapped[CurrencyEnum] = mapped_column( ChoiceType(CurrencyEnum), nullable=True ) - date = mapped_column(Date, nullable=True) owner = mapped_column(String, index=True) diff --git a/app/routers/prices.py b/app/routers/prices.py index 10967e2d..589ebe7b 100644 --- a/app/routers/prices.py +++ b/app/routers/prices.py @@ -60,9 +60,10 @@ def create_price( ) # create price db_price = crud.create_price(db, price=price, user=current_user, source=app_name) - # update counts + # relationships background_tasks.add_task(tasks.create_price_product, db, price=db_price) background_tasks.add_task(tasks.create_price_location, db, price=db_price) + # update counts background_tasks.add_task(tasks.increment_user_price_count, db, user=current_user) if price.proof_id and db_proof: background_tasks.add_task(tasks.increment_proof_price_count, db, proof=db_proof) diff --git a/app/routers/proofs.py b/app/routers/proofs.py index 62474788..0a6e67c2 100644 --- a/app/routers/proofs.py +++ b/app/routers/proofs.py @@ -1,22 +1,30 @@ from typing import Annotated, Optional -from fastapi import APIRouter, Depends, Form, HTTPException, UploadFile, status +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + Form, + HTTPException, + UploadFile, + status, +) from fastapi_filter import FilterDepends from fastapi_pagination import Page from fastapi_pagination.ext.sqlalchemy import paginate from sqlalchemy import Select from sqlalchemy.orm import Session -from app import crud, schemas +from app import crud, schemas, tasks from app.auth import get_current_user from app.db import get_db -from app.enums import CurrencyEnum, ProofTypeEnum +from app.enums import CurrencyEnum, LocationOSMEnum, ProofTypeEnum from app.models import Proof router = APIRouter(prefix="/proofs") -@router.get("", response_model=Page[schemas.ProofFull]) +@router.get("", response_model=Page[schemas.ProofFullWithRelations]) def get_user_proofs( current_user: schemas.UserCreate = Depends(get_current_user), filters: schemas.ProofFilter = FilterDepends(schemas.ProofFilter), @@ -39,6 +47,13 @@ def get_user_proofs( def upload_proof( file: UploadFile, type: Annotated[ProofTypeEnum, Form(description="The type of the proof")], + background_tasks: BackgroundTasks, + location_osm_id: Optional[str] = Form( + description="Proof location OSM id", default=None + ), + location_osm_type: Optional[LocationOSMEnum] = Form( + description="Proof location OSM type", default=None + ), date: Optional[str] = Form(description="Proof date", default=None), currency: Optional[CurrencyEnum] = Form(description="Proof currency", default=None), app_name: str | None = None, @@ -54,16 +69,21 @@ def upload_proof( This endpoint requires authentication. """ file_path, mimetype = crud.create_proof_file(file) + # create proof db_proof = crud.create_proof( db, file_path, mimetype, type=type, user=current_user, + location_osm_id=location_osm_id, + location_osm_type=location_osm_type, date=date, currency=currency, source=app_name, ) + # relationships + background_tasks.add_task(tasks.create_proof_location, db, proof=db_proof) return db_proof diff --git a/app/schemas.py b/app/schemas.py index 90886897..6fd62b8b 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -185,15 +185,27 @@ class ProofFull(BaseModel): price_count: int = Field( description="number of prices for this proof.", examples=[15], default=0 ) + location_osm_id: int | None = Field( + gt=0, + description="ID of the location in OpenStreetMap: the store where the product was bought.", + examples=[1234567890], + ) + location_osm_type: LocationOSMEnum | None = Field( + description="type of the OpenStreetMap location object. Stores can be represented as nodes, " + "ways or relations in OpenStreetMap. It is necessary to be able to fetch the correct " + "information about the store using the ID.", + examples=["NODE", "WAY", "RELATION"], + ) + date: datetime.date | None = Field( + description="date of the proof.", examples=["2024-01-01"] + ) currency: CurrencyEnum | None = Field( description="currency of the price, as a string. " "The currency must be a valid currency code. " "See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.", examples=["EUR", "USD"], ) - date: datetime.date | None = Field( - description="date of the proof.", examples=["2024-01-01"] - ) + location_id: int | None owner: str # source: str | None = Field( # description="Source (App name)", @@ -206,6 +218,10 @@ class ProofFull(BaseModel): ) +class ProofFullWithRelations(ProofFull): + location: LocationFull | None + + class ProofBasicUpdatableFields(BaseModel): type: ProofTypeEnum | None = None currency: CurrencyEnum | None = None @@ -498,8 +514,11 @@ class ProofFilter(Filter): price_count: Optional[int] | None = None price_count__gte: Optional[int] | None = None price_count__lte: Optional[int] | None = None - currency: Optional[str] | None = None + location_osm_id: Optional[int] | None = None + location_osm_type: Optional[LocationOSMEnum] | None = None + location_id: Optional[int] | None = None date: Optional[str] | None = None + currency: Optional[str] | None = None date__gt: Optional[str] | None = None date__gte: Optional[str] | None = None date__lt: Optional[str] | None = None diff --git a/app/tasks.py b/app/tasks.py index 94f16d30..40ffc0dc 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -209,6 +209,32 @@ def create_price_location(db: Session, price: Price) -> None: crud.increment_location_price_count(db, location=db_location) +def create_proof_location(db: Session, proof: Proof) -> None: + if proof.location_osm_id and proof.location_osm_type: + # get or create the corresponding location + location = LocationCreate( + osm_id=proof.location_osm_id, + osm_type=proof.location_osm_type, + ) + db_location, created = crud.get_or_create_location(db, location=location) + # link the location to the proof + crud.set_proof_location(db, proof=proof, location=db_location) + # fetch data from OpenStreetMap if created + if created: + location_openstreetmap_details = fetch_location_openstreetmap_details( + location=db_location + ) + if location_openstreetmap_details: + crud.update_location( + db, location=db_location, update_dict=location_openstreetmap_details + ) + # else: + # # Increment the proof count of the location + # crud.increment_location_proof_count(db, location=db_location) + + +# Other +# ------------------------------------------------------------------------------ def dump_db(db: Session, output_dir: Path) -> None: """Dump the database to gzipped JSONL files.""" logger.info("Creating dumps of the database") diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 0cb14693..080b7f4d 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -861,6 +861,8 @@ def test_create_proof(db_session, user_session: SessionModel, clean_proofs): headers={"Authorization": f"Bearer {user_session.token}"}, ) assert response.status_code == 201 + assert response.json()["location_osm_id"] is None + assert response.json()["location_osm_type"] is None assert response.json()["date"] is None assert response.json()["currency"] is None assert len(crud.get_proofs(db_session)) == 1 @@ -868,10 +870,18 @@ def test_create_proof(db_session, user_session: SessionModel, clean_proofs): response = client.post( "/api/v1/proofs/upload", files={"file": ("filename", (io.BytesIO(b"test")), "image/webp")}, - data={"type": "RECEIPT", "date": "2024-01-01", "currency": ""}, + data={ + "type": "RECEIPT", + "location_osm_id": 123, + "location_osm_type": "NODE", + "date": "2024-01-01", + "currency": "", + }, headers={"Authorization": f"Bearer {user_session.token}"}, ) assert response.status_code == 201 + assert response.json()["location_osm_id"] == 123 + assert response.json()["location_osm_type"] == "NODE" assert response.json()["date"] == "2024-01-01" assert response.json()["currency"] is None assert len(crud.get_proofs(db_session)) == 1 + 1 @@ -927,8 +937,12 @@ def test_get_proofs(db_session, user_session: SessionModel, clean_proofs): "mimetype", "type", "price_count", - "currency", + "location_osm_id", + "location_osm_type", "date", + "currency", + "location", + "location_id", "owner", "created", "updated",