Skip to content

Commit

Permalink
feat(proofs): New location fields (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn committed Jun 22, 2024
1 parent 5a309b4 commit fe54229
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Add Proof location fields
Revision ID: cf78049d89b3
Revises: 49a828f10b05
Create Date: 2024-06-22 22:31:47.195399
"""
from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "cf78049d89b3"
down_revision: Union[str, None] = "49a828f10b05"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"proofs", sa.Column("location_osm_id", sa.BigInteger(), nullable=True)
)
op.add_column(
"proofs", sa.Column("location_osm_type", sa.String(length=255), nullable=True)
)
op.add_column("proofs", sa.Column("location_id", sa.Integer(), nullable=True))
op.create_foreign_key(None, "proofs", "locations", ["location_id"], ["id"])
# Set the location to the location of the first price for each proof
op.execute(
"""
UPDATE proofs
SET location_osm_id = (
SELECT location_osm_id
FROM prices
WHERE prices.proof_id = proofs.id
LIMIT 1
)
WHERE type IN ('PRICE_TAG', 'RECEIPT')
"""
)
op.execute(
"""
UPDATE proofs
SET location_osm_type = (
SELECT location_osm_type
FROM prices
WHERE prices.proof_id = proofs.id
LIMIT 1
)
WHERE type IN ('PRICE_TAG', 'RECEIPT')
"""
)
op.execute(
"""
UPDATE proofs
SET location_id = (
SELECT location_id
FROM prices
WHERE prices.proof_id = proofs.id
LIMIT 1
)
WHERE type IN ('PRICE_TAG', 'RECEIPT')
"""
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "proofs", type_="foreignkey")
op.drop_column("proofs", "location_id")
op.drop_column("proofs", "location_osm_type")
op.drop_column("proofs", "location_osm_id")
# ### end Alembic commands ###
11 changes: 11 additions & 0 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def create_proof(
mimetype: str,
type: ProofTypeEnum,
user: UserCreate,
location_osm_id: int = None,
location_osm_type: LocationOSMEnum = None,
date: str = None,
currency: CurrencyEnum = None,
source: str = None,
Expand All @@ -399,6 +401,8 @@ def create_proof(
file_path=file_path,
mimetype=mimetype,
type=type,
location_osm_id=location_osm_id,
location_osm_type=location_osm_type,
date=date,
currency=currency,
owner=user.user_id,
Expand Down Expand Up @@ -480,6 +484,13 @@ def increment_proof_price_count(db: Session, proof: Proof) -> Proof:
return proof


def set_proof_location(db: Session, proof: Proof, location: Location) -> Proof:
proof.location_id = location.id
db.commit()
db.refresh(proof)
return proof


def update_proof(
db: Session, proof: Proof, new_values: ProofBasicUpdatableFields
) -> Proof:
Expand Down
10 changes: 9 additions & 1 deletion app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Location(Base):
price_count: Mapped[int] = mapped_column(
Integer, nullable=False, server_default="0", index=True
)
proofs: Mapped[list["Proof"]] = relationship(back_populates="location")

created = mapped_column(DateTime(timezone=True), server_default=func.now())
updated = mapped_column(DateTime(timezone=True), onupdate=func.now())
Expand All @@ -124,10 +125,17 @@ class Proof(Base):
Integer, nullable=False, server_default="0", index=True
)

location_osm_id = mapped_column(BigInteger, nullable=True)
location_osm_type: Mapped[LocationOSMEnum] = mapped_column(
ChoiceType(LocationOSMEnum), nullable=True
)
location_id: Mapped[int] = mapped_column(ForeignKey("locations.id"), nullable=True)
location: Mapped[Location] = relationship(back_populates="proofs")

date = mapped_column(Date, nullable=True)
currency: Mapped[CurrencyEnum] = mapped_column(
ChoiceType(CurrencyEnum), nullable=True
)
date = mapped_column(Date, nullable=True)

owner = mapped_column(String, index=True)

Expand Down
3 changes: 2 additions & 1 deletion app/routers/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def create_price(
)
# create price
db_price = crud.create_price(db, price=price, user=current_user, source=app_name)
# update counts
# relationships
background_tasks.add_task(tasks.create_price_product, db, price=db_price)
background_tasks.add_task(tasks.create_price_location, db, price=db_price)
# update counts
background_tasks.add_task(tasks.increment_user_price_count, db, user=current_user)
if price.proof_id and db_proof:
background_tasks.add_task(tasks.increment_proof_price_count, db, proof=db_proof)
Expand Down
28 changes: 24 additions & 4 deletions app/routers/proofs.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from typing import Annotated, Optional

from fastapi import APIRouter, Depends, Form, HTTPException, UploadFile, status
from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
Form,
HTTPException,
UploadFile,
status,
)
from fastapi_filter import FilterDepends
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy import Select
from sqlalchemy.orm import Session

from app import crud, schemas
from app import crud, schemas, tasks
from app.auth import get_current_user
from app.db import get_db
from app.enums import CurrencyEnum, ProofTypeEnum
from app.enums import CurrencyEnum, LocationOSMEnum, ProofTypeEnum
from app.models import Proof

router = APIRouter(prefix="/proofs")


@router.get("", response_model=Page[schemas.ProofFull])
@router.get("", response_model=Page[schemas.ProofFullWithRelations])
def get_user_proofs(
current_user: schemas.UserCreate = Depends(get_current_user),
filters: schemas.ProofFilter = FilterDepends(schemas.ProofFilter),
Expand All @@ -39,6 +47,13 @@ def get_user_proofs(
def upload_proof(
file: UploadFile,
type: Annotated[ProofTypeEnum, Form(description="The type of the proof")],
background_tasks: BackgroundTasks,
location_osm_id: Optional[str] = Form(
description="Proof location OSM id", default=None
),
location_osm_type: Optional[LocationOSMEnum] = Form(
description="Proof location OSM type", default=None
),
date: Optional[str] = Form(description="Proof date", default=None),
currency: Optional[CurrencyEnum] = Form(description="Proof currency", default=None),
app_name: str | None = None,
Expand All @@ -54,16 +69,21 @@ def upload_proof(
This endpoint requires authentication.
"""
file_path, mimetype = crud.create_proof_file(file)
# create proof
db_proof = crud.create_proof(
db,
file_path,
mimetype,
type=type,
user=current_user,
location_osm_id=location_osm_id,
location_osm_type=location_osm_type,
date=date,
currency=currency,
source=app_name,
)
# relationships
background_tasks.add_task(tasks.create_proof_location, db, proof=db_proof)
return db_proof


Expand Down
27 changes: 23 additions & 4 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,27 @@ class ProofFull(BaseModel):
price_count: int = Field(
description="number of prices for this proof.", examples=[15], default=0
)
location_osm_id: int | None = Field(
gt=0,
description="ID of the location in OpenStreetMap: the store where the product was bought.",
examples=[1234567890],
)
location_osm_type: LocationOSMEnum | None = Field(
description="type of the OpenStreetMap location object. Stores can be represented as nodes, "
"ways or relations in OpenStreetMap. It is necessary to be able to fetch the correct "
"information about the store using the ID.",
examples=["NODE", "WAY", "RELATION"],
)
date: datetime.date | None = Field(
description="date of the proof.", examples=["2024-01-01"]
)
currency: CurrencyEnum | None = Field(
description="currency of the price, as a string. "
"The currency must be a valid currency code. "
"See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.",
examples=["EUR", "USD"],
)
date: datetime.date | None = Field(
description="date of the proof.", examples=["2024-01-01"]
)
location_id: int | None
owner: str
# source: str | None = Field(
# description="Source (App name)",
Expand All @@ -206,6 +218,10 @@ class ProofFull(BaseModel):
)


class ProofFullWithRelations(ProofFull):
location: LocationFull | None


class ProofBasicUpdatableFields(BaseModel):
type: ProofTypeEnum | None = None
currency: CurrencyEnum | None = None
Expand Down Expand Up @@ -498,8 +514,11 @@ class ProofFilter(Filter):
price_count: Optional[int] | None = None
price_count__gte: Optional[int] | None = None
price_count__lte: Optional[int] | None = None
currency: Optional[str] | None = None
location_osm_id: Optional[int] | None = None
location_osm_type: Optional[LocationOSMEnum] | None = None
location_id: Optional[int] | None = None
date: Optional[str] | None = None
currency: Optional[str] | None = None
date__gt: Optional[str] | None = None
date__gte: Optional[str] | None = None
date__lt: Optional[str] | None = None
Expand Down
26 changes: 26 additions & 0 deletions app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,32 @@ def create_price_location(db: Session, price: Price) -> None:
crud.increment_location_price_count(db, location=db_location)


def create_proof_location(db: Session, proof: Proof) -> None:
if proof.location_osm_id and proof.location_osm_type:
# get or create the corresponding location
location = LocationCreate(
osm_id=proof.location_osm_id,
osm_type=proof.location_osm_type,
)
db_location, created = crud.get_or_create_location(db, location=location)
# link the location to the proof
crud.set_proof_location(db, proof=proof, location=db_location)
# fetch data from OpenStreetMap if created
if created:
location_openstreetmap_details = fetch_location_openstreetmap_details(
location=db_location
)
if location_openstreetmap_details:
crud.update_location(
db, location=db_location, update_dict=location_openstreetmap_details
)
# else:
# # Increment the proof count of the location
# crud.increment_location_proof_count(db, location=db_location)


# Other
# ------------------------------------------------------------------------------
def dump_db(db: Session, output_dir: Path) -> None:
"""Dump the database to gzipped JSONL files."""
logger.info("Creating dumps of the database")
Expand Down
18 changes: 16 additions & 2 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,17 +861,27 @@ def test_create_proof(db_session, user_session: SessionModel, clean_proofs):
headers={"Authorization": f"Bearer {user_session.token}"},
)
assert response.status_code == 201
assert response.json()["location_osm_id"] is None
assert response.json()["location_osm_type"] is None
assert response.json()["date"] is None
assert response.json()["currency"] is None
assert len(crud.get_proofs(db_session)) == 1

response = client.post(
"/api/v1/proofs/upload",
files={"file": ("filename", (io.BytesIO(b"test")), "image/webp")},
data={"type": "RECEIPT", "date": "2024-01-01", "currency": ""},
data={
"type": "RECEIPT",
"location_osm_id": 123,
"location_osm_type": "NODE",
"date": "2024-01-01",
"currency": "",
},
headers={"Authorization": f"Bearer {user_session.token}"},
)
assert response.status_code == 201
assert response.json()["location_osm_id"] == 123
assert response.json()["location_osm_type"] == "NODE"
assert response.json()["date"] == "2024-01-01"
assert response.json()["currency"] is None
assert len(crud.get_proofs(db_session)) == 1 + 1
Expand Down Expand Up @@ -927,8 +937,12 @@ def test_get_proofs(db_session, user_session: SessionModel, clean_proofs):
"mimetype",
"type",
"price_count",
"currency",
"location_osm_id",
"location_osm_type",
"date",
"currency",
"location",
"location_id",
"owner",
"created",
"updated",
Expand Down

0 comments on commit fe54229

Please sign in to comment.