diff --git a/app/crud.py b/app/crud.py index 6f41d6fe..e2311700 100644 --- a/app/crud.py +++ b/app/crud.py @@ -8,7 +8,7 @@ from sqlalchemy.sql import func from app import config -from app.enums import LocationOSMType +from app.enums import LocationOSMEnum from app.models import Location, Price, Product, Proof, User from app.schemas import ( LocationBase, @@ -214,7 +214,7 @@ def get_location_by_id(db: Session, id: int): def get_location_by_osm_id_and_type( - db: Session, osm_id: int, osm_type: LocationOSMType + db: Session, osm_id: int, osm_type: LocationOSMEnum ): return ( db.query(Location) diff --git a/app/enums.py b/app/enums.py index 3ebb1c47..f971a05f 100644 --- a/app/enums.py +++ b/app/enums.py @@ -1,7 +1,13 @@ from enum import Enum +from babel.numbers import list_currencies -class LocationOSMType(Enum): +CURRENCIES = [(currency, currency) for currency in list_currencies()] + +CurrencyEnum = Enum("CurrencyEnum", CURRENCIES) + + +class LocationOSMEnum(Enum): NODE = "NODE" WAY = "WAY" RELATION = "RELATION" diff --git a/app/models.py b/app/models.py index 10d172d0..ccda8106 100644 --- a/app/models.py +++ b/app/models.py @@ -14,10 +14,9 @@ from sqlalchemy.sql import func from sqlalchemy_utils import force_auto_coercion from sqlalchemy_utils.types.choice import ChoiceType -from sqlalchemy_utils.types.currency import CurrencyType from app.db import Base -from app.enums import LocationOSMType +from app.enums import CurrencyEnum, LocationOSMEnum force_auto_coercion() @@ -55,7 +54,7 @@ class Location(Base): id = Column(Integer, primary_key=True, index=True) osm_id = Column(BigInteger) - osm_type = Column(ChoiceType(LocationOSMType)) + osm_type = Column(ChoiceType(LocationOSMEnum)) osm_name = Column(String) osm_display_name = Column(String) osm_address_postcode = Column(String) @@ -97,10 +96,10 @@ class Price(Base): product: Mapped[Product] = relationship(back_populates="prices") price = Column(Numeric(precision=10, scale=2)) - currency = Column(CurrencyType) + currency = Column(ChoiceType(CurrencyEnum)) location_osm_id = Column(BigInteger, index=True) - location_osm_type = Column(ChoiceType(LocationOSMType)) + location_osm_type = Column(ChoiceType(LocationOSMEnum)) location_id: Mapped[int] = mapped_column(ForeignKey("locations.id"), nullable=True) location: Mapped[Location] = relationship(back_populates="prices") diff --git a/app/schemas.py b/app/schemas.py index 7878ae65..f571bbc3 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -8,13 +8,11 @@ BaseModel, ConfigDict, Field, - field_serializer, field_validator, model_validator, ) -from sqlalchemy_utils import Currency -from app.enums import LocationOSMType +from app.enums import CurrencyEnum, LocationOSMEnum from app.models import Price @@ -65,7 +63,7 @@ class LocationCreate(BaseModel): model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) osm_id: int = Field(gt=0) - osm_type: LocationOSMType + osm_type: LocationOSMEnum class LocationBase(LocationCreate): @@ -126,7 +124,7 @@ class PriceCreate(BaseModel): "kilogram or per liter.", examples=["1.99"], ) - currency: str | Currency = Field( + currency: CurrencyEnum = Field( description="currency of the price, as a string. " "The currency must be a valid currency code. " "See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.", @@ -137,7 +135,7 @@ class PriceCreate(BaseModel): description="ID of the location in OpenStreetMap: the store where the product was bought.", examples=[1234567890], ) - location_osm_type: LocationOSMType = Field( + location_osm_type: LocationOSMEnum = Field( description="type of the OpenStreetMap location object. Stores can be represented as nodes, " "ways or relations in OpenStreetMap. It is necessary to be able to fetch the correct " "information about the store using the ID.", @@ -152,25 +150,12 @@ class PriceCreate(BaseModel): examples=[15], ) - @field_validator("currency") - def currency_is_valid(cls, v): - try: - return Currency(v).code - except ValueError: - raise ValueError("not a valid currency code") - @field_validator("labels_tags") def labels_tags_is_valid(cls, v): if v is not None: if len(v) == 0: raise ValueError("`labels_tags` cannot be empty") - @field_serializer("currency") - def serialize_currency(self, currency: Currency, _info): - if type(currency) is Currency: - return currency.code - return currency - @model_validator(mode="after") def product_code_and_category_tag_are_exclusive(self): """Validator that checks that `product_code` and `category_tag` are @@ -207,7 +192,7 @@ class ProofBase(ProofCreate): class PriceFilter(Filter): product_code: Optional[str] | None = None location_osm_id: Optional[int] | None = None - location_osm_type: Optional[LocationOSMType] | None = None + location_osm_type: Optional[LocationOSMEnum] | None = None price: Optional[int] | None = None currency: Optional[str] | None = None price__gt: Optional[int] | None = None