diff --git a/testing/__init__.py b/testing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/testing/util.py b/testing/util.py new file mode 100644 index 0000000..5d14c23 --- /dev/null +++ b/testing/util.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import sqlalchemy as sa +from data.models import Expense +from data.models import Item +from data.models import User +from sqlalchemy import orm + + +base_dir = Path(__file__).parent.parent / 'bot' + + +def fill_in_db( + session: orm.Session, + users_number: int = 1, + expenses_number: int = 1, + items_number: int = 1, + currencies: list = ['unit1'], + commit: bool = True, +): + for u in range(users_number): + user = User(id=u, first_name=f'user-{u}') + for i_num in range(items_number): + item_name = str(i_num) + item = Item(name=item_name) + user.items.append(item) + for unit in currencies: + for e_num in range(expenses_number): + expense = Expense(item_name=item_name, price=100, unit=unit) + user.expenses.append(expense) + session.add(user) + if commit: + session.commit() + return session + + +def get_random_user(session: orm.Session) -> User: + return session.scalars(sa.select(User)).first() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py index fe67cad..aa688cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,17 +5,13 @@ import sqlalchemy as sa from data.db_session import make_alembic_config from data.models import Base -from data.models import Expense -from data.models import Item -from data.models import User from sqlalchemy import orm from sqlalchemy.engine import Engine +from testing.util import base_dir -base_dir = Path(__file__).parent.parent / 'bot' - -@pytest.fixture(scope='session') +@pytest.fixture(scope='function') def engine(): with tempfile.TemporaryDirectory() as dir: tmpfile = Path(dir) / 'pytest.db' @@ -28,6 +24,8 @@ def engine(): def db(engine): _factory = orm.sessionmaker(engine, expire_on_commit=False) session: orm.Session = _factory() + Base.metadata.create_all(bind=engine) + yield session for table in reversed(Base.metadata.sorted_tables): session.execute(table.delete()) @@ -44,31 +42,3 @@ def alembic_config(engine: Engine): ) config.set_main_option('sqlalchemy.url', conn_str) return config - - -@pytest.fixture(scope='function') -def session(): - with tempfile.TemporaryDirectory() as dir: - tmpfile = Path(dir) / 'money.db' - conn_str = f'sqlite:///{tmpfile}' - engine = sa.create_engine(conn_str) - Base.metadata.create_all(bind=engine) - - _factory = orm.sessionmaker(engine, expire_on_commit=False) - session: orm.Session = _factory() - item_name = 'Test' - user = User(id=0, first_name='test0') - item = Item(name=item_name) - user.items.append(item) - - for i in range(1, 4): - user = User(id=i, first_name=f'test{i}') - for _ in range(1, 3): - for unit in 'unit1', 'unit2': - expense = Expense(item_name=item_name, price=100, unit=unit) - user.expenses.append(expense) - session.add(user) - - session.commit() - - yield session diff --git a/tests/item_test.py b/tests/item_test.py index 6b7bf05..a66d3d6 100644 --- a/tests/item_test.py +++ b/tests/item_test.py @@ -4,6 +4,8 @@ from sqlalchemy import func from sqlalchemy import select +from testing.util import fill_in_db + @pytest.mark.parametrize( 'ids,item_name,expected', @@ -13,35 +15,37 @@ ([0, 1, 2], 'item3', 3), ] ) -def test_add_item(session, ids, item_name, expected): +def test_add_item(db, ids, item_name, expected): + fill_in_db(session=db, users_number=3, currencies=['unit1', 'unit2']) + for id_ in ids: item_service.add_item( item_name=item_name, user_id=id_, - session=session, + session=db, ) - cnt = session.scalar( + cnt = db.scalar( select(func.count(Item.name)).where(Item.name == item_name) ) assert cnt == expected -def test_add_item_one_user(session): +def test_add_item_one_user(db): item_name = 'add_item_test' item_service.add_item( item_name=item_name, user_id=1, - session=session, + session=db, ) item_service.add_item( item_name=item_name, user_id=1, - session=session, + session=db, ) - cnt = session.scalar( + cnt = db.scalar( select(func.count(Item.name)).where(Item.name == item_name) ) assert cnt == 1 diff --git a/tests/migration_test.py b/tests/migration_test.py index 286922b..2a64a30 100644 --- a/tests/migration_test.py +++ b/tests/migration_test.py @@ -4,7 +4,8 @@ from alembic.config import Config from alembic.script import Script from alembic.script import ScriptDirectory -from conftest import base_dir + +from testing.util import base_dir def get_revisions(): diff --git a/tests/rm_test.py b/tests/rm_test.py new file mode 100644 index 0000000..a3636b8 --- /dev/null +++ b/tests/rm_test.py @@ -0,0 +1,41 @@ +import pytest +from services import rm_service + +from testing.util import fill_in_db +from testing.util import get_random_user + + +def test_raise_wrong_category(db): + fill_in_db(session=db, users_number=2, items_number=0) + user = get_random_user(db) + + with pytest.raises(rm_service.WrongCategoryError): + rm_service.rm_empty_category( + user_id=user.id, + item_name='no such category', + session=db, + ) + + +def test_raise_not_empty_category(db): + fill_in_db(session=db) + user = get_random_user(db) + + with pytest.raises(rm_service.NotEmptyCategoryError): + rm_service.rm_empty_category( + user_id=user.id, + item_name=user.items[0].name, + session=db, + ) + + +def test_success_delete_empty_cat(db): + fill_in_db(session=db, users_number=10, items_number=1, expenses_number=0) + user = get_random_user(db) + + retval = rm_service.rm_empty_category( + user_id=user.id, + item_name=user.items[0].name, + session=db, + ) + assert retval == 0 diff --git a/tox.ini b/tox.ini index febd112..9f171e4 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,11 @@ [tox] -env_list = py310,py311,py312,pre-commit +env_list = py310,py311,pre-commit [testenv] deps = -rrequirements.txt -rrequirements-dev.txt -commands = pytest +commands = pytest {posargs:tests} [testenv:pre-commit] description = Run pre-commit