Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/2024-unit-of-work'
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjan Egges committed May 1, 2024
2 parents d31e297 + 4a9c1fa commit 0623ab9
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 0 deletions.
15 changes: 15 additions & 0 deletions 2024/unit_of_work/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[tool.poetry]
name = "unit-of-work"
version = "0.1.0"
description = ""
authors = ["Your Name <[email protected]>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
SQLAlchemy = "^2.0.28"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
161 changes: 161 additions & 0 deletions 2024/unit_of_work/unit_of_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import contextlib
from typing import Optional

from sqlalchemy import ForeignKey, String, create_engine
from sqlalchemy.orm import (
Mapped,
declarative_base,
mapped_column,
relationship,
scoped_session,
sessionmaker,
)

DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(DATABASE_URL)
Session = scoped_session(sessionmaker(bind=engine))
Base = declarative_base()


class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(
primary_key=True, index=True, unique=True, autoincrement=True
)
name: Mapped[str] = mapped_column()
user_detail: Mapped["UserDetail"] = relationship(
"UserDetail", back_populates="user", uselist=False, cascade="all, delete-orphan"
)
user_preference: Mapped["UserPreference"] = relationship(
"UserPreference",
back_populates="user",
uselist=False,
cascade="all, delete-orphan",
)

def __repr__(self) -> str:
return f"User(id={self.id!r}, name={self.name!r})"


class UserDetail(Base):
__tablename__ = "user_details"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), unique=True)
details: Mapped[str] = mapped_column()
user: Mapped["User"] = relationship("User", back_populates="user_detail")

def __repr__(self) -> str:
return f"UserDetail(id={self.id!r}, details={self.details!r})"


class UserPreference(Base):
__tablename__ = "user_preferences"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), unique=True)
preference: Mapped[str] = mapped_column(String)
user: Mapped["User"] = relationship("User", back_populates="user_preference")

def __repr__(self) -> str:
return f"UserPreference(id={self.id!r}, preference={self.preference!r})"


def create_user(
session: Session,
name: str,
details: Optional[str] = None,
preferences: Optional[str] = None,
) -> User:
user = User(name=name)
if details:
user.user_detail = UserDetail(details=details)
if preferences:
user.user_preference = UserPreference(preference=preferences)
session.add(user)
return user


def update_user(
session: Session,
user_id: int,
name: Optional[str] = None,
details: Optional[str] = None,
preferences: Optional[str] = None,
) -> User:
user: Optional[User] = session.get(User, user_id)
if not user:
raise ValueError(f"User {user_id} not found")
if name:
user.name = name
if details:
if user.user_detail:
user.user_detail.details = details
else:
user.user_detail = UserDetail(details=details)
if preferences:
if user.user_preference:
user.user_preference.preference = preferences
else:
user.user_preference = UserPreference(preference=preferences)
return user


def delete_user(session: Session, user_id: int) -> None:
user: Optional[User] = session.get(User, user_id)
if not user:
raise ValueError(f"User {user_id} not found")
session.delete(user)


def get_user(session: Session, user_id: int) -> User:
user = session.get(User, user_id)
if not user:
raise ValueError(f"User {user_id} not found")
return user


def get_users(session: Session) -> list[User]:
users: list[User] = session.query(User).all()
return users


Base.metadata.create_all(engine)


@contextlib.contextmanager
def unit():
session = Session()
try:
yield session
session.commit()
except Exception as e:
print("Rolling back")
session.rollback()
raise e
finally:
session.close()


def main() -> None:
try:
with unit() as session:
user = create_user(session, "Arjan", "details", "preferences")
print(user, user.user_detail, user.user_preference)
session.flush()
print(user)
user = update_user(
session, 1, "Arjan Updated", "more details", "updated preferences"
)
print(user, user.user_detail, user.user_preference)
user = get_user(session, user.id)
print(user)
delete_user(session, 1)
except ValueError as e:
print(e)

with unit() as session:
users = get_users(session)
print(users)


if __name__ == "__main__":
main()
73 changes: 73 additions & 0 deletions 2024/unit_of_work/unit_of_work_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from dataclasses import dataclass, field


@dataclass
class User:
username: str


@dataclass
class UnitOfWork:
new_users: list = field(default_factory=list)
dirty_users: list = field(default_factory=list)
removed_users: list = field(default_factory=list)

def register_new(self, user: User) -> None:
self.new_users.append(user)

def register_dirty(self, user: User) -> None:
if user not in self.dirty_users:
self.dirty_users.append(user)

def register_removed(self, user: User) -> None:
self.removed_users.append(user)

def commit(self) -> None:
self.insert_new()
self.update_dirty()
self.delete_removed()

def insert_new(self):
for obj in self.new_users:
# Insert new object into the database
print(f"Inserting {obj}")
# For demonstration, pretend we insert into a database here
self.new_users.clear()

def update_dirty(self):
for obj in self.dirty_users:
# Update existing object in the database
print(f"Updating {obj}")
# For demonstration, pretend we update in a database here
self.dirty_users.clear()

def delete_removed(self):
for obj in self.removed_users:
# Remove object from the database
print(f"Deleting {obj}")
# For demonstration, pretend we delete from a database here
self.removed_users.clear()


def main() -> None:
# Creating a new Unit of Work
uow = UnitOfWork()

# Creating a new user
new_user = User("john_doe")
uow.register_new(new_user)

# Simulate updating a user
existing_user = User("existing_user")
uow.register_dirty(existing_user)

# Simulate removing a user
removed_user = User("removed_user")
uow.register_removed(removed_user)

# Committing changes
uow.commit()


if __name__ == "__main__":
main()
114 changes: 114 additions & 0 deletions 2024/unit_of_work/unit_of_work_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import logging
import sqlite3
from typing import Any, Optional

logging.basicConfig(level=logging.INFO)


class DBConnectionHandler:
def __init__(self, db_name: str):
self.db_name: str = db_name
self.connection: Optional[sqlite3.Connection] = None

def __enter__(self) -> sqlite3.Connection:
self.connection = sqlite3.connect(self.db_name)
self.connection.row_factory = sqlite3.Row
return self.connection

def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[Any],
) -> None:
assert self.connection is not None
self.connection.close()


def create_tables() -> None:
with DBConnectionHandler("example.db") as connection:
cursor: sqlite3.Cursor = connection.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
quantity INTEGER NOT NULL
)
""")
connection.commit()


def drop_tables() -> None:
with DBConnectionHandler("example.db") as connection:
cursor: sqlite3.Cursor = connection.cursor()
cursor.execute("DROP TABLE IF EXISTS items")
connection.commit()


class Repository:
def __init__(self, connection: sqlite3.Connection):
self.connection: sqlite3.Connection = connection

def add(self, name: str, quantity: int) -> None:
cursor: sqlite3.Cursor = self.connection.cursor()
cursor.execute(
"INSERT INTO items (name, quantity) VALUES (?, ?)", (name, quantity)
)
self.connection.commit()

def all(self) -> list[dict[str, Any]]:
cursor: sqlite3.Cursor = self.connection.cursor()
cursor.execute("SELECT * FROM items")
return [dict(row) for row in cursor.fetchall()]


class UnitOfWork:
def __init__(self, db_name: str = "example.db"):
self.db_name: str = db_name
self.connection: Optional[sqlite3.Connection] = None
self.repository: Optional[Repository] = None

def __enter__(self) -> "UnitOfWork":
self.connection = sqlite3.connect(self.db_name)
self.connection.execute("BEGIN")
self.connection.row_factory = sqlite3.Row
self.repository = Repository(self.connection)
return self

def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[Any],
) -> None:
assert self.connection is not None
if exc_val:
self.connection.rollback()
else:
self.connection.commit()
self.connection.close()


def main() -> None:
create_tables()

try:
with UnitOfWork() as uow:
assert uow.repository is not None
uow.repository.add("Apple", 10)
uow.repository.add("Banana", 20)
# raise Exception("Something went wrong")
except Exception as e:
logging.error(f"Error during database operation: {e}")

with DBConnectionHandler("example.db") as connection:
repo: Repository = Repository(connection)
logging.info("Items in the database:")
for item in repo.all():
logging.info(item)

drop_tables()


if __name__ == "__main__":
main()

0 comments on commit 0623ab9

Please sign in to comment.