Skip to content

Commit

Permalink
adds test for count; refactors filter in_zephir
Browse files Browse the repository at this point in the history
  • Loading branch information
niquerio committed Oct 31, 2024
1 parent 7d91d99 commit b47cf3e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 41 deletions.
51 changes: 14 additions & 37 deletions aim/digifeeds/database/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,9 @@ def get_item(db: Session, barcode: str):
return db.query(models.Item).filter(models.Item.barcode == barcode).first()


def get_item_total(db: Session, in_zephir: bool | None):
if in_zephir is True:
return (
db.query(models.Item)
.filter(
models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir")
)
.count()
)
elif in_zephir is False:
return (
db.query(models.Item)
.filter(
~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir")
)
.count()
)

return db.query(models.Item).count()
def get_items_total(db: Session, in_zephir: bool | None):
query = get_items_query(db=db, in_zephir=in_zephir)
return query.count()


def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int):
Expand All @@ -56,28 +40,21 @@ def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int):
Returns:
_type_: _description_
"""
query = get_items_query(db=db, in_zephir=in_zephir)
return query.offset(offset).limit(limit).all()


def get_items_query(db: Session, in_zephir: bool | None):
query = db.query(models.Item)
if in_zephir is True:
return (
db.query(models.Item)
.filter(
models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir")
)
.offset(offset)
.limit(limit)
.all()
query = query.filter(
models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir")
)
elif in_zephir is False:
return (
db.query(models.Item)
.filter(
~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir")
)
.offset(offset)
.limit(limit)
.all()
query = query.filter(
~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir")
)

return db.query(models.Item).offset(offset).limit(limit).all()
return query


def add_item(db: Session, item: schemas.ItemCreate):
Expand Down
2 changes: 1 addition & 1 deletion aim/digifeeds/database/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_items(
return {
"limit": limit,
"offset": offset,
"total": crud.get_item_total(in_zephir=in_zephir, db=db),
"total": crud.get_items_total(in_zephir=in_zephir, db=db),
"items": db_items,
}

Expand Down
13 changes: 10 additions & 3 deletions tests/digifeeds/database/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
get_status,
get_statuses,
add_item_status,
get_items_total,
)
from aim.digifeeds.database.schemas import ItemCreate

Expand All @@ -20,38 +21,44 @@ def test_get_item_that_does_not_exist(self, db_session):
item_in_db = get_item(barcode="does not exist", db=db_session)
assert (item_in_db) is None

def test_get_items_all(self, db_session):
def test_get_items_and_total_any(self, db_session):
item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode"))
item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2"))
status = get_status(db=db_session, name="in_zephir")
add_item_status(db=db_session, item=item1, status=status)
items = get_items(db=db_session, in_zephir=None, limit=2, offset=0)
count = get_items_total(db=db_session, in_zephir=None)
db_session.refresh(item1)
db_session.refresh(item2)
assert (items[0]) == item1
assert (items[1]) == item2
assert (count) == 2

def test_get_items_in_zephir(self, db_session):
def test_get_items_and_total_in_zephir(self, db_session):
item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode"))
item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2"))
status = get_status(db=db_session, name="in_zephir")
add_item_status(db=db_session, item=item1, status=status)
items = get_items(db=db_session, in_zephir=True, limit=2, offset=0)
count = get_items_total(db=db_session, in_zephir=True)
db_session.refresh(item1)
db_session.refresh(item2)
assert (len(items)) == 1
assert (items[0]) == item1
assert count == 1

def test_get_items_not_in_zephir(self, db_session):
def test_get_items_and_total_not_in_zephir(self, db_session):
item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode"))
item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2"))
status = get_status(db=db_session, name="in_zephir")
add_item_status(db=db_session, item=item1, status=status)
items = get_items(db=db_session, in_zephir=False, limit=2, offset=0)
count = get_items_total(db=db_session, in_zephir=False)
db_session.refresh(item1)
db_session.refresh(item2)
assert (len(items)) == 1
assert (items[0]) == item2
assert count == 1

def test_get_status_that_exists(self, db_session):
status = get_status(db=db_session, name="in_zephir")
Expand Down

0 comments on commit b47cf3e

Please sign in to comment.