Skip to content

Commit

Permalink
Refactor -> Unit of Work pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
milinsoft committed May 12, 2024
1 parent 76f39eb commit 7660f27
Show file tree
Hide file tree
Showing 44 changed files with 811 additions and 584 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
FROM alpine

ARG branch
ENV GIT_REPO_URL=https://github.com/milinsoft/bank_app
ENV PROJECT_FOLDER=/app/bank_app

Expand Down
8 changes: 3 additions & 5 deletions __main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from logging import getLogger

from app.database import Database
from app.interface import BankAppCli
from settings import DB_URL
from app.cli import BankAppCli

_logger = getLogger(__name__)
MAJOR = 3
Expand All @@ -18,10 +17,9 @@ def check_python_version(min_version=(MAJOR, MINOR)):

if __name__ == "__main__":
check_python_version()
db_session = Database(DB_URL).session
app = BankAppCli(db_session)
db = Database()
app = BankAppCli(db)
try:
app.main_menu()
except KeyboardInterrupt:
db_session.close()
print("\nGoodbye!")
File renamed without changes.
139 changes: 139 additions & 0 deletions app/cli/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from datetime import date, datetime
from logging import getLogger
from os.path import exists
from typing import Optional, Type, List, Union

from sqlalchemy.exc import SQLAlchemyError
from tabulate import tabulate

import settings
from app.database import Database
from app.models import BankApp, Transaction
from app.domain_classes import AccountType
from app.utils import UnitOfWork
from app.parsers import TransactionParser
from app.services import AccountService, TransactionService

_logger = getLogger(__name__)


class BankAppCli(BankApp):
def __init__(self, db: Type[Database]) -> None:
self.acc_service = AccountService()
self.trx_service = TransactionService()
self.db = db
self.account_id: Optional[int] = None
self.uow: Optional[UnitOfWork] = UnitOfWork(db) # imported instance
self.parser = TransactionParser()

self.menu_options = {
"0": ("Exit", self.exit_app),
"1": ("Import transactions (supported formats are: csv)", self.import_data),
"2": ("Show balance", self.show_balance),
"3": ("Search transactions for the a given period", self.search_transactions),
}
self.menu_msg = "\n".join(f"{k}: {v[0]}" for k, v in self.menu_options.items())

def main_menu(self) -> None:
self.pick_account()
while True:
choice = self.get_valid_action()
action = self.menu_options[choice][1]
action()

def import_data(self):
try:
trx_data = self.parser.parse_data(self.get_file_path())
_, balance = self.trx_service.create(self.uow, self.account_id, trx_data)
print(f"Transactions have been loaded successfully! Current balance: {balance}")
except (ValueError, SQLAlchemyError) as err:
_logger.error(err)

@classmethod
def get_file_path(cls) -> str:
while True:
file_path = input("Please provide the path to your file: ").strip("'\"")
if not exists(file_path):
print("Incorrect file path, please try again!")
else:
return file_path

def show_balance(self):
tar_get_date = self._get_date(mode="end_date")
print(
f"Your balance on {tar_get_date} is: ",
self.acc_service.get_balance(self.uow, self.account_id, tar_get_date),
)

def _search_transactions(self) -> List["Transaction"]:
return self.trx_service.get_by_date_range(
self.uow, self.account_id, self._get_date("start_date"), self._get_date("end_date")
)

def search_transactions(self) -> None:
transactions = self._search_transactions()
print(self._get_transaction_table(transactions) if transactions else "No transactions found!")

def get_valid_action(self):
print("\nPICK AN OPTION: ")
action = False
while action not in self.menu_options.keys():
action = input(f"{self.menu_msg}\n").strip()
return action

@staticmethod
def _get_date(mode: str):
if mode not in (allowed_modes := ("start_date", "end_date")):
raise ValueError(f"Invalid mode: {mode}. Allowed modes are {allowed_modes}")
today_date = date.today()

def compose__get_date_message() -> str:
date_example = datetime.strftime(today_date, settings.DATE_FORMAT)
action_description = (
"search from the oldest transaction\n"
if mode == allowed_modes[0]
else "pick today's date by default!\n"
)
return f"\nProvide the {mode} in the following {date_example} format or {action_description}\n press enter/return to {action_description}"

msg = compose__get_date_message()
while True:
tar_get_date = input(msg).strip()
if not tar_get_date:
return datetime.min.date() if mode == allowed_modes[0] else today_date
try:
tar_get_date = datetime.strptime(tar_get_date, settings.DATE_FORMAT).date()
if tar_get_date > today_date:
tar_get_date = today_date
except ValueError:
_logger.error("Incorrect data format")
else:
return tar_get_date

@staticmethod
def _get_account_type() -> Type[AccountType]:
acc_type: Union[Optional[str], Type[AccountType]] = None
while not acc_type:
acc_type = getattr(
AccountType, input("Pick an account: Debit or Credit (debit/credit): ").upper().strip(), ""
)
return acc_type

def pick_account(self) -> None:
acc_type = self._get_account_type()
existing_account = self.acc_service.get_by_type(self.uow, acc_type)
self.account_id = existing_account.id if existing_account else self.acc_service.create_one(self.uow, acc_type)

@classmethod
def _get_transaction_table(cls, transactions: list[Transaction]) -> str:
"""Return transactions in a tabular str format."""
return tabulate(
[(t.date, t.description, t.amount) for t in transactions],
headers=["Date", "Description", "Amount"],
colalign=("left", "left", "right"),
tablefmt="pretty",
)

@staticmethod
def exit_app():
exit(print("Goodbye!"))
8 changes: 4 additions & 4 deletions app/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import settings
from app.models import Base
from app.utils import Singleton


class Database:
class Database(metaclass=Singleton):
"""DB connection abstraction.
DB url format: ``dialect[+driver]://user:password@host/dbname[?key=value..]``, # pragma: allowlist secret
Expand All @@ -19,9 +20,8 @@ class Database:
def __init__(self, db_url: str = settings.DB_URL) -> None:
self.db_url: str = db_url
self.engine: Engine = create_engine(self.db_url)
self.session: Session = self.create_session()

def create_session(self) -> Session:
# Create tables if they don't exist
Base.metadata.create_all(self.engine)

def create_session(self) -> Session:
return sessionmaker(bind=self.engine)()
2 changes: 2 additions & 0 deletions app/domain_classes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .account_type import AccountType
from .transaction_data import TransactionData
6 changes: 6 additions & 0 deletions app/domain_classes/account_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class AccountType(Enum):
CREDIT: int = 1
DEBIT: int = 2
44 changes: 44 additions & 0 deletions app/domain_classes/transaction_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal, InvalidOperation

from typing import Optional, Type
import settings

@dataclass
class TransactionData:
date: date
amount: Decimal
description: str
account_id: Optional[int] = None

def __post_init__(self) -> None:
self._convert_str_to_date()
self._convert_str_to_decimal_amount()
self._check_description()

def set_account_id(self, account_id: int) -> Type["TransactionData"]:
self.account_id = account_id
return self

def _check_description(self) -> None:
if not self.description:
raise ValueError("Missing transaction description!")

def _convert_str_to_date(self, date_format: str = settings.DATE_FORMAT) -> None:
try:
converted_date = datetime.strptime(self.date, date_format).date()
except ValueError:
raise ValueError(f"Wrong date format! Provided value: {self.date} Please use {date_format}")
if converted_date > date.today():
raise ValueError("Transaction date is in the future!")
self.date = converted_date

def _convert_str_to_decimal_amount(self, rounding: str = settings.ROUNDING) -> None:
try:
converted_amount = Decimal(self.amount).quantize(Decimal("0.00"), rounding=rounding)
if not converted_amount:
raise ValueError
except (ValueError, InvalidOperation):
raise ValueError("Incorrect transaction amount!")
self.amount = converted_amount
117 changes: 0 additions & 117 deletions app/interface/cli.py

This file was deleted.

5 changes: 1 addition & 4 deletions app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
from .base import Base
from .account import Account
from .models import Base,Account, Transaction
from .bank_app import BankApp
from .orm_independent import AccountType, TransactionData
from .transaction import Transaction
Loading

0 comments on commit 7660f27

Please sign in to comment.