diff --git a/server/db/errors.go b/server/db/errors.go index 872c223de..7a3f1fcc2 100644 --- a/server/db/errors.go +++ b/server/db/errors.go @@ -11,6 +11,9 @@ import ( var ( // ErrDuplicateKey is a constraint violation error. ErrDuplicateKey = errors.New("duplicate key value violates table constraint") + + // ErrRecordNotFound is returned when a record is not found. + ErrRecordNotFound = errors.New("record not found") ) // WrapError is a convenient function that unite various database driver @@ -18,7 +21,7 @@ var ( func WrapError(err error) error { if err != nil { if errors.Is(err, sql.ErrNoRows) { - return err + return ErrRecordNotFound } // Handle sqlite constraint error. if liteErr, ok := err.(*sqlite.Error); ok { diff --git a/server/db/migrate/0001_create_tables.go b/server/db/migrate/0001_create_tables.go new file mode 100644 index 000000000..7a22825d7 --- /dev/null +++ b/server/db/migrate/0001_create_tables.go @@ -0,0 +1,64 @@ +package migrate + +import ( + "context" + "fmt" + + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/sshutils" + "github.com/charmbracelet/soft-serve/server/store" +) + +const ( + createTablesName = "create tables" + createTablesVersion = 1 +) + +var createTables = Migration{ + Version: createTablesVersion, + Name: createTablesName, + Migrate: func(ctx context.Context, tx *db.Tx) error { + cfg := config.FromContext(ctx) + + if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil { + return err + } + + // Insert default settings + insertSettings := "INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" + insertSettings = tx.Rebind(insertSettings) + settings := []struct { + Key string + Value string + }{ + {"allow_keyless", "true"}, + {"anon_access", store.ReadOnlyAccess.String()}, + {"init", "true"}, + } + + for _, s := range settings { + if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil { + return fmt.Errorf("inserting default settings %q: %w", s.Key, err) + } + } + + // Insert default user + insertUser := tx.Rebind("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)") + if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil { + return err + } + + for _, k := range cfg.AdminKeys() { + ak := sshutils.MarshalAuthorizedKey(k) + if _, err := tx.ExecContext(ctx, "INSERT INTO public_key (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", 1, ak); err != nil { + return err + } + } + + return nil + }, + Rollback: func(ctx context.Context, tx *db.Tx) error { + return migrateDown(ctx, tx, createTablesVersion, createTablesName) + }, +} diff --git a/server/db/migrate/0001_create_tables_sqlite.down.sql b/server/db/migrate/0001_create_tables_sqlite.down.sql new file mode 100644 index 000000000..116aaa276 --- /dev/null +++ b/server/db/migrate/0001_create_tables_sqlite.down.sql @@ -0,0 +1,6 @@ + +DROP TABLE IF EXISTS collab; +DROP TABLE IF EXISTS repo; +DROP TABLE IF EXISTS public_key; +DROP TABLE IF EXISTS user; +DROP TABLE IF EXISTS settings; diff --git a/server/db/migrate/0001_create_tables_sqlite.up.sql b/server/db/migrate/0001_create_tables_sqlite.up.sql new file mode 100644 index 000000000..0b2c70d9d --- /dev/null +++ b/server/db/migrate/0001_create_tables_sqlite.up.sql @@ -0,0 +1,54 @@ +CREATE TABLE IF NOT EXISTS settings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT NOT NULL UNIQUE, + value TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL + ); +CREATE TABLE IF NOT EXISTS user ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE, + admin BOOLEAN NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL + ); +CREATE TABLE IF NOT EXISTS public_key ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + public_key TEXT NOT NULL UNIQUE, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (user_id, public_key), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES user(id) + ON DELETE CASCADE + ON UPDATE CASCADE + ); +CREATE TABLE IF NOT EXISTS repo ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + project_name TEXT NOT NULL, + description TEXT NOT NULL, + private BOOLEAN NOT NULL, + mirror BOOLEAN NOT NULL, + hidden BOOLEAN NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL + ); +CREATE TABLE IF NOT EXISTS collab ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + repo_id INTEGER NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (user_id, repo_id), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES user(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repo(id) + ON DELETE CASCADE + ON UPDATE CASCADE + ); + diff --git a/server/db/migrate/migrate.go b/server/db/migrate/migrate.go new file mode 100644 index 000000000..6bea1cea2 --- /dev/null +++ b/server/db/migrate/migrate.go @@ -0,0 +1,140 @@ +package migrate + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/db" +) + +// MigrateFunc is a function that executes a migration. +type MigrateFunc func(ctx context.Context, tx *db.Tx) error + +// Migration is a struct that contains the name of the migration and the +// function to execute it. +type Migration struct { + Version int + Name string + Migrate MigrateFunc + Rollback MigrateFunc +} + +// Migrations is a database model to store migrations. +type Migrations struct { + ID int64 `db:"id"` + Name string `db:"name"` + Version int `db:"version"` +} + +func (Migrations) schema(driverName string) string { + switch driverName { + case "sqlite3", "sqlite": + return `CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + version INTEGER NOT NULL + ); + ` + case "postgres": + return `CREATE TABLE IF NOT EXISTS migrations ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + version INTEGER NOT NULL + ); + ` + case "mysql": + return `CREATE TABLE IF NOT EXISTS migrations ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name TEXT NOT NULL, + version INT NOT NULL + ); + ` + default: + panic("unknown driver") + } +} + +// Migrate runs the migrations. +func Migrate(ctx context.Context, dbx *db.DB) error { + logger := log.FromContext(ctx).WithPrefix("migrate") + return dbx.TransactionContext(ctx, func(tx *db.Tx) error { + if !hasTable(tx, "migrations") { + if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil { + return err + } + } + + var migrs Migrations + if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return err + } + } + + for _, m := range migrations { + if m.Version <= migrs.Version { + continue + } + + logger.Infof("running migration %d:%s", m.Version, m.Name) + if err := m.Migrate(ctx, tx); err != nil { + return err + } + + if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil { + return err + } + } + + return nil + }) +} + +// Rollback rolls back a migration. +func Rollback(ctx context.Context, dbx *db.DB) error { + logger := log.FromContext(ctx).WithPrefix("migrate") + return dbx.TransactionContext(ctx, func(tx *db.Tx) error { + var migrs Migrations + if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("there are no migrations to rollback: %w", err) + } + } + + if len(migrations) < migrs.Version { + return fmt.Errorf("there are no migrations to rollback") + } + + m := migrations[migrs.Version-1] + logger.Infof("rolling back migration %d:%s", m.Version, m.Name) + if err := m.Rollback(ctx, tx); err != nil { + return err + } + + if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE id = ?"), migrs.ID); err != nil { + return err + } + + return nil + }) +} + +func hasTable(tx *db.Tx, tableName string) bool { + var query string + switch tx.DriverName() { + case "sqlite3", "sqlite": + query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" + case "postgres": + fallthrough + case "mysql": + query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?" + } + + query = tx.Rebind(query) + var name string + err := tx.Get(&name, query, tableName) + return err == nil +} diff --git a/server/db/migrate/migrations.go b/server/db/migrate/migrations.go new file mode 100644 index 000000000..88a9e4346 --- /dev/null +++ b/server/db/migrate/migrations.go @@ -0,0 +1,62 @@ +package migrate + +import ( + "context" + "embed" + "fmt" + "regexp" + "strings" + + "github.com/charmbracelet/soft-serve/server/db" +) + +//go:embed *.sql +var sqls embed.FS + +// Keep this in order of execution, oldest to newest. +var migrations = []Migration{ + createTables, +} + +func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error { + direction := "up" + if down { + direction = "down" + } + + driverName := tx.DriverName() + if driverName == "sqlite3" { + driverName = "sqlite" + } + + fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction) + sqlstr, err := sqls.ReadFile(fn) + if err != nil { + return err + } + + if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil { + return err + } + + return nil +} + +func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error { + return execMigration(ctx, tx, version, name, false) +} + +func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error { + return execMigration(ctx, tx, version, name, true) +} + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +func toSnakeCase(str string) string { + str = strings.ReplaceAll(str, "-", "_") + str = strings.ReplaceAll(str, " ", "_") + snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +}