-
Notifications
You must be signed in to change notification settings - Fork 146
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/2024-unit-of-work'
- Loading branch information
Showing
4 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[tool.poetry] | ||
name = "unit-of-work" | ||
version = "0.1.0" | ||
description = "" | ||
authors = ["Your Name <[email protected]>"] | ||
readme = "README.md" | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.10" | ||
SQLAlchemy = "^2.0.28" | ||
|
||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import contextlib | ||
from typing import Optional | ||
|
||
from sqlalchemy import ForeignKey, String, create_engine | ||
from sqlalchemy.orm import ( | ||
Mapped, | ||
declarative_base, | ||
mapped_column, | ||
relationship, | ||
scoped_session, | ||
sessionmaker, | ||
) | ||
|
||
DATABASE_URL = "sqlite:///:memory:" | ||
engine = create_engine(DATABASE_URL) | ||
Session = scoped_session(sessionmaker(bind=engine)) | ||
Base = declarative_base() | ||
|
||
|
||
class User(Base): | ||
__tablename__ = "users" | ||
id: Mapped[int] = mapped_column( | ||
primary_key=True, index=True, unique=True, autoincrement=True | ||
) | ||
name: Mapped[str] = mapped_column() | ||
user_detail: Mapped["UserDetail"] = relationship( | ||
"UserDetail", back_populates="user", uselist=False, cascade="all, delete-orphan" | ||
) | ||
user_preference: Mapped["UserPreference"] = relationship( | ||
"UserPreference", | ||
back_populates="user", | ||
uselist=False, | ||
cascade="all, delete-orphan", | ||
) | ||
|
||
def __repr__(self) -> str: | ||
return f"User(id={self.id!r}, name={self.name!r})" | ||
|
||
|
||
class UserDetail(Base): | ||
__tablename__ = "user_details" | ||
id: Mapped[int] = mapped_column(primary_key=True) | ||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), unique=True) | ||
details: Mapped[str] = mapped_column() | ||
user: Mapped["User"] = relationship("User", back_populates="user_detail") | ||
|
||
def __repr__(self) -> str: | ||
return f"UserDetail(id={self.id!r}, details={self.details!r})" | ||
|
||
|
||
class UserPreference(Base): | ||
__tablename__ = "user_preferences" | ||
id: Mapped[int] = mapped_column(primary_key=True) | ||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), unique=True) | ||
preference: Mapped[str] = mapped_column(String) | ||
user: Mapped["User"] = relationship("User", back_populates="user_preference") | ||
|
||
def __repr__(self) -> str: | ||
return f"UserPreference(id={self.id!r}, preference={self.preference!r})" | ||
|
||
|
||
def create_user( | ||
session: Session, | ||
name: str, | ||
details: Optional[str] = None, | ||
preferences: Optional[str] = None, | ||
) -> User: | ||
user = User(name=name) | ||
if details: | ||
user.user_detail = UserDetail(details=details) | ||
if preferences: | ||
user.user_preference = UserPreference(preference=preferences) | ||
session.add(user) | ||
return user | ||
|
||
|
||
def update_user( | ||
session: Session, | ||
user_id: int, | ||
name: Optional[str] = None, | ||
details: Optional[str] = None, | ||
preferences: Optional[str] = None, | ||
) -> User: | ||
user: Optional[User] = session.get(User, user_id) | ||
if not user: | ||
raise ValueError(f"User {user_id} not found") | ||
if name: | ||
user.name = name | ||
if details: | ||
if user.user_detail: | ||
user.user_detail.details = details | ||
else: | ||
user.user_detail = UserDetail(details=details) | ||
if preferences: | ||
if user.user_preference: | ||
user.user_preference.preference = preferences | ||
else: | ||
user.user_preference = UserPreference(preference=preferences) | ||
return user | ||
|
||
|
||
def delete_user(session: Session, user_id: int) -> None: | ||
user: Optional[User] = session.get(User, user_id) | ||
if not user: | ||
raise ValueError(f"User {user_id} not found") | ||
session.delete(user) | ||
|
||
|
||
def get_user(session: Session, user_id: int) -> User: | ||
user = session.get(User, user_id) | ||
if not user: | ||
raise ValueError(f"User {user_id} not found") | ||
return user | ||
|
||
|
||
def get_users(session: Session) -> list[User]: | ||
users: list[User] = session.query(User).all() | ||
return users | ||
|
||
|
||
Base.metadata.create_all(engine) | ||
|
||
|
||
@contextlib.contextmanager | ||
def unit(): | ||
session = Session() | ||
try: | ||
yield session | ||
session.commit() | ||
except Exception as e: | ||
print("Rolling back") | ||
session.rollback() | ||
raise e | ||
finally: | ||
session.close() | ||
|
||
|
||
def main() -> None: | ||
try: | ||
with unit() as session: | ||
user = create_user(session, "Arjan", "details", "preferences") | ||
print(user, user.user_detail, user.user_preference) | ||
session.flush() | ||
print(user) | ||
user = update_user( | ||
session, 1, "Arjan Updated", "more details", "updated preferences" | ||
) | ||
print(user, user.user_detail, user.user_preference) | ||
user = get_user(session, user.id) | ||
print(user) | ||
delete_user(session, 1) | ||
except ValueError as e: | ||
print(e) | ||
|
||
with unit() as session: | ||
users = get_users(session) | ||
print(users) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass | ||
class User: | ||
username: str | ||
|
||
|
||
@dataclass | ||
class UnitOfWork: | ||
new_users: list = field(default_factory=list) | ||
dirty_users: list = field(default_factory=list) | ||
removed_users: list = field(default_factory=list) | ||
|
||
def register_new(self, user: User) -> None: | ||
self.new_users.append(user) | ||
|
||
def register_dirty(self, user: User) -> None: | ||
if user not in self.dirty_users: | ||
self.dirty_users.append(user) | ||
|
||
def register_removed(self, user: User) -> None: | ||
self.removed_users.append(user) | ||
|
||
def commit(self) -> None: | ||
self.insert_new() | ||
self.update_dirty() | ||
self.delete_removed() | ||
|
||
def insert_new(self): | ||
for obj in self.new_users: | ||
# Insert new object into the database | ||
print(f"Inserting {obj}") | ||
# For demonstration, pretend we insert into a database here | ||
self.new_users.clear() | ||
|
||
def update_dirty(self): | ||
for obj in self.dirty_users: | ||
# Update existing object in the database | ||
print(f"Updating {obj}") | ||
# For demonstration, pretend we update in a database here | ||
self.dirty_users.clear() | ||
|
||
def delete_removed(self): | ||
for obj in self.removed_users: | ||
# Remove object from the database | ||
print(f"Deleting {obj}") | ||
# For demonstration, pretend we delete from a database here | ||
self.removed_users.clear() | ||
|
||
|
||
def main() -> None: | ||
# Creating a new Unit of Work | ||
uow = UnitOfWork() | ||
|
||
# Creating a new user | ||
new_user = User("john_doe") | ||
uow.register_new(new_user) | ||
|
||
# Simulate updating a user | ||
existing_user = User("existing_user") | ||
uow.register_dirty(existing_user) | ||
|
||
# Simulate removing a user | ||
removed_user = User("removed_user") | ||
uow.register_removed(removed_user) | ||
|
||
# Committing changes | ||
uow.commit() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import logging | ||
import sqlite3 | ||
from typing import Any, Optional | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class DBConnectionHandler: | ||
def __init__(self, db_name: str): | ||
self.db_name: str = db_name | ||
self.connection: Optional[sqlite3.Connection] = None | ||
|
||
def __enter__(self) -> sqlite3.Connection: | ||
self.connection = sqlite3.connect(self.db_name) | ||
self.connection.row_factory = sqlite3.Row | ||
return self.connection | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[type], | ||
exc_val: Optional[Exception], | ||
exc_tb: Optional[Any], | ||
) -> None: | ||
assert self.connection is not None | ||
self.connection.close() | ||
|
||
|
||
def create_tables() -> None: | ||
with DBConnectionHandler("example.db") as connection: | ||
cursor: sqlite3.Cursor = connection.cursor() | ||
cursor.execute(""" | ||
CREATE TABLE IF NOT EXISTS items ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
name TEXT NOT NULL, | ||
quantity INTEGER NOT NULL | ||
) | ||
""") | ||
connection.commit() | ||
|
||
|
||
def drop_tables() -> None: | ||
with DBConnectionHandler("example.db") as connection: | ||
cursor: sqlite3.Cursor = connection.cursor() | ||
cursor.execute("DROP TABLE IF EXISTS items") | ||
connection.commit() | ||
|
||
|
||
class Repository: | ||
def __init__(self, connection: sqlite3.Connection): | ||
self.connection: sqlite3.Connection = connection | ||
|
||
def add(self, name: str, quantity: int) -> None: | ||
cursor: sqlite3.Cursor = self.connection.cursor() | ||
cursor.execute( | ||
"INSERT INTO items (name, quantity) VALUES (?, ?)", (name, quantity) | ||
) | ||
self.connection.commit() | ||
|
||
def all(self) -> list[dict[str, Any]]: | ||
cursor: sqlite3.Cursor = self.connection.cursor() | ||
cursor.execute("SELECT * FROM items") | ||
return [dict(row) for row in cursor.fetchall()] | ||
|
||
|
||
class UnitOfWork: | ||
def __init__(self, db_name: str = "example.db"): | ||
self.db_name: str = db_name | ||
self.connection: Optional[sqlite3.Connection] = None | ||
self.repository: Optional[Repository] = None | ||
|
||
def __enter__(self) -> "UnitOfWork": | ||
self.connection = sqlite3.connect(self.db_name) | ||
self.connection.execute("BEGIN") | ||
self.connection.row_factory = sqlite3.Row | ||
self.repository = Repository(self.connection) | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[type], | ||
exc_val: Optional[Exception], | ||
exc_tb: Optional[Any], | ||
) -> None: | ||
assert self.connection is not None | ||
if exc_val: | ||
self.connection.rollback() | ||
else: | ||
self.connection.commit() | ||
self.connection.close() | ||
|
||
|
||
def main() -> None: | ||
create_tables() | ||
|
||
try: | ||
with UnitOfWork() as uow: | ||
assert uow.repository is not None | ||
uow.repository.add("Apple", 10) | ||
uow.repository.add("Banana", 20) | ||
# raise Exception("Something went wrong") | ||
except Exception as e: | ||
logging.error(f"Error during database operation: {e}") | ||
|
||
with DBConnectionHandler("example.db") as connection: | ||
repo: Repository = Repository(connection) | ||
logging.info("Items in the database:") | ||
for item in repo.all(): | ||
logging.info(item) | ||
|
||
drop_tables() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |