diff --git a/build/install_tools.sh b/build/install_tools.sh
index da40003..c1067a4 100755
--- a/build/install_tools.sh
+++ b/build/install_tools.sh
@@ -3,26 +3,51 @@ set -eux
source $(dirname "$0")/env.sh
+FORCE=false
+GODLV_VERSION=${GODLV_VERSION:=1.21.0}
if go version | grep 1.2; then
- GOLNT_VERSION=${GOLNT_VERSION:=v1.52.2}
+ GOLNT_VERSION=${GOLNT_VERSION:=1.52.2}
else
- GOLNT_VERSION=${GOLNT_VERSION:=v1.50.1}
+ GOLNT_VERSION=${GOLNT_VERSION:=1.50.1}
fi
-GOMRK_VERSION=v1.1.0
-SHFMT_VERSION=v3.6.0
-TYPOS_VERSION=1.13.6
+GOMRK_VERSION=${GOMRK_VERSION:=1.1.0}
+SHFMT_VERSION=${SHFMT_VERSION:=3.6.0}
+GOIMP_VERSION=${GOIMP_VERSION:=0.5.0}
+GOIMPORTS_REVISER_VERSION=${GOIMPORTS_REVISER_VERSION:=3.4.2}
+TYPOS_VERSION=${TYPOS_VERSION:=1.13.6}
+ADDLICENSE_VERSION=${ADDLICENSE_VERSION:=1.1.1}
+CODEOWNERS_VALIDATOR_VERSION=${CODEOWNERS_VALIDATOR_VERSION:=0.7.4}
+YAMLFMT_VERSION=${YAMLFMT_VERSION:=0.9.0}
-go-licenser -version || go install github.com/elastic/go-licenser@latest
-$(go env GOPATH)/bin/shfmt -version | grep ${SHFMT_VERSION} || go install mvdan.cc/sh/v3/cmd/shfmt@${SHFMT_VERSION}
-$(go env GOPATH)/bin/gomarkdoc --version | grep ${GOMRK_VERSION} || go install github.com/princjef/gomarkdoc/cmd/gomarkdoc@${GOMRK_VERSION}
-$(go env GOPATH)/bin/golangci-lint version | grep ${GOLNT_VERSION} || curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin ${GOLNT_VERSION}
-go install -v github.com/incu6us/goimports-reviser/v3@latest
+# With version check
+go_install() {
+ ([ ${FORCE} = false ] && $1 | grep $2) || go install $3@v$2
+}
-if [[ "${OSTYPE}" == "linux"* ]]; then
+# Without version check
+go_install_any() {
+ ([ ${FORCE} = false ] && $1 >/dev/null 2>&1) || go install $2@$3
+}
+
+go version | grep '1.19\|1.20\|1.21' || (
+ echo "Install supported version (>=1.19) of golang to use saas-ci"
+ exit 1
+)
+go_install "dlv version" ${GODLV_VERSION} github.com/go-delve/delve/cmd/dlv
+go_install "golangci-lint version" ${GOLNT_VERSION} github.com/golangci/golangci-lint/cmd/golangci-lint
+go_install "gomarkdoc --version" ${GOMRK_VERSION} github.com/princjef/gomarkdoc/cmd/gomarkdoc
+go_install "shfmt -version" ${SHFMT_VERSION} mvdan.cc/sh/v3/cmd/shfmt
+go_install_any "goimports -h" golang.org/x/tools/cmd/goimports v${GOIMP_VERSION}
+go_install_any "goimports-reviser -h" "github.com/incu6us/goimports-reviser/v3" v${GOIMPORTS_REVISER_VERSION}
+go_install_any "golicenser -version" github.com/elastic/go-licenser latest
+go_install_any "addlicense -h" github.com/google/addlicense v${ADDLICENSE_VERSION}
+go_install_any "codeowners-validator -v" github.com/mszostok/codeowners-validator v${CODEOWNERS_VALIDATOR_VERSION}
+go_install_any "yamlfmt -h" github.com/google/yamlfmt/cmd/yamlfmt v${YAMLFMT_VERSION}
+
+OS_TYPE=$(uname -s)
+if [ "${OS_TYPE}" = "Linux" ]; then
/tmp/typos --version | grep ${TYPOS_VERSION} || wget -qO- https://github.com/crate-ci/typos/releases/download/v${TYPOS_VERSION}/typos-v${TYPOS_VERSION}-x86_64-unknown-linux-musl.tar.gz | tar -zxf - -C /tmp/ ./typos
- sudo snap install mdl
-elif [[ "${OSTYPE}" == "darwin"* ]]; then
+elif [ "${OS_TYPE}" = "Darwin" ]; then
/tmp/typos --version | grep ${TYPOS_VERSION} || wget -qO- https://github.com/crate-ci/typos/releases/download/v${TYPOS_VERSION}/typos-v${TYPOS_VERSION}-x86_64-apple-darwin.tar.gz | tar -zxf - -C /tmp/ ./typos
fi
-
pip3 install addlicense mdv yamllint
diff --git a/docs/DOCUMENTATION.md b/docs/DOCUMENTATION.md
index c3a4288..7edc6a5 100755
--- a/docs/DOCUMENTATION.md
+++ b/docs/DOCUMENTATION.md
@@ -27,7 +27,13 @@ import "github.com/vmware-labs/multi-tenant-persistence-for-saas/pkg/authorizer"
- [type SimpleInstancer](<#SimpleInstancer>)
- [func \(s SimpleInstancer\) GetInstanceId\(ctx context.Context\) \(string, error\)](<#SimpleInstancer.GetInstanceId>)
- [func \(s SimpleInstancer\) WithInstanceId\(ctx context.Context, instanceId string\) context.Context](<#SimpleInstancer.WithInstanceId>)
+- [type SimpleTransactionFetcher](<#SimpleTransactionFetcher>)
+ - [func \(s SimpleTransactionFetcher\) GetTransactionCtx\(ctx context.Context\) \*gorm.DB](<#SimpleTransactionFetcher.GetTransactionCtx>)
+ - [func \(s SimpleTransactionFetcher\) IsTransactionCtx\(ctx context.Context\) bool](<#SimpleTransactionFetcher.IsTransactionCtx>)
+ - [func \(s SimpleTransactionFetcher\) WithTransactionCtx\(ctx context.Context, tx \*gorm.DB\) context.Context](<#SimpleTransactionFetcher.WithTransactionCtx>)
- [type Tenancer](<#Tenancer>)
+- [type TransactionContextKey](<#TransactionContextKey>)
+- [type TransactionFetcher](<#TransactionFetcher>)
## Constants
@@ -55,6 +61,14 @@ const (
)
```
+
+
+```go
+const (
+ TransactionCtx = TransactionContextKey("DB_TRANSACTION")
+)
+```
+
## type [Authorizer]()
@@ -172,6 +186,42 @@ func (s SimpleInstancer) WithInstanceId(ctx context.Context, instanceId string)
+
+## type [SimpleTransactionFetcher]()
+
+
+
+```go
+type SimpleTransactionFetcher struct{}
+```
+
+
+### func \(SimpleTransactionFetcher\) [GetTransactionCtx]()
+
+```go
+func (s SimpleTransactionFetcher) GetTransactionCtx(ctx context.Context) *gorm.DB
+```
+
+
+
+
+### func \(SimpleTransactionFetcher\) [IsTransactionCtx]()
+
+```go
+func (s SimpleTransactionFetcher) IsTransactionCtx(ctx context.Context) bool
+```
+
+
+
+
+### func \(SimpleTransactionFetcher\) [WithTransactionCtx]()
+
+```go
+func (s SimpleTransactionFetcher) WithTransactionCtx(ctx context.Context, tx *gorm.DB) context.Context
+```
+
+
+
## type [Tenancer]()
@@ -183,6 +233,28 @@ type Tenancer interface {
}
```
+
+## type [TransactionContextKey]()
+
+
+
+```go
+type TransactionContextKey string
+```
+
+
+## type [TransactionFetcher]()
+
+
+
+```go
+type TransactionFetcher interface {
+ IsTransactionCtx(ctx context.Context) bool
+ GetTransactionCtx(ctx context.Context) *gorm.DB
+ WithTransactionCtx(ctx context.Context, tx *gorm.DB) context.Context
+}
+```
+
# datastore
```go
@@ -226,8 +298,8 @@ DataStore interface exposes basic methods like Find/FindAll/Upsert/Delete. For r
- [type DBConfig](<#DBConfig>)
- [func ConfigFromEnv\(dbName string\) DBConfig](<#ConfigFromEnv>)
- [type DataStore](<#DataStore>)
- - [func FromConfig\(l \*logrus.Entry, authorizer authorizer.Authorizer, instancer authorizer.Instancer, cfg DBConfig\) \(d DataStore, err error\)](<#FromConfig>)
- - [func FromEnv\(l \*logrus.Entry, authorizer authorizer.Authorizer, instancer authorizer.Instancer\) \(d DataStore, err error\)](<#FromEnv>)
+ - [func FromConfig\(l \*logrus.Entry, a authorizer.Authorizer, instancer authorizer.Instancer, cfg DBConfig\) \(d DataStore, err error\)](<#FromConfig>)
+ - [func FromEnv\(l \*logrus.Entry, a authorizer.Authorizer, instancer authorizer.Instancer\) \(d DataStore, err error\)](<#FromEnv>)
- [type Helper](<#Helper>)
- [type Pagination](<#Pagination>)
- [func DefaultPagination\(\) \*Pagination](<#DefaultPagination>)
@@ -318,7 +390,7 @@ var TRACE = func(format string, v ...any) {
```
-## func [DBCreate]()
+## func [DBCreate]()
```go
func DBCreate(cfg DBConfig) error
@@ -327,7 +399,7 @@ func DBCreate(cfg DBConfig) error
Create a Postgres DB using the provided config if it doesn't exist.
-## func [DBExists]()
+## func [DBExists]()
```go
func DBExists(cfg DBConfig) bool
@@ -790,7 +862,7 @@ err != nil - true
### func [FromConfig]()
```go
-func FromConfig(l *logrus.Entry, authorizer authorizer.Authorizer, instancer authorizer.Instancer, cfg DBConfig) (d DataStore, err error)
+func FromConfig(l *logrus.Entry, a authorizer.Authorizer, instancer authorizer.Instancer, cfg DBConfig) (d DataStore, err error)
```
@@ -799,7 +871,7 @@ func FromConfig(l *logrus.Entry, authorizer authorizer.Authorizer, instancer aut
### func [FromEnv]()
```go
-func FromEnv(l *logrus.Entry, authorizer authorizer.Authorizer, instancer authorizer.Instancer) (d DataStore, err error)
+func FromEnv(l *logrus.Entry, a authorizer.Authorizer, instancer authorizer.Instancer) (d DataStore, err error)
```
@@ -879,7 +951,7 @@ func GetRecordInstanceFromSlice(x interface{}) Record
-## type [TenancyInfo]()
+## type [TenancyInfo]()
diff --git a/pkg/authorizer/transaction.go b/pkg/authorizer/transaction.go
new file mode 100644
index 0000000..4271a56
--- /dev/null
+++ b/pkg/authorizer/transaction.go
@@ -0,0 +1,43 @@
+package authorizer
+
+import (
+ "context"
+
+ "gorm.io/gorm"
+)
+
+type TransactionContextKey string
+
+const (
+ TransactionCtx = TransactionContextKey("DB_TRANSACTION")
+)
+
+type TransactionFetcher interface {
+ IsTransactionCtx(ctx context.Context) bool
+ GetTransactionCtx(ctx context.Context) *gorm.DB
+ WithTransactionCtx(ctx context.Context, tx *gorm.DB) context.Context
+}
+
+type SimpleTransactionFetcher struct{}
+
+func (s SimpleTransactionFetcher) GetTransactionCtx(ctx context.Context) *gorm.DB {
+ if v := ctx.Value(TransactionCtx); v != nil {
+ if dbTx, ok := v.(*gorm.DB); ok {
+ return dbTx
+ }
+ }
+ return nil
+}
+
+func (s SimpleTransactionFetcher) WithTransactionCtx(ctx context.Context, tx *gorm.DB) context.Context {
+ return context.WithValue(ctx, TransactionCtx, tx)
+}
+
+func (s SimpleTransactionFetcher) IsTransactionCtx(ctx context.Context) bool {
+ if v := ctx.Value(TransactionCtx); v != nil {
+ if _, ok := v.(*gorm.DB); ok {
+ return true
+ }
+ }
+ return false
+}
diff --git a/pkg/datastore/database.go b/pkg/datastore/database.go
index 71e3691..780ba1b 100644
--- a/pkg/datastore/database.go
+++ b/pkg/datastore/database.go
@@ -67,6 +67,7 @@ type relationalDb struct {
gormDBMap map[dbrole.DbRole]*gorm.DB
logger *logrus.Entry
initializer func(db *relationalDb, dbRole dbrole.DbRole) error
+ txFetcher authorizer.TransactionFetcher
}
type TenancyInfo struct {
@@ -100,10 +101,13 @@ func (db *relationalDb) getDBTransaction(ctx context.Context, tableName string,
return nil, err
}
- if tx, err = db.configureTxWithTenancyScope(tenancyInfo); err != nil {
- return nil, err
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ if tx, err = db.configureTxWithTenancyScope(tenancyInfo); err != nil {
+ return nil, err
+ }
+ } else {
+ tx = db.txFetcher.GetTransactionCtx(ctx).Clauses(clause.Locking{Strength: "UPDATE"})
}
-
tx = tx.Table(tableName)
if err = tx.Error; err != nil {
err = ErrStartingTx.Wrap(err).WithMap(map[ErrorContextKey]string{
@@ -218,7 +222,12 @@ func (db *relationalDb) FindInTable(ctx context.Context, tableName string, recor
if tx, err = db.GetDBTransaction(ctx, tableName, record); err != nil {
return err
}
- defer rollbackTx(tx, db)
+
+ defer func() {
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ rollbackTx(tx, db)
+ }
+ }()
if err = tx.Table(tableName).Where(record).First(record).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -227,9 +236,11 @@ func (db *relationalDb) FindInTable(ctx context.Context, tableName string, recor
db.logger.Debug(err)
return ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
}
- if err = tx.Commit().Error; err != nil {
- db.logger.Debug(err)
- return ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ if err = tx.Commit().Error; err != nil {
+ db.logger.Debug(err)
+ return ErrExecutingSqlStmt.Wrap(err)
+ }
}
return nil
}
@@ -290,9 +301,11 @@ func (db *relationalDb) FindWithFilterInTable(ctx context.Context, tableName str
return ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
}
- if err = tx.Commit().Error; err != nil {
- db.logger.Debug(err)
- return ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ if err = tx.Commit().Error; err != nil {
+ db.logger.Debug(err)
+ return ErrExecutingSqlStmt.Wrap(err)
+ }
}
return nil
}
@@ -312,15 +325,22 @@ func (db *relationalDb) InsertInTable(ctx context.Context, tableName string, rec
if tx, err = db.GetDBTransaction(ctx, tableName, record); err != nil {
return 0, err
}
- defer rollbackTx(tx, db)
+
+ defer func() {
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ rollbackTx(tx, db)
+ }
+ }()
if err = tx.Create(record).Error; err != nil {
db.logger.Debug(err)
return 0, ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
}
- if err = tx.Commit().Error; err != nil {
- db.logger.Debug(err)
- return 0, ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ if err = tx.Commit().Error; err != nil {
+ db.logger.Debug(err)
+ return 0, ErrExecutingSqlStmt.Wrap(err)
+ }
}
return tx.RowsAffected, nil
}
@@ -354,11 +374,16 @@ func (db *relationalDb) DeleteInTable(ctx context.Context, tableName string, rec
}
func (db *relationalDb) delete(ctx context.Context, tableName string, record Record, softDelete bool) (rowsAffected int64, err error) {
- tx, err := db.GetDBTransaction(ctx, tableName, record)
- if err != nil {
+ var tx *gorm.DB
+ if tx, err = db.GetDBTransaction(ctx, tableName, record); err != nil {
return 0, err
}
- defer rollbackTx(tx, db)
+
+ defer func() {
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ rollbackTx(tx, db)
+ }
+ }()
if softDelete {
if err = tx.Delete(record).Error; err != nil {
@@ -371,9 +396,11 @@ func (db *relationalDb) delete(ctx context.Context, tableName string, record Rec
return 0, ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
}
}
- if err = tx.Commit().Error; err != nil {
- db.logger.Debug(err)
- return 0, ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ if err = tx.Commit().Error; err != nil {
+ db.logger.Debug(err)
+ return 0, ErrExecutingSqlStmt.Wrap(err)
+ }
}
return tx.RowsAffected, nil
}
@@ -453,7 +480,12 @@ func (db *relationalDb) UpsertInTable(ctx context.Context, tableName string, rec
if tx, err = db.GetDBTransaction(ctx, tableName, record); err != nil {
return 0, err
}
- defer rollbackTx(tx, db)
+
+ defer func() {
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ rollbackTx(tx, db)
+ }
+ }()
if err = tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(record).Error; err != nil {
db.logger.Debug(err)
@@ -463,9 +495,11 @@ func (db *relationalDb) UpsertInTable(ctx context.Context, tableName string, rec
return 0, ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
}
}
- if err = tx.Commit().Error; err != nil {
- db.logger.Debug(err)
- return 0, ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ if err = tx.Commit().Error; err != nil {
+ db.logger.Debug(err)
+ return 0, ErrExecutingSqlStmt.Wrap(err)
+ }
}
return tx.RowsAffected, nil
}
@@ -478,7 +512,12 @@ func (db *relationalDb) UpdateInTable(ctx context.Context, tableName string, rec
if tx, err = db.GetDBTransaction(ctx, tableName, record); err != nil {
return 0, err
}
- defer rollbackTx(tx, db)
+
+ defer func() {
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ rollbackTx(tx, db)
+ }
+ }()
if err = tx.Model(record).Select("*").Updates(record).Error; err != nil {
db.logger.Debug(err)
@@ -490,9 +529,11 @@ func (db *relationalDb) UpdateInTable(ctx context.Context, tableName string, rec
return 0, err
}
}
- if err = tx.Commit().Error; err != nil {
- db.logger.Debug(err)
- return 0, ErrExecutingSqlStmt.Wrap(err).WithValue(DB_NAME, db.dbName)
+ if !db.txFetcher.IsTransactionCtx(ctx) {
+ if err = tx.Commit().Error; err != nil {
+ db.logger.Debug(err)
+ return 0, ErrExecutingSqlStmt.Wrap(err)
+ }
}
return tx.RowsAffected, nil
}
diff --git a/pkg/datastore/datastore_with_txcontext_test.go b/pkg/datastore/datastore_with_txcontext_test.go
new file mode 100644
index 0000000..997c532
--- /dev/null
+++ b/pkg/datastore/datastore_with_txcontext_test.go
@@ -0,0 +1,140 @@
+// Copyright 2023 VMware, Inc.
+// Licensed to VMware, Inc. under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. VMware, Inc. licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package datastore_test
+
+import (
+ "context"
+ "database/sql"
+ "io"
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "gorm.io/gorm"
+
+ "github.com/vmware-labs/multi-tenant-persistence-for-saas/pkg/authorizer"
+ "github.com/vmware-labs/multi-tenant-persistence-for-saas/pkg/datastore"
+ . "github.com/vmware-labs/multi-tenant-persistence-for-saas/pkg/errors"
+ . "github.com/vmware-labs/multi-tenant-persistence-for-saas/test"
+)
+
+func rollbackTx(t *testing.T, tx *gorm.DB) {
+ t.Helper()
+ if err := tx.Rollback().Error; err != nil && err != sql.ErrTxDone {
+ t.Log("Rollback of tx errored", err)
+ }
+}
+
+func testTxCrud(t *testing.T, ds datastore.DataStore, ctx context.Context, myCokeApp *App, user1, user2 *AppUser) {
+ t.Helper()
+ assert := assert.New(t)
+ var err error
+
+ txFetcher := authorizer.SimpleTransactionFetcher{}
+
+ // Querying of previously inserted records should succeed
+ for _, record := range []*AppUser{user1, user2} {
+ queryResult := AppUser{Id: record.Id}
+ err = ds.Find(ctx, &queryResult)
+ assert.NoError(err)
+ assert.Equal(record, &queryResult)
+ }
+
+ // Updating non-key fields in a record should succeed
+ user1.Name = "Jeyhun G."
+ user1.Email = "jeyhun111@mail.com"
+ user1.EmailConfirmed = !user1.EmailConfirmed
+ user1.NumFollowers++
+ user2.Name = "Jahangir G."
+ user2.Email = "jahangir111@mail.com"
+ user2.EmailConfirmed = !user2.EmailConfirmed
+ user2.NumFollowers--
+
+ // TODO: Find, Update and Delete are still not supported due to appending where clauses,
+ // after each call to Find/Update/Delete methods
+
+ tx, err := ds.Helper().GetDBTransaction(ctx, datastore.GetTableName(user1), user1)
+ assert.NoError(err)
+ defer rollbackTx(t, tx)
+ txCtx := txFetcher.WithTransactionCtx(ctx, tx)
+ for _, record := range []*AppUser{user1, user2} {
+ rowsAffected, err := ds.Upsert(txCtx, record)
+ assert.NoError(err)
+ assert.EqualValues(1, rowsAffected)
+ }
+ assert.NoError(tx.Commit().Error)
+
+ for _, record := range []*AppUser{user1, user2} {
+ queryResult := &AppUser{Id: record.Id}
+ err = ds.Find(ctx, queryResult)
+ assert.NoError(err)
+ assert.Equal(record, queryResult)
+ }
+
+ tx, err = ds.Helper().GetDBTransaction(ctx, datastore.GetTableName(user1), user1)
+ assert.NoError(err)
+ defer rollbackTx(t, tx)
+ txCtx = txFetcher.WithTransactionCtx(ctx, tx)
+ // Upsert operation should be an update for already existing records
+ user1.NumFollowers++
+ user2.NumFollowers--
+ for _, record := range []*AppUser{user1, user2} {
+ rowsAffected, err := ds.Upsert(txCtx, record)
+ assert.NoError(err)
+ assert.EqualValues(1, rowsAffected)
+ }
+ assert.NoError(tx.Commit().Error)
+ for _, record := range []*AppUser{user1, user2} {
+ queryResult := &AppUser{Id: record.Id}
+ err = ds.Find(ctx, queryResult)
+ assert.NoError(err)
+ assert.Equal(record, queryResult)
+ }
+
+ // Deletion of existing records should not fail, and the records should no longer be found in the DB
+ for _, record := range []*AppUser{user1, user2} {
+ rowsAffected, err := ds.Delete(ctx, record)
+ assert.NoError(err)
+ assert.EqualValues(1, rowsAffected)
+ queryResult := AppUser{Id: record.Id}
+ err = ds.Find(ctx, &queryResult)
+ assert.ErrorIs(err, ErrRecordNotFound)
+ assert.True(queryResult.AreNonKeyFieldsEmpty())
+ }
+}
+
+func BenchmarkTxCrud(b *testing.B) {
+ logger := datastore.GetLogger()
+ logger.SetLevel(logrus.FatalLevel)
+ logger.SetOutput(io.Discard)
+ LOG = logger.WithField(datastore.COMP, datastore.SAAS_PERSISTENCE)
+
+ var t testing.T
+ ds, _ := SetupDataStore("BenchmarkCrud")
+ myCokeApp, user1, user2 := SetupDbTables(ds)
+ for n := 0; n < b.N; n++ {
+ testCrud(&t, ds, CokeAdminCtx, myCokeApp, user1, user2)
+ }
+}
+
+func TestTxCrud(t *testing.T) {
+ ds, _ := SetupDataStore("TestTxCrud")
+ myCokeApp, user1, user2 := SetupDbTables(ds)
+ testTxCrud(t, ds, CokeAdminCtx, myCokeApp, user1, user2)
+}
diff --git a/pkg/datastore/helper.go b/pkg/datastore/helper.go
index fea8af2..1e9e417 100644
--- a/pkg/datastore/helper.go
+++ b/pkg/datastore/helper.go
@@ -97,11 +97,11 @@ func ConfigFromEnv(dbName string) DBConfig {
return cfg
}
-func FromEnv(l *logrus.Entry, authorizer authorizer.Authorizer, instancer authorizer.Instancer) (d DataStore, err error) {
- return FromConfig(l, authorizer, instancer, ConfigFromEnv(""))
+func FromEnv(l *logrus.Entry, a authorizer.Authorizer, instancer authorizer.Instancer) (d DataStore, err error) {
+ return FromConfig(l, a, instancer, ConfigFromEnv(""))
}
-func FromConfig(l *logrus.Entry, authorizer authorizer.Authorizer, instancer authorizer.Instancer, cfg DBConfig) (d DataStore, err error) {
+func FromConfig(l *logrus.Entry, a authorizer.Authorizer, instancer authorizer.Instancer, cfg DBConfig) (d DataStore, err error) {
gl := GetGormLogger(l)
dbConnInitializer := func(db *relationalDb, dbRole dbrole.DbRole) error {
db.Lock()
@@ -180,11 +180,12 @@ func FromConfig(l *logrus.Entry, authorizer authorizer.Authorizer, instancer aut
}
return &relationalDb{
dbName: cfg.dbName,
- authorizer: authorizer,
+ authorizer: a,
instancer: instancer,
gormDBMap: make(map[dbrole.DbRole]*gorm.DB),
initializer: dbConnInitializer,
logger: l,
+ txFetcher: authorizer.SimpleTransactionFetcher{},
}, nil
}