From 24fbe22171be522f11cb760566694aa2917ebd5c Mon Sep 17 00:00:00 2001 From: ajlake Date: Mon, 12 Feb 2018 13:29:03 -0800 Subject: [PATCH] Revert gorm changes (#33) * Revert "Use gorm for DB interactions (#18)" This reverts commit 1ad6d6ef8913014c7369b4587425a56e5017850b. --- Gopkg.lock | 17 +- Gopkg.toml | 35 +- persist/repo.go | 61 + persist/schema.go | 114 +- persist/user.go | 78 + server/endpoints/hook.go | 22 +- server/endpoints/repositories.go | 58 +- server/endpoints/token.go | 31 +- server/init.go | 12 +- server/server.go | 6 +- .../github.com/jinzhu/gorm/.codeclimate.yml | 11 - vendor/github.com/jinzhu/gorm/.gitignore | 2 - vendor/github.com/jinzhu/gorm/CONTRIBUTING.md | 52 - vendor/github.com/jinzhu/gorm/License | 21 - vendor/github.com/jinzhu/gorm/README.md | 46 - vendor/github.com/jinzhu/gorm/association.go | 359 ---- .../jinzhu/gorm/association_test.go | 842 -------- vendor/github.com/jinzhu/gorm/callback.go | 237 --- .../github.com/jinzhu/gorm/callback_create.go | 144 -- .../github.com/jinzhu/gorm/callback_delete.go | 53 - .../github.com/jinzhu/gorm/callback_query.go | 93 - .../jinzhu/gorm/callback_query_preload.go | 310 --- .../github.com/jinzhu/gorm/callback_save.go | 92 - .../jinzhu/gorm/callback_system_test.go | 112 -- .../github.com/jinzhu/gorm/callback_update.go | 104 - .../github.com/jinzhu/gorm/callbacks_test.go | 177 -- vendor/github.com/jinzhu/gorm/create_test.go | 164 -- .../jinzhu/gorm/customize_column_test.go | 280 --- vendor/github.com/jinzhu/gorm/delete_test.go | 68 - vendor/github.com/jinzhu/gorm/dialect.go | 100 - .../github.com/jinzhu/gorm/dialect_common.go | 137 -- .../github.com/jinzhu/gorm/dialect_mysql.go | 113 -- .../jinzhu/gorm/dialect_postgres.go | 132 -- .../github.com/jinzhu/gorm/dialect_sqlite3.go | 106 - .../jinzhu/gorm/dialects/postgres/postgres.go | 54 - .../jinzhu/gorm/embedded_struct_test.go | 48 - vendor/github.com/jinzhu/gorm/errors.go | 58 - vendor/github.com/jinzhu/gorm/field.go | 58 - vendor/github.com/jinzhu/gorm/field_test.go | 49 - vendor/github.com/jinzhu/gorm/interface.go | 19 - .../jinzhu/gorm/join_table_handler.go | 204 -- .../github.com/jinzhu/gorm/join_table_test.go | 72 - vendor/github.com/jinzhu/gorm/logger.go | 99 - vendor/github.com/jinzhu/gorm/main.go | 700 ------- vendor/github.com/jinzhu/gorm/main_test.go | 774 ------- .../github.com/jinzhu/gorm/migration_test.go | 349 ---- vendor/github.com/jinzhu/gorm/model.go | 14 - vendor/github.com/jinzhu/gorm/model_struct.go | 542 ----- .../jinzhu/gorm/multi_primary_keys_test.go | 381 ---- vendor/github.com/jinzhu/gorm/pointer_test.go | 84 - .../jinzhu/gorm/polymorphic_test.go | 219 -- vendor/github.com/jinzhu/gorm/preload_test.go | 1327 ------------ vendor/github.com/jinzhu/gorm/query_test.go | 636 ------ vendor/github.com/jinzhu/gorm/scaner_test.go | 70 - vendor/github.com/jinzhu/gorm/scope.go | 1246 ------------ vendor/github.com/jinzhu/gorm/scope_test.go | 43 - vendor/github.com/jinzhu/gorm/search.go | 149 -- vendor/github.com/jinzhu/gorm/search_test.go | 30 - vendor/github.com/jinzhu/gorm/test_all.sh | 5 - vendor/github.com/jinzhu/gorm/update_test.go | 435 ---- vendor/github.com/jinzhu/gorm/utils.go | 264 --- vendor/github.com/jinzhu/gorm/utils_test.go | 30 - vendor/github.com/jinzhu/inflection/LICENSE | 21 - vendor/github.com/jinzhu/inflection/README.md | 55 - .../jinzhu/inflection/inflections.go | 273 --- .../jinzhu/inflection/inflections_test.go | 213 -- vendor/github.com/jmoiron/sqlx/.gitignore | 24 + vendor/github.com/jmoiron/sqlx/LICENSE | 23 + vendor/github.com/jmoiron/sqlx/README.md | 183 ++ vendor/github.com/jmoiron/sqlx/bind.go | 207 ++ vendor/github.com/jmoiron/sqlx/doc.go | 12 + vendor/github.com/jmoiron/sqlx/named.go | 344 ++++ .../github.com/jmoiron/sqlx/named_context.go | 132 ++ .../jmoiron/sqlx/named_context_test.go | 136 ++ vendor/github.com/jmoiron/sqlx/named_test.go | 227 +++ .../jmoiron/sqlx/reflectx/README.md | 17 + .../jmoiron/sqlx/reflectx/reflect.go | 422 ++++ .../jmoiron/sqlx/reflectx/reflect_test.go | 905 +++++++++ vendor/github.com/jmoiron/sqlx/sqlx.go | 1035 ++++++++++ .../github.com/jmoiron/sqlx/sqlx_context.go | 335 +++ .../jmoiron/sqlx/sqlx_context_test.go | 1344 +++++++++++++ vendor/github.com/jmoiron/sqlx/sqlx_test.go | 1792 +++++++++++++++++ vendor/github.com/lib/pq/hstore/hstore.go | 118 -- .../github.com/lib/pq/hstore/hstore_test.go | 148 -- 84 files changed, 7471 insertions(+), 12643 deletions(-) create mode 100644 persist/repo.go create mode 100644 persist/user.go delete mode 100644 vendor/github.com/jinzhu/gorm/.codeclimate.yml delete mode 100644 vendor/github.com/jinzhu/gorm/.gitignore delete mode 100644 vendor/github.com/jinzhu/gorm/CONTRIBUTING.md delete mode 100644 vendor/github.com/jinzhu/gorm/License delete mode 100644 vendor/github.com/jinzhu/gorm/README.md delete mode 100644 vendor/github.com/jinzhu/gorm/association.go delete mode 100644 vendor/github.com/jinzhu/gorm/association_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_create.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_delete.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_query.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_query_preload.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_save.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_system_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_update.go delete mode 100644 vendor/github.com/jinzhu/gorm/callbacks_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/create_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/customize_column_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/delete_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_common.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_mysql.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_postgres.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_sqlite3.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go delete mode 100644 vendor/github.com/jinzhu/gorm/embedded_struct_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/errors.go delete mode 100644 vendor/github.com/jinzhu/gorm/field.go delete mode 100644 vendor/github.com/jinzhu/gorm/field_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/interface.go delete mode 100644 vendor/github.com/jinzhu/gorm/join_table_handler.go delete mode 100644 vendor/github.com/jinzhu/gorm/join_table_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/logger.go delete mode 100644 vendor/github.com/jinzhu/gorm/main.go delete mode 100644 vendor/github.com/jinzhu/gorm/main_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/migration_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/model.go delete mode 100644 vendor/github.com/jinzhu/gorm/model_struct.go delete mode 100644 vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/pointer_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/polymorphic_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/preload_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/query_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/scaner_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/scope.go delete mode 100644 vendor/github.com/jinzhu/gorm/scope_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/search.go delete mode 100644 vendor/github.com/jinzhu/gorm/search_test.go delete mode 100755 vendor/github.com/jinzhu/gorm/test_all.sh delete mode 100644 vendor/github.com/jinzhu/gorm/update_test.go delete mode 100644 vendor/github.com/jinzhu/gorm/utils.go delete mode 100644 vendor/github.com/jinzhu/gorm/utils_test.go delete mode 100644 vendor/github.com/jinzhu/inflection/LICENSE delete mode 100644 vendor/github.com/jinzhu/inflection/README.md delete mode 100644 vendor/github.com/jinzhu/inflection/inflections.go delete mode 100644 vendor/github.com/jinzhu/inflection/inflections_test.go create mode 100644 vendor/github.com/jmoiron/sqlx/.gitignore create mode 100644 vendor/github.com/jmoiron/sqlx/LICENSE create mode 100644 vendor/github.com/jmoiron/sqlx/README.md create mode 100644 vendor/github.com/jmoiron/sqlx/bind.go create mode 100644 vendor/github.com/jmoiron/sqlx/doc.go create mode 100644 vendor/github.com/jmoiron/sqlx/named.go create mode 100644 vendor/github.com/jmoiron/sqlx/named_context.go create mode 100644 vendor/github.com/jmoiron/sqlx/named_context_test.go create mode 100644 vendor/github.com/jmoiron/sqlx/named_test.go create mode 100644 vendor/github.com/jmoiron/sqlx/reflectx/README.md create mode 100644 vendor/github.com/jmoiron/sqlx/reflectx/reflect.go create mode 100644 vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go create mode 100644 vendor/github.com/jmoiron/sqlx/sqlx.go create mode 100644 vendor/github.com/jmoiron/sqlx/sqlx_context.go create mode 100644 vendor/github.com/jmoiron/sqlx/sqlx_context_test.go create mode 100644 vendor/github.com/jmoiron/sqlx/sqlx_test.go delete mode 100644 vendor/github.com/lib/pq/hstore/hstore.go delete mode 100644 vendor/github.com/lib/pq/hstore/hstore_test.go diff --git a/Gopkg.lock b/Gopkg.lock index aed2cea81..5c77bdb5e 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -103,16 +103,9 @@ version = "v3.1.1" [[projects]] - name = "github.com/jinzhu/gorm" - packages = [".","dialects/postgres"] - revision = "5174cc5c242a728b435ea2be8a2f7f998e15429b" - version = "v1.0" - -[[projects]] - branch = "master" - name = "github.com/jinzhu/inflection" - packages = ["."] - revision = "1c35d901db3da928c72a72d8458480cc9ade058f" + name = "github.com/jmoiron/sqlx" + packages = [".","reflectx"] + revision = "d9bd385d68c068f1fabb5057e3dedcbcbb039d0f" [[projects]] name = "github.com/labstack/echo" @@ -128,7 +121,7 @@ [[projects]] name = "github.com/lib/pq" - packages = [".","hstore","oid"] + packages = [".","oid"] revision = "b77235e3890a962fe8a6f8c4c7198679ca7814e7" [[projects]] @@ -287,6 +280,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "3f288cd94a1d8e2d6de01bba4ea5f70692f19f52d3ec13677a5dc8496f385e20" + inputs-digest = "40bc7063ffd9b03f5f059775cab1296e4b6efc4fcd2e316eedf49e91dfb7e6c9" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index df7934aef..15d93117b 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -1,3 +1,26 @@ + +# Gopkg.toml example +# +# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" + + [[constraint]] name = "github.com/google/go-github" revision = "511f540f1887d30b88cee4a2fcd1f2922754acf4" @@ -6,10 +29,18 @@ name = "github.com/ipfans/echo-session" version = "3.1.1" +[[constraint]] + name = "github.com/jmoiron/sqlx" + revision = "d9bd385d68c068f1fabb5057e3dedcbcbb039d0f" + [[constraint]] name = "github.com/labstack/echo" version = "3.2.3" +[[constraint]] + name = "github.com/lib/pq" + revision = "b77235e3890a962fe8a6f8c4c7198679ca7814e7" + [[constraint]] name = "github.com/pkg/errors" version = "0.8.0" @@ -41,7 +72,3 @@ [[constraint]] name = "gopkg.in/yaml.v2" revision = "eb3733d160e74a9c7e442f435eb3bea458e1d19f" - -[[constraint]] - name = "github.com/jinzhu/gorm" - version = "1.0.0" diff --git a/persist/repo.go b/persist/repo.go new file mode 100644 index 000000000..95c72ccc7 --- /dev/null +++ b/persist/repo.go @@ -0,0 +1,61 @@ +// Copyright 2017 Palantir Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +import ( + "fmt" + "strings" + + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" +) + +type Repository struct { + ID int `db:"id"` + Name string `db:"name"` + EnabledBy string `db:"enabled_by"` + EnabledAt int64 `db:"enabled_at"` + HookID int `db:"hook_id"` +} + +func (*Repository) InsertStmt() string { + return "INSERT INTO REPOS (id, name, enabled_by, enabled_at, hook_id) VALUES (:id, :name, :enabled_by, :enabled_at, :hook_id)" +} + +func (*Repository) DeleteStmt() string { + return "DELETE FROM REPOS WHERE id = :id" +} + +func (*Repository) UpdateStmt() string { + return "TODO" +} + +func GetRepositoryByID(db *sqlx.DB, repoID int) (*Repository, error) { + r := &Repository{} + + q := fmt.Sprintf("SELECT * FROM REPOS WHERE id=%d", repoID) + row := db.QueryRowx(q) + err := row.StructScan(r) + + if err != nil { + if strings.Contains(err.Error(), "no rows in result set") { + return nil, nil + } + + return nil, errors.Wrapf(err, "cannot get repo %d", repoID) + } + + return r, nil +} diff --git a/persist/schema.go b/persist/schema.go index 280435ece..5ebd484b7 100644 --- a/persist/schema.go +++ b/persist/schema.go @@ -15,28 +15,108 @@ package persist import ( - "time" - - "github.com/jinzhu/gorm" + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" ) -type Repository struct { - gorm.Model - GitHubID int `gorm:"column:github_id"` - Name string - EnabledBy User - EnabledAt time.Time - HookID int +var metaSchema = ` +CREATE TABLE IF NOT EXISTS schema ( + version INTEGER PRIMARY KEY CHECK (version > 0) +); + +CREATE UNIQUE INDEX IF NOT EXISTS schema_one_row ON schema((TRUE));` + +var currSchemaVersion = 1 + +var schema = ` +CREATE TABLE IF NOT EXISTS USERS ( + github_id INTEGER PRIMARY KEY UNIQUE, + name TEXT, + token TEXT +); + +CREATE TABLE IF NOT EXISTS REPOS ( + id INTEGER PRIMARY KEY UNIQUE, + name TEXT, + enabled_by TEXT, + enabled_at BIGINT, + hook_id INTEGER +); +` + +// Persistable are structs that can be persisted to a DB and +// are compatible with associated utility methods in this package +type Persistable interface { + InsertStmt() string + DeleteStmt() string + UpdateStmt() string +} + +// Put persists a Persistable to the given DB +func Put(db *sqlx.DB, p Persistable) error { + _, err := db.NamedExec(p.InsertStmt(), p) + if err != nil { + return errors.Wrapf(err, "failed persisting %v", p) + } + return nil +} + +// Delete deletes a Persistable from the given DB +func Delete(db *sqlx.DB, p Persistable) error { + _, err := db.NamedExec(p.DeleteStmt(), p) + if err != nil { + return errors.Wrapf(err, "failed deleting %v", p) + } + return nil } -type User struct { - gorm.Model - GitHubID int `gorm:"column:github_id"` - Name string - Token string +func Update(db *sqlx.DB, p Persistable) error { + _, err := db.NamedExec(p.UpdateStmt(), p) + if err != nil { + return errors.Wrapf(err, "failed updating %v", p) + } + return nil } // InitializeSchema initializes the schema for storing artifact data -func InitializeSchema(db *gorm.DB) { - db.AutoMigrate(&Repository{}, &User{}) +func InitializeSchema(db *sqlx.DB) error { + version, err := getSchemaVersion(db) + if err != nil { + return errors.Wrapf(err, "failed to determine database schema version") + } + if version != currSchemaVersion { + err := migrateSchema(db, currSchemaVersion) + if err != nil { + return errors.Wrapf(err, "failed migrating database schema from version %d to %d", version, currSchemaVersion) + } + } + _, err = db.Exec(schema) + if err != nil { + return errors.Wrapf(err, "failed initializing database schema") + } + return nil +} + +func getSchemaVersion(db *sqlx.DB) (int, error) { + _, err := db.Exec(metaSchema) + if err != nil { + return 0, errors.Wrapf(err, "failed initializing schema table") + } + var version []int + err = db.Select(&version, "SELECT version FROM schema") + if err != nil { + return 0, errors.Wrapf(err, "failed querying schema table") + } + if len(version) == 0 { + _, err := db.Exec("INSERT INTO schema (version) VALUES ($1)", currSchemaVersion) + if err != nil { + return 0, errors.Wrapf(err, "failed setting schema version in database") + } + version = append(version, currSchemaVersion) + } + return version[0], nil +} + +func migrateSchema(db *sqlx.DB, schemaVersion int) error { + return errors.New("SCHEMA MIGRATION NOT IMPLEMENTED AT THIS TIME :(") } diff --git a/persist/user.go b/persist/user.go new file mode 100644 index 000000000..bf6caeeb3 --- /dev/null +++ b/persist/user.go @@ -0,0 +1,78 @@ +// Copyright 2017 Palantir Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +import ( + "fmt" + "strings" + + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" +) + +type User struct { + GithubID int `db:"github_id"` + Name string `db:"name"` + Token string `db:"token"` +} + +func (*User) InsertStmt() string { + return "INSERT INTO USERS (github_id, name, token) VALUES (:github_id, :name, :token)" +} + +func (*User) DeleteStmt() string { + return "DELETE FROM USERS WHERE github_id = :github_id" +} + +func (*User) UpdateStmt() string { + return "UPDATE USERS SET token = :token WHERE github_id = :github_id" +} + +func GetUserByName(db *sqlx.DB, name string) (*User, error) { + u := &User{} + + q := fmt.Sprintf("SELECT * FROM USERS WHERE name='%s'", name) + row := db.QueryRowx(q) + err := row.StructScan(u) + + if err != nil { + if strings.Contains(err.Error(), "no rows in result set") { + return nil, nil + } + + return nil, errors.Wrapf(err, "cannot get user %s", name) + } + + return u, nil +} + +func GetUserByID(db *sqlx.DB, id int) (*User, error) { + u := &User{} + + q := fmt.Sprintf("SELECT * FROM USERS WHERE github_id='%d'", id) + row := db.QueryRowx(q) + err := row.StructScan(u) + + if err != nil { + return nil, errors.Wrapf(err, "cannot get user %d", id) + } + + return u, nil +} + +func UpdateUserToken(db *sqlx.DB, id int, token string) error { + q := fmt.Sprintf("UPDATE USERS SET token='%s' WHERE github_id='%d'", token, id) + return db.QueryRowx(q).Err() +} diff --git a/server/endpoints/hook.go b/server/endpoints/hook.go index d5e42ce67..f4f873081 100644 --- a/server/endpoints/hook.go +++ b/server/endpoints/hook.go @@ -20,7 +20,7 @@ import ( "strings" "github.com/google/go-github/github" - "github.com/jinzhu/gorm" + "github.com/jmoiron/sqlx" "github.com/labstack/echo" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -31,7 +31,7 @@ import ( "github.com/palantir/bulldozer/server/config" ) -func Hook(db *gorm.DB, secret string) echo.HandlerFunc { +func Hook(db *sqlx.DB, secret string) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) @@ -42,18 +42,18 @@ func Hook(db *gorm.DB, secret string) echo.HandlerFunc { logger.Debugf("ProcessHook returned %+v", result) - var dbRepo persist.Repository - res := db.Where("github_id = ?", result.RepoID).First(&dbRepo) - if err := res.Error; err != nil { - return errors.Wrap(err, "cannot get repository from db") + dbRepo, err := persist.GetRepositoryByID(db, result.RepoID) + if err != nil { + return errors.Wrapf(err, "cannot get repo with id %d from database", result.RepoID) } - if res.RecordNotFound() { - return c.String(http.StatusOK, "Repository not enabled for bulldozer") + + if dbRepo == nil { + return errors.Wrapf(err, "repository with ID not enabled", result.RepoID) } - var user persist.User - if err := db.First(&dbRepo.EnabledBy).Error; err != nil { - return errors.Wrapf(err, "cannot get user %s from database", dbRepo.EnabledBy.Name) + user, err := persist.GetUserByName(db, dbRepo.EnabledBy) + if err != nil { + return errors.Wrapf(err, "cannot get user %s from database", dbRepo.EnabledBy) } ghClient := gh.FromToken(c, user.Token) diff --git a/server/endpoints/repositories.go b/server/endpoints/repositories.go index 5a36c2f3d..cac5a534b 100644 --- a/server/endpoints/repositories.go +++ b/server/endpoints/repositories.go @@ -21,7 +21,7 @@ import ( "time" "github.com/google/go-github/github" - "github.com/jinzhu/gorm" + "github.com/jmoiron/sqlx" "github.com/labstack/echo" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -41,7 +41,7 @@ type Repository struct { EnabledAt string `json:"enabledAt,omitempty"` } -func worker(c echo.Context, db *gorm.DB, wg *sync.WaitGroup, repo *github.Repository, repoc chan *Repository, user *github.User, client *gh.Client) { +func worker(c echo.Context, db *sqlx.DB, wg *sync.WaitGroup, repo *github.Repository, repoc chan *Repository, user *github.User, client *gh.Client) { logger := log.FromContext(c) defer wg.Done() @@ -57,19 +57,18 @@ func worker(c echo.Context, db *gorm.DB, wg *sync.WaitGroup, repo *github.Reposi isAdmin = perm.GetPermission() == "admin" } - var repository persist.Repository - result := db.Where("github_id = ?", repo.GetID()).First(&repository) - if err := result.Error; err != nil && err != gorm.ErrRecordNotFound { + repository, err := persist.GetRepositoryByID(db, repo.GetID()) + if err != nil { logger.WithFields(logrus.Fields{ "repo": repo.GetFullName(), - }).Error(errors.Wrap(err, "Cannot get repository from db")) + }).Error(errors.Wrap(err, "Cannot get repository from database")) return } - if !result.RecordNotFound() { + if repository != nil { isEnabled = true - enabledBy = repository.EnabledBy.Name - enabledAt = repository.EnabledAt.Format(time.RFC3339) + enabledBy = repository.EnabledBy + enabledAt = time.Unix(repository.EnabledAt, 0).Format(time.RFC3339) } repoc <- &Repository{ @@ -83,7 +82,7 @@ func worker(c echo.Context, db *gorm.DB, wg *sync.WaitGroup, repo *github.Reposi } } -func Repositories(db *gorm.DB) echo.HandlerFunc { +func Repositories(db *sqlx.DB) echo.HandlerFunc { return func(c echo.Context) error { var repositories []*Repository var wg sync.WaitGroup @@ -123,7 +122,7 @@ func Repositories(db *gorm.DB) echo.HandlerFunc { } } -func RepositoryEnable(db *gorm.DB, webHookURL string, webHookSecret string) echo.HandlerFunc { +func RepositoryEnable(db *sqlx.DB, webHookURL string, webHookSecret string) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) @@ -137,12 +136,6 @@ func RepositoryEnable(db *gorm.DB, webHookURL string, webHookSecret string) echo return errors.Wrap(err, "cannot get current user") } - var dbUser persist.User - result := db.Where("github_id = ?", user.GetID()).First(&dbUser) - if err := result.Error; err != nil && err != gorm.ErrRecordNotFound { - return errors.Wrap(err, "cannot get current user from db") - } - owner := c.Param("owner") name := c.Param("name") repo, _, err := client.Repositories.Get(client.Ctx, owner, name) @@ -189,20 +182,20 @@ func RepositoryEnable(db *gorm.DB, webHookURL string, webHookSecret string) echo }).Info("Created hook on repository") dbRepo := &persist.Repository{ - GitHubID: repo.GetID(), + ID: repo.GetID(), Name: repo.GetFullName(), - EnabledAt: time.Now().UTC(), - EnabledBy: dbUser, + EnabledAt: time.Now().UTC().Unix(), + EnabledBy: user.GetLogin(), HookID: hook.GetID(), } - result = db.Create(dbRepo) - if err := result.Error; err != nil { + err = persist.Put(db, dbRepo) + if err != nil { _, e := client.Repositories.DeleteHook(client.Ctx, owner, name, hook.GetID()) if e != nil { - logger.Error(errors.Wrapf(err, "cannot delete hook on %s/%s (repo not saved to db)", owner, name)) + logger.Error(errors.Wrapf(err, "cannot delete hook on %s/%s (repo not saved to DB)", owner, name)) } - return errors.Wrapf(err, "cannot add %s/%s to the db", owner, name) + return errors.Wrapf(err, "cannot add %s/%s to the database", owner, name) } data := struct { @@ -227,7 +220,7 @@ func RepositoryEnable(db *gorm.DB, webHookURL string, webHookSecret string) echo } } -func RepositoryDisable(db *gorm.DB) echo.HandlerFunc { +func RepositoryDisable(db *sqlx.DB) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) @@ -263,15 +256,13 @@ func RepositoryDisable(db *gorm.DB) echo.HandlerFunc { "user": user.GetLogin(), }).Debug("Deleting hook from repository") - var repository persist.Repository - result := db.Where("github_id = ?", repo.GetID()).First(&repository) - if err := result.Error; err != nil { - return errors.Wrap(err, "cannot get repository from db") + dbRepo, err := persist.GetRepositoryByID(db, repo.GetID()) + if err != nil { + return errors.Wrapf(err, "cannot get repo with ID %d from database", repo.GetID()) } - - _, err = client.Repositories.DeleteHook(client.Ctx, owner, name, repository.HookID) + _, err = client.Repositories.DeleteHook(client.Ctx, owner, name, dbRepo.HookID) if err != nil { - return errors.Wrapf(err, "cannot delete hook %d for %s/%s via %s", owner, name, repository.HookID, user.GetLogin()) + return errors.Wrapf(err, "cannot delete hook %d for %s/%s via %s", owner, name, dbRepo.HookID, user.GetLogin()) } logger.WithFields(logrus.Fields{ @@ -279,7 +270,8 @@ func RepositoryDisable(db *gorm.DB) echo.HandlerFunc { "user": user.GetLogin(), }).Info("Deleted hook from repository") - if err := db.Delete(&repository).Error; err != nil { + err = persist.Delete(db, dbRepo) + if err != nil { return errors.Wrapf(err, "cannot remove %s/%s from database", owner, name) } diff --git a/server/endpoints/token.go b/server/endpoints/token.go index d7ae2e817..31ba757f2 100644 --- a/server/endpoints/token.go +++ b/server/endpoints/token.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/jinzhu/gorm" + "github.com/jmoiron/sqlx" "github.com/labstack/echo" "github.com/pkg/errors" @@ -28,9 +28,10 @@ import ( "github.com/palantir/bulldozer/persist" ) -func Token(db *gorm.DB) echo.HandlerFunc { +func Token(db *sqlx.DB) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) + token, err := auth.GithubOauthConfig.Exchange(context.TODO(), c.QueryParam("code")) if err != nil { return errors.Wrap(err, "Cannot get code from GitHub") @@ -43,22 +44,22 @@ func Token(db *gorm.DB) echo.HandlerFunc { return errors.Wrap(err, "Cannot get user from token") } - var user persist.User - result := db.Where("github_id = ?", u.GetID()).First(&user) - if err := result.Error; err != nil && err != gorm.ErrRecordNotFound { - return errors.Wrap(err, "cannot get user from db") - } - if result.RecordNotFound() { - db.Create(&persist.User{ - GitHubID: u.GetID(), + user, err := persist.GetUserByID(db, u.GetID()) + if err != nil { + dbUser := &persist.User{ + GithubID: u.GetID(), Name: u.GetLogin(), Token: accessToken, - }) + } + if err := persist.Put(db, dbUser); err != nil { + return errors.Wrapf(err, "Cannot add %s to the database", u.GetLogin()) + } } else { - logger.Infof("%+v", user) - user.Token = accessToken - if err := db.Save(&user).Error; err != nil { - return errors.Wrapf(err, "cannot update token for user %s", u.GetLogin()) + if user.Token != accessToken { + if err := persist.UpdateUserToken(db, u.GetID(), accessToken); err != nil { + return errors.Wrapf(err, "Cannot update token for user %s", u.GetLogin()) + } + logger.Debugf("Updated token for user %s", u.GetLogin()) } } diff --git a/server/init.go b/server/init.go index ebb220e84..3971bad2c 100644 --- a/server/init.go +++ b/server/init.go @@ -17,8 +17,8 @@ package server import ( "fmt" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/postgres" // Import for side-effects + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" // postgres bindings "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -26,7 +26,7 @@ import ( "github.com/palantir/bulldozer/server/config" ) -func InitDB(dbc *config.DatabaseConfig) (*gorm.DB, error) { +func InitDB(dbc *config.DatabaseConfig) (*sqlx.DB, error) { connectStr := fmt.Sprintf("host=%s dbname=%s user=%s sslmode=%s", dbc.Host, dbc.DBName, dbc.Username, dbc.SSLMode) log.WithFields(log.Fields{ "connectionString": connectStr, @@ -36,11 +36,11 @@ func InitDB(dbc *config.DatabaseConfig) (*gorm.DB, error) { connectStr += fmt.Sprintf(" password=%s", dbc.Password) } - db, err := gorm.Open("postgres", connectStr) + db, err := sqlx.Connect("postgres", connectStr) if err != nil { return nil, errors.Wrapf(err, "failed connecting to postgres") } - persist.InitializeSchema(db) - return db, nil + err = persist.InitializeSchema(db) + return db, errors.Wrap(err, "failed to initialize schema") } diff --git a/server/server.go b/server/server.go index 4acc27882..8b6766b65 100644 --- a/server/server.go +++ b/server/server.go @@ -21,7 +21,7 @@ import ( "strings" "github.com/ipfans/echo-session" - "github.com/jinzhu/gorm" + "github.com/jmoiron/sqlx" "github.com/labstack/echo" "github.com/labstack/echo/middleware" "github.com/pkg/errors" @@ -38,7 +38,7 @@ type Server struct { e *echo.Echo } -func New(db *gorm.DB, startup *config.Startup) *Server { +func New(db *sqlx.DB, startup *config.Startup) *Server { e := echo.New() e.Use(bm.ContextMiddleware) @@ -58,7 +58,7 @@ func New(db *gorm.DB, startup *config.Startup) *Server { return &Server{startup.Server, e} } -func registerEndpoints(startup *config.Startup, e *echo.Echo, db *gorm.DB) { +func registerEndpoints(startup *config.Startup, e *echo.Echo, db *sqlx.DB) { e.Static("/", startup.AssetDir) e.GET("/repositories", func(c echo.Context) error { diff --git a/vendor/github.com/jinzhu/gorm/.codeclimate.yml b/vendor/github.com/jinzhu/gorm/.codeclimate.yml deleted file mode 100644 index 51aba50cb..000000000 --- a/vendor/github.com/jinzhu/gorm/.codeclimate.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -engines: - gofmt: - enabled: true - govet: - enabled: true - golint: - enabled: true -ratings: - paths: - - "**.go" diff --git a/vendor/github.com/jinzhu/gorm/.gitignore b/vendor/github.com/jinzhu/gorm/.gitignore deleted file mode 100644 index 01dc5ce07..000000000 --- a/vendor/github.com/jinzhu/gorm/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -documents -_book diff --git a/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md b/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md deleted file mode 100644 index c54d572d2..000000000 --- a/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md +++ /dev/null @@ -1,52 +0,0 @@ -# How to Contribute - -## Bug Report - -- Do a search on GitHub under Issues in case it has already been reported -- Submit __executable script__ or failing test pull request that could demonstrates the issue is *MUST HAVE* - -## Feature Request - -- Feature request with pull request is welcome -- Or it won't be implemented until I (other developers) find it is helpful for my (their) daily work - -## Pull Request - -- Prefer single commit pull request, that make the git history can be a bit easier to follow. -- New features need to be covered with tests to make sure your code works as expected, and won't be broken by others in future - -## Contributing to Documentation - -- You are welcome ;) -- You can help improve the README by making them more coherent, consistent or readable, and add more godoc documents to make people easier to follow. -- Blogs & Usage Guides & PPT also welcome, please add them to https://github.com/jinzhu/gorm/wiki/Guides - -### Executable script template - -```go -package main - -import ( - _ "github.com/mattn/go-sqlite3" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - "github.com/jinzhu/gorm" -) - -var db gorm.DB - -func init() { - var err error - db, err = gorm.Open("sqlite3", "test.db") - // db, err := gorm.Open("postgres", "user=username dbname=password sslmode=disable") - // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") - if err != nil { - panic(err) - } - db.LogMode(true) -} - -func main() { - // Your code -} -``` diff --git a/vendor/github.com/jinzhu/gorm/License b/vendor/github.com/jinzhu/gorm/License deleted file mode 100644 index 037e1653e..000000000 --- a/vendor/github.com/jinzhu/gorm/License +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2013-NOW Jinzhu - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/vendor/github.com/jinzhu/gorm/README.md b/vendor/github.com/jinzhu/gorm/README.md deleted file mode 100644 index c3f209c9d..000000000 --- a/vendor/github.com/jinzhu/gorm/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# GORM - -The fantastic ORM library for Golang, aims to be developer friendly. - -[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) - -## Overview - -* Full-Featured ORM (almost) -* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) -* Callbacks (Before/After Create/Save/Update/Delete/Find) -* Preloading (eager loading) -* Transactions -* Composite Primary Key -* SQL Builder -* Auto Migrations -* Logger -* Extendable, write Plugins based on GORM callbacks -* Every feature comes with tests -* Developer Friendly - -## Getting Started - -* GORM Guides [jinzhu.github.com/gorm](https://jinzhu.github.io/gorm) - -## Upgrading To V1.0 - -* [CHANGELOG](https://jinzhu.github.io/gorm/changelog.html) - -# Author - -**jinzhu** - -* -* -* - -# Contributors - -https://github.com/jinzhu/gorm/graphs/contributors - -## License - -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). diff --git a/vendor/github.com/jinzhu/gorm/association.go b/vendor/github.com/jinzhu/gorm/association.go deleted file mode 100644 index cd8fd9125..000000000 --- a/vendor/github.com/jinzhu/gorm/association.go +++ /dev/null @@ -1,359 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Association Mode contains some helper methods to handle relationship things easily. -type Association struct { - Error error - scope *Scope - column string - field *Field -} - -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} - -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for _, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - } else { - // If other relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - return association -} - -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break - } - } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break - } - } - } - } - - return association -} - -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { - return association.Replace() -} - -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - if relationship.Kind == "many_to_many" { - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } else if relationship.Kind == "belongs_to" { - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - scope.TableName(), - ) - } - - query.Model(fieldValue).Count(&count) - return count -} - -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { - var ( - scope = association.scope - field = association.field - relationship = field.Relationship - ) - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} diff --git a/vendor/github.com/jinzhu/gorm/association_test.go b/vendor/github.com/jinzhu/gorm/association_test.go deleted file mode 100644 index 52d2303f6..000000000 --- a/vendor/github.com/jinzhu/gorm/association_test.go +++ /dev/null @@ -1,842 +0,0 @@ -package gorm_test - -import ( - "fmt" - "reflect" - "sort" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestBelongsTo(t *testing.T) { - post := Post{ - Title: "post belongs to", - Body: "body belongs to", - Category: Category{Name: "Category 1"}, - MainCategory: Category{Name: "Main Category 1"}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - if post.Category.ID == 0 || post.MainCategory.ID == 0 { - t.Errorf("Category's primary key should be updated") - } - - if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 { - t.Errorf("post's foreign key should be updated") - } - - // Query - var category1 Category - DB.Model(&post).Association("Category").Find(&category1) - if category1.Name != "Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var mainCategory1 Category - DB.Model(&post).Association("MainCategory").Find(&mainCategory1) - if mainCategory1.Name != "Main Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var category11 Category - DB.Model(&post).Related(&category11) - if category11.Name != "Category 1" { - t.Errorf("Query belongs to relations with Related") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - if DB.Model(&post).Association("MainCategory").Count() != 1 { - t.Errorf("Post's main category count should be 1") - } - - // Append - var category2 = Category{ - Name: "Category 2", - } - DB.Model(&post).Association("Category").Append(&category2) - - if category2.ID == 0 { - t.Errorf("Category should has ID when created with Append") - } - - var category21 Category - DB.Model(&post).Related(&category21) - - if category21.Name != "Category 2" { - t.Errorf("Category should be updated with Append") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Replace - var category3 = Category{ - Name: "Category 3", - } - DB.Model(&post).Association("Category").Replace(&category3) - - if category3.ID == 0 { - t.Errorf("Category should has ID when created with Replace") - } - - var category31 Category - DB.Model(&post).Related(&category31) - if category31.Name != "Category 3" { - t.Errorf("Category should be updated with Replace") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Delete - DB.Model(&post).Association("Category").Delete(&category2) - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not delete any category when Delete a unrelated Category") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should not be reseted when Delete a unrelated Category") - } - - DB.Model(&post).Association("Category").Delete(&category3) - - if post.Category.Name != "" { - t.Errorf("Post's category should be reseted after Delete") - } - - var category41 Category - DB.Model(&post).Related(&category41) - if category41.Name != "" { - t.Errorf("Category should be deleted with Delete") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Delete, but got %v", count) - } - - // Clear - DB.Model(&post).Association("Category").Append(&Category{ - Name: "Category 2", - }) - - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should find category after append") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should has value after Append") - } - - DB.Model(&post).Association("Category").Clear() - - if post.Category.Name != "" { - t.Errorf("Post's category should be cleared after Clear") - } - - if !DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not find any category after Clear") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Clear, but got %v", count) - } - - // Check Association mode with soft delete - category6 := Category{ - Name: "Category 6", - } - DB.Model(&post).Association("Category").Append(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 after Append, but got %v", count) - } - - DB.Delete(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) - } - - if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { - t.Errorf("Post's category is not findable after Delete") - } - - if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { - t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) - } -} - -func TestBelongsToOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileRefer"` - ProfileRefer int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestBelongsToOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Refer string - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` - ProfileID int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOne(t *testing.T) { - user := User{ - Name: "has one", - CreditCard: CreditCard{Number: "411111111111"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Error("Got errors when save user", err.Error()) - } - - if user.CreditCard.UserId.Int64 == 0 { - t.Errorf("CreditCard's foreign key should be updated") - } - - // Query - var creditCard1 CreditCard - DB.Model(&user).Related(&creditCard1) - - if creditCard1.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - var creditCard11 CreditCard - DB.Model(&user).Association("CreditCard").Find(&creditCard11) - - if creditCard11.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Append - var creditcard2 = CreditCard{ - Number: "411111111112", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard2) - - if creditcard2.ID == 0 { - t.Errorf("Creditcard should has ID when created with Append") - } - - var creditcard21 CreditCard - DB.Model(&user).Related(&creditcard21) - if creditcard21.Number != "411111111112" { - t.Errorf("CreditCard should be updated with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Replace - var creditcard3 = CreditCard{ - Number: "411111111113", - } - DB.Model(&user).Association("CreditCard").Replace(&creditcard3) - - if creditcard3.ID == 0 { - t.Errorf("Creditcard should has ID when created with Replace") - } - - var creditcard31 CreditCard - DB.Model(&user).Related(&creditcard31) - if creditcard31.Number != "411111111113" { - t.Errorf("CreditCard should be updated with Replace") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Delete - DB.Model(&user).Association("CreditCard").Delete(&creditcard2) - var creditcard4 CreditCard - DB.Model(&user).Related(&creditcard4) - if creditcard4.Number != "411111111113" { - t.Errorf("Should not delete credit card when Delete a unrelated CreditCard") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Delete(&creditcard3) - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should delete credit card with Delete") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Delete") - } - - // Clear - var creditcard5 = CreditCard{ - Number: "411111111115", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard5) - - if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should added credit card with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Clear() - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Credit card should be deleted with Clear") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Clear") - } - - // Check Association mode with soft delete - var creditcard6 = CreditCard{ - Number: "411111111116", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 after Append, but got %v", count) - } - - DB.Delete(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { - t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) - } - - if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { - t.Errorf("User's creditcard is not findable after Delete") - } - - if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { - t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) - } -} - -func TestHasOneOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOneOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasMany(t *testing.T) { - post := Post{ - Title: "post has many", - Body: "body has many", - Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - for _, comment := range post.Comments { - if comment.PostId == 0 { - t.Errorf("comment's PostID should be updated") - } - } - - var compareComments = func(comments []Comment, contents []string) bool { - var commentContents []string - for _, comment := range comments { - commentContents = append(commentContents, comment.Content) - } - sort.Strings(commentContents) - sort.Strings(contents) - return reflect.DeepEqual(commentContents, contents) - } - - // Query - if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { - t.Errorf("Comment 1 should be saved") - } - - var comments1 []Comment - DB.Model(&post).Association("Comments").Find(&comments1) - if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Association") - } - - var comments11 []Comment - DB.Model(&post).Related(&comments11) - if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Related") - } - - if DB.Model(&post).Association("Comments").Count() != 2 { - t.Errorf("Post's comments count should be 2") - } - - // Append - DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) - - var comments2 []Comment - DB.Model(&post).Related(&comments2) - if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { - t.Errorf("Append new record to has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 3 { - t.Errorf("Post's comments count should be 3 after Append") - } - - // Delete - DB.Model(&post).Association("Comments").Delete(comments11) - - var comments3 []Comment - DB.Model(&post).Related(&comments3) - if !compareComments(comments3, []string{"Comment 3"}) { - t.Errorf("Delete an existing resource for has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 1 { - t.Errorf("Post's comments count should be 1 after Delete 2") - } - - // Replace - DB.Model(&Post{Id: 999}).Association("Comments").Replace() - - var comments4 []Comment - DB.Model(&post).Related(&comments4) - if len(comments4) == 0 { - t.Errorf("Replace for other resource should not clear all comments") - } - - DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"}) - - var comments41 []Comment - DB.Model(&post).Related(&comments41) - if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) { - t.Errorf("Replace has many relations") - } - - // Clear - DB.Model(&Post{Id: 999}).Association("Comments").Clear() - - var comments5 []Comment - DB.Model(&post).Related(&comments5) - if len(comments5) == 0 { - t.Errorf("Clear should not clear all comments") - } - - DB.Model(&post).Association("Comments").Clear() - - var comments51 []Comment - DB.Model(&post).Related(&comments51) - if len(comments51) != 0 { - t.Errorf("Clear has many relations") - } - - // Check Association mode with soft delete - var comment6 = Comment{ - Content: "comment 6", - } - DB.Model(&post).Association("Comments").Append(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 after Append, but got %v", count) - } - - DB.Delete(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 0 { - t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) - } - - var comments6 []Comment - if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { - t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) - } - - if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) - } - - var comments61 []Comment - if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) - } -} - -func TestHasManyOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile []Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasManyOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestManyToMany(t *testing.T) { - DB.Raw("delete from languages") - var languages = []Language{{Name: "ZH"}, {Name: "EN"}} - user := User{Name: "Many2Many", Languages: languages} - DB.Save(&user) - - // Query - var newLanguages []Language - DB.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Query many to many relations") - } - - DB.Model(&user).Association("Languages").Find(&newLanguages) - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Should be able to find many to many relations") - } - - if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { - t.Errorf("Count should return correct result") - } - - // Append - DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) - if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { - t.Errorf("New record should be saved when append") - } - - languageA := Language{Name: "AA"} - DB.Save(&languageA) - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) - - languageC := Language{Name: "CC"} - DB.Save(&languageC) - DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) - - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) - - totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { - t.Errorf("All appended languages should be saved") - } - - // Delete - user.Languages = []Language{} - DB.Model(&user).Association("Languages").Find(&user.Languages) - - var language Language - DB.Where("name = ?", "EE").First(&language) - DB.Model(&user).Association("Languages").Delete(language, &language) - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { - t.Errorf("Relations should be deleted with Delete") - } - if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { - t.Errorf("Language EE should not be deleted") - } - - DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) - - user2 := User{Name: "Many2Many_User2", Languages: languages} - DB.Save(&user2) - - DB.Model(&user).Association("Languages").Delete(languages, &languages) - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { - t.Errorf("Relations should be deleted with Delete") - } - - if DB.Model(&user2).Association("Languages").Count() == 0 { - t.Errorf("Other user's relations should not be deleted") - } - - // Replace - var languageB Language - DB.Where("name = ?", "BB").First(&languageB) - DB.Model(&user).Association("Languages").Replace(languageB) - if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { - t.Errorf("Relations should be replaced") - } - - DB.Model(&user).Association("Languages").Replace() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be replaced with empty") - } - - DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) - if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { - t.Errorf("Relations should be replaced") - } - - // Clear - DB.Model(&user).Association("Languages").Clear() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be cleared") - } - - // Check Association mode with soft delete - var language6 = Language{ - Name: "language 6", - } - DB.Model(&user).Association("Languages").Append(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 after Append, but got %v", count) - } - - DB.Delete(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 0 { - t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) - } - - var languages6 []Language - if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { - t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) - } - - if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) - } - - var languages61 []Language - if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) - } -} - -func TestRelated(t *testing.T) { - user := User{ - Name: "jinzhu", - BillingAddress: Address{Address1: "Billing Address - Address 1"}, - ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, - CreditCard: CreditCard{Number: "1234567890"}, - Company: Company{Name: "company1"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Errorf("No error should happen when saving user") - } - - if user.CreditCard.ID == 0 { - t.Errorf("After user save, credit card should have id") - } - - if user.BillingAddress.ID == 0 { - t.Errorf("After user save, billing address should have id") - } - - if user.Emails[0].Id == 0 { - t.Errorf("After user save, billing address should have id") - } - - var emails []Email - DB.Model(&user).Related(&emails) - if len(emails) != 2 { - t.Errorf("Should have two emails") - } - - var emails2 []Email - DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) - if len(emails2) != 1 { - t.Errorf("Should have two emails") - } - - var emails3 []*Email - DB.Model(&user).Related(&emails3) - if len(emails3) != 2 { - t.Errorf("Should have two emails") - } - - var user1 User - DB.Model(&user).Related(&user1.Emails) - if len(user1.Emails) != 2 { - t.Errorf("Should have only one email match related condition") - } - - var address1 Address - DB.Model(&user).Related(&address1, "BillingAddressId") - if address1.Address1 != "Billing Address - Address 1" { - t.Errorf("Should get billing address from user correctly") - } - - user1 = User{} - DB.Model(&address1).Related(&user1, "BillingAddressId") - if DB.NewRecord(user1) { - t.Errorf("Should get user from address correctly") - } - - var user2 User - DB.Model(&emails[0]).Related(&user2) - if user2.Id != user.Id || user2.Name != user.Name { - t.Errorf("Should get user from email correctly") - } - - var creditcard CreditCard - var user3 User - DB.First(&creditcard, "number = ?", "1234567890") - DB.Model(&creditcard).Related(&user3) - if user3.Id != user.Id || user3.Name != user.Name { - t.Errorf("Should get user from credit card correctly") - } - - if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() { - t.Errorf("RecordNotFound for Related") - } - - var company Company - if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" { - t.Errorf("RecordNotFound for Related") - } -} - -func TestForeignKey(t *testing.T) { - for _, structField := range DB.NewScope(&User{}).GetStructFields() { - for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Email{}).GetStructFields() { - for _, foreignKey := range []string{"UserId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Post{}).GetStructFields() { - for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { - for _, foreignKey := range []string{"PostId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback.go b/vendor/github.com/jinzhu/gorm/callback.go deleted file mode 100644 index 93198a71a..000000000 --- a/vendor/github.com/jinzhu/gorm/callback.go +++ /dev/null @@ -1,237 +0,0 @@ -package gorm - -import ( - "fmt" -) - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} - -// Callback is a struct that contains all CURD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone() *Callback { - return &Callback{ - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("Created", now) -// scope.SetColumn("Updated", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind && !cp.remove { - return *p.processor - } - } - return nil -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CURD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/vendor/github.com/jinzhu/gorm/callback_create.go b/vendor/github.com/jinzhu/gorm/callback_create.go deleted file mode 100644 index e3cd2f0b4..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_create.go +++ /dev/null @@ -1,144 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := NowFunc() - scope.SetColumn("CreatedAt", now) - scope.SetColumn("UpdatedAt", now) - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal { - if !field.IsPrimaryKey || !field.IsBlank { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v DEFAULT VALUES%v%v", - quotedTableName, - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", - scope.QuotedTableName(), - strings.Join(columns, ","), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - } else { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - scope.db.RowsAffected = 1 - } - } - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_delete.go b/vendor/github.com/jinzhu/gorm/callback_delete.go deleted file mode 100644 index c8ffcc821..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_delete.go +++ /dev/null @@ -1,53 +0,0 @@ -package gorm - -import "fmt" - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET deleted_at=%v%v%v", - scope.QuotedTableName(), - scope.AddToVars(NowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_query.go b/vendor/github.com/jinzhu/gorm/callback_query.go deleted file mode 100644 index 93782b1dc..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_query.go +++ /dev/null @@ -1,93 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = reflect.Indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_query_preload.go b/vendor/github.com/jinzhu/gorm/callback_query_preload.go deleted file mode 100644 index 5746f533a..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_query_preload.go +++ /dev/null @@ -1,310 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - currentFields = currentScope.Fields() - } - } - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) { - objectField := object.FieldByName(field.Name) - objectField.Set(reflect.Append(objectField, result)) - break - } - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Select("*") - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) - } - } else if indirectScopeValue.IsValid() { - fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) - } - - for source, link := range linkHash { - fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_save.go b/vendor/github.com/jinzhu/gorm/callback_save.go deleted file mode 100644 index 5ffe53b97..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_save.go +++ /dev/null @@ -1,92 +0,0 @@ -package gorm - -import "reflect" - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) - } - - scope.Err(newDB.Save(elem).Error) - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) - } - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_system_test.go b/vendor/github.com/jinzhu/gorm/callback_system_test.go deleted file mode 100644 index 13ca3f428..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_system_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package gorm - -import ( - "reflect" - "runtime" - "strings" - "testing" -) - -func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { - var names []string - for _, f := range funcs { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") - names = append(names, fnames[len(fnames)-1]) - } - return reflect.DeepEqual(names, fnames) -} - -func create(s *Scope) {} -func beforeCreate1(s *Scope) {} -func beforeCreate2(s *Scope) {} -func afterCreate1(s *Scope) {} -func afterCreate2(s *Scope) {} - -func TestRegisterCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("before_create2", beforeCreate2) - callback.Create().Register("create", create) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Register("after_create2", afterCreate2) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback") - } -} - -func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{} - callback1.Create().Register("before_create1", beforeCreate1) - callback1.Create().Register("create", create) - callback1.Create().Register("after_create1", afterCreate1) - callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{} - - callback2.Update().Register("create", create) - callback2.Update().Before("create").Register("before_create1", beforeCreate1) - callback2.Update().After("after_create2").Register("after_create1", afterCreate1) - callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) - callback2.Update().Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } -} - -func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{} - - callback1.Query().Before("after_create1").After("before_create1").Register("create", create) - callback1.Query().Register("before_create1", beforeCreate1) - callback1.Query().Register("after_create1", afterCreate1) - - if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{} - - callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - callback2.Delete().Register("after_create1", afterCreate1) - callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback with order") - } -} - -func replaceCreate(s *Scope) {} - -func TestReplaceCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Replace("create", replaceCreate) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { - t.Errorf("replace callback") - } -} - -func TestRemoveCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Remove("create") - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { - t.Errorf("remove callback") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_update.go b/vendor/github.com/jinzhu/gorm/callback_update.go deleted file mode 100644 index aa27b5fb7..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_update.go +++ /dev/null @@ -1,104 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callbacks_test.go b/vendor/github.com/jinzhu/gorm/callbacks_test.go deleted file mode 100644 index a58913d76..000000000 --- a/vendor/github.com/jinzhu/gorm/callbacks_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package gorm_test - -import ( - "errors" - - "github.com/jinzhu/gorm" - - "reflect" - "testing" -) - -func (s *Product) BeforeCreate() (err error) { - if s.Code == "Invalid" { - err = errors.New("invalid product") - } - s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 - return -} - -func (s *Product) BeforeUpdate() (err error) { - if s.Code == "dont_update" { - err = errors.New("can't update") - } - s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 - return -} - -func (s *Product) BeforeSave() (err error) { - if s.Code == "dont_save" { - err = errors.New("can't save") - } - s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 - return -} - -func (s *Product) AfterFind() { - s.AfterFindCallTimes = s.AfterFindCallTimes + 1 -} - -func (s *Product) AfterCreate(tx *gorm.DB) { - tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) -} - -func (s *Product) AfterUpdate() { - s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 -} - -func (s *Product) AfterSave() (err error) { - if s.Code == "after_save_error" { - err = errors.New("can't save") - } - s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 - return -} - -func (s *Product) BeforeDelete() (err error) { - if s.Code == "dont_delete" { - err = errors.New("can't delete") - } - s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 - return -} - -func (s *Product) AfterDelete() (err error) { - if s.Code == "after_delete_error" { - err = errors.New("can't delete") - } - s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 - return -} - -func (s *Product) GetCallTimes() []int64 { - return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} -} - -func TestRunCallbacks(t *testing.T) { - p := Product{Code: "unique_code", Price: 100} - DB.Save(&p) - - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { - t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) - } - - p.Price = 200 - DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { - t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - var products []Product - DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should work with slice") - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { - t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) - } - - DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { - t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { - t.Errorf("Can't find a deleted record") - } -} - -func TestCallbacksWithErrors(t *testing.T) { - p := Product{Code: "Invalid", Price: 100} - if DB.Save(&p).Error == nil { - t.Errorf("An error from before create callbacks happened when create with invalid value") - } - - if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { - t.Errorf("Should not save record that have errors") - } - - if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { - t.Errorf("An error from after create callbacks happened when create with invalid value") - } - - p2 := Product{Code: "update_callback", Price: 100} - DB.Save(&p2) - - p2.Code = "dont_update" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before update callbacks happened when update with invalid value") - } - - if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - p2.Code = "dont_save" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before save callbacks happened when update with invalid value") - } - - p3 := Product{Code: "dont_delete", Price: 100} - DB.Save(&p3) - if DB.Delete(&p3).Error == nil { - t.Errorf("An error from before delete callbacks happened when delete") - } - - if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { - t.Errorf("An error from before delete callbacks happened") - } - - p4 := Product{Code: "after_save_error", Price: 100} - DB.Save(&p4) - if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { - t.Errorf("Record should be reverted if get an error in after save callback") - } - - p5 := Product{Code: "after_delete_error", Price: 100} - DB.Save(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record should be found") - } - - DB.Delete(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") - } -} diff --git a/vendor/github.com/jinzhu/gorm/create_test.go b/vendor/github.com/jinzhu/gorm/create_test.go deleted file mode 100644 index dc82de50d..000000000 --- a/vendor/github.com/jinzhu/gorm/create_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "testing" - "time" -) - -func TestCreate(t *testing.T) { - float := 35.03554004971999 - user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} - - if !DB.NewRecord(user) || !DB.NewRecord(&user) { - t.Error("User should be new record before create") - } - - if count := DB.Save(&user).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - if DB.NewRecord(user) || DB.NewRecord(&user) { - t.Error("User should not new record after save") - } - - var newUser User - DB.First(&newUser, user.Id) - - if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { - t.Errorf("User's PasswordHash should be saved ([]byte)") - } - - if newUser.Age != 18 { - t.Errorf("User's Age should be saved (int)") - } - - if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type)") - } - - if newUser.Latitude != float { - t.Errorf("Float64 should not be changed after save") - } - - if user.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - if newUser.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - DB.Model(user).Update("name", "create_user_new_name") - DB.First(&user, user.Id) - if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { - t.Errorf("CreatedAt should not be changed after update") - } -} - -func TestCreateWithNoGORMPrimayKey(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { - t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") - } - - jt := JoinTable{From: 1, To: 2} - err := DB.Create(&jt).Error - if err != nil { - t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) - } -} - -func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - if DB.Save(&animal).Error != nil { - t.Errorf("No error should happen when create a record without std primary key") - } - - if animal.Counter == 0 { - t.Errorf("No std primary key should be filled value after create") - } - - if animal.Name != "Ferdinand" { - t.Errorf("Default value should be overrided") - } - - // Test create with default value not overrided - an := Animal{From: "nerdz"} - - if DB.Save(&an).Error != nil { - t.Errorf("No error should happen when create an record without std primary key") - } - - // We must fetch the value again, to have the default fields updated - // (We can't do this in the update statements, since sql default can be expressions - // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" - DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) - - if an.Name != "galeone" { - t.Errorf("Default value should fill the field. But got %v", an.Name) - } -} - -func TestAnonymousScanner(t *testing.T) { - user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_scanner") - if user2.Role.Name != "admin" { - t.Errorf("Should be able to get anonymous scanner") - } - - if !user2.IsAdmin() { - t.Errorf("Should be able to get anonymous scanner") - } -} - -func TestAnonymousField(t *testing.T) { - user := User{Name: "anonymous_field", Company: Company{Name: "company"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_field") - DB.Model(&user2).Related(&user2.Company) - if user2.Company.Name != "company" { - t.Errorf("Should be able to get anonymous field") - } -} - -func TestSelectWithCreate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_create") - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name != user.Name || queryuser.Age == user.Age { - t.Errorf("Should only create users with name column") - } - - if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || - queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { - t.Errorf("Should only create selected relationships") - } -} - -func TestOmitWithCreate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_create") - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name == user.Name || queryuser.Age != user.Age { - t.Errorf("Should only create users with age column") - } - - if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || - queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { - t.Errorf("Should not create omited relationships") - } -} diff --git a/vendor/github.com/jinzhu/gorm/customize_column_test.go b/vendor/github.com/jinzhu/gorm/customize_column_test.go deleted file mode 100644 index 177b4a5de..000000000 --- a/vendor/github.com/jinzhu/gorm/customize_column_test.go +++ /dev/null @@ -1,280 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date time.Time `gorm:"column:mapped_time"` -} - -// Make sure an ignored field does not interfere with another field's custom -// column name that matches the ignored field. -type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` -} - -func TestCustomizeColumn(t *testing.T) { - col := "mapped_name" - DB.DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) - - scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { - t.Errorf("CustomizeColumn should have column %s", col) - } - - col = "mapped_id" - if scope.PrimaryKey() != col { - t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) - } - - expected := "foo" - cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} - - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - var cc1 CustomizeColumn - DB.First(&cc1, 666) - - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } - - cc.Name = "bar" - DB.Save(&cc) - - var cc2 CustomizeColumn - DB.First(&cc2, 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } -} - -func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { - t.Errorf("Should not raise error: %s", err) - } -} - -type CustomizePerson struct { - IdPerson string `gorm:"column:idPerson;primary_key:true"` - Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` -} - -type CustomizeAccount struct { - IdAccount string `gorm:"column:idAccount;primary_key:true"` - Name string -} - -func TestManyToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") - DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) - - account := CustomizeAccount{IdAccount: "account", Name: "id1"} - person := CustomizePerson{ - IdPerson: "person", - Accounts: []CustomizeAccount{account}, - } - - if err := DB.Create(&account).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if err := DB.Create(&person).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - var person1 CustomizePerson - scope := DB.NewScope(nil) - if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { - t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) - } - - if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { - t.Errorf("should preload correct accounts") - } -} - -type CustomizeUser struct { - gorm.Model - Email string `sql:"column:email_address"` -} - -type CustomizeInvitation struct { - gorm.Model - Address string `sql:"column:invitation"` - Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` -} - -func TestOneToOneWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) - DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) - - user := CustomizeUser{ - Email: "hello@example.com", - } - invitation := CustomizeInvitation{ - Address: "hello@example.com", - } - - DB.Create(&user) - DB.Create(&invitation) - - var invitation2 CustomizeInvitation - if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if invitation2.Person.Email != user.Email { - t.Errorf("Should preload one to one relation with customize foreign keys") - } -} - -type PromotionDiscount struct { - gorm.Model - Name string - Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` - Rule *PromotionRule `gorm:"ForeignKey:discount_id"` - Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` -} - -type PromotionBenefit struct { - gorm.Model - Name string - PromotionID uint - Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` -} - -type PromotionCoupon struct { - gorm.Model - Code string - DiscountID uint - Discount PromotionDiscount -} - -type PromotionRule struct { - gorm.Model - Name string - Begin *time.Time - End *time.Time - DiscountID uint - Discount *PromotionDiscount -} - -func TestOneToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) - - discount := PromotionDiscount{ - Name: "Happy New Year", - Coupons: []*PromotionCoupon{ - {Code: "newyear1"}, - {Code: "newyear2"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Coupons) != 2 { - t.Errorf("should find two coupons") - } - - var coupon PromotionCoupon - if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if coupon.Discount.Name != "Happy New Year" { - t.Errorf("should preload discount from coupon") - } -} - -func TestHasOneWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) - - var begin = time.Now() - var end = time.Now().Add(24 * time.Hour) - discount := PromotionDiscount{ - Name: "Happy New Year 2", - Rule: &PromotionRule{ - Name: "time_limited", - Begin: &begin, - End: &end, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { - t.Errorf("Should be able to preload Rule") - } - - var rule PromotionRule - if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if rule.Discount.Name != "Happy New Year 2" { - t.Errorf("should preload discount from rule") - } -} - -func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) - - discount := PromotionDiscount{ - Name: "Happy New Year 3", - Benefits: []PromotionBenefit{ - {Name: "free cod"}, - {Name: "free shipping"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Benefits) != 2 { - t.Errorf("should find two benefits") - } - - var benefit PromotionBenefit - if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if benefit.Discount.Name != "Happy New Year 3" { - t.Errorf("should preload discount from coupon") - } -} diff --git a/vendor/github.com/jinzhu/gorm/delete_test.go b/vendor/github.com/jinzhu/gorm/delete_test.go deleted file mode 100644 index d3de0a6d9..000000000 --- a/vendor/github.com/jinzhu/gorm/delete_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" -) - -func TestDelete(t *testing.T) { - user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if err := DB.Delete(&user1).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } - - if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("Other users that not deleted should be found-able") - } -} - -func TestInlineDelete(t *testing.T) { - user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if DB.Delete(&User{}, user1.Id).Error != nil { - t.Errorf("No error should happen when delete a record") - } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } -} - -func TestSoftDelete(t *testing.T) { - type User struct { - Id int64 - Name string - DeletedAt *time.Time - } - DB.AutoMigrate(&User{}) - - user := User{Name: "soft_delete"} - DB.Save(&user) - DB.Delete(&user) - - if DB.First(&User{}, "name = ?", user.Name).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&user) - if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} diff --git a/vendor/github.com/jinzhu/gorm/dialect.go b/vendor/github.com/jinzhu/gorm/dialect.go deleted file mode 100644 index 6c9405da3..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect.go +++ /dev/null @@ -1,100 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db *sql.DB) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset int) string - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db *sql.DB) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// ParseFieldStructForDialect parse field struct for dialect -func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var reflectType = field.Struct.Type - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - // Get scanner's real value - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - - // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = additionalType + " DEFAULT " + value - } - - return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType) -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_common.go b/vendor/github.com/jinzhu/gorm/dialect_common.go deleted file mode 100644 index f009271b3..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_common.go +++ /dev/null @@ -1,137 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strings" - "time" -) - -type commonDialect struct { - db *sql.DB -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db *sql.DB) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) currentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) { - if limit > 0 || offset > 0 { - if limit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", limit) - } - if offset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", offset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_mysql.go b/vendor/github.com/jinzhu/gorm/dialect_mysql.go deleted file mode 100644 index 6fade59d2..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_mysql.go +++ /dev/null @@ -1,113 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" - } else { - sqlType = "timestamp NULL" - } - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) currentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_postgres.go b/vendor/github.com/jinzhu/gorm/dialect_postgres.go deleted file mode 100644 index 09ac59616..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_postgres.go +++ /dev/null @@ -1,132 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if isByteArrayOrSlice(dataValue) { - sqlType = "bytea" - } else if isUUID(dataValue) { - sqlType = "uuid" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) currentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go b/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go deleted file mode 100644 index 5c262aaf2..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go +++ /dev/null @@ -1,106 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite", &sqlite3{}) - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) currentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go b/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go deleted file mode 100644 index adeeec7bf..000000000 --- a/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go +++ /dev/null @@ -1,54 +0,0 @@ -package postgres - -import ( - "database/sql" - "database/sql/driver" - - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" -) - -type Hstore map[string]*string - -// Value get value of Hstore -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -// Scan scan value into Hstore -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} diff --git a/vendor/github.com/jinzhu/gorm/embedded_struct_test.go b/vendor/github.com/jinzhu/gorm/embedded_struct_test.go deleted file mode 100644 index 7be75d990..000000000 --- a/vendor/github.com/jinzhu/gorm/embedded_struct_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package gorm_test - -import "testing" - -type BasePost struct { - Id int64 - Title string - URL string -} - -type HNPost struct { - BasePost - Upvotes int32 -} - -type EngadgetPost struct { - BasePost BasePost `gorm:"embedded"` - ImageUrl string -} - -func TestSaveAndQueryEmbeddedStruct(t *testing.T) { - DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) - DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) - var news HNPost - if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if news.Title != "hn_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) - var egNews EngadgetPost - if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if egNews.BasePost.Title != "engadget_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - if DB.NewScope(&HNPost{}).PrimaryField() == nil { - t.Errorf("primary key with embedded struct should works") - } - - for _, field := range DB.NewScope(&HNPost{}).Fields() { - if field.Name == "BasePost" { - t.Errorf("scope Fields should not contain embedded struct") - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/errors.go b/vendor/github.com/jinzhu/gorm/errors.go deleted file mode 100644 index ce3a25c0f..000000000 --- a/vendor/github.com/jinzhu/gorm/errors.go +++ /dev/null @@ -1,58 +0,0 @@ -package gorm - -import ( - "errors" - "strings" -) - -var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") -) - -type errorsInterface interface { - GetErrors() []error -} - -// Errors contains all happened errors -type Errors struct { - errors []error -} - -// GetErrors get all happened errors -func (errs Errors) GetErrors() []error { - return errs.errors -} - -// Add add an error -func (errs *Errors) Add(err error) { - if errors, ok := err.(errorsInterface); ok { - for _, err := range errors.GetErrors() { - errs.Add(err) - } - } else { - for _, e := range errs.errors { - if err == e { - return - } - } - errs.errors = append(errs.errors, err) - } -} - -// Error format happened errors -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs.errors { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/vendor/github.com/jinzhu/gorm/field.go b/vendor/github.com/jinzhu/gorm/field.go deleted file mode 100644 index 11c410b0f..000000000 --- a/vendor/github.com/jinzhu/gorm/field.go +++ /dev/null @@ -1,58 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/vendor/github.com/jinzhu/gorm/field_test.go b/vendor/github.com/jinzhu/gorm/field_test.go deleted file mode 100644 index 30e9a778d..000000000 --- a/vendor/github.com/jinzhu/gorm/field_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -type CalculateField struct { - gorm.Model - Name string - Children []CalculateFieldChild - Category CalculateFieldCategory - EmbeddedField -} - -type EmbeddedField struct { - EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` -} - -type CalculateFieldChild struct { - gorm.Model - CalculateFieldID uint - Name string -} - -type CalculateFieldCategory struct { - gorm.Model - CalculateFieldID uint - Name string -} - -func TestCalculateField(t *testing.T) { - var field CalculateField - var scope = DB.NewScope(&field) - if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("embedded_name"); !ok { - t.Errorf("should find embedded field") - } else if _, ok := field.TagSettings["NOT NULL"]; !ok { - t.Errorf("should find embedded field's tag settings") - } -} diff --git a/vendor/github.com/jinzhu/gorm/interface.go b/vendor/github.com/jinzhu/gorm/interface.go deleted file mode 100644 index 7b02aa664..000000000 --- a/vendor/github.com/jinzhu/gorm/interface.go +++ /dev/null @@ -1,19 +0,0 @@ -package gorm - -import "database/sql" - -type sqlCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/vendor/github.com/jinzhu/gorm/join_table_handler.go b/vendor/github.com/jinzhu/gorm/join_table_handler.go deleted file mode 100644 index 18c12a859..000000000 --- a/vendor/github.com/jinzhu/gorm/join_table_handler.go +++ /dev/null @@ -1,204 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return s.TableName -} - -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } - } - return values -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range searchMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - ) - - for key, value := range s.getSearchMap(db, sources...) { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/vendor/github.com/jinzhu/gorm/join_table_test.go b/vendor/github.com/jinzhu/gorm/join_table_test.go deleted file mode 100644 index 1a83a9c87..000000000 --- a/vendor/github.com/jinzhu/gorm/join_table_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package gorm_test - -import ( - "fmt" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type Person struct { - Id int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` -} - -type PersonAddress struct { - gorm.JoinTableHandler - PersonID int - AddressID int - DeletedAt *time.Time - CreatedAt time.Time -} - -func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { - return db.Where(map[string]interface{}{ - "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), - "address_id": db.NewScope(associationValue).PrimaryKeyValue(), - }).Assign(map[string]interface{}{ - "person_id": foreignValue, - "address_id": associationValue, - "deleted_at": gorm.Expr("NULL"), - }).FirstOrCreate(&PersonAddress{}).Error -} - -func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { - return db.Delete(&PersonAddress{}).Error -} - -func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { - table := pa.Table(db) - return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) -} - -func TestJoinTable(t *testing.T) { - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &Person{Name: "person", Addresses: []*Address{address1, address2}} - DB.Save(person) - - DB.Model(person).Association("Addresses").Delete(address1) - - if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { - t.Errorf("Should found one address") - } - - if DB.Model(person).Association("Addresses").Count() != 1 { - t.Errorf("Should found one address") - } - - if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { - t.Errorf("Found two addresses with Unscoped") - } - - if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} diff --git a/vendor/github.com/jinzhu/gorm/logger.go b/vendor/github.com/jinzhu/gorm/logger.go deleted file mode 100644 index 2c4ccbbc4..000000000 --- a/vendor/github.com/jinzhu/gorm/logger.go +++ /dev/null @@ -1,99 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) -) - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - if len(values) > 1 { - level := values[0] - currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - messages := []interface{}{source, currentTime} - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - var sql string - var formattedValues []string - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } - - var formattedValuesLength = len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - - messages = append(messages, sql) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - logger.Println(messages...) - } -} - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} diff --git a/vendor/github.com/jinzhu/gorm/main.go b/vendor/github.com/jinzhu/gorm/main.go deleted file mode 100644 index cd4455551..000000000 --- a/vendor/github.com/jinzhu/gorm/main.go +++ /dev/null @@ -1,700 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" -) - -// DB contains information for current db connection -type DB struct { - Value interface{} - Error error - RowsAffected int64 - callbacks *Callback - db sqlCommon - parent *DB - search *search - logMode int - logger logger - dialect Dialect - singularTable bool - source string - values map[string]interface{} - joinTableHandlers map[string]JoinTableHandler -} - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (*DB, error) { - var db DB - var err error - - if len(args) == 0 { - err = errors.New("invalid database source") - } else { - var source string - var dbSQL sqlCommon - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - case sqlCommon: - source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() - dbSQL = value - } - - db = DB{ - dialect: newDialect(dialect, dbSQL.(*sql.DB)), - logger: defaultLogger, - callbacks: DefaultCallback, - source: source, - values: map[string]interface{}{}, - db: dbSQL, - } - db.parent = &db - - if err == nil { - err = db.DB().Ping() // Send a ping to make sure the database connection is alive. - } - } - - return &db, err -} - -// Close close current db connection -func (s *DB) Close() error { - return s.parent.db.(*sql.DB).Close() -} - -// DB get `*sql.DB` from current connection -func (s *DB) DB() *sql.DB { - return s.db.(*sql.DB) -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() sqlCommon { - return s.db -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = 2 - } else { - s.logMode = 1 - } - return s -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() - s.parent.singularTable = enable -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit int) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset int) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -func (s *DB) Order(value string, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query string, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/curd.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - clone = s.clone() - scope = clone.NewScope(result) - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/curd.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/curd.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error) - } else if len(c.search.assignAttrs) > 0 { - c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error) - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) - if scope.PrimaryKeyZero() { - return scope.callCallbacks(s.parent.callbacks.creates).db - } - return scope.callCallbacks(s.parent.callbacks.updates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Begin begin a transaction -func (s *DB) Begin() *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok { - tx, err := db.Begin() - c.db = interface{}(tx).(sqlCommon) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok { - s.AddError(db.Rollback()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.clone().NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - scope := s.clone().NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == 0 { - go s.print(fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors{errors: s.GetErrors()} - errors.Add(err) - if len(errors.GetErrors()) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() (errors []error) { - if errs, ok := s.Error.(errorsInterface); ok { - return errs.GetErrors() - } else if s.Error != nil { - return []error{s.Error} - } - return -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} - - for key, value := range s.values { - db.values[key] = value - } - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = &db - return &db -} - -func (s *DB) print(v ...interface{}) { - s.logger.(logger).Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) - } -} diff --git a/vendor/github.com/jinzhu/gorm/main_test.go b/vendor/github.com/jinzhu/gorm/main_test.go deleted file mode 100644 index 8ac015c8d..000000000 --- a/vendor/github.com/jinzhu/gorm/main_test.go +++ /dev/null @@ -1,774 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "os" - "reflect" - "strconv" - "testing" - "time" - - "github.com/erikstmartin/go-testdb" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/now" -) - -var ( - DB *gorm.DB - t1, t2, t3, t4, t5 time.Time -) - -func init() { - var err error - - if DB, err = OpenTestConnection(); err != nil { - panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) - } - - // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) - // DB.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { - DB.LogMode(true) - } - - DB.DB().SetMaxIdleConns(10) - - runMigration() -} - -func OpenTestConnection() (db *gorm.DB, err error) { - switch os.Getenv("GORM_DIALECT") { - case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; - fmt.Println("testing mysql...") - db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") - case "postgres": - fmt.Println("testing postgres...") - db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") - case "mssql": - fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") - default: - fmt.Println("testing sqlite3...") - db, err = gorm.Open("sqlite3", "/tmp/gorm.db") - } - return -} - -func TestStringPrimaryKey(t *testing.T) { - type UUIDStruct struct { - ID string `gorm:"primary_key"` - Name string - } - DB.AutoMigrate(&UUIDStruct{}) - - data := UUIDStruct{ID: "uuid", Name: "hello"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { - t.Errorf("string primary key should not be populated") - } -} - -func TestExceptionsWithInvalidSql(t *testing.T) { - var columns []string - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - var count1, count2 int64 - DB.Model(&User{}).Count(&count1) - if count1 <= 0 { - t.Errorf("Should find some users") - } - - if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - DB.Model(&User{}).Count(&count2) - if count1 != count2 { - t.Errorf("No user should not be deleted by invalid SQL") - } -} - -func TestSetTable(t *testing.T) { - DB.Create(getPreparedUser("pluck_user1", "pluck_user")) - DB.Create(getPreparedUser("pluck_user2", "pluck_user")) - DB.Create(getPreparedUser("pluck_user3", "pluck_user")) - - if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { - t.Error("No errors should happen if set table for pluck", err) - } - - var users []User - if DB.Table("users").Find(&[]User{}).Error != nil { - t.Errorf("No errors should happen if set table for find") - } - - if DB.Table("invalid_table").Find(&users).Error == nil { - t.Errorf("Should got error when table is set to an invalid table") - } - - DB.Exec("drop table deleted_users;") - if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { - t.Errorf("Create table with specified table") - } - - DB.Table("deleted_users").Save(&User{Name: "DeletedUser"}) - - var deletedUsers []User - DB.Table("deleted_users").Find(&deletedUsers) - if len(deletedUsers) != 1 { - t.Errorf("Query from specified table") - } - - DB.Save(getPreparedUser("normal_user", "reset_table")) - DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) - var user1, user2, user3 User - DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) - if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { - t.Errorf("unset specified table with blank string") - } -} - -type Order struct { -} - -type Cart struct { -} - -func (c Cart) TableName() string { - return "shopping_cart" -} - -func TestHasTable(t *testing.T) { - type Foo struct { - Id int - Stuff string - } - DB.DropTable(&Foo{}) - - // Table should not exist at this point, HasTable should return false - if ok := DB.HasTable("foos"); ok { - t.Errorf("Table should not exist, but does") - } - if ok := DB.HasTable(&Foo{}); ok { - t.Errorf("Table should not exist, but does") - } - - // We create the table - if err := DB.CreateTable(&Foo{}).Error; err != nil { - t.Errorf("Table should be created") - } - - // And now it should exits, and HasTable should return true - if ok := DB.HasTable("foos"); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } - if ok := DB.HasTable(&Foo{}); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } -} - -func TestTableName(t *testing.T) { - DB := DB.Model("") - if DB.NewScope(Order{}).TableName() != "orders" { - t.Errorf("Order's table name should be orders") - } - - if DB.NewScope(&Order{}).TableName() != "orders" { - t.Errorf("&Order's table name should be orders") - } - - if DB.NewScope([]Order{}).TableName() != "orders" { - t.Errorf("[]Order's table name should be orders") - } - - if DB.NewScope(&[]Order{}).TableName() != "orders" { - t.Errorf("&[]Order's table name should be orders") - } - - DB.SingularTable(true) - if DB.NewScope(Order{}).TableName() != "order" { - t.Errorf("Order's singular table name should be order") - } - - if DB.NewScope(&Order{}).TableName() != "order" { - t.Errorf("&Order's singular table name should be order") - } - - if DB.NewScope([]Order{}).TableName() != "order" { - t.Errorf("[]Order's singular table name should be order") - } - - if DB.NewScope(&[]Order{}).TableName() != "order" { - t.Errorf("&[]Order's singular table name should be order") - } - - if DB.NewScope(&Cart{}).TableName() != "shopping_cart" { - t.Errorf("&Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(Cart{}).TableName() != "shopping_cart" { - t.Errorf("Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" { - t.Errorf("&[]Cart's singular table name should be shopping_cart") - } - - if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { - t.Errorf("[]Cart's singular table name should be shopping_cart") - } - DB.SingularTable(false) -} - -func TestNullValues(t *testing.T) { - DB.DropTable(&NullValue{}) - DB.AutoMigrate(&NullValue{}) - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: true}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv NullValue - DB.First(&nv, "name = ?", "hello") - - if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-2", Valid: true}, - Gender: &sql.NullString{String: "F", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv2 NullValue - DB.First(&nv2, "name = ?", "hello-2") - if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-3", Valid: false}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err == nil { - t.Errorf("Can't save because of name can't be null") - } -} - -func TestNullValuesWithFirstOrCreate(t *testing.T) { - var nv1 = NullValue{ - Name: sql.NullString{String: "first_or_create", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - } - - var nv2 NullValue - if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil { - t.Errorf("Should not raise any error, but got %v", err) - } - - if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { - t.Errorf("first or create with nullvalues") - } - - if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { - t.Errorf("Should not raise any error, but got %v", err) - } - - if nv2.Age.Int64 != 18 { - t.Errorf("should update age to 18") - } -} - -func TestTransaction(t *testing.T) { - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") - } - - tx.Rollback() - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - tx2 := DB.Begin() - u2 := User{Name: "transcation-2"} - if err := tx2.Save(&u2).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx2.Commit() - - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } -} - -func TestRow(t *testing.T) { - user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() - var age int64 - row.Scan(&age) - if age != 10 { - t.Errorf("Scan with Row") - } -} - -func TestRows(t *testing.T) { - user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - count := 0 - for rows.Next() { - var name string - var age int64 - rows.Scan(&name, &age) - count++ - } - - if count != 2 { - t.Errorf("Should found two records") - } -} - -func TestScanRows(t *testing.T) { - user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - type Result struct { - Name string - Age int - } - - var results []Result - for rows.Next() { - var result Result - if err := DB.ScanRows(rows, &result); err != nil { - t.Errorf("should get no error, but got %v", err) - } - results = append(results, result) - } - - if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") - } -} - -func TestScan(t *testing.T) { - user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Age int - } - - var res result - DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res) - if res.Name != user3.Name { - t.Errorf("Scan into struct should work") - } - - var doubleAgeRes result - DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes) - if doubleAgeRes.Age != res.Age*2 { - t.Errorf("Scan double age as age") - } - - var ress []result - DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Scan into struct map") - } -} - -func TestRaw(t *testing.T) { - user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Email string - } - - var ress []result - DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Raw with scan") - } - - rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() - count := 0 - for rows.Next() { - count++ - } - if count != 1 { - t.Errorf("Raw with Rows should find one record with name 3") - } - - DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { - t.Error("Raw sql to update records") - } -} - -func TestGroup(t *testing.T) { - rows, err := DB.Select("name").Table("users").Group("name").Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - rows.Scan(&name) - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestJoins(t *testing.T) { - var user = User{ - Name: "joins", - CreditCard: CreditCard{Number: "411111111111"}, - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var users1 []User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) - if len(users1) != 2 { - t.Errorf("should find two users using left join") - } - - var users2 []User - DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) - if len(users2) != 1 { - t.Errorf("should find one users using left join with conditions") - } - - var users3 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) - if len(users3) != 1 { - t.Errorf("should find one users using multiple left join conditions") - } - - var users4 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) - if len(users4) != 0 { - t.Errorf("should find no user when searching with unexisting credit card") - } -} - -func TestJoinsWithSelect(t *testing.T) { - type result struct { - Name string - Email string - } - - user := User{ - Name: "joins_with_select", - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var results []result - DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) - if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { - t.Errorf("Should find all two emails with Join select") - } -} - -func TestHaving(t *testing.T) { - rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - var total int64 - rows.Scan(&name, &total) - - if name == "2" && total != 1 { - t.Errorf("Should have one user having name 2") - } - if name == "3" && total != 2 { - t.Errorf("Should have two users having name 3") - } - } - } else { - t.Errorf("Should not raise any error") - } -} - -func DialectHasTzSupport() bool { - // NB: mssql and FoundationDB do not support time zones. - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { - return false - } - return true -} - -func TestTimeWithZone(t *testing.T) { - var format = "2006-01-02 15:04:05 -0700" - var times []time.Time - GMT8, _ := time.LoadLocation("Asia/Shanghai") - times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) - times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) - - for index, vtime := range times { - name := "time_with_zone_" + strconv.Itoa(index) - user := User{Name: name, Birthday: vtime} - - if !DialectHasTzSupport() { - // If our driver dialect doesn't support TZ's, just use UTC for everything here. - user.Birthday = vtime.UTC() - } - - DB.Save(&user) - expectedBirthday := "2013-02-18 17:51:49 +0000" - foundBirthday := user.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - var findUser, findUser2, findUser3 User - DB.First(&findUser, "name = ?", name) - foundBirthday = findUser.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { - t.Errorf("User should be found") - } - - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { - t.Errorf("User should not be found") - } - } -} - -func TestHstore(t *testing.T) { - type Details struct { - Id int64 - Bulk postgres.Hstore - } - - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip() - } - - if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { - fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") - panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) - } - - DB.Exec("drop table details") - - if err := DB.CreateTable(&Details{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - - bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" - bulk := map[string]*string{ - "bankAccountId": &bankAccountId, - "phoneNumber": &phoneNumber, - "opinion": &opinion, - } - d := Details{Bulk: bulk} - DB.Save(&d) - - var d2 Details - if err := DB.First(&d2).Error; err != nil { - t.Errorf("Got error when tried to fetch details: %+v", err) - } - - for k := range bulk { - if r, ok := d2.Bulk[k]; ok { - if res, _ := bulk[k]; *res != *r { - t.Errorf("Details should be equal") - } - } else { - t.Errorf("Details should be existed") - } - } -} - -func TestSetAndGet(t *testing.T) { - if value, ok := DB.Set("hello", "world").Get("hello"); !ok { - t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } - } - - if _, ok := DB.Get("non_existing"); ok { - t.Errorf("Get non existing key should return error") - } -} - -func TestCompatibilityMode(t *testing.T) { - DB, _ := gorm.Open("testdb", "") - testdb.SetQueryFunc(func(query string) (driver.Rows, error) { - columns := []string{"id", "name", "age"} - result := ` - 1,Tim,20 - 2,Joe,25 - 3,Bob,30 - ` - return testdb.RowsFromCSVString(columns, result), nil - }) - - var users []User - DB.Find(&users) - if (users[0].Name != "Tim") || len(users) != 3 { - t.Errorf("Unexcepted result returned") - } -} - -func TestOpenExistingDB(t *testing.T) { - DB.Save(&User{Name: "jnfeinstein"}) - dialect := os.Getenv("GORM_DIALECT") - - db, err := gorm.Open(dialect, DB.DB()) - if err != nil { - t.Errorf("Should have wrapped the existing DB connection") - } - - var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { - t.Errorf("Should have found existing record") - } -} - -func TestDdlErrors(t *testing.T) { - var err error - - if err = DB.Close(); err != nil { - t.Errorf("Closing DDL test db connection err=%s", err) - } - defer func() { - // Reopen DB connection. - if DB, err = OpenTestConnection(); err != nil { - t.Fatalf("Failed re-opening db connection: %s", err) - } - }() - - if err := DB.Find(&User{}).Error; err == nil { - t.Errorf("Expected operation on closed db to produce an error, but err was nil") - } -} - -func BenchmarkGorm(b *testing.B) { - b.N = 2000 - for x := 0; x < b.N; x++ { - e := strconv.Itoa(x) + "benchmark@example.org" - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} - // Insert - DB.Save(&email) - // Query - DB.First(&BigEmail{}, "email = ?", e) - // Update - DB.Model(&email).UpdateColumn("email", "new-"+e) - // Delete - DB.Delete(&email) - } -} - -func BenchmarkRawSql(b *testing.B) { - DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") - DB.SetMaxIdleConns(10) - insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" - querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" - updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" - deleteSql := "DELETE FROM orders WHERE id = $1" - - b.N = 2000 - for x := 0; x < b.N; x++ { - var id int64 - e := strconv.Itoa(x) + "benchmark@example.org" - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} - // Insert - DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) - // Query - rows, _ := DB.Query(querySql, email.Email) - rows.Close() - // Update - DB.Exec(updateSql, "new-"+e, time.Now(), id) - // Delete - DB.Exec(deleteSql, id) - } -} diff --git a/vendor/github.com/jinzhu/gorm/migration_test.go b/vendor/github.com/jinzhu/gorm/migration_test.go deleted file mode 100644 index 38e5c1c2e..000000000 --- a/vendor/github.com/jinzhu/gorm/migration_test.go +++ /dev/null @@ -1,349 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type User struct { - Id int64 - Age int64 - UserNum Num - Name string `sql:"size:255"` - Email string - Birthday time.Time // Time - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressID sql.NullInt64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key - CreditCard CreditCard - Latitude float64 - Languages []Language `gorm:"many2many:user_languages;"` - CompanyID *int - Company Company - Role - PasswordHash []byte - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` - Ignored struct{ Name string } `sql:"-"` - IgnoredPointer *User `sql:"-"` -} - -type CreditCard struct { - ID int8 - Number string - UserId sql.NullInt64 - CreatedAt time.Time `sql:"not null"` - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Email struct { - Id int16 - UserId int - Email string `sql:"type:varchar(100);"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Address struct { - ID int - Address1 string - Address2 string - Post string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Language struct { - gorm.Model - Name string - Users []User `gorm:"many2many:user_languages;"` -} - -type Product struct { - Id int64 - Code string - Price int64 - CreatedAt time.Time - UpdatedAt time.Time - AfterFindCallTimes int64 - BeforeCreateCallTimes int64 - AfterCreateCallTimes int64 - BeforeUpdateCallTimes int64 - AfterUpdateCallTimes int64 - BeforeSaveCallTimes int64 - AfterSaveCallTimes int64 - BeforeDeleteCallTimes int64 - AfterDeleteCallTimes int64 -} - -type Company struct { - Id int64 - Name string - Owner *User `sql:"-"` -} - -type Role struct { - Name string -} - -func (role *Role) Scan(value interface{}) error { - if b, ok := value.([]uint8); ok { - role.Name = string(b) - } else { - role.Name = value.(string) - } - return nil -} - -func (role Role) Value() (driver.Value, error) { - return role.Name, nil -} - -func (role Role) IsAdmin() bool { - return role.Name == "admin" -} - -type Num int64 - -func (i *Num) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - case int64: - *i = Num(s) - default: - return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) - } - return nil -} - -type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `sql:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `sql:"DEFAULT:current_timestamp"` - unexported string // unexported value - CreatedAt time.Time - UpdatedAt time.Time -} - -type JoinTable struct { - From uint64 - To uint64 - Time time.Time `sql:"default: null"` -} - -type Post struct { - Id int64 - CategoryId sql.NullInt64 - MainCategoryId int64 - Title string - Body string - Comments []*Comment - Category Category - MainCategory Category -} - -type Category struct { - gorm.Model - Name string -} - -type Comment struct { - gorm.Model - PostId int64 - Content string - Post Post -} - -// Scanner -type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Gender *sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime -} - -type NullTime struct { - Time time.Time - Valid bool -} - -func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil -} - -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} - -func getPreparedUser(name string, role string) *User { - var company Company - DB.Where(Company{Name: role}).FirstOrCreate(&company) - - return &User{ - Name: name, - Age: 20, - Role: Role{role}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } -} - -func runMigration() { - if err := DB.DropTableIfExists(&User{}).Error; err != nil { - fmt.Printf("Got error when try to delete table users, %+v\n", err) - } - - for _, table := range []string{"animals", "user_languages"} { - DB.Exec(fmt.Sprintf("drop table %v;", table)) - } - - values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} - for _, value := range values { - DB.DropTable(value) - } - - if err := DB.AutoMigrate(values...).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } -} - -func TestIndexes(t *testing.T) { - if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email should have index idx_email_email") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email's index idx_email_email should be deleted") - } - - if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { - t.Errorf("Should get to create duplicate record when having unique index") - } - - var user = User{Name: "sample_user"} - DB.Save(&user) - if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { - t.Errorf("Should get no error when append two emails for user") - } - - if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { - t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { - t.Errorf("Should be able to create duplicated emails after remove unique index") - } -} - -type BigEmail struct { - Id int64 - UserId int64 - Email string `sql:"index:idx_email_agent"` - UserAgent string `sql:"index:idx_email_agent"` - RegisteredAt time.Time `sql:"unique_index"` - CreatedAt time.Time - UpdatedAt time.Time -} - -func (b BigEmail) TableName() string { - return "emails" -} - -func TestAutoMigration(t *testing.T) { - DB.AutoMigrate(&Address{}) - if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) - - scope := DB.NewScope(&BigEmail{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") { - t.Errorf("Failed to create index") - } - - var bigemail BigEmail - DB.First(&bigemail, "user_agent = ?", "pc") - if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { - t.Error("Big Emails should be saved and fetched correctly") - } -} diff --git a/vendor/github.com/jinzhu/gorm/model.go b/vendor/github.com/jinzhu/gorm/model.go deleted file mode 100644 index f37ff7eaa..000000000 --- a/vendor/github.com/jinzhu/gorm/model.go +++ /dev/null @@ -1,14 +0,0 @@ -package gorm - -import "time" - -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` -} diff --git a/vendor/github.com/jinzhu/gorm/model_struct.go b/vendor/github.com/jinzhu/gorm/model_struct.go deleted file mode 100644 index 6df615d1b..000000000 --- a/vendor/github.com/jinzhu/gorm/model_struct.go +++ /dev/null @@ -1,542 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - defaultTableName string -} - -// TableName get model's table name -func (s *ModelStruct) TableName(db *DB) string { - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship -} - -func (structField *StructField) clone() *StructField { - return &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - IsScanner: structField.IsScanner, - HasDefaultValue: structField.HasDefaultValue, - Tag: structField.Tag, - TagSettings: structField.TagSettings, - Struct: structField.Struct, - IsForeignKey: structField.IsForeignKey, - Relationship: structField.Relationship, - } -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value - } - - modelStruct.ModelType = reflectType - - // Set default table name - if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok { - modelStruct.defaultTableName = tabler.TableName() - } else { - tableName := ToDBName(reflectType.Name()) - if scope.db == nil || !scope.db.parent.singularTable { - tableName = inflection.Plural(tableName) - } - modelStruct.defaultTableName = tableName - } - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if fieldStruct.Tag.Get("sql") == "-" { - field.IsIgnored = true - } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettings["DEFAULT"]; ok { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetStructFields() { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if subField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") - } - - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys - foreignField.IsForeignKey = true - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") - } - - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") - } - - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - foreignField.IsForeignKey = true - // source foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - foreignField.IsForeignKey = true - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = value - } else { - field.DBName = ToDBName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Set(reflectType, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go b/vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go deleted file mode 100644 index 8b275d182..000000000 --- a/vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go +++ /dev/null @@ -1,381 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "sort" - "testing" -) - -type Blog struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Subject string - Body string - Tags []Tag `gorm:"many2many:blog_tags;"` - SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` - LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` -} - -type Tag struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Value string - Blogs []*Blog `gorm:"many2many:blogs_tags"` -} - -func compareTags(tags []Tag, contents []string) bool { - var tagContents []string - for _, tag := range tags { - tagContents = append(tagContents, tag.Value) - } - sort.Strings(tagContents) - sort.Strings(contents) - return reflect.DeepEqual(tagContents, contents) -} - -func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - Tags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - - DB.Save(&blog) - if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) - if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("Tags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "Tags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("Tags").Find(&blog1) - if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog).Association("Tags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "Tags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("Tags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("Tags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "Tags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("Tags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog).Association("Tags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "Tags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog).Association("Tags").Clear() - if DB.Model(&blog).Association("Tags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("shared_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - SharedTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) - if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("SharedTags").Find(&blog1) - if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("SharedTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - DB.Model(&blog2).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("SharedTags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "SharedTags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog2).Association("SharedTags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "SharedTags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog2).Association("SharedTags").Clear() - if DB.Model(&blog).Association("SharedTags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("locale_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - LocaleTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) - if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog should has 0 tags after ZH Blog Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if len(tags) != 0 { - t.Errorf("Should find 0 tags with Related for EN Blog") - } - - var blog1 Blog - DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) - if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("LocaleTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related for EN Blog") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag4"}) { - t.Errorf("Should find 1 tags with Related for EN Blog") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) - - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - var blog11 Blog - DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) - if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - DB.Model(&blog2).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - var blog21 Blog - DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) - if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { - t.Errorf("EN Blog's tags should be changed after Replace") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Replace") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after Replace") - } - - // Delete - DB.Model(&blog).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") - } - - DB.Model(&blog2).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { - t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") - } - - // Clear - DB.Model(&blog2).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") - } - - DB.Model(&blog).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 0 { - t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared") - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/pointer_test.go b/vendor/github.com/jinzhu/gorm/pointer_test.go deleted file mode 100644 index 2a68a5ab2..000000000 --- a/vendor/github.com/jinzhu/gorm/pointer_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package gorm_test - -import "testing" - -type PointerStruct struct { - ID int64 - Name *string - Num *int -} - -type NormalStruct struct { - ID int64 - Name string - Num int -} - -func TestPointerFields(t *testing.T) { - DB.DropTable(&PointerStruct{}) - DB.AutoMigrate(&PointerStruct{}) - var name = "pointer struct 1" - var num = 100 - pointerStruct := PointerStruct{Name: &name, Num: &num} - if DB.Create(&pointerStruct).Error != nil { - t.Errorf("Failed to save pointer struct") - } - - var pointerStructResult PointerStruct - if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { - t.Errorf("Failed to query saved pointer struct") - } - - var tableName = DB.NewScope(&PointerStruct{}).TableName() - - var normalStruct NormalStruct - DB.Table(tableName).First(&normalStruct) - if normalStruct.Name != name || normalStruct.Num != num { - t.Errorf("Failed to query saved Normal struct") - } - - var nilPointerStruct = PointerStruct{} - if err := DB.Create(&nilPointerStruct).Error; err != nil { - t.Error("Failed to save nil pointer struct", err) - } - - var pointerStruct2 PointerStruct - if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var normalStruct2 NormalStruct - if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var partialNilPointerStruct1 = PointerStruct{Num: &num} - if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct3 PointerStruct - if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct3 NormalStruct - if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { - t.Error("Failed to query saved partial pointer struct", err) - } - - var partialNilPointerStruct2 = PointerStruct{Name: &name} - if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct4 PointerStruct - if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct4 NormalStruct - if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { - t.Error("Failed to query saved partial pointer struct", err) - } -} diff --git a/vendor/github.com/jinzhu/gorm/polymorphic_test.go b/vendor/github.com/jinzhu/gorm/polymorphic_test.go deleted file mode 100644 index df573f97b..000000000 --- a/vendor/github.com/jinzhu/gorm/polymorphic_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package gorm_test - -import ( - "reflect" - "sort" - "testing" -) - -type Cat struct { - Id int - Name string - Toy Toy `gorm:"polymorphic:Owner;"` -} - -type Dog struct { - Id int - Name string - Toys []Toy `gorm:"polymorphic:Owner;"` -} - -type Toy struct { - Id int - Name string - OwnerId int - OwnerType string -} - -var compareToys = func(toys []Toy, contents []string) bool { - var toyContents []string - for _, toy := range toys { - toyContents = append(toyContents, toy.Name) - } - sort.Strings(toyContents) - sort.Strings(contents) - return reflect.DeepEqual(toyContents, contents) -} - -func TestPolymorphic(t *testing.T) { - cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} - dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} - DB.Save(&cat).Save(&dog) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Dog's toys count should be 2") - } - - // Query - var catToys []Toy - if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(catToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if catToys[0].Name != cat.Toy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - var dogToys []Toy - if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { - t.Errorf("Did not find any polymorphic has many associations") - } else if len(dogToys) != len(dog.Toys) { - t.Errorf("Should have found all polymorphic has many associations") - } - - var catToy Toy - DB.Model(&cat).Association("Toy").Find(&catToy) - if catToy.Name != cat.Toy.Name { - t.Errorf("Should find has one polymorphic association") - } - - var dogToys1 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys1) - if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { - t.Errorf("Should find has many polymorphic association") - } - - // Append - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - var catToy2 Toy - DB.Model(&cat).Association("Toy").Find(&catToy2) - if catToy2.Name != "cat toy 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Should return two polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 3", - }) - - var dogToys2 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys2) - if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { - t.Errorf("Dog's toys should be updated with Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Replace - DB.Model(&cat).Association("Toy").Replace(&Toy{ - Name: "cat toy 3", - }) - - var catToy3 Toy - DB.Model(&cat).Association("Toy").Find(&catToy3) - if catToy3.Name != "cat toy 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Replace(&Toy{ - Name: "dog toy 4", - }, []Toy{ - {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, - }) - - var dogToys3 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys3) - if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { - t.Errorf("Dog's toys should be updated with Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Delete - DB.Model(&cat).Association("Toy").Delete(&catToy2) - - var catToy4 Toy - DB.Model(&cat).Association("Toy").Find(&catToy4) - if catToy4.Name != "cat toy 3" { - t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should be 4") - } - - DB.Model(&cat).Association("Toy").Delete(&catToy3) - - if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { - t.Errorf("Toy should be deleted with Delete") - } - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys count should be 0 after Delete") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete cat's toy") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys2) - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete unrelated toys") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys3) - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys count should be deleted with Delete") - } - - // Clear - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys should be added with Append") - } - - DB.Model(&cat).Association("Toy").Clear() - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys should be cleared with Clear") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 8", - }) - - if DB.Model(&dog).Association("Toys").Count() != 1 { - t.Errorf("Dog's toys should be added with Append") - } - - DB.Model(&dog).Association("Toys").Clear() - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys should be cleared with Clear") - } -} diff --git a/vendor/github.com/jinzhu/gorm/preload_test.go b/vendor/github.com/jinzhu/gorm/preload_test.go deleted file mode 100644 index 5c49ecc21..000000000 --- a/vendor/github.com/jinzhu/gorm/preload_test.go +++ /dev/null @@ -1,1327 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "encoding/json" - "os" - "reflect" - "testing" - - "github.com/jinzhu/gorm" -) - -func getPreloadUser(name string) *User { - return getPreparedUser(name, "Preload") -} - -func checkUserHasPreloadData(user User, t *testing.T) { - u := getPreloadUser(user.Name) - if user.BillingAddress.Address1 != u.BillingAddress.Address1 { - t.Error("Failed to preload user's BillingAddress") - } - - if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { - t.Error("Failed to preload user's ShippingAddress") - } - - if user.CreditCard.Number != u.CreditCard.Number { - t.Error("Failed to preload user's CreditCard") - } - - if user.Company.Name != u.Company.Name { - t.Error("Failed to preload user's Company") - } - - if len(user.Emails) != len(u.Emails) { - t.Error("Failed to preload user's Emails") - } else { - var found int - for _, e1 := range u.Emails { - for _, e2 := range user.Emails { - if e1.Email == e2.Email { - found++ - break - } - } - } - if found != len(u.Emails) { - t.Error("Failed to preload user's email details") - } - } -} - -func TestPreload(t *testing.T) { - user1 := getPreloadUser("user1") - DB.Save(user1) - - preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("user2") - DB.Save(user2) - - user3 := getPreloadUser("user3") - DB.Save(user3) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } - - var users3 []*User - preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) - - for _, user := range users3 { - if user.Name == user3.Name { - if len(user.Emails) != 1 { - t.Errorf("should only preload one emails for user3 when with condition") - } - } else if len(user.Emails) != 0 { - t.Errorf("should not preload any emails for other users when with condition") - } - } -} - -func TestNestedPreload1(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []*Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2s: []Level2{ - { - Level1s: []*Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []*Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - Name string - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload4(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -// Slice: []Level3 -func TestNestedPreload5(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload6(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - - want[1] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value5"}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload7(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - - want[1] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value3"}}, - {Level1: Level1{Value: "value4"}}, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload8(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload9(t *testing.T) { - type ( - Level0 struct { - ID uint - Value string - Level1ID uint - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2_1ID uint - Level0s []Level0 - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level2_1 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - Level2_1 Level2_1 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level2_1{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level0{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - { - Value: "value1-1", - Level0s: []Level0{{Value: "Level0-1"}}, - }, - { - Value: "value2-2", - Level0s: []Level0{{Value: "Level0-2"}}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - {Value: "value3-3"}, - {Value: "value4-4"}, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelA1 struct { - ID uint - Value string -} - -type LevelA2 struct { - ID uint - Value string - LevelA3s []*LevelA3 -} - -type LevelA3 struct { - ID uint - Value string - LevelA1ID sql.NullInt64 - LevelA1 *LevelA1 - LevelA2ID sql.NullInt64 - LevelA2 *LevelA2 -} - -func TestNestedPreload10(t *testing.T) { - DB.DropTableIfExists(&LevelA3{}) - DB.DropTableIfExists(&LevelA2{}) - DB.DropTableIfExists(&LevelA1{}) - - if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { - t.Error(err) - } - - levelA1 := &LevelA1{Value: "foo"} - if err := DB.Save(levelA1).Error; err != nil { - t.Error(err) - } - - want := []*LevelA2{ - { - Value: "bar", - LevelA3s: []*LevelA3{ - { - Value: "qux", - LevelA1: levelA1, - }, - }, - }, - { - Value: "bar 2", - }, - } - for _, levelA2 := range want { - if err := DB.Save(levelA2).Error; err != nil { - t.Error(err) - } - } - - var got []*LevelA2 - if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelB1 struct { - ID uint - Value string - LevelB3s []*LevelB3 -} - -type LevelB2 struct { - ID uint - Value string -} - -type LevelB3 struct { - ID uint - Value string - LevelB1ID sql.NullInt64 - LevelB1 *LevelB1 - LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` -} - -func TestNestedPreload11(t *testing.T) { - DB.DropTableIfExists(&LevelB2{}) - DB.DropTableIfExists(&LevelB3{}) - DB.DropTableIfExists(&LevelB1{}) - if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { - t.Error(err) - } - - levelB1 := &LevelB1{Value: "foo"} - if err := DB.Create(levelB1).Error; err != nil { - t.Error(err) - } - - levelB3 := &LevelB3{ - Value: "bar", - LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, - } - if err := DB.Create(levelB3).Error; err != nil { - t.Error(err) - } - levelB1.LevelB3s = []*LevelB3{levelB3} - - want := []*LevelB1{levelB1} - var got []*LevelB1 - if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { - return - } - - type ( - Level1 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - } - Level2 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - Level1s []Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ - {Value: "ru", LanguageCode: "ru"}, - {Value: "en", LanguageCode: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ - {Value: "zh", LanguageCode: "zh"}, - {Value: "de", LanguageCode: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []Level1{ruLevel1} - got2.Level1s = []Level1{zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } - - if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { - t.Error(err) - } -} - -func TestManyToManyPreloadForNestedPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Bob", - Level2: &Level2{ - Value: "Foo", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level3{ - Value: "Tom", - Level2: &Level2{ - Value: "Bar", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level3 - if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) - } - - var got4 []Level3 - if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level3 - DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level2.Level1s = []*Level1{&ruLevel1} - got2.Level2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) - } -} - -func TestNestedManyToManyPreload(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2s []Level2 `gorm:"many2many:level2_level3;"` - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Level3", - Level2s: []Level2{ - { - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, { - Value: "Tom", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - }, - } - - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedManyToManyPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Level3", - Level2: &Level2{ - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } - - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedManyToManyPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - level1Zh := &Level1{Value: "zh"} - level1Ru := &Level1{Value: "ru"} - level1En := &Level1{Value: "en"} - - level21 := &Level2{ - Value: "Level2-1", - Level1s: []*Level1{level1Zh, level1Ru}, - } - - level22 := &Level2{ - Value: "Level2-2", - Level1s: []*Level1{level1Zh, level1En}, - } - - wants := []*Level3{ - { - Value: "Level3-1", - Level2: level21, - }, - { - Value: "Level3-2", - Level2: level22, - }, - { - Value: "Level3-3", - Level2: level21, - }, - } - - for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - } - - var gots []*Level3 - if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { - return db.Order("level1.id ASC") - }).Find(&gots).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(gots, wants) { - t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) - } -} - -func TestNestedManyToManyPreload4(t *testing.T) { - type ( - Level4 struct { - ID uint - Value string - Level3ID uint - } - Level3 struct { - ID uint - Value string - Level4s []*Level4 - } - Level2 struct { - ID uint - Value string - Level3s []*Level3 `gorm:"many2many:level2_level3;"` - } - Level1 struct { - ID uint - Value string - Level2s []*Level2 `gorm:"many2many:level1_level2;"` - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level4{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") - - dummy := Level1{ - Value: "Level1", - Level2s: []*Level2{{ - Value: "Level2", - Level3s: []*Level3{{ - Value: "Level3", - Level4s: []*Level4{{ - Value: "Level4", - }}, - }}, - }}, - } - - if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - if err := DB.Save(&dummy).Error; err != nil { - t.Error(err) - } - - var level1 Level1 - if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { - t.Error(err) - } -} - -func TestManyToManyPreloadForPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level2 - DB.Preload("Level1s").First(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []*Level1{&ruLevel1} - got2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } -} - -func TestNilPointerSlice(t *testing.T) { - type ( - Level3 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level3ID uint - Level3 *Level3 - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level1{Value: "Bob", Level2: &Level2{ - Value: "en", - Level3: &Level3{ - Value: "native", - }, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level1{Value: "Tom", Level2: nil} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got []Level1 - if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { - t.Error(err) - } - - if len(got) != 2 { - t.Errorf("got %v items, expected 2", len(got)) - } - - if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) - } - - if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) - } -} - -func toJSONString(v interface{}) []byte { - r, _ := json.MarshalIndent(v, "", " ") - return r -} diff --git a/vendor/github.com/jinzhu/gorm/query_test.go b/vendor/github.com/jinzhu/gorm/query_test.go deleted file mode 100644 index 7dc3d91b9..000000000 --- a/vendor/github.com/jinzhu/gorm/query_test.go +++ /dev/null @@ -1,636 +0,0 @@ -package gorm_test - -import ( - "fmt" - "reflect" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/now" - - "testing" - "time" -) - -func TestFirstAndLast(t *testing.T) { - DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) - DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) - - var user1, user2, user3, user4 User - DB.First(&user1) - DB.Order("id").Limit(1).Find(&user2) - - DB.Last(&user3) - DB.Order("id desc").Limit(1).Find(&user4) - if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should by order by primary key") - } - - var users []User - DB.First(&users) - if len(users) != 1 { - t.Errorf("Find first record as slice") - } - - var user User - if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { - t.Errorf("Should not raise any error when order with Join table") - } - - if user.Email != "" { - t.Errorf("User's Email should be blank as no one set it") - } -} - -func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { - DB.Save(&Animal{Name: "animal1"}) - DB.Save(&Animal{Name: "animal2"}) - - var animal1, animal2, animal3, animal4 Animal - DB.First(&animal1) - DB.Order("counter").Limit(1).Find(&animal2) - - DB.Last(&animal3) - DB.Order("counter desc").Limit(1).Find(&animal4) - if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { - t.Errorf("First and Last should work correctly") - } -} - -func TestUIntPrimaryKey(t *testing.T) { - var animal Animal - DB.First(&animal, uint64(1)) - if animal.Counter != 1 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } - - DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) - if animal.Counter != 2 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } -} - -func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { - type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string - } - - DB.AutoMigrate(&AddressByZipCode{}) - DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) - - var address AddressByZipCode - DB.First(&address, "00501") - if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") - } -} - -func TestFindAsSliceOfPointers(t *testing.T) { - DB.Save(&User{Name: "user"}) - - var users []User - DB.Find(&users) - - var userPointers []*User - DB.Find(&userPointers) - - if len(users) == 0 || len(users) != len(userPointers) { - t.Errorf("Find slice of pointers") - } -} - -func TestSearchWithPlainSQL(t *testing.T) { - user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") - - if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("Search with plain SQL") - } - - if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { - t.Errorf("Search with plan SQL (regexp)") - } - - var users []User - DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) - if len(users) != 2 { - t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) - } - - DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) - } - - scopedb.Where("age <> ?", 20).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users age != 20, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", "2002-10-10").Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) - } - - scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) - } - - DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users, but got %v", len(users)) - } - - DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users, but got %v", len(users)) - } - - DB.Where("id in (?)", user1.Id).Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users, but got %v", len(users)) - } - - if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { - t.Errorf("Should not get RecordNotFound error when looking for none existing records") - } -} - -func TestSearchWithStruct(t *testing.T) { - user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where(user1.Id).First(&User{}).RecordNotFound() { - t.Errorf("Search with primary key") - } - - if DB.First(&User{}, user1.Id).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - var users []User - DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) - } - - var user User - DB.First(&user, &User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline pointer of struct") - } - - DB.First(&user, User{Name: user1.Name}) - if user.Id == 0 || user.Name != user.Name { - t.Errorf("Search first record with inline struct") - } - - DB.Where(&User{Name: user1.Name}).First(&user) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with where struct") - } - - DB.Find(&users, &User{Name: user2.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline struct") - } -} - -func TestSearchWithMap(t *testing.T) { - companyID := 1 - user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) - - var user User - DB.First(&user, map[string]interface{}{"name": user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline map") - } - - user = User{} - DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) - if user.Id == 0 || user.Name != user2.Name { - t.Errorf("Search first record with where map") - } - - var users []User - DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user3.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) - if len(users) != 0 { - t.Errorf("Search all records with inline map containing null value finding 0 records") - } - - DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) - if len(users) != 1 { - t.Errorf("Search all records with inline map containing null value finding 1 record") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) - if len(users) != 1 { - t.Errorf("Search all records with inline multiple value map") - } -} - -func TestSearchWithEmptyChain(t *testing.T) { - user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where("").Where("").First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty strings") - } - - if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty struct") - } - - if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty map") - } -} - -func TestSelect(t *testing.T) { - user1 := User{Name: "SelectUser1"} - DB.Save(&user1) - - var user User - DB.Where("name = ?", user1.Name).Select("name").Find(&user) - if user.Id != 0 { - t.Errorf("Should not have ID because only selected name, %+v", user.Id) - } - - if user.Name != user1.Name { - t.Errorf("Should have user Name when selected it") - } -} - -func TestOrderAndPluck(t *testing.T) { - user1 := User{Name: "OrderPluckUser1", Age: 1} - user2 := User{Name: "OrderPluckUser2", Age: 10} - user3 := User{Name: "OrderPluckUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") - - var ages []int64 - scopedb.Order("age desc").Pluck("age", &ages) - if ages[0] != 20 { - t.Errorf("The first age should be 20 when order with age desc") - } - - var ages1, ages2 []int64 - scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) - if !reflect.DeepEqual(ages1, ages2) { - t.Errorf("The first order is the primary order") - } - - var ages3, ages4 []int64 - scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) - if reflect.DeepEqual(ages3, ages4) { - t.Errorf("Reorder should work") - } - - var names []string - var ages5 []int64 - scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) - if names != nil && ages5 != nil { - if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { - t.Errorf("Order with multiple orders") - } - } else { - t.Errorf("Order with multiple orders") - } - - DB.Model(User{}).Select("name, age").Find(&[]User{}) -} - -func TestLimit(t *testing.T) { - user1 := User{Name: "LimitUser1", Age: 1} - user2 := User{Name: "LimitUser2", Age: 10} - user3 := User{Name: "LimitUser3", Age: 20} - user4 := User{Name: "LimitUser4", Age: 10} - user5 := User{Name: "LimitUser5", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) - - var users1, users2, users3 []User - DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) - - if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") - } -} - -func TestOffset(t *testing.T) { - for i := 0; i < 20; i++ { - DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) - } - var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) - - if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { - t.Errorf("Offset should work") - } -} - -func TestOr(t *testing.T) { - user1 := User{Name: "OrUser1", Age: 1} - user2 := User{Name: "OrUser2", Age: 10} - user3 := User{Name: "OrUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users []User - DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) - if len(users) != 2 { - t.Errorf("Find users with or") - } -} - -func TestCount(t *testing.T) { - user1 := User{Name: "CountUser1", Age: 1} - user2 := User{Name: "CountUser2", Age: 10} - user3 := User{Name: "CountUser3", Age: 20} - - DB.Save(&user1).Save(&user2).Save(&user3) - var count, count1, count2 int64 - var users []User - - if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { - t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) - } - - if count != int64(len(users)) { - t.Errorf("Count() method should get correct value") - } - - DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) - if count1 != 1 || count2 != 3 { - t.Errorf("Multiple count in chain") - } -} - -func TestNot(t *testing.T) { - DB.Create(getPreparedUser("user1", "not")) - DB.Create(getPreparedUser("user2", "not")) - DB.Create(getPreparedUser("user3", "not")) - - user4 := getPreparedUser("user4", "not") - user4.Company = Company{} - DB.Create(user4) - - DB := DB.Where("role = ?", "not") - - var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User - if DB.Find(&users1).RowsAffected != 4 { - t.Errorf("should find 4 not users") - } - DB.Not(users1[0].Id).Find(&users2) - - if len(users1)-len(users2) != 1 { - t.Errorf("Should ignore the first users with Not") - } - - DB.Not([]int{}).Find(&users3) - if len(users1)-len(users3) != 0 { - t.Errorf("Should find all users with a blank condition") - } - - var name3Count int64 - DB.Table("users").Where("name = ?", "user3").Count(&name3Count) - DB.Not("name", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not("name = ?", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not("name <> ?", "user3").Find(&users4) - if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(User{Name: "user3"}).Find(&users5) - - if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) - if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) - if len(users1)-len(users7) != 2 { // not user3 or user4 - t.Errorf("Should find all user's name not equal to 3 who do not have company id") - } - - DB.Not("name", []string{"user3"}).Find(&users8) - if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - var name2Count int64 - DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users9) - if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") - } -} - -func TestFillSmallerStruct(t *testing.T) { - user1 := User{Name: "SmallerUser", Age: 100} - DB.Save(&user1) - type SimpleUser struct { - Name string - Id int64 - UpdatedAt time.Time - CreatedAt time.Time - } - - var simpleUser SimpleUser - DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) - - if simpleUser.Id == 0 || simpleUser.Name == "" { - t.Errorf("Should fill data correctly into smaller struct") - } -} - -func TestFindOrInitialize(t *testing.T) { - var user1, user2, user3, user4, user5, user6 User - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) - if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) - if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) - if user3.Name != "find or init 2" || user3.Id != 0 { - t.Errorf("user should be initialized with inline search value") - } - - DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and attrs") - } - - DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and assign attrs") - } - - DB.Save(&User{Name: "find or init", Age: 33}) - DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { - t.Errorf("user should be found with FirstOrInit") - } - - DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } -} - -func TestFindOrCreate(t *testing.T) { - var user1, user2, user3, user4, user5, user6, user7, user8 User - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) - if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) - if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) - if user3.Name != "find or create 2" || user3.Id == 0 { - t.Errorf("user should be created with inline search value") - } - - DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) - if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and attrs") - } - - updatedAt1 := user4.UpdatedAt - DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) - if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateAt should be changed when update values with assign") - } - - DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) - if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) - if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Find(&user7) - if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) - if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { - t.Errorf("embedded struct email should be saved") - } - - if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { - t.Errorf("embedded struct credit card should be saved") - } -} - -func TestSelectWithEscapedFieldName(t *testing.T) { - user1 := User{Name: "EscapedFieldNameUser", Age: 1} - user2 := User{Name: "EscapedFieldNameUser", Age: 10} - user3 := User{Name: "EscapedFieldNameUser", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var names []string - DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) - - if len(names) != 3 { - t.Errorf("Expected 3 name, but got: %d", len(names)) - } -} - -func TestSelectWithVariables(t *testing.T) { - DB.Save(&User{Name: "jinzhu"}) - - rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() - - if !rows.Next() { - t.Errorf("Should have returned at least one row") - } else { - columns, _ := rows.Columns() - if !reflect.DeepEqual(columns, []string{"fake"}) { - t.Errorf("Should only contains one column") - } - } -} - -func TestSelectWithArrayInput(t *testing.T) { - DB.Save(&User{Name: "jinzhu", Age: 42}) - - var user User - DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) - - if user.Name != "jinzhu" || user.Age != 42 { - t.Errorf("Should have selected both age and name") - } -} diff --git a/vendor/github.com/jinzhu/gorm/scaner_test.go b/vendor/github.com/jinzhu/gorm/scaner_test.go deleted file mode 100644 index 214105481..000000000 --- a/vendor/github.com/jinzhu/gorm/scaner_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package gorm_test - -import ( - "database/sql/driver" - "encoding/json" - "testing" -) - -func TestScannableSlices(t *testing.T) { - if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { - t.Errorf("Should create table with slice values correctly: %s", err) - } - - r1 := RecordWithSlice{ - Strings: ExampleStringSlice{"a", "b", "c"}, - Structs: ExampleStructSlice{ - {"name1", "value1"}, - {"name2", "value2"}, - }, - } - - if err := DB.Save(&r1).Error; err != nil { - t.Errorf("Should save record with slice values") - } - - var r2 RecordWithSlice - - if err := DB.Find(&r2).Error; err != nil { - t.Errorf("Should fetch record with slice values") - } - - if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { - t.Errorf("Should have serialised and deserialised a string array") - } - - if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { - t.Errorf("Should have serialised and deserialised a struct array") - } -} - -type RecordWithSlice struct { - ID uint64 - Strings ExampleStringSlice `sql:"type:text"` - Structs ExampleStructSlice `sql:"type:text"` -} - -type ExampleStringSlice []string - -func (l ExampleStringSlice) Value() (driver.Value, error) { - return json.Marshal(l) -} - -func (l *ExampleStringSlice) Scan(input interface{}) error { - return json.Unmarshal(input.([]byte), l) -} - -type ExampleStruct struct { - Name string - Value string -} - -type ExampleStructSlice []ExampleStruct - -func (l ExampleStructSlice) Value() (driver.Value, error) { - return json.Marshal(l) -} - -func (l *ExampleStructSlice) Scan(input interface{}) error { - return json.Unmarshal(input.([]byte), l) -} diff --git a/vendor/github.com/jinzhu/gorm/scope.go b/vendor/github.com/jinzhu/gorm/scope.go deleted file mode 100644 index 844df85c7..000000000 --- a/vendor/github.com/jinzhu/gorm/scope.go +++ /dev/null @@ -1,1246 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "regexp" - "strconv" - "strings" - "time" - - "reflect" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() sqlCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Index(str, ".") != -1 { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - if expr, ok := value.(*expr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Index(scope.Search.tableName, " ") != -1 { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(sqlCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var columnRegexp = regexp.MustCompile("^[a-zA-Z]+(\\.[a-zA-Z]+)*$") // only match string like `name`, `users.name` - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - selectFields []*Field - values = make([]interface{}, len(columns)) - selectedColumnsMap = map[string]int{} - resetFields = map[*Field]int{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[field] = index - } - - selectedColumnsMap[column] = fieldIndex - break - } - } - } - - scope.Err(rows.Scan(values...)) - - for field, index := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - // if string is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - for _, field := range scope.New(value).Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case string: - // is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } - return "" - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - for _, field := range scope.New(value).Fields() { - if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() - } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { - sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.countingQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - orders = append(orders, scope.quoteIfPossible(order)) - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values}).Fields() { - if !field.IsBlank { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal { - hasUpdate = true - if err == ErrUnaddressable { - fmt.Println(err) - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - scope.Search.Select("count(*)") - scope.Search.countingQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { - return false - } - return true && !scope.HasError() -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - query := toScope.db - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(query.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := toScope.db - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) - } - scope.Err(query.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { - if name == "INDEX" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) - } - indexes[name] = append(indexes[name], field.DBName) - } - - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { - if name == "UNIQUE_INDEX" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) - } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) - } - } - - for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) - } - - for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - for _, value := range values { - indirectValue := reflect.ValueOf(value) - for indirectValue.Kind() == reflect.Ptr { - indirectValue = indirectValue.Elem() - } - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - for _, column := range columns { - result = append(result, object.FieldByName(column).Interface()) - } - results = append(results, result) - } - case reflect.Struct: - var result []interface{} - for _, column := range columns { - result = append(result, indirectValue.FieldByName(column).Interface()) - } - results = append(results, result) - } - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() { - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() { - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} diff --git a/vendor/github.com/jinzhu/gorm/scope_test.go b/vendor/github.com/jinzhu/gorm/scope_test.go deleted file mode 100644 index 42458995d..000000000 --- a/vendor/github.com/jinzhu/gorm/scope_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package gorm_test - -import ( - "github.com/jinzhu/gorm" - "testing" -) - -func NameIn1And2(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) -} - -func NameIn2And3(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) -} - -func NameIn(names []string) func(d *gorm.DB) *gorm.DB { - return func(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", names) - } -} - -func TestScopes(t *testing.T) { - user1 := User{Name: "ScopeUser1", Age: 1} - user2 := User{Name: "ScopeUser2", Age: 1} - user3 := User{Name: "ScopeUser3", Age: 2} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users1, users2, users3 []User - DB.Scopes(NameIn1And2).Find(&users1) - if len(users1) != 2 { - t.Errorf("Should found two users's name in 1, 2") - } - - DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) - if len(users2) != 1 { - t.Errorf("Should found one user's name is 2") - } - - DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) - if len(users3) != 2 { - t.Errorf("Should found two users's name in 1, 3") - } -} diff --git a/vendor/github.com/jinzhu/gorm/search.go b/vendor/github.com/jinzhu/gorm/search.go deleted file mode 100644 index 078bd4298..000000000 --- a/vendor/github.com/jinzhu/gorm/search.go +++ /dev/null @@ -1,149 +0,0 @@ -package gorm - -import "fmt" - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []string - preload []searchPreload - offset int - limit int - group string - tableName string - raw bool - Unscoped bool - countingQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := *s - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value string, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - if value != "" { - s.orders = []string{value} - } else { - s.orders = []string{} - } - } else if value != "" { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit int) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset int) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query string, values ...interface{}) *search { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/vendor/github.com/jinzhu/gorm/search_test.go b/vendor/github.com/jinzhu/gorm/search_test.go deleted file mode 100644 index 4db7ab6a5..000000000 --- a/vendor/github.com/jinzhu/gorm/search_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import ( - "reflect" - "testing" -) - -func TestCloneSearch(t *testing.T) { - s := new(search) - s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") - - s1 := s.clone() - s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") - - if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { - t.Errorf("Where should be copied") - } - - if reflect.DeepEqual(s.orders, s1.orders) { - t.Errorf("Order should be copied") - } - - if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { - t.Errorf("InitAttrs should be copied") - } - - if reflect.DeepEqual(s.Select, s1.Select) { - t.Errorf("selectStr should be copied") - } -} diff --git a/vendor/github.com/jinzhu/gorm/test_all.sh b/vendor/github.com/jinzhu/gorm/test_all.sh deleted file mode 100755 index 6c5593b37..000000000 --- a/vendor/github.com/jinzhu/gorm/test_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -dialects=("postgres" "mysql" "sqlite") - -for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test -done diff --git a/vendor/github.com/jinzhu/gorm/update_test.go b/vendor/github.com/jinzhu/gorm/update_test.go deleted file mode 100644 index bdf010912..000000000 --- a/vendor/github.com/jinzhu/gorm/update_test.go +++ /dev/null @@ -1,435 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -func TestUpdate(t *testing.T) { - product1 := Product{Code: "product1code"} - product2 := Product{Code: "product2code"} - - DB.Save(&product1).Save(&product2).Update("code", "product2newcode") - - if product2.Code != "product2newcode" { - t.Errorf("Record should be updated") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt1 := product1.UpdatedAt - - if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { - t.Errorf("Product1 should not be updated") - } - - if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") - - var product4 Product - DB.First(&product4, product1.Id) - if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { - t.Errorf("Product1's code should be updated") - } - - if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { - t.Errorf("Product should not be changed to 789") - } - - if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update with CamelCase") - } - - if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update_column with CamelCase") - } - - var products []Product - DB.Find(&products) - if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { - t.Error("RowsAffected should be correct when do batch update") - } - - DB.First(&product4, product4.Id) - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("Update with expression") - } - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Update with expression should update UpdatedAt") - } -} - -func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - DB.Save(&animal) - updatedAt1 := animal.UpdatedAt - - DB.Save(&animal).Update("name", "Francis") - - if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") - } - - var animals []Animal - DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { - t.Error("RowsAffected should be correct when do batch update") - } - - animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) - DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched - DB.First(&animal, animal.Counter) - if animal.Name != "galeone" { - t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name) - } - - // When changing a field with a default value, the change must occur - animal.Name = "amazing horse" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "amazing horse" { - t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) - } - - // When changing a field with a default value with blank value - animal.Name = "" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "" { - t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) - } -} - -func TestUpdates(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 10} - DB.Save(&product1).Save(&product2) - DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) - if product1.Code != "product1newcode" || product1.Price != 100 { - t.Errorf("Record should be updated also with map") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - - if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { - t.Errorf("Product2 should not be updated") - } - - if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { - t.Errorf("Product1 should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) - if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("product2's code should be updated") - } - - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100 { - t.Errorf("Updates with expression") - } - // product4's UpdatedAt will be reset when updating - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Updates with expression should update UpdatedAt") - } -} - -func TestUpdateColumn(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 20} - DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) - if product2.Code != "product2newcode" || product2.Price != 100 { - t.Errorf("product 2 should be updated with update column") - } - - var product3 Product - DB.First(&product3, product1.Id) - if product3.Code != "product1code" || product3.Price != 10 { - t.Errorf("product 1 should not be updated") - } - - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - DB.Model(product2).UpdateColumn("code", "update_column_new") - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated with update column") - } - - DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("UpdateColumn with expression") - } - if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateColumn with expression should not update UpdatedAt") - } -} - -func TestSelectWithUpdate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestSelectWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestOmitWithUpdate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships that not omited") - } -} - -func TestOmitWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships not omited") - } -} - -func TestSelectWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } -} - -func TestOmitWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should omit name column when update user") - } -} - -func TestUpdateColumnsSkipsAssociations(t *testing.T) { - user := getPreparedUser("update_columns_user", "special_role") - user.Age = 99 - address1 := "first street" - user.BillingAddress = Address{Address1: address1} - DB.Save(user) - - // Update a single field of the user and verify that the changed address is not stored. - newAge := int64(100) - user.BillingAddress.Address1 = "second street" - db := DB.Model(user).UpdateColumns(User{Age: newAge}) - if db.RowsAffected != 1 { - t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) - } - - // Verify that Age now=`newAge`. - freshUser := &User{Id: user.Id} - DB.First(freshUser) - if freshUser.Age != newAge { - t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) - } - - // Verify that user's BillingAddress.Address1 is not changed and is still "first street". - DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) - if freshUser.BillingAddress.Address1 != address1 { - t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) - } -} - -func TestUpdatesWithBlankValues(t *testing.T) { - product := Product{Code: "product1", Price: 10} - DB.Save(&product) - - DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) - - var product1 Product - DB.First(&product1, product.Id) - - if product1.Code != "product1" || product1.Price != 100 { - t.Errorf("product's code should not be updated") - } -} - -func TestUpdateDecodeVirtualAttributes(t *testing.T) { - var user = User{ - Name: "jinzhu", - IgnoreMe: 88, - } - - DB.Save(&user) - - DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) - - if user.IgnoreMe != 100 { - t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") - } -} diff --git a/vendor/github.com/jinzhu/gorm/utils.go b/vendor/github.com/jinzhu/gorm/utils.go deleted file mode 100644 index dc69e8046..000000000 --- a/vendor/github.com/jinzhu/gorm/utils.go +++ /dev/null @@ -1,264 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - if i > 0 { - if currCase == upper { - if lastCase == upper && nextCase == upper { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - -// SQL expression -type expr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/vendor/github.com/jinzhu/gorm/utils_test.go b/vendor/github.com/jinzhu/gorm/utils_test.go deleted file mode 100644 index 07f5b17f4..000000000 --- a/vendor/github.com/jinzhu/gorm/utils_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} diff --git a/vendor/github.com/jinzhu/inflection/LICENSE b/vendor/github.com/jinzhu/inflection/LICENSE deleted file mode 100644 index a1ca9a0ff..000000000 --- a/vendor/github.com/jinzhu/inflection/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2015 - Jinzhu - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/jinzhu/inflection/README.md b/vendor/github.com/jinzhu/inflection/README.md deleted file mode 100644 index 4dd0f2d9f..000000000 --- a/vendor/github.com/jinzhu/inflection/README.md +++ /dev/null @@ -1,55 +0,0 @@ -Inflection -========= - -Inflection pluralizes and singularizes English nouns - -## Basic Usage - -```go -inflection.Plural("person") => "people" -inflection.Plural("Person") => "People" -inflection.Plural("PERSON") => "PEOPLE" -inflection.Plural("bus") => "buses" -inflection.Plural("BUS") => "BUSES" -inflection.Plural("Bus") => "Buses" - -inflection.Singular("people") => "person" -inflection.Singular("People") => "Person" -inflection.Singular("PEOPLE") => "PERSON" -inflection.Singular("buses") => "bus" -inflection.Singular("BUSES") => "BUS" -inflection.Singular("Buses") => "Bus" - -inflection.Plural("FancyPerson") => "FancyPeople" -inflection.Singular("FancyPeople") => "FancyPerson" -``` - -## Register Rules - -Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb) - -If you want to register more rules, follow: - -``` -inflection.AddUncountable("fish") -inflection.AddIrregular("person", "people") -inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses" -inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS" -``` - -## Supporting the project - -[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) - - -## Author - -**jinzhu** - -* -* -* - -## License - -Released under the [MIT License](http://www.opensource.org/licenses/MIT). diff --git a/vendor/github.com/jinzhu/inflection/inflections.go b/vendor/github.com/jinzhu/inflection/inflections.go deleted file mode 100644 index 606263bb7..000000000 --- a/vendor/github.com/jinzhu/inflection/inflections.go +++ /dev/null @@ -1,273 +0,0 @@ -/* -Package inflection pluralizes and singularizes English nouns. - - inflection.Plural("person") => "people" - inflection.Plural("Person") => "People" - inflection.Plural("PERSON") => "PEOPLE" - - inflection.Singular("people") => "person" - inflection.Singular("People") => "Person" - inflection.Singular("PEOPLE") => "PERSON" - - inflection.Plural("FancyPerson") => "FancydPeople" - inflection.Singular("FancyPeople") => "FancydPerson" - -Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb) - -If you want to register more rules, follow: - - inflection.AddUncountable("fish") - inflection.AddIrregular("person", "people") - inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses" - inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS" -*/ -package inflection - -import ( - "regexp" - "strings" -) - -type inflection struct { - regexp *regexp.Regexp - replace string -} - -// Regular is a regexp find replace inflection -type Regular struct { - find string - replace string -} - -// Irregular is a hard replace inflection, -// containing both singular and plural forms -type Irregular struct { - singular string - plural string -} - -// RegularSlice is a slice of Regular inflections -type RegularSlice []Regular - -// IrregularSlice is a slice of Irregular inflections -type IrregularSlice []Irregular - -var pluralInflections = RegularSlice{ - {"([a-z])$", "${1}s"}, - {"s$", "s"}, - {"^(ax|test)is$", "${1}es"}, - {"(octop|vir)us$", "${1}i"}, - {"(octop|vir)i$", "${1}i"}, - {"(alias|status)$", "${1}es"}, - {"(bu)s$", "${1}ses"}, - {"(buffal|tomat)o$", "${1}oes"}, - {"([ti])um$", "${1}a"}, - {"([ti])a$", "${1}a"}, - {"sis$", "ses"}, - {"(?:([^f])fe|([lr])f)$", "${1}${2}ves"}, - {"(hive)$", "${1}s"}, - {"([^aeiouy]|qu)y$", "${1}ies"}, - {"(x|ch|ss|sh)$", "${1}es"}, - {"(matr|vert|ind)(?:ix|ex)$", "${1}ices"}, - {"^(m|l)ouse$", "${1}ice"}, - {"^(m|l)ice$", "${1}ice"}, - {"^(ox)$", "${1}en"}, - {"^(oxen)$", "${1}"}, - {"(quiz)$", "${1}zes"}, -} - -var singularInflections = RegularSlice{ - {"s$", ""}, - {"(ss)$", "${1}"}, - {"(n)ews$", "${1}ews"}, - {"([ti])a$", "${1}um"}, - {"((a)naly|(b)a|(d)iagno|(p)arenthe|(p)rogno|(s)ynop|(t)he)(sis|ses)$", "${1}sis"}, - {"(^analy)(sis|ses)$", "${1}sis"}, - {"([^f])ves$", "${1}fe"}, - {"(hive)s$", "${1}"}, - {"(tive)s$", "${1}"}, - {"([lr])ves$", "${1}f"}, - {"([^aeiouy]|qu)ies$", "${1}y"}, - {"(s)eries$", "${1}eries"}, - {"(m)ovies$", "${1}ovie"}, - {"(c)ookies$", "${1}ookie"}, - {"(x|ch|ss|sh)es$", "${1}"}, - {"^(m|l)ice$", "${1}ouse"}, - {"(bus)(es)?$", "${1}"}, - {"(o)es$", "${1}"}, - {"(shoe)s$", "${1}"}, - {"(cris|test)(is|es)$", "${1}is"}, - {"^(a)x[ie]s$", "${1}xis"}, - {"(octop|vir)(us|i)$", "${1}us"}, - {"(alias|status)(es)?$", "${1}"}, - {"^(ox)en", "${1}"}, - {"(vert|ind)ices$", "${1}ex"}, - {"(matr)ices$", "${1}ix"}, - {"(quiz)zes$", "${1}"}, - {"(database)s$", "${1}"}, -} - -var irregularInflections = IrregularSlice{ - {"person", "people"}, - {"man", "men"}, - {"child", "children"}, - {"sex", "sexes"}, - {"move", "moves"}, - {"mombie", "mombies"}, -} - -var uncountableInflections = []string{"equipment", "information", "rice", "money", "species", "series", "fish", "sheep", "jeans", "police"} - -var compiledPluralMaps []inflection -var compiledSingularMaps []inflection - -func compile() { - compiledPluralMaps = []inflection{} - compiledSingularMaps = []inflection{} - for _, uncountable := range uncountableInflections { - inf := inflection{ - regexp: regexp.MustCompile("^(?i)(" + uncountable + ")$"), - replace: "${1}", - } - compiledPluralMaps = append(compiledPluralMaps, inf) - compiledSingularMaps = append(compiledSingularMaps, inf) - } - - for _, value := range irregularInflections { - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.singular) + "$"), replace: strings.ToUpper(value.plural)}, - inflection{regexp: regexp.MustCompile(strings.Title(value.singular) + "$"), replace: strings.Title(value.plural)}, - inflection{regexp: regexp.MustCompile(value.singular + "$"), replace: value.plural}, - } - compiledPluralMaps = append(compiledPluralMaps, infs...) - } - - for _, value := range irregularInflections { - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.plural) + "$"), replace: strings.ToUpper(value.singular)}, - inflection{regexp: regexp.MustCompile(strings.Title(value.plural) + "$"), replace: strings.Title(value.singular)}, - inflection{regexp: regexp.MustCompile(value.plural + "$"), replace: value.singular}, - } - compiledSingularMaps = append(compiledSingularMaps, infs...) - } - - for i := len(pluralInflections) - 1; i >= 0; i-- { - value := pluralInflections[i] - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)}, - inflection{regexp: regexp.MustCompile(value.find), replace: value.replace}, - inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace}, - } - compiledPluralMaps = append(compiledPluralMaps, infs...) - } - - for i := len(singularInflections) - 1; i >= 0; i-- { - value := singularInflections[i] - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)}, - inflection{regexp: regexp.MustCompile(value.find), replace: value.replace}, - inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace}, - } - compiledSingularMaps = append(compiledSingularMaps, infs...) - } -} - -func init() { - compile() -} - -// AddPlural adds a plural inflection -func AddPlural(find, replace string) { - pluralInflections = append(pluralInflections, Regular{find, replace}) - compile() -} - -// AddSingular adds a singular inflection -func AddSingular(find, replace string) { - singularInflections = append(singularInflections, Regular{find, replace}) - compile() -} - -// AddIrregular adds an irregular inflection -func AddIrregular(singular, plural string) { - irregularInflections = append(irregularInflections, Irregular{singular, plural}) - compile() -} - -// AddUncountable adds an uncountable inflection -func AddUncountable(values ...string) { - uncountableInflections = append(uncountableInflections, values...) - compile() -} - -// GetPlural retrieves the plural inflection values -func GetPlural() RegularSlice { - plurals := make(RegularSlice, len(pluralInflections)) - copy(plurals, pluralInflections) - return plurals -} - -// GetSingular retrieves the singular inflection values -func GetSingular() RegularSlice { - singulars := make(RegularSlice, len(singularInflections)) - copy(singulars, singularInflections) - return singulars -} - -// GetIrregular retrieves the irregular inflection values -func GetIrregular() IrregularSlice { - irregular := make(IrregularSlice, len(irregularInflections)) - copy(irregular, irregularInflections) - return irregular -} - -// GetUncountable retrieves the uncountable inflection values -func GetUncountable() []string { - uncountables := make([]string, len(uncountableInflections)) - copy(uncountables, uncountableInflections) - return uncountables -} - -// SetPlural sets the plural inflections slice -func SetPlural(inflections RegularSlice) { - pluralInflections = inflections - compile() -} - -// SetSingular sets the singular inflections slice -func SetSingular(inflections RegularSlice) { - singularInflections = inflections - compile() -} - -// SetIrregular sets the irregular inflections slice -func SetIrregular(inflections IrregularSlice) { - irregularInflections = inflections - compile() -} - -// SetUncountable sets the uncountable inflections slice -func SetUncountable(inflections []string) { - uncountableInflections = inflections - compile() -} - -// Plural converts a word to its plural form -func Plural(str string) string { - for _, inflection := range compiledPluralMaps { - if inflection.regexp.MatchString(str) { - return inflection.regexp.ReplaceAllString(str, inflection.replace) - } - } - return str -} - -// Singular converts a word to its singular form -func Singular(str string) string { - for _, inflection := range compiledSingularMaps { - if inflection.regexp.MatchString(str) { - return inflection.regexp.ReplaceAllString(str, inflection.replace) - } - } - return str -} diff --git a/vendor/github.com/jinzhu/inflection/inflections_test.go b/vendor/github.com/jinzhu/inflection/inflections_test.go deleted file mode 100644 index 689e1dfb1..000000000 --- a/vendor/github.com/jinzhu/inflection/inflections_test.go +++ /dev/null @@ -1,213 +0,0 @@ -package inflection - -import ( - "strings" - "testing" -) - -var inflections = map[string]string{ - "star": "stars", - "STAR": "STARS", - "Star": "Stars", - "bus": "buses", - "fish": "fish", - "mouse": "mice", - "query": "queries", - "ability": "abilities", - "agency": "agencies", - "movie": "movies", - "archive": "archives", - "index": "indices", - "wife": "wives", - "safe": "saves", - "half": "halves", - "move": "moves", - "salesperson": "salespeople", - "person": "people", - "spokesman": "spokesmen", - "man": "men", - "woman": "women", - "basis": "bases", - "diagnosis": "diagnoses", - "diagnosis_a": "diagnosis_as", - "datum": "data", - "medium": "media", - "stadium": "stadia", - "analysis": "analyses", - "node_child": "node_children", - "child": "children", - "experience": "experiences", - "day": "days", - "comment": "comments", - "foobar": "foobars", - "newsletter": "newsletters", - "old_news": "old_news", - "news": "news", - "series": "series", - "species": "species", - "quiz": "quizzes", - "perspective": "perspectives", - "ox": "oxen", - "photo": "photos", - "buffalo": "buffaloes", - "tomato": "tomatoes", - "dwarf": "dwarves", - "elf": "elves", - "information": "information", - "equipment": "equipment", - "criterion": "criteria", -} - -// storage is used to restore the state of the global variables -// on each test execution, to ensure no global state pollution -type storage struct { - singulars RegularSlice - plurals RegularSlice - irregulars IrregularSlice - uncountables []string -} - -var backup = storage{} - -func init() { - AddIrregular("criterion", "criteria") - copy(backup.singulars, singularInflections) - copy(backup.plurals, pluralInflections) - copy(backup.irregulars, irregularInflections) - copy(backup.uncountables, uncountableInflections) -} - -func restore() { - copy(singularInflections, backup.singulars) - copy(pluralInflections, backup.plurals) - copy(irregularInflections, backup.irregulars) - copy(uncountableInflections, backup.uncountables) -} - -func TestPlural(t *testing.T) { - for key, value := range inflections { - if v := Plural(strings.ToUpper(key)); v != strings.ToUpper(value) { - t.Errorf("%v's plural should be %v, but got %v", strings.ToUpper(key), strings.ToUpper(value), v) - } - - if v := Plural(strings.Title(key)); v != strings.Title(value) { - t.Errorf("%v's plural should be %v, but got %v", strings.Title(key), strings.Title(value), v) - } - - if v := Plural(key); v != value { - t.Errorf("%v's plural should be %v, but got %v", key, value, v) - } - } -} - -func TestSingular(t *testing.T) { - for key, value := range inflections { - if v := Singular(strings.ToUpper(value)); v != strings.ToUpper(key) { - t.Errorf("%v's singular should be %v, but got %v", strings.ToUpper(value), strings.ToUpper(key), v) - } - - if v := Singular(strings.Title(value)); v != strings.Title(key) { - t.Errorf("%v's singular should be %v, but got %v", strings.Title(value), strings.Title(key), v) - } - - if v := Singular(value); v != key { - t.Errorf("%v's singular should be %v, but got %v", value, key, v) - } - } -} - -func TestAddPlural(t *testing.T) { - defer restore() - ln := len(pluralInflections) - AddPlural("", "") - if ln+1 != len(pluralInflections) { - t.Errorf("Expected len %d, got %d", ln+1, len(pluralInflections)) - } -} - -func TestAddSingular(t *testing.T) { - defer restore() - ln := len(singularInflections) - AddSingular("", "") - if ln+1 != len(singularInflections) { - t.Errorf("Expected len %d, got %d", ln+1, len(singularInflections)) - } -} - -func TestAddIrregular(t *testing.T) { - defer restore() - ln := len(irregularInflections) - AddIrregular("", "") - if ln+1 != len(irregularInflections) { - t.Errorf("Expected len %d, got %d", ln+1, len(irregularInflections)) - } -} - -func TestAddUncountable(t *testing.T) { - defer restore() - ln := len(uncountableInflections) - AddUncountable("", "") - if ln+2 != len(uncountableInflections) { - t.Errorf("Expected len %d, got %d", ln+2, len(uncountableInflections)) - } -} - -func TestGetPlural(t *testing.T) { - plurals := GetPlural() - if len(plurals) != len(pluralInflections) { - t.Errorf("Expected len %d, got %d", len(plurals), len(pluralInflections)) - } -} - -func TestGetSingular(t *testing.T) { - singular := GetSingular() - if len(singular) != len(singularInflections) { - t.Errorf("Expected len %d, got %d", len(singular), len(singularInflections)) - } -} - -func TestGetIrregular(t *testing.T) { - irregular := GetIrregular() - if len(irregular) != len(irregularInflections) { - t.Errorf("Expected len %d, got %d", len(irregular), len(irregularInflections)) - } -} - -func TestGetUncountable(t *testing.T) { - uncountables := GetUncountable() - if len(uncountables) != len(uncountableInflections) { - t.Errorf("Expected len %d, got %d", len(uncountables), len(uncountableInflections)) - } -} - -func TestSetPlural(t *testing.T) { - defer restore() - SetPlural(RegularSlice{{}, {}}) - if len(pluralInflections) != 2 { - t.Errorf("Expected len 2, got %d", len(pluralInflections)) - } -} - -func TestSetSingular(t *testing.T) { - defer restore() - SetSingular(RegularSlice{{}, {}}) - if len(singularInflections) != 2 { - t.Errorf("Expected len 2, got %d", len(singularInflections)) - } -} - -func TestSetIrregular(t *testing.T) { - defer restore() - SetIrregular(IrregularSlice{{}, {}}) - if len(irregularInflections) != 2 { - t.Errorf("Expected len 2, got %d", len(irregularInflections)) - } -} - -func TestSetUncountable(t *testing.T) { - defer restore() - SetUncountable([]string{"", ""}) - if len(uncountableInflections) != 2 { - t.Errorf("Expected len 2, got %d", len(uncountableInflections)) - } -} diff --git a/vendor/github.com/jmoiron/sqlx/.gitignore b/vendor/github.com/jmoiron/sqlx/.gitignore new file mode 100644 index 000000000..529841cf1 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +tags +environ diff --git a/vendor/github.com/jmoiron/sqlx/LICENSE b/vendor/github.com/jmoiron/sqlx/LICENSE new file mode 100644 index 000000000..0d31edfa7 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/LICENSE @@ -0,0 +1,23 @@ + Copyright (c) 2013, Jason Moiron + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without + restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following + conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + diff --git a/vendor/github.com/jmoiron/sqlx/README.md b/vendor/github.com/jmoiron/sqlx/README.md new file mode 100644 index 000000000..5c1bb3cb9 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/README.md @@ -0,0 +1,183 @@ +# sqlx + +[![Build Status](https://drone.io/github.com/jmoiron/sqlx/status.png)](https://drone.io/github.com/jmoiron/sqlx/latest) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE) + +sqlx is a library which provides a set of extensions on go's standard +`database/sql` library. The sqlx versions of `sql.DB`, `sql.TX`, `sql.Stmt`, +et al. all leave the underlying interfaces untouched, so that their interfaces +are a superset on the standard ones. This makes it relatively painless to +integrate existing codebases using database/sql with sqlx. + +Major additional concepts are: + +* Marshal rows into structs (with embedded struct support), maps, and slices +* Named parameter support including prepared statements +* `Get` and `Select` to go quickly from query to struct/slice + +In addition to the [godoc API documentation](http://godoc.org/github.com/jmoiron/sqlx), +there is also some [standard documentation](http://jmoiron.github.io/sqlx/) that +explains how to use `database/sql` along with sqlx. + +## Recent Changes + +* sqlx/types.JsonText has been renamed to JSONText to follow Go naming conventions. + +This breaks backwards compatibility, but it's in a way that is trivially fixable +(`s/JsonText/JSONText/g`). The `types` package is both experimental and not in +active development currently. + +* Using Go 1.6 and below with `types.JSONText` and `types.GzippedText` can be _potentially unsafe_, **especially** when used with common auto-scan sqlx idioms like `Select` and `Get`. See [golang bug #13905](https://github.com/golang/go/issues/13905). + +### Backwards Compatibility + +There is no Go1-like promise of absolute stability, but I take the issue seriously +and will maintain the library in a compatible state unless vital bugs prevent me +from doing so. Since [#59](https://github.com/jmoiron/sqlx/issues/59) and +[#60](https://github.com/jmoiron/sqlx/issues/60) necessitated breaking behavior, +a wider API cleanup was done at the time of fixing. It's possible this will happen +in future; if it does, a git tag will be provided for users requiring the old +behavior to continue to use it until such a time as they can migrate. + +## install + + go get github.com/jmoiron/sqlx + +## issues + +Row headers can be ambiguous (`SELECT 1 AS a, 2 AS a`), and the result of +`Columns()` does not fully qualify column names in queries like: + +```sql +SELECT a.id, a.name, b.id, b.name FROM foos AS a JOIN foos AS b ON a.parent = b.id; +``` + +making a struct or map destination ambiguous. Use `AS` in your queries +to give columns distinct names, `rows.Scan` to scan them manually, or +`SliceScan` to get a slice of results. + +## usage + +Below is an example which shows some common use cases for sqlx. Check +[sqlx_test.go](https://github.com/jmoiron/sqlx/blob/master/sqlx_test.go) for more +usage. + + +```go +package main + +import ( + _ "github.com/lib/pq" + "database/sql" + "github.com/jmoiron/sqlx" + "log" +) + +var schema = ` +CREATE TABLE person ( + first_name text, + last_name text, + email text +); + +CREATE TABLE place ( + country text, + city text NULL, + telcode integer +)` + +type Person struct { + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string +} + +type Place struct { + Country string + City sql.NullString + TelCode int +} + +func main() { + // this Pings the database trying to connect, panics on error + // use sqlx.Open() for sql.Open() semantics + db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable") + if err != nil { + log.Fatalln(err) + } + + // exec the schema or fail; multi-statement Exec behavior varies between + // database drivers; pq will exec them all, sqlite3 won't, ymmv + db.MustExec(schema) + + tx := db.MustBegin() + tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "Jason", "Moiron", "jmoiron@jmoiron.net") + tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "John", "Doe", "johndoeDNE@gmail.net") + tx.MustExec("INSERT INTO place (country, city, telcode) VALUES ($1, $2, $3)", "United States", "New York", "1") + tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Hong Kong", "852") + tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Singapore", "65") + // Named queries can use structs, so if you have an existing struct (i.e. person := &Person{}) that you have populated, you can pass it in as &person + tx.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", &Person{"Jane", "Citizen", "jane.citzen@example.com"}) + tx.Commit() + + // Query the database, storing results in a []Person (wrapped in []interface{}) + people := []Person{} + db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") + jason, john := people[0], people[1] + + fmt.Printf("%#v\n%#v", jason, john) + // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} + // Person{FirstName:"John", LastName:"Doe", Email:"johndoeDNE@gmail.net"} + + // You can also get a single result, a la QueryRow + jason = Person{} + err = db.Get(&jason, "SELECT * FROM person WHERE first_name=$1", "Jason") + fmt.Printf("%#v\n", jason) + // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} + + // if you have null fields and use SELECT *, you must use sql.Null* in your struct + places := []Place{} + err = db.Select(&places, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + fmt.Println(err) + return + } + usa, singsing, honkers := places[0], places[1], places[2] + + fmt.Printf("%#v\n%#v\n%#v\n", usa, singsing, honkers) + // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} + // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} + // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} + + // Loop through rows using only one struct + place := Place{} + rows, err := db.Queryx("SELECT * FROM place") + for rows.Next() { + err := rows.StructScan(&place) + if err != nil { + log.Fatalln(err) + } + fmt.Printf("%#v\n", place) + } + // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} + // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} + // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} + + // Named queries, using `:name` as the bindvar. Automatic bindvar support + // which takes into account the dbtype based on the driverName on sqlx.Open/Connect + _, err = db.NamedExec(`INSERT INTO person (first_name,last_name,email) VALUES (:first,:last,:email)`, + map[string]interface{}{ + "first": "Bin", + "last": "Smuth", + "email": "bensmith@allblacks.nz", + }) + + // Selects Mr. Smith from the database + rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:fn`, map[string]interface{}{"fn": "Bin"}) + + // Named queries can also use structs. Their bind names follow the same rules + // as the name -> db mapping, so struct fields are lowercased and the `db` tag + // is taken into consideration. + rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:first_name`, jason) +} +``` + diff --git a/vendor/github.com/jmoiron/sqlx/bind.go b/vendor/github.com/jmoiron/sqlx/bind.go new file mode 100644 index 000000000..10f7bdf84 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/bind.go @@ -0,0 +1,207 @@ +package sqlx + +import ( + "bytes" + "errors" + "reflect" + "strconv" + "strings" + + "github.com/jmoiron/sqlx/reflectx" +) + +// Bindvar types supported by Rebind, BindMap and BindStruct. +const ( + UNKNOWN = iota + QUESTION + DOLLAR + NAMED +) + +// BindType returns the bindtype for a given database given a drivername. +func BindType(driverName string) int { + switch driverName { + case "postgres", "pgx": + return DOLLAR + case "mysql": + return QUESTION + case "sqlite3": + return QUESTION + case "oci8", "ora", "goracle": + return NAMED + } + return UNKNOWN +} + +// FIXME: this should be able to be tolerant of escaped ?'s in queries without +// losing much speed, and should be to avoid confusion. + +// Rebind a query from the default bindtype (QUESTION) to the target bindtype. +func Rebind(bindType int, query string) string { + switch bindType { + case QUESTION, UNKNOWN: + return query + } + + // Add space enough for 10 params before we have to allocate + rqb := make([]byte, 0, len(query)+10) + + var i, j int + + for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { + rqb = append(rqb, query[:i]...) + + switch bindType { + case DOLLAR: + rqb = append(rqb, '$') + case NAMED: + rqb = append(rqb, ':', 'a', 'r', 'g') + } + + j++ + rqb = strconv.AppendInt(rqb, int64(j), 10) + + query = query[i+1:] + } + + return string(append(rqb, query...)) +} + +// Experimental implementation of Rebind which uses a bytes.Buffer. The code is +// much simpler and should be more resistant to odd unicode, but it is twice as +// slow. Kept here for benchmarking purposes and to possibly replace Rebind if +// problems arise with its somewhat naive handling of unicode. +func rebindBuff(bindType int, query string) string { + if bindType != DOLLAR { + return query + } + + b := make([]byte, 0, len(query)) + rqb := bytes.NewBuffer(b) + j := 1 + for _, r := range query { + if r == '?' { + rqb.WriteRune('$') + rqb.WriteString(strconv.Itoa(j)) + j++ + } else { + rqb.WriteRune(r) + } + } + + return rqb.String() +} + +// In expands slice values in args, returning the modified query string +// and a new arg list that can be executed by a database. The `query` should +// use the `?` bindVar. The return value uses the `?` bindVar. +func In(query string, args ...interface{}) (string, []interface{}, error) { + // argMeta stores reflect.Value and length for slices and + // the value itself for non-slice arguments + type argMeta struct { + v reflect.Value + i interface{} + length int + } + + var flatArgsCount int + var anySlices bool + + meta := make([]argMeta, len(args)) + + for i, arg := range args { + v := reflect.ValueOf(arg) + t := reflectx.Deref(v.Type()) + + if t.Kind() == reflect.Slice { + meta[i].length = v.Len() + meta[i].v = v + + anySlices = true + flatArgsCount += meta[i].length + + if meta[i].length == 0 { + return "", nil, errors.New("empty slice passed to 'in' query") + } + } else { + meta[i].i = arg + flatArgsCount++ + } + } + + // don't do any parsing if there aren't any slices; note that this means + // some errors that we might have caught below will not be returned. + if !anySlices { + return query, args, nil + } + + newArgs := make([]interface{}, 0, flatArgsCount) + buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount)) + + var arg, offset int + + for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { + if arg >= len(meta) { + // if an argument wasn't passed, lets return an error; this is + // not actually how database/sql Exec/Query works, but since we are + // creating an argument list programmatically, we want to be able + // to catch these programmer errors earlier. + return "", nil, errors.New("number of bindVars exceeds arguments") + } + + argMeta := meta[arg] + arg++ + + // not a slice, continue. + // our questionmark will either be written before the next expansion + // of a slice or after the loop when writing the rest of the query + if argMeta.length == 0 { + offset = offset + i + 1 + newArgs = append(newArgs, argMeta.i) + continue + } + + // write everything up to and including our ? character + buf.WriteString(query[:offset+i+1]) + + for si := 1; si < argMeta.length; si++ { + buf.WriteString(", ?") + } + + newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) + + // slice the query and reset the offset. this avoids some bookkeeping for + // the write after the loop + query = query[offset+i+1:] + offset = 0 + } + + buf.WriteString(query) + + if arg < len(meta) { + return "", nil, errors.New("number of bindVars less than number arguments") + } + + return buf.String(), newArgs, nil +} + +func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { + switch val := v.Interface().(type) { + case []interface{}: + args = append(args, val...) + case []int: + for i := range val { + args = append(args, val[i]) + } + case []string: + for i := range val { + args = append(args, val[i]) + } + default: + for si := 0; si < vlen; si++ { + args = append(args, v.Index(si).Interface()) + } + } + + return args +} diff --git a/vendor/github.com/jmoiron/sqlx/doc.go b/vendor/github.com/jmoiron/sqlx/doc.go new file mode 100644 index 000000000..e2b4e60b2 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/doc.go @@ -0,0 +1,12 @@ +// Package sqlx provides general purpose extensions to database/sql. +// +// It is intended to seamlessly wrap database/sql and provide convenience +// methods which are useful in the development of database driven applications. +// None of the underlying database/sql methods are changed. Instead all extended +// behavior is implemented through new methods defined on wrapper types. +// +// Additions include scanning into structs, named query support, rebinding +// queries for different drivers, convenient shorthands for common error handling +// and more. +// +package sqlx diff --git a/vendor/github.com/jmoiron/sqlx/named.go b/vendor/github.com/jmoiron/sqlx/named.go new file mode 100644 index 000000000..dd899d351 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/named.go @@ -0,0 +1,344 @@ +package sqlx + +// Named Query Support +// +// * BindMap - bind query bindvars to map/struct args +// * NamedExec, NamedQuery - named query w/ struct or map +// * NamedStmt - a pre-compiled named query which is a prepared statement +// +// Internal Interfaces: +// +// * compileNamedQuery - rebind a named query, returning a query and list of names +// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist +// +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strconv" + "unicode" + + "github.com/jmoiron/sqlx/reflectx" +) + +// NamedStmt is a prepared statement that executes named queries. Prepare it +// how you would execute a NamedQuery, but pass in a struct or map when executing. +type NamedStmt struct { + Params []string + QueryString string + Stmt *Stmt +} + +// Close closes the named statement. +func (n *NamedStmt) Close() error { + return n.Stmt.Close() +} + +// Exec executes a named statement using the struct passed. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return *new(sql.Result), err + } + return n.Stmt.Exec(args...) +} + +// Query executes a named statement using the struct argument, returning rows. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return nil, err + } + return n.Stmt.Query(args...) +} + +// QueryRow executes a named statement against the database. Because sqlx cannot +// create a *sql.Row with an error condition pre-set for binding errors, sqlx +// returns a *sqlx.Row instead. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRow(arg interface{}) *Row { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return &Row{err: err} + } + return n.Stmt.QueryRowx(args...) +} + +// MustExec execs a NamedStmt, panicing on error +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) MustExec(arg interface{}) sql.Result { + res, err := n.Exec(arg) + if err != nil { + panic(err) + } + return res +} + +// Queryx using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { + r, err := n.Query(arg) + if err != nil { + return nil, err + } + return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err +} + +// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is +// an alias for QueryRow. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowx(arg interface{}) *Row { + return n.QueryRow(arg) +} + +// Select using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { + rows, err := n.Queryx(arg) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// Get using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { + r := n.QueryRowx(arg) + return r.scanAny(dest, false) +} + +// Unsafe creates an unsafe version of the NamedStmt +func (n *NamedStmt) Unsafe() *NamedStmt { + r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString} + r.Stmt.unsafe = true + return r +} + +// A union interface of preparer and binder, required to be able to prepare +// named statements (as the bindtype must be determined). +type namedPreparer interface { + Preparer + binder +} + +func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { + bindType := BindType(p.DriverName()) + q, args, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return nil, err + } + stmt, err := Preparex(p, q) + if err != nil { + return nil, err + } + return &NamedStmt{ + QueryString: q, + Params: args, + Stmt: stmt, + }, nil +} + +func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { + if maparg, ok := arg.(map[string]interface{}); ok { + return bindMapArgs(names, maparg) + } + return bindArgs(names, arg, m) +} + +// private interface to generate a list of interfaces from a given struct +// type, given a list of names to pull out of the struct. Used by public +// BindStruct interface. +func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { + arglist := make([]interface{}, 0, len(names)) + + // grab the indirected value of arg + v := reflect.ValueOf(arg) + for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { + v = v.Elem() + } + + fields := m.TraversalsByName(v.Type(), names) + for i, t := range fields { + if len(t) == 0 { + return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg) + } + val := reflectx.FieldByIndexesReadOnly(v, t) + arglist = append(arglist, val.Interface()) + } + + return arglist, nil +} + +// like bindArgs, but for maps. +func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) { + arglist := make([]interface{}, 0, len(names)) + + for _, name := range names { + val, ok := arg[name] + if !ok { + return arglist, fmt.Errorf("could not find name %s in %#v", name, arg) + } + arglist = append(arglist, val) + } + return arglist, nil +} + +// bindStruct binds a named parameter query with fields from a struct argument. +// The rules for binding field names to parameter names follow the same +// conventions as for StructScan, including obeying the `db` struct tags. +func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { + bound, names, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return "", []interface{}{}, err + } + + arglist, err := bindArgs(names, arg, m) + if err != nil { + return "", []interface{}{}, err + } + + return bound, arglist, nil +} + +// bindMap binds a named parameter query with a map of arguments. +func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { + bound, names, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return "", []interface{}{}, err + } + + arglist, err := bindMapArgs(names, args) + return bound, arglist, err +} + +// -- Compilation of Named Queries + +// Allow digits and letters in bind params; additionally runes are +// checked against underscores, meaning that bind params can have be +// alphanumeric with underscores. Mind the difference between unicode +// digits and numbers, where '5' is a digit but '五' is not. +var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} + +// FIXME: this function isn't safe for unicode named params, as a failing test +// can testify. This is not a regression but a failure of the original code +// as well. It should be modified to range over runes in a string rather than +// bytes, even though this is less convenient and slower. Hopefully the +// addition of the prepared NamedStmt (which will only do this once) will make +// up for the slightly slower ad-hoc NamedExec/NamedQuery. + +// compile a NamedQuery into an unbound query (using the '?' bindvar) and +// a list of names. +func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { + names = make([]string, 0, 10) + rebound := make([]byte, 0, len(qs)) + + inName := false + last := len(qs) - 1 + currentVar := 1 + name := make([]byte, 0, 10) + + for i, b := range qs { + // a ':' while we're in a name is an error + if b == ':' { + // if this is the second ':' in a '::' escape sequence, append a ':' + if inName && i > 0 && qs[i-1] == ':' { + rebound = append(rebound, ':') + inName = false + continue + } else if inName { + err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) + return query, names, err + } + inName = true + name = []byte{} + // if we're in a name, and this is an allowed character, continue + } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { + // append the byte to the name if we are in a name and not on the last byte + name = append(name, b) + // if we're in a name and it's not an allowed character, the name is done + } else if inName { + inName = false + // if this is the final byte of the string and it is part of the name, then + // make sure to add it to the name + if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { + name = append(name, b) + } + // add the string representation to the names list + names = append(names, string(name)) + // add a proper bindvar for the bindType + switch bindType { + // oracle only supports named type bind vars even for positional + case NAMED: + rebound = append(rebound, ':') + rebound = append(rebound, name...) + case QUESTION, UNKNOWN: + rebound = append(rebound, '?') + case DOLLAR: + rebound = append(rebound, '$') + for _, b := range strconv.Itoa(currentVar) { + rebound = append(rebound, byte(b)) + } + currentVar++ + } + // add this byte to string unless it was not part of the name + if i != last { + rebound = append(rebound, b) + } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { + rebound = append(rebound, b) + } + } else { + // this is a normal byte and should just go onto the rebound query + rebound = append(rebound, b) + } + } + + return string(rebound), names, err +} + +// BindNamed binds a struct or a map to a query with named parameters. +// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. +func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(bindType, query, arg, mapper()) +} + +// Named takes a query using named parameters and an argument and +// returns a new query with a list of args that can be executed by +// a database. The return value uses the `?` bindvar. +func Named(query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(QUESTION, query, arg, mapper()) +} + +func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { + if maparg, ok := arg.(map[string]interface{}); ok { + return bindMap(bindType, query, maparg) + } + return bindStruct(bindType, query, arg, m) +} + +// NamedQuery binds a named query and then runs Query on the result using the +// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with +// map[string]interface{} types. +func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.Queryx(q, args...) +} + +// NamedExec uses BindStruct to get a query executable by the driver and +// then runs Exec on the result. Returns an error from the binding +// or the query excution itself. +func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.Exec(q, args...) +} diff --git a/vendor/github.com/jmoiron/sqlx/named_context.go b/vendor/github.com/jmoiron/sqlx/named_context.go new file mode 100644 index 000000000..9405007e2 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/named_context.go @@ -0,0 +1,132 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" +) + +// A union interface of contextPreparer and binder, required to be able to +// prepare named statements with context (as the bindtype must be determined). +type namedPreparerContext interface { + PreparerContext + binder +} + +func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { + bindType := BindType(p.DriverName()) + q, args, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return nil, err + } + stmt, err := PreparexContext(ctx, p, q) + if err != nil { + return nil, err + } + return &NamedStmt{ + QueryString: q, + Params: args, + Stmt: stmt, + }, nil +} + +// ExecContext executes a named statement using the struct passed. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return *new(sql.Result), err + } + return n.Stmt.ExecContext(ctx, args...) +} + +// QueryContext executes a named statement using the struct argument, returning rows. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return nil, err + } + return n.Stmt.QueryContext(ctx, args...) +} + +// QueryRowContext executes a named statement against the database. Because sqlx cannot +// create a *sql.Row with an error condition pre-set for binding errors, sqlx +// returns a *sqlx.Row instead. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return &Row{err: err} + } + return n.Stmt.QueryRowxContext(ctx, args...) +} + +// MustExecContext execs a NamedStmt, panicing on error +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { + res, err := n.ExecContext(ctx, arg) + if err != nil { + panic(err) + } + return res +} + +// QueryxContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { + r, err := n.QueryContext(ctx, arg) + if err != nil { + return nil, err + } + return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err +} + +// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is +// an alias for QueryRow. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { + return n.QueryRowContext(ctx, arg) +} + +// SelectContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { + rows, err := n.QueryxContext(ctx, arg) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// GetContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { + r := n.QueryRowxContext(ctx, arg) + return r.scanAny(dest, false) +} + +// NamedQueryContext binds a named query and then runs Query on the result using the +// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with +// map[string]interface{} types. +func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.QueryxContext(ctx, q, args...) +} + +// NamedExecContext uses BindStruct to get a query executable by the driver and +// then runs Exec on the result. Returns an error from the binding +// or the query excution itself. +func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.ExecContext(ctx, q, args...) +} diff --git a/vendor/github.com/jmoiron/sqlx/named_context_test.go b/vendor/github.com/jmoiron/sqlx/named_context_test.go new file mode 100644 index 000000000..87e94ac22 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/named_context_test.go @@ -0,0 +1,136 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" + "testing" +) + +func TestNamedContextQueries(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + test := Test{t} + var ns *NamedStmt + var err error + + ctx := context.Background() + + // Check that invalid preparations fail + ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + ns, err = db.PrepareNamedContext(ctx, "invalid sql") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + // Check closing works as anticipated + ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name") + test.Error(err) + err = ns.Close() + test.Error(err) + + ns, err = db.PrepareNamedContext(ctx, ` + SELECT first_name, last_name, email + FROM person WHERE first_name=:first_name AND email=:email`) + test.Error(err) + + // test Queryx w/ uses Query + p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} + + rows, err := ns.QueryxContext(ctx, p) + test.Error(err) + for rows.Next() { + var p2 Person + rows.StructScan(&p2) + if p.FirstName != p2.FirstName { + t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) + } + if p.LastName != p2.LastName { + t.Errorf("got %s, expected %s", p.LastName, p2.LastName) + } + if p.Email != p2.Email { + t.Errorf("got %s, expected %s", p.Email, p2.Email) + } + } + + // test Select + people := make([]Person, 0, 5) + err = ns.SelectContext(ctx, &people, p) + test.Error(err) + + if len(people) != 1 { + t.Errorf("got %d results, expected %d", len(people), 1) + } + if p.FirstName != people[0].FirstName { + t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) + } + if p.LastName != people[0].LastName { + t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) + } + if p.Email != people[0].Email { + t.Errorf("got %s, expected %s", p.Email, people[0].Email) + } + + // test Exec + ns, err = db.PrepareNamedContext(ctx, ` + INSERT INTO person (first_name, last_name, email) + VALUES (:first_name, :last_name, :email)`) + test.Error(err) + + js := Person{ + FirstName: "Julien", + LastName: "Savea", + Email: "jsavea@ab.co.nz", + } + _, err = ns.ExecContext(ctx, js) + test.Error(err) + + // Make sure we can pull him out again + p2 := Person{} + db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) + if p2.Email != js.Email { + t.Errorf("expected %s, got %s", js.Email, p2.Email) + } + + // test Txn NamedStmts + tx := db.MustBeginTx(ctx, nil) + txns := tx.NamedStmtContext(ctx, ns) + + // We're going to add Steven in this txn + sl := Person{ + FirstName: "Steven", + LastName: "Luatua", + Email: "sluatua@ab.co.nz", + } + + _, err = txns.ExecContext(ctx, sl) + test.Error(err) + // then rollback... + tx.Rollback() + // looking for Steven after a rollback should fail + err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + if err != sql.ErrNoRows { + t.Errorf("expected no rows error, got %v", err) + } + + // now do the same, but commit + tx = db.MustBeginTx(ctx, nil) + txns = tx.NamedStmtContext(ctx, ns) + _, err = txns.ExecContext(ctx, sl) + test.Error(err) + tx.Commit() + + // looking for Steven after a Commit should succeed + err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + test.Error(err) + if p2.Email != sl.Email { + t.Errorf("expected %s, got %s", sl.Email, p2.Email) + } + + }) +} diff --git a/vendor/github.com/jmoiron/sqlx/named_test.go b/vendor/github.com/jmoiron/sqlx/named_test.go new file mode 100644 index 000000000..d3459a86f --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/named_test.go @@ -0,0 +1,227 @@ +package sqlx + +import ( + "database/sql" + "testing" +) + +func TestCompileQuery(t *testing.T) { + table := []struct { + Q, R, D, N string + V []string + }{ + // basic test for named parameters, invalid char ',' terminating + { + Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, + R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, + D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, + N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, + V: []string{"name", "age", "first", "last"}, + }, + // This query tests a named parameter ending the string as well as numbers + { + Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, + R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, + D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, + N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, + V: []string{"name1", "name2"}, + }, + { + Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, + R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, + D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, + N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, + V: []string{"name1", "name2"}, + }, + { + Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, + R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, + D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, + N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, + V: []string{"first_name", "last_name"}, + }, + /* This unicode awareness test sadly fails, because of our byte-wise worldview. + * We could certainly iterate by Rune instead, though it's a great deal slower, + * it's probably the RightWay(tm) + { + Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`, + R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, + D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, + N: []string{"name", "age", "first", "last"}, + }, + */ + } + + for _, test := range table { + qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION) + if err != nil { + t.Error(err) + } + if qr != test.R { + t.Errorf("expected %s, got %s", test.R, qr) + } + if len(names) != len(test.V) { + t.Errorf("expected %#v, got %#v", test.V, names) + } else { + for i, name := range names { + if name != test.V[i] { + t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name) + } + } + } + qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR) + if qd != test.D { + t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) + } + + qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED) + if qq != test.N { + t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) + } + } +} + +type Test struct { + t *testing.T +} + +func (t Test) Error(err error, msg ...interface{}) { + if err != nil { + if len(msg) == 0 { + t.t.Error(err) + } else { + t.t.Error(msg...) + } + } +} + +func (t Test) Errorf(err error, format string, args ...interface{}) { + if err != nil { + t.t.Errorf(format, args...) + } +} + +func TestNamedQueries(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + test := Test{t} + var ns *NamedStmt + var err error + + // Check that invalid preparations fail + ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first:name") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + ns, err = db.PrepareNamed("invalid sql") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + // Check closing works as anticipated + ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first_name") + test.Error(err) + err = ns.Close() + test.Error(err) + + ns, err = db.PrepareNamed(` + SELECT first_name, last_name, email + FROM person WHERE first_name=:first_name AND email=:email`) + test.Error(err) + + // test Queryx w/ uses Query + p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} + + rows, err := ns.Queryx(p) + test.Error(err) + for rows.Next() { + var p2 Person + rows.StructScan(&p2) + if p.FirstName != p2.FirstName { + t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) + } + if p.LastName != p2.LastName { + t.Errorf("got %s, expected %s", p.LastName, p2.LastName) + } + if p.Email != p2.Email { + t.Errorf("got %s, expected %s", p.Email, p2.Email) + } + } + + // test Select + people := make([]Person, 0, 5) + err = ns.Select(&people, p) + test.Error(err) + + if len(people) != 1 { + t.Errorf("got %d results, expected %d", len(people), 1) + } + if p.FirstName != people[0].FirstName { + t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) + } + if p.LastName != people[0].LastName { + t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) + } + if p.Email != people[0].Email { + t.Errorf("got %s, expected %s", p.Email, people[0].Email) + } + + // test Exec + ns, err = db.PrepareNamed(` + INSERT INTO person (first_name, last_name, email) + VALUES (:first_name, :last_name, :email)`) + test.Error(err) + + js := Person{ + FirstName: "Julien", + LastName: "Savea", + Email: "jsavea@ab.co.nz", + } + _, err = ns.Exec(js) + test.Error(err) + + // Make sure we can pull him out again + p2 := Person{} + db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) + if p2.Email != js.Email { + t.Errorf("expected %s, got %s", js.Email, p2.Email) + } + + // test Txn NamedStmts + tx := db.MustBegin() + txns := tx.NamedStmt(ns) + + // We're going to add Steven in this txn + sl := Person{ + FirstName: "Steven", + LastName: "Luatua", + Email: "sluatua@ab.co.nz", + } + + _, err = txns.Exec(sl) + test.Error(err) + // then rollback... + tx.Rollback() + // looking for Steven after a rollback should fail + err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + if err != sql.ErrNoRows { + t.Errorf("expected no rows error, got %v", err) + } + + // now do the same, but commit + tx = db.MustBegin() + txns = tx.NamedStmt(ns) + _, err = txns.Exec(sl) + test.Error(err) + tx.Commit() + + // looking for Steven after a Commit should succeed + err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + test.Error(err) + if p2.Email != sl.Email { + t.Errorf("expected %s, got %s", sl.Email, p2.Email) + } + + }) +} diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/README.md b/vendor/github.com/jmoiron/sqlx/reflectx/README.md new file mode 100644 index 000000000..f01d3d1f0 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/reflectx/README.md @@ -0,0 +1,17 @@ +# reflectx + +The sqlx package has special reflect needs. In particular, it needs to: + +* be able to map a name to a field +* understand embedded structs +* understand mapping names to fields by a particular tag +* user specified name -> field mapping functions + +These behaviors mimic the behaviors by the standard library marshallers and also the +behavior of standard Go accessors. + +The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is +addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct +tags in the ways that are vital to most marshallers, and they are slow. + +This reflectx package extends reflect to achieve these goals. diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go b/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go new file mode 100644 index 000000000..f2802b80b --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go @@ -0,0 +1,422 @@ +// Package reflectx implements extensions to the standard reflect lib suitable +// for implementing marshalling and unmarshalling packages. The main Mapper type +// allows for Go-compatible named attribute access, including accessing embedded +// struct attributes and the ability to use functions and struct tags to +// customize field names. +// +package reflectx + +import ( + "reflect" + "runtime" + "strings" + "sync" +) + +// A FieldInfo is metadata for a struct field. +type FieldInfo struct { + Index []int + Path string + Field reflect.StructField + Zero reflect.Value + Name string + Options map[string]string + Embedded bool + Children []*FieldInfo + Parent *FieldInfo +} + +// A StructMap is an index of field metadata for a struct. +type StructMap struct { + Tree *FieldInfo + Index []*FieldInfo + Paths map[string]*FieldInfo + Names map[string]*FieldInfo +} + +// GetByPath returns a *FieldInfo for a given string path. +func (f StructMap) GetByPath(path string) *FieldInfo { + return f.Paths[path] +} + +// GetByTraversal returns a *FieldInfo for a given integer path. It is +// analogous to reflect.FieldByIndex, but using the cached traversal +// rather than re-executing the reflect machinery each time. +func (f StructMap) GetByTraversal(index []int) *FieldInfo { + if len(index) == 0 { + return nil + } + + tree := f.Tree + for _, i := range index { + if i >= len(tree.Children) || tree.Children[i] == nil { + return nil + } + tree = tree.Children[i] + } + return tree +} + +// Mapper is a general purpose mapper of names to struct fields. A Mapper +// behaves like most marshallers in the standard library, obeying a field tag +// for name mapping but also providing a basic transform function. +type Mapper struct { + cache map[reflect.Type]*StructMap + tagName string + tagMapFunc func(string) string + mapFunc func(string) string + mutex sync.Mutex +} + +// NewMapper returns a new mapper using the tagName as its struct field tag. +// If tagName is the empty string, it is ignored. +func NewMapper(tagName string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + } +} + +// NewMapperTagFunc returns a new mapper which contains a mapper for field names +// AND a mapper for tag values. This is useful for tags like json which can +// have values like "name,omitempty". +func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: mapFunc, + tagMapFunc: tagMapFunc, + } +} + +// NewMapperFunc returns a new mapper which optionally obeys a field tag and +// a struct field name mapper func given by f. Tags will take precedence, but +// for any other field, the mapped name will be f(field.Name) +func NewMapperFunc(tagName string, f func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: f, + } +} + +// TypeMap returns a mapping of field strings to int slices representing +// the traversal down the struct to reach the field. +func (m *Mapper) TypeMap(t reflect.Type) *StructMap { + m.mutex.Lock() + mapping, ok := m.cache[t] + if !ok { + mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) + m.cache[t] = mapping + } + m.mutex.Unlock() + return mapping +} + +// FieldMap returns the mapper's mapping of field names to reflect values. Panics +// if v's Kind is not Struct, or v is not Indirectable to a struct kind. +func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + r := map[string]reflect.Value{} + tm := m.TypeMap(v.Type()) + for tagName, fi := range tm.Names { + r[tagName] = FieldByIndexes(v, fi.Index) + } + return r +} + +// FieldByName returns a field by its mapped name as a reflect.Value. +// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. +// Returns zero Value if the name is not found. +func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + fi, ok := tm.Names[name] + if !ok { + return v + } + return FieldByIndexes(v, fi.Index) +} + +// FieldsByName returns a slice of values corresponding to the slice of names +// for the value. Panics if v's Kind is not Struct or v is not Indirectable +// to a struct Kind. Returns zero Value for each name not found. +func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + vals := make([]reflect.Value, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + vals = append(vals, *new(reflect.Value)) + } else { + vals = append(vals, FieldByIndexes(v, fi.Index)) + } + } + return vals +} + +// TraversalsByName returns a slice of int slices which represent the struct +// traversals for each mapped name. Panics if t is not a struct or Indirectable +// to a struct. Returns empty int slice for each name not found. +func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { + t = Deref(t) + mustBe(t, reflect.Struct) + tm := m.TypeMap(t) + + r := make([][]int, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + r = append(r, []int{}) + } else { + r = append(r, fi.Index) + } + } + return r +} + +// FieldByIndexes returns a value for the field given by the struct traversal +// for the given value. +func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + // if this is a pointer and it's nil, allocate a new value and set it + if v.Kind() == reflect.Ptr && v.IsNil() { + alloc := reflect.New(Deref(v.Type())) + v.Set(alloc) + } + if v.Kind() == reflect.Map && v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + } + return v +} + +// FieldByIndexesReadOnly returns a value for a particular struct traversal, +// but is not concerned with allocating nil pointers because the value is +// going to be used for reading and not setting. +func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + } + return v +} + +// Deref is Indirect for reflect.Types +func Deref(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +// -- helpers & utilities -- + +type kinder interface { + Kind() reflect.Kind +} + +// mustBe checks a value against a kind, panicing with a reflect.ValueError +// if the kind isn't that which is required. +func mustBe(v kinder, expected reflect.Kind) { + if k := v.Kind(); k != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: k}) + } +} + +// methodName returns the caller of the function calling methodName +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +type typeQueue struct { + t reflect.Type + fi *FieldInfo + pp string // Parent path +} + +// A copying append that creates a new slice each time. +func apnd(is []int, i int) []int { + x := make([]int, len(is)+1) + for p, n := range is { + x[p] = n + } + x[len(x)-1] = i + return x +} + +type mapf func(string) string + +// parseName parses the tag and the target name for the given field using +// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the +// field's name to a target name, and tagMapFunc for mapping the tag to +// a target name. +func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { + // first, set the fieldName to the field's name + fieldName = field.Name + // if a mapFunc is set, use that to override the fieldName + if mapFunc != nil { + fieldName = mapFunc(fieldName) + } + + // if there's no tag to look for, return the field name + if tagName == "" { + return "", fieldName + } + + // if this tag is not set using the normal convention in the tag, + // then return the fieldname.. this check is done because according + // to the reflect documentation: + // If the tag does not have the conventional format, + // the value returned by Get is unspecified. + // which doesn't sound great. + if !strings.Contains(string(field.Tag), tagName+":") { + return "", fieldName + } + + // at this point we're fairly sure that we have a tag, so lets pull it out + tag = field.Tag.Get(tagName) + + // if we have a mapper function, call it on the whole tag + // XXX: this is a change from the old version, which pulled out the name + // before the tagMapFunc could be run, but I think this is the right way + if tagMapFunc != nil { + tag = tagMapFunc(tag) + } + + // finally, split the options from the name + parts := strings.Split(tag, ",") + fieldName = parts[0] + + return tag, fieldName +} + +// parseOptions parses options out of a tag string, skipping the name +func parseOptions(tag string) map[string]string { + parts := strings.Split(tag, ",") + options := make(map[string]string, len(parts)) + if len(parts) > 1 { + for _, opt := range parts[1:] { + // short circuit potentially expensive split op + if strings.Contains(opt, "=") { + kv := strings.Split(opt, "=") + options[kv[0]] = kv[1] + continue + } + options[opt] = "" + } + } + return options +} + +// getMapping returns a mapping for the t type, using the tagName, mapFunc and +// tagMapFunc to determine the canonical names of fields. +func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { + m := []*FieldInfo{} + + root := &FieldInfo{} + queue := []typeQueue{} + queue = append(queue, typeQueue{Deref(t), root, ""}) + +QueueLoop: + for len(queue) != 0 { + // pop the first item off of the queue + tq := queue[0] + queue = queue[1:] + + // ignore recursive field + for p := tq.fi.Parent; p != nil; p = p.Parent { + if tq.fi.Field.Type == p.Field.Type { + continue QueueLoop + } + } + + nChildren := 0 + if tq.t.Kind() == reflect.Struct { + nChildren = tq.t.NumField() + } + tq.fi.Children = make([]*FieldInfo, nChildren) + + // iterate through all of its fields + for fieldPos := 0; fieldPos < nChildren; fieldPos++ { + + f := tq.t.Field(fieldPos) + + // parse the tag and the target name using the mapping options for this field + tag, name := parseName(f, tagName, mapFunc, tagMapFunc) + + // if the name is "-", disabled via a tag, skip it + if name == "-" { + continue + } + + fi := FieldInfo{ + Field: f, + Name: name, + Zero: reflect.New(f.Type).Elem(), + Options: parseOptions(tag), + } + + // if the path is empty this path is just the name + if tq.pp == "" { + fi.Path = fi.Name + } else { + fi.Path = tq.pp + "." + fi.Name + } + + // skip unexported fields + if len(f.PkgPath) != 0 && !f.Anonymous { + continue + } + + // bfs search of anonymous embedded structs + if f.Anonymous { + pp := tq.pp + if tag != "" { + pp = fi.Path + } + + fi.Embedded = true + fi.Index = apnd(tq.fi.Index, fieldPos) + nChildren := 0 + ft := Deref(f.Type) + if ft.Kind() == reflect.Struct { + nChildren = ft.NumField() + } + fi.Children = make([]*FieldInfo, nChildren) + queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) + } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) + queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) + } + + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Parent = tq.fi + tq.fi.Children[fieldPos] = &fi + m = append(m, &fi) + } + } + + flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} + for _, fi := range flds.Index { + flds.Paths[fi.Path] = fi + if fi.Name != "" && !fi.Embedded { + flds.Names[fi.Path] = fi + } + } + + return flds +} diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go b/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go new file mode 100644 index 000000000..b702f9cd1 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go @@ -0,0 +1,905 @@ +package reflectx + +import ( + "reflect" + "strings" + "testing" +) + +func ival(v reflect.Value) int { + return v.Interface().(int) +} + +func TestBasic(t *testing.T) { + type Foo struct { + A int + B int + C int + } + + f := Foo{1, 2, 3} + fv := reflect.ValueOf(f) + m := NewMapperFunc("", func(s string) string { return s }) + + v := m.FieldByName(fv, "A") + if ival(v) != f.A { + t.Errorf("Expecting %d, got %d", ival(v), f.A) + } + v = m.FieldByName(fv, "B") + if ival(v) != f.B { + t.Errorf("Expecting %d, got %d", f.B, ival(v)) + } + v = m.FieldByName(fv, "C") + if ival(v) != f.C { + t.Errorf("Expecting %d, got %d", f.C, ival(v)) + } +} + +func TestBasicEmbedded(t *testing.T) { + type Foo struct { + A int + } + + type Bar struct { + Foo // `db:""` is implied for an embedded struct + B int + C int `db:"-"` + } + + type Baz struct { + A int + Bar `db:"Bar"` + } + + m := NewMapperFunc("db", func(s string) string { return s }) + + z := Baz{} + z.A = 1 + z.B = 2 + z.C = 4 + z.Bar.Foo.A = 3 + + zv := reflect.ValueOf(z) + fields := m.TypeMap(reflect.TypeOf(z)) + + if len(fields.Index) != 5 { + t.Errorf("Expecting 5 fields") + } + + // for _, fi := range fields.Index { + // log.Println(fi) + // } + + v := m.FieldByName(zv, "A") + if ival(v) != z.A { + t.Errorf("Expecting %d, got %d", z.A, ival(v)) + } + v = m.FieldByName(zv, "Bar.B") + if ival(v) != z.Bar.B { + t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) + } + v = m.FieldByName(zv, "Bar.A") + if ival(v) != z.Bar.Foo.A { + t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) + } + v = m.FieldByName(zv, "Bar.C") + if _, ok := v.Interface().(int); ok { + t.Errorf("Expecting Bar.C to not exist") + } + + fi := fields.GetByPath("Bar.C") + if fi != nil { + t.Errorf("Bar.C should not exist") + } +} + +func TestEmbeddedSimple(t *testing.T) { + type UUID [16]byte + type MyID struct { + UUID + } + type Item struct { + ID MyID + } + z := Item{} + + m := NewMapper("db") + m.TypeMap(reflect.TypeOf(z)) +} + +func TestBasicEmbeddedWithTags(t *testing.T) { + type Foo struct { + A int `db:"a"` + } + + type Bar struct { + Foo // `db:""` is implied for an embedded struct + B int `db:"b"` + } + + type Baz struct { + A int `db:"a"` + Bar // `db:""` is implied for an embedded struct + } + + m := NewMapper("db") + + z := Baz{} + z.A = 1 + z.B = 2 + z.Bar.Foo.A = 3 + + zv := reflect.ValueOf(z) + fields := m.TypeMap(reflect.TypeOf(z)) + + if len(fields.Index) != 5 { + t.Errorf("Expecting 5 fields") + } + + // for _, fi := range fields.index { + // log.Println(fi) + // } + + v := m.FieldByName(zv, "a") + if ival(v) != z.Bar.Foo.A { // the dominant field + t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) + } + v = m.FieldByName(zv, "b") + if ival(v) != z.B { + t.Errorf("Expecting %d, got %d", z.B, ival(v)) + } +} + +func TestFlatTags(t *testing.T) { + m := NewMapper("db") + + type Asset struct { + Title string `db:"title"` + } + type Post struct { + Author string `db:"author,required"` + Asset Asset `db:""` + } + // Post columns: (author title) + + post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}} + pv := reflect.ValueOf(post) + + v := m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } + v = m.FieldByName(pv, "title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } +} + +func TestNestedStruct(t *testing.T) { + m := NewMapper("db") + + type Details struct { + Active bool `db:"active"` + } + type Asset struct { + Title string `db:"title"` + Details Details `db:"details"` + } + type Post struct { + Author string `db:"author,required"` + Asset `db:"asset"` + } + // Post columns: (author asset.title asset.details.active) + + post := Post{ + Author: "Joe", + Asset: Asset{Title: "Hello", Details: Details{Active: true}}, + } + pv := reflect.ValueOf(post) + + v := m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } + v = m.FieldByName(pv, "title") + if _, ok := v.Interface().(string); ok { + t.Errorf("Expecting field to not exist") + } + v = m.FieldByName(pv, "asset.title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset.details.active") + if v.Interface().(bool) != post.Asset.Details.Active { + t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) + } +} + +func TestInlineStruct(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + + type Employee struct { + Name string + ID int + } + type Boss Employee + type person struct { + Employee `db:"employee"` + Boss `db:"boss"` + } + // employees columns: (employee.name employee.id boss.name boss.id) + + em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}} + ev := reflect.ValueOf(em) + + fields := m.TypeMap(reflect.TypeOf(em)) + if len(fields.Index) != 6 { + t.Errorf("Expecting 6 fields") + } + + v := m.FieldByName(ev, "employee.name") + if v.Interface().(string) != em.Employee.Name { + t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) + } + v = m.FieldByName(ev, "boss.id") + if ival(v) != em.Boss.ID { + t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) + } +} + +func TestRecursiveStruct(t *testing.T) { + type Person struct { + Parent *Person + } + m := NewMapperFunc("db", strings.ToLower) + var p *Person + m.TypeMap(reflect.TypeOf(p)) +} + +func TestFieldsEmbedded(t *testing.T) { + m := NewMapper("db") + + type Person struct { + Name string `db:"name,size=64"` + } + type Place struct { + Name string `db:"name"` + } + type Article struct { + Title string `db:"title"` + } + type PP struct { + Person `db:"person,required"` + Place `db:",someflag"` + Article `db:",required"` + } + // PP columns: (person.name name title) + + pp := PP{} + pp.Person.Name = "Peter" + pp.Place.Name = "Toronto" + pp.Article.Title = "Best city ever" + + fields := m.TypeMap(reflect.TypeOf(pp)) + // for i, f := range fields { + // log.Println(i, f) + // } + + ppv := reflect.ValueOf(pp) + + v := m.FieldByName(ppv, "person.name") + if v.Interface().(string) != pp.Person.Name { + t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) + } + + v = m.FieldByName(ppv, "name") + if v.Interface().(string) != pp.Place.Name { + t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) + } + + v = m.FieldByName(ppv, "title") + if v.Interface().(string) != pp.Article.Title { + t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) + } + + fi := fields.GetByPath("person") + if _, ok := fi.Options["required"]; !ok { + t.Errorf("Expecting required option to be set") + } + if !fi.Embedded { + t.Errorf("Expecting field to be embedded") + } + if len(fi.Index) != 1 || fi.Index[0] != 0 { + t.Errorf("Expecting index to be [0]") + } + + fi = fields.GetByPath("person.name") + if fi == nil { + t.Errorf("Expecting person.name to exist") + } + if fi.Path != "person.name" { + t.Errorf("Expecting %s, got %s", "person.name", fi.Path) + } + if fi.Options["size"] != "64" { + t.Errorf("Expecting %s, got %s", "64", fi.Options["size"]) + } + + fi = fields.GetByTraversal([]int{1, 0}) + if fi == nil { + t.Errorf("Expecting traveral to exist") + } + if fi.Path != "name" { + t.Errorf("Expecting %s, got %s", "name", fi.Path) + } + + fi = fields.GetByTraversal([]int{2}) + if fi == nil { + t.Errorf("Expecting traversal to exist") + } + if _, ok := fi.Options["required"]; !ok { + t.Errorf("Expecting required option to be set") + } + + trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) + if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { + t.Errorf("Expecting traversal: %v", trs) + } +} + +func TestPtrFields(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + type Asset struct { + Title string + } + type Post struct { + *Asset `db:"asset"` + Author string + } + + post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}} + pv := reflect.ValueOf(post) + + fields := m.TypeMap(reflect.TypeOf(post)) + if len(fields.Index) != 3 { + t.Errorf("Expecting 3 fields") + } + + v := m.FieldByName(pv, "asset.title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } +} + +func TestNamedPtrFields(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + + type User struct { + Name string + } + + type Asset struct { + Title string + + Owner *User `db:"owner"` + } + type Post struct { + Author string + + Asset1 *Asset `db:"asset1"` + Asset2 *Asset `db:"asset2"` + } + + post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil + pv := reflect.ValueOf(post) + + fields := m.TypeMap(reflect.TypeOf(post)) + if len(fields.Index) != 9 { + t.Errorf("Expecting 9 fields") + } + + v := m.FieldByName(pv, "asset1.title") + if v.Interface().(string) != post.Asset1.Title { + t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset1.owner.name") + if v.Interface().(string) != post.Asset1.Owner.Name { + t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset2.title") + if v.Interface().(string) != post.Asset2.Title { + t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset2.owner.name") + if v.Interface().(string) != post.Asset2.Owner.Name { + t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) + } + v = m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } +} + +func TestFieldMap(t *testing.T) { + type Foo struct { + A int + B int + C int + } + + f := Foo{1, 2, 3} + m := NewMapperFunc("db", strings.ToLower) + + fm := m.FieldMap(reflect.ValueOf(f)) + + if len(fm) != 3 { + t.Errorf("Expecting %d keys, got %d", 3, len(fm)) + } + if fm["a"].Interface().(int) != 1 { + t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) + } + if fm["b"].Interface().(int) != 2 { + t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) + } + if fm["c"].Interface().(int) != 3 { + t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) + } +} + +func TestTagNameMapping(t *testing.T) { + type Strategy struct { + StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"` + StrategyName string + } + + m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string { + if strings.Contains(value, ",") { + return strings.Split(value, ",")[0] + } + return value + }) + strategy := Strategy{"1", "Alpah"} + mapping := m.TypeMap(reflect.TypeOf(strategy)) + + for _, key := range []string{"strategy_id", "STRATEGYNAME"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } +} + +func TestMapping(t *testing.T) { + type Person struct { + ID int + Name string + WearsGlasses bool `db:"wears_glasses"` + } + + m := NewMapperFunc("db", strings.ToLower) + p := Person{1, "Jason", true} + mapping := m.TypeMap(reflect.TypeOf(p)) + + for _, key := range []string{"id", "name", "wears_glasses"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + type SportsPerson struct { + Weight int + Age int + Person + } + s := SportsPerson{Weight: 100, Age: 30, Person: p} + mapping = m.TypeMap(reflect.TypeOf(s)) + for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + type RugbyPlayer struct { + Position int + IsIntense bool `db:"is_intense"` + IsAllBlack bool `db:"-"` + SportsPerson + } + r := RugbyPlayer{12, true, false, s} + mapping = m.TypeMap(reflect.TypeOf(r)) + for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + if fi := mapping.GetByPath("isallblack"); fi != nil { + t.Errorf("Expecting to ignore `IsAllBlack` field") + } +} + +func TestGetByTraversal(t *testing.T) { + type C struct { + C0 int + C1 int + } + type B struct { + B0 string + B1 *C + } + type A struct { + A0 int + A1 B + } + + testCases := []struct { + Index []int + ExpectedName string + ExpectNil bool + }{ + { + Index: []int{0}, + ExpectedName: "A0", + }, + { + Index: []int{1, 0}, + ExpectedName: "B0", + }, + { + Index: []int{1, 1, 1}, + ExpectedName: "C1", + }, + { + Index: []int{3, 4, 5}, + ExpectNil: true, + }, + { + Index: []int{}, + ExpectNil: true, + }, + { + Index: nil, + ExpectNil: true, + }, + } + + m := NewMapperFunc("db", func(n string) string { return n }) + tm := m.TypeMap(reflect.TypeOf(A{})) + + for i, tc := range testCases { + fi := tm.GetByTraversal(tc.Index) + if tc.ExpectNil { + if fi != nil { + t.Errorf("%d: expected nil, got %v", i, fi) + } + continue + } + + if fi == nil { + t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName) + continue + } + + if fi.Name != tc.ExpectedName { + t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name) + } + } +} + +// TestMapperMethodsByName tests Mapper methods FieldByName and TraversalsByName +func TestMapperMethodsByName(t *testing.T) { + type C struct { + C0 string + C1 int + } + type B struct { + B0 *C `db:"B0"` + B1 C `db:"B1"` + B2 string `db:"B2"` + } + type A struct { + A0 *B `db:"A0"` + B `db:"A1"` + A2 int + a3 int + } + + val := &A{ + A0: &B{ + B0: &C{C0: "0", C1: 1}, + B1: C{C0: "2", C1: 3}, + B2: "4", + }, + B: B{ + B0: nil, + B1: C{C0: "5", C1: 6}, + B2: "7", + }, + A2: 8, + } + + testCases := []struct { + Name string + ExpectInvalid bool + ExpectedValue interface{} + ExpectedIndexes []int + }{ + { + Name: "A0.B0.C0", + ExpectedValue: "0", + ExpectedIndexes: []int{0, 0, 0}, + }, + { + Name: "A0.B0.C1", + ExpectedValue: 1, + ExpectedIndexes: []int{0, 0, 1}, + }, + { + Name: "A0.B1.C0", + ExpectedValue: "2", + ExpectedIndexes: []int{0, 1, 0}, + }, + { + Name: "A0.B1.C1", + ExpectedValue: 3, + ExpectedIndexes: []int{0, 1, 1}, + }, + { + Name: "A0.B2", + ExpectedValue: "4", + ExpectedIndexes: []int{0, 2}, + }, + { + Name: "A1.B0.C0", + ExpectedValue: "", + ExpectedIndexes: []int{1, 0, 0}, + }, + { + Name: "A1.B0.C1", + ExpectedValue: 0, + ExpectedIndexes: []int{1, 0, 1}, + }, + { + Name: "A1.B1.C0", + ExpectedValue: "5", + ExpectedIndexes: []int{1, 1, 0}, + }, + { + Name: "A1.B1.C1", + ExpectedValue: 6, + ExpectedIndexes: []int{1, 1, 1}, + }, + { + Name: "A1.B2", + ExpectedValue: "7", + ExpectedIndexes: []int{1, 2}, + }, + { + Name: "A2", + ExpectedValue: 8, + ExpectedIndexes: []int{2}, + }, + { + Name: "XYZ", + ExpectInvalid: true, + ExpectedIndexes: []int{}, + }, + { + Name: "a3", + ExpectInvalid: true, + ExpectedIndexes: []int{}, + }, + } + + // build the names array from the test cases + names := make([]string, len(testCases)) + for i, tc := range testCases { + names[i] = tc.Name + } + m := NewMapperFunc("db", func(n string) string { return n }) + v := reflect.ValueOf(val) + values := m.FieldsByName(v, names) + if len(values) != len(testCases) { + t.Errorf("expected %d values, got %d", len(testCases), len(values)) + t.FailNow() + } + indexes := m.TraversalsByName(v.Type(), names) + if len(indexes) != len(testCases) { + t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes)) + t.FailNow() + } + for i, val := range values { + tc := testCases[i] + traversal := indexes[i] + if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) { + t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal) + t.FailNow() + } + val = reflect.Indirect(val) + if tc.ExpectInvalid { + if val.IsValid() { + t.Errorf("%d: expected zero value, got %v", i, val) + } + continue + } + if !val.IsValid() { + t.Errorf("%d: expected valid value, got %v", i, val) + continue + } + actualValue := reflect.Indirect(val).Interface() + if !reflect.DeepEqual(tc.ExpectedValue, actualValue) { + t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue) + } + } +} + +func TestFieldByIndexes(t *testing.T) { + type C struct { + C0 bool + C1 string + C2 int + C3 map[string]int + } + type B struct { + B1 C + B2 *C + } + type A struct { + A1 B + A2 *B + } + testCases := []struct { + value interface{} + indexes []int + expectedValue interface{} + readOnly bool + }{ + { + value: A{ + A1: B{B1: C{C0: true}}, + }, + indexes: []int{0, 0, 0}, + expectedValue: true, + readOnly: true, + }, + { + value: A{ + A2: &B{B2: &C{C1: "answer"}}, + }, + indexes: []int{1, 1, 1}, + expectedValue: "answer", + readOnly: true, + }, + { + value: &A{}, + indexes: []int{1, 1, 3}, + expectedValue: map[string]int{}, + }, + } + + for i, tc := range testCases { + checkResults := func(v reflect.Value) { + if tc.expectedValue == nil { + if !v.IsNil() { + t.Errorf("%d: expected nil, actual %v", i, v.Interface()) + } + } else { + if !reflect.DeepEqual(tc.expectedValue, v.Interface()) { + t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface()) + } + } + } + + checkResults(FieldByIndexes(reflect.ValueOf(tc.value), tc.indexes)) + if tc.readOnly { + checkResults(FieldByIndexesReadOnly(reflect.ValueOf(tc.value), tc.indexes)) + } + } +} + +func TestMustBe(t *testing.T) { + typ := reflect.TypeOf(E1{}) + mustBe(typ, reflect.Struct) + + defer func() { + if r := recover(); r != nil { + valueErr, ok := r.(*reflect.ValueError) + if !ok { + t.Errorf("unexpected Method: %s", valueErr.Method) + t.Error("expected panic with *reflect.ValueError") + return + } + if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" { + } + if valueErr.Kind != reflect.String { + t.Errorf("unexpected Kind: %s", valueErr.Kind) + } + } else { + t.Error("expected panic") + } + }() + + typ = reflect.TypeOf("string") + mustBe(typ, reflect.Struct) + t.Error("got here, didn't expect to") +} + +type E1 struct { + A int +} +type E2 struct { + E1 + B int +} +type E3 struct { + E2 + C int +} +type E4 struct { + E3 + D int +} + +func BenchmarkFieldNameL1(b *testing.B) { + e4 := E4{D: 1} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.FieldByName("D") + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldNameL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.FieldByName("A") + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldPosL1(b *testing.B) { + e4 := E4{D: 1} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.Field(1) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldPosL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.Field(0) + f = f.Field(0) + f = f.Field(0) + f = f.Field(0) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldByIndexL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + idx := []int{0, 0, 0, 0} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := FieldByIndexes(v, idx) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx.go b/vendor/github.com/jmoiron/sqlx/sqlx.go new file mode 100644 index 000000000..4859d5ac8 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/sqlx.go @@ -0,0 +1,1035 @@ +package sqlx + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + + "io/ioutil" + "path/filepath" + "reflect" + "strings" + "sync" + + "github.com/jmoiron/sqlx/reflectx" +) + +// Although the NameMapper is convenient, in practice it should not +// be relied on except for application code. If you are writing a library +// that uses sqlx, you should be aware that the name mappings you expect +// can be overridden by your user's application. + +// NameMapper is used to map column names to struct field names. By default, +// it uses strings.ToLower to lowercase struct field names. It can be set +// to whatever you want, but it is encouraged to be set before sqlx is used +// as name-to-field mappings are cached after first use on a type. +var NameMapper = strings.ToLower +var origMapper = reflect.ValueOf(NameMapper) + +// Rather than creating on init, this is created when necessary so that +// importers have time to customize the NameMapper. +var mpr *reflectx.Mapper + +// mprMu protects mpr. +var mprMu sync.Mutex + +// mapper returns a valid mapper using the configured NameMapper func. +func mapper() *reflectx.Mapper { + mprMu.Lock() + defer mprMu.Unlock() + + if mpr == nil { + mpr = reflectx.NewMapperFunc("db", NameMapper) + } else if origMapper != reflect.ValueOf(NameMapper) { + // if NameMapper has changed, create a new mapper + mpr = reflectx.NewMapperFunc("db", NameMapper) + origMapper = reflect.ValueOf(NameMapper) + } + return mpr +} + +// isScannable takes the reflect.Type and the actual dest value and returns +// whether or not it's Scannable. Something is scannable if: +// * it is not a struct +// * it implements sql.Scanner +// * it has no exported fields +func isScannable(t reflect.Type) bool { + if reflect.PtrTo(t).Implements(_scannerInterface) { + return true + } + if t.Kind() != reflect.Struct { + return true + } + + // it's not important that we use the right mapper for this particular object, + // we're only concerned on how many exported fields this struct has + m := mapper() + if len(m.TypeMap(t).Index) == 0 { + return true + } + return false +} + +// ColScanner is an interface used by MapScan and SliceScan +type ColScanner interface { + Columns() ([]string, error) + Scan(dest ...interface{}) error + Err() error +} + +// Queryer is an interface used by Get and Select +type Queryer interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + Queryx(query string, args ...interface{}) (*Rows, error) + QueryRowx(query string, args ...interface{}) *Row +} + +// Execer is an interface used by MustExec and LoadFile +type Execer interface { + Exec(query string, args ...interface{}) (sql.Result, error) +} + +// Binder is an interface for something which can bind queries (Tx, DB) +type binder interface { + DriverName() string + Rebind(string) string + BindNamed(string, interface{}) (string, []interface{}, error) +} + +// Ext is a union interface which can bind, query, and exec, used by +// NamedQuery and NamedExec. +type Ext interface { + binder + Queryer + Execer +} + +// Preparer is an interface used by Preparex. +type Preparer interface { + Prepare(query string) (*sql.Stmt, error) +} + +// determine if any of our extensions are unsafe +func isUnsafe(i interface{}) bool { + switch v := i.(type) { + case Row: + return v.unsafe + case *Row: + return v.unsafe + case Rows: + return v.unsafe + case *Rows: + return v.unsafe + case NamedStmt: + return v.Stmt.unsafe + case *NamedStmt: + return v.Stmt.unsafe + case Stmt: + return v.unsafe + case *Stmt: + return v.unsafe + case qStmt: + return v.unsafe + case *qStmt: + return v.unsafe + case DB: + return v.unsafe + case *DB: + return v.unsafe + case Tx: + return v.unsafe + case *Tx: + return v.unsafe + case sql.Rows, *sql.Rows: + return false + default: + return false + } +} + +func mapperFor(i interface{}) *reflectx.Mapper { + switch i.(type) { + case DB: + return i.(DB).Mapper + case *DB: + return i.(*DB).Mapper + case Tx: + return i.(Tx).Mapper + case *Tx: + return i.(*Tx).Mapper + default: + return mapper() + } +} + +var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// Row is a reimplementation of sql.Row in order to gain access to the underlying +// sql.Rows.Columns() data, necessary for StructScan. +type Row struct { + err error + unsafe bool + rows *sql.Rows + Mapper *reflectx.Mapper +} + +// Scan is a fixed implementation of sql.Row.Scan, which does not discard the +// underlying error from the internal rows object if it exists. +func (r *Row) Scan(dest ...interface{}) error { + if r.err != nil { + return r.err + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + defer r.rows.Close() + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !r.rows.Next() { + if err := r.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := r.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + if err := r.rows.Close(); err != nil { + return err + } + return nil +} + +// Columns returns the underlying sql.Rows.Columns(), or the deferred error usually +// returned by Row.Scan() +func (r *Row) Columns() ([]string, error) { + if r.err != nil { + return []string{}, r.err + } + return r.rows.Columns() +} + +// Err returns the error encountered while scanning. +func (r *Row) Err() error { + return r.err +} + +// DB is a wrapper around sql.DB which keeps track of the driverName upon Open, +// used mostly to automatically bind named queries using the right bindvars. +type DB struct { + *sql.DB + driverName string + unsafe bool + Mapper *reflectx.Mapper +} + +// NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The +// driverName of the original database is required for named query support. +func NewDb(db *sql.DB, driverName string) *DB { + return &DB{DB: db, driverName: driverName, Mapper: mapper()} +} + +// DriverName returns the driverName passed to the Open function for this DB. +func (db *DB) DriverName() string { + return db.driverName +} + +// Open is the same as sql.Open, but returns an *sqlx.DB instead. +func Open(driverName, dataSourceName string) (*DB, error) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err +} + +// MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error. +func MustOpen(driverName, dataSourceName string) *DB { + db, err := Open(driverName, dataSourceName) + if err != nil { + panic(err) + } + return db +} + +// MapperFunc sets a new mapper for this db using the default sqlx struct tag +// and the provided mapper function. +func (db *DB) MapperFunc(mf func(string) string) { + db.Mapper = reflectx.NewMapperFunc("db", mf) +} + +// Rebind transforms a query from QUESTION to the DB driver's bindvar type. +func (db *DB) Rebind(query string) string { + return Rebind(BindType(db.driverName), query) +} + +// Unsafe returns a version of DB which will silently succeed to scan when +// columns in the SQL result have no fields in the destination struct. +// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its +// safety behavior. +func (db *DB) Unsafe() *DB { + return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper} +} + +// BindNamed binds a query using the DB driver's bindvar type. +func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper) +} + +// NamedQuery using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { + return NamedQuery(db, query, arg) +} + +// NamedExec using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { + return NamedExec(db, query, arg) +} + +// Select using this DB. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { + return Select(db, dest, query, args...) +} + +// Get using this DB. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { + return Get(db, dest, query, args...) +} + +// MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead +// of an *sql.Tx. +func (db *DB) MustBegin() *Tx { + tx, err := db.Beginx() + if err != nil { + panic(err) + } + return tx +} + +// Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx. +func (db *DB) Beginx() (*Tx, error) { + tx, err := db.DB.Begin() + if err != nil { + return nil, err + } + return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// Queryx queries the database and returns an *sqlx.Rows. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { + r, err := db.DB.Query(query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// QueryRowx queries the database and returns an *sqlx.Row. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryRowx(query string, args ...interface{}) *Row { + rows, err := db.DB.Query(query, args...) + return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} +} + +// MustExec (panic) runs MustExec using this database. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) MustExec(query string, args ...interface{}) sql.Result { + return MustExec(db, query, args...) +} + +// Preparex returns an sqlx.Stmt instead of a sql.Stmt +func (db *DB) Preparex(query string) (*Stmt, error) { + return Preparex(db, query) +} + +// PrepareNamed returns an sqlx.NamedStmt +func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { + return prepareNamed(db, query) +} + +// Tx is an sqlx wrapper around sql.Tx with extra functionality +type Tx struct { + *sql.Tx + driverName string + unsafe bool + Mapper *reflectx.Mapper +} + +// DriverName returns the driverName used by the DB which began this transaction. +func (tx *Tx) DriverName() string { + return tx.driverName +} + +// Rebind a query within a transaction's bindvar type. +func (tx *Tx) Rebind(query string) string { + return Rebind(BindType(tx.driverName), query) +} + +// Unsafe returns a version of Tx which will silently succeed to scan when +// columns in the SQL result have no fields in the destination struct. +func (tx *Tx) Unsafe() *Tx { + return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper} +} + +// BindNamed binds a query within a transaction's bindvar type. +func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper) +} + +// NamedQuery within a transaction. +// Any named placeholder parameters are replaced with fields from arg. +func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { + return NamedQuery(tx, query, arg) +} + +// NamedExec a named query within a transaction. +// Any named placeholder parameters are replaced with fields from arg. +func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { + return NamedExec(tx, query, arg) +} + +// Select within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { + return Select(tx, dest, query, args...) +} + +// Queryx within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { + r, err := tx.Tx.Query(query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err +} + +// QueryRowx within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { + rows, err := tx.Tx.Query(query, args...) + return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} +} + +// Get within a transaction. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { + return Get(tx, dest, query, args...) +} + +// MustExec runs MustExec within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { + return MustExec(tx, query, args...) +} + +// Preparex a statement within a transaction. +func (tx *Tx) Preparex(query string) (*Stmt, error) { + return Preparex(tx, query) +} + +// Stmtx returns a version of the prepared statement which runs within a transaction. Provided +// stmt can be either *sql.Stmt or *sqlx.Stmt. +func (tx *Tx) Stmtx(stmt interface{}) *Stmt { + var s *sql.Stmt + switch v := stmt.(type) { + case Stmt: + s = v.Stmt + case *Stmt: + s = v.Stmt + case sql.Stmt: + s = &v + case *sql.Stmt: + s = v + default: + panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) + } + return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper} +} + +// NamedStmt returns a version of the prepared statement which runs within a transaction. +func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt { + return &NamedStmt{ + QueryString: stmt.QueryString, + Params: stmt.Params, + Stmt: tx.Stmtx(stmt.Stmt), + } +} + +// PrepareNamed returns an sqlx.NamedStmt +func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) { + return prepareNamed(tx, query) +} + +// Stmt is an sqlx wrapper around sql.Stmt with extra functionality +type Stmt struct { + *sql.Stmt + unsafe bool + Mapper *reflectx.Mapper +} + +// Unsafe returns a version of Stmt which will silently succeed to scan when +// columns in the SQL result have no fields in the destination struct. +func (s *Stmt) Unsafe() *Stmt { + return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper} +} + +// Select using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) Select(dest interface{}, args ...interface{}) error { + return Select(&qStmt{s}, dest, "", args...) +} + +// Get using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (s *Stmt) Get(dest interface{}, args ...interface{}) error { + return Get(&qStmt{s}, dest, "", args...) +} + +// MustExec (panic) using this statement. Note that the query portion of the error +// output will be blank, as Stmt does not expose its query. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) MustExec(args ...interface{}) sql.Result { + return MustExec(&qStmt{s}, "", args...) +} + +// QueryRowx using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryRowx(args ...interface{}) *Row { + qs := &qStmt{s} + return qs.QueryRowx("", args...) +} + +// Queryx using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { + qs := &qStmt{s} + return qs.Queryx("", args...) +} + +// qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by +// implementing those interfaces and ignoring the `query` argument. +type qStmt struct{ *Stmt } + +func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) { + return q.Stmt.Query(args...) +} + +func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { + r, err := q.Stmt.Query(args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err +} + +func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row { + rows, err := q.Stmt.Query(args...) + return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} +} + +func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) { + return q.Stmt.Exec(args...) +} + +// Rows is a wrapper around sql.Rows which caches costly reflect operations +// during a looped StructScan +type Rows struct { + *sql.Rows + unsafe bool + Mapper *reflectx.Mapper + // these fields cache memory use for a rows during iteration w/ structScan + started bool + fields [][]int + values []interface{} +} + +// SliceScan using this Rows. +func (r *Rows) SliceScan() ([]interface{}, error) { + return SliceScan(r) +} + +// MapScan using this Rows. +func (r *Rows) MapScan(dest map[string]interface{}) error { + return MapScan(r, dest) +} + +// StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct. +// Use this and iterate over Rows manually when the memory load of Select() might be +// prohibitive. *Rows.StructScan caches the reflect work of matching up column +// positions to fields to avoid that overhead per scan, which means it is not safe +// to run StructScan on the same Rows instance with different struct types. +func (r *Rows) StructScan(dest interface{}) error { + v := reflect.ValueOf(dest) + + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + + v = reflect.Indirect(v) + + if !r.started { + columns, err := r.Columns() + if err != nil { + return err + } + m := r.Mapper + + r.fields = m.TraversalsByName(v.Type(), columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(r.fields); err != nil && !r.unsafe { + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) + } + r.values = make([]interface{}, len(columns)) + r.started = true + } + + err := fieldsByTraversal(v, r.fields, r.values, true) + if err != nil { + return err + } + // scan into the struct field pointers and append to our results + err = r.Scan(r.values...) + if err != nil { + return err + } + return r.Err() +} + +// Connect to a database and verify with a ping. +func Connect(driverName, dataSourceName string) (*DB, error) { + db, err := Open(driverName, dataSourceName) + if err != nil { + return db, err + } + err = db.Ping() + return db, err +} + +// MustConnect connects to a database and panics on error. +func MustConnect(driverName, dataSourceName string) *DB { + db, err := Connect(driverName, dataSourceName) + if err != nil { + panic(err) + } + return db +} + +// Preparex prepares a statement. +func Preparex(p Preparer, query string) (*Stmt, error) { + s, err := p.Prepare(query) + if err != nil { + return nil, err + } + return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err +} + +// Select executes a query using the provided Queryer, and StructScans each row +// into dest, which must be a slice. If the slice elements are scannable, then +// the result set must have only one column. Otherwise, StructScan is used. +// The *sql.Rows are closed automatically. +// Any placeholder parameters are replaced with supplied args. +func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { + rows, err := q.Queryx(query, args...) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// Get does a QueryRow using the provided Queryer, and scans the resulting row +// to dest. If dest is scannable, the result must only have one column. Otherwise, +// StructScan is used. Get will return sql.ErrNoRows like row.Scan would. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { + r := q.QueryRowx(query, args...) + return r.scanAny(dest, false) +} + +// LoadFile exec's every statement in a file (as a single call to Exec). +// LoadFile may return a nil *sql.Result if errors are encountered locating or +// reading the file at path. LoadFile reads the entire file into memory, so it +// is not suitable for loading large data dumps, but can be useful for initializing +// schemas or loading indexes. +// +// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 +// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting +// this by requiring something with DriverName() and then attempting to split the +// queries will be difficult to get right, and its current driver-specific behavior +// is deemed at least not complex in its incorrectness. +func LoadFile(e Execer, path string) (*sql.Result, error) { + realpath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + contents, err := ioutil.ReadFile(realpath) + if err != nil { + return nil, err + } + res, err := e.Exec(string(contents)) + return &res, err +} + +// MustExec execs the query using e and panics if there was an error. +// Any placeholder parameters are replaced with supplied args. +func MustExec(e Execer, query string, args ...interface{}) sql.Result { + res, err := e.Exec(query, args...) + if err != nil { + panic(err) + } + return res +} + +// SliceScan using this Rows. +func (r *Row) SliceScan() ([]interface{}, error) { + return SliceScan(r) +} + +// MapScan using this Rows. +func (r *Row) MapScan(dest map[string]interface{}) error { + return MapScan(r, dest) +} + +func (r *Row) scanAny(dest interface{}, structOnly bool) error { + if r.err != nil { + return r.err + } + if r.rows == nil { + r.err = sql.ErrNoRows + return r.err + } + defer r.rows.Close() + + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + if v.IsNil() { + return errors.New("nil pointer passed to StructScan destination") + } + + base := reflectx.Deref(v.Type()) + scannable := isScannable(base) + + if structOnly && scannable { + return structOnlyError(base) + } + + columns, err := r.Columns() + if err != nil { + return err + } + + if scannable && len(columns) > 1 { + return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns)) + } + + if scannable { + return r.Scan(dest) + } + + m := r.Mapper + + fields := m.TraversalsByName(v.Type(), columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(fields); err != nil && !r.unsafe { + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) + } + values := make([]interface{}, len(columns)) + + err = fieldsByTraversal(v, fields, values, true) + if err != nil { + return err + } + // scan into the struct field pointers and append to our results + return r.Scan(values...) +} + +// StructScan a single Row into dest. +func (r *Row) StructScan(dest interface{}) error { + return r.scanAny(dest, true) +} + +// SliceScan a row, returning a []interface{} with values similar to MapScan. +// This function is primarily intended for use where the number of columns +// is not known. Because you can pass an []interface{} directly to Scan, +// it's recommended that you do that as it will not have to allocate new +// slices per row. +func SliceScan(r ColScanner) ([]interface{}, error) { + // ignore r.started, since we needn't use reflect for anything. + columns, err := r.Columns() + if err != nil { + return []interface{}{}, err + } + + values := make([]interface{}, len(columns)) + for i := range values { + values[i] = new(interface{}) + } + + err = r.Scan(values...) + + if err != nil { + return values, err + } + + for i := range columns { + values[i] = *(values[i].(*interface{})) + } + + return values, r.Err() +} + +// MapScan scans a single Row into the dest map[string]interface{}. +// Use this to get results for SQL that might not be under your control +// (for instance, if you're building an interface for an SQL server that +// executes SQL from input). Please do not use this as a primary interface! +// This will modify the map sent to it in place, so reuse the same map with +// care. Columns which occur more than once in the result will overwrite +// each other! +func MapScan(r ColScanner, dest map[string]interface{}) error { + // ignore r.started, since we needn't use reflect for anything. + columns, err := r.Columns() + if err != nil { + return err + } + + values := make([]interface{}, len(columns)) + for i := range values { + values[i] = new(interface{}) + } + + err = r.Scan(values...) + if err != nil { + return err + } + + for i, column := range columns { + dest[column] = *(values[i].(*interface{})) + } + + return r.Err() +} + +type rowsi interface { + Close() error + Columns() ([]string, error) + Err() error + Next() bool + Scan(...interface{}) error +} + +// structOnlyError returns an error appropriate for type when a non-scannable +// struct is expected but something else is given +func structOnlyError(t reflect.Type) error { + isStruct := t.Kind() == reflect.Struct + isScanner := reflect.PtrTo(t).Implements(_scannerInterface) + if !isStruct { + return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) + } + if isScanner { + return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name()) + } + return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) +} + +// scanAll scans all rows into a destination, which must be a slice of any +// type. If the destination slice type is a Struct, then StructScan will be +// used on each row. If the destination is some other kind of base type, then +// each row must only have one column which can scan into that type. This +// allows you to do something like: +// +// rows, _ := db.Query("select id from people;") +// var ids []int +// scanAll(rows, &ids, false) +// +// and ids will be a list of the id results. I realize that this is a desirable +// interface to expose to users, but for now it will only be exposed via changes +// to `Get` and `Select`. The reason that this has been implemented like this is +// this is the only way to not duplicate reflect work in the new API while +// maintaining backwards compatibility. +func scanAll(rows rowsi, dest interface{}, structOnly bool) error { + var v, vp reflect.Value + + value := reflect.ValueOf(dest) + + // json.Unmarshal returns errors for these + if value.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + if value.IsNil() { + return errors.New("nil pointer passed to StructScan destination") + } + direct := reflect.Indirect(value) + + slice, err := baseType(value.Type(), reflect.Slice) + if err != nil { + return err + } + + isPtr := slice.Elem().Kind() == reflect.Ptr + base := reflectx.Deref(slice.Elem()) + scannable := isScannable(base) + + if structOnly && scannable { + return structOnlyError(base) + } + + columns, err := rows.Columns() + if err != nil { + return err + } + + // if it's a base type make sure it only has 1 column; if not return an error + if scannable && len(columns) > 1 { + return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns)) + } + + if !scannable { + var values []interface{} + var m *reflectx.Mapper + + switch rows.(type) { + case *Rows: + m = rows.(*Rows).Mapper + default: + m = mapper() + } + + fields := m.TraversalsByName(base, columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) + } + values = make([]interface{}, len(columns)) + + for rows.Next() { + // create a new struct type (which returns PtrTo) and indirect it + vp = reflect.New(base) + v = reflect.Indirect(vp) + + err = fieldsByTraversal(v, fields, values, true) + if err != nil { + return err + } + + // scan into the struct field pointers and append to our results + err = rows.Scan(values...) + if err != nil { + return err + } + + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, v)) + } + } + } else { + for rows.Next() { + vp = reflect.New(base) + err = rows.Scan(vp.Interface()) + if err != nil { + return err + } + // append + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, reflect.Indirect(vp))) + } + } + } + + return rows.Err() +} + +// FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately +// it doesn't really feel like it's named properly. There is an incongruency +// between this and the way that StructScan (which might better be ScanStruct +// anyway) works on a rows object. + +// StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. +// StructScan will scan in the entire rows result, so if you do not want to +// allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. +// If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. +func StructScan(rows rowsi, dest interface{}) error { + return scanAll(rows, dest, true) + +} + +// reflect helpers + +func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { + t = reflectx.Deref(t) + if t.Kind() != expected { + return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) + } + return t, nil +} + +// fieldsByName fills a values interface with fields from the passed value based +// on the traversals in int. If ptrs is true, return addresses instead of values. +// We write this instead of using FieldsByName to save allocations and map lookups +// when iterating over many rows. Empty traversals will get an interface pointer. +// Because of the necessity of requesting ptrs or values, it's considered a bit too +// specialized for inclusion in reflectx itself. +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return errors.New("argument not a struct") + } + + for i, traversal := range traversals { + if len(traversal) == 0 { + values[i] = new(interface{}) + continue + } + f := reflectx.FieldByIndexes(v, traversal) + if ptrs { + values[i] = f.Addr().Interface() + } else { + values[i] = f.Interface() + } + } + return nil +} + +func missingFields(transversals [][]int) (field int, err error) { + for i, t := range transversals { + if len(t) == 0 { + return i, errors.New("missing field") + } + } + return 0, nil +} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_context.go b/vendor/github.com/jmoiron/sqlx/sqlx_context.go new file mode 100644 index 000000000..0b1714514 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/sqlx_context.go @@ -0,0 +1,335 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" + "fmt" + "io/ioutil" + "path/filepath" + "reflect" +) + +// ConnectContext to a database and verify with a ping. +func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { + db, err := Open(driverName, dataSourceName) + if err != nil { + return db, err + } + err = db.PingContext(ctx) + return db, err +} + +// QueryerContext is an interface used by GetContext and SelectContext +type QueryerContext interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) + QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row +} + +// PreparerContext is an interface used by PreparexContext. +type PreparerContext interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +// ExecerContext is an interface used by MustExecContext and LoadFileContext +type ExecerContext interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// ExtContext is a union interface which can bind, query, and exec, with Context +// used by NamedQueryContext and NamedExecContext. +type ExtContext interface { + binder + QueryerContext + ExecerContext +} + +// SelectContext executes a query using the provided Queryer, and StructScans +// each row into dest, which must be a slice. If the slice elements are +// scannable, then the result set must have only one column. Otherwise, +// StructScan is used. The *sql.Rows are closed automatically. +// Any placeholder parameters are replaced with supplied args. +func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + rows, err := q.QueryxContext(ctx, query, args...) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// PreparexContext prepares a statement. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { + s, err := p.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err +} + +// GetContext does a QueryRow using the provided Queryer, and scans the +// resulting row to dest. If dest is scannable, the result must only have one +// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like +// row.Scan would. Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + r := q.QueryRowxContext(ctx, query, args...) + return r.scanAny(dest, false) +} + +// LoadFileContext exec's every statement in a file (as a single call to Exec). +// LoadFileContext may return a nil *sql.Result if errors are encountered +// locating or reading the file at path. LoadFile reads the entire file into +// memory, so it is not suitable for loading large data dumps, but can be useful +// for initializing schemas or loading indexes. +// +// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 +// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting +// this by requiring something with DriverName() and then attempting to split the +// queries will be difficult to get right, and its current driver-specific behavior +// is deemed at least not complex in its incorrectness. +func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { + realpath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + contents, err := ioutil.ReadFile(realpath) + if err != nil { + return nil, err + } + res, err := e.ExecContext(ctx, string(contents)) + return &res, err +} + +// MustExecContext execs the query using e and panics if there was an error. +// Any placeholder parameters are replaced with supplied args. +func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { + res, err := e.ExecContext(ctx, query, args...) + if err != nil { + panic(err) + } + return res +} + +// PrepareNamedContext returns an sqlx.NamedStmt +func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { + return prepareNamedContext(ctx, db, query) +} + +// NamedQueryContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { + return NamedQueryContext(ctx, db, query, arg) +} + +// NamedExecContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return NamedExecContext(ctx, db, query, arg) +} + +// SelectContext using this DB. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return SelectContext(ctx, db, dest, query, args...) +} + +// GetContext using this DB. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return GetContext(ctx, db, dest, query, args...) +} + +// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { + return PreparexContext(ctx, db, query) +} + +// QueryxContext queries the database and returns an *sqlx.Rows. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := db.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// QueryRowxContext queries the database and returns an *sqlx.Row. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.DB.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} +} + +// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead +// of an *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// MustBeginContext is canceled. +func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { + tx, err := db.BeginTxx(ctx, opts) + if err != nil { + panic(err) + } + return tx +} + +// MustExecContext (panic) runs MustExec using this database. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, db, query, args...) +} + +// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an +// *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// BeginxContext is canceled. +func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// StmtxContext returns a version of the prepared statement which runs within a +// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. +func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { + var s *sql.Stmt + switch v := stmt.(type) { + case Stmt: + s = v.Stmt + case *Stmt: + s = v.Stmt + case sql.Stmt: + s = &v + case *sql.Stmt: + s = v + default: + panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) + } + return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} +} + +// NamedStmtContext returns a version of the prepared statement which runs +// within a transaction. +func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { + return &NamedStmt{ + QueryString: stmt.QueryString, + Params: stmt.Params, + Stmt: tx.StmtxContext(ctx, stmt.Stmt), + } +} + +// MustExecContext runs MustExecContext within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, tx, query, args...) +} + +// QueryxContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := tx.Tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err +} + +// SelectContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return SelectContext(ctx, tx, dest, query, args...) +} + +// GetContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return GetContext(ctx, tx, dest, query, args...) +} + +// QueryRowxContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.Tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} +} + +// NamedExecContext using this Tx. +// Any named placeholder parameters are replaced with fields from arg. +func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return NamedExecContext(ctx, tx, query, arg) +} + +// SelectContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return SelectContext(ctx, &qStmt{s}, dest, "", args...) +} + +// GetContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return GetContext(ctx, &qStmt{s}, dest, "", args...) +} + +// MustExecContext (panic) using this statement. Note that the query portion of +// the error output will be blank, as Stmt does not expose its query. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { + return MustExecContext(ctx, &qStmt{s}, "", args...) +} + +// QueryRowxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { + qs := &qStmt{s} + return qs.QueryRowxContext(ctx, "", args...) +} + +// QueryxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { + qs := &qStmt{s} + return qs.QueryxContext(ctx, "", args...) +} + +func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return q.Stmt.QueryContext(ctx, args...) +} + +func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := q.Stmt.QueryContext(ctx, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err +} + +func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := q.Stmt.QueryContext(ctx, args...) + return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} +} + +func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return q.Stmt.ExecContext(ctx, args...) +} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go b/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go new file mode 100644 index 000000000..85e112bd5 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go @@ -0,0 +1,1344 @@ +// +build go1.8 + +// The following environment variables, if set, will be used: +// +// * SQLX_SQLITE_DSN +// * SQLX_POSTGRES_DSN +// * SQLX_MYSQL_DSN +// +// Set any of these variables to 'skip' to skip them. Note that for MySQL, +// the string '?parseTime=True' will be appended to the DSN if it's not there +// already. +// +package sqlx + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log" + "strings" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx/reflectx" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +func MultiExecContext(ctx context.Context, e ExecerContext, query string) { + stmts := strings.Split(query, ";\n") + if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { + stmts = stmts[:len(stmts)-1] + } + for _, s := range stmts { + _, err := e.ExecContext(ctx, s) + if err != nil { + fmt.Println(err, s) + } + } +} + +func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func(ctx context.Context, db *DB, t *testing.T)) { + runner := func(ctx context.Context, db *DB, t *testing.T, create, drop string) { + defer func() { + MultiExecContext(ctx, db, drop) + }() + + MultiExecContext(ctx, db, create) + test(ctx, db, t) + } + + if TestPostgres { + create, drop := schema.Postgres() + runner(ctx, pgdb, t, create, drop) + } + if TestSqlite { + create, drop := schema.Sqlite3() + runner(ctx, sldb, t, create, drop) + } + if TestMysql { + create, drop := schema.MySQL() + runner(ctx, mysqldb, t, create, drop) + } +} + +func loadDefaultFixtureContext(ctx context.Context, db *DB, t *testing.T) { + tx := db.MustBeginTx(ctx, nil) + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + if db.DriverName() == "mysql" { + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") + } else { + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") + } + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") + tx.Commit() +} + +// Test a new backwards compatible feature, that missing scan destinations +// will silently scan into sql.RawText rather than failing/panicing +func TestMissingNamesContextContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + type PersonPlus struct { + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string + //AddedAt time.Time `db:"added_at"` + } + + // test Select first + pps := []PersonPlus{} + // pps lacks added_at destination + err := db.SelectContext(ctx, &pps, "SELECT * FROM person") + if err == nil { + t.Error("Expected missing name from Select to fail, but it did not.") + } + + // test Get + pp := PersonPlus{} + err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected missing name Get to fail, but it did not.") + } + + // test naked StructScan + pps = []PersonPlus{} + rows, err := db.QueryContext(ctx, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rows.Next() + err = StructScan(rows, &pps) + if err == nil { + t.Error("Expected missing name in StructScan to fail, but it did not.") + } + rows.Close() + + // now try various things with unsafe set. + db = db.Unsafe() + pps = []PersonPlus{} + err = db.SelectContext(ctx, &pps, "SELECT * FROM person") + if err != nil { + t.Error(err) + } + + // test Get + pp = PersonPlus{} + err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Error(err) + } + + // test naked StructScan + pps = []PersonPlus{} + rowsx, err := db.QueryxContext(ctx, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rowsx.Next() + err = StructScan(rowsx, &pps) + if err != nil { + t.Error(err) + } + rowsx.Close() + + // test Named stmt + if !isUnsafe(db) { + t.Error("Expected db to be unsafe, but it isn't") + } + nstmt, err := db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // its internal stmt should be marked unsafe + if !nstmt.Stmt.unsafe { + t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") + } + pps = []PersonPlus{} + err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + // test it with a safe db + db.unsafe = false + if isUnsafe(db) { + t.Error("expected db to be safe but it isn't") + } + nstmt, err = db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // it should be safe + if isUnsafe(nstmt) { + t.Error("NamedStmt did not inherit safety") + } + nstmt.Unsafe() + if !isUnsafe(nstmt) { + t.Error("expected newly unsafed NamedStmt to be unsafe") + } + pps = []PersonPlus{} + err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + }) +} + +func TestEmbeddedStructsContextContext(t *testing.T) { + type Loop1 struct{ Person } + type Loop2 struct{ Loop1 } + type Loop3 struct{ Loop2 } + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + peopleAndPlaces := []PersonPlace{} + err := db.SelectContext( + ctx, + &peopleAndPlaces, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlaces { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test embedded structs with StructScan + rows, err := db.QueryxContext( + ctx, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Error(err) + } + + perp := PersonPlace{} + rows.Next() + err = rows.StructScan(&perp) + if err != nil { + t.Error(err) + } + + if len(perp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(perp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + + rows.Close() + + // test the same for embedded pointer structs + peopleAndPlacesPtrs := []PersonPlacePtr{} + err = db.SelectContext( + ctx, + &peopleAndPlacesPtrs, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlacesPtrs { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test "deep nesting" + l3s := []Loop3{} + err = db.SelectContext(ctx, &l3s, `select * from person`) + if err != nil { + t.Fatal(err) + } + for _, l3 := range l3s { + if len(l3.Loop2.Loop1.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + } + + // test "embed conflicts" + ec := []EmbedConflict{} + err = db.SelectContext(ctx, &ec, `select * from person`) + // I'm torn between erroring here or having some kind of working behavior + // in order to allow for more flexibility in destination structs + if err != nil { + t.Errorf("Was not expecting an error on embed conflicts.") + } + }) +} + +func TestJoinQueryContext(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + + var employees []struct { + Employee + Boss `db:"boss"` + } + + err := db.SelectContext(ctx, + &employees, + `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees + JOIN employees AS boss ON employees.boss_id = boss.id`) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Employee.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Employee.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestJoinQueryNamedPointerStructsContext(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + + var employees []struct { + Emp1 *Employee `db:"emp1"` + Emp2 *Employee `db:"emp2"` + *Boss `db:"boss"` + } + + err := db.SelectContext(ctx, + &employees, + `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", + emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", + boss.id "boss.id", boss.name "boss.name" FROM employees AS emp + JOIN employees AS boss ON emp.boss_id = boss.id + `) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestSelectSliceMapTimeContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + rows, err := db.QueryxContext(ctx, "SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + _, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + m := map[string]interface{}{} + err := rows.MapScan(m) + if err != nil { + t.Error(err) + } + } + + }) +} + +func TestNilReceiverContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + var p *Person + err := db.GetContext(ctx, p, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected error when getting into nil struct ptr.") + } + var pp *[]Person + err = db.SelectContext(ctx, pp, "SELECT * FROM person") + if err == nil { + t.Error("Expected an error when selecting into nil slice ptr.") + } + }) +} + +func TestNamedQueryContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE place ( + id integer PRIMARY KEY, + name text NULL + ); + CREATE TABLE person ( + first_name text NULL, + last_name text NULL, + email text NULL + ); + CREATE TABLE placeperson ( + first_name text NULL, + last_name text NULL, + email text NULL, + place_id integer NULL + ); + CREATE TABLE jsperson ( + "FIRST" text NULL, + last_name text NULL, + "EMAIL" text NULL + );`, + drop: ` + drop table person; + drop table jsperson; + drop table place; + drop table placeperson; + `, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type Person struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + } + + p := Person{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` + _, err := db.NamedExecContext(ctx, q1, p) + if err != nil { + log.Fatal(err) + } + + p2 := &Person{} + rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(p2) + if err != nil { + t.Error(err) + } + if p2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + p2.FirstName.String) + } + if p2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + p2.LastName.String) + } + } + + // these are tests for #73; they verify that named queries work if you've + // changed the db mapper. This code checks both NamedQuery "ad-hoc" style + // queries and NamedStmt queries, which use different code paths internally. + old := *db.Mapper + + type JSONPerson struct { + FirstName sql.NullString `json:"FIRST"` + LastName sql.NullString `json:"last_name"` + Email sql.NullString + } + + jp := JSONPerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "smith", Valid: true}, + Email: sql.NullString{String: "ben@smith.com", Valid: true}, + } + + db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) + + // prepare queries for case sensitivity to test our ToUpper function. + // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line + // strings are `` we use "" by default and swap out for MySQL + pdb := func(s string, db *DB) string { + if db.DriverName() == "mysql" { + return strings.Replace(s, `"`, "`", -1) + } + return s + } + + q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` + _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) + if err != nil { + t.Fatal(err, db.DriverName()) + } + + // Checks that a person pulled out of the db matches the one we put in + check := func(t *testing.T, rows *Rows) { + jp = JSONPerson{} + for rows.Next() { + err = rows.StructScan(&jp) + if err != nil { + t.Error(err) + } + if jp.FirstName.String != "ben" { + t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) + } + if jp.LastName.String != "smith" { + t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) + } + if jp.Email.String != "ben@smith.com" { + t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) + } + } + } + + ns, err := db.PrepareNamed(pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db)) + + if err != nil { + t.Fatal(err) + } + rows, err = ns.QueryxContext(ctx, jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + // Check exactly the same thing, but with db.NamedQuery, which does not go + // through the PrepareNamed/NamedStmt path. + rows, err = db.NamedQueryContext(ctx, pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db), jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + db.Mapper = &old + + // Test nested structs + type Place struct { + ID int `db:"id"` + Name sql.NullString `db:"name"` + } + type PlacePerson struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + Place Place `db:"place"` + } + + pl := Place{ + Name: sql.NullString{String: "myplace", Valid: true}, + } + + pp := PlacePerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q2 := `INSERT INTO place (id, name) VALUES (1, :name)` + _, err = db.NamedExecContext(ctx, q2, pl) + if err != nil { + log.Fatal(err) + } + + id := 1 + pp.Place.ID = id + + q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q3, pp) + if err != nil { + log.Fatal(err) + } + + pp2 := &PlacePerson{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + first_name, + last_name, + email, + place.id AS "place.id", + place.name AS "place.name" + FROM placeperson + INNER JOIN place ON place.id = placeperson.place_id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp2) + if err != nil { + t.Error(err) + } + if pp2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) + } + if pp2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + pp2.LastName.String) + } + if pp2.Place.Name.String != "myplace" { + t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) + } + if pp2.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) + } + } + }) +} + +func TestNilInsertsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE tt ( + id integer, + value text NULL DEFAULT NULL + );`, + drop: "drop table tt;", + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type TT struct { + ID int + Value *string + } + var v, v2 TT + r := db.Rebind + + db.MustExecContext(ctx, r(`INSERT INTO tt (id) VALUES (1)`)) + db.GetContext(ctx, &v, r(`SELECT * FROM tt`)) + if v.ID != 1 { + t.Errorf("Expecting id of 1, got %v", v.ID) + } + if v.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + + v.ID = 2 + // NOTE: this incidentally uncovered a bug which was that named queries with + // pointer destinations would not work if the passed value here was not addressable, + // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for + // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly + // function. This next line is important as it provides the only coverage for this. + db.NamedExecContext(ctx, `INSERT INTO tt (id, value) VALUES (:id, :value)`, v) + + db.GetContext(ctx, &v2, r(`SELECT * FROM tt WHERE id=2`)) + if v.ID != v2.ID { + t.Errorf("%v != %v", v.ID, v2.ID) + } + if v2.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + }) +} + +func TestScanErrorContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE kv ( + k text, + v integer + );`, + drop: `drop table kv;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type WrongTypes struct { + K int + V string + } + _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) + if err != nil { + t.Error(err) + } + + rows, err := db.QueryxContext(ctx, "SELECT * FROM kv") + if err != nil { + t.Error(err) + } + for rows.Next() { + var wt WrongTypes + err := rows.StructScan(&wt) + if err == nil { + t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) + } + } + }) +} + +// FIXME: this function is kinda big but it slows things down to be constantly +// loading and reloading the schema.. + +func TestUsageContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + slicemembers := []SliceMember{} + err := db.SelectContext(ctx, &slicemembers, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + people := []Person{} + + err = db.SelectContext(ctx, &people, "SELECT * FROM person ORDER BY first_name ASC") + if err != nil { + t.Fatal(err) + } + + jason, john := people[0], people[1] + if jason.FirstName != "Jason" { + t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) + } + if jason.LastName != "Moiron" { + t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) + } + if jason.Email != "jmoiron@jmoiron.net" { + t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) + } + if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { + t.Errorf("John Doe's person record not what expected: Got %v\n", john) + } + + jason = Person{} + err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") + + if err != nil { + t.Fatal(err) + } + if jason.FirstName != "Jason" { + t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) + } + + err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") + if err == nil { + t.Errorf("Expecting an error, got nil\n") + } + if err != sql.ErrNoRows { + t.Errorf("Expected sql.ErrNoRows, got %v\n", err) + } + + // The following tests check statement reuse, which was actually a problem + // due to copying being done when creating Stmt's which was eventually removed + stmt1, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + + row := stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + row = stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + + err = stmt1.GetContext(ctx, &jason, "DoesNotExist User") + if err == nil { + t.Error("Expected an error") + } + err = stmt1.GetContext(ctx, &jason, "DoesNotExist User 2") + if err == nil { + t.Fatal(err) + } + + stmt2, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + tx, err := db.Beginx() + if err != nil { + t.Fatal(err) + } + tstmt2 := tx.Stmtx(stmt2) + row2 := tstmt2.QueryRowx("Jason") + err = row2.StructScan(&jason) + if err != nil { + t.Error(err) + } + tx.Commit() + + places := []*Place{} + err = db.SelectContext(ctx, &places, "SELECT telcode FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers := places[0], places[1], places[2] + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + placesptr := []PlacePtr{} + err = db.SelectContext(ctx, &placesptr, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Error(err) + } + //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) + + // if you have null fields and use SELECT *, you must use sql.Null* in your struct + // this test also verifies that you can use either a []Struct{} or a []*Struct{} + places2 := []Place{} + err = db.SelectContext(ctx, &places2, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers = &places2[0], &places2[1], &places2[2] + + // this should return a type error that &p is not a pointer to a struct slice + p := Place{} + err = db.SelectContext(ctx, &p, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") + } + + // this should be an error + pl := []Place{} + err = db.SelectContext(ctx, pl, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") + } + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + stmt, err := db.PreparexContext(ctx, db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) + if err != nil { + t.Error(err) + } + + places = []*Place{} + err = stmt.SelectContext(ctx, &places, 10) + if len(places) != 2 { + t.Error("Expected 2 places, got 0.") + } + if err != nil { + t.Fatal(err) + } + singsing, honkers = places[0], places[1] + if singsing.TelCode != 65 || honkers.TelCode != 852 { + t.Errorf("Expected the right telcodes, got %#v", places) + } + + rows, err := db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + place := Place{} + for rows.Next() { + err = rows.StructScan(&place) + if err != nil { + t.Fatal(err) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + m := map[string]interface{}{} + for rows.Next() { + err = rows.MapScan(m) + if err != nil { + t.Fatal(err) + } + _, ok := m["country"] + if !ok { + t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + s, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + if len(s) != 3 { + t.Errorf("Expected 3 columns in result, got %d\n", len(s)) + } + } + + // test advanced querying + // test that NamedExec works with a map as well as a struct + _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ + "first": "Bin", + "last": "Smuth", + "email": "bensmith@allblacks.nz", + }) + if err != nil { + t.Fatal(err) + } + + // ensure that if the named param happens right at the end it still works + // ensure that NamedQuery works with a map[string]interface{} + rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) + if err != nil { + t.Fatal(err) + } + + ben := &Person{} + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Bin" { + t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) + } + if ben.LastName != "Smuth" { + t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) + } + } + + ben.FirstName = "Ben" + ben.LastName = "Smith" + ben.Email = "binsmuth@allblacks.nz" + + // Insert via a named query using the struct + _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) + + if err != nil { + t.Fatal(err) + } + + rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", ben) + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + // ensure that Get does not panic on emppty result set + person := &Person{} + err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") + if err == nil { + t.Fatal("Should have got an error for Get on non-existant row.") + } + + // lets test prepared statements some more + + stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.QueryxContext(ctx, "Ben") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + + john = Person{} + stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Error(err) + } + err = stmt.GetContext(ctx, &john, "John") + if err != nil { + t.Error(err) + } + + // test name mapping + // THIS USED TO WORK BUT WILL NO LONGER WORK. + db.MapperFunc(strings.ToUpper) + rsa := CPlace{} + err = db.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + if err != nil { + t.Error(err, "in db:", db.DriverName()) + } + db.MapperFunc(strings.ToLower) + + // create a copy and change the mapper, then verify the copy behaves + // differently from the original. + dbCopy := NewDb(db.DB, db.DriverName()) + dbCopy.MapperFunc(strings.ToUpper) + err = dbCopy.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + if err != nil { + fmt.Println(db.DriverName()) + t.Error(err) + } + + err = db.GetContext(ctx, &rsa, "SELECT * FROM cappplace;") + if err == nil { + t.Error("Expected no error, got ", err) + } + + // test base type slices + var sdest []string + rows, err = db.QueryxContext(ctx, "SELECT email FROM person ORDER BY email ASC;") + if err != nil { + t.Error(err) + } + err = scanAll(rows, &sdest, false) + if err != nil { + t.Error(err) + } + + // test Get with base types + var count int + err = db.GetContext(ctx, &count, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if count != len(sdest) { + t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) + } + + // test Get and Select with time.Time, #84 + var addedAt time.Time + err = db.GetContext(ctx, &addedAt, "SELECT added_at FROM person LIMIT 1;") + if err != nil { + t.Error(err) + } + + var addedAts []time.Time + err = db.SelectContext(ctx, &addedAts, "SELECT added_at FROM person;") + if err != nil { + t.Error(err) + } + + // test it on a double pointer + var pcount *int + err = db.GetContext(ctx, &pcount, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if *pcount != count { + t.Errorf("expected %d = %d", *pcount, count) + } + + // test Select... + sdest = []string{} + err = db.SelectContext(ctx, &sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") + if err != nil { + t.Error(err) + } + expected := []string{"Ben", "Bin", "Jason", "John"} + for i, got := range sdest { + if got != expected[i] { + t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) + } + } + + var nsdest []sql.NullString + err = db.SelectContext(ctx, &nsdest, "SELECT city FROM place ORDER BY city ASC") + if err != nil { + t.Error(err) + } + for _, val := range nsdest { + if val.Valid && val.String != "New York" { + t.Errorf("expected single valid result to be `New York`, but got %s", val.String) + } + } + }) +} + +// tests that sqlx will not panic when the wrong driver is passed because +// of an automatic nil dereference in sqlx.Open(), which was fixed. +func TestDoNotPanicOnConnectContext(t *testing.T) { + _, err := ConnectContext(context.Background(), "bogus", "hehe") + if err == nil { + t.Errorf("Should return error when using bogus driverName") + } +} + +func TestEmbeddedMapsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE message ( + string text, + properties text + );`, + drop: `drop table message;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + messages := []Message{ + {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, + {"Thanks, Joy", PropertyMap{"pull": "request"}}, + } + q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` + for _, m := range messages { + _, err := db.NamedExecContext(ctx, q1, m) + if err != nil { + t.Fatal(err) + } + } + var count int + err := db.GetContext(ctx, &count, "SELECT count(*) FROM message") + if err != nil { + t.Fatal(err) + } + if count != len(messages) { + t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) + } + + var m Message + err = db.GetContext(ctx, &m, "SELECT * FROM message LIMIT 1;") + if err != nil { + t.Fatal(err) + } + if m.Properties == nil { + t.Fatal("Expected m.Properties to not be nil, but it was.") + } + }) +} + +func TestIssue197Context(t *testing.T) { + // this test actually tests for a bug in database/sql: + // https://github.com/golang/go/issues/13905 + // this potentially makes _any_ named type that is an alias for []byte + // unsafe to use in a lot of different ways (basically, unsafe to hold + // onto after loading from the database). + t.Skip() + + type mybyte []byte + type Var struct{ Raw json.RawMessage } + type Var2 struct{ Raw []byte } + type Var3 struct{ Raw mybyte } + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + var err error + var v, q Var + if err = db.GetContext(ctx, &v, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.GetContext(ctx, &q, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v2, q2 Var2 + if err = db.GetContext(ctx, &v2, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.GetContext(ctx, &q2, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v3, q3 Var3 + if err = db.QueryRowContext(ctx, `SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { + t.Fatal(err) + } + if err = db.QueryRowContext(ctx, `SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { + t.Fatal(err) + } + t.Fail() + }) +} + +func TestInContext(t *testing.T) { + // some quite normal situations + type tr struct { + q string + args []interface{} + c int + } + tests := []tr{ + {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", + []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, + 7}, + {"SELECT * FROM foo WHERE x in (?)", + []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, + 8}, + } + for _, test := range tests { + q, a, err := In(test.q, test.args...) + if err != nil { + t.Error(err) + } + if len(a) != test.c { + t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) + } + if strings.Count(q, "?") != test.c { + t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) + } + } + + // too many bindVars, but no slices, so short circuits parsing + // i'm not sure if this is the right behavior; this query/arg combo + // might not work, but we shouldn't parse if we don't need to + { + orig := "SELECT * FROM foo WHERE x = ? AND y = ?" + q, a, err := In(orig, "foo", "bar", "baz") + if err != nil { + t.Error(err) + } + if len(a) != 3 { + t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) + } + if q != orig { + t.Error("Expected unchanged query.") + } + } + + tests = []tr{ + // too many bindvars; slice present so should return error during parse + {"SELECT * FROM foo WHERE x = ? and y = ?", + []interface{}{"foo", []int{1, 2, 3}, "bar"}, + 0}, + // empty slice, should return error before parse + {"SELECT * FROM foo WHERE x = ?", + []interface{}{[]int{}}, + 0}, + // too *few* bindvars, should return an error + {"SELECT * FROM foo WHERE x = ? AND y in (?)", + []interface{}{[]int{1, 2, 3}}, + 0}, + } + for _, test := range tests { + _, _, err := In(test.q, test.args...) + if err == nil { + t.Error("Expected an error, but got nil.") + } + } + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + telcodes := []int{852, 65} + q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" + query, args, err := In(q, telcodes) + if err != nil { + t.Error(err) + } + query = db.Rebind(query) + places := []Place{} + err = db.SelectContext(ctx, &places, query, args...) + if err != nil { + t.Error(err) + } + if len(places) != 2 { + t.Fatalf("Expecting 2 results, got %d", len(places)) + } + if places[0].TelCode != 65 { + t.Errorf("Expecting singapore first, but got %#v", places[0]) + } + if places[1].TelCode != 852 { + t.Errorf("Expecting hong kong second, but got %#v", places[1]) + } + }) +} + +func TestEmbeddedLiteralsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE x ( + k text + );`, + drop: `drop table x;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type t1 struct { + K *string + } + type t2 struct { + Inline struct { + F string + } + K *string + } + + db.MustExecContext(ctx, db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") + + target := t1{} + err := db.GetContext(ctx, &target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target.K != "one" { + t.Error("Expected target.K to be `one`, got ", target.K) + } + + target2 := t2{} + err = db.GetContext(ctx, &target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target2.K != "one" { + t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) + } + }) +} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_test.go b/vendor/github.com/jmoiron/sqlx/sqlx_test.go new file mode 100644 index 000000000..5752773a0 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/sqlx_test.go @@ -0,0 +1,1792 @@ +// The following environment variables, if set, will be used: +// +// * SQLX_SQLITE_DSN +// * SQLX_POSTGRES_DSN +// * SQLX_MYSQL_DSN +// +// Set any of these variables to 'skip' to skip them. Note that for MySQL, +// the string '?parseTime=True' will be appended to the DSN if it's not there +// already. +// +package sqlx + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "log" + "os" + "reflect" + "strings" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx/reflectx" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +/* compile time checks that Db, Tx, Stmt (qStmt) implement expected interfaces */ +var _, _ Ext = &DB{}, &Tx{} +var _, _ ColScanner = &Row{}, &Rows{} +var _ Queryer = &qStmt{} +var _ Execer = &qStmt{} + +var TestPostgres = true +var TestSqlite = true +var TestMysql = true + +var sldb *DB +var pgdb *DB +var mysqldb *DB +var active = []*DB{} + +func init() { + ConnectAll() +} + +func ConnectAll() { + var err error + + pgdsn := os.Getenv("SQLX_POSTGRES_DSN") + mydsn := os.Getenv("SQLX_MYSQL_DSN") + sqdsn := os.Getenv("SQLX_SQLITE_DSN") + + TestPostgres = pgdsn != "skip" + TestMysql = mydsn != "skip" + TestSqlite = sqdsn != "skip" + + if !strings.Contains(mydsn, "parseTime=true") { + mydsn += "?parseTime=true" + } + + if TestPostgres { + pgdb, err = Connect("postgres", pgdsn) + if err != nil { + fmt.Printf("Disabling PG tests:\n %v\n", err) + TestPostgres = false + } + } else { + fmt.Println("Disabling Postgres tests.") + } + + if TestMysql { + mysqldb, err = Connect("mysql", mydsn) + if err != nil { + fmt.Printf("Disabling MySQL tests:\n %v", err) + TestMysql = false + } + } else { + fmt.Println("Disabling MySQL tests.") + } + + if TestSqlite { + sldb, err = Connect("sqlite3", sqdsn) + if err != nil { + fmt.Printf("Disabling SQLite:\n %v", err) + TestSqlite = false + } + } else { + fmt.Println("Disabling SQLite tests.") + } +} + +type Schema struct { + create string + drop string +} + +func (s Schema) Postgres() (string, string) { + return s.create, s.drop +} + +func (s Schema) MySQL() (string, string) { + return strings.Replace(s.create, `"`, "`", -1), s.drop +} + +func (s Schema) Sqlite3() (string, string) { + return strings.Replace(s.create, `now()`, `CURRENT_TIMESTAMP`, -1), s.drop +} + +var defaultSchema = Schema{ + create: ` +CREATE TABLE person ( + first_name text, + last_name text, + email text, + added_at timestamp default now() +); + +CREATE TABLE place ( + country text, + city text NULL, + telcode integer +); + +CREATE TABLE capplace ( + "COUNTRY" text, + "CITY" text NULL, + "TELCODE" integer +); + +CREATE TABLE nullperson ( + first_name text NULL, + last_name text NULL, + email text NULL +); + +CREATE TABLE employees ( + name text, + id integer, + boss_id integer +); + +`, + drop: ` +drop table person; +drop table place; +drop table capplace; +drop table nullperson; +drop table employees; +`, +} + +type Person struct { + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string + AddedAt time.Time `db:"added_at"` +} + +type Person2 struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString +} + +type Place struct { + Country string + City sql.NullString + TelCode int +} + +type PlacePtr struct { + Country string + City *string + TelCode int +} + +type PersonPlace struct { + Person + Place +} + +type PersonPlacePtr struct { + *Person + *Place +} + +type EmbedConflict struct { + FirstName string `db:"first_name"` + Person +} + +type SliceMember struct { + Country string + City sql.NullString + TelCode int + People []Person `db:"-"` + Addresses []Place `db:"-"` +} + +// Note that because of field map caching, we need a new type here +// if we've used Place already somewhere in sqlx +type CPlace Place + +func MultiExec(e Execer, query string) { + stmts := strings.Split(query, ";\n") + if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { + stmts = stmts[:len(stmts)-1] + } + for _, s := range stmts { + _, err := e.Exec(s) + if err != nil { + fmt.Println(err, s) + } + } +} + +func RunWithSchema(schema Schema, t *testing.T, test func(db *DB, t *testing.T)) { + runner := func(db *DB, t *testing.T, create, drop string) { + defer func() { + MultiExec(db, drop) + }() + + MultiExec(db, create) + test(db, t) + } + + if TestPostgres { + create, drop := schema.Postgres() + runner(pgdb, t, create, drop) + } + if TestSqlite { + create, drop := schema.Sqlite3() + runner(sldb, t, create, drop) + } + if TestMysql { + create, drop := schema.MySQL() + runner(mysqldb, t, create, drop) + } +} + +func loadDefaultFixture(db *DB, t *testing.T) { + tx := db.MustBegin() + tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") + tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") + tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + if db.DriverName() == "mysql" { + tx.MustExec(tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") + } else { + tx.MustExec(tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") + } + tx.MustExec(tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") + tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") + tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") + tx.Commit() +} + +// Test a new backwards compatible feature, that missing scan destinations +// will silently scan into sql.RawText rather than failing/panicing +func TestMissingNames(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + type PersonPlus struct { + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string + //AddedAt time.Time `db:"added_at"` + } + + // test Select first + pps := []PersonPlus{} + // pps lacks added_at destination + err := db.Select(&pps, "SELECT * FROM person") + if err == nil { + t.Error("Expected missing name from Select to fail, but it did not.") + } + + // test Get + pp := PersonPlus{} + err = db.Get(&pp, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected missing name Get to fail, but it did not.") + } + + // test naked StructScan + pps = []PersonPlus{} + rows, err := db.Query("SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rows.Next() + err = StructScan(rows, &pps) + if err == nil { + t.Error("Expected missing name in StructScan to fail, but it did not.") + } + rows.Close() + + // now try various things with unsafe set. + db = db.Unsafe() + pps = []PersonPlus{} + err = db.Select(&pps, "SELECT * FROM person") + if err != nil { + t.Error(err) + } + + // test Get + pp = PersonPlus{} + err = db.Get(&pp, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Error(err) + } + + // test naked StructScan + pps = []PersonPlus{} + rowsx, err := db.Queryx("SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rowsx.Next() + err = StructScan(rowsx, &pps) + if err != nil { + t.Error(err) + } + rowsx.Close() + + // test Named stmt + if !isUnsafe(db) { + t.Error("Expected db to be unsafe, but it isn't") + } + nstmt, err := db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // its internal stmt should be marked unsafe + if !nstmt.Stmt.unsafe { + t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") + } + pps = []PersonPlus{} + err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + // test it with a safe db + db.unsafe = false + if isUnsafe(db) { + t.Error("expected db to be safe but it isn't") + } + nstmt, err = db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // it should be safe + if isUnsafe(nstmt) { + t.Error("NamedStmt did not inherit safety") + } + nstmt.Unsafe() + if !isUnsafe(nstmt) { + t.Error("expected newly unsafed NamedStmt to be unsafe") + } + pps = []PersonPlus{} + err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + }) +} + +func TestEmbeddedStructs(t *testing.T) { + type Loop1 struct{ Person } + type Loop2 struct{ Loop1 } + type Loop3 struct{ Loop2 } + + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + peopleAndPlaces := []PersonPlace{} + err := db.Select( + &peopleAndPlaces, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlaces { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test embedded structs with StructScan + rows, err := db.Queryx( + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Error(err) + } + + perp := PersonPlace{} + rows.Next() + err = rows.StructScan(&perp) + if err != nil { + t.Error(err) + } + + if len(perp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(perp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + + rows.Close() + + // test the same for embedded pointer structs + peopleAndPlacesPtrs := []PersonPlacePtr{} + err = db.Select( + &peopleAndPlacesPtrs, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlacesPtrs { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test "deep nesting" + l3s := []Loop3{} + err = db.Select(&l3s, `select * from person`) + if err != nil { + t.Fatal(err) + } + for _, l3 := range l3s { + if len(l3.Loop2.Loop1.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + } + + // test "embed conflicts" + ec := []EmbedConflict{} + err = db.Select(&ec, `select * from person`) + // I'm torn between erroring here or having some kind of working behavior + // in order to allow for more flexibility in destination structs + if err != nil { + t.Errorf("Was not expecting an error on embed conflicts.") + } + }) +} + +func TestJoinQuery(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + + var employees []struct { + Employee + Boss `db:"boss"` + } + + err := db.Select( + &employees, + `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees + JOIN employees AS boss ON employees.boss_id = boss.id`) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Employee.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Employee.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestJoinQueryNamedPointerStructs(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + + var employees []struct { + Emp1 *Employee `db:"emp1"` + Emp2 *Employee `db:"emp2"` + *Boss `db:"boss"` + } + + err := db.Select( + &employees, + `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", + emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", + boss.id "boss.id", boss.name "boss.name" FROM employees AS emp + JOIN employees AS boss ON emp.boss_id = boss.id + `) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestSelectSliceMapTime(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + rows, err := db.Queryx("SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + _, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + } + + rows, err = db.Queryx("SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + m := map[string]interface{}{} + err := rows.MapScan(m) + if err != nil { + t.Error(err) + } + } + + }) +} + +func TestNilReceiver(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + var p *Person + err := db.Get(p, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected error when getting into nil struct ptr.") + } + var pp *[]Person + err = db.Select(pp, "SELECT * FROM person") + if err == nil { + t.Error("Expected an error when selecting into nil slice ptr.") + } + }) +} + +func TestNamedQuery(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE place ( + id integer PRIMARY KEY, + name text NULL + ); + CREATE TABLE person ( + first_name text NULL, + last_name text NULL, + email text NULL + ); + CREATE TABLE placeperson ( + first_name text NULL, + last_name text NULL, + email text NULL, + place_id integer NULL + ); + CREATE TABLE jsperson ( + "FIRST" text NULL, + last_name text NULL, + "EMAIL" text NULL + );`, + drop: ` + drop table person; + drop table jsperson; + drop table place; + drop table placeperson; + `, + } + + RunWithSchema(schema, t, func(db *DB, t *testing.T) { + type Person struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + } + + p := Person{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` + _, err := db.NamedExec(q1, p) + if err != nil { + log.Fatal(err) + } + + p2 := &Person{} + rows, err := db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", p) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(p2) + if err != nil { + t.Error(err) + } + if p2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + p2.FirstName.String) + } + if p2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + p2.LastName.String) + } + } + + // these are tests for #73; they verify that named queries work if you've + // changed the db mapper. This code checks both NamedQuery "ad-hoc" style + // queries and NamedStmt queries, which use different code paths internally. + old := *db.Mapper + + type JSONPerson struct { + FirstName sql.NullString `json:"FIRST"` + LastName sql.NullString `json:"last_name"` + Email sql.NullString + } + + jp := JSONPerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "smith", Valid: true}, + Email: sql.NullString{String: "ben@smith.com", Valid: true}, + } + + db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) + + // prepare queries for case sensitivity to test our ToUpper function. + // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line + // strings are `` we use "" by default and swap out for MySQL + pdb := func(s string, db *DB) string { + if db.DriverName() == "mysql" { + return strings.Replace(s, `"`, "`", -1) + } + return s + } + + q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` + _, err = db.NamedExec(pdb(q1, db), jp) + if err != nil { + t.Fatal(err, db.DriverName()) + } + + // Checks that a person pulled out of the db matches the one we put in + check := func(t *testing.T, rows *Rows) { + jp = JSONPerson{} + for rows.Next() { + err = rows.StructScan(&jp) + if err != nil { + t.Error(err) + } + if jp.FirstName.String != "ben" { + t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) + } + if jp.LastName.String != "smith" { + t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) + } + if jp.Email.String != "ben@smith.com" { + t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) + } + } + } + + ns, err := db.PrepareNamed(pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db)) + + if err != nil { + t.Fatal(err) + } + rows, err = ns.Queryx(jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + // Check exactly the same thing, but with db.NamedQuery, which does not go + // through the PrepareNamed/NamedStmt path. + rows, err = db.NamedQuery(pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db), jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + db.Mapper = &old + + // Test nested structs + type Place struct { + ID int `db:"id"` + Name sql.NullString `db:"name"` + } + type PlacePerson struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + Place Place `db:"place"` + } + + pl := Place{ + Name: sql.NullString{String: "myplace", Valid: true}, + } + + pp := PlacePerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q2 := `INSERT INTO place (id, name) VALUES (1, :name)` + _, err = db.NamedExec(q2, pl) + if err != nil { + log.Fatal(err) + } + + id := 1 + pp.Place.ID = id + + q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExec(q3, pp) + if err != nil { + log.Fatal(err) + } + + pp2 := &PlacePerson{} + rows, err = db.NamedQuery(` + SELECT + first_name, + last_name, + email, + place.id AS "place.id", + place.name AS "place.name" + FROM placeperson + INNER JOIN place ON place.id = placeperson.place_id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp2) + if err != nil { + t.Error(err) + } + if pp2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) + } + if pp2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + pp2.LastName.String) + } + if pp2.Place.Name.String != "myplace" { + t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) + } + if pp2.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) + } + } + }) +} + +func TestNilInserts(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE tt ( + id integer, + value text NULL DEFAULT NULL + );`, + drop: "drop table tt;", + } + + RunWithSchema(schema, t, func(db *DB, t *testing.T) { + type TT struct { + ID int + Value *string + } + var v, v2 TT + r := db.Rebind + + db.MustExec(r(`INSERT INTO tt (id) VALUES (1)`)) + db.Get(&v, r(`SELECT * FROM tt`)) + if v.ID != 1 { + t.Errorf("Expecting id of 1, got %v", v.ID) + } + if v.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + + v.ID = 2 + // NOTE: this incidentally uncovered a bug which was that named queries with + // pointer destinations would not work if the passed value here was not addressable, + // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for + // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly + // function. This next line is important as it provides the only coverage for this. + db.NamedExec(`INSERT INTO tt (id, value) VALUES (:id, :value)`, v) + + db.Get(&v2, r(`SELECT * FROM tt WHERE id=2`)) + if v.ID != v2.ID { + t.Errorf("%v != %v", v.ID, v2.ID) + } + if v2.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + }) +} + +func TestScanError(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE kv ( + k text, + v integer + );`, + drop: `drop table kv;`, + } + + RunWithSchema(schema, t, func(db *DB, t *testing.T) { + type WrongTypes struct { + K int + V string + } + _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) + if err != nil { + t.Error(err) + } + + rows, err := db.Queryx("SELECT * FROM kv") + if err != nil { + t.Error(err) + } + for rows.Next() { + var wt WrongTypes + err := rows.StructScan(&wt) + if err == nil { + t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) + } + } + }) +} + +// FIXME: this function is kinda big but it slows things down to be constantly +// loading and reloading the schema.. + +func TestUsage(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + slicemembers := []SliceMember{} + err := db.Select(&slicemembers, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + people := []Person{} + + err = db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") + if err != nil { + t.Fatal(err) + } + + jason, john := people[0], people[1] + if jason.FirstName != "Jason" { + t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) + } + if jason.LastName != "Moiron" { + t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) + } + if jason.Email != "jmoiron@jmoiron.net" { + t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) + } + if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { + t.Errorf("John Doe's person record not what expected: Got %v\n", john) + } + + jason = Person{} + err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") + + if err != nil { + t.Fatal(err) + } + if jason.FirstName != "Jason" { + t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) + } + + err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") + if err == nil { + t.Errorf("Expecting an error, got nil\n") + } + if err != sql.ErrNoRows { + t.Errorf("Expected sql.ErrNoRows, got %v\n", err) + } + + // The following tests check statement reuse, which was actually a problem + // due to copying being done when creating Stmt's which was eventually removed + stmt1, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + + row := stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + row = stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + + err = stmt1.Get(&jason, "DoesNotExist User") + if err == nil { + t.Error("Expected an error") + } + err = stmt1.Get(&jason, "DoesNotExist User 2") + if err == nil { + t.Fatal(err) + } + + stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + tx, err := db.Beginx() + if err != nil { + t.Fatal(err) + } + tstmt2 := tx.Stmtx(stmt2) + row2 := tstmt2.QueryRowx("Jason") + err = row2.StructScan(&jason) + if err != nil { + t.Error(err) + } + tx.Commit() + + places := []*Place{} + err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers := places[0], places[1], places[2] + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + placesptr := []PlacePtr{} + err = db.Select(&placesptr, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Error(err) + } + //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) + + // if you have null fields and use SELECT *, you must use sql.Null* in your struct + // this test also verifies that you can use either a []Struct{} or a []*Struct{} + places2 := []Place{} + err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers = &places2[0], &places2[1], &places2[2] + + // this should return a type error that &p is not a pointer to a struct slice + p := Place{} + err = db.Select(&p, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") + } + + // this should be an error + pl := []Place{} + err = db.Select(pl, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") + } + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + stmt, err := db.Preparex(db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) + if err != nil { + t.Error(err) + } + + places = []*Place{} + err = stmt.Select(&places, 10) + if len(places) != 2 { + t.Error("Expected 2 places, got 0.") + } + if err != nil { + t.Fatal(err) + } + singsing, honkers = places[0], places[1] + if singsing.TelCode != 65 || honkers.TelCode != 852 { + t.Errorf("Expected the right telcodes, got %#v", places) + } + + rows, err := db.Queryx("SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + place := Place{} + for rows.Next() { + err = rows.StructScan(&place) + if err != nil { + t.Fatal(err) + } + } + + rows, err = db.Queryx("SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + m := map[string]interface{}{} + for rows.Next() { + err = rows.MapScan(m) + if err != nil { + t.Fatal(err) + } + _, ok := m["country"] + if !ok { + t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) + } + } + + rows, err = db.Queryx("SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + s, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + if len(s) != 3 { + t.Errorf("Expected 3 columns in result, got %d\n", len(s)) + } + } + + // test advanced querying + // test that NamedExec works with a map as well as a struct + _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ + "first": "Bin", + "last": "Smuth", + "email": "bensmith@allblacks.nz", + }) + if err != nil { + t.Fatal(err) + } + + // ensure that if the named param happens right at the end it still works + // ensure that NamedQuery works with a map[string]interface{} + rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) + if err != nil { + t.Fatal(err) + } + + ben := &Person{} + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Bin" { + t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) + } + if ben.LastName != "Smuth" { + t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) + } + } + + ben.FirstName = "Ben" + ben.LastName = "Smith" + ben.Email = "binsmuth@allblacks.nz" + + // Insert via a named query using the struct + _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) + + if err != nil { + t.Fatal(err) + } + + rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", ben) + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + // ensure that Get does not panic on emppty result set + person := &Person{} + err = db.Get(person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") + if err == nil { + t.Fatal("Should have got an error for Get on non-existant row.") + } + + // lets test prepared statements some more + + stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.Queryx("Ben") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + + john = Person{} + stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Error(err) + } + err = stmt.Get(&john, "John") + if err != nil { + t.Error(err) + } + + // test name mapping + // THIS USED TO WORK BUT WILL NO LONGER WORK. + db.MapperFunc(strings.ToUpper) + rsa := CPlace{} + err = db.Get(&rsa, "SELECT * FROM capplace;") + if err != nil { + t.Error(err, "in db:", db.DriverName()) + } + db.MapperFunc(strings.ToLower) + + // create a copy and change the mapper, then verify the copy behaves + // differently from the original. + dbCopy := NewDb(db.DB, db.DriverName()) + dbCopy.MapperFunc(strings.ToUpper) + err = dbCopy.Get(&rsa, "SELECT * FROM capplace;") + if err != nil { + fmt.Println(db.DriverName()) + t.Error(err) + } + + err = db.Get(&rsa, "SELECT * FROM cappplace;") + if err == nil { + t.Error("Expected no error, got ", err) + } + + // test base type slices + var sdest []string + rows, err = db.Queryx("SELECT email FROM person ORDER BY email ASC;") + if err != nil { + t.Error(err) + } + err = scanAll(rows, &sdest, false) + if err != nil { + t.Error(err) + } + + // test Get with base types + var count int + err = db.Get(&count, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if count != len(sdest) { + t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) + } + + // test Get and Select with time.Time, #84 + var addedAt time.Time + err = db.Get(&addedAt, "SELECT added_at FROM person LIMIT 1;") + if err != nil { + t.Error(err) + } + + var addedAts []time.Time + err = db.Select(&addedAts, "SELECT added_at FROM person;") + if err != nil { + t.Error(err) + } + + // test it on a double pointer + var pcount *int + err = db.Get(&pcount, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if *pcount != count { + t.Errorf("expected %d = %d", *pcount, count) + } + + // test Select... + sdest = []string{} + err = db.Select(&sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") + if err != nil { + t.Error(err) + } + expected := []string{"Ben", "Bin", "Jason", "John"} + for i, got := range sdest { + if got != expected[i] { + t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) + } + } + + var nsdest []sql.NullString + err = db.Select(&nsdest, "SELECT city FROM place ORDER BY city ASC") + if err != nil { + t.Error(err) + } + for _, val := range nsdest { + if val.Valid && val.String != "New York" { + t.Errorf("expected single valid result to be `New York`, but got %s", val.String) + } + } + }) +} + +type Product struct { + ProductID int +} + +// tests that sqlx will not panic when the wrong driver is passed because +// of an automatic nil dereference in sqlx.Open(), which was fixed. +func TestDoNotPanicOnConnect(t *testing.T) { + _, err := Connect("bogus", "hehe") + if err == nil { + t.Errorf("Should return error when using bogus driverName") + } +} + +func TestRebind(t *testing.T) { + q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` + + s1 := Rebind(DOLLAR, q1) + s2 := Rebind(DOLLAR, q2) + + if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` { + t.Errorf("q1 failed") + } + + if s2 != `INSERT INTO foo (a, b, c) VALUES ($1, $2, "foo"), ("Hi", $3, $4)` { + t.Errorf("q2 failed") + } + + s1 = Rebind(NAMED, q1) + s2 = Rebind(NAMED, q2) + + ex1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ` + + `(:arg1, :arg2, :arg3, :arg4, :arg5, :arg6, :arg7, :arg8, :arg9, :arg10)` + if s1 != ex1 { + t.Error("q1 failed on Named params") + } + + ex2 := `INSERT INTO foo (a, b, c) VALUES (:arg1, :arg2, "foo"), ("Hi", :arg3, :arg4)` + if s2 != ex2 { + t.Error("q2 failed on Named params") + } +} + +func TestBindMap(t *testing.T) { + // Test that it works.. + q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + am := map[string]interface{}{ + "name": "Jason Moiron", + "age": 30, + "first": "Jason", + "last": "Moiron", + } + + bq, args, _ := bindMap(QUESTION, q1, am) + expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` + if bq != expect { + t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) + } + + if args[0].(string) != "Jason Moiron" { + t.Errorf("Expected `Jason Moiron`, got %v\n", args[0]) + } + + if args[1].(int) != 30 { + t.Errorf("Expected 30, got %v\n", args[1]) + } + + if args[2].(string) != "Jason" { + t.Errorf("Expected Jason, got %v\n", args[2]) + } + + if args[3].(string) != "Moiron" { + t.Errorf("Expected Moiron, got %v\n", args[3]) + } +} + +// Test for #117, embedded nil maps + +type Message struct { + Text string `db:"string"` + Properties PropertyMap `db:"properties"` // Stored as JSON in the database +} + +type PropertyMap map[string]string + +// Implement driver.Valuer and sql.Scanner interfaces on PropertyMap +func (p PropertyMap) Value() (driver.Value, error) { + if len(p) == 0 { + return nil, nil + } + return json.Marshal(p) +} + +func (p PropertyMap) Scan(src interface{}) error { + v := reflect.ValueOf(src) + if !v.IsValid() || v.IsNil() { + return nil + } + if data, ok := src.([]byte); ok { + return json.Unmarshal(data, &p) + } + return fmt.Errorf("Could not not decode type %T -> %T", src, p) +} + +func TestEmbeddedMaps(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE message ( + string text, + properties text + );`, + drop: `drop table message;`, + } + + RunWithSchema(schema, t, func(db *DB, t *testing.T) { + messages := []Message{ + {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, + {"Thanks, Joy", PropertyMap{"pull": "request"}}, + } + q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` + for _, m := range messages { + _, err := db.NamedExec(q1, m) + if err != nil { + t.Fatal(err) + } + } + var count int + err := db.Get(&count, "SELECT count(*) FROM message") + if err != nil { + t.Fatal(err) + } + if count != len(messages) { + t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) + } + + var m Message + err = db.Get(&m, "SELECT * FROM message LIMIT 1;") + if err != nil { + t.Fatal(err) + } + if m.Properties == nil { + t.Fatal("Expected m.Properties to not be nil, but it was.") + } + }) +} + +func TestIssue197(t *testing.T) { + // this test actually tests for a bug in database/sql: + // https://github.com/golang/go/issues/13905 + // this potentially makes _any_ named type that is an alias for []byte + // unsafe to use in a lot of different ways (basically, unsafe to hold + // onto after loading from the database). + t.Skip() + + type mybyte []byte + type Var struct{ Raw json.RawMessage } + type Var2 struct{ Raw []byte } + type Var3 struct{ Raw mybyte } + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + var err error + var v, q Var + if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.Get(&q, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v2, q2 Var2 + if err = db.Get(&v2, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.Get(&q2, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v3, q3 Var3 + if err = db.QueryRow(`SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { + t.Fatal(err) + } + if err = db.QueryRow(`SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { + t.Fatal(err) + } + t.Fail() + }) +} + +func TestIn(t *testing.T) { + // some quite normal situations + type tr struct { + q string + args []interface{} + c int + } + tests := []tr{ + {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", + []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, + 7}, + {"SELECT * FROM foo WHERE x in (?)", + []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, + 8}, + } + for _, test := range tests { + q, a, err := In(test.q, test.args...) + if err != nil { + t.Error(err) + } + if len(a) != test.c { + t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) + } + if strings.Count(q, "?") != test.c { + t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) + } + } + + // too many bindVars, but no slices, so short circuits parsing + // i'm not sure if this is the right behavior; this query/arg combo + // might not work, but we shouldn't parse if we don't need to + { + orig := "SELECT * FROM foo WHERE x = ? AND y = ?" + q, a, err := In(orig, "foo", "bar", "baz") + if err != nil { + t.Error(err) + } + if len(a) != 3 { + t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) + } + if q != orig { + t.Error("Expected unchanged query.") + } + } + + tests = []tr{ + // too many bindvars; slice present so should return error during parse + {"SELECT * FROM foo WHERE x = ? and y = ?", + []interface{}{"foo", []int{1, 2, 3}, "bar"}, + 0}, + // empty slice, should return error before parse + {"SELECT * FROM foo WHERE x = ?", + []interface{}{[]int{}}, + 0}, + // too *few* bindvars, should return an error + {"SELECT * FROM foo WHERE x = ? AND y in (?)", + []interface{}{[]int{1, 2, 3}}, + 0}, + } + for _, test := range tests { + _, _, err := In(test.q, test.args...) + if err == nil { + t.Error("Expected an error, but got nil.") + } + } + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + //tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + //tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + //tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + telcodes := []int{852, 65} + q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" + query, args, err := In(q, telcodes) + if err != nil { + t.Error(err) + } + query = db.Rebind(query) + places := []Place{} + err = db.Select(&places, query, args...) + if err != nil { + t.Error(err) + } + if len(places) != 2 { + t.Fatalf("Expecting 2 results, got %d", len(places)) + } + if places[0].TelCode != 65 { + t.Errorf("Expecting singapore first, but got %#v", places[0]) + } + if places[1].TelCode != 852 { + t.Errorf("Expecting hong kong second, but got %#v", places[1]) + } + }) +} + +func TestBindStruct(t *testing.T) { + var err error + + q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + + type tt struct { + Name string + Age int + First string + Last string + } + + type tt2 struct { + Field1 string `db:"field_1"` + Field2 string `db:"field_2"` + } + + type tt3 struct { + tt2 + Name string + } + + am := tt{"Jason Moiron", 30, "Jason", "Moiron"} + + bq, args, _ := bindStruct(QUESTION, q1, am, mapper()) + expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` + if bq != expect { + t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) + } + + if args[0].(string) != "Jason Moiron" { + t.Errorf("Expected `Jason Moiron`, got %v\n", args[0]) + } + + if args[1].(int) != 30 { + t.Errorf("Expected 30, got %v\n", args[1]) + } + + if args[2].(string) != "Jason" { + t.Errorf("Expected Jason, got %v\n", args[2]) + } + + if args[3].(string) != "Moiron" { + t.Errorf("Expected Moiron, got %v\n", args[3]) + } + + am2 := tt2{"Hello", "World"} + bq, args, _ = bindStruct(QUESTION, "INSERT INTO foo (a, b) VALUES (:field_2, :field_1)", am2, mapper()) + expect = `INSERT INTO foo (a, b) VALUES (?, ?)` + if bq != expect { + t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) + } + + if args[0].(string) != "World" { + t.Errorf("Expected 'World', got %s\n", args[0].(string)) + } + if args[1].(string) != "Hello" { + t.Errorf("Expected 'Hello', got %s\n", args[1].(string)) + } + + am3 := tt3{Name: "Hello!"} + am3.Field1 = "Hello" + am3.Field2 = "World" + + bq, args, err = bindStruct(QUESTION, "INSERT INTO foo (a, b, c) VALUES (:name, :field_1, :field_2)", am3, mapper()) + + if err != nil { + t.Fatal(err) + } + + expect = `INSERT INTO foo (a, b, c) VALUES (?, ?, ?)` + if bq != expect { + t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) + } + + if args[0].(string) != "Hello!" { + t.Errorf("Expected 'Hello!', got %s\n", args[0].(string)) + } + if args[1].(string) != "Hello" { + t.Errorf("Expected 'Hello', got %s\n", args[1].(string)) + } + if args[2].(string) != "World" { + t.Errorf("Expected 'World', got %s\n", args[0].(string)) + } +} + +func TestEmbeddedLiterals(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE x ( + k text + );`, + drop: `drop table x;`, + } + + RunWithSchema(schema, t, func(db *DB, t *testing.T) { + type t1 struct { + K *string + } + type t2 struct { + Inline struct { + F string + } + K *string + } + + db.MustExec(db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") + + target := t1{} + err := db.Get(&target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target.K != "one" { + t.Error("Expected target.K to be `one`, got ", target.K) + } + + target2 := t2{} + err = db.Get(&target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target2.K != "one" { + t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) + } + }) +} + +func BenchmarkBindStruct(b *testing.B) { + b.StopTimer() + q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + type t struct { + Name string + Age int + First string + Last string + } + am := t{"Jason Moiron", 30, "Jason", "Moiron"} + b.StartTimer() + for i := 0; i < b.N; i++ { + bindStruct(DOLLAR, q1, am, mapper()) + } +} + +func BenchmarkBindMap(b *testing.B) { + b.StopTimer() + q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + am := map[string]interface{}{ + "name": "Jason Moiron", + "age": 30, + "first": "Jason", + "last": "Moiron", + } + b.StartTimer() + for i := 0; i < b.N; i++ { + bindMap(DOLLAR, q1, am) + } +} + +func BenchmarkIn(b *testing.B) { + q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + + for i := 0; i < b.N; i++ { + _, _, _ = In(q, []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}...) + } +} + +func BenchmarkIn1k(b *testing.B) { + q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + + var vals [1000]interface{} + + for i := 0; i < b.N; i++ { + _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) + } +} + +func BenchmarkIn1kInt(b *testing.B) { + q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + + var vals [1000]int + + for i := 0; i < b.N; i++ { + _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) + } +} + +func BenchmarkIn1kString(b *testing.B) { + q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + + var vals [1000]string + + for i := 0; i < b.N; i++ { + _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) + } +} + +func BenchmarkRebind(b *testing.B) { + b.StopTimer() + q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` + q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` + b.StartTimer() + + for i := 0; i < b.N; i++ { + Rebind(DOLLAR, q1) + Rebind(DOLLAR, q2) + } +} + +func BenchmarkRebindBuffer(b *testing.B) { + b.StopTimer() + q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` + q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` + b.StartTimer() + + for i := 0; i < b.N; i++ { + rebindBuff(DOLLAR, q1) + rebindBuff(DOLLAR, q2) + } +} diff --git a/vendor/github.com/lib/pq/hstore/hstore.go b/vendor/github.com/lib/pq/hstore/hstore.go deleted file mode 100644 index 72d5abf51..000000000 --- a/vendor/github.com/lib/pq/hstore/hstore.go +++ /dev/null @@ -1,118 +0,0 @@ -package hstore - -import ( - "database/sql" - "database/sql/driver" - "strings" -) - -// A wrapper for transferring Hstore values back and forth easily. -type Hstore struct { - Map map[string]sql.NullString -} - -// escapes and quotes hstore keys/values -// s should be a sql.NullString or string -func hQuote(s interface{}) string { - var str string - switch v := s.(type) { - case sql.NullString: - if !v.Valid { - return "NULL" - } - str = v.String - case string: - str = v - default: - panic("not a string or sql.NullString") - } - - str = strings.Replace(str, "\\", "\\\\", -1) - return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` -} - -// Scan implements the Scanner interface. -// -// Note h.Map is reallocated before the scan to clear existing values. If the -// hstore column's database value is NULL, then h.Map is set to nil instead. -func (h *Hstore) Scan(value interface{}) error { - if value == nil { - h.Map = nil - return nil - } - h.Map = make(map[string]sql.NullString) - var b byte - pair := [][]byte{{}, {}} - pi := 0 - inQuote := false - didQuote := false - sawSlash := false - bindex := 0 - for bindex, b = range value.([]byte) { - if sawSlash { - pair[pi] = append(pair[pi], b) - sawSlash = false - continue - } - - switch b { - case '\\': - sawSlash = true - continue - case '"': - inQuote = !inQuote - if !didQuote { - didQuote = true - } - continue - default: - if !inQuote { - switch b { - case ' ', '\t', '\n', '\r': - continue - case '=': - continue - case '>': - pi = 1 - didQuote = false - continue - case ',': - s := string(pair[1]) - if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { - h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} - } else { - h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} - } - pair[0] = []byte{} - pair[1] = []byte{} - pi = 0 - continue - } - } - } - pair[pi] = append(pair[pi], b) - } - if bindex > 0 { - s := string(pair[1]) - if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { - h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} - } else { - h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} - } - } - return nil -} - -// Value implements the driver Valuer interface. Note if h.Map is nil, the -// database column value will be set to NULL. -func (h Hstore) Value() (driver.Value, error) { - if h.Map == nil { - return nil, nil - } - parts := []string{} - for key, val := range h.Map { - thispart := hQuote(key) + "=>" + hQuote(val) - parts = append(parts, thispart) - } - return []byte(strings.Join(parts, ",")), nil -} diff --git a/vendor/github.com/lib/pq/hstore/hstore_test.go b/vendor/github.com/lib/pq/hstore/hstore_test.go deleted file mode 100644 index 1c9f2bd49..000000000 --- a/vendor/github.com/lib/pq/hstore/hstore_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package hstore - -import ( - "database/sql" - "os" - "testing" - - _ "github.com/lib/pq" -) - -type Fatalistic interface { - Fatal(args ...interface{}) -} - -func openTestConn(t Fatalistic) *sql.DB { - datname := os.Getenv("PGDATABASE") - sslmode := os.Getenv("PGSSLMODE") - - if datname == "" { - os.Setenv("PGDATABASE", "pqgotest") - } - - if sslmode == "" { - os.Setenv("PGSSLMODE", "disable") - } - - conn, err := sql.Open("postgres", "") - if err != nil { - t.Fatal(err) - } - - return conn -} - -func TestHstore(t *testing.T) { - db := openTestConn(t) - defer db.Close() - - // quitely create hstore if it doesn't exist - _, err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore") - if err != nil { - t.Skipf("Skipping hstore tests - hstore extension create failed: %s", err.Error()) - } - - hs := Hstore{} - - // test for null-valued hstores - err = db.QueryRow("SELECT NULL::hstore").Scan(&hs) - if err != nil { - t.Fatal(err) - } - if hs.Map != nil { - t.Fatalf("expected null map") - } - - err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) - if err != nil { - t.Fatalf("re-query null map failed: %s", err.Error()) - } - if hs.Map != nil { - t.Fatalf("expected null map") - } - - // test for empty hstores - err = db.QueryRow("SELECT ''::hstore").Scan(&hs) - if err != nil { - t.Fatal(err) - } - if hs.Map == nil { - t.Fatalf("expected empty map, got null map") - } - if len(hs.Map) != 0 { - t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) - } - - err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) - if err != nil { - t.Fatalf("re-query empty map failed: %s", err.Error()) - } - if hs.Map == nil { - t.Fatalf("expected empty map, got null map") - } - if len(hs.Map) != 0 { - t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) - } - - // a few example maps to test out - hsOnePair := Hstore{ - Map: map[string]sql.NullString{ - "key1": {String: "value1", Valid: true}, - }, - } - - hsThreePairs := Hstore{ - Map: map[string]sql.NullString{ - "key1": {String: "value1", Valid: true}, - "key2": {String: "value2", Valid: true}, - "key3": {String: "value3", Valid: true}, - }, - } - - hsSmorgasbord := Hstore{ - Map: map[string]sql.NullString{ - "nullstring": {String: "NULL", Valid: true}, - "actuallynull": {String: "", Valid: false}, - "NULL": {String: "NULL string key", Valid: true}, - "withbracket": {String: "value>42", Valid: true}, - "withequal": {String: "value=42", Valid: true}, - `"withquotes1"`: {String: `this "should" be fine`, Valid: true}, - `"withquotes"2"`: {String: `this "should\" also be fine`, Valid: true}, - "embedded1": {String: "value1=>x1", Valid: true}, - "embedded2": {String: `"value2"=>x2`, Valid: true}, - "withnewlines": {String: "\n\nvalue\t=>2", Valid: true}, - "<>": {String: `this, "should,\" also, => be fine`, Valid: true}, - }, - } - - // test encoding in query params, then decoding during Scan - testBidirectional := func(h Hstore) { - err = db.QueryRow("SELECT $1::hstore", h).Scan(&hs) - if err != nil { - t.Fatalf("re-query %d-pair map failed: %s", len(h.Map), err.Error()) - } - if hs.Map == nil { - t.Fatalf("expected %d-pair map, got null map", len(h.Map)) - } - if len(hs.Map) != len(h.Map) { - t.Fatalf("expected %d-pair map, got len(map)=%d", len(h.Map), len(hs.Map)) - } - - for key, val := range hs.Map { - otherval, found := h.Map[key] - if !found { - t.Fatalf(" key '%v' not found in %d-pair map", key, len(h.Map)) - } - if otherval.Valid != val.Valid { - t.Fatalf(" value %v <> %v in %d-pair map", otherval, val, len(h.Map)) - } - if otherval.String != val.String { - t.Fatalf(" value '%v' <> '%v' in %d-pair map", otherval.String, val.String, len(h.Map)) - } - } - } - - testBidirectional(hsOnePair) - testBidirectional(hsThreePairs) - testBidirectional(hsSmorgasbord) -}