-
Notifications
You must be signed in to change notification settings - Fork 135
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
package migrate | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
|
||
"github.com/charmbracelet/soft-serve/server/config" | ||
"github.com/charmbracelet/soft-serve/server/db" | ||
"github.com/charmbracelet/soft-serve/server/sshutils" | ||
Check failure on line 9 in server/db/migrate/0001_create_tables.go GitHub Actions / coverage
Check failure on line 9 in server/db/migrate/0001_create_tables.go GitHub Actions / build / govulncheck
Check failure on line 9 in server/db/migrate/0001_create_tables.go GitHub Actions / build / govulncheck
|
||
"github.com/charmbracelet/soft-serve/server/store" | ||
Check failure on line 10 in server/db/migrate/0001_create_tables.go GitHub Actions / coverage
Check failure on line 10 in server/db/migrate/0001_create_tables.go GitHub Actions / build / govulncheck
Check failure on line 10 in server/db/migrate/0001_create_tables.go GitHub Actions / build / govulncheck
|
||
) | ||
|
||
const ( | ||
createTablesName = "create tables" | ||
createTablesVersion = 1 | ||
) | ||
|
||
var createTables = Migration{ | ||
Version: createTablesVersion, | ||
Name: createTablesName, | ||
Migrate: func(ctx context.Context, tx *db.Tx) error { | ||
cfg := config.FromContext(ctx) | ||
|
||
if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil { | ||
return err | ||
} | ||
|
||
// Insert default settings | ||
insertSettings := "INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" | ||
insertSettings = tx.Rebind(insertSettings) | ||
settings := []struct { | ||
Key string | ||
Value string | ||
}{ | ||
{"allow_keyless", "true"}, | ||
{"anon_access", store.ReadOnlyAccess.String()}, | ||
{"init", "true"}, | ||
} | ||
|
||
for _, s := range settings { | ||
if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil { | ||
return fmt.Errorf("inserting default settings %q: %w", s.Key, err) | ||
} | ||
} | ||
|
||
// Insert default user | ||
insertUser := tx.Rebind("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)") | ||
if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil { | ||
return err | ||
} | ||
|
||
for _, k := range cfg.AdminKeys() { | ||
ak := sshutils.MarshalAuthorizedKey(k) | ||
if _, err := tx.ExecContext(ctx, "INSERT INTO public_key (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", 1, ak); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
}, | ||
Rollback: func(ctx context.Context, tx *db.Tx) error { | ||
return migrateDown(ctx, tx, createTablesVersion, createTablesName) | ||
}, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
DROP TABLE IF EXISTS collab; | ||
DROP TABLE IF EXISTS repo; | ||
DROP TABLE IF EXISTS public_key; | ||
DROP TABLE IF EXISTS user; | ||
DROP TABLE IF EXISTS settings; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
CREATE TABLE IF NOT EXISTS settings ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
key TEXT NOT NULL UNIQUE, | ||
value TEXT NOT NULL, | ||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
updated_at DATETIME NOT NULL | ||
); | ||
CREATE TABLE IF NOT EXISTS user ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
username TEXT UNIQUE, | ||
admin BOOLEAN NOT NULL, | ||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
updated_at DATETIME NOT NULL | ||
); | ||
CREATE TABLE IF NOT EXISTS public_key ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
user_id INTEGER NOT NULL, | ||
public_key TEXT NOT NULL UNIQUE, | ||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
updated_at DATETIME NOT NULL, | ||
UNIQUE (user_id, public_key), | ||
CONSTRAINT user_id_fk | ||
FOREIGN KEY(user_id) REFERENCES user(id) | ||
ON DELETE CASCADE | ||
ON UPDATE CASCADE | ||
); | ||
CREATE TABLE IF NOT EXISTS repo ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
name TEXT NOT NULL UNIQUE, | ||
project_name TEXT NOT NULL, | ||
description TEXT NOT NULL, | ||
private BOOLEAN NOT NULL, | ||
mirror BOOLEAN NOT NULL, | ||
hidden BOOLEAN NOT NULL, | ||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
updated_at DATETIME NOT NULL | ||
); | ||
CREATE TABLE IF NOT EXISTS collab ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
user_id INTEGER NOT NULL, | ||
repo_id INTEGER NOT NULL, | ||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
updated_at DATETIME NOT NULL, | ||
UNIQUE (user_id, repo_id), | ||
CONSTRAINT user_id_fk | ||
FOREIGN KEY(user_id) REFERENCES user(id) | ||
ON DELETE CASCADE | ||
ON UPDATE CASCADE, | ||
CONSTRAINT repo_id_fk | ||
FOREIGN KEY(repo_id) REFERENCES repo(id) | ||
ON DELETE CASCADE | ||
ON UPDATE CASCADE | ||
); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
package migrate | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/charmbracelet/log" | ||
"github.com/charmbracelet/soft-serve/server/db" | ||
) | ||
|
||
// MigrateFunc is a function that executes a migration. | ||
type MigrateFunc func(ctx context.Context, tx *db.Tx) error | ||
|
||
// Migration is a struct that contains the name of the migration and the | ||
// function to execute it. | ||
type Migration struct { | ||
Version int | ||
Name string | ||
Migrate MigrateFunc | ||
Rollback MigrateFunc | ||
} | ||
|
||
// Migrations is a database model to store migrations. | ||
type Migrations struct { | ||
ID int64 `db:"id"` | ||
Name string `db:"name"` | ||
Version int `db:"version"` | ||
} | ||
|
||
func (Migrations) schema(driverName string) string { | ||
switch driverName { | ||
case "sqlite3", "sqlite": | ||
return `CREATE TABLE IF NOT EXISTS migrations ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
name TEXT NOT NULL, | ||
version INTEGER NOT NULL | ||
); | ||
` | ||
case "postgres": | ||
return `CREATE TABLE IF NOT EXISTS migrations ( | ||
id SERIAL PRIMARY KEY, | ||
name TEXT NOT NULL, | ||
version INTEGER NOT NULL | ||
); | ||
` | ||
case "mysql": | ||
return `CREATE TABLE IF NOT EXISTS migrations ( | ||
id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, | ||
name TEXT NOT NULL, | ||
version INT NOT NULL | ||
); | ||
` | ||
default: | ||
panic("unknown driver") | ||
} | ||
} | ||
|
||
// Migrate runs the migrations. | ||
func Migrate(ctx context.Context, dbx *db.DB) error { | ||
logger := log.FromContext(ctx).WithPrefix("migrate") | ||
return dbx.TransactionContext(ctx, func(tx *db.Tx) error { | ||
if !hasTable(tx, "migrations") { | ||
if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
var migrs Migrations | ||
if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { | ||
if !errors.Is(err, sql.ErrNoRows) { | ||
return err | ||
} | ||
} | ||
|
||
for _, m := range migrations { | ||
if m.Version <= migrs.Version { | ||
continue | ||
} | ||
|
||
logger.Infof("running migration %d:%s", m.Version, m.Name) | ||
if err := m.Migrate(ctx, tx); err != nil { | ||
return err | ||
} | ||
|
||
if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
}) | ||
} | ||
|
||
// Rollback rolls back a migration. | ||
func Rollback(ctx context.Context, dbx *db.DB) error { | ||
logger := log.FromContext(ctx).WithPrefix("migrate") | ||
return dbx.TransactionContext(ctx, func(tx *db.Tx) error { | ||
var migrs Migrations | ||
if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { | ||
if !errors.Is(err, sql.ErrNoRows) { | ||
return fmt.Errorf("there are no migrations to rollback: %w", err) | ||
} | ||
} | ||
|
||
if len(migrations) < migrs.Version { | ||
return fmt.Errorf("there are no migrations to rollback") | ||
} | ||
|
||
m := migrations[migrs.Version-1] | ||
logger.Infof("rolling back migration %d:%s", m.Version, m.Name) | ||
if err := m.Rollback(ctx, tx); err != nil { | ||
return err | ||
} | ||
|
||
if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE id = ?"), migrs.ID); err != nil { | ||
return err | ||
} | ||
|
||
return nil | ||
}) | ||
} | ||
|
||
func hasTable(tx *db.Tx, tableName string) bool { | ||
var query string | ||
switch tx.DriverName() { | ||
case "sqlite3", "sqlite": | ||
query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" | ||
case "postgres": | ||
fallthrough | ||
case "mysql": | ||
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?" | ||
} | ||
|
||
query = tx.Rebind(query) | ||
var name string | ||
err := tx.Get(&name, query, tableName) | ||
return err == nil | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package migrate | ||
|
||
import ( | ||
"context" | ||
"embed" | ||
"fmt" | ||
"regexp" | ||
"strings" | ||
|
||
"github.com/charmbracelet/soft-serve/server/db" | ||
) | ||
|
||
//go:embed *.sql | ||
var sqls embed.FS | ||
|
||
// Keep this in order of execution, oldest to newest. | ||
var migrations = []Migration{ | ||
createTables, | ||
} | ||
|
||
func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error { | ||
direction := "up" | ||
if down { | ||
direction = "down" | ||
} | ||
|
||
driverName := tx.DriverName() | ||
if driverName == "sqlite3" { | ||
driverName = "sqlite" | ||
} | ||
|
||
fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction) | ||
sqlstr, err := sqls.ReadFile(fn) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil { | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error { | ||
return execMigration(ctx, tx, version, name, false) | ||
} | ||
|
||
func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error { | ||
return execMigration(ctx, tx, version, name, true) | ||
} | ||
|
||
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") | ||
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") | ||
|
||
func toSnakeCase(str string) string { | ||
str = strings.ReplaceAll(str, "-", "_") | ||
str = strings.ReplaceAll(str, " ", "_") | ||
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") | ||
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") | ||
return strings.ToLower(snake) | ||
} |