From b47cf3ecb772dcbc807c20e8e1c869d4f36816dd Mon Sep 17 00:00:00 2001 From: Monique Rio Date: Thu, 31 Oct 2024 18:36:28 +0000 Subject: [PATCH] adds test for count; refactors filter in_zephir --- aim/digifeeds/database/crud.py | 51 ++++++++------------------- aim/digifeeds/database/main.py | 2 +- tests/digifeeds/database/test_crud.py | 13 +++++-- 3 files changed, 25 insertions(+), 41 deletions(-) diff --git a/aim/digifeeds/database/crud.py b/aim/digifeeds/database/crud.py index fa4a434..91d9f91 100644 --- a/aim/digifeeds/database/crud.py +++ b/aim/digifeeds/database/crud.py @@ -24,25 +24,9 @@ def get_item(db: Session, barcode: str): return db.query(models.Item).filter(models.Item.barcode == barcode).first() -def get_item_total(db: Session, in_zephir: bool | None): - if in_zephir is True: - return ( - db.query(models.Item) - .filter( - models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .count() - ) - elif in_zephir is False: - return ( - db.query(models.Item) - .filter( - ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .count() - ) - - return db.query(models.Item).count() +def get_items_total(db: Session, in_zephir: bool | None): + query = get_items_query(db=db, in_zephir=in_zephir) + return query.count() def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int): @@ -56,28 +40,21 @@ def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int): Returns: _type_: _description_ """ + query = get_items_query(db=db, in_zephir=in_zephir) + return query.offset(offset).limit(limit).all() + + +def get_items_query(db: Session, in_zephir: bool | None): + query = db.query(models.Item) if in_zephir is True: - return ( - db.query(models.Item) - .filter( - models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .offset(offset) - .limit(limit) - .all() + query = query.filter( + models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) elif in_zephir is False: - return ( - db.query(models.Item) - .filter( - ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .offset(offset) - .limit(limit) - .all() + query = query.filter( + ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) - - return db.query(models.Item).offset(offset).limit(limit).all() + return query def add_item(db: Session, item: schemas.ItemCreate): diff --git a/aim/digifeeds/database/main.py b/aim/digifeeds/database/main.py index da12190..b7bf5ef 100644 --- a/aim/digifeeds/database/main.py +++ b/aim/digifeeds/database/main.py @@ -59,7 +59,7 @@ def get_items( return { "limit": limit, "offset": offset, - "total": crud.get_item_total(in_zephir=in_zephir, db=db), + "total": crud.get_items_total(in_zephir=in_zephir, db=db), "items": db_items, } diff --git a/tests/digifeeds/database/test_crud.py b/tests/digifeeds/database/test_crud.py index d6f6484..c9b06bd 100644 --- a/tests/digifeeds/database/test_crud.py +++ b/tests/digifeeds/database/test_crud.py @@ -5,6 +5,7 @@ get_status, get_statuses, add_item_status, + get_items_total, ) from aim.digifeeds.database.schemas import ItemCreate @@ -20,38 +21,44 @@ def test_get_item_that_does_not_exist(self, db_session): item_in_db = get_item(barcode="does not exist", db=db_session) assert (item_in_db) is None - def test_get_items_all(self, db_session): + def test_get_items_and_total_any(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) items = get_items(db=db_session, in_zephir=None, limit=2, offset=0) + count = get_items_total(db=db_session, in_zephir=None) db_session.refresh(item1) db_session.refresh(item2) assert (items[0]) == item1 assert (items[1]) == item2 + assert (count) == 2 - def test_get_items_in_zephir(self, db_session): + def test_get_items_and_total_in_zephir(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) items = get_items(db=db_session, in_zephir=True, limit=2, offset=0) + count = get_items_total(db=db_session, in_zephir=True) db_session.refresh(item1) db_session.refresh(item2) assert (len(items)) == 1 assert (items[0]) == item1 + assert count == 1 - def test_get_items_not_in_zephir(self, db_session): + def test_get_items_and_total_not_in_zephir(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) items = get_items(db=db_session, in_zephir=False, limit=2, offset=0) + count = get_items_total(db=db_session, in_zephir=False) db_session.refresh(item1) db_session.refresh(item2) assert (len(items)) == 1 assert (items[0]) == item2 + assert count == 1 def test_get_status_that_exists(self, db_session): status = get_status(db=db_session, name="in_zephir")