Skip to content

Commit

Permalink
fix(tests): use LocationFull for fixtures (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
vjousse committed Mar 4, 2024
1 parent 6d05821 commit f04a53f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 42 deletions.
14 changes: 4 additions & 10 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,7 @@ def get_location_by_osm_id_and_type(
)


def create_location(
db: Session, location: LocationCreate, price_count: int = 0
) -> Location:
def create_location(db: Session, location: LocationCreate) -> Location:
"""Create a location in the database.
:param db: the database session
Expand All @@ -513,16 +511,14 @@ def create_location(
to 0
:return: the created location
"""
db_location = Location(price_count=price_count, **location.model_dump())
db_location = Location(**location.model_dump())
db.add(db_location)
db.commit()
db.refresh(db_location)
return db_location


def get_or_create_location(
db: Session, location: LocationCreate, init_price_count: int = 0
):
def get_or_create_location(db: Session, location: LocationCreate):
"""Get or create a location in the database.
:param db: the database session
Expand All @@ -537,9 +533,7 @@ def get_or_create_location(
db, osm_id=location.osm_id, osm_type=location.osm_type
)
if not db_location:
db_location = create_location(
db, location=location, price_count=init_price_count
)
db_location = create_location(db, location=location)
created = True
return db_location, created

Expand Down
22 changes: 11 additions & 11 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,22 @@ class LocationCreate(BaseModel):

osm_id: int = Field(gt=0)
osm_type: LocationOSMEnum
price_count: int = Field(
description="number of prices for this location.", examples=[15], default=0
)


class LocationFull(LocationCreate):
id: int
osm_name: str | None
osm_display_name: str | None
osm_address_postcode: str | None
osm_address_city: str | None
osm_address_country: str | None
osm_lat: float | None
osm_lon: float | None
price_count: int = Field(
description="number of prices for this location.", examples=[15], default=0
)
osm_name: str | None = None
osm_display_name: str | None = None
osm_address_postcode: str | None = None
osm_address_city: str | None = None
osm_address_country: str | None = None
osm_lat: float | None = None
osm_lon: float | None = None
created: datetime.datetime
updated: datetime.datetime | None
updated: datetime.datetime | None = None


# Proof
Expand Down
8 changes: 4 additions & 4 deletions app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ def create_price_location(db: Session, price: PriceFull):
if price.location_osm_id and price.location_osm_type:
# get or create the corresponding location
location = LocationCreate(
osm_id=price.location_osm_id, osm_type=price.location_osm_type
)
db_location, created = crud.get_or_create_location(
db, location=location, init_price_count=1
osm_id=price.location_osm_id,
osm_type=price.location_osm_type,
price_count=1,
)
db_location, created = crud.get_or_create_location(db, location=location)
# link the location to the price
crud.set_price_location(db, price=price, location=db_location)
# fetch data from OpenStreetMap if created
Expand Down
49 changes: 32 additions & 17 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from app.db import Base, engine, get_db, session
from app.models import Session as SessionModel
from app.schemas import (
LocationCreate,
LocationFull,
PriceCreate,
ProductCreate,
ProofFilter,
UserCreate,
)

Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)


Expand Down Expand Up @@ -79,22 +80,36 @@ def override_get_db():
source="off",
unique_scans_n=0,
)
LOCATION = LocationCreate(osm_id=3344841823, osm_type="NODE")
LOCATION_1 = LocationCreate(
LOCATION = LocationFull(
id=1, osm_id=3344841823, osm_type="NODE", created=datetime.datetime.now()
)
LOCATION_1 = LocationFull(
id=2,
osm_id=652825274,
osm_type="NODE",
osm_name="Monoprix",
osm_address_postcode="38000",
osm_address_city="Grenoble",
osm_address_country="France",
osm_display_name="MMonoprix, Boulevard Joseph Vallier, Secteur 1, Grenoble, Isère, Auvergne-Rhône-Alpes, France métropolitaine, 38000, France",
osm_lat=45.1805534,
osm_lon=5.7153387,
created=datetime.datetime.now(),
updated=datetime.datetime.now(),
)
LOCATION_2 = LocationCreate(
LOCATION_2 = LocationFull(
id=3,
osm_id=6509705997,
osm_type="NODE",
osm_name="Carrefour",
osm_address_postcode="1000",
osm_address_city="Bruxelles - Brussel",
osm_address_country="België / Belgique / Belgien",
osm_display_name="Carrefour à Bruxelles",
osm_lat=1,
osm_lon=2,
created=datetime.datetime.now(),
updated=datetime.datetime.now(),
)
PRICE_1 = PriceCreate(
product_code="8001505005707",
Expand Down Expand Up @@ -1095,28 +1110,28 @@ def test_get_locations_pagination(clean_locations):
assert key in response.json()


# def test_get_locations_filters(db_session, clean_locations):
# crud.create_location(db_session, LOCATION_1)
# crud.create_location(db_session, LOCATION_2)
def test_get_locations_filters(db_session):
crud.create_location(db_session, LOCATION_1)
crud.create_location(db_session, LOCATION_2)

# assert len(crud.get_locations(db_session)) == 2
assert len(crud.get_locations(db_session)) == 2

# # 1 location Monoprix
# response = client.get("/api/v1/locations?osm_name__like=Monoprix")
# assert response.status_code == 200
# assert len(response.json()["items"]) == 1
# # 1 location in France
# response = client.get("/api/v1/locations?osm_address_country__like=France") # noqa
# assert response.status_code == 200
# assert len(response.json()["items"]) == 1
# 1 location Monoprix
response = client.get("/api/v1/locations?osm_name__like=Monoprix")
assert response.status_code == 200
assert len(response.json()["items"]) == 1
# 1 location in France
response = client.get("/api/v1/locations?osm_address_country__like=France") # noqa
assert response.status_code == 200
assert len(response.json()["items"]) == 1


def test_get_location(location):
# by id: location exists
response = client.get(f"/api/v1/locations/{location.id}")
assert response.status_code == 200
# by id: location does not exist
response = client.get(f"/api/v1/locations/{location.id + 1}")
response = client.get("/api/v1/locations/99999")
assert response.status_code == 404
# by osm id & type: location exists
response = client.get(
Expand Down

0 comments on commit f04a53f

Please sign in to comment.