Skip to content

Commit

Permalink
feat(server): abstract database from backend
Browse files Browse the repository at this point in the history
Prepare for multi database driver support
  • Loading branch information
aymanbagabas committed Jul 10, 2023
1 parent 75172cf commit 6e95a80
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 116 deletions.
55 changes: 2 additions & 53 deletions server/backend/sqlite/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/charmbracelet/soft-serve/server/backend"
"github.com/jmoiron/sqlx"
"modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
"github.com/charmbracelet/soft-serve/server/db"
)

// Close closes the database.
Expand All @@ -19,7 +16,7 @@ func (d *SqliteBackend) Close() error {

// init creates the database.
func (d *SqliteBackend) init() error {
return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
return d.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
if _, err := tx.Exec(sqlCreateSettingsTable); err != nil {
return err
}
Expand Down Expand Up @@ -91,51 +88,3 @@ func (d *SqliteBackend) init() error {
return nil
})
}

func wrapDbErr(err error) error {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrNoRecord
}
if liteErr, ok := err.(*sqlite.Error); ok {
code := liteErr.Code()
if code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY ||
code == sqlite3.SQLITE_CONSTRAINT_UNIQUE {
return ErrDuplicateKey
}
}
}
return err
}

func wrapTx(db *sqlx.DB, ctx context.Context, fn func(tx *sqlx.Tx) error) error {
tx, err := db.BeginTxx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}

if err := fn(tx); err != nil {
return rollback(tx, err)
}

if err := tx.Commit(); err != nil {
if errors.Is(err, sql.ErrTxDone) {
// this is ok because whoever did finish the tx should have also written the error already.
return nil
}
return fmt.Errorf("failed to commit transaction: %w", err)
}

return nil
}

func rollback(tx *sqlx.Tx, err error) error {
if rerr := tx.Rollback(); rerr != nil {
if errors.Is(rerr, sql.ErrTxDone) {
return err
}
return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
}

return err
}
16 changes: 8 additions & 8 deletions server/backend/sqlite/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

"github.com/charmbracelet/soft-serve/git"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/jmoiron/sqlx"
"github.com/charmbracelet/soft-serve/server/db"
)

var _ backend.Repository = (*Repo)(nil)
Expand All @@ -19,7 +19,7 @@ var _ backend.Repository = (*Repo)(nil)
type Repo struct {
name string
path string
db *sqlx.DB
db *db.DB

// cache
// updatedAt is cached in "last-modified" file.
Expand All @@ -42,7 +42,7 @@ func (r *Repo) Description() string {
}

var desc string
if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
if err := r.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
return tx.Get(&desc, "SELECT description FROM repo WHERE name = ?", r.name)
}); err != nil {
return ""
Expand All @@ -63,7 +63,7 @@ func (r *Repo) IsMirror() bool {
}

var mirror bool
if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
if err := r.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
return tx.Get(&mirror, "SELECT mirror FROM repo WHERE name = ?", r.name)
}); err != nil {
return false
Expand All @@ -84,7 +84,7 @@ func (r *Repo) IsPrivate() bool {
}

var private bool
if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
if err := r.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
return tx.Get(&private, "SELECT private FROM repo WHERE name = ?", r.name)
}); err != nil {
return false
Expand Down Expand Up @@ -119,7 +119,7 @@ func (r *Repo) ProjectName() string {
}

var name string
if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
if err := r.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
return tx.Get(&name, "SELECT project_name FROM repo WHERE name = ?", r.name)
}); err != nil {
return ""
Expand All @@ -140,7 +140,7 @@ func (r *Repo) IsHidden() bool {
}

var hidden bool
if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
if err := r.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
return tx.Get(&hidden, "SELECT hidden FROM repo WHERE name = ?", r.name)
}); err != nil {
return false
Expand Down Expand Up @@ -170,7 +170,7 @@ func (r *Repo) UpdatedAt() time.Time {
}

if updatedAt.IsZero() {
if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
if err := r.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
return tx.Get(&updatedAt, "SELECT updated_at FROM repo WHERE name = ?", r.name)
}); err != nil {
return time.Time{}
Expand Down
62 changes: 31 additions & 31 deletions server/backend/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import (
"github.com/charmbracelet/soft-serve/git"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/db"
"github.com/charmbracelet/soft-serve/server/hooks"
"github.com/charmbracelet/soft-serve/server/utils"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite" // sqlite driver
)

Expand All @@ -25,7 +25,7 @@ type SqliteBackend struct { //nolint: revive
cfg *config.Config
ctx context.Context
dp string
db *sqlx.DB
db *db.DB
logger *log.Logger

// Repositories cache
Expand All @@ -46,7 +46,7 @@ func NewSqliteBackend(ctx context.Context) (*SqliteBackend, error) {
return nil, err
}

db, err := sqlx.Connect("sqlite", filepath.Join(dataPath, "soft-serve.db"+
db, err := db.Open("sqlite", filepath.Join(dataPath, "soft-serve.db"+
"?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)"))
if err != nil {
return nil, err
Expand Down Expand Up @@ -85,7 +85,7 @@ func (d SqliteBackend) WithContext(ctx context.Context) backend.Backend {
// It implements backend.Backend.
func (d *SqliteBackend) AllowKeyless() bool {
var allow bool
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
return tx.Get(&allow, "SELECT value FROM settings WHERE key = ?;", "allow_keyless")
}); err != nil {
return false
Expand All @@ -99,7 +99,7 @@ func (d *SqliteBackend) AllowKeyless() bool {
// It implements backend.Backend.
func (d *SqliteBackend) AnonAccess() backend.AccessLevel {
var level string
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
return tx.Get(&level, "SELECT value FROM settings WHERE key = ?;", "anon_access")
}); err != nil {
return backend.NoAccess
Expand All @@ -112,8 +112,8 @@ func (d *SqliteBackend) AnonAccess() backend.AccessLevel {
//
// It implements backend.Backend.
func (d *SqliteBackend) SetAllowKeyless(allow bool) error {
return wrapDbErr(
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return db.WrapError(
d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
_, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", allow, "allow_keyless")
return err
}),
Expand All @@ -124,8 +124,8 @@ func (d *SqliteBackend) SetAllowKeyless(allow bool) error {
//
// It implements backend.Backend.
func (d *SqliteBackend) SetAnonAccess(level backend.AccessLevel) error {
return wrapDbErr(
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return db.WrapError(
d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
_, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", level.String(), "anon_access")
return err
}),
Expand All @@ -144,7 +144,7 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt
repo := name + ".git"
rp := filepath.Join(d.reposPath(), repo)

if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
if _, err := tx.Exec(`INSERT INTO repo (name, project_name, description, private, mirror, hidden, updated_at)
VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`,
name, opts.ProjectName, opts.Description, opts.Private, opts.Mirror, opts.Hidden); err != nil {
Expand All @@ -160,7 +160,7 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt
return nil
}); err != nil {
d.logger.Debug("failed to create repository in database", "err", err)
return nil, wrapDbErr(err)
return nil, db.WrapError(err)
}

r := &Repo{
Expand Down Expand Up @@ -226,7 +226,7 @@ func (d *SqliteBackend) DeleteRepository(name string) error {
repo := name + ".git"
rp := filepath.Join(d.reposPath(), repo)

return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
// Delete repo from cache
defer d.cache.Delete(name)

Expand Down Expand Up @@ -263,7 +263,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
return ErrRepoExist
}

if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
// Delete cache
defer d.cache.Delete(oldName)

Expand All @@ -283,7 +283,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {

return nil
}); err != nil {
return wrapDbErr(err)
return db.WrapError(err)
}

return nil
Expand All @@ -295,7 +295,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {
repos := make([]backend.Repository, 0)

if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
rows, err := tx.Query("SELECT name FROM repo")
if err != nil {
return err
Expand Down Expand Up @@ -327,7 +327,7 @@ func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {

return nil
}); err != nil {
return nil, wrapDbErr(err)
return nil, db.WrapError(err)
}

return repos, nil
Expand All @@ -349,10 +349,10 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
}

var count int
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
return tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo)
}); err != nil {
return nil, wrapDbErr(err)
return nil, db.WrapError(err)
}

if count == 0 {
Expand Down Expand Up @@ -429,7 +429,7 @@ func (d *SqliteBackend) SetHidden(repo string, hidden bool) error {
// Delete cache
d.cache.Delete(repo)

return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return db.WrapError(d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
var count int
if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
return err
Expand Down Expand Up @@ -463,7 +463,7 @@ func (d *SqliteBackend) SetDescription(repo string, desc string) error {
// Delete cache
d.cache.Delete(repo)

return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
var count int
if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
return err
Expand All @@ -485,8 +485,8 @@ func (d *SqliteBackend) SetPrivate(repo string, private bool) error {
// Delete cache
d.cache.Delete(repo)

return wrapDbErr(
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return db.WrapError(
d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
var count int
if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
return err
Expand All @@ -509,8 +509,8 @@ func (d *SqliteBackend) SetProjectName(repo string, name string) error {
// Delete cache
d.cache.Delete(repo)

return wrapDbErr(
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return db.WrapError(
d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
var count int
if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
return err
Expand All @@ -534,7 +534,7 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error {
}

repo = utils.SanitizeRepo(repo)
return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return db.WrapError(d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
_, err := tx.Exec(`INSERT INTO collab (user_id, repo_id, updated_at)
VALUES (
(SELECT id FROM user WHERE username = ?),
Expand All @@ -552,13 +552,13 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error {
func (d *SqliteBackend) Collaborators(repo string) ([]string, error) {
repo = utils.SanitizeRepo(repo)
var users []string
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
return tx.Select(&users, `SELECT user.username FROM user
INNER JOIN collab ON user.id = collab.user_id
INNER JOIN repo ON repo.id = collab.repo_id
WHERE repo.name = ?`, repo)
}); err != nil {
return nil, wrapDbErr(err)
return nil, db.WrapError(err)
}

return users, nil
Expand All @@ -570,13 +570,13 @@ func (d *SqliteBackend) Collaborators(repo string) ([]string, error) {
func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, error) {
repo = utils.SanitizeRepo(repo)
var count int
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
if err := d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
return tx.Get(&count, `SELECT COUNT(*) FROM user
INNER JOIN collab ON user.id = collab.user_id
INNER JOIN repo ON repo.id = collab.repo_id
WHERE repo.name = ? AND user.username = ?`, repo, username)
}); err != nil {
return false, wrapDbErr(err)
return false, db.WrapError(err)
}

return count > 0, nil
Expand All @@ -587,8 +587,8 @@ func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, erro
// It implements backend.Backend.
func (d *SqliteBackend) RemoveCollaborator(repo string, username string) error {
repo = utils.SanitizeRepo(repo)
return wrapDbErr(
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return db.WrapError(
d.db.TransactionContext(d.ctx, func(tx *db.Tx) error {
_, err := tx.Exec(`DELETE FROM collab
WHERE user_id = (SELECT id FROM user WHERE username = ?)
AND repo_id = (SELECT id FROM repo WHERE name = ?)`, username, repo)
Expand Down
Loading

0 comments on commit 6e95a80

Please sign in to comment.