diff --git a/app/api.py b/app/api.py index d3f1c55e..93729c5c 100644 --- a/app/api.py +++ b/app/api.py @@ -18,9 +18,9 @@ from openfoodfacts.utils import get_logger from app import crud +from app import schemas from app.config import settings from app.db import session -from app.schemas import UserBase from app.utils import init_sentry @@ -55,6 +55,18 @@ async def create_token(user_id: str): return f"{user_id}__U{str(uuid.uuid4())}" +async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): + if token and '__U' in token: + current_user: schemas.UserBase = crud.update_user_last_used_field(db, token=token) # type: ignore + if current_user: + return current_user + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # App startup & shutdown # ------------------------------------------------------------------------------ @app.on_event("startup") @@ -99,7 +111,7 @@ async def authentication(form_data: Annotated[OAuth2PasswordRequestForm, Depends r = requests.post(settings.oauth2_server_url, data=data) # type: ignore if r.status_code == 200: token = await create_token(form_data.username) - user: UserBase = {"user_id": form_data.username, "token": token} # type: ignore + user: schemas.UserBase = {"user_id": form_data.username, "token": token} # type: ignore crud.create_user(db, user=user) # type: ignore return {"access_token": token, "token_type": "bearer"} elif r.status_code == 403: @@ -112,6 +124,12 @@ async def authentication(form_data: Annotated[OAuth2PasswordRequestForm, Depends raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Server error") +@app.post("/prices", response_model=schemas.PriceBase) +async def create_price(price: schemas.PriceCreate, current_user: schemas.UserBase = Depends(get_current_user)): + db_price = crud.create_price(db, price=price, user=current_user) # type: ignore + return db_price + + @app.get("/robots.txt", response_class=PlainTextResponse) def robots_txt(): return """User-agent: *\nDisallow: /""" diff --git a/app/crud.py b/app/crud.py index 16ca652f..efb39975 100644 --- a/app/crud.py +++ b/app/crud.py @@ -1,6 +1,9 @@ from sqlalchemy.orm import Session +from sqlalchemy.sql import func +from app.models import Price from app.models import User +from app.schemas import PriceCreate from app.schemas import UserBase @@ -27,6 +30,16 @@ def create_user(db: Session, user: UserBase): return db_user +def update_user_last_used_field(db: Session, token: str): + db_user = get_user_by_token(db, token=token) + if db_user: + db.query(User).filter(User.user_id == db_user.user_id).update({"last_used": func.now()}) + db.commit() + db.refresh(db_user) + return db_user + return False + + def delete_user(db: Session, user_id: UserBase): db_user = get_user_by_user_id(db, user_id=user_id) if db_user: @@ -34,3 +47,11 @@ def delete_user(db: Session, user_id: UserBase): db.commit() return True return False + + +def create_price(db: Session, price: PriceCreate, user: UserBase): + db_price = Price(**price.dict(), owner=user.user_id) + db.add(db_price) + db.commit() + db.refresh(db_price) + return db_price diff --git a/app/schemas.py b/app/schemas.py index b795f3bc..1cd233af 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,5 +1,14 @@ +from datetime import date +from datetime import datetime + from pydantic import BaseModel from pydantic import ConfigDict +from pydantic import Field +from pydantic import field_serializer +from pydantic import field_validator +from sqlalchemy_utils import Currency + +from app.enums import PriceLocationOSMType class UserBase(BaseModel): @@ -7,3 +16,32 @@ class UserBase(BaseModel): user_id: str token: str + + +class PriceCreate(BaseModel): + model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) + + product_code: str = Field(min_length=1, pattern="^[0-9]+$") + price: float + currency: str | Currency + location_osm_id: int = Field(gt=0) + location_osm_type: PriceLocationOSMType + date: date + + @field_validator("currency") + def currency_is_valid(cls, v): + try: + return Currency(v).code + except ValueError: + raise ValueError("not a valid currency code") + + @field_serializer("currency") + def serialize_currency(self, currency: Currency, _info): + if type(currency) is Currency: + return currency.code + return currency + + +class PriceBase(PriceCreate): + # owner: str + created: datetime