Skip to content

Commit

Permalink
feat(db): add database migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
aymanbagabas committed Jul 11, 2023
1 parent 906b6a4 commit fa1b608
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 1 deletion.
5 changes: 4 additions & 1 deletion server/db/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ import (
var (
// ErrDuplicateKey is a constraint violation error.
ErrDuplicateKey = errors.New("duplicate key value violates table constraint")

// ErrRecordNotFound is returned when a record is not found.
ErrRecordNotFound = errors.New("record not found")
)

// WrapError is a convenient function that unite various database driver
// errors to consistent errors.
func WrapError(err error) error {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return err
return ErrRecordNotFound
}
// Handle sqlite constraint error.
if liteErr, ok := err.(*sqlite.Error); ok {
Expand Down
64 changes: 64 additions & 0 deletions server/db/migrate/0001_create_tables.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package migrate

import (
"context"
"fmt"

"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/db"
"github.com/charmbracelet/soft-serve/server/sshutils"

Check failure on line 9 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / coverage

no required module provides package github.com/charmbracelet/soft-serve/server/sshutils; to add it:

Check failure on line 9 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / build / govulncheck

no required module provides package github.com/charmbracelet/soft-serve/server/sshutils; to add it:

Check failure on line 9 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / build / govulncheck

could not import github.com/charmbracelet/soft-serve/server/sshutils (invalid package name: "")

Check failure on line 9 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / build / build (^1, ubuntu-latest)

no required module provides package github.com/charmbracelet/soft-serve/server/sshutils; to add it:
"github.com/charmbracelet/soft-serve/server/store"

Check failure on line 10 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / coverage

no required module provides package github.com/charmbracelet/soft-serve/server/store; to add it:

Check failure on line 10 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / build / govulncheck

no required module provides package github.com/charmbracelet/soft-serve/server/store; to add it:

Check failure on line 10 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / build / govulncheck

could not import github.com/charmbracelet/soft-serve/server/store (invalid package name: "")

Check failure on line 10 in server/db/migrate/0001_create_tables.go

View workflow job for this annotation

GitHub Actions / build / build (^1, ubuntu-latest)

no required module provides package github.com/charmbracelet/soft-serve/server/store; to add it:
)

const (
createTablesName = "create tables"
createTablesVersion = 1
)

var createTables = Migration{
Version: createTablesVersion,
Name: createTablesName,
Migrate: func(ctx context.Context, tx *db.Tx) error {
cfg := config.FromContext(ctx)

if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil {
return err
}

// Insert default settings
insertSettings := "INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
insertSettings = tx.Rebind(insertSettings)
settings := []struct {
Key string
Value string
}{
{"allow_keyless", "true"},
{"anon_access", store.ReadOnlyAccess.String()},
{"init", "true"},
}

for _, s := range settings {
if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil {
return fmt.Errorf("inserting default settings %q: %w", s.Key, err)
}
}

// Insert default user
insertUser := tx.Rebind("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)")
if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil {
return err
}

for _, k := range cfg.AdminKeys() {
ak := sshutils.MarshalAuthorizedKey(k)
if _, err := tx.ExecContext(ctx, "INSERT INTO public_key (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", 1, ak); err != nil {
return err
}
}

return nil
},
Rollback: func(ctx context.Context, tx *db.Tx) error {
return migrateDown(ctx, tx, createTablesVersion, createTablesName)
},
}
6 changes: 6 additions & 0 deletions server/db/migrate/0001_create_tables_sqlite.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

DROP TABLE IF EXISTS collab;
DROP TABLE IF EXISTS repo;
DROP TABLE IF EXISTS public_key;
DROP TABLE IF EXISTS user;
DROP TABLE IF EXISTS settings;
54 changes: 54 additions & 0 deletions server/db/migrate/0001_create_tables_sqlite.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
CREATE TABLE IF NOT EXISTS settings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT NOT NULL UNIQUE,
value TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL
);
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
admin BOOLEAN NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL
);
CREATE TABLE IF NOT EXISTS public_key (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
public_key TEXT NOT NULL UNIQUE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL,
UNIQUE (user_id, public_key),
CONSTRAINT user_id_fk
FOREIGN KEY(user_id) REFERENCES user(id)
ON DELETE CASCADE
ON UPDATE CASCADE
);
CREATE TABLE IF NOT EXISTS repo (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
project_name TEXT NOT NULL,
description TEXT NOT NULL,
private BOOLEAN NOT NULL,
mirror BOOLEAN NOT NULL,
hidden BOOLEAN NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL
);
CREATE TABLE IF NOT EXISTS collab (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
repo_id INTEGER NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL,
UNIQUE (user_id, repo_id),
CONSTRAINT user_id_fk
FOREIGN KEY(user_id) REFERENCES user(id)
ON DELETE CASCADE
ON UPDATE CASCADE,
CONSTRAINT repo_id_fk
FOREIGN KEY(repo_id) REFERENCES repo(id)
ON DELETE CASCADE
ON UPDATE CASCADE
);

140 changes: 140 additions & 0 deletions server/db/migrate/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package migrate

import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/server/db"
)

// MigrateFunc is a function that executes a migration.
type MigrateFunc func(ctx context.Context, tx *db.Tx) error

// Migration is a struct that contains the name of the migration and the
// function to execute it.
type Migration struct {
Version int
Name string
Migrate MigrateFunc
Rollback MigrateFunc
}

// Migrations is a database model to store migrations.
type Migrations struct {
ID int64 `db:"id"`
Name string `db:"name"`
Version int `db:"version"`
}

func (Migrations) schema(driverName string) string {
switch driverName {
case "sqlite3", "sqlite":
return `CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
version INTEGER NOT NULL
);
`
case "postgres":
return `CREATE TABLE IF NOT EXISTS migrations (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
version INTEGER NOT NULL
);
`
case "mysql":
return `CREATE TABLE IF NOT EXISTS migrations (
id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
name TEXT NOT NULL,
version INT NOT NULL
);
`
default:
panic("unknown driver")
}
}

// Migrate runs the migrations.
func Migrate(ctx context.Context, dbx *db.DB) error {
logger := log.FromContext(ctx).WithPrefix("migrate")
return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
if !hasTable(tx, "migrations") {
if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil {
return err
}
}

var migrs Migrations
if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return err
}
}

for _, m := range migrations {
if m.Version <= migrs.Version {
continue
}

logger.Infof("running migration %d:%s", m.Version, m.Name)
if err := m.Migrate(ctx, tx); err != nil {
return err
}

if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
return err
}
}

return nil
})
}

// Rollback rolls back a migration.
func Rollback(ctx context.Context, dbx *db.DB) error {
logger := log.FromContext(ctx).WithPrefix("migrate")
return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
var migrs Migrations
if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("there are no migrations to rollback: %w", err)
}
}

if len(migrations) < migrs.Version {
return fmt.Errorf("there are no migrations to rollback")
}

m := migrations[migrs.Version-1]
logger.Infof("rolling back migration %d:%s", m.Version, m.Name)
if err := m.Rollback(ctx, tx); err != nil {
return err
}

if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE id = ?"), migrs.ID); err != nil {
return err
}

return nil
})
}

func hasTable(tx *db.Tx, tableName string) bool {
var query string
switch tx.DriverName() {
case "sqlite3", "sqlite":
query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
case "postgres":
fallthrough
case "mysql":
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
}

query = tx.Rebind(query)
var name string
err := tx.Get(&name, query, tableName)
return err == nil
}
62 changes: 62 additions & 0 deletions server/db/migrate/migrations.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package migrate

import (
"context"
"embed"
"fmt"
"regexp"
"strings"

"github.com/charmbracelet/soft-serve/server/db"
)

//go:embed *.sql
var sqls embed.FS

// Keep this in order of execution, oldest to newest.
var migrations = []Migration{
createTables,
}

func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error {
direction := "up"
if down {
direction = "down"
}

driverName := tx.DriverName()
if driverName == "sqlite3" {
driverName = "sqlite"
}

fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction)
sqlstr, err := sqls.ReadFile(fn)
if err != nil {
return err
}

if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil {
return err
}

return nil
}

func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error {
return execMigration(ctx, tx, version, name, false)
}

func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error {
return execMigration(ctx, tx, version, name, true)
}

var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")

func toSnakeCase(str string) string {
str = strings.ReplaceAll(str, "-", "_")
str = strings.ReplaceAll(str, " ", "_")
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
return strings.ToLower(snake)
}

0 comments on commit fa1b608

Please sign in to comment.