Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experiemntal: add StoreController to allowing extending interface #752

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions database/controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package database

import (
"context"
"database/sql"
"errors"
)

// ErrNotSupported is returned when an optional method is not supported by the Store implementation.
var ErrNotSupported = errors.New("not supported")

// A StoreController is used by the goose package to interact with a database. This type is a
// wrapper around the Store interface, but can be extended to include additional (optional) methods
// that are not part of the core Store interface.
type StoreController struct {
store Store
}

var _ Store = (*StoreController)(nil)

// NewStoreController returns a new StoreController that wraps the given Store.
//
// If the Store implements the following optional methods, the StoreController will call them as
// appropriate:
//
// - TableExists(context.Context, DBTxConn) (bool, error)
//
// If the Store does not implement a method, it will either return a [ErrNotSupported] error or fall
// back to the default behavior.
func NewStoreController(store Store) *StoreController {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside with this pattern is external implementations don't have a type assertion to lean on, and so it's really easy to believe a method is implemented, but in reality the method signature doesn't match. You lose that nice compile-time check.

Another pattern could be to have a second interface like StoreExtender that's a superset of Store.

type StoreExtender interface {
	Store
	// TableExists is an optional method that checks if the version table exists in the database. It is
	// recommended to implement this method if the database supports it, as it can be used to optimize
	// certain operations.
	TableExists(ctx context.Context, db DBTxConn) (bool, error)
}

There's pros/cons to both approaches, need to sleep on this a bit and chat with folks to get feedback.

return &StoreController{store: store}
}

// TableExists is an optional method that checks if the version table exists in the database. It is
// recommended to implement this method if the database supports it, as it can be used to optimize
// certain operations.
func (c *StoreController) TableExists(ctx context.Context, db *sql.Conn) (bool, error) {
if t, ok := c.store.(interface {
TableExists(ctx context.Context, db *sql.Conn) (bool, error)
}); ok {
return t.TableExists(ctx, db)
}
return false, ErrNotSupported
}

// Default methods

func (c *StoreController) Tablename() string {
return c.store.Tablename()
}

func (c *StoreController) CreateVersionTable(ctx context.Context, db DBTxConn) error {
return c.store.CreateVersionTable(ctx, db)
}

func (c *StoreController) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
return c.store.Insert(ctx, db, req)
}

func (c *StoreController) Delete(ctx context.Context, db DBTxConn, version int64) error {
return c.store.Delete(ctx, db, version)
}

func (c *StoreController) GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) {
return c.store.GetMigration(ctx, db, version)
}

func (c *StoreController) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
return c.store.GetLatestVersion(ctx, db)
}

func (c *StoreController) ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error) {
return c.store.ListMigrations(ctx, db)
}
16 changes: 14 additions & 2 deletions database/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func NewStore(dialect Dialect, tablename string) (Store, error) {
}
return &store{
tablename: tablename,
querier: querier,
querier: dialectquery.NewQueryController(querier),
}, nil
}

type store struct {
tablename string
querier dialectquery.Querier
querier *dialectquery.QueryController
}

var _ Store = (*store)(nil)
Expand Down Expand Up @@ -137,3 +137,15 @@ func (s *store) ListMigrations(
}
return migrations, nil
}

func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
q := s.querier.TableExists(s.tablename)
if q == "" {
return false, ErrNotSupported
}
var exists bool
if err := db.QueryRowContext(ctx, q, s.tablename).Scan(&exists); err != nil {
return false, fmt.Errorf("failed to check if table exists: %w", err)
}
return exists, nil
}
63 changes: 53 additions & 10 deletions internal/dialect/dialectquery/dialectquery.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,71 @@
package dialectquery

// Querier is the interface that wraps the basic methods to create a dialect
// specific query.
// Querier is the interface that wraps the basic methods to create a dialect specific query.
type Querier interface {
// CreateTable returns the SQL query string to create the db version table.
CreateTable(tableName string) string

// InsertVersion returns the SQL query string to insert a new version into
// the db version table.
// InsertVersion returns the SQL query string to insert a new version into the db version table.
InsertVersion(tableName string) string

// DeleteVersion returns the SQL query string to delete a version from
// the db version table.
// DeleteVersion returns the SQL query string to delete a version from the db version table.
DeleteVersion(tableName string) string

// GetMigrationByVersion returns the SQL query string to get a single
// migration by version.
// GetMigrationByVersion returns the SQL query string to get a single migration by version.
//
// The query should return the timestamp and is_applied columns.
GetMigrationByVersion(tableName string) string

// ListMigrations returns the SQL query string to list all migrations in
// descending order by id.
// ListMigrations returns the SQL query string to list all migrations in descending order by id.
//
// The query should return the version_id and is_applied columns.
ListMigrations(tableName string) string
}

var _ Querier = (*QueryController)(nil)

type QueryController struct {
querier Querier
}

// NewQueryController returns a new QueryController that wraps the given Querier.
func NewQueryController(querier Querier) *QueryController {
return &QueryController{querier: querier}
}

// Optional methods

// TableExists returns the SQL query string to check if the version table exists. If the Querier
// does not implement this method, it will return an empty string.
//
// The query should return a boolean value.
func (c *QueryController) TableExists(tableName string) string {
if t, ok := c.querier.(interface {
TableExists(string) string
}); ok {
return t.TableExists(tableName)
}
return ""
}

// Default methods

func (c *QueryController) CreateTable(tableName string) string {
return c.querier.CreateTable(tableName)
}

func (c *QueryController) InsertVersion(tableName string) string {
return c.querier.InsertVersion(tableName)
}

func (c *QueryController) DeleteVersion(tableName string) string {
return c.querier.DeleteVersion(tableName)
}

func (c *QueryController) GetMigrationByVersion(tableName string) string {
return c.querier.GetMigrationByVersion(tableName)
}

func (c *QueryController) ListMigrations(tableName string) string {
return c.querier.ListMigrations(tableName)
}
4 changes: 4 additions & 0 deletions internal/dialect/dialectquery/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ func (p *Postgres) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) TableExists(_ string) string {
return `SELECT EXISTS ( SELECT FROM pg_tables WHERE tablename = $1)`
}
4 changes: 2 additions & 2 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Provider struct {
mu sync.Mutex

db *sql.DB
store database.Store
store *database.StoreController

fsys fs.FS
cfg config
Expand Down Expand Up @@ -143,7 +143,7 @@ func newProvider(
db: db,
fsys: fsys,
cfg: cfg,
store: store,
store: database.NewStoreController(store),
migrations: migrations,
}, nil
}
Expand Down
26 changes: 8 additions & 18 deletions provider_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,15 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
}

func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
// existor is an interface that extends the Store interface with a method to check if the
// version table exists. This API is not stable and may change in the future.
type existor interface {
TableExists(context.Context, database.DBTxConn, string) (bool, error)
ok, err := p.store.TableExists(ctx, conn)
if err != nil && !errors.Is(err, database.ErrNotSupported) {
return err
}
if e, ok := p.store.(existor); ok {
exists, err := e.TableExists(ctx, conn, p.store.Tablename())
if err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
}
if exists {
return nil
}
} else {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
if ok {
return nil
}
if errors.Is(err, database.ErrNotSupported) {
// Fall back to the default behavior if the Store does not implement TableExists.
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
Expand All @@ -355,9 +348,6 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
return err
}
if p.cfg.disableVersioning {
return nil
}
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
})
}
Expand Down
31 changes: 16 additions & 15 deletions provider_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,19 +748,6 @@ func TestGoMigrationPanic(t *testing.T) {
check.Contains(t, expected.Err.Error(), wantErrString)
}

func TestCustomStoreTableExists(t *testing.T) {
t.Parallel()

store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
check.NoError(t, err)
p, err := goose.NewProvider("", newDB(t), newFsys(),
goose.WithStore(&customStoreSQLite3{store}),
)
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
}

func TestProviderApply(t *testing.T) {
t.Parallel()

Expand All @@ -774,15 +761,29 @@ func TestProviderApply(t *testing.T) {
check.HasError(t, err)
check.Bool(t, errors.Is(err, goose.ErrNotApplied), true)
}
func TestCustomStoreTableExists(t *testing.T) {
t.Parallel()

store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
check.NoError(t, err)
p, err := goose.NewProvider("", newDB(t), newFsys(),
goose.WithStore(&customStoreSQLite3{store}),
)
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
}

type customStoreSQLite3 struct {
database.Store
}

func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) {
func (s *customStoreSQLite3) TableExists(ctx context.Context, db *sql.Conn) (bool, error) {
q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists`
var exists bool
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
if err := db.QueryRowContext(ctx, q, s.Tablename()).Scan(&exists); err != nil {
return false, err
}
return exists, nil
Expand Down
Loading