From f04a53f6132905ff2f5c16b7bd9854992092885e Mon Sep 17 00:00:00 2001 From: Vincent Jousse Date: Mon, 4 Mar 2024 21:33:32 +0100 Subject: [PATCH] fix(tests): use LocationFull for fixtures (#240) --- app/crud.py | 14 +++------- app/schemas.py | 22 ++++++++-------- app/tasks.py | 8 +++--- tests/integration/test_api.py | 49 +++++++++++++++++++++++------------ 4 files changed, 51 insertions(+), 42 deletions(-) diff --git a/app/crud.py b/app/crud.py index 874e317e..5ca95215 100644 --- a/app/crud.py +++ b/app/crud.py @@ -502,9 +502,7 @@ def get_location_by_osm_id_and_type( ) -def create_location( - db: Session, location: LocationCreate, price_count: int = 0 -) -> Location: +def create_location(db: Session, location: LocationCreate) -> Location: """Create a location in the database. :param db: the database session @@ -513,16 +511,14 @@ def create_location( to 0 :return: the created location """ - db_location = Location(price_count=price_count, **location.model_dump()) + db_location = Location(**location.model_dump()) db.add(db_location) db.commit() db.refresh(db_location) return db_location -def get_or_create_location( - db: Session, location: LocationCreate, init_price_count: int = 0 -): +def get_or_create_location(db: Session, location: LocationCreate): """Get or create a location in the database. :param db: the database session @@ -537,9 +533,7 @@ def get_or_create_location( db, osm_id=location.osm_id, osm_type=location.osm_type ) if not db_location: - db_location = create_location( - db, location=location, price_count=init_price_count - ) + db_location = create_location(db, location=location) created = True return db_location, created diff --git a/app/schemas.py b/app/schemas.py index 20c82a2e..60862797 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -114,22 +114,22 @@ class LocationCreate(BaseModel): osm_id: int = Field(gt=0) osm_type: LocationOSMEnum + price_count: int = Field( + description="number of prices for this location.", examples=[15], default=0 + ) class LocationFull(LocationCreate): id: int - osm_name: str | None - osm_display_name: str | None - osm_address_postcode: str | None - osm_address_city: str | None - osm_address_country: str | None - osm_lat: float | None - osm_lon: float | None - price_count: int = Field( - description="number of prices for this location.", examples=[15], default=0 - ) + osm_name: str | None = None + osm_display_name: str | None = None + osm_address_postcode: str | None = None + osm_address_city: str | None = None + osm_address_country: str | None = None + osm_lat: float | None = None + osm_lon: float | None = None created: datetime.datetime - updated: datetime.datetime | None + updated: datetime.datetime | None = None # Proof diff --git a/app/tasks.py b/app/tasks.py index 14351cee..949ca45f 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -163,11 +163,11 @@ def create_price_location(db: Session, price: PriceFull): if price.location_osm_id and price.location_osm_type: # get or create the corresponding location location = LocationCreate( - osm_id=price.location_osm_id, osm_type=price.location_osm_type - ) - db_location, created = crud.get_or_create_location( - db, location=location, init_price_count=1 + osm_id=price.location_osm_id, + osm_type=price.location_osm_type, + price_count=1, ) + db_location, created = crud.get_or_create_location(db, location=location) # link the location to the price crud.set_price_location(db, price=price, location=db_location) # fetch data from OpenStreetMap if created diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 6f6524cd..350a6097 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -10,13 +10,14 @@ from app.db import Base, engine, get_db, session from app.models import Session as SessionModel from app.schemas import ( - LocationCreate, + LocationFull, PriceCreate, ProductCreate, ProofFilter, UserCreate, ) +Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) @@ -79,22 +80,36 @@ def override_get_db(): source="off", unique_scans_n=0, ) -LOCATION = LocationCreate(osm_id=3344841823, osm_type="NODE") -LOCATION_1 = LocationCreate( +LOCATION = LocationFull( + id=1, osm_id=3344841823, osm_type="NODE", created=datetime.datetime.now() +) +LOCATION_1 = LocationFull( + id=2, osm_id=652825274, osm_type="NODE", osm_name="Monoprix", osm_address_postcode="38000", osm_address_city="Grenoble", osm_address_country="France", + osm_display_name="MMonoprix, Boulevard Joseph Vallier, Secteur 1, Grenoble, Isère, Auvergne-Rhône-Alpes, France métropolitaine, 38000, France", + osm_lat=45.1805534, + osm_lon=5.7153387, + created=datetime.datetime.now(), + updated=datetime.datetime.now(), ) -LOCATION_2 = LocationCreate( +LOCATION_2 = LocationFull( + id=3, osm_id=6509705997, osm_type="NODE", osm_name="Carrefour", osm_address_postcode="1000", osm_address_city="Bruxelles - Brussel", osm_address_country="België / Belgique / Belgien", + osm_display_name="Carrefour à Bruxelles", + osm_lat=1, + osm_lon=2, + created=datetime.datetime.now(), + updated=datetime.datetime.now(), ) PRICE_1 = PriceCreate( product_code="8001505005707", @@ -1095,20 +1110,20 @@ def test_get_locations_pagination(clean_locations): assert key in response.json() -# def test_get_locations_filters(db_session, clean_locations): -# crud.create_location(db_session, LOCATION_1) -# crud.create_location(db_session, LOCATION_2) +def test_get_locations_filters(db_session): + crud.create_location(db_session, LOCATION_1) + crud.create_location(db_session, LOCATION_2) -# assert len(crud.get_locations(db_session)) == 2 + assert len(crud.get_locations(db_session)) == 2 -# # 1 location Monoprix -# response = client.get("/api/v1/locations?osm_name__like=Monoprix") -# assert response.status_code == 200 -# assert len(response.json()["items"]) == 1 -# # 1 location in France -# response = client.get("/api/v1/locations?osm_address_country__like=France") # noqa -# assert response.status_code == 200 -# assert len(response.json()["items"]) == 1 + # 1 location Monoprix + response = client.get("/api/v1/locations?osm_name__like=Monoprix") + assert response.status_code == 200 + assert len(response.json()["items"]) == 1 + # 1 location in France + response = client.get("/api/v1/locations?osm_address_country__like=France") # noqa + assert response.status_code == 200 + assert len(response.json()["items"]) == 1 def test_get_location(location): @@ -1116,7 +1131,7 @@ def test_get_location(location): response = client.get(f"/api/v1/locations/{location.id}") assert response.status_code == 200 # by id: location does not exist - response = client.get(f"/api/v1/locations/{location.id + 1}") + response = client.get("/api/v1/locations/99999") assert response.status_code == 404 # by osm id & type: location exists response = client.get(