From 7660f274f38c27c8eefd5c4e3a50e41e74631728 Mon Sep 17 00:00:00 2001 From: "milinsoft@gmail.com" Date: Mon, 29 Apr 2024 11:22:08 -0400 Subject: [PATCH] Refactor -> Unit of Work pattern --- Dockerfile | 1 + __main__.py | 8 +- app/{interface => cli}/__init__.py | 0 app/cli/cli.py | 139 ++++++++++++++++++ app/database/database.py | 8 +- app/domain_classes/__init__.py | 2 + app/domain_classes/account_type.py | 6 + app/domain_classes/transaction_data.py | 44 ++++++ app/interface/cli.py | 117 --------------- app/models/__init__.py | 5 +- app/models/account.py | 27 ---- app/models/bank_app.py | 56 +++---- app/models/base.py | 3 - app/models/models.py | 56 +++++++ app/models/orm_independent.py | 16 -- app/models/transaction.py | 15 -- app/parsers/base.py | 40 ----- app/parsers/csv.py | 27 ++-- app/parsers/exceptions.py | 0 app/parsers/parse_strategy.py | 12 ++ app/parsers/transaction_parser.py | 23 +-- app/repositories/__init__.py | 2 +- app/repositories/account.py | 40 ++--- app/repositories/sql_alchemy/__init__.py | 2 - app/repositories/sql_alchemy/account.py | 43 ------ app/repositories/sql_alchemy/transaction.py | 65 -------- app/repositories/transaction.py | 34 +---- app/schemas/__init__.py | 3 + app/schemas/account.py | 12 ++ app/schemas/transaction.py | 14 ++ app/services/__init__.py | 2 + app/services/account_service.py | 66 +++++++++ app/services/transaction_service.py | 55 +++++++ app/utils/__init__.py | 3 + app/utils/repository.py | 104 +++++++++++++ app/utils/singleton.py | 18 +++ app/utils/unit_of_work.py | 48 ++++++ requirements.txt | 2 +- settings.py | 1 - tests/common.py | 74 ++++++---- tests/test_bank_app.py | 69 --------- tests/test_bank_app_cli.py | 85 +++++++++++ tests/test_file_parse_csv.py | 42 +++--- ...d_format.py => test_transaction_parser.py} | 6 +- 44 files changed, 811 insertions(+), 584 deletions(-) rename app/{interface => cli}/__init__.py (100%) create mode 100644 app/cli/cli.py create mode 100644 app/domain_classes/__init__.py create mode 100644 app/domain_classes/account_type.py create mode 100644 app/domain_classes/transaction_data.py delete mode 100644 app/interface/cli.py delete mode 100644 app/models/account.py delete mode 100644 app/models/base.py create mode 100644 app/models/models.py delete mode 100644 app/models/orm_independent.py delete mode 100644 app/models/transaction.py delete mode 100644 app/parsers/base.py delete mode 100644 app/parsers/exceptions.py create mode 100644 app/parsers/parse_strategy.py delete mode 100644 app/repositories/sql_alchemy/__init__.py delete mode 100644 app/repositories/sql_alchemy/account.py delete mode 100644 app/repositories/sql_alchemy/transaction.py create mode 100644 app/schemas/__init__.py create mode 100644 app/schemas/account.py create mode 100644 app/schemas/transaction.py create mode 100644 app/services/__init__.py create mode 100644 app/services/account_service.py create mode 100644 app/services/transaction_service.py create mode 100644 app/utils/__init__.py create mode 100644 app/utils/repository.py create mode 100644 app/utils/singleton.py create mode 100644 app/utils/unit_of_work.py delete mode 100644 tests/test_bank_app.py create mode 100644 tests/test_bank_app_cli.py rename tests/{test_file_parse_unsupported_format.py => test_transaction_parser.py} (57%) diff --git a/Dockerfile b/Dockerfile index 251a86b..0f57a2f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ FROM alpine +ARG branch ENV GIT_REPO_URL=https://github.com/milinsoft/bank_app ENV PROJECT_FOLDER=/app/bank_app diff --git a/__main__.py b/__main__.py index bfee8bc..1957484 100644 --- a/__main__.py +++ b/__main__.py @@ -2,8 +2,7 @@ from logging import getLogger from app.database import Database -from app.interface import BankAppCli -from settings import DB_URL +from app.cli import BankAppCli _logger = getLogger(__name__) MAJOR = 3 @@ -18,10 +17,9 @@ def check_python_version(min_version=(MAJOR, MINOR)): if __name__ == "__main__": check_python_version() - db_session = Database(DB_URL).session - app = BankAppCli(db_session) + db = Database() + app = BankAppCli(db) try: app.main_menu() except KeyboardInterrupt: - db_session.close() print("\nGoodbye!") diff --git a/app/interface/__init__.py b/app/cli/__init__.py similarity index 100% rename from app/interface/__init__.py rename to app/cli/__init__.py diff --git a/app/cli/cli.py b/app/cli/cli.py new file mode 100644 index 0000000..9e7b537 --- /dev/null +++ b/app/cli/cli.py @@ -0,0 +1,139 @@ +from datetime import date, datetime +from logging import getLogger +from os.path import exists +from typing import Optional, Type, List, Union + +from sqlalchemy.exc import SQLAlchemyError +from tabulate import tabulate + +import settings +from app.database import Database +from app.models import BankApp, Transaction +from app.domain_classes import AccountType +from app.utils import UnitOfWork +from app.parsers import TransactionParser +from app.services import AccountService, TransactionService + +_logger = getLogger(__name__) + + +class BankAppCli(BankApp): + def __init__(self, db: Type[Database]) -> None: + self.acc_service = AccountService() + self.trx_service = TransactionService() + self.db = db + self.account_id: Optional[int] = None + self.uow: Optional[UnitOfWork] = UnitOfWork(db) # imported instance + self.parser = TransactionParser() + + self.menu_options = { + "0": ("Exit", self.exit_app), + "1": ("Import transactions (supported formats are: csv)", self.import_data), + "2": ("Show balance", self.show_balance), + "3": ("Search transactions for the a given period", self.search_transactions), + } + self.menu_msg = "\n".join(f"{k}: {v[0]}" for k, v in self.menu_options.items()) + + def main_menu(self) -> None: + self.pick_account() + while True: + choice = self.get_valid_action() + action = self.menu_options[choice][1] + action() + + def import_data(self): + try: + trx_data = self.parser.parse_data(self.get_file_path()) + _, balance = self.trx_service.create(self.uow, self.account_id, trx_data) + print(f"Transactions have been loaded successfully! Current balance: {balance}") + except (ValueError, SQLAlchemyError) as err: + _logger.error(err) + + @classmethod + def get_file_path(cls) -> str: + while True: + file_path = input("Please provide the path to your file: ").strip("'\"") + if not exists(file_path): + print("Incorrect file path, please try again!") + else: + return file_path + + def show_balance(self): + tar_get_date = self._get_date(mode="end_date") + print( + f"Your balance on {tar_get_date} is: ", + self.acc_service.get_balance(self.uow, self.account_id, tar_get_date), + ) + + def _search_transactions(self) -> List["Transaction"]: + return self.trx_service.get_by_date_range( + self.uow, self.account_id, self._get_date("start_date"), self._get_date("end_date") + ) + + def search_transactions(self) -> None: + transactions = self._search_transactions() + print(self._get_transaction_table(transactions) if transactions else "No transactions found!") + + def get_valid_action(self): + print("\nPICK AN OPTION: ") + action = False + while action not in self.menu_options.keys(): + action = input(f"{self.menu_msg}\n").strip() + return action + + @staticmethod + def _get_date(mode: str): + if mode not in (allowed_modes := ("start_date", "end_date")): + raise ValueError(f"Invalid mode: {mode}. Allowed modes are {allowed_modes}") + today_date = date.today() + + def compose__get_date_message() -> str: + date_example = datetime.strftime(today_date, settings.DATE_FORMAT) + action_description = ( + "search from the oldest transaction\n" + if mode == allowed_modes[0] + else "pick today's date by default!\n" + ) + return f"\nProvide the {mode} in the following {date_example} format or {action_description}\n press enter/return to {action_description}" + + msg = compose__get_date_message() + while True: + tar_get_date = input(msg).strip() + if not tar_get_date: + return datetime.min.date() if mode == allowed_modes[0] else today_date + try: + tar_get_date = datetime.strptime(tar_get_date, settings.DATE_FORMAT).date() + if tar_get_date > today_date: + tar_get_date = today_date + except ValueError: + _logger.error("Incorrect data format") + else: + return tar_get_date + + @staticmethod + def _get_account_type() -> Type[AccountType]: + acc_type: Union[Optional[str], Type[AccountType]] = None + while not acc_type: + acc_type = getattr( + AccountType, input("Pick an account: Debit or Credit (debit/credit): ").upper().strip(), "" + ) + return acc_type + + def pick_account(self) -> None: + acc_type = self._get_account_type() + existing_account = self.acc_service.get_by_type(self.uow, acc_type) + self.account_id = existing_account.id if existing_account else self.acc_service.create_one(self.uow, acc_type) + + @classmethod + def _get_transaction_table(cls, transactions: list[Transaction]) -> str: + """Return transactions in a tabular str format.""" + return tabulate( + [(t.date, t.description, t.amount) for t in transactions], + headers=["Date", "Description", "Amount"], + colalign=("left", "left", "right"), + tablefmt="pretty", + ) + + @staticmethod + def exit_app(): + exit(print("Goodbye!")) diff --git a/app/database/database.py b/app/database/database.py index c7b6548..13e46fd 100644 --- a/app/database/database.py +++ b/app/database/database.py @@ -3,9 +3,10 @@ import settings from app.models import Base +from app.utils import Singleton -class Database: +class Database(metaclass=Singleton): """DB connection abstraction. DB url format: ``dialect[+driver]://user:password@host/dbname[?key=value..]``, # pragma: allowlist secret @@ -19,9 +20,8 @@ class Database: def __init__(self, db_url: str = settings.DB_URL) -> None: self.db_url: str = db_url self.engine: Engine = create_engine(self.db_url) - self.session: Session = self.create_session() - - def create_session(self) -> Session: # Create tables if they don't exist Base.metadata.create_all(self.engine) + + def create_session(self) -> Session: return sessionmaker(bind=self.engine)() diff --git a/app/domain_classes/__init__.py b/app/domain_classes/__init__.py new file mode 100644 index 0000000..383cedb --- /dev/null +++ b/app/domain_classes/__init__.py @@ -0,0 +1,2 @@ +from .account_type import AccountType +from .transaction_data import TransactionData \ No newline at end of file diff --git a/app/domain_classes/account_type.py b/app/domain_classes/account_type.py new file mode 100644 index 0000000..00c553e --- /dev/null +++ b/app/domain_classes/account_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class AccountType(Enum): + CREDIT: int = 1 + DEBIT: int = 2 diff --git a/app/domain_classes/transaction_data.py b/app/domain_classes/transaction_data.py new file mode 100644 index 0000000..de5d66f --- /dev/null +++ b/app/domain_classes/transaction_data.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from datetime import date, datetime +from decimal import Decimal, InvalidOperation + +from typing import Optional, Type +import settings + +@dataclass +class TransactionData: + date: date + amount: Decimal + description: str + account_id: Optional[int] = None + + def __post_init__(self) -> None: + self._convert_str_to_date() + self._convert_str_to_decimal_amount() + self._check_description() + + def set_account_id(self, account_id: int) -> Type["TransactionData"]: + self.account_id = account_id + return self + + def _check_description(self) -> None: + if not self.description: + raise ValueError("Missing transaction description!") + + def _convert_str_to_date(self, date_format: str = settings.DATE_FORMAT) -> None: + try: + converted_date = datetime.strptime(self.date, date_format).date() + except ValueError: + raise ValueError(f"Wrong date format! Provided value: {self.date} Please use {date_format}") + if converted_date > date.today(): + raise ValueError("Transaction date is in the future!") + self.date = converted_date + + def _convert_str_to_decimal_amount(self, rounding: str = settings.ROUNDING) -> None: + try: + converted_amount = Decimal(self.amount).quantize(Decimal("0.00"), rounding=rounding) + if not converted_amount: + raise ValueError + except (ValueError, InvalidOperation): + raise ValueError("Incorrect transaction amount!") + self.amount = converted_amount diff --git a/app/interface/cli.py b/app/interface/cli.py deleted file mode 100644 index 5168054..0000000 --- a/app/interface/cli.py +++ /dev/null @@ -1,117 +0,0 @@ -from datetime import date, datetime -from logging import getLogger -from os.path import exists - -from sqlalchemy.exc import SQLAlchemyError -from tabulate import tabulate - -import settings -from app.models import BankApp, Transaction - -_logger = getLogger(__name__) - - -class BankAppCli(BankApp): - MENU_OPTIONS = { - '0': 'Exit', - '1': 'Import transactions (supported formats are: csv)', - '2': 'Show balance', - '3': 'Search transactions for the a given period', - } - MENU_MSG = '\n'.join((f'{k}: {v}' for k, v in MENU_OPTIONS.items())) - - def main_menu(self): - print('Welcome to the Bank App!') - self.current_account = self.pick_account() - actions = {'0': self.exit_app, '1': self.import_data, '2': self.show_balance, '3': self.search_transactions} - while True: - action = self.get_valid_action() - actions[action]() - - def import_data(self): - try: - self.parser.parse_data(self.get_file_path(), self.current_account) - print(f'Transactions have been loaded successfully! Current balance: {self.current_account.balance}') - except (ValueError, SQLAlchemyError) as err: - _logger.error(err) - - @classmethod - def get_file_path(cls) -> str: - while True: - file_path = input('Please provide the path to your file: ').strip("'\"") - if not exists(file_path): - print('Incorrect file path, please try again!') - else: - return file_path - - def show_balance(self): - target_date = self.get_date(mode='end_date') - print( - f'Your balance on {target_date} is: ', - self.account_repository.get_balance(self.current_account, target_date), - ) - - def search_transactions(self): - transactions = self.transaction_repository.get_by_date_range( - self.current_account, self.get_date('start_date'), self.get_date('end_date') - ) - if not transactions: - print('No transactions found!') - else: - self.display_transactions(transactions) - - @classmethod - def get_valid_action(cls): - print('\nWelcome to the main menu, how can I help you today?: ') - action = False - while action not in cls.MENU_OPTIONS.keys(): - action = input(f'{cls.MENU_MSG}\n').strip() - return action - - @staticmethod - def get_date(mode): - assert mode in (allowed_modes := ('start_date', 'end_date')), 'invalid mode' - today_date = date.today() - - msg = ( - f'\nProvide the {mode} in the following {datetime.strftime(today_date, settings.DATE_FORMAT)} format or\n' - 'press enter/return to ' - ) - msg += 'search from the oldest transaction\n' if mode == allowed_modes[0] else "pick today's date by default!\n" - - while True: - target_date = input(msg).strip() - if not target_date: - return datetime.min.date() if mode == allowed_modes[0] else today_date - try: - target_date = datetime.strptime(target_date, settings.DATE_FORMAT).date() - assert target_date <= today_date, 'Cannot lookup in the future! :)' - except ValueError: - _logger.error('Incorrect data format') - except AssertionError as err: - _logger.error(err) - else: - return target_date - - def pick_account(self): - acc = False - while acc not in ('d', 'c'): - acc = input('Pick an account: Debit or Credit (d/c): ').lower().strip() - return self.accounts['debit'] if acc == 'd' else self.accounts['credit'] - - @classmethod - def display_transactions(cls, transactions: list[Transaction]) -> None: - """Display transactions in a tabular format.""" - table_data = [(t.date, t.description, t.amount) for t in transactions] - print( - tabulate( - table_data, - headers=['Date', 'Description', 'Amount'], - colalign=('left', 'left', 'right'), - tablefmt='pretty', - ) - ) - - def exit_app(self): - self.session.close() - exit(print('Goodbye!')) diff --git a/app/models/__init__.py b/app/models/__init__.py index 34ab742..9e8ccea 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,5 +1,2 @@ -from .base import Base -from .account import Account +from .models import Base,Account, Transaction from .bank_app import BankApp -from .orm_independent import AccountType, TransactionData -from .transaction import Transaction diff --git a/app/models/account.py b/app/models/account.py deleted file mode 100644 index dbfc387..0000000 --- a/app/models/account.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import TYPE_CHECKING - -from sqlalchemy import CheckConstraint, Column, Integer, Numeric, String - -from .base import Base -from .orm_independent import AccountType - -if TYPE_CHECKING: - pass - - -class Account(Base): - __tablename__ = "account" - - id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String, nullable=False) - account_type = Column(Integer, nullable=False) - credit_limit = Column(Numeric(10, 2), default=0, comment="only for credit accounts") - balance = Column(Numeric(10, 2), default=0) - - __table_args__ = ( - CheckConstraint( - f"(account_type = '{AccountType.DEBIT.value}' AND credit_limit = 0 AND balance >= 0) OR " - f"(account_type = '{AccountType.CREDIT.value}' AND credit_limit < 0 AND balance >= credit_limit)", - name="Balance and credit limit constraints", - ), - ) diff --git a/app/models/bank_app.py b/app/models/bank_app.py index 9826534..de2a3e1 100644 --- a/app/models/bank_app.py +++ b/app/models/bank_app.py @@ -1,54 +1,32 @@ -from importlib import import_module -from typing import TYPE_CHECKING, Optional +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Annotated -import settings -from app.models.orm_independent import AccountType -from app.parsers import TransactionParser - -CREDIT = AccountType.CREDIT.value -DEBIT = AccountType.DEBIT.value - -orm = import_module(f"app.repositories.{settings.ORM}") +AccountID = Annotated[int, "Bank Account ID"] if TYPE_CHECKING: - from app.models import Account, Transaction - + from app.models import Transaction -class BankApp: - def __init__(self, session) -> None: - self.session = session - self.account_repository = orm.AccountRepository(self.session) - self.transaction_repository = orm.TransactionRepository(self.session) - self.accounts: dict[str, "Account"] = self.set_default_accounts() - self.current_account: Optional["Account"] = None - self.parser = TransactionParser(self.session, self.transaction_repository) - def main_menu(self): +class BankApp(ABC): + @abstractmethod + def main_menu(self) -> None: raise NotImplementedError - def show_balance(self): + @abstractmethod + def show_balance(self) -> None: raise NotImplementedError - def search_transactions(self): + @abstractmethod + def search_transactions(self) -> None: raise NotImplementedError - def pick_account(self): + @abstractmethod + def pick_account(self) -> AccountID: raise NotImplementedError + # TODO: consider refactoring, so it would be in a repo? + @classmethod - def display_transactions(cls, transactions: list["Transaction"]): + @abstractmethod + def _get_transaction_table(cls, transactions: list["Transaction"]) -> str: raise NotImplementedError - - def set_default_accounts(self) -> dict[str, "Account"]: - """Initialize default Debit and Credit Accounts if they don't exist.""" - debit_acc = self.account_repository.get_by_type(DEBIT) or self.account_repository.create( - name="Debit Account", account_type=DEBIT - ) - credit_acc = self.account_repository.get_by_type(CREDIT) or self.account_repository.create( - name="Credit Account", - account_type=CREDIT, - credit_limit=settings.DEFAULT_CREDIT_LIMIT, - ) - self.session.add_all([debit_acc, credit_acc]) - self.session.commit() - return {"credit": credit_acc, "debit": debit_acc} diff --git a/app/models/base.py b/app/models/base.py deleted file mode 100644 index 860e542..0000000 --- a/app/models/base.py +++ /dev/null @@ -1,3 +0,0 @@ -from sqlalchemy.ext.declarative import declarative_base - -Base = declarative_base() diff --git a/app/models/models.py b/app/models/models.py new file mode 100644 index 0000000..afbf729 --- /dev/null +++ b/app/models/models.py @@ -0,0 +1,56 @@ +from sqlalchemy import Column, Date, ForeignKey, Integer, Numeric, String, CheckConstraint +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship +from typing import Optional, Type +from app.domain_classes import AccountType +import settings +from app.schemas import AccountSchema, TransactionSchema + +Base = declarative_base() + +class Account(Base): + # TODO: update syntax to 2.0 version for models + __tablename__ = "account" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + account_type = Column(Integer, nullable=False) + credit_limit = Column(Numeric(10, 2), default=0, comment="only for credit accounts") + balance = Column(Numeric(10, 2), default=0) + + __table_args__ = ( + CheckConstraint( + f"(account_type = '{AccountType.DEBIT.value}' AND credit_limit = 0 AND balance >= 0) OR " + f"(account_type = '{AccountType.CREDIT.value}' AND credit_limit < 0 AND balance >= credit_limit)", + name="Balance and credit limit constraints", + ), + ) + + def to_read_model(self) -> AccountSchema: + return AccountSchema( + id=self.id, + name=self.name, + account_type=self.account_type, + credit_limit=self.credit_limit, + balance=self.balance, + ) + + +class Transaction(Base): + __tablename__ = "transaction" + + id = Column(Integer, primary_key=True, autoincrement=True) + date = Column(Date, nullable=False) + description = Column(String, nullable=False) + amount = Column(Numeric(10, 2), nullable=False) + account_id = Column(Integer, ForeignKey("account.id"), nullable=False) + account = relationship("Account", backref="Transaction", order_by=date) + + def to_read_model(self) -> TransactionSchema: + return TransactionSchema( + id=self.id, + date=self.date, + description=self.description, + amount=self.amount, + account_id=self.account_id, + ) diff --git a/app/models/orm_independent.py b/app/models/orm_independent.py deleted file mode 100644 index 7c54b5d..0000000 --- a/app/models/orm_independent.py +++ /dev/null @@ -1,16 +0,0 @@ -from dataclasses import dataclass -from datetime import date -from decimal import Decimal -from enum import Enum - - -@dataclass -class TransactionData: - date: date - amount: Decimal - description: str - - -class AccountType(Enum): - CREDIT = 1 - DEBIT = 2 diff --git a/app/models/transaction.py b/app/models/transaction.py deleted file mode 100644 index 43fc22f..0000000 --- a/app/models/transaction.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy import Column, Date, ForeignKey, Integer, Numeric, String -from sqlalchemy.orm import relationship - -from .base import Base - - -class Transaction(Base): - __tablename__ = "transaction" - - id = Column(Integer, primary_key=True, autoincrement=True) - date = Column(Date, nullable=False) - description = Column(String, nullable=False) - amount = Column(Numeric(10, 2), nullable=False) - account_id = Column(Integer, ForeignKey("account.id"), nullable=False) - account = relationship("Account", backref="Transaction", order_by=date) diff --git a/app/parsers/base.py b/app/parsers/base.py deleted file mode 100644 index c6d497b..0000000 --- a/app/parsers/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Sequence -from datetime import date, datetime -from decimal import Decimal, InvalidOperation -from typing import TYPE_CHECKING - -import settings - -if TYPE_CHECKING: - from app.models import TransactionData - - -class ParseStrategy: - @classmethod - def parse_data(cls, file_path) -> Sequence["TransactionData"]: - raise NotImplementedError - - @classmethod - def _convert_date(cls, trx_date: str) -> date: - try: - converted_date = datetime.strptime(trx_date, settings.DATE_FORMAT).date() - except ValueError: - raise ValueError(f"wrong date format! Please use {settings.DATE_FORMAT}") - if converted_date > date.today(): - raise ValueError("Transaction date is in the future!") - return converted_date - - @classmethod - def _convert_amount(cls, trx_amount: str) -> Decimal: - try: - converted_amount = Decimal(trx_amount).quantize(Decimal("0.00"), rounding=settings.ROUNDING) - if not converted_amount: - raise ValueError - except (ValueError, InvalidOperation): - raise ValueError("Incorrect transaction amount!") - return converted_amount - - @classmethod - def _check_description(cls, description: str) -> None: - if not description: - raise ValueError("Missing transaction description!") diff --git a/app/parsers/csv.py b/app/parsers/csv.py index 62cd2a3..d780fc8 100644 --- a/app/parsers/csv.py +++ b/app/parsers/csv.py @@ -1,40 +1,39 @@ import csv -from collections.abc import Sequence +from typing import List -from app.models.orm_independent import TransactionData -from app.parsers.base import ParseStrategy +from app.domain_classes import TransactionData +from app.parsers.parse_strategy import AbstractParseStrategy # noinspection PyClassHasNoInit -class ParseCsv(ParseStrategy): +class ParseCsv(AbstractParseStrategy): EXPECTED_HEADER = ["date", "description", "amount"] - ROW_LENGTH = 3 + ROW_LENGTH = len(EXPECTED_HEADER) @classmethod - def parse_data(cls, file_path: str) -> Sequence[TransactionData]: - parsed_data = [] + def parse_data(cls, file_path: str) -> List[TransactionData]: with open(file_path, encoding="UTF-8-SIG") as f: csv_reader = csv.reader(f) cls._validate_header(next(csv_reader)) # Validate and skip the header row + parsed_data = [] for row_number, row in enumerate(csv_reader, start=1): try: parsed_data.append(cls._process_row(row)) except ValueError as err: raise ValueError(f"The row number {row_number}: {err}") - if not parsed_data: - raise ValueError("No data to import!") + if not parsed_data: + raise ValueError("No data to import!") return parsed_data @classmethod def _validate_header(cls, header): if [col_name.lower().strip() for col_name in header] != cls.EXPECTED_HEADER: - msg = "Incorrect header! Expected: " + ",".join(cls.EXPECTED_HEADER) - raise ValueError(msg) + expected_header = ",".join(cls.EXPECTED_HEADER) + raise ValueError(f"Incorrect header! Expected: {expected_header}") @classmethod - def _process_row(cls, row: Sequence[str]) -> TransactionData: + def _process_row(cls, row: List[str]) -> TransactionData: if row_len := len(row) != cls.ROW_LENGTH: raise ValueError(f"Incorrect number of elements. Expected: {cls.ROW_LENGTH} Found:{row_len}.") date_str, description, amount_str = row - cls._check_description(description) - return TransactionData(cls._convert_date(date_str), cls._convert_amount(amount_str), description) + return TransactionData(date_str, amount_str, description) diff --git a/app/parsers/exceptions.py b/app/parsers/exceptions.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/parsers/parse_strategy.py b/app/parsers/parse_strategy.py new file mode 100644 index 0000000..a2d719b --- /dev/null +++ b/app/parsers/parse_strategy.py @@ -0,0 +1,12 @@ +from typing import TYPE_CHECKING, Sequence +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + from app.domain_classes import TransactionData + + +class AbstractParseStrategy(ABC): + @classmethod + @abstractmethod + def parse_data(cls, file_path: str) -> Sequence["TransactionData"]: + raise NotImplementedError diff --git a/app/parsers/transaction_parser.py b/app/parsers/transaction_parser.py index c0e2915..461b5ca 100644 --- a/app/parsers/transaction_parser.py +++ b/app/parsers/transaction_parser.py @@ -1,12 +1,12 @@ -from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence from app.parsers.csv import ParseCsv -if TYPE_CHECKING: - from app.models import Account, TransactionData - from .csv import ParseStrategy + +if TYPE_CHECKING: + from app.domain_classes import TransactionData + from .csv import AbstractParseStrategy class TransactionParser: @@ -21,20 +21,11 @@ class TransactionParser: strategy_map = {"csv": ParseCsv} - def __init__(self, session, transaction_repository): - self.session = session - self.transaction_repository = transaction_repository - - def _get_strategy(self, file_path: str) -> type["ParseStrategy"] | None: + def _get_strategy(self, file_path: str) -> type["AbstractParseStrategy"] | None: return self.strategy_map.get(file_path.split(".")[-1].lower()) - def parse_data(self, file_path: str, current_account: "Account") -> Sequence["TransactionData"]: + def parse_data(self, file_path: str) -> Sequence["TransactionData"]: if not (parser := self._get_strategy(file_path)): raise ValueError("Unsupported file format.") data = parser.parse_data(file_path) - self._save_to_db(current_account, data) return data - - def _save_to_db(self, current_account, transaction_data: Sequence["TransactionData"]): - self.session.add_all(self.transaction_repository.create(current_account, transaction_data)) - self.session.commit() diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py index 88816c6..dcaeaae 100644 --- a/app/repositories/__init__.py +++ b/app/repositories/__init__.py @@ -1,2 +1,2 @@ from .account import AccountRepository -from .transaction import TransactionRepository +from . transaction import TransactionRepository diff --git a/app/repositories/account.py b/app/repositories/account.py index e2f2bfd..0e45389 100644 --- a/app/repositories/account.py +++ b/app/repositories/account.py @@ -1,30 +1,22 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING -if TYPE_CHECKING: - from datetime import date - from decimal import Decimal - - from app.models import Account +from app.models import Account +from app.utils import SqlAlchemyRepository -class AccountRepository(ABC): - @abstractmethod - def create(self, **kwargs) -> list["Account"]: - pass - - @abstractmethod - def get_by_type(self, account_type: str): - pass +if TYPE_CHECKING: + from decimal import Decimal - @abstractmethod - def get_all(self) -> list["Account"]: - pass - @abstractmethod - def get_by_id(self, account_id: int) -> Optional["Account"]: - pass +class AccountRepository(SqlAlchemyRepository): + model = Account - @abstractmethod - def get_balance(self, account: "Account", trx_date: Optional["date"] = None) -> Union[int, "Decimal"]: - pass + def update_balance(self, account_id: int, amount_to_add: "Decimal") -> "Decimal": + account = self.get_by_id(account_id) + new_balance = account.balance + amount_to_add + if new_balance < account.credit_limit: + raise ValueError( + f"\nImpossible to import data, as your account balance would go less than {account.credit_limit}" + ) + self.update({"balance": new_balance}, where=[self.model.id == account_id]) + return new_balance diff --git a/app/repositories/sql_alchemy/__init__.py b/app/repositories/sql_alchemy/__init__.py deleted file mode 100644 index 88816c6..0000000 --- a/app/repositories/sql_alchemy/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .account import AccountRepository -from .transaction import TransactionRepository diff --git a/app/repositories/sql_alchemy/account.py b/app/repositories/sql_alchemy/account.py deleted file mode 100644 index d4a647a..0000000 --- a/app/repositories/sql_alchemy/account.py +++ /dev/null @@ -1,43 +0,0 @@ -from datetime import date -from typing import TYPE_CHECKING, Union - -from app import repositories -from app.models.account import Account - -from .transaction import TransactionRepository - -if TYPE_CHECKING: - from decimal import Decimal - - from sqlalchemy.orm import Session - - -class AccountRepository(repositories.AccountRepository): - def __init__(self, session: "Session"): - self.session: "Session" = session - self.query = self.session.query(Account) - - def create(self, **kwargs) -> list[Account]: - new_account = Account(**kwargs) - self.session.add(new_account) - self.session.commit() - return new_account - - def get_by_type(self, account_type: str): - return self.query.filter_by(account_type=account_type).first() - - def get_all(self) -> list[Account]: - return self.query.all() - - def get_by_id(self, account_id: int) -> Account | None: - return self.query.filter(Account.id == account_id).first() - - def get_balance(self, account: Account, trx_date: date | None = None) -> Union[int, "Decimal"]: - if not trx_date: - return account.balance # type: ignore - # Incompatible return value type (got "Column[Decimal]", expected "int | Decimal") - if trx_date > date.today(): - raise ValueError("You cannot lookup in the future! :)") - transactions = TransactionRepository(self.session).get_by_date_range(account, end_date=trx_date) - return sum(t.amount for t in transactions) # type: ignore - # Generator has incompatible item type "Column[Decimal]"; expected "bool" diff --git a/app/repositories/sql_alchemy/transaction.py b/app/repositories/sql_alchemy/transaction.py deleted file mode 100644 index e5a17b7..0000000 --- a/app/repositories/sql_alchemy/transaction.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Sequence -from datetime import date, datetime -from typing import TYPE_CHECKING - -from sqlalchemy import desc -from sqlalchemy.sql.expression import and_ - -from app import repositories -from app.models.account import Account -from app.models.transaction import Transaction - -if TYPE_CHECKING: - from sqlalchemy.orm import Session - - from app.models import TransactionData - - -class TransactionRepository(repositories.TransactionRepository): - def __init__(self, session: "Session"): - self.session: "Session" = session - self.query = self.session.query(Transaction) - - def create(self, account: "Account", data: Sequence["TransactionData"]) -> list[Transaction]: - new_balance = account.balance + sum(_t.amount for _t in data) - if new_balance < account.credit_limit: - raise ValueError( - f"\nImpossible to import data, as your account balance would go less than {account.credit_limit}" - ) - new_transactions = [ - Transaction( - date=_t.date, - description=_t.description, - amount=_t.amount, - account_id=account.id, - ) - for _t in data - ] - # account.balance = new_balance - self.session.add_all(new_transactions) - self.session.commit() - account.balance = new_balance # type: ignore - # expression has type "ColumnElement[Decimal]", variable has type "Column[Decimal]" - return new_transactions - - def get_by_account(self, account_id: str): - return self.query.filter_by(account_id=account_id).all() - - def get_by_id(self, transaction_id: int) -> Transaction | None: - return self.query.filter(Transaction.id == transaction_id).first() - - def get_all(self) -> list[Transaction]: - return self.query.all() - - def get_by_date_range( - self, account: "Account", start_date: date | None = None, end_date: date | None = None - ) -> list["Transaction"]: - start_date = start_date or datetime.min.date() - end_date = end_date or date.today() - # Splitting filters for better readability - date_filter = Transaction.date.between(start_date, end_date) - account_filter = Transaction.account_id == account.id - - query = self.query.filter(and_(date_filter, account_filter)) - query = query.order_by(desc(Transaction.date)) - return query.all() diff --git a/app/repositories/transaction.py b/app/repositories/transaction.py index 98e6457..cc1d10c 100644 --- a/app/repositories/transaction.py +++ b/app/repositories/transaction.py @@ -1,32 +1,6 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional, Union +from app.models import Transaction +from app.utils import SqlAlchemyRepository -if TYPE_CHECKING: - from datetime import date - from app.models import Account, Transaction, TransactionData - - -class TransactionRepository(ABC): - @abstractmethod - def create(self, account: "Account", data: Sequence["TransactionData"]) -> list["Transaction"]: - pass - - @abstractmethod - def get_by_account(self, account_id: str): - pass - - @abstractmethod - def get_by_id(self, transaction_id: int) -> Optional["Transaction"]: - pass - - @abstractmethod - def get_all(self) -> list["Transaction"]: - pass - - @abstractmethod - def get_by_date_range( - self, account: "Account", start_date: Optional["date"] = None, end_date: Union["date", None] = None - ) -> list["Transaction"]: - pass +class TransactionRepository(SqlAlchemyRepository): + model = Transaction diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..0cd1b9c --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1,3 @@ +from . account import AccountSchema +from .transaction import TransactionSchema + diff --git a/app/schemas/account.py b/app/schemas/account.py new file mode 100644 index 0000000..605c5c2 --- /dev/null +++ b/app/schemas/account.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, PositiveInt +from decimal import Decimal + +class AccountSchema(BaseModel): + id: PositiveInt + name: str + account_type: PositiveInt + credit_limit: Decimal + balance: Decimal + + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/schemas/transaction.py b/app/schemas/transaction.py new file mode 100644 index 0000000..bacbd00 --- /dev/null +++ b/app/schemas/transaction.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel, PositiveInt +from datetime import date +from decimal import Decimal + + +class TransactionSchema(BaseModel): + id: PositiveInt + date: date + description: str + amount: Decimal + account_id: PositiveInt + + class Config: + from_attributes = True diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..42126d9 --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,2 @@ +from .account_service import AccountService +from .transaction_service import TransactionService diff --git a/app/services/account_service.py b/app/services/account_service.py new file mode 100644 index 0000000..a0bc864 --- /dev/null +++ b/app/services/account_service.py @@ -0,0 +1,66 @@ +from datetime import date +from typing import TYPE_CHECKING, Optional, Union, Type, List +import settings +from sqlalchemy import func +from app.domain_classes import AccountType +from app.models import Transaction, Account +from app.schemas import AccountSchema +from .transaction_service import TransactionService + +if TYPE_CHECKING: + from decimal import Decimal + from app.utils import AbstractUnitOfWork + +# TODO: add abstract class + + +class AccountService: + def __init__(self): + self.trx_service = TransactionService() + + @classmethod + def get_balance( + cls, uow: "AbstractUnitOfWork", account_id: int, trx_date: Optional[date] = None + ) -> Union[int, "Decimal"]: + if trx_date: + if trx_date > date.today(): + raise ValueError("You cannot lookup in the future! :)") + else: + trx_date = date.today() + with uow: + result = uow.transactions.get_aggregated( + filters=[Transaction.date <= trx_date, Transaction.account_id == account_id], + aggregate_func=func.sum, + column_name="amount", + ) + return result[0] or 0 + + @classmethod + def create_one(cls, uow: "AbstractUnitOfWork", acc_type: Type["AccountType"], data: Optional[dict] = None) -> int: + data = data or {} + cls._add_defaults(data, acc_type) + with uow: + account = uow.account.create_one(data) + uow.commit() + return account + + @classmethod + def get_one(cls, uow: "AbstractUnitOfWork", filters=None, order_by=None) -> Type["AccountSchema"]: + with uow: + return uow.account.get_one(filters, order_by) + + def get_by_type(self, uow: "AbstractUnitOfWork", account_type: Type["AccountType"]) -> List[Type["AccountSchema"]]: + return self.get_one(uow, filters=[Account.account_type == account_type.value]) + + @classmethod + def get_by_id(cls, uow: "AbstractUnitOfWork", rec_id: int) -> "AccountSchema": + with uow: + return uow.account.get_by_id(rec_id) + + @staticmethod + def _add_defaults(params: dict, account_type: Type[AccountType]) -> None: + """Add default values for debit account.""" + params.setdefault("name", f"{account_type.name.capitalize()} Account") + params.setdefault("account_type", account_type.value) + if account_type == AccountType.CREDIT: + params.setdefault("credit_limit", settings.DEFAULT_CREDIT_LIMIT) diff --git a/app/services/transaction_service.py b/app/services/transaction_service.py new file mode 100644 index 0000000..436bb4c --- /dev/null +++ b/app/services/transaction_service.py @@ -0,0 +1,55 @@ +from datetime import date, datetime +from dataclasses import asdict +from typing import TYPE_CHECKING, Sequence, Tuple, Optional, List +from sqlalchemy import and_, desc + +if TYPE_CHECKING: + from decimal import Decimal + from app.utils import AbstractUnitOfWork + from app.domain_classes import TransactionData + from app.schemas import TransactionSchema + + +class TransactionService: + @classmethod + def get_by_date_range( + cls, uow: "AbstractUnitOfWork", account_id: int, start_date: date | None = None, end_date: date | None = None + ) -> List["TransactionSchema"]: + start_date = start_date or datetime.min.date() + end_date = end_date or date.today() + with uow: + trx = uow.transactions.model + res = uow.transactions.get_all( + filters=[and_(trx.date.between(start_date, end_date), trx.account_id == account_id)], + order_by=desc(trx.date), + ) + return res + + @classmethod + def create( + cls, uow: "AbstractUnitOfWork", account_id: int, data: Sequence["TransactionData"] + ) -> Tuple[list["TransactionSchema"], "Decimal"]: + amount_to_add = 0 + data_dicts = [] + for trx_data in data: + data_dicts.append(asdict(trx_data.set_account_id(account_id))) + amount_to_add += trx_data.amount + with uow: # single transaction + transactions = uow.transactions.create_multi(data_dicts) + new_balance = uow.account.update_balance(account_id=account_id, amount_to_add=amount_to_add) + uow.commit() + return transactions, new_balance + + @classmethod + def get_one( + cls, + uow: "AbstractUnitOfWork", + filters=None, + order_by=None, + aggregate_func=None, + column_name: Optional[str] = None, + ): + with uow: + return uow.transactions.get_one( + filters, order_by=order_by, aggregate_func=aggregate_func, column_name=column_name + ) diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..9776941 --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,3 @@ +from .repository import AbstractRepository, SqlAlchemyRepository +from .unit_of_work import AbstractUnitOfWork, UnitOfWork +from .singleton import Singleton diff --git a/app/utils/repository.py b/app/utils/repository.py new file mode 100644 index 0000000..61dd7fd --- /dev/null +++ b/app/utils/repository.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from pyexpat import model +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence, Union, Type + +from sqlalchemy import func, insert, select, update +from sqlalchemy.orm import Session +from sqlalchemy.sql._typing import _ColumnExpressionOrStrLabelArgument +from sqlalchemy.sql.base import _NoArg + + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + from sqlalchemy.engine.result import _RowData + from pydantic import BaseModel + + # TODO: orderby types Union[ + # Literal[None, _NoArg.NO_ARG], + # _ColumnExpressionOrStrLabelArgument[Any], + # ] + # create type aliases +# FIXME: add test to choosing EXIT from the app menu + + +class AbstractRepository(ABC): + @abstractmethod + def create_one(self, data: dict) -> Type["BaseModel"]: + raise NotImplementedError + + @abstractmethod + def create_multi(self, data: Sequence[dict]) -> List[Type["BaseModel"]]: + raise NotImplementedError + + @abstractmethod + def get_one(self, filters=None, order_by=None) -> Type["BaseModel"]: + raise NotImplementedError + + @abstractmethod + def get_all( + self, + filters=None, + order_by: Union[ + Literal[None, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], + ] = None, + aggregate_func: Optional[func] = None, + column_name: Optional[str] = None, + ) -> List[Type["BaseModel"]]: + raise NotImplementedError + + @abstractmethod + def get_by_id(self, rec_id: int) -> Type[model]: + raise NotImplementedError + + +class SqlAlchemyRepository(AbstractRepository): + model: Any = None + + def __init__(self, session: Session) -> None: + self.session = session + + def create_one(self, data: dict) -> int: + stmt = insert(self.model).values(**data).returning(self.model.id) + res = self.session.execute(stmt).scalar_one() + return res + + def create_multi(self, data: Sequence[dict]) -> List[Type["BaseModel"]]: + stmt = insert(self.model).values(data).returning(self.model.id) + res = self.session.execute(stmt).scalars().all() + return res + + def _build_selectee(self, aggregate_function: Optional[func] = None, column_name: Optional[str] = None) -> Any: + # TODO: docstring + selectee = self.model + if column_name: + selectee = getattr(selectee, column_name) + if aggregate_function: + selectee = aggregate_function(selectee) + return selectee + + # FIXME add annotations for the filters + + def get_one(self, filters=None, order_by=None) -> Type["BaseModel"]: + # TODO: docstring, if not column_name -> return obj | or column value + stmt = select(self.model).filter(*filters).order_by(order_by) + result = self.session.execute(stmt).scalars().first() + return result and result.to_read_model() + + def get_all(self, filters=None, order_by=None) -> "List[BaseModel]": + stmt = select(self.model).filter(*filters).order_by(order_by) + result_rows = self.session.execute(stmt).scalars().fetchall() + return result_rows and [result.to_read_model() for result in result_rows] + + def get_aggregated(self, aggregate_func: func, column_name: str, filters=None, order_by=None) -> Any: + selectee = self._build_selectee(aggregate_func, column_name) + stmt = select(selectee).filter(*filters).order_by(order_by) + result = self.session.execute(stmt).scalars() + return result.all() + + def get_by_id(self, rec_id: int) -> Type["BaseModel"]: + return self.get_one(filters=[self.model.id == rec_id]) + + def update(self, data: dict, where=None) -> None: + stmt = update(self.model.__table__).values(data).where(*where) + self.session.execute(stmt) diff --git a/app/utils/singleton.py b/app/utils/singleton.py new file mode 100644 index 0000000..cddd579 --- /dev/null +++ b/app/utils/singleton.py @@ -0,0 +1,18 @@ +from threading import Lock +from typing import Type + + +class Singleton(type): + _instance = None + _lock = Lock() + + # FIXME: create lock file to make sure multi process safety + def __new__(cls, *args, **kwargs) -> Type["Singleton"]: + if not cls._instance: + with cls._lock: + # Another thread could have created the instance + # before the lock was acquired. So check that the + # instance is still nonexistent. + if not cls._instance: + cls._instance = super().__new__(cls, *args) + return cls._instance diff --git a/app/utils/unit_of_work.py b/app/utils/unit_of_work.py new file mode 100644 index 0000000..29e4a75 --- /dev/null +++ b/app/utils/unit_of_work.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Type, TYPE_CHECKING + +from app.repositories import AccountRepository, TransactionRepository + +if TYPE_CHECKING: + from app.database import Database + +class AbstractUnitOfWork(ABC): + account: Type[AccountRepository] + transactions: Type[TransactionRepository] # Type as class is accepted, not instance + + @abstractmethod + def __enter__(self): + raise NotImplementedError + + @abstractmethod + def __exit__(self, *args): + raise NotImplementedError + + @abstractmethod + def commit(self): + raise NotImplementedError + + @abstractmethod + def rollback(self): + raise NotImplementedError + + +class UnitOfWork(AbstractUnitOfWork): + + def __init__(self, database: Type["Database"]) -> None: + self.database = database + + def __enter__(self) -> None: + self.session = self.database.create_session() + self.account = AccountRepository(self.session) + self.transactions = TransactionRepository(self.session) + + def __exit__(self, *args) -> None: + self.rollback() + self.session.close() + + def commit(self) -> None: + self.session.commit() + + def rollback(self) -> None: + self.session.rollback() diff --git a/requirements.txt b/requirements.txt index 943f3cc..e0d5edf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ SQLAlchemy==2.0.20 tabulate==0.9.0 -sqlalchemy +pydantic==2.7.1 diff --git a/settings.py b/settings.py index 9743404..31f0329 100644 --- a/settings.py +++ b/settings.py @@ -5,4 +5,3 @@ TEST_DB_URL = 'sqlite:///test_bank_app.database' DEFAULT_CREDIT_LIMIT = -3000 ROUNDING = ROUND_HALF_UP -ORM = "sql_alchemy" diff --git a/tests/common.py b/tests/common.py index cda7d8a..20f46a6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,3 +1,4 @@ +import os import pathlib import unittest from decimal import Decimal @@ -5,59 +6,78 @@ import settings from app.database import Database -from app.interface.cli import BankAppCli -from app.models import Account, Base -from app.parsers import TransactionParser +from app.cli import BankAppCli +from app.models import Base +from app.domain_classes import AccountType +from app.repositories import AccountRepository, TransactionRepository +from typing import TYPE_CHECKING, List, Optional, Type, Tuple + +if TYPE_CHECKING: + from app.domain_classes import TransactionData + from datetime import date + from app.models import BankApp + from sqlalchemy import Engine + # Done this way to make sure CI works test_directory = pathlib.Path(__file__).parent.absolute().resolve() / "data" # Paths to data folders -correct_test_files_dir = f'{test_directory}/correct' -incorrect_test_files_dir = f'{test_directory}/incorrect' +correct_test_files_dir = f"{test_directory}/correct" +incorrect_test_files_dir = f"{test_directory}/incorrect" _logger = getLogger(__name__) def convert_to_decimal(amount: float, rounding: str) -> Decimal: - return Decimal(amount).quantize(Decimal('0.00'), rounding=rounding) + return Decimal(amount).quantize(Decimal("0.00"), rounding=rounding) + + +# TODO move tests to APP folder class TestBankAppCommon(unittest.TestCase): """Test cases for the BankAppCli class.""" - # noinspection PyPep8Naming - # noinspection PyAttributeOutsideInit - def setUp(self): + # TODO: test change to SetUpClass + def setUp(self) -> None: """Set up the test environment.""" - self.db = Database(settings.TEST_DB_URL) - self.session = self.db.session - # TODO: split BankApp from BankAppCli - self.bank_app = BankAppCli(self.session) - self.credit_acc, self.debit_acc = self.bank_app.accounts.values() - self.parser = TransactionParser(self.session, self.bank_app.transaction_repository) - - def tearDown(self): + self.db: Type["Database"] = Database(settings.TEST_DB_URL) + self.bank_app: Type["BankApp"] = BankAppCli(self.db) + self.credit_acc_id: int = self.bank_app.acc_service.create_one(self.bank_app.uow, AccountType.DEBIT) + self.debit_acc_id: int = self.bank_app.acc_service.create_one(self.bank_app.uow, AccountType.CREDIT) + + def tearDown(self) -> None: Base.metadata.drop_all(self.db.engine) # noinspection PyPep8Naming @classmethod def tearDownClass(cls) -> None: """Remove sqlite database (if applicable) after all tests run.""" - import os - try: - os.remove(settings.TEST_DB_URL.replace('sqlite:///', '')) + os.remove(settings.TEST_DB_URL.replace("sqlite:///", "")) except FileNotFoundError: pass except OSError as e: _logger.error(e) # HELPER METHODS - def _test_account_limit(self, file_path: str, account: Account, expect_failure: bool): - if expect_failure: - with self.assertRaises(ValueError): - self.parser.parse_data(file_path, account) + def parse_data(self, file_path: str) -> List["TransactionData"]: + return self.bank_app.parser.parse_data(file_path) + + def get_balance(self, account_id: int, trx_date: Optional["date"] = None) -> Decimal: + return self.bank_app.acc_service.get_balance(self.bank_app.uow, account_id, trx_date) + + def create_transactions(self, account_id: int, data: "TransactionData") -> List["Transaction"]: + return self.bank_app.trx_service.create(self.bank_app.uow, account_id, data) + + def _test_credit_limit( + self, account_id: int, transactions: List["Transaction"], expect_error: bool + ) -> Tuple[List[int], Decimal]: + if expect_error: + with self.assertRaisesRegex( + ValueError, "Impossible to import data, as your account balance would go less than" + ): + self.create_transactions(account_id, transactions) else: - data = self.parser.parse_data(file_path, account) - self.assertEqual(len(data), 1) - self.assertEqual(data[0].amount, -3000) + transaction_ids, balance = self.create_transactions(account_id, transactions) + return transaction_ids, balance diff --git a/tests/test_bank_app.py b/tests/test_bank_app.py deleted file mode 100644 index f945fe7..0000000 --- a/tests/test_bank_app.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest -from datetime import date, datetime, timedelta -from unittest.mock import patch - -import settings -from tests.common import TestBankAppCommon, convert_to_decimal, correct_test_files_dir - -# Test files with CORRECT data -TRANSACTIONS_1 = f'{correct_test_files_dir}/transactions_1.csv' -TRANSACTIONS_2 = f'{correct_test_files_dir}/transactions_2.csv' -TRANSACTIONS_3 = f'{correct_test_files_dir}/transactions_3.csv' - -TEST_DATES = [ - datetime.strptime(_d, settings.DATE_FORMAT).date() - for _d in ("2023-04-01", "2023-04-21", "2023-05-22", "2023-06-23", "2023-07-23", "2023-08-23", "2023-08-24") -] -TEST_AMOUNTS = [ - convert_to_decimal(charge, settings.ROUNDING) - for charge in (100_000.00, 99_987.75, 99_967.5, 99_761.5, 99_761.51, 96_523.0, 296_523.0) -] - - -class TestBankApp(TestBankAppCommon): - def test_01_balance_on_date(self): - """Test account balance on specific dates.""" - self.parser.parse_data(TRANSACTIONS_1, self.debit_acc) - - for trx_date, trx_balance in zip(TEST_DATES, TEST_AMOUNTS): - self.assertEqual(trx_balance, self.bank_app.account_repository.get_balance(self.debit_acc, trx_date)) - - expected_balance = 0 - self.assertEqual(expected_balance, self.bank_app.account_repository.get_balance(self.debit_acc, date.min)) - - # test balance for future date - with self.assertRaisesRegex(ValueError, "You cannot lookup in the future!"): - self.bank_app.account_repository.get_balance(self.debit_acc, date.today() + timedelta(days=1)) - - def test_02_debit_acc_negative_balance(self): - """Test handling negative balance in a Debit account.""" - self._test_account_limit(TRANSACTIONS_2, self.debit_acc, expect_failure=True) - - def test_03_credit_acc_limit(self): - """Test credit account limits.""" - self._test_account_limit(TRANSACTIONS_3, self.credit_acc, expect_failure=False) - self._test_account_limit(TRANSACTIONS_2, self.credit_acc, expect_failure=True) - - def test_04_transactions_lookup_by_range(self): - """Test transactions search by range""" - self.parser.parse_data(TRANSACTIONS_1, self.debit_acc) - - # TRANSACTIONS_1 contains 1 transaction per date, - # the number of found transactions should be increased by 1 by moving to the next date - for n, _date in enumerate(TEST_DATES, start=1): - self.assertEqual( - n, len(self.bank_app.transaction_repository.get_by_date_range(self.debit_acc, end_date=_date)) - ) - - # case where no transactions will be found - self.assertFalse(len(self.bank_app.transaction_repository.get_by_date_range(self.debit_acc, end_date=date.min))) - - @patch("builtins.input", side_effect=["clearly_non_existing_path", TRANSACTIONS_1, TRANSACTIONS_2]) - def test_05_get_file_path_with_retry(self, mock_input): - expected_result = TRANSACTIONS_1 - result = self.bank_app.get_file_path() - self.assertEqual(result, expected_result) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_bank_app_cli.py b/tests/test_bank_app_cli.py new file mode 100644 index 0000000..a555a32 --- /dev/null +++ b/tests/test_bank_app_cli.py @@ -0,0 +1,85 @@ +import unittest +from datetime import date, datetime, timedelta +from unittest.mock import patch + +import settings +from app.domain_classes import AccountType +from tests.common import TestBankAppCommon, convert_to_decimal, correct_test_files_dir + +# Test files with CORRECT data +TRANSACTIONS_1 = f"{correct_test_files_dir}/transactions_1.csv" +TRANSACTIONS_2 = f"{correct_test_files_dir}/transactions_2.csv" +TRANSACTIONS_3 = f"{correct_test_files_dir}/transactions_3.csv" + +TEST_DATES = ( + date.min, + date(2023, 4, 1), + date(2023, 4, 21), + date(2023, 5, 22), + date(2023, 6, 23), + date(2023, 7, 23), + date(2023, 8, 23), + date(2023, 8, 24), +) +TEST_AMOUNTS = [ + convert_to_decimal(charge, settings.ROUNDING) + for charge in (0, 100_000.00, 99_987.75, 99_967.5, 99_761.5, 99_761.51, 96_523.0, 296_523.0) +] + +class TestBankApp(TestBankAppCommon): + + def test_01_balance_on_date(self): + """Test account balance on specific dates.""" + # GIVEN + parsed_data = self.parse_data(TRANSACTIONS_1) + # WHEN + self.create_transactions(self.debit_acc_id, parsed_data) + # THEN + for trx_date, trx_balance in zip(TEST_DATES, TEST_AMOUNTS): + self.assertEqual(trx_balance, self.get_balance(self.debit_acc_id, trx_date)) + self.assertEqual(0, self.get_balance(self.debit_acc_id, date.min)) + + def test_02_balance_on_date_future_date(self): + """Test account balance with the future date.""" + with self.assertRaisesRegex(ValueError, "You cannot lookup in the future!"): + self.get_balance(self.debit_acc_id, date.today() + timedelta(days=1)) + + def test_03_debit_acc_negative_balance(self): + """Test handling negative balance in a Debit and Credit accounts.""" + # GIVEN + parsed_data = self.parse_data(TRANSACTIONS_2) + # WHEN/THEN + self._test_credit_limit(self.debit_acc_id, parsed_data, expect_error=True) + self._test_credit_limit(self.credit_acc_id, parsed_data, expect_error=True) + + def test_04_credit_acc_limit_no_exception(self): + # GIVEN + parsed_data = self.parse_data(TRANSACTIONS_3) + # WHEN + transaction_ids, balance = self._test_credit_limit(self.debit_acc_id, parsed_data, expect_error=False) + # THEN + self.assertEqual(transaction_ids, [1]) + self.assertEqual(balance, -3000) + + def test_05_transactions_lookup_by_range(self): + """Test transactions search by range""" + # GIVEN + parsed_data = self.parse_data(TRANSACTIONS_1) + # WHEN + self.create_transactions(self.debit_acc_id, parsed_data) + # THEN + # TRANSACTIONS_1 contains 1 transaction per date, + # the number of found transactions should be increased by 1 by moving to the next date + for n, _date in enumerate(TEST_DATES): + self.assertEqual( + n, + len(self.bank_app.trx_service.get_by_date_range(self.bank_app.uow, self.debit_acc_id, end_date=_date)), + ) + + @patch("builtins.input", side_effect=["clearly_non_existing_path", TRANSACTIONS_1, TRANSACTIONS_2]) + def test_07_get_file_path_with_retry(self, mock_input): + result = self.bank_app.get_file_path() + self.assertEqual(result, TRANSACTIONS_1) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_file_parse_csv.py b/tests/test_file_parse_csv.py index ff8e7be..71803ab 100644 --- a/tests/test_file_parse_csv.py +++ b/tests/test_file_parse_csv.py @@ -1,50 +1,56 @@ from tests.common import TestBankAppCommon, correct_test_files_dir, incorrect_test_files_dir # Test files with INCORRECT data -INCORRECT_DATA_DATE_FORMAT = f'{incorrect_test_files_dir}/date_format.csv' -INCORRECT_DATA_EMPTY_AMOUNT = f'{incorrect_test_files_dir}/empty_amount.csv' -INCORRECT_DATA_ZERO_AMOUNT = f'{incorrect_test_files_dir}/zero_amount.csv' -INCORRECT_DATA_EMPTY_DESCRIPTION = f'{incorrect_test_files_dir}/empty_description.csv' -INCORRECT_DATA_HEADER_ONLY = f'{incorrect_test_files_dir}/header_only.csv' -INCORRECT_DATA_WRONG_HEADER = f'{incorrect_test_files_dir}/header.csv' +INCORRECT_DATA_DATE_FORMAT = f"{incorrect_test_files_dir}/date_format.csv" +INCORRECT_DATA_EMPTY_AMOUNT = f"{incorrect_test_files_dir}/empty_amount.csv" +INCORRECT_DATA_ZERO_AMOUNT = f"{incorrect_test_files_dir}/zero_amount.csv" +INCORRECT_DATA_EMPTY_DESCRIPTION = f"{incorrect_test_files_dir}/empty_description.csv" +INCORRECT_DATA_HEADER_ONLY = f"{incorrect_test_files_dir}/header_only.csv" +INCORRECT_DATA_WRONG_HEADER = f"{incorrect_test_files_dir}/header.csv" # Test files with CORRECT data -TRANSACTIONS_1 = f'{correct_test_files_dir}/transactions_1.csv' -TRANSACTIONS_2 = f'{correct_test_files_dir}/transactions_2.csv' -TRANSACTIONS_3 = f'{correct_test_files_dir}/transactions_3.csv' +TRANSACTIONS_1 = f"{correct_test_files_dir}/transactions_1.csv" +TRANSACTIONS_2 = f"{correct_test_files_dir}/transactions_2.csv" +TRANSACTIONS_3 = f"{correct_test_files_dir}/transactions_3.csv" class TestFileParseCSV(TestBankAppCommon): + def test_01_parse_transactions(self): - """Test parseing transactions and checking account balance.""" - self.parser.parse_data(TRANSACTIONS_1, self.debit_acc) - self.assertEqual(self.debit_acc.balance, 296_523.00) + """Test parsing transactions and checking account balance.""" + # GIVEN + parsed_data = self.parse_data(TRANSACTIONS_1) + # WHEN + self.create_transactions(self.debit_acc_id, parsed_data) + # THEN + acc_balance = self.get_balance(self.debit_acc_id) + self.assertEqual(acc_balance, 296_523.00) def test_02_parse_file_with_INCORRECT_date(self): """Test parse file with incorrect date format""" with self.assertRaises(ValueError): - self.parser.parse_data(INCORRECT_DATA_DATE_FORMAT, self.debit_acc) + self.bank_app.parser.parse_data(INCORRECT_DATA_DATE_FORMAT) def test_03_parse_file_without_amount(self): """Test parse file with skipped amount""" with self.assertRaises(ValueError): - self.parser.parse_data(INCORRECT_DATA_EMPTY_AMOUNT, self.debit_acc) + self.bank_app.parser.parse_data(INCORRECT_DATA_EMPTY_AMOUNT) def test_04_parse_file_without_description(self): """Test parse file without transaction description""" with self.assertRaises(ValueError): - self.parser.parse_data(INCORRECT_DATA_EMPTY_DESCRIPTION, self.debit_acc) + self.bank_app.parser.parse_data(INCORRECT_DATA_EMPTY_DESCRIPTION) def test_05_parse_transactions_header_only(self): """Test parse correct header, but missing transactions""" with self.assertRaises(ValueError): - self.parser.parse_data(INCORRECT_DATA_HEADER_ONLY, self.debit_acc) + self.bank_app.parser.parse_data(INCORRECT_DATA_HEADER_ONLY) def test_06_parse_transactions_wrong_header(self): """Test parse data with INCORRECT header""" with self.assertRaises(ValueError): - self.parser.parse_data(INCORRECT_DATA_WRONG_HEADER, self.debit_acc) + self.bank_app.parser.parse_data(INCORRECT_DATA_WRONG_HEADER) def test_07_parse_file_with_amount_zero(self): """Test parse file containing row with the amount 0""" with self.assertRaises(ValueError): - self.parser.parse_data(INCORRECT_DATA_ZERO_AMOUNT, self.debit_acc) + self.bank_app.parser.parse_data(INCORRECT_DATA_ZERO_AMOUNT) diff --git a/tests/test_file_parse_unsupported_format.py b/tests/test_transaction_parser.py similarity index 57% rename from tests/test_file_parse_unsupported_format.py rename to tests/test_transaction_parser.py index e02db04..b9b151e 100644 --- a/tests/test_file_parse_unsupported_format.py +++ b/tests/test_transaction_parser.py @@ -3,8 +3,8 @@ UNSUPPORTED_FILE_FORMAT = f'{incorrect_test_files_dir}/unsupported_file_format.jpeg' -class TestFileParseUnsupportedFormat(TestBankAppCommon): - def test_01_parse_unsupported_format(self): +class TestTransactionParser(TestBankAppCommon): + def test_01_parsing_strategy_unsupported_format(self): """Test import of unsupported file""" with self.assertRaises(ValueError): - self.parser.parse_data(UNSUPPORTED_FILE_FORMAT, self.debit_acc) + self.bank_app.parser.parse_data(UNSUPPORTED_FILE_FORMAT)