From 1ad6d6ef8913014c7369b4587425a56e5017850b Mon Sep 17 00:00:00 2001 From: Vlad Ungureanu Date: Sun, 1 Oct 2017 21:01:26 -0700 Subject: [PATCH] Use gorm for DB interactions (#18) * Swap to using golang/dep as dependency manager * Run dep prune * Vendor github.com/jinzhu/gorm * Swap raw sqlx logic for gorm --- Gopkg.lock | 17 +- Gopkg.toml | 35 +- persist/repo.go | 61 - persist/schema.go | 114 +- persist/user.go | 78 - server/endpoints/hook.go | 22 +- server/endpoints/repositories.go | 58 +- server/endpoints/token.go | 31 +- server/init.go | 12 +- server/server.go | 6 +- .../github.com/jinzhu/gorm/.codeclimate.yml | 11 + vendor/github.com/jinzhu/gorm/.gitignore | 2 + vendor/github.com/jinzhu/gorm/CONTRIBUTING.md | 52 + vendor/github.com/jinzhu/gorm/License | 21 + vendor/github.com/jinzhu/gorm/README.md | 46 + vendor/github.com/jinzhu/gorm/association.go | 359 ++++ .../jinzhu/gorm/association_test.go | 842 ++++++++ vendor/github.com/jinzhu/gorm/callback.go | 237 +++ .../github.com/jinzhu/gorm/callback_create.go | 144 ++ .../github.com/jinzhu/gorm/callback_delete.go | 53 + .../github.com/jinzhu/gorm/callback_query.go | 93 + .../jinzhu/gorm/callback_query_preload.go | 310 +++ .../github.com/jinzhu/gorm/callback_save.go | 92 + .../jinzhu/gorm/callback_system_test.go | 112 ++ .../github.com/jinzhu/gorm/callback_update.go | 104 + .../github.com/jinzhu/gorm/callbacks_test.go | 177 ++ vendor/github.com/jinzhu/gorm/create_test.go | 164 ++ .../jinzhu/gorm/customize_column_test.go | 280 +++ vendor/github.com/jinzhu/gorm/delete_test.go | 68 + vendor/github.com/jinzhu/gorm/dialect.go | 100 + .../github.com/jinzhu/gorm/dialect_common.go | 137 ++ .../github.com/jinzhu/gorm/dialect_mysql.go | 113 ++ .../jinzhu/gorm/dialect_postgres.go | 132 ++ .../github.com/jinzhu/gorm/dialect_sqlite3.go | 106 + .../jinzhu/gorm/dialects/postgres/postgres.go | 54 + .../jinzhu/gorm/embedded_struct_test.go | 48 + vendor/github.com/jinzhu/gorm/errors.go | 58 + vendor/github.com/jinzhu/gorm/field.go | 58 + vendor/github.com/jinzhu/gorm/field_test.go | 49 + vendor/github.com/jinzhu/gorm/interface.go | 19 + .../jinzhu/gorm/join_table_handler.go | 204 ++ .../github.com/jinzhu/gorm/join_table_test.go | 72 + vendor/github.com/jinzhu/gorm/logger.go | 99 + vendor/github.com/jinzhu/gorm/main.go | 700 +++++++ vendor/github.com/jinzhu/gorm/main_test.go | 774 +++++++ .../github.com/jinzhu/gorm/migration_test.go | 349 ++++ vendor/github.com/jinzhu/gorm/model.go | 14 + vendor/github.com/jinzhu/gorm/model_struct.go | 542 +++++ .../jinzhu/gorm/multi_primary_keys_test.go | 381 ++++ vendor/github.com/jinzhu/gorm/pointer_test.go | 84 + .../jinzhu/gorm/polymorphic_test.go | 219 ++ vendor/github.com/jinzhu/gorm/preload_test.go | 1327 ++++++++++++ vendor/github.com/jinzhu/gorm/query_test.go | 636 ++++++ vendor/github.com/jinzhu/gorm/scaner_test.go | 70 + vendor/github.com/jinzhu/gorm/scope.go | 1246 ++++++++++++ vendor/github.com/jinzhu/gorm/scope_test.go | 43 + vendor/github.com/jinzhu/gorm/search.go | 149 ++ vendor/github.com/jinzhu/gorm/search_test.go | 30 + vendor/github.com/jinzhu/gorm/test_all.sh | 5 + vendor/github.com/jinzhu/gorm/update_test.go | 435 ++++ vendor/github.com/jinzhu/gorm/utils.go | 264 +++ vendor/github.com/jinzhu/gorm/utils_test.go | 30 + vendor/github.com/jinzhu/inflection/LICENSE | 21 + vendor/github.com/jinzhu/inflection/README.md | 55 + .../jinzhu/inflection/inflections.go | 273 +++ .../jinzhu/inflection/inflections_test.go | 213 ++ vendor/github.com/jmoiron/sqlx/.gitignore | 24 - vendor/github.com/jmoiron/sqlx/LICENSE | 23 - vendor/github.com/jmoiron/sqlx/README.md | 183 -- vendor/github.com/jmoiron/sqlx/bind.go | 207 -- vendor/github.com/jmoiron/sqlx/doc.go | 12 - vendor/github.com/jmoiron/sqlx/named.go | 344 ---- .../github.com/jmoiron/sqlx/named_context.go | 132 -- .../jmoiron/sqlx/named_context_test.go | 136 -- vendor/github.com/jmoiron/sqlx/named_test.go | 227 --- .../jmoiron/sqlx/reflectx/README.md | 17 - .../jmoiron/sqlx/reflectx/reflect.go | 422 ---- .../jmoiron/sqlx/reflectx/reflect_test.go | 905 --------- vendor/github.com/jmoiron/sqlx/sqlx.go | 1035 ---------- .../github.com/jmoiron/sqlx/sqlx_context.go | 335 --- .../jmoiron/sqlx/sqlx_context_test.go | 1344 ------------- vendor/github.com/jmoiron/sqlx/sqlx_test.go | 1792 ----------------- vendor/github.com/lib/pq/hstore/hstore.go | 118 ++ .../github.com/lib/pq/hstore/hstore_test.go | 148 ++ 84 files changed, 12643 insertions(+), 7471 deletions(-) delete mode 100644 persist/repo.go delete mode 100644 persist/user.go create mode 100644 vendor/github.com/jinzhu/gorm/.codeclimate.yml create mode 100644 vendor/github.com/jinzhu/gorm/.gitignore create mode 100644 vendor/github.com/jinzhu/gorm/CONTRIBUTING.md create mode 100644 vendor/github.com/jinzhu/gorm/License create mode 100644 vendor/github.com/jinzhu/gorm/README.md create mode 100644 vendor/github.com/jinzhu/gorm/association.go create mode 100644 vendor/github.com/jinzhu/gorm/association_test.go create mode 100644 vendor/github.com/jinzhu/gorm/callback.go create mode 100644 vendor/github.com/jinzhu/gorm/callback_create.go create mode 100644 vendor/github.com/jinzhu/gorm/callback_delete.go create mode 100644 vendor/github.com/jinzhu/gorm/callback_query.go create mode 100644 vendor/github.com/jinzhu/gorm/callback_query_preload.go create mode 100644 vendor/github.com/jinzhu/gorm/callback_save.go create mode 100644 vendor/github.com/jinzhu/gorm/callback_system_test.go create mode 100644 vendor/github.com/jinzhu/gorm/callback_update.go create mode 100644 vendor/github.com/jinzhu/gorm/callbacks_test.go create mode 100644 vendor/github.com/jinzhu/gorm/create_test.go create mode 100644 vendor/github.com/jinzhu/gorm/customize_column_test.go create mode 100644 vendor/github.com/jinzhu/gorm/delete_test.go create mode 100644 vendor/github.com/jinzhu/gorm/dialect.go create mode 100644 vendor/github.com/jinzhu/gorm/dialect_common.go create mode 100644 vendor/github.com/jinzhu/gorm/dialect_mysql.go create mode 100644 vendor/github.com/jinzhu/gorm/dialect_postgres.go create mode 100644 vendor/github.com/jinzhu/gorm/dialect_sqlite3.go create mode 100644 vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go create mode 100644 vendor/github.com/jinzhu/gorm/embedded_struct_test.go create mode 100644 vendor/github.com/jinzhu/gorm/errors.go create mode 100644 vendor/github.com/jinzhu/gorm/field.go create mode 100644 vendor/github.com/jinzhu/gorm/field_test.go create mode 100644 vendor/github.com/jinzhu/gorm/interface.go create mode 100644 vendor/github.com/jinzhu/gorm/join_table_handler.go create mode 100644 vendor/github.com/jinzhu/gorm/join_table_test.go create mode 100644 vendor/github.com/jinzhu/gorm/logger.go create mode 100644 vendor/github.com/jinzhu/gorm/main.go create mode 100644 vendor/github.com/jinzhu/gorm/main_test.go create mode 100644 vendor/github.com/jinzhu/gorm/migration_test.go create mode 100644 vendor/github.com/jinzhu/gorm/model.go create mode 100644 vendor/github.com/jinzhu/gorm/model_struct.go create mode 100644 vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go create mode 100644 vendor/github.com/jinzhu/gorm/pointer_test.go create mode 100644 vendor/github.com/jinzhu/gorm/polymorphic_test.go create mode 100644 vendor/github.com/jinzhu/gorm/preload_test.go create mode 100644 vendor/github.com/jinzhu/gorm/query_test.go create mode 100644 vendor/github.com/jinzhu/gorm/scaner_test.go create mode 100644 vendor/github.com/jinzhu/gorm/scope.go create mode 100644 vendor/github.com/jinzhu/gorm/scope_test.go create mode 100644 vendor/github.com/jinzhu/gorm/search.go create mode 100644 vendor/github.com/jinzhu/gorm/search_test.go create mode 100755 vendor/github.com/jinzhu/gorm/test_all.sh create mode 100644 vendor/github.com/jinzhu/gorm/update_test.go create mode 100644 vendor/github.com/jinzhu/gorm/utils.go create mode 100644 vendor/github.com/jinzhu/gorm/utils_test.go create mode 100644 vendor/github.com/jinzhu/inflection/LICENSE create mode 100644 vendor/github.com/jinzhu/inflection/README.md create mode 100644 vendor/github.com/jinzhu/inflection/inflections.go create mode 100644 vendor/github.com/jinzhu/inflection/inflections_test.go delete mode 100644 vendor/github.com/jmoiron/sqlx/.gitignore delete mode 100644 vendor/github.com/jmoiron/sqlx/LICENSE delete mode 100644 vendor/github.com/jmoiron/sqlx/README.md delete mode 100644 vendor/github.com/jmoiron/sqlx/bind.go delete mode 100644 vendor/github.com/jmoiron/sqlx/doc.go delete mode 100644 vendor/github.com/jmoiron/sqlx/named.go delete mode 100644 vendor/github.com/jmoiron/sqlx/named_context.go delete mode 100644 vendor/github.com/jmoiron/sqlx/named_context_test.go delete mode 100644 vendor/github.com/jmoiron/sqlx/named_test.go delete mode 100644 vendor/github.com/jmoiron/sqlx/reflectx/README.md delete mode 100644 vendor/github.com/jmoiron/sqlx/reflectx/reflect.go delete mode 100644 vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go delete mode 100644 vendor/github.com/jmoiron/sqlx/sqlx.go delete mode 100644 vendor/github.com/jmoiron/sqlx/sqlx_context.go delete mode 100644 vendor/github.com/jmoiron/sqlx/sqlx_context_test.go delete mode 100644 vendor/github.com/jmoiron/sqlx/sqlx_test.go create mode 100644 vendor/github.com/lib/pq/hstore/hstore.go create mode 100644 vendor/github.com/lib/pq/hstore/hstore_test.go diff --git a/Gopkg.lock b/Gopkg.lock index 5c77bdb5e..aed2cea81 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -103,9 +103,16 @@ version = "v3.1.1" [[projects]] - name = "github.com/jmoiron/sqlx" - packages = [".","reflectx"] - revision = "d9bd385d68c068f1fabb5057e3dedcbcbb039d0f" + name = "github.com/jinzhu/gorm" + packages = [".","dialects/postgres"] + revision = "5174cc5c242a728b435ea2be8a2f7f998e15429b" + version = "v1.0" + +[[projects]] + branch = "master" + name = "github.com/jinzhu/inflection" + packages = ["."] + revision = "1c35d901db3da928c72a72d8458480cc9ade058f" [[projects]] name = "github.com/labstack/echo" @@ -121,7 +128,7 @@ [[projects]] name = "github.com/lib/pq" - packages = [".","oid"] + packages = [".","hstore","oid"] revision = "b77235e3890a962fe8a6f8c4c7198679ca7814e7" [[projects]] @@ -280,6 +287,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "40bc7063ffd9b03f5f059775cab1296e4b6efc4fcd2e316eedf49e91dfb7e6c9" + inputs-digest = "3f288cd94a1d8e2d6de01bba4ea5f70692f19f52d3ec13677a5dc8496f385e20" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 15d93117b..df7934aef 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -1,26 +1,3 @@ - -# Gopkg.toml example -# -# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md -# for detailed Gopkg.toml documentation. -# -# required = ["github.com/user/thing/cmd/thing"] -# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] -# -# [[constraint]] -# name = "github.com/user/project" -# version = "1.0.0" -# -# [[constraint]] -# name = "github.com/user/project2" -# branch = "dev" -# source = "github.com/myfork/project2" -# -# [[override]] -# name = "github.com/x/y" -# version = "2.4.0" - - [[constraint]] name = "github.com/google/go-github" revision = "511f540f1887d30b88cee4a2fcd1f2922754acf4" @@ -29,18 +6,10 @@ name = "github.com/ipfans/echo-session" version = "3.1.1" -[[constraint]] - name = "github.com/jmoiron/sqlx" - revision = "d9bd385d68c068f1fabb5057e3dedcbcbb039d0f" - [[constraint]] name = "github.com/labstack/echo" version = "3.2.3" -[[constraint]] - name = "github.com/lib/pq" - revision = "b77235e3890a962fe8a6f8c4c7198679ca7814e7" - [[constraint]] name = "github.com/pkg/errors" version = "0.8.0" @@ -72,3 +41,7 @@ [[constraint]] name = "gopkg.in/yaml.v2" revision = "eb3733d160e74a9c7e442f435eb3bea458e1d19f" + +[[constraint]] + name = "github.com/jinzhu/gorm" + version = "1.0.0" diff --git a/persist/repo.go b/persist/repo.go deleted file mode 100644 index 95c72ccc7..000000000 --- a/persist/repo.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2017 Palantir Technologies, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package persist - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" -) - -type Repository struct { - ID int `db:"id"` - Name string `db:"name"` - EnabledBy string `db:"enabled_by"` - EnabledAt int64 `db:"enabled_at"` - HookID int `db:"hook_id"` -} - -func (*Repository) InsertStmt() string { - return "INSERT INTO REPOS (id, name, enabled_by, enabled_at, hook_id) VALUES (:id, :name, :enabled_by, :enabled_at, :hook_id)" -} - -func (*Repository) DeleteStmt() string { - return "DELETE FROM REPOS WHERE id = :id" -} - -func (*Repository) UpdateStmt() string { - return "TODO" -} - -func GetRepositoryByID(db *sqlx.DB, repoID int) (*Repository, error) { - r := &Repository{} - - q := fmt.Sprintf("SELECT * FROM REPOS WHERE id=%d", repoID) - row := db.QueryRowx(q) - err := row.StructScan(r) - - if err != nil { - if strings.Contains(err.Error(), "no rows in result set") { - return nil, nil - } - - return nil, errors.Wrapf(err, "cannot get repo %d", repoID) - } - - return r, nil -} diff --git a/persist/schema.go b/persist/schema.go index 5ebd484b7..280435ece 100644 --- a/persist/schema.go +++ b/persist/schema.go @@ -15,108 +15,28 @@ package persist import ( - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" -) - -var metaSchema = ` -CREATE TABLE IF NOT EXISTS schema ( - version INTEGER PRIMARY KEY CHECK (version > 0) -); - -CREATE UNIQUE INDEX IF NOT EXISTS schema_one_row ON schema((TRUE));` - -var currSchemaVersion = 1 - -var schema = ` -CREATE TABLE IF NOT EXISTS USERS ( - github_id INTEGER PRIMARY KEY UNIQUE, - name TEXT, - token TEXT -); - -CREATE TABLE IF NOT EXISTS REPOS ( - id INTEGER PRIMARY KEY UNIQUE, - name TEXT, - enabled_by TEXT, - enabled_at BIGINT, - hook_id INTEGER -); -` - -// Persistable are structs that can be persisted to a DB and -// are compatible with associated utility methods in this package -type Persistable interface { - InsertStmt() string - DeleteStmt() string - UpdateStmt() string -} + "time" -// Put persists a Persistable to the given DB -func Put(db *sqlx.DB, p Persistable) error { - _, err := db.NamedExec(p.InsertStmt(), p) - if err != nil { - return errors.Wrapf(err, "failed persisting %v", p) - } - return nil -} + "github.com/jinzhu/gorm" +) -// Delete deletes a Persistable from the given DB -func Delete(db *sqlx.DB, p Persistable) error { - _, err := db.NamedExec(p.DeleteStmt(), p) - if err != nil { - return errors.Wrapf(err, "failed deleting %v", p) - } - return nil +type Repository struct { + gorm.Model + GitHubID int `gorm:"column:github_id"` + Name string + EnabledBy User + EnabledAt time.Time + HookID int } -func Update(db *sqlx.DB, p Persistable) error { - _, err := db.NamedExec(p.UpdateStmt(), p) - if err != nil { - return errors.Wrapf(err, "failed updating %v", p) - } - return nil +type User struct { + gorm.Model + GitHubID int `gorm:"column:github_id"` + Name string + Token string } // InitializeSchema initializes the schema for storing artifact data -func InitializeSchema(db *sqlx.DB) error { - version, err := getSchemaVersion(db) - if err != nil { - return errors.Wrapf(err, "failed to determine database schema version") - } - if version != currSchemaVersion { - err := migrateSchema(db, currSchemaVersion) - if err != nil { - return errors.Wrapf(err, "failed migrating database schema from version %d to %d", version, currSchemaVersion) - } - } - _, err = db.Exec(schema) - if err != nil { - return errors.Wrapf(err, "failed initializing database schema") - } - return nil -} - -func getSchemaVersion(db *sqlx.DB) (int, error) { - _, err := db.Exec(metaSchema) - if err != nil { - return 0, errors.Wrapf(err, "failed initializing schema table") - } - var version []int - err = db.Select(&version, "SELECT version FROM schema") - if err != nil { - return 0, errors.Wrapf(err, "failed querying schema table") - } - if len(version) == 0 { - _, err := db.Exec("INSERT INTO schema (version) VALUES ($1)", currSchemaVersion) - if err != nil { - return 0, errors.Wrapf(err, "failed setting schema version in database") - } - version = append(version, currSchemaVersion) - } - return version[0], nil -} - -func migrateSchema(db *sqlx.DB, schemaVersion int) error { - return errors.New("SCHEMA MIGRATION NOT IMPLEMENTED AT THIS TIME :(") +func InitializeSchema(db *gorm.DB) { + db.AutoMigrate(&Repository{}, &User{}) } diff --git a/persist/user.go b/persist/user.go deleted file mode 100644 index bf6caeeb3..000000000 --- a/persist/user.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2017 Palantir Technologies, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package persist - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" -) - -type User struct { - GithubID int `db:"github_id"` - Name string `db:"name"` - Token string `db:"token"` -} - -func (*User) InsertStmt() string { - return "INSERT INTO USERS (github_id, name, token) VALUES (:github_id, :name, :token)" -} - -func (*User) DeleteStmt() string { - return "DELETE FROM USERS WHERE github_id = :github_id" -} - -func (*User) UpdateStmt() string { - return "UPDATE USERS SET token = :token WHERE github_id = :github_id" -} - -func GetUserByName(db *sqlx.DB, name string) (*User, error) { - u := &User{} - - q := fmt.Sprintf("SELECT * FROM USERS WHERE name='%s'", name) - row := db.QueryRowx(q) - err := row.StructScan(u) - - if err != nil { - if strings.Contains(err.Error(), "no rows in result set") { - return nil, nil - } - - return nil, errors.Wrapf(err, "cannot get user %s", name) - } - - return u, nil -} - -func GetUserByID(db *sqlx.DB, id int) (*User, error) { - u := &User{} - - q := fmt.Sprintf("SELECT * FROM USERS WHERE github_id='%d'", id) - row := db.QueryRowx(q) - err := row.StructScan(u) - - if err != nil { - return nil, errors.Wrapf(err, "cannot get user %d", id) - } - - return u, nil -} - -func UpdateUserToken(db *sqlx.DB, id int, token string) error { - q := fmt.Sprintf("UPDATE USERS SET token='%s' WHERE github_id='%d'", token, id) - return db.QueryRowx(q).Err() -} diff --git a/server/endpoints/hook.go b/server/endpoints/hook.go index f4f873081..d5e42ce67 100644 --- a/server/endpoints/hook.go +++ b/server/endpoints/hook.go @@ -20,7 +20,7 @@ import ( "strings" "github.com/google/go-github/github" - "github.com/jmoiron/sqlx" + "github.com/jinzhu/gorm" "github.com/labstack/echo" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -31,7 +31,7 @@ import ( "github.com/palantir/bulldozer/server/config" ) -func Hook(db *sqlx.DB, secret string) echo.HandlerFunc { +func Hook(db *gorm.DB, secret string) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) @@ -42,18 +42,18 @@ func Hook(db *sqlx.DB, secret string) echo.HandlerFunc { logger.Debugf("ProcessHook returned %+v", result) - dbRepo, err := persist.GetRepositoryByID(db, result.RepoID) - if err != nil { - return errors.Wrapf(err, "cannot get repo with id %d from database", result.RepoID) + var dbRepo persist.Repository + res := db.Where("github_id = ?", result.RepoID).First(&dbRepo) + if err := res.Error; err != nil { + return errors.Wrap(err, "cannot get repository from db") } - - if dbRepo == nil { - return errors.Wrapf(err, "repository with ID not enabled", result.RepoID) + if res.RecordNotFound() { + return c.String(http.StatusOK, "Repository not enabled for bulldozer") } - user, err := persist.GetUserByName(db, dbRepo.EnabledBy) - if err != nil { - return errors.Wrapf(err, "cannot get user %s from database", dbRepo.EnabledBy) + var user persist.User + if err := db.First(&dbRepo.EnabledBy).Error; err != nil { + return errors.Wrapf(err, "cannot get user %s from database", dbRepo.EnabledBy.Name) } ghClient := gh.FromToken(c, user.Token) diff --git a/server/endpoints/repositories.go b/server/endpoints/repositories.go index cac5a534b..5a36c2f3d 100644 --- a/server/endpoints/repositories.go +++ b/server/endpoints/repositories.go @@ -21,7 +21,7 @@ import ( "time" "github.com/google/go-github/github" - "github.com/jmoiron/sqlx" + "github.com/jinzhu/gorm" "github.com/labstack/echo" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -41,7 +41,7 @@ type Repository struct { EnabledAt string `json:"enabledAt,omitempty"` } -func worker(c echo.Context, db *sqlx.DB, wg *sync.WaitGroup, repo *github.Repository, repoc chan *Repository, user *github.User, client *gh.Client) { +func worker(c echo.Context, db *gorm.DB, wg *sync.WaitGroup, repo *github.Repository, repoc chan *Repository, user *github.User, client *gh.Client) { logger := log.FromContext(c) defer wg.Done() @@ -57,18 +57,19 @@ func worker(c echo.Context, db *sqlx.DB, wg *sync.WaitGroup, repo *github.Reposi isAdmin = perm.GetPermission() == "admin" } - repository, err := persist.GetRepositoryByID(db, repo.GetID()) - if err != nil { + var repository persist.Repository + result := db.Where("github_id = ?", repo.GetID()).First(&repository) + if err := result.Error; err != nil && err != gorm.ErrRecordNotFound { logger.WithFields(logrus.Fields{ "repo": repo.GetFullName(), - }).Error(errors.Wrap(err, "Cannot get repository from database")) + }).Error(errors.Wrap(err, "Cannot get repository from db")) return } - if repository != nil { + if !result.RecordNotFound() { isEnabled = true - enabledBy = repository.EnabledBy - enabledAt = time.Unix(repository.EnabledAt, 0).Format(time.RFC3339) + enabledBy = repository.EnabledBy.Name + enabledAt = repository.EnabledAt.Format(time.RFC3339) } repoc <- &Repository{ @@ -82,7 +83,7 @@ func worker(c echo.Context, db *sqlx.DB, wg *sync.WaitGroup, repo *github.Reposi } } -func Repositories(db *sqlx.DB) echo.HandlerFunc { +func Repositories(db *gorm.DB) echo.HandlerFunc { return func(c echo.Context) error { var repositories []*Repository var wg sync.WaitGroup @@ -122,7 +123,7 @@ func Repositories(db *sqlx.DB) echo.HandlerFunc { } } -func RepositoryEnable(db *sqlx.DB, webHookURL string, webHookSecret string) echo.HandlerFunc { +func RepositoryEnable(db *gorm.DB, webHookURL string, webHookSecret string) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) @@ -136,6 +137,12 @@ func RepositoryEnable(db *sqlx.DB, webHookURL string, webHookSecret string) echo return errors.Wrap(err, "cannot get current user") } + var dbUser persist.User + result := db.Where("github_id = ?", user.GetID()).First(&dbUser) + if err := result.Error; err != nil && err != gorm.ErrRecordNotFound { + return errors.Wrap(err, "cannot get current user from db") + } + owner := c.Param("owner") name := c.Param("name") repo, _, err := client.Repositories.Get(client.Ctx, owner, name) @@ -182,20 +189,20 @@ func RepositoryEnable(db *sqlx.DB, webHookURL string, webHookSecret string) echo }).Info("Created hook on repository") dbRepo := &persist.Repository{ - ID: repo.GetID(), + GitHubID: repo.GetID(), Name: repo.GetFullName(), - EnabledAt: time.Now().UTC().Unix(), - EnabledBy: user.GetLogin(), + EnabledAt: time.Now().UTC(), + EnabledBy: dbUser, HookID: hook.GetID(), } - err = persist.Put(db, dbRepo) - if err != nil { + result = db.Create(dbRepo) + if err := result.Error; err != nil { _, e := client.Repositories.DeleteHook(client.Ctx, owner, name, hook.GetID()) if e != nil { - logger.Error(errors.Wrapf(err, "cannot delete hook on %s/%s (repo not saved to DB)", owner, name)) + logger.Error(errors.Wrapf(err, "cannot delete hook on %s/%s (repo not saved to db)", owner, name)) } - return errors.Wrapf(err, "cannot add %s/%s to the database", owner, name) + return errors.Wrapf(err, "cannot add %s/%s to the db", owner, name) } data := struct { @@ -220,7 +227,7 @@ func RepositoryEnable(db *sqlx.DB, webHookURL string, webHookSecret string) echo } } -func RepositoryDisable(db *sqlx.DB) echo.HandlerFunc { +func RepositoryDisable(db *gorm.DB) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) @@ -256,13 +263,15 @@ func RepositoryDisable(db *sqlx.DB) echo.HandlerFunc { "user": user.GetLogin(), }).Debug("Deleting hook from repository") - dbRepo, err := persist.GetRepositoryByID(db, repo.GetID()) - if err != nil { - return errors.Wrapf(err, "cannot get repo with ID %d from database", repo.GetID()) + var repository persist.Repository + result := db.Where("github_id = ?", repo.GetID()).First(&repository) + if err := result.Error; err != nil { + return errors.Wrap(err, "cannot get repository from db") } - _, err = client.Repositories.DeleteHook(client.Ctx, owner, name, dbRepo.HookID) + + _, err = client.Repositories.DeleteHook(client.Ctx, owner, name, repository.HookID) if err != nil { - return errors.Wrapf(err, "cannot delete hook %d for %s/%s via %s", owner, name, dbRepo.HookID, user.GetLogin()) + return errors.Wrapf(err, "cannot delete hook %d for %s/%s via %s", owner, name, repository.HookID, user.GetLogin()) } logger.WithFields(logrus.Fields{ @@ -270,8 +279,7 @@ func RepositoryDisable(db *sqlx.DB) echo.HandlerFunc { "user": user.GetLogin(), }).Info("Deleted hook from repository") - err = persist.Delete(db, dbRepo) - if err != nil { + if err := db.Delete(&repository).Error; err != nil { return errors.Wrapf(err, "cannot remove %s/%s from database", owner, name) } diff --git a/server/endpoints/token.go b/server/endpoints/token.go index 31ba757f2..d7ae2e817 100644 --- a/server/endpoints/token.go +++ b/server/endpoints/token.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/jmoiron/sqlx" + "github.com/jinzhu/gorm" "github.com/labstack/echo" "github.com/pkg/errors" @@ -28,10 +28,9 @@ import ( "github.com/palantir/bulldozer/persist" ) -func Token(db *sqlx.DB) echo.HandlerFunc { +func Token(db *gorm.DB) echo.HandlerFunc { return func(c echo.Context) error { logger := log.FromContext(c) - token, err := auth.GithubOauthConfig.Exchange(context.TODO(), c.QueryParam("code")) if err != nil { return errors.Wrap(err, "Cannot get code from GitHub") @@ -44,22 +43,22 @@ func Token(db *sqlx.DB) echo.HandlerFunc { return errors.Wrap(err, "Cannot get user from token") } - user, err := persist.GetUserByID(db, u.GetID()) - if err != nil { - dbUser := &persist.User{ - GithubID: u.GetID(), + var user persist.User + result := db.Where("github_id = ?", u.GetID()).First(&user) + if err := result.Error; err != nil && err != gorm.ErrRecordNotFound { + return errors.Wrap(err, "cannot get user from db") + } + if result.RecordNotFound() { + db.Create(&persist.User{ + GitHubID: u.GetID(), Name: u.GetLogin(), Token: accessToken, - } - if err := persist.Put(db, dbUser); err != nil { - return errors.Wrapf(err, "Cannot add %s to the database", u.GetLogin()) - } + }) } else { - if user.Token != accessToken { - if err := persist.UpdateUserToken(db, u.GetID(), accessToken); err != nil { - return errors.Wrapf(err, "Cannot update token for user %s", u.GetLogin()) - } - logger.Debugf("Updated token for user %s", u.GetLogin()) + logger.Infof("%+v", user) + user.Token = accessToken + if err := db.Save(&user).Error; err != nil { + return errors.Wrapf(err, "cannot update token for user %s", u.GetLogin()) } } diff --git a/server/init.go b/server/init.go index 3971bad2c..ebb220e84 100644 --- a/server/init.go +++ b/server/init.go @@ -17,8 +17,8 @@ package server import ( "fmt" - "github.com/jmoiron/sqlx" - _ "github.com/lib/pq" // postgres bindings + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/postgres" // Import for side-effects "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -26,7 +26,7 @@ import ( "github.com/palantir/bulldozer/server/config" ) -func InitDB(dbc *config.DatabaseConfig) (*sqlx.DB, error) { +func InitDB(dbc *config.DatabaseConfig) (*gorm.DB, error) { connectStr := fmt.Sprintf("host=%s dbname=%s user=%s sslmode=%s", dbc.Host, dbc.DBName, dbc.Username, dbc.SSLMode) log.WithFields(log.Fields{ "connectionString": connectStr, @@ -36,11 +36,11 @@ func InitDB(dbc *config.DatabaseConfig) (*sqlx.DB, error) { connectStr += fmt.Sprintf(" password=%s", dbc.Password) } - db, err := sqlx.Connect("postgres", connectStr) + db, err := gorm.Open("postgres", connectStr) if err != nil { return nil, errors.Wrapf(err, "failed connecting to postgres") } - err = persist.InitializeSchema(db) - return db, errors.Wrap(err, "failed to initialize schema") + persist.InitializeSchema(db) + return db, nil } diff --git a/server/server.go b/server/server.go index 8b6766b65..4acc27882 100644 --- a/server/server.go +++ b/server/server.go @@ -21,7 +21,7 @@ import ( "strings" "github.com/ipfans/echo-session" - "github.com/jmoiron/sqlx" + "github.com/jinzhu/gorm" "github.com/labstack/echo" "github.com/labstack/echo/middleware" "github.com/pkg/errors" @@ -38,7 +38,7 @@ type Server struct { e *echo.Echo } -func New(db *sqlx.DB, startup *config.Startup) *Server { +func New(db *gorm.DB, startup *config.Startup) *Server { e := echo.New() e.Use(bm.ContextMiddleware) @@ -58,7 +58,7 @@ func New(db *sqlx.DB, startup *config.Startup) *Server { return &Server{startup.Server, e} } -func registerEndpoints(startup *config.Startup, e *echo.Echo, db *sqlx.DB) { +func registerEndpoints(startup *config.Startup, e *echo.Echo, db *gorm.DB) { e.Static("/", startup.AssetDir) e.GET("/repositories", func(c echo.Context) error { diff --git a/vendor/github.com/jinzhu/gorm/.codeclimate.yml b/vendor/github.com/jinzhu/gorm/.codeclimate.yml new file mode 100644 index 000000000..51aba50cb --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/.codeclimate.yml @@ -0,0 +1,11 @@ +--- +engines: + gofmt: + enabled: true + govet: + enabled: true + golint: + enabled: true +ratings: + paths: + - "**.go" diff --git a/vendor/github.com/jinzhu/gorm/.gitignore b/vendor/github.com/jinzhu/gorm/.gitignore new file mode 100644 index 000000000..01dc5ce07 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/.gitignore @@ -0,0 +1,2 @@ +documents +_book diff --git a/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md b/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md new file mode 100644 index 000000000..c54d572d2 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md @@ -0,0 +1,52 @@ +# How to Contribute + +## Bug Report + +- Do a search on GitHub under Issues in case it has already been reported +- Submit __executable script__ or failing test pull request that could demonstrates the issue is *MUST HAVE* + +## Feature Request + +- Feature request with pull request is welcome +- Or it won't be implemented until I (other developers) find it is helpful for my (their) daily work + +## Pull Request + +- Prefer single commit pull request, that make the git history can be a bit easier to follow. +- New features need to be covered with tests to make sure your code works as expected, and won't be broken by others in future + +## Contributing to Documentation + +- You are welcome ;) +- You can help improve the README by making them more coherent, consistent or readable, and add more godoc documents to make people easier to follow. +- Blogs & Usage Guides & PPT also welcome, please add them to https://github.com/jinzhu/gorm/wiki/Guides + +### Executable script template + +```go +package main + +import ( + _ "github.com/mattn/go-sqlite3" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/jinzhu/gorm" +) + +var db gorm.DB + +func init() { + var err error + db, err = gorm.Open("sqlite3", "test.db") + // db, err := gorm.Open("postgres", "user=username dbname=password sslmode=disable") + // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") + if err != nil { + panic(err) + } + db.LogMode(true) +} + +func main() { + // Your code +} +``` diff --git a/vendor/github.com/jinzhu/gorm/License b/vendor/github.com/jinzhu/gorm/License new file mode 100644 index 000000000..037e1653e --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/License @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/jinzhu/gorm/README.md b/vendor/github.com/jinzhu/gorm/README.md new file mode 100644 index 000000000..c3f209c9d --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/README.md @@ -0,0 +1,46 @@ +# GORM + +The fantastic ORM library for Golang, aims to be developer friendly. + +[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) +[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) + +## Overview + +* Full-Featured ORM (almost) +* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) +* Callbacks (Before/After Create/Save/Update/Delete/Find) +* Preloading (eager loading) +* Transactions +* Composite Primary Key +* SQL Builder +* Auto Migrations +* Logger +* Extendable, write Plugins based on GORM callbacks +* Every feature comes with tests +* Developer Friendly + +## Getting Started + +* GORM Guides [jinzhu.github.com/gorm](https://jinzhu.github.io/gorm) + +## Upgrading To V1.0 + +* [CHANGELOG](https://jinzhu.github.io/gorm/changelog.html) + +# Author + +**jinzhu** + +* +* +* + +# Contributors + +https://github.com/jinzhu/gorm/graphs/contributors + +## License + +Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). diff --git a/vendor/github.com/jinzhu/gorm/association.go b/vendor/github.com/jinzhu/gorm/association.go new file mode 100644 index 000000000..cd8fd9125 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/association.go @@ -0,0 +1,359 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" +) + +// Association Mode contains some helper methods to handle relationship things easily. +type Association struct { + Error error + scope *Scope + column string + field *Field +} + +// Find find out all related associations +func (association *Association) Find(value interface{}) *Association { + association.scope.related(value, association.column) + return association.setErr(association.scope.db.Error) +} + +// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to +func (association *Association) Append(values ...interface{}) *Association { + if relationship := association.field.Relationship; relationship.Kind == "has_one" { + return association.Replace(values...) + } + return association.saveAssociations(values...) +} + +// Replace replace current associations with new one +func (association *Association) Replace(values ...interface{}) *Association { + var ( + relationship = association.field.Relationship + scope = association.scope + field = association.field.Field + newDB = scope.NewDB() + ) + + // Append new values + association.field.Set(reflect.Zero(association.field.Field.Type())) + association.saveAssociations(values...) + + // Belongs To + if relationship.Kind == "belongs_to" { + // Set foreign key to be null when clearing value (length equals 0) + if len(values) == 0 { + // Set foreign key to be nil + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } + association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) + } + } else { + // Polymorphic Relations + if relationship.PolymorphicDBName != "" { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + } + + // Delete Relations except new created + if len(values) > 0 { + var associationForeignFieldNames []string + if relationship.Kind == "many_to_many" { + // if many to many relations, get association fields name from association foreign keys + associationScope := scope.New(reflect.New(field.Type()).Interface()) + for _, dbName := range relationship.AssociationForeignFieldNames { + if field, ok := associationScope.FieldByName(dbName); ok { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } + } + } else { + // If other relations, use primary keys + for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } + } + + newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) + + if len(newPrimaryKeys) > 0 { + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) + newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) + } + } + + if relationship.Kind == "many_to_many" { + // if many to many relations, delete related relations from join table + var sourceForeignFieldNames []string + + for _, dbName := range relationship.ForeignFieldNames { + if field, ok := scope.FieldByName(dbName); ok { + sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) + } + } + + if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { + newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) + + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + } + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) + var foreignKeyMap = map[string]interface{}{} + for idx, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + fieldValue := reflect.New(association.field.Field.Type()).Interface() + association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + } + } + return association +} + +// Delete remove relationship between source & passed arguments, but won't delete those arguments +func (association *Association) Delete(values ...interface{}) *Association { + var ( + relationship = association.field.Relationship + scope = association.scope + field = association.field.Field + newDB = scope.NewDB() + ) + + if len(values) == 0 { + return association + } + + var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string + for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { + deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) + deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) + } + + deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) + + if relationship.Kind == "many_to_many" { + // source value's foreign keys + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + // get association's foreign fields name + var associationScope = scope.New(reflect.New(field.Type()).Interface()) + var associationForeignFieldNames []string + for _, associationDBName := range relationship.AssociationForeignFieldNames { + if field, ok := associationScope.FieldByName(associationDBName); ok { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } + } + + // association value's foreign keys + deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) + newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) + + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + } else { + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } + + if relationship.Kind == "belongs_to" { + // find with deleting relation's foreign keys + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) + newDB = newDB.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + + // set foreign key to be null if there are some records affected + modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() + if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { + if results.RowsAffected > 0 { + scope.updatedAttrsWithValues(foreignKeyMap) + } + } else { + association.setErr(results.Error) + } + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + // find all relations + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + newDB = newDB.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + + // only include those deleting relations + newDB = newDB.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), + toQueryValues(deletingPrimaryKeys)..., + ) + + // set matched relation's foreign key to be null + fieldValue := reflect.New(association.field.Field.Type()).Interface() + association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + } + } + + // Remove deleted records from source's field + if association.Error == nil { + if field.Kind() == reflect.Slice { + leftValues := reflect.Zero(field.Type()) + + for i := 0; i < field.Len(); i++ { + reflectValue := field.Index(i) + primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] + var isDeleted = false + for _, pk := range deletingPrimaryKeys { + if equalAsString(primaryKey, pk) { + isDeleted = true + break + } + } + if !isDeleted { + leftValues = reflect.Append(leftValues, reflectValue) + } + } + + association.field.Set(leftValues) + } else if field.Kind() == reflect.Struct { + primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] + for _, pk := range deletingPrimaryKeys { + if equalAsString(primaryKey, pk) { + association.field.Set(reflect.Zero(field.Type())) + break + } + } + } + } + + return association +} + +// Clear remove relationship between source & current associations, won't delete those associations +func (association *Association) Clear() *Association { + return association.Replace() +} + +// Count return the count of current associations +func (association *Association) Count() int { + var ( + count = 0 + relationship = association.field.Relationship + scope = association.scope + fieldValue = association.field.Field.Interface() + query = scope.DB() + ) + + if relationship.Kind == "many_to_many" { + query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) + } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + query = query.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + } else if relationship.Kind == "belongs_to" { + primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) + query = query.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + } + + if relationship.PolymorphicType != "" { + query = query.Where( + fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), + scope.TableName(), + ) + } + + query.Model(fieldValue).Count(&count) + return count +} + +// saveAssociations save passed values as associations +func (association *Association) saveAssociations(values ...interface{}) *Association { + var ( + scope = association.scope + field = association.field + relationship = field.Relationship + ) + + saveAssociation := func(reflectValue reflect.Value) { + // value has to been pointer + if reflectValue.Kind() != reflect.Ptr { + reflectPtr := reflect.New(reflectValue.Type()) + reflectPtr.Elem().Set(reflectValue) + reflectValue = reflectPtr + } + + // value has to been saved for many2many + if relationship.Kind == "many_to_many" { + if scope.New(reflectValue.Interface()).PrimaryKeyZero() { + association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) + } + } + + // Assign Fields + var fieldType = field.Field.Type() + var setFieldBackToValue, setSliceFieldBackToValue bool + if reflectValue.Type().AssignableTo(fieldType) { + field.Set(reflectValue) + } else if reflectValue.Type().Elem().AssignableTo(fieldType) { + // if field's type is struct, then need to set value back to argument after save + setFieldBackToValue = true + field.Set(reflectValue.Elem()) + } else if fieldType.Kind() == reflect.Slice { + if reflectValue.Type().AssignableTo(fieldType.Elem()) { + field.Set(reflect.Append(field.Field, reflectValue)) + } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { + // if field's type is slice of struct, then need to set value back to argument after save + setSliceFieldBackToValue = true + field.Set(reflect.Append(field.Field, reflectValue.Elem())) + } + } + + if relationship.Kind == "many_to_many" { + association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) + } else { + association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) + + if setFieldBackToValue { + reflectValue.Elem().Set(field.Field) + } else if setSliceFieldBackToValue { + reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) + } + } + } + + for _, value := range values { + reflectValue := reflect.ValueOf(value) + indirectReflectValue := reflect.Indirect(reflectValue) + if indirectReflectValue.Kind() == reflect.Struct { + saveAssociation(reflectValue) + } else if indirectReflectValue.Kind() == reflect.Slice { + for i := 0; i < indirectReflectValue.Len(); i++ { + saveAssociation(indirectReflectValue.Index(i)) + } + } else { + association.setErr(errors.New("invalid value type")) + } + } + return association +} + +func (association *Association) setErr(err error) *Association { + if err != nil { + association.Error = err + } + return association +} diff --git a/vendor/github.com/jinzhu/gorm/association_test.go b/vendor/github.com/jinzhu/gorm/association_test.go new file mode 100644 index 000000000..52d2303f6 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/association_test.go @@ -0,0 +1,842 @@ +package gorm_test + +import ( + "fmt" + "reflect" + "sort" + "testing" + + "github.com/jinzhu/gorm" +) + +func TestBelongsTo(t *testing.T) { + post := Post{ + Title: "post belongs to", + Body: "body belongs to", + Category: Category{Name: "Category 1"}, + MainCategory: Category{Name: "Main Category 1"}, + } + + if err := DB.Save(&post).Error; err != nil { + t.Error("Got errors when save post", err) + } + + if post.Category.ID == 0 || post.MainCategory.ID == 0 { + t.Errorf("Category's primary key should be updated") + } + + if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 { + t.Errorf("post's foreign key should be updated") + } + + // Query + var category1 Category + DB.Model(&post).Association("Category").Find(&category1) + if category1.Name != "Category 1" { + t.Errorf("Query belongs to relations with Association") + } + + var mainCategory1 Category + DB.Model(&post).Association("MainCategory").Find(&mainCategory1) + if mainCategory1.Name != "Main Category 1" { + t.Errorf("Query belongs to relations with Association") + } + + var category11 Category + DB.Model(&post).Related(&category11) + if category11.Name != "Category 1" { + t.Errorf("Query belongs to relations with Related") + } + + if DB.Model(&post).Association("Category").Count() != 1 { + t.Errorf("Post's category count should be 1") + } + + if DB.Model(&post).Association("MainCategory").Count() != 1 { + t.Errorf("Post's main category count should be 1") + } + + // Append + var category2 = Category{ + Name: "Category 2", + } + DB.Model(&post).Association("Category").Append(&category2) + + if category2.ID == 0 { + t.Errorf("Category should has ID when created with Append") + } + + var category21 Category + DB.Model(&post).Related(&category21) + + if category21.Name != "Category 2" { + t.Errorf("Category should be updated with Append") + } + + if DB.Model(&post).Association("Category").Count() != 1 { + t.Errorf("Post's category count should be 1") + } + + // Replace + var category3 = Category{ + Name: "Category 3", + } + DB.Model(&post).Association("Category").Replace(&category3) + + if category3.ID == 0 { + t.Errorf("Category should has ID when created with Replace") + } + + var category31 Category + DB.Model(&post).Related(&category31) + if category31.Name != "Category 3" { + t.Errorf("Category should be updated with Replace") + } + + if DB.Model(&post).Association("Category").Count() != 1 { + t.Errorf("Post's category count should be 1") + } + + // Delete + DB.Model(&post).Association("Category").Delete(&category2) + if DB.Model(&post).Related(&Category{}).RecordNotFound() { + t.Errorf("Should not delete any category when Delete a unrelated Category") + } + + if post.Category.Name == "" { + t.Errorf("Post's category should not be reseted when Delete a unrelated Category") + } + + DB.Model(&post).Association("Category").Delete(&category3) + + if post.Category.Name != "" { + t.Errorf("Post's category should be reseted after Delete") + } + + var category41 Category + DB.Model(&post).Related(&category41) + if category41.Name != "" { + t.Errorf("Category should be deleted with Delete") + } + + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Delete, but got %v", count) + } + + // Clear + DB.Model(&post).Association("Category").Append(&Category{ + Name: "Category 2", + }) + + if DB.Model(&post).Related(&Category{}).RecordNotFound() { + t.Errorf("Should find category after append") + } + + if post.Category.Name == "" { + t.Errorf("Post's category should has value after Append") + } + + DB.Model(&post).Association("Category").Clear() + + if post.Category.Name != "" { + t.Errorf("Post's category should be cleared after Clear") + } + + if !DB.Model(&post).Related(&Category{}).RecordNotFound() { + t.Errorf("Should not find any category after Clear") + } + + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Clear, but got %v", count) + } + + // Check Association mode with soft delete + category6 := Category{ + Name: "Category 6", + } + DB.Model(&post).Association("Category").Append(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 after Append, but got %v", count) + } + + DB.Delete(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) + } + + if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { + t.Errorf("Post's category is not findable after Delete") + } + + if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { + t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) + } +} + +func TestBelongsToOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileRefer"` + ProfileRefer int + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "belongs_to" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestBelongsToOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` + ProfileID int + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "belongs_to" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasOne(t *testing.T) { + user := User{ + Name: "has one", + CreditCard: CreditCard{Number: "411111111111"}, + } + + if err := DB.Save(&user).Error; err != nil { + t.Error("Got errors when save user", err.Error()) + } + + if user.CreditCard.UserId.Int64 == 0 { + t.Errorf("CreditCard's foreign key should be updated") + } + + // Query + var creditCard1 CreditCard + DB.Model(&user).Related(&creditCard1) + + if creditCard1.Number != "411111111111" { + t.Errorf("Query has one relations with Related") + } + + var creditCard11 CreditCard + DB.Model(&user).Association("CreditCard").Find(&creditCard11) + + if creditCard11.Number != "411111111111" { + t.Errorf("Query has one relations with Related") + } + + if DB.Model(&user).Association("CreditCard").Count() != 1 { + t.Errorf("User's credit card count should be 1") + } + + // Append + var creditcard2 = CreditCard{ + Number: "411111111112", + } + DB.Model(&user).Association("CreditCard").Append(&creditcard2) + + if creditcard2.ID == 0 { + t.Errorf("Creditcard should has ID when created with Append") + } + + var creditcard21 CreditCard + DB.Model(&user).Related(&creditcard21) + if creditcard21.Number != "411111111112" { + t.Errorf("CreditCard should be updated with Append") + } + + if DB.Model(&user).Association("CreditCard").Count() != 1 { + t.Errorf("User's credit card count should be 1") + } + + // Replace + var creditcard3 = CreditCard{ + Number: "411111111113", + } + DB.Model(&user).Association("CreditCard").Replace(&creditcard3) + + if creditcard3.ID == 0 { + t.Errorf("Creditcard should has ID when created with Replace") + } + + var creditcard31 CreditCard + DB.Model(&user).Related(&creditcard31) + if creditcard31.Number != "411111111113" { + t.Errorf("CreditCard should be updated with Replace") + } + + if DB.Model(&user).Association("CreditCard").Count() != 1 { + t.Errorf("User's credit card count should be 1") + } + + // Delete + DB.Model(&user).Association("CreditCard").Delete(&creditcard2) + var creditcard4 CreditCard + DB.Model(&user).Related(&creditcard4) + if creditcard4.Number != "411111111113" { + t.Errorf("Should not delete credit card when Delete a unrelated CreditCard") + } + + if DB.Model(&user).Association("CreditCard").Count() != 1 { + t.Errorf("User's credit card count should be 1") + } + + DB.Model(&user).Association("CreditCard").Delete(&creditcard3) + if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { + t.Errorf("Should delete credit card with Delete") + } + + if DB.Model(&user).Association("CreditCard").Count() != 0 { + t.Errorf("User's credit card count should be 0 after Delete") + } + + // Clear + var creditcard5 = CreditCard{ + Number: "411111111115", + } + DB.Model(&user).Association("CreditCard").Append(&creditcard5) + + if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { + t.Errorf("Should added credit card with Append") + } + + if DB.Model(&user).Association("CreditCard").Count() != 1 { + t.Errorf("User's credit card count should be 1") + } + + DB.Model(&user).Association("CreditCard").Clear() + if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { + t.Errorf("Credit card should be deleted with Clear") + } + + if DB.Model(&user).Association("CreditCard").Count() != 0 { + t.Errorf("User's credit card count should be 0 after Clear") + } + + // Check Association mode with soft delete + var creditcard6 = CreditCard{ + Number: "411111111116", + } + DB.Model(&user).Association("CreditCard").Append(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 after Append, but got %v", count) + } + + DB.Delete(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { + t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) + } + + if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { + t.Errorf("User's creditcard is not findable after Delete") + } + + if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { + t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) + } +} + +func TestHasOneOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:UserRefer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_one" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasOneOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_one" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasMany(t *testing.T) { + post := Post{ + Title: "post has many", + Body: "body has many", + Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, + } + + if err := DB.Save(&post).Error; err != nil { + t.Error("Got errors when save post", err) + } + + for _, comment := range post.Comments { + if comment.PostId == 0 { + t.Errorf("comment's PostID should be updated") + } + } + + var compareComments = func(comments []Comment, contents []string) bool { + var commentContents []string + for _, comment := range comments { + commentContents = append(commentContents, comment.Content) + } + sort.Strings(commentContents) + sort.Strings(contents) + return reflect.DeepEqual(commentContents, contents) + } + + // Query + if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { + t.Errorf("Comment 1 should be saved") + } + + var comments1 []Comment + DB.Model(&post).Association("Comments").Find(&comments1) + if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { + t.Errorf("Query has many relations with Association") + } + + var comments11 []Comment + DB.Model(&post).Related(&comments11) + if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { + t.Errorf("Query has many relations with Related") + } + + if DB.Model(&post).Association("Comments").Count() != 2 { + t.Errorf("Post's comments count should be 2") + } + + // Append + DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) + + var comments2 []Comment + DB.Model(&post).Related(&comments2) + if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { + t.Errorf("Append new record to has many relations") + } + + if DB.Model(&post).Association("Comments").Count() != 3 { + t.Errorf("Post's comments count should be 3 after Append") + } + + // Delete + DB.Model(&post).Association("Comments").Delete(comments11) + + var comments3 []Comment + DB.Model(&post).Related(&comments3) + if !compareComments(comments3, []string{"Comment 3"}) { + t.Errorf("Delete an existing resource for has many relations") + } + + if DB.Model(&post).Association("Comments").Count() != 1 { + t.Errorf("Post's comments count should be 1 after Delete 2") + } + + // Replace + DB.Model(&Post{Id: 999}).Association("Comments").Replace() + + var comments4 []Comment + DB.Model(&post).Related(&comments4) + if len(comments4) == 0 { + t.Errorf("Replace for other resource should not clear all comments") + } + + DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"}) + + var comments41 []Comment + DB.Model(&post).Related(&comments41) + if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) { + t.Errorf("Replace has many relations") + } + + // Clear + DB.Model(&Post{Id: 999}).Association("Comments").Clear() + + var comments5 []Comment + DB.Model(&post).Related(&comments5) + if len(comments5) == 0 { + t.Errorf("Clear should not clear all comments") + } + + DB.Model(&post).Association("Comments").Clear() + + var comments51 []Comment + DB.Model(&post).Related(&comments51) + if len(comments51) != 0 { + t.Errorf("Clear has many relations") + } + + // Check Association mode with soft delete + var comment6 = Comment{ + Content: "comment 6", + } + DB.Model(&post).Association("Comments").Append(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 after Append, but got %v", count) + } + + DB.Delete(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 0 { + t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) + } + + var comments6 []Comment + if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { + t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) + } + + if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) + } + + var comments61 []Comment + if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) + } +} + +func TestHasManyOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_many" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasManyOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_many" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestManyToMany(t *testing.T) { + DB.Raw("delete from languages") + var languages = []Language{{Name: "ZH"}, {Name: "EN"}} + user := User{Name: "Many2Many", Languages: languages} + DB.Save(&user) + + // Query + var newLanguages []Language + DB.Model(&user).Related(&newLanguages, "Languages") + if len(newLanguages) != len([]string{"ZH", "EN"}) { + t.Errorf("Query many to many relations") + } + + DB.Model(&user).Association("Languages").Find(&newLanguages) + if len(newLanguages) != len([]string{"ZH", "EN"}) { + t.Errorf("Should be able to find many to many relations") + } + + if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { + t.Errorf("Count should return correct result") + } + + // Append + DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) + if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { + t.Errorf("New record should be saved when append") + } + + languageA := Language{Name: "AA"} + DB.Save(&languageA) + DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) + + languageC := Language{Name: "CC"} + DB.Save(&languageC) + DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) + + DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) + + totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} + + if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { + t.Errorf("All appended languages should be saved") + } + + // Delete + user.Languages = []Language{} + DB.Model(&user).Association("Languages").Find(&user.Languages) + + var language Language + DB.Where("name = ?", "EE").First(&language) + DB.Model(&user).Association("Languages").Delete(language, &language) + + if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { + t.Errorf("Relations should be deleted with Delete") + } + if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { + t.Errorf("Language EE should not be deleted") + } + + DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) + + user2 := User{Name: "Many2Many_User2", Languages: languages} + DB.Save(&user2) + + DB.Model(&user).Association("Languages").Delete(languages, &languages) + if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { + t.Errorf("Relations should be deleted with Delete") + } + + if DB.Model(&user2).Association("Languages").Count() == 0 { + t.Errorf("Other user's relations should not be deleted") + } + + // Replace + var languageB Language + DB.Where("name = ?", "BB").First(&languageB) + DB.Model(&user).Association("Languages").Replace(languageB) + if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { + t.Errorf("Relations should be replaced") + } + + DB.Model(&user).Association("Languages").Replace() + if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { + t.Errorf("Relations should be replaced with empty") + } + + DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) + if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { + t.Errorf("Relations should be replaced") + } + + // Clear + DB.Model(&user).Association("Languages").Clear() + if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { + t.Errorf("Relations should be cleared") + } + + // Check Association mode with soft delete + var language6 = Language{ + Name: "language 6", + } + DB.Model(&user).Association("Languages").Append(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 after Append, but got %v", count) + } + + DB.Delete(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 0 { + t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) + } + + var languages6 []Language + if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { + t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) + } + + if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) + } + + var languages61 []Language + if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) + } +} + +func TestRelated(t *testing.T) { + user := User{ + Name: "jinzhu", + BillingAddress: Address{Address1: "Billing Address - Address 1"}, + ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, + Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, + CreditCard: CreditCard{Number: "1234567890"}, + Company: Company{Name: "company1"}, + } + + if err := DB.Save(&user).Error; err != nil { + t.Errorf("No error should happen when saving user") + } + + if user.CreditCard.ID == 0 { + t.Errorf("After user save, credit card should have id") + } + + if user.BillingAddress.ID == 0 { + t.Errorf("After user save, billing address should have id") + } + + if user.Emails[0].Id == 0 { + t.Errorf("After user save, billing address should have id") + } + + var emails []Email + DB.Model(&user).Related(&emails) + if len(emails) != 2 { + t.Errorf("Should have two emails") + } + + var emails2 []Email + DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) + if len(emails2) != 1 { + t.Errorf("Should have two emails") + } + + var emails3 []*Email + DB.Model(&user).Related(&emails3) + if len(emails3) != 2 { + t.Errorf("Should have two emails") + } + + var user1 User + DB.Model(&user).Related(&user1.Emails) + if len(user1.Emails) != 2 { + t.Errorf("Should have only one email match related condition") + } + + var address1 Address + DB.Model(&user).Related(&address1, "BillingAddressId") + if address1.Address1 != "Billing Address - Address 1" { + t.Errorf("Should get billing address from user correctly") + } + + user1 = User{} + DB.Model(&address1).Related(&user1, "BillingAddressId") + if DB.NewRecord(user1) { + t.Errorf("Should get user from address correctly") + } + + var user2 User + DB.Model(&emails[0]).Related(&user2) + if user2.Id != user.Id || user2.Name != user.Name { + t.Errorf("Should get user from email correctly") + } + + var creditcard CreditCard + var user3 User + DB.First(&creditcard, "number = ?", "1234567890") + DB.Model(&creditcard).Related(&user3) + if user3.Id != user.Id || user3.Name != user.Name { + t.Errorf("Should get user from credit card correctly") + } + + if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() { + t.Errorf("RecordNotFound for Related") + } + + var company Company + if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" { + t.Errorf("RecordNotFound for Related") + } +} + +func TestForeignKey(t *testing.T) { + for _, structField := range DB.NewScope(&User{}).GetStructFields() { + for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } + + for _, structField := range DB.NewScope(&Email{}).GetStructFields() { + for _, foreignKey := range []string{"UserId"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } + + for _, structField := range DB.NewScope(&Post{}).GetStructFields() { + for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } + + for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { + for _, foreignKey := range []string{"PostId"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback.go b/vendor/github.com/jinzhu/gorm/callback.go new file mode 100644 index 000000000..93198a71a --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback.go @@ -0,0 +1,237 @@ +package gorm + +import ( + "fmt" +) + +// DefaultCallback default callbacks defined by gorm +var DefaultCallback = &Callback{} + +// Callback is a struct that contains all CURD callbacks +// Field `creates` contains callbacks will be call when creating object +// Field `updates` contains callbacks will be call when updating object +// Field `deletes` contains callbacks will be call when deleting object +// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... +// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... +// Field `processors` contains all callback processors, will be used to generate above callbacks in order +type Callback struct { + creates []*func(scope *Scope) + updates []*func(scope *Scope) + deletes []*func(scope *Scope) + queries []*func(scope *Scope) + rowQueries []*func(scope *Scope) + processors []*CallbackProcessor +} + +// CallbackProcessor contains callback informations +type CallbackProcessor struct { + name string // current callback's name + before string // register current callback before a callback + after string // register current callback after a callback + replace bool // replace callbacks with same name + remove bool // delete callbacks with same name + kind string // callback type: create, update, delete, query, row_query + processor *func(scope *Scope) // callback handler + parent *Callback +} + +func (c *Callback) clone() *Callback { + return &Callback{ + creates: c.creates, + updates: c.updates, + deletes: c.deletes, + queries: c.queries, + rowQueries: c.rowQueries, + processors: c.processors, + } +} + +// Create could be used to register callbacks for creating object +// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { +// // business logic +// ... +// +// // set error if some thing wrong happened, will rollback the creating +// scope.Err(errors.New("error")) +// }) +func (c *Callback) Create() *CallbackProcessor { + return &CallbackProcessor{kind: "create", parent: c} +} + +// Update could be used to register callbacks for updating object, refer `Create` for usage +func (c *Callback) Update() *CallbackProcessor { + return &CallbackProcessor{kind: "update", parent: c} +} + +// Delete could be used to register callbacks for deleting object, refer `Create` for usage +func (c *Callback) Delete() *CallbackProcessor { + return &CallbackProcessor{kind: "delete", parent: c} +} + +// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... +// Refer `Create` for usage +func (c *Callback) Query() *CallbackProcessor { + return &CallbackProcessor{kind: "query", parent: c} +} + +// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage +func (c *Callback) RowQuery() *CallbackProcessor { + return &CallbackProcessor{kind: "row_query", parent: c} +} + +// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` +func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { + cp.after = callbackName + return cp +} + +// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` +func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { + cp.before = callbackName + return cp +} + +// Register a new callback, refer `Callbacks.Create` +func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { + cp.name = callbackName + cp.processor = &callback + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() +} + +// Remove a registered callback +// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") +func (cp *CallbackProcessor) Remove(callbackName string) { + fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.name = callbackName + cp.remove = true + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() +} + +// Replace a registered callback with new callback +// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { +// scope.SetColumn("Created", now) +// scope.SetColumn("Updated", now) +// }) +func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { + fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.name = callbackName + cp.processor = &callback + cp.replace = true + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() +} + +// Get registered callback +// db.Callback().Create().Get("gorm:create") +func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { + for _, p := range cp.parent.processors { + if p.name == callbackName && p.kind == cp.kind && !cp.remove { + return *p.processor + } + } + return nil +} + +// getRIndex get right index from string slice +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i + } + } + return -1 +} + +// sortProcessors sort callback processors based on its before, after, remove, replace +func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { + var ( + allNames, sortedNames []string + sortCallbackProcessor func(c *CallbackProcessor) + ) + + for _, cp := range cps { + // show warning message the callback name already exists + if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { + fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + } + allNames = append(allNames, cp.name) + } + + sortCallbackProcessor = func(c *CallbackProcessor) { + if getRIndex(sortedNames, c.name) == -1 { // if not sorted + if c.before != "" { // if defined before callback + if index := getRIndex(sortedNames, c.before); index != -1 { + // if before callback already sorted, append current callback just after it + sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) + } else if index := getRIndex(allNames, c.before); index != -1 { + // if before callback exists but haven't sorted, append current callback to last + sortedNames = append(sortedNames, c.name) + sortCallbackProcessor(cps[index]) + } + } + + if c.after != "" { // if defined after callback + if index := getRIndex(sortedNames, c.after); index != -1 { + // if after callback already sorted, append current callback just before it + sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) + } else if index := getRIndex(allNames, c.after); index != -1 { + // if after callback exists but haven't sorted + cp := cps[index] + // set after callback's before callback to current callback + if cp.before == "" { + cp.before = c.name + } + sortCallbackProcessor(cp) + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sortedNames, c.name) == -1 { + sortedNames = append(sortedNames, c.name) + } + } + } + + for _, cp := range cps { + sortCallbackProcessor(cp) + } + + var sortedFuncs []*func(scope *Scope) + for _, name := range sortedNames { + if index := getRIndex(allNames, name); !cps[index].remove { + sortedFuncs = append(sortedFuncs, cps[index].processor) + } + } + + return sortedFuncs +} + +// reorder all registered processors, and reset CURD callbacks +func (c *Callback) reorder() { + var creates, updates, deletes, queries, rowQueries []*CallbackProcessor + + for _, processor := range c.processors { + if processor.name != "" { + switch processor.kind { + case "create": + creates = append(creates, processor) + case "update": + updates = append(updates, processor) + case "delete": + deletes = append(deletes, processor) + case "query": + queries = append(queries, processor) + case "row_query": + rowQueries = append(rowQueries, processor) + } + } + } + + c.creates = sortProcessors(creates) + c.updates = sortProcessors(updates) + c.deletes = sortProcessors(deletes) + c.queries = sortProcessors(queries) + c.rowQueries = sortProcessors(rowQueries) +} diff --git a/vendor/github.com/jinzhu/gorm/callback_create.go b/vendor/github.com/jinzhu/gorm/callback_create.go new file mode 100644 index 000000000..e3cd2f0b4 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_create.go @@ -0,0 +1,144 @@ +package gorm + +import ( + "fmt" + "strings" +) + +// Define callbacks for creating +func init() { + DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) + DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) + DefaultCallback.Create().Register("gorm:create", createCallback) + DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) + DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) + DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) +} + +// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating +func beforeCreateCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("BeforeSave") + } + if !scope.HasError() { + scope.CallMethod("BeforeCreate") + } +} + +// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating +func updateTimeStampForCreateCallback(scope *Scope) { + if !scope.HasError() { + now := NowFunc() + scope.SetColumn("CreatedAt", now) + scope.SetColumn("UpdatedAt", now) + } +} + +// createCallback the callback used to insert data into database +func createCallback(scope *Scope) { + if !scope.HasError() { + defer scope.trace(NowFunc()) + + var ( + columns, placeholders []string + blankColumnsWithDefaultValue []string + ) + + for _, field := range scope.Fields() { + if scope.changeableField(field) { + if field.IsNormal { + if !field.IsPrimaryKey || !field.IsBlank { + if field.IsBlank && field.HasDefaultValue { + blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName) + scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) + } else { + columns = append(columns, scope.Quote(field.DBName)) + placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) + } + } + } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { + for _, foreignKey := range field.Relationship.ForeignDBNames { + if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { + columns = append(columns, scope.Quote(foreignField.DBName)) + placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) + } + } + } + } + } + + var ( + returningColumn = "*" + quotedTableName = scope.QuotedTableName() + primaryField = scope.PrimaryField() + extraOption string + ) + + if str, ok := scope.Get("gorm:insert_option"); ok { + extraOption = fmt.Sprint(str) + } + + if primaryField != nil { + returningColumn = scope.Quote(primaryField.DBName) + } + + lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + + if len(columns) == 0 { + scope.Raw(fmt.Sprintf( + "INSERT INTO %v DEFAULT VALUES%v%v", + quotedTableName, + addExtraSpaceIfExist(extraOption), + addExtraSpaceIfExist(lastInsertIDReturningSuffix), + )) + } else { + scope.Raw(fmt.Sprintf( + "INSERT INTO %v (%v) VALUES (%v)%v%v", + scope.QuotedTableName(), + strings.Join(columns, ","), + strings.Join(placeholders, ","), + addExtraSpaceIfExist(extraOption), + addExtraSpaceIfExist(lastInsertIDReturningSuffix), + )) + } + + // execute create sql + if lastInsertIDReturningSuffix == "" || primaryField == nil { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + // set primary value to primary field + if primaryField != nil && primaryField.IsBlank { + if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { + scope.Err(primaryField.Set(primaryValue)) + } + } + } + } else { + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + scope.db.RowsAffected = 1 + } + } + } +} + +// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object +func forceReloadAfterCreateCallback(scope *Scope) { + if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { + scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value) + } +} + +// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating +func afterCreateCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("AfterCreate") + } + if !scope.HasError() { + scope.CallMethod("AfterSave") + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback_delete.go b/vendor/github.com/jinzhu/gorm/callback_delete.go new file mode 100644 index 000000000..c8ffcc821 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_delete.go @@ -0,0 +1,53 @@ +package gorm + +import "fmt" + +// Define callbacks for deleting +func init() { + DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) + DefaultCallback.Delete().Register("gorm:delete", deleteCallback) + DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) + DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) +} + +// beforeDeleteCallback will invoke `BeforeDelete` method before deleting +func beforeDeleteCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("BeforeDelete") + } +} + +// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) +func deleteCallback(scope *Scope) { + if !scope.HasError() { + var extraOption string + if str, ok := scope.Get("gorm:delete_option"); ok { + extraOption = fmt.Sprint(str) + } + + if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { + scope.Raw(fmt.Sprintf( + "UPDATE %v SET deleted_at=%v%v%v", + scope.QuotedTableName(), + scope.AddToVars(NowFunc()), + addExtraSpaceIfExist(scope.CombinedConditionSql()), + addExtraSpaceIfExist(extraOption), + )).Exec() + } else { + scope.Raw(fmt.Sprintf( + "DELETE FROM %v%v%v", + scope.QuotedTableName(), + addExtraSpaceIfExist(scope.CombinedConditionSql()), + addExtraSpaceIfExist(extraOption), + )).Exec() + } + } +} + +// afterDeleteCallback will invoke `AfterDelete` method after deleting +func afterDeleteCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("AfterDelete") + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback_query.go b/vendor/github.com/jinzhu/gorm/callback_query.go new file mode 100644 index 000000000..93782b1dc --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_query.go @@ -0,0 +1,93 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" +) + +// Define callbacks for querying +func init() { + DefaultCallback.Query().Register("gorm:query", queryCallback) + DefaultCallback.Query().Register("gorm:preload", preloadCallback) + DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) +} + +// queryCallback used to query data from database +func queryCallback(scope *Scope) { + defer scope.trace(NowFunc()) + + var ( + isSlice, isPtr bool + resultType reflect.Type + results = scope.IndirectValue() + ) + + if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { + if primaryField := scope.PrimaryField(); primaryField != nil { + scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) + } + } + + if value, ok := scope.Get("gorm:query_destination"); ok { + results = reflect.Indirect(reflect.ValueOf(value)) + } + + if kind := results.Kind(); kind == reflect.Slice { + isSlice = true + resultType = results.Type().Elem() + results.Set(reflect.MakeSlice(results.Type(), 0, 0)) + + if resultType.Kind() == reflect.Ptr { + isPtr = true + resultType = resultType.Elem() + } + } else if kind != reflect.Struct { + scope.Err(errors.New("unsupported destination, should be slice or struct")) + return + } + + scope.prepareQuerySQL() + + if !scope.HasError() { + scope.db.RowsAffected = 0 + if str, ok := scope.Get("gorm:query_option"); ok { + scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) + } + + if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + defer rows.Close() + + columns, _ := rows.Columns() + for rows.Next() { + scope.db.RowsAffected++ + + elem := results + if isSlice { + elem = reflect.New(resultType).Elem() + } + + scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) + + if isSlice { + if isPtr { + results.Set(reflect.Append(results, elem.Addr())) + } else { + results.Set(reflect.Append(results, elem)) + } + } + } + + if scope.db.RowsAffected == 0 && !isSlice { + scope.Err(ErrRecordNotFound) + } + } + } +} + +// afterQueryCallback will invoke `AfterFind` method after querying +func afterQueryCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("AfterFind") + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback_query_preload.go b/vendor/github.com/jinzhu/gorm/callback_query_preload.go new file mode 100644 index 000000000..5746f533a --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_query_preload.go @@ -0,0 +1,310 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +// preloadCallback used to preload associations +func preloadCallback(scope *Scope) { + if scope.Search.preload == nil || scope.HasError() { + return + } + + var ( + preloadedMap = map[string]bool{} + fields = scope.Fields() + ) + + for _, preload := range scope.Search.preload { + var ( + preloadFields = strings.Split(preload.schema, ".") + currentScope = scope + currentFields = fields + ) + + for idx, preloadField := range preloadFields { + var currentPreloadConditions []interface{} + + // if not preloaded + if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { + + // assign search conditions to last preload + if idx == len(preloadFields)-1 { + currentPreloadConditions = preload.conditions + } + + for _, field := range currentFields { + if field.Name != preloadField || field.Relationship == nil { + continue + } + + switch field.Relationship.Kind { + case "has_one": + currentScope.handleHasOnePreload(field, currentPreloadConditions) + case "has_many": + currentScope.handleHasManyPreload(field, currentPreloadConditions) + case "belongs_to": + currentScope.handleBelongsToPreload(field, currentPreloadConditions) + case "many_to_many": + currentScope.handleManyToManyPreload(field, currentPreloadConditions) + default: + scope.Err(errors.New("unsupported relation")) + } + + preloadedMap[preloadKey] = true + break + } + + if !preloadedMap[preloadKey] { + scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) + return + } + } + + // preload next level + if idx < len(preloadFields)-1 { + currentScope = currentScope.getColumnAsScope(preloadField) + currentFields = currentScope.Fields() + } + } + } +} + +func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { + var ( + preloadDB = scope.NewDB() + preloadConditions []interface{} + ) + + for _, condition := range conditions { + if scopes, ok := condition.(func(*DB) *DB); ok { + preloadDB = scopes(preloadDB) + } else { + preloadConditions = append(preloadConditions, condition) + } + } + + return preloadDB, preloadConditions +} + +// handleHasOnePreload used to preload has one associations +func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + // get relations's primary keys + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + if len(primaryKeys) == 0 { + return + } + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // find relations + results := makeSlice(field.Struct.Type) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + + // assign find results + var ( + resultsValue = indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + if indirectScopeValue.Kind() == reflect.Slice { + foreignValues := getValueFromFields(result, relation.ForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { + indirectValue.FieldByName(field.Name).Set(result) + break + } + } + } else { + scope.Err(field.Set(result)) + } + } +} + +// handleHasManyPreload used to preload has many associations +func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + // get relations's primary keys + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + if len(primaryKeys) == 0 { + return + } + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // find relations + results := makeSlice(field.Struct.Type) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + + // assign find results + var ( + resultsValue = indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + if indirectScopeValue.Kind() == reflect.Slice { + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := getValueFromFields(result, relation.ForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) { + objectField := object.FieldByName(field.Name) + objectField.Set(reflect.Append(objectField, result)) + break + } + } + } + } else { + scope.Err(field.Set(resultsValue)) + } +} + +// handleBelongsToPreload used to preload belongs to associations +func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // get relations's primary keys + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) + if len(primaryKeys) == 0 { + return + } + + // find relations + results := makeSlice(field.Struct.Type) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + + // assign find results + var ( + resultsValue = indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + if indirectScopeValue.Kind() == reflect.Slice { + value := getValueFromFields(result, relation.AssociationForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + scope.Err(field.Set(result)) + } + } +} + +// handleManyToManyPreload used to preload many to many associations +func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { + var ( + relation = field.Relationship + joinTableHandler = relation.JoinTableHandler + fieldType = field.Struct.Type.Elem() + foreignKeyValue interface{} + foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() + linkHash = map[string][]reflect.Value{} + isPtr bool + ) + + if fieldType.Kind() == reflect.Ptr { + isPtr = true + fieldType = fieldType.Elem() + } + + var sourceKeys = []string{} + for _, key := range joinTableHandler.SourceForeignKeys() { + sourceKeys = append(sourceKeys, key.DBName) + } + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // generate query with join table + newScope := scope.New(reflect.New(fieldType).Interface()) + preloadDB = preloadDB.Table(newScope.TableName()).Select("*") + preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) + + // preload inline conditions + if len(preloadConditions) > 0 { + preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) + } + + rows, err := preloadDB.Rows() + + if scope.Err(err) != nil { + return + } + defer rows.Close() + + columns, _ := rows.Columns() + for rows.Next() { + var ( + elem = reflect.New(fieldType).Elem() + fields = scope.New(elem.Addr().Interface()).Fields() + ) + + // register foreign keys in join tables + var joinTableFields []*Field + for _, sourceKey := range sourceKeys { + joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) + } + + scope.scan(rows, columns, append(fields, joinTableFields...)) + + var foreignKeys = make([]interface{}, len(sourceKeys)) + // generate hashed forkey keys in join table + for idx, joinTableField := range joinTableFields { + if !joinTableField.Field.IsNil() { + foreignKeys[idx] = joinTableField.Field.Elem().Interface() + } + } + hashedSourceKeys := toString(foreignKeys) + + if isPtr { + linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) + } else { + linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) + } + } + + // assign find results + var ( + indirectScopeValue = scope.IndirectValue() + fieldsSourceMap = map[string]reflect.Value{} + foreignFieldNames = []string{} + ) + + for _, dbName := range relation.ForeignFieldNames { + if field, ok := scope.FieldByName(dbName); ok { + foreignFieldNames = append(foreignFieldNames, field.Name) + } + } + + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) + } + } else if indirectScopeValue.IsValid() { + fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) + } + + for source, link := range linkHash { + fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback_save.go b/vendor/github.com/jinzhu/gorm/callback_save.go new file mode 100644 index 000000000..5ffe53b97 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_save.go @@ -0,0 +1,92 @@ +package gorm + +import "reflect" + +func beginTransactionCallback(scope *Scope) { + scope.Begin() +} + +func commitOrRollbackTransactionCallback(scope *Scope) { + scope.CommitOrRollback() +} + +func saveBeforeAssociationsCallback(scope *Scope) { + if !scope.shouldSaveAssociations() { + return + } + for _, field := range scope.Fields() { + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + scope.Err(scope.NewDB().Save(fieldValue).Error) + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } + } + } + } + } + } +} + +func saveAfterAssociationsCallback(scope *Scope) { + if !scope.shouldSaveAssociations() { + return + } + for _, field := range scope.Fields() { + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if relationship := field.Relationship; relationship != nil && + (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + value := field.Field + + switch value.Kind() { + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + newDB := scope.NewDB() + elem := value.Index(i).Addr().Interface() + newScope := newDB.NewScope(elem) + + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + } + + scope.Err(newDB.Save(elem).Error) + + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } + } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + } + scope.Err(scope.NewDB().Save(elem).Error) + } + } + } + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback_system_test.go b/vendor/github.com/jinzhu/gorm/callback_system_test.go new file mode 100644 index 000000000..13ca3f428 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_system_test.go @@ -0,0 +1,112 @@ +package gorm + +import ( + "reflect" + "runtime" + "strings" + "testing" +) + +func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { + var names []string + for _, f := range funcs { + fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") + names = append(names, fnames[len(fnames)-1]) + } + return reflect.DeepEqual(names, fnames) +} + +func create(s *Scope) {} +func beforeCreate1(s *Scope) {} +func beforeCreate2(s *Scope) {} +func afterCreate1(s *Scope) {} +func afterCreate2(s *Scope) {} + +func TestRegisterCallback(t *testing.T) { + var callback = &Callback{} + + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("before_create2", beforeCreate2) + callback.Create().Register("create", create) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Register("after_create2", afterCreate2) + + if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + t.Errorf("register callback") + } +} + +func TestRegisterCallbackWithOrder(t *testing.T) { + var callback1 = &Callback{} + callback1.Create().Register("before_create1", beforeCreate1) + callback1.Create().Register("create", create) + callback1.Create().Register("after_create1", afterCreate1) + callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) + if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { + t.Errorf("register callback with order") + } + + var callback2 = &Callback{} + + callback2.Update().Register("create", create) + callback2.Update().Before("create").Register("before_create1", beforeCreate1) + callback2.Update().After("after_create2").Register("after_create1", afterCreate1) + callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) + callback2.Update().Register("after_create2", afterCreate2) + + if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { + t.Errorf("register callback with order") + } +} + +func TestRegisterCallbackWithComplexOrder(t *testing.T) { + var callback1 = &Callback{} + + callback1.Query().Before("after_create1").After("before_create1").Register("create", create) + callback1.Query().Register("before_create1", beforeCreate1) + callback1.Query().Register("after_create1", afterCreate1) + + if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { + t.Errorf("register callback with order") + } + + var callback2 = &Callback{} + + callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) + callback2.Delete().Before("create").Register("before_create1", beforeCreate1) + callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) + callback2.Delete().Register("after_create1", afterCreate1) + callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) + + if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + t.Errorf("register callback with order") + } +} + +func replaceCreate(s *Scope) {} + +func TestReplaceCallback(t *testing.T) { + var callback = &Callback{} + + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Replace("create", replaceCreate) + + if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { + t.Errorf("replace callback") + } +} + +func TestRemoveCallback(t *testing.T) { + var callback = &Callback{} + + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Remove("create") + + if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { + t.Errorf("remove callback") + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback_update.go b/vendor/github.com/jinzhu/gorm/callback_update.go new file mode 100644 index 000000000..aa27b5fb7 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_update.go @@ -0,0 +1,104 @@ +package gorm + +import ( + "fmt" + "strings" +) + +// Define callbacks for updating +func init() { + DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) + DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) + DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) + DefaultCallback.Update().Register("gorm:update", updateCallback) + DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) + DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) +} + +// assignUpdatingAttributesCallback assign updating attributes to model +func assignUpdatingAttributesCallback(scope *Scope) { + if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { + if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { + scope.InstanceSet("gorm:update_attrs", updateMaps) + } else { + scope.SkipLeft() + } + } +} + +// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating +func beforeUpdateCallback(scope *Scope) { + if _, ok := scope.Get("gorm:update_column"); !ok { + if !scope.HasError() { + scope.CallMethod("BeforeSave") + } + if !scope.HasError() { + scope.CallMethod("BeforeUpdate") + } + } +} + +// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating +func updateTimeStampForUpdateCallback(scope *Scope) { + if _, ok := scope.Get("gorm:update_column"); !ok { + scope.SetColumn("UpdatedAt", NowFunc()) + } +} + +// updateCallback the callback used to update data to database +func updateCallback(scope *Scope) { + if !scope.HasError() { + var sqls []string + + if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { + for column, value := range updateAttrs.(map[string]interface{}) { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) + } + } else { + for _, field := range scope.Fields() { + if scope.changeableField(field) { + if !field.IsPrimaryKey && field.IsNormal { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { + for _, foreignKey := range relationship.ForeignDBNames { + if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { + sqls = append(sqls, + fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) + } + } + } + } + } + } + + var extraOption string + if str, ok := scope.Get("gorm:update_option"); ok { + extraOption = fmt.Sprint(str) + } + + if len(sqls) > 0 { + scope.Raw(fmt.Sprintf( + "UPDATE %v SET %v%v%v", + scope.QuotedTableName(), + strings.Join(sqls, ", "), + addExtraSpaceIfExist(scope.CombinedConditionSql()), + addExtraSpaceIfExist(extraOption), + )).Exec() + } + } +} + +// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating +func afterUpdateCallback(scope *Scope) { + if _, ok := scope.Get("gorm:update_column"); !ok { + if !scope.HasError() { + scope.CallMethod("AfterUpdate") + } + if !scope.HasError() { + scope.CallMethod("AfterSave") + } + } +} diff --git a/vendor/github.com/jinzhu/gorm/callbacks_test.go b/vendor/github.com/jinzhu/gorm/callbacks_test.go new file mode 100644 index 000000000..a58913d76 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callbacks_test.go @@ -0,0 +1,177 @@ +package gorm_test + +import ( + "errors" + + "github.com/jinzhu/gorm" + + "reflect" + "testing" +) + +func (s *Product) BeforeCreate() (err error) { + if s.Code == "Invalid" { + err = errors.New("invalid product") + } + s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 + return +} + +func (s *Product) BeforeUpdate() (err error) { + if s.Code == "dont_update" { + err = errors.New("can't update") + } + s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 + return +} + +func (s *Product) BeforeSave() (err error) { + if s.Code == "dont_save" { + err = errors.New("can't save") + } + s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 + return +} + +func (s *Product) AfterFind() { + s.AfterFindCallTimes = s.AfterFindCallTimes + 1 +} + +func (s *Product) AfterCreate(tx *gorm.DB) { + tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) +} + +func (s *Product) AfterUpdate() { + s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 +} + +func (s *Product) AfterSave() (err error) { + if s.Code == "after_save_error" { + err = errors.New("can't save") + } + s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 + return +} + +func (s *Product) BeforeDelete() (err error) { + if s.Code == "dont_delete" { + err = errors.New("can't delete") + } + s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 + return +} + +func (s *Product) AfterDelete() (err error) { + if s.Code == "after_delete_error" { + err = errors.New("can't delete") + } + s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 + return +} + +func (s *Product) GetCallTimes() []int64 { + return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} +} + +func TestRunCallbacks(t *testing.T) { + p := Product{Code: "unique_code", Price: 100} + DB.Save(&p) + + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { + t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { + t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) + } + + p.Price = 200 + DB.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { + t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + var products []Product + DB.Find(&products, "code = ?", "unique_code") + if products[0].AfterFindCallTimes != 2 { + t.Errorf("AfterFind callbacks should work with slice") + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { + t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) + } + + DB.Delete(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { + t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { + t.Errorf("Can't find a deleted record") + } +} + +func TestCallbacksWithErrors(t *testing.T) { + p := Product{Code: "Invalid", Price: 100} + if DB.Save(&p).Error == nil { + t.Errorf("An error from before create callbacks happened when create with invalid value") + } + + if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + t.Errorf("Should not save record that have errors") + } + + if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + t.Errorf("An error from after create callbacks happened when create with invalid value") + } + + p2 := Product{Code: "update_callback", Price: 100} + DB.Save(&p2) + + p2.Code = "dont_update" + if DB.Save(&p2).Error == nil { + t.Errorf("An error from before update callbacks happened when update with invalid value") + } + + if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + t.Errorf("Record Should not be updated due to errors happened in before update callback") + } + + if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + t.Errorf("Record Should not be updated due to errors happened in before update callback") + } + + p2.Code = "dont_save" + if DB.Save(&p2).Error == nil { + t.Errorf("An error from before save callbacks happened when update with invalid value") + } + + p3 := Product{Code: "dont_delete", Price: 100} + DB.Save(&p3) + if DB.Delete(&p3).Error == nil { + t.Errorf("An error from before delete callbacks happened when delete") + } + + if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + t.Errorf("An error from before delete callbacks happened") + } + + p4 := Product{Code: "after_save_error", Price: 100} + DB.Save(&p4) + if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + t.Errorf("Record should be reverted if get an error in after save callback") + } + + p5 := Product{Code: "after_delete_error", Price: 100} + DB.Save(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Errorf("Record should be found") + } + + DB.Delete(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") + } +} diff --git a/vendor/github.com/jinzhu/gorm/create_test.go b/vendor/github.com/jinzhu/gorm/create_test.go new file mode 100644 index 000000000..dc82de50d --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/create_test.go @@ -0,0 +1,164 @@ +package gorm_test + +import ( + "os" + "reflect" + "testing" + "time" +) + +func TestCreate(t *testing.T) { + float := 35.03554004971999 + user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} + + if !DB.NewRecord(user) || !DB.NewRecord(&user) { + t.Error("User should be new record before create") + } + + if count := DB.Save(&user).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + if DB.NewRecord(user) || DB.NewRecord(&user) { + t.Error("User should not new record after save") + } + + var newUser User + DB.First(&newUser, user.Id) + + if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { + t.Errorf("User's PasswordHash should be saved ([]byte)") + } + + if newUser.Age != 18 { + t.Errorf("User's Age should be saved (int)") + } + + if newUser.UserNum != Num(111) { + t.Errorf("User's UserNum should be saved (custom type)") + } + + if newUser.Latitude != float { + t.Errorf("Float64 should not be changed after save") + } + + if user.CreatedAt.IsZero() { + t.Errorf("Should have created_at after create") + } + + if newUser.CreatedAt.IsZero() { + t.Errorf("Should have created_at after create") + } + + DB.Model(user).Update("name", "create_user_new_name") + DB.First(&user, user.Id) + if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { + t.Errorf("CreatedAt should not be changed after update") + } +} + +func TestCreateWithNoGORMPrimayKey(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { + t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") + } + + jt := JoinTable{From: 1, To: 2} + err := DB.Create(&jt).Error + if err != nil { + t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) + } +} + +func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { + animal := Animal{Name: "Ferdinand"} + if DB.Save(&animal).Error != nil { + t.Errorf("No error should happen when create a record without std primary key") + } + + if animal.Counter == 0 { + t.Errorf("No std primary key should be filled value after create") + } + + if animal.Name != "Ferdinand" { + t.Errorf("Default value should be overrided") + } + + // Test create with default value not overrided + an := Animal{From: "nerdz"} + + if DB.Save(&an).Error != nil { + t.Errorf("No error should happen when create an record without std primary key") + } + + // We must fetch the value again, to have the default fields updated + // (We can't do this in the update statements, since sql default can be expressions + // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" + DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) + + if an.Name != "galeone" { + t.Errorf("Default value should fill the field. But got %v", an.Name) + } +} + +func TestAnonymousScanner(t *testing.T) { + user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} + DB.Save(&user) + + var user2 User + DB.First(&user2, "name = ?", "anonymous_scanner") + if user2.Role.Name != "admin" { + t.Errorf("Should be able to get anonymous scanner") + } + + if !user2.IsAdmin() { + t.Errorf("Should be able to get anonymous scanner") + } +} + +func TestAnonymousField(t *testing.T) { + user := User{Name: "anonymous_field", Company: Company{Name: "company"}} + DB.Save(&user) + + var user2 User + DB.First(&user2, "name = ?", "anonymous_field") + DB.Model(&user2).Related(&user2.Company) + if user2.Company.Name != "company" { + t.Errorf("Should be able to get anonymous field") + } +} + +func TestSelectWithCreate(t *testing.T) { + user := getPreparedUser("select_user", "select_with_create") + DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) + + var queryuser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) + + if queryuser.Name != user.Name || queryuser.Age == user.Age { + t.Errorf("Should only create users with name column") + } + + if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || + queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { + t.Errorf("Should only create selected relationships") + } +} + +func TestOmitWithCreate(t *testing.T) { + user := getPreparedUser("omit_user", "omit_with_create") + DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) + + var queryuser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) + + if queryuser.Name == user.Name || queryuser.Age != user.Age { + t.Errorf("Should only create users with age column") + } + + if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || + queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { + t.Errorf("Should not create omited relationships") + } +} diff --git a/vendor/github.com/jinzhu/gorm/customize_column_test.go b/vendor/github.com/jinzhu/gorm/customize_column_test.go new file mode 100644 index 000000000..177b4a5de --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/customize_column_test.go @@ -0,0 +1,280 @@ +package gorm_test + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" +) + +type CustomizeColumn struct { + ID int64 `gorm:"column:mapped_id; primary_key:yes"` + Name string `gorm:"column:mapped_name"` + Date time.Time `gorm:"column:mapped_time"` +} + +// Make sure an ignored field does not interfere with another field's custom +// column name that matches the ignored field. +type CustomColumnAndIgnoredFieldClash struct { + Body string `sql:"-"` + RawBody string `gorm:"column:body"` +} + +func TestCustomizeColumn(t *testing.T) { + col := "mapped_name" + DB.DropTable(&CustomizeColumn{}) + DB.AutoMigrate(&CustomizeColumn{}) + + scope := DB.NewScope(&CustomizeColumn{}) + if !scope.Dialect().HasColumn(scope.TableName(), col) { + t.Errorf("CustomizeColumn should have column %s", col) + } + + col = "mapped_id" + if scope.PrimaryKey() != col { + t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) + } + + expected := "foo" + cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} + + if count := DB.Create(&cc).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + var cc1 CustomizeColumn + DB.First(&cc1, 666) + + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, 666) + if cc2.Name != "bar" { + t.Errorf("Failed to query CustomizeColumn") + } +} + +func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { + DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { + t.Errorf("Should not raise error: %s", err) + } +} + +type CustomizePerson struct { + IdPerson string `gorm:"column:idPerson;primary_key:true"` + Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` +} + +type CustomizeAccount struct { + IdAccount string `gorm:"column:idAccount;primary_key:true"` + Name string +} + +func TestManyToManyWithCustomizedColumn(t *testing.T) { + DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") + DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) + + account := CustomizeAccount{IdAccount: "account", Name: "id1"} + person := CustomizePerson{ + IdPerson: "person", + Accounts: []CustomizeAccount{account}, + } + + if err := DB.Create(&account).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if err := DB.Create(&person).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + var person1 CustomizePerson + scope := DB.NewScope(nil) + if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { + t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) + } + + if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { + t.Errorf("should preload correct accounts") + } +} + +type CustomizeUser struct { + gorm.Model + Email string `sql:"column:email_address"` +} + +type CustomizeInvitation struct { + gorm.Model + Address string `sql:"column:invitation"` + Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` +} + +func TestOneToOneWithCustomizedColumn(t *testing.T) { + DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) + DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) + + user := CustomizeUser{ + Email: "hello@example.com", + } + invitation := CustomizeInvitation{ + Address: "hello@example.com", + } + + DB.Create(&user) + DB.Create(&invitation) + + var invitation2 CustomizeInvitation + if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if invitation2.Person.Email != user.Email { + t.Errorf("Should preload one to one relation with customize foreign keys") + } +} + +type PromotionDiscount struct { + gorm.Model + Name string + Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` + Rule *PromotionRule `gorm:"ForeignKey:discount_id"` + Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` +} + +type PromotionBenefit struct { + gorm.Model + Name string + PromotionID uint + Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` +} + +type PromotionCoupon struct { + gorm.Model + Code string + DiscountID uint + Discount PromotionDiscount +} + +type PromotionRule struct { + gorm.Model + Name string + Begin *time.Time + End *time.Time + DiscountID uint + Discount *PromotionDiscount +} + +func TestOneToManyWithCustomizedColumn(t *testing.T) { + DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) + + discount := PromotionDiscount{ + Name: "Happy New Year", + Coupons: []*PromotionCoupon{ + {Code: "newyear1"}, + {Code: "newyear2"}, + }, + } + + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + var discount1 PromotionDiscount + if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if len(discount.Coupons) != 2 { + t.Errorf("should find two coupons") + } + + var coupon PromotionCoupon + if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if coupon.Discount.Name != "Happy New Year" { + t.Errorf("should preload discount from coupon") + } +} + +func TestHasOneWithPartialCustomizedColumn(t *testing.T) { + DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) + + var begin = time.Now() + var end = time.Now().Add(24 * time.Hour) + discount := PromotionDiscount{ + Name: "Happy New Year 2", + Rule: &PromotionRule{ + Name: "time_limited", + Begin: &begin, + End: &end, + }, + } + + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + var discount1 PromotionDiscount + if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { + t.Errorf("Should be able to preload Rule") + } + + var rule PromotionRule + if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if rule.Discount.Name != "Happy New Year 2" { + t.Errorf("should preload discount from rule") + } +} + +func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { + DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) + + discount := PromotionDiscount{ + Name: "Happy New Year 3", + Benefits: []PromotionBenefit{ + {Name: "free cod"}, + {Name: "free shipping"}, + }, + } + + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + var discount1 PromotionDiscount + if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if len(discount.Benefits) != 2 { + t.Errorf("should find two benefits") + } + + var benefit PromotionBenefit + if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if benefit.Discount.Name != "Happy New Year 3" { + t.Errorf("should preload discount from coupon") + } +} diff --git a/vendor/github.com/jinzhu/gorm/delete_test.go b/vendor/github.com/jinzhu/gorm/delete_test.go new file mode 100644 index 000000000..d3de0a6d9 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/delete_test.go @@ -0,0 +1,68 @@ +package gorm_test + +import ( + "testing" + "time" +) + +func TestDelete(t *testing.T) { + user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} + DB.Save(&user1) + DB.Save(&user2) + + if err := DB.Delete(&user1).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) + } + + if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { + t.Errorf("User can't be found after delete") + } + + if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { + t.Errorf("Other users that not deleted should be found-able") + } +} + +func TestInlineDelete(t *testing.T) { + user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} + DB.Save(&user1) + DB.Save(&user2) + + if DB.Delete(&User{}, user1.Id).Error != nil { + t.Errorf("No error should happen when delete a record") + } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { + t.Errorf("User can't be found after delete") + } + + if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) + } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { + t.Errorf("User can't be found after delete") + } +} + +func TestSoftDelete(t *testing.T) { + type User struct { + Id int64 + Name string + DeletedAt *time.Time + } + DB.AutoMigrate(&User{}) + + user := User{Name: "soft_delete"} + DB.Save(&user) + DB.Delete(&user) + + if DB.First(&User{}, "name = ?", user.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) + } + + DB.Unscoped().Delete(&user) + if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { + t.Errorf("Can't find permanently deleted record") + } +} diff --git a/vendor/github.com/jinzhu/gorm/dialect.go b/vendor/github.com/jinzhu/gorm/dialect.go new file mode 100644 index 000000000..6c9405da3 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/dialect.go @@ -0,0 +1,100 @@ +package gorm + +import ( + "database/sql" + "fmt" + "reflect" + "strconv" + "strings" +) + +// Dialect interface contains behaviors that differ across SQL database +type Dialect interface { + // GetName get dialect's name + GetName() string + + // SetDB set db for dialect + SetDB(db *sql.DB) + + // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 + BindVar(i int) string + // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name + Quote(key string) string + // DataTypeOf return data's sql type + DataTypeOf(field *StructField) string + + // HasIndex check has index or not + HasIndex(tableName string, indexName string) bool + // HasForeignKey check has foreign key or not + HasForeignKey(tableName string, foreignKeyName string) bool + // RemoveIndex remove index + RemoveIndex(tableName string, indexName string) error + // HasTable check has table or not + HasTable(tableName string) bool + // HasColumn check has column or not + HasColumn(tableName string, columnName string) bool + + // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case + LimitAndOffsetSQL(limit, offset int) string + // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` + SelectFromDummyTable() string + // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` + LastInsertIDReturningSuffix(tableName, columnName string) string +} + +var dialectsMap = map[string]Dialect{} + +func newDialect(name string, db *sql.DB) Dialect { + if value, ok := dialectsMap[name]; ok { + dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) + dialect.SetDB(db) + return dialect + } + + fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) + commontDialect := &commonDialect{} + commontDialect.SetDB(db) + return commontDialect +} + +// RegisterDialect register new dialect +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// ParseFieldStructForDialect parse field struct for dialect +func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { + // Get redirected field type + var reflectType = field.Struct.Type + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + + // Get redirected field value + fieldValue = reflect.Indirect(reflect.New(reflectType)) + + // Get scanner's real value + var getScannerValue func(reflect.Value) + getScannerValue = func(value reflect.Value) { + fieldValue = value + if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { + getScannerValue(fieldValue.Field(0)) + } + } + getScannerValue(fieldValue) + + // Default Size + if num, ok := field.TagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } else { + size = 255 + } + + // Default type from tag setting + additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] + if value, ok := field.TagSettings["DEFAULT"]; ok { + additionalType = additionalType + " DEFAULT " + value + } + + return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType) +} diff --git a/vendor/github.com/jinzhu/gorm/dialect_common.go b/vendor/github.com/jinzhu/gorm/dialect_common.go new file mode 100644 index 000000000..f009271b3 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/dialect_common.go @@ -0,0 +1,137 @@ +package gorm + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "time" +) + +type commonDialect struct { + db *sql.DB +} + +func init() { + RegisterDialect("common", &commonDialect{}) +} + +func (commonDialect) GetName() string { + return "common" +} + +func (s *commonDialect) SetDB(db *sql.DB) { + s.db = db +} + +func (commonDialect) BindVar(i int) string { + return "$$" // ? +} + +func (commonDialect) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) +} + +func (commonDialect) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "BOOLEAN" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + sqlType = "INTEGER AUTO_INCREMENT" + } else { + sqlType = "INTEGER" + } + case reflect.Int64, reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + sqlType = "BIGINT AUTO_INCREMENT" + } else { + sqlType = "BIGINT" + } + case reflect.Float32, reflect.Float64: + sqlType = "FLOAT" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("VARCHAR(%d)", size) + } else { + sqlType = "VARCHAR(65532)" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "TIMESTAMP" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("BINARY(%d)", size) + } else { + sqlType = "BINARY(65532)" + } + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s commonDialect) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count) + return count > 0 +} + +func (s commonDialect) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) + return err +} + +func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { + return false +} + +func (s commonDialect) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count) + return count > 0 +} + +func (s commonDialect) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) + return count > 0 +} + +func (s commonDialect) currentDatabase() (name string) { + s.db.QueryRow("SELECT DATABASE()").Scan(&name) + return +} + +func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) { + if limit > 0 || offset > 0 { + if limit >= 0 { + sql += fmt.Sprintf(" LIMIT %d", limit) + } + if offset >= 0 { + sql += fmt.Sprintf(" OFFSET %d", offset) + } + } + return +} + +func (commonDialect) SelectFromDummyTable() string { + return "" +} + +func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { + return "" +} diff --git a/vendor/github.com/jinzhu/gorm/dialect_mysql.go b/vendor/github.com/jinzhu/gorm/dialect_mysql.go new file mode 100644 index 000000000..6fade59d2 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/dialect_mysql.go @@ -0,0 +1,113 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type mysql struct { + commonDialect +} + +func init() { + RegisterDialect("mysql", &mysql{}) +} + +func (mysql) GetName() string { + return "mysql" +} + +func (mysql) Quote(key string) string { + return fmt.Sprintf("`%s`", key) +} + +// Get Data Type for MySQL Dialect +func (mysql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "int AUTO_INCREMENT" + } else { + sqlType = "int" + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "int unsigned AUTO_INCREMENT" + } else { + sqlType = "int unsigned" + } + case reflect.Int64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigint AUTO_INCREMENT" + } else { + sqlType = "bigint" + } + case reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigint unsigned AUTO_INCREMENT" + } else { + sqlType = "bigint unsigned" + } + case reflect.Float32, reflect.Float64: + sqlType = "double" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "longtext" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + if _, ok := field.TagSettings["NOT NULL"]; ok { + sqlType = "timestamp" + } else { + sqlType = "timestamp NULL" + } + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varbinary(%d)", size) + } else { + sqlType = "longblob" + } + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s mysql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) + return err +} + +func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), tableName, foreignKeyName).Scan(&count) + return count > 0 +} + +func (s mysql) currentDatabase() (name string) { + s.db.QueryRow("SELECT DATABASE()").Scan(&name) + return +} + +func (mysql) SelectFromDummyTable() string { + return "FROM DUAL" +} diff --git a/vendor/github.com/jinzhu/gorm/dialect_postgres.go b/vendor/github.com/jinzhu/gorm/dialect_postgres.go new file mode 100644 index 000000000..09ac59616 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/dialect_postgres.go @@ -0,0 +1,132 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type postgres struct { + commonDialect +} + +func init() { + RegisterDialect("postgres", &postgres{}) +} + +func (postgres) GetName() string { + return "postgres" +} + +func (postgres) BindVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (postgres) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "serial" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigserial" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "numeric" + case reflect.String: + if _, ok := field.TagSettings["SIZE"]; !ok { + size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different + } + + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "timestamp with time zone" + } + case reflect.Map: + if dataValue.Type().Name() == "Hstore" { + sqlType = "hstore" + } + default: + if isByteArrayOrSlice(dataValue) { + sqlType = "bytea" + } else if isUUID(dataValue) { + sqlType = "uuid" + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s postgres) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) + return count > 0 +} + +func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count) + return count > 0 +} + +func (s postgres) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) + return count > 0 +} + +func (s postgres) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) + return count > 0 +} + +func (s postgres) currentDatabase() (name string) { + s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) + return +} + +func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", tableName, key) +} + +func (postgres) SupportLastInsertID() bool { + return false +} + +func isByteArrayOrSlice(value reflect.Value) bool { + return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) +} + +func isUUID(value reflect.Value) bool { + if value.Kind() != reflect.Array || value.Type().Len() != 16 { + return false + } + typename := value.Type().Name() + lower := strings.ToLower(typename) + return "uuid" == lower || "guid" == lower +} diff --git a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go b/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go new file mode 100644 index 000000000..5c262aaf2 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go @@ -0,0 +1,106 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type sqlite3 struct { + commonDialect +} + +func init() { + RegisterDialect("sqlite", &sqlite3{}) + RegisterDialect("sqlite3", &sqlite3{}) +} + +func (sqlite3) GetName() string { + return "sqlite3" +} + +// Get Data Type for Sqlite Dialect +func (sqlite3) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uint64: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "real" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "datetime" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + sqlType = "blob" + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s sqlite3) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) currentDatabase() (name string) { + var ( + ifaces = make([]interface{}, 3) + pointers = make([]*string, 3) + i int + ) + for i = 0; i < 3; i++ { + ifaces[i] = &pointers[i] + } + if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { + return + } + if pointers[1] != nil { + name = *pointers[1] + } + return +} diff --git a/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go b/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go new file mode 100644 index 000000000..adeeec7bf --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go @@ -0,0 +1,54 @@ +package postgres + +import ( + "database/sql" + "database/sql/driver" + + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" +) + +type Hstore map[string]*string + +// Value get value of Hstore +func (h Hstore) Value() (driver.Value, error) { + hstore := hstore.Hstore{Map: map[string]sql.NullString{}} + if len(h) == 0 { + return nil, nil + } + + for key, value := range h { + var s sql.NullString + if value != nil { + s.String = *value + s.Valid = true + } + hstore.Map[key] = s + } + return hstore.Value() +} + +// Scan scan value into Hstore +func (h *Hstore) Scan(value interface{}) error { + hstore := hstore.Hstore{} + + if err := hstore.Scan(value); err != nil { + return err + } + + if len(hstore.Map) == 0 { + return nil + } + + *h = Hstore{} + for k := range hstore.Map { + if hstore.Map[k].Valid { + s := hstore.Map[k].String + (*h)[k] = &s + } else { + (*h)[k] = nil + } + } + + return nil +} diff --git a/vendor/github.com/jinzhu/gorm/embedded_struct_test.go b/vendor/github.com/jinzhu/gorm/embedded_struct_test.go new file mode 100644 index 000000000..7be75d990 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/embedded_struct_test.go @@ -0,0 +1,48 @@ +package gorm_test + +import "testing" + +type BasePost struct { + Id int64 + Title string + URL string +} + +type HNPost struct { + BasePost + Upvotes int32 +} + +type EngadgetPost struct { + BasePost BasePost `gorm:"embedded"` + ImageUrl string +} + +func TestSaveAndQueryEmbeddedStruct(t *testing.T) { + DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) + DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) + var news HNPost + if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if news.Title != "hn_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } + + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + var egNews EngadgetPost + if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if egNews.BasePost.Title != "engadget_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } + + if DB.NewScope(&HNPost{}).PrimaryField() == nil { + t.Errorf("primary key with embedded struct should works") + } + + for _, field := range DB.NewScope(&HNPost{}).Fields() { + if field.Name == "BasePost" { + t.Errorf("scope Fields should not contain embedded struct") + } + } +} diff --git a/vendor/github.com/jinzhu/gorm/errors.go b/vendor/github.com/jinzhu/gorm/errors.go new file mode 100644 index 000000000..ce3a25c0f --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/errors.go @@ -0,0 +1,58 @@ +package gorm + +import ( + "errors" + "strings" +) + +var ( + // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + ErrRecordNotFound = errors.New("record not found") + // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL + ErrInvalidSQL = errors.New("invalid SQL") + // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + ErrInvalidTransaction = errors.New("no valid transaction") + // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` + ErrCantStartTransaction = errors.New("can't start transaction") + // ErrUnaddressable unaddressable value + ErrUnaddressable = errors.New("using unaddressable value") +) + +type errorsInterface interface { + GetErrors() []error +} + +// Errors contains all happened errors +type Errors struct { + errors []error +} + +// GetErrors get all happened errors +func (errs Errors) GetErrors() []error { + return errs.errors +} + +// Add add an error +func (errs *Errors) Add(err error) { + if errors, ok := err.(errorsInterface); ok { + for _, err := range errors.GetErrors() { + errs.Add(err) + } + } else { + for _, e := range errs.errors { + if err == e { + return + } + } + errs.errors = append(errs.errors, err) + } +} + +// Error format happened errors +func (errs Errors) Error() string { + var errors = []string{} + for _, e := range errs.errors { + errors = append(errors, e.Error()) + } + return strings.Join(errors, "; ") +} diff --git a/vendor/github.com/jinzhu/gorm/field.go b/vendor/github.com/jinzhu/gorm/field.go new file mode 100644 index 000000000..11c410b0f --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/field.go @@ -0,0 +1,58 @@ +package gorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" +) + +// Field model field definition +type Field struct { + *StructField + IsBlank bool + Field reflect.Value +} + +// Set set a value to the field +func (field *Field) Set(value interface{}) (err error) { + if !field.Field.IsValid() { + return errors.New("field value not valid") + } + + if !field.Field.CanAddr() { + return ErrUnaddressable + } + + reflectValue, ok := value.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(value) + } + + fieldValue := field.Field + if reflectValue.IsValid() { + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) + } else { + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.Struct.Type.Elem())) + } + fieldValue = fieldValue.Elem() + } + + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) + } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + err = scanner.Scan(reflectValue.Interface()) + } else { + err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) + } + } + } else { + field.Field.Set(reflect.Zero(field.Field.Type())) + } + + field.IsBlank = isBlank(field.Field) + return err +} diff --git a/vendor/github.com/jinzhu/gorm/field_test.go b/vendor/github.com/jinzhu/gorm/field_test.go new file mode 100644 index 000000000..30e9a778d --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/field_test.go @@ -0,0 +1,49 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +type CalculateField struct { + gorm.Model + Name string + Children []CalculateFieldChild + Category CalculateFieldCategory + EmbeddedField +} + +type EmbeddedField struct { + EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` +} + +type CalculateFieldChild struct { + gorm.Model + CalculateFieldID uint + Name string +} + +type CalculateFieldCategory struct { + gorm.Model + CalculateFieldID uint + Name string +} + +func TestCalculateField(t *testing.T) { + var field CalculateField + var scope = DB.NewScope(&field) + if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { + t.Errorf("Should calculate fields correctly for the first time") + } + + if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { + t.Errorf("Should calculate fields correctly for the first time") + } + + if field, ok := scope.FieldByName("embedded_name"); !ok { + t.Errorf("should find embedded field") + } else if _, ok := field.TagSettings["NOT NULL"]; !ok { + t.Errorf("should find embedded field's tag settings") + } +} diff --git a/vendor/github.com/jinzhu/gorm/interface.go b/vendor/github.com/jinzhu/gorm/interface.go new file mode 100644 index 000000000..7b02aa664 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/interface.go @@ -0,0 +1,19 @@ +package gorm + +import "database/sql" + +type sqlCommon interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +type sqlDb interface { + Begin() (*sql.Tx, error) +} + +type sqlTx interface { + Commit() error + Rollback() error +} diff --git a/vendor/github.com/jinzhu/gorm/join_table_handler.go b/vendor/github.com/jinzhu/gorm/join_table_handler.go new file mode 100644 index 000000000..18c12a859 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/join_table_handler.go @@ -0,0 +1,204 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +// JoinTableHandlerInterface is an interface for how to handle many2many relations +type JoinTableHandlerInterface interface { + // initialize join table handler + Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) + // Table return join table's table name + Table(db *DB) string + // Add create relationship in join table for source and destination + Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error + // Delete delete relationship in join table for sources + Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error + // JoinWith query with `Join` conditions + JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + // SourceForeignKeys return source foreign keys + SourceForeignKeys() []JoinTableForeignKey + // DestinationForeignKeys return destination foreign keys + DestinationForeignKeys() []JoinTableForeignKey +} + +// JoinTableForeignKey join table foreign key struct +type JoinTableForeignKey struct { + DBName string + AssociationDBName string +} + +// JoinTableSource is a struct that contains model type and foreign keys +type JoinTableSource struct { + ModelType reflect.Type + ForeignKeys []JoinTableForeignKey +} + +// JoinTableHandler default join table handler +type JoinTableHandler struct { + TableName string `sql:"-"` + Source JoinTableSource `sql:"-"` + Destination JoinTableSource `sql:"-"` +} + +// SourceForeignKeys return source foreign keys +func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { + return s.Source.ForeignKeys +} + +// DestinationForeignKeys return destination foreign keys +func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { + return s.Destination.ForeignKeys +} + +// Setup initialize a default join table handler +func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { + s.TableName = tableName + + s.Source = JoinTableSource{ModelType: source} + for idx, dbName := range relationship.ForeignFieldNames { + s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ + DBName: relationship.ForeignDBNames[idx], + AssociationDBName: dbName, + }) + } + + s.Destination = JoinTableSource{ModelType: destination} + for idx, dbName := range relationship.AssociationForeignFieldNames { + s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ + DBName: relationship.AssociationForeignDBNames[idx], + AssociationDBName: dbName, + }) + } +} + +// Table return join table's table name +func (s JoinTableHandler) Table(db *DB) string { + return s.TableName +} + +func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { + values := map[string]interface{}{} + + for _, source := range sources { + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Source.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + values[foreignKey.DBName] = field.Field.Interface() + } + } + } else if s.Destination.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + values[foreignKey.DBName] = field.Field.Interface() + } + } + } + } + return values +} + +// Add create relationship in join table for source and destination +func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { + scope := db.NewScope("") + searchMap := s.getSearchMap(db, source, destination) + + var assignColumns, binVars, conditions []string + var values []interface{} + for key, value := range searchMap { + assignColumns = append(assignColumns, scope.Quote(key)) + binVars = append(binVars, `?`) + conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) + values = append(values, value) + } + + for _, value := range values { + values = append(values, value) + } + + quotedTable := scope.Quote(handler.Table(db)) + sql := fmt.Sprintf( + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", + quotedTable, + strings.Join(assignColumns, ","), + strings.Join(binVars, ","), + scope.Dialect().SelectFromDummyTable(), + quotedTable, + strings.Join(conditions, " AND "), + ) + + return db.Exec(sql, values...).Error +} + +// Delete delete relationship in join table for sources +func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { + var ( + scope = db.NewScope(nil) + conditions []string + values []interface{} + ) + + for key, value := range s.getSearchMap(db, sources...) { + conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) + values = append(values, value) + } + + return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error +} + +// JoinWith query with `Join` conditions +func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { + var ( + scope = db.NewScope(source) + tableName = handler.Table(db) + quotedTableName = scope.Quote(tableName) + joinConditions []string + values []interface{} + ) + + if s.Source.ModelType == scope.GetModelStruct().ModelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() + for _, foreignKey := range s.Destination.ForeignKeys { + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) + } + + var foreignDBNames []string + var foreignFieldNames []string + + for _, foreignKey := range s.Source.ForeignKeys { + foreignDBNames = append(foreignDBNames, foreignKey.DBName) + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + foreignFieldNames = append(foreignFieldNames, field.Name) + } + } + + foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) + + var condString string + if len(foreignFieldValues) > 0 { + var quotedForeignDBNames []string + for _, dbName := range foreignDBNames { + quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) + } + + condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) + + keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) + values = append(values, toQueryValues(keys)) + } else { + condString = fmt.Sprintf("1 <> 1") + } + + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). + Where(condString, toQueryValues(foreignFieldValues)...) + } + + db.Error = errors.New("wrong source type for join table handler") + return db +} diff --git a/vendor/github.com/jinzhu/gorm/join_table_test.go b/vendor/github.com/jinzhu/gorm/join_table_test.go new file mode 100644 index 000000000..1a83a9c87 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/join_table_test.go @@ -0,0 +1,72 @@ +package gorm_test + +import ( + "fmt" + "testing" + "time" + + "github.com/jinzhu/gorm" +) + +type Person struct { + Id int + Name string + Addresses []*Address `gorm:"many2many:person_addresses;"` +} + +type PersonAddress struct { + gorm.JoinTableHandler + PersonID int + AddressID int + DeletedAt *time.Time + CreatedAt time.Time +} + +func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { + return db.Where(map[string]interface{}{ + "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), + "address_id": db.NewScope(associationValue).PrimaryKeyValue(), + }).Assign(map[string]interface{}{ + "person_id": foreignValue, + "address_id": associationValue, + "deleted_at": gorm.Expr("NULL"), + }).FirstOrCreate(&PersonAddress{}).Error +} + +func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { + return db.Delete(&PersonAddress{}).Error +} + +func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { + table := pa.Table(db) + return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) +} + +func TestJoinTable(t *testing.T) { + DB.Exec("drop table person_addresses;") + DB.AutoMigrate(&Person{}) + DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) + + address1 := &Address{Address1: "address 1"} + address2 := &Address{Address1: "address 2"} + person := &Person{Name: "person", Addresses: []*Address{address1, address2}} + DB.Save(person) + + DB.Model(person).Association("Addresses").Delete(address1) + + if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { + t.Errorf("Should found one address") + } + + if DB.Model(person).Association("Addresses").Count() != 1 { + t.Errorf("Should found one address") + } + + if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { + t.Errorf("Found two addresses with Unscoped") + } + + if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { + t.Errorf("Should deleted all addresses") + } +} diff --git a/vendor/github.com/jinzhu/gorm/logger.go b/vendor/github.com/jinzhu/gorm/logger.go new file mode 100644 index 000000000..2c4ccbbc4 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/logger.go @@ -0,0 +1,99 @@ +package gorm + +import ( + "database/sql/driver" + "fmt" + "log" + "os" + "reflect" + "regexp" + "time" + "unicode" +) + +var ( + defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} + sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) +) + +type logger interface { + Print(v ...interface{}) +} + +// LogWriter log writer interface +type LogWriter interface { + Println(v ...interface{}) +} + +// Logger default logger +type Logger struct { + LogWriter +} + +// Print format & print log +func (logger Logger) Print(values ...interface{}) { + if len(values) > 1 { + level := values[0] + currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" + source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) + messages := []interface{}{source, currentTime} + + if level == "sql" { + // duration + messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) + // sql + var sql string + var formattedValues []string + + for _, value := range values[4].([]interface{}) { + indirectValue := reflect.Indirect(reflect.ValueOf(value)) + if indirectValue.IsValid() { + value = indirectValue.Interface() + if t, ok := value.(time.Time); ok { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) + } else if b, ok := value.([]byte); ok { + if str := string(b); isPrintable(str) { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) + } else { + formattedValues = append(formattedValues, "''") + } + } else if r, ok := value.(driver.Valuer); ok { + if value, err := r.Value(); err == nil && value != nil { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + } else { + formattedValues = append(formattedValues, "NULL") + } + } else { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + } + } else { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + } + } + + var formattedValuesLength = len(formattedValues) + for index, value := range sqlRegexp.Split(values[3].(string), -1) { + sql += value + if index < formattedValuesLength { + sql += formattedValues[index] + } + } + + messages = append(messages, sql) + } else { + messages = append(messages, "\033[31;1m") + messages = append(messages, values[2:]...) + messages = append(messages, "\033[0m") + } + logger.Println(messages...) + } +} + +func isPrintable(s string) bool { + for _, r := range s { + if !unicode.IsPrint(r) { + return false + } + } + return true +} diff --git a/vendor/github.com/jinzhu/gorm/main.go b/vendor/github.com/jinzhu/gorm/main.go new file mode 100644 index 000000000..cd4455551 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/main.go @@ -0,0 +1,700 @@ +package gorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +// DB contains information for current db connection +type DB struct { + Value interface{} + Error error + RowsAffected int64 + callbacks *Callback + db sqlCommon + parent *DB + search *search + logMode int + logger logger + dialect Dialect + singularTable bool + source string + values map[string]interface{} + joinTableHandlers map[string]JoinTableHandler +} + +// Open initialize a new db connection, need to import driver first, e.g: +// +// import _ "github.com/go-sql-driver/mysql" +// func main() { +// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") +// } +// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with +// import _ "github.com/jinzhu/gorm/dialects/mysql" +// // import _ "github.com/jinzhu/gorm/dialects/postgres" +// // import _ "github.com/jinzhu/gorm/dialects/sqlite" +// // import _ "github.com/jinzhu/gorm/dialects/mssql" +func Open(dialect string, args ...interface{}) (*DB, error) { + var db DB + var err error + + if len(args) == 0 { + err = errors.New("invalid database source") + } else { + var source string + var dbSQL sqlCommon + + switch value := args[0].(type) { + case string: + var driver = dialect + if len(args) == 1 { + source = value + } else if len(args) >= 2 { + driver = value + source = args[1].(string) + } + dbSQL, err = sql.Open(driver, source) + case sqlCommon: + source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() + dbSQL = value + } + + db = DB{ + dialect: newDialect(dialect, dbSQL.(*sql.DB)), + logger: defaultLogger, + callbacks: DefaultCallback, + source: source, + values: map[string]interface{}{}, + db: dbSQL, + } + db.parent = &db + + if err == nil { + err = db.DB().Ping() // Send a ping to make sure the database connection is alive. + } + } + + return &db, err +} + +// Close close current db connection +func (s *DB) Close() error { + return s.parent.db.(*sql.DB).Close() +} + +// DB get `*sql.DB` from current connection +func (s *DB) DB() *sql.DB { + return s.db.(*sql.DB) +} + +// New clone a new db connection without search conditions +func (s *DB) New() *DB { + clone := s.clone() + clone.search = nil + clone.Value = nil + return clone +} + +// NewScope create a scope for current operation +func (s *DB) NewScope(value interface{}) *Scope { + dbClone := s.clone() + dbClone.Value = value + return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} +} + +// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. +func (s *DB) CommonDB() sqlCommon { + return s.db +} + +// Callback return `Callbacks` container, you could add/change/delete callbacks with it +// db.Callback().Create().Register("update_created_at", updateCreated) +// Refer https://jinzhu.github.io/gorm/development.html#callbacks +func (s *DB) Callback() *Callback { + s.parent.callbacks = s.parent.callbacks.clone() + return s.parent.callbacks +} + +// SetLogger replace default logger +func (s *DB) SetLogger(log logger) { + s.logger = log +} + +// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs +func (s *DB) LogMode(enable bool) *DB { + if enable { + s.logMode = 2 + } else { + s.logMode = 1 + } + return s +} + +// SingularTable use singular table by default +func (s *DB) SingularTable(enable bool) { + modelStructsMap = newModelStructsMap() + s.parent.singularTable = enable +} + +// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query +func (s *DB) Where(query interface{}, args ...interface{}) *DB { + return s.clone().search.Where(query, args...).db +} + +// Or filter records that match before conditions or this one, similar to `Where` +func (s *DB) Or(query interface{}, args ...interface{}) *DB { + return s.clone().search.Or(query, args...).db +} + +// Not filter records that don't match current conditions, similar to `Where` +func (s *DB) Not(query interface{}, args ...interface{}) *DB { + return s.clone().search.Not(query, args...).db +} + +// Limit specify the number of records to be retrieved +func (s *DB) Limit(limit int) *DB { + return s.clone().search.Limit(limit).db +} + +// Offset specify the number of records to skip before starting to return the records +func (s *DB) Offset(offset int) *DB { + return s.clone().search.Offset(offset).db +} + +// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions +func (s *DB) Order(value string, reorder ...bool) *DB { + return s.clone().search.Order(value, reorder...).db +} + +// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; +// When creating/updating, specify fields that you want to save to database +func (s *DB) Select(query interface{}, args ...interface{}) *DB { + return s.clone().search.Select(query, args...).db +} + +// Omit specify fields that you want to ignore when saving to database for creating, updating +func (s *DB) Omit(columns ...string) *DB { + return s.clone().search.Omit(columns...).db +} + +// Group specify the group method on the find +func (s *DB) Group(query string) *DB { + return s.clone().search.Group(query).db +} + +// Having specify HAVING conditions for GROUP BY +func (s *DB) Having(query string, values ...interface{}) *DB { + return s.clone().search.Having(query, values...).db +} + +// Joins specify Joins conditions +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +func (s *DB) Joins(query string, args ...interface{}) *DB { + return s.clone().search.Joins(query, args...).db +} + +// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// Refer https://jinzhu.github.io/gorm/curd.html#scopes +func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { + for _, f := range funcs { + s = f(s) + } + return s +} + +// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete +func (s *DB) Unscoped() *DB { + return s.clone().search.unscoped().db +} + +// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate +func (s *DB) Attrs(attrs ...interface{}) *DB { + return s.clone().search.Attrs(attrs...).db +} + +// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate +func (s *DB) Assign(attrs ...interface{}) *DB { + return s.clone().search.Assign(attrs...).db +} + +// First find first record that match given conditions, order by primary key +func (s *DB) First(out interface{}, where ...interface{}) *DB { + newScope := s.clone().NewScope(out) + newScope.Search.Limit(1) + return newScope.Set("gorm:order_by_primary_key", "ASC"). + inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + +// Last find last record that match given conditions, order by primary key +func (s *DB) Last(out interface{}, where ...interface{}) *DB { + newScope := s.clone().NewScope(out) + newScope.Search.Limit(1) + return newScope.Set("gorm:order_by_primary_key", "DESC"). + inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + +// Find find records that match given conditions +func (s *DB) Find(out interface{}, where ...interface{}) *DB { + return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + +// Scan scan value to a struct +func (s *DB) Scan(dest interface{}) *DB { + return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db +} + +// Row return `*sql.Row` with given conditions +func (s *DB) Row() *sql.Row { + return s.NewScope(s.Value).row() +} + +// Rows return `*sql.Rows` with given conditions +func (s *DB) Rows() (*sql.Rows, error) { + return s.NewScope(s.Value).rows() +} + +// ScanRows scan `*sql.Rows` to give struct +func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { + var ( + clone = s.clone() + scope = clone.NewScope(result) + columns, err = rows.Columns() + ) + + if clone.AddError(err) == nil { + scope.scan(rows, columns, scope.Fields()) + } + + return clone.Error +} + +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) +func (s *DB) Pluck(column string, value interface{}) *DB { + return s.NewScope(s.Value).pluck(column, value).db +} + +// Count get how many records for a model +func (s *DB) Count(value interface{}) *DB { + return s.NewScope(s.Value).count(value).db +} + +// Related get related associations +func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { + return s.clone().NewScope(s.Value).related(value, foreignKeys...).db +} + +// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) +// https://jinzhu.github.io/gorm/curd.html#firstorinit +func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { + c := s.clone() + if result := c.First(out, where...); result.Error != nil { + if !result.RecordNotFound() { + return result + } + c.NewScope(out).inlineCondition(where...).initialize() + } else { + c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) + } + return c +} + +// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) +// https://jinzhu.github.io/gorm/curd.html#firstorcreate +func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { + c := s.clone() + if result := c.First(out, where...); result.Error != nil { + if !result.RecordNotFound() { + return result + } + c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error) + } else if len(c.search.assignAttrs) > 0 { + c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error) + } + return c +} + +// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +func (s *DB) Update(attrs ...interface{}) *DB { + return s.Updates(toSearchableMap(attrs...), true) +} + +// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { + return s.clone().NewScope(s.Value). + Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). + InstanceSet("gorm:update_interface", values). + callCallbacks(s.parent.callbacks.updates).db +} + +// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +func (s *DB) UpdateColumn(attrs ...interface{}) *DB { + return s.UpdateColumns(toSearchableMap(attrs...)) +} + +// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +func (s *DB) UpdateColumns(values interface{}) *DB { + return s.clone().NewScope(s.Value). + Set("gorm:update_column", true). + Set("gorm:save_associations", false). + InstanceSet("gorm:update_interface", values). + callCallbacks(s.parent.callbacks.updates).db +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (s *DB) Save(value interface{}) *DB { + scope := s.clone().NewScope(value) + if scope.PrimaryKeyZero() { + return scope.callCallbacks(s.parent.callbacks.creates).db + } + return scope.callCallbacks(s.parent.callbacks.updates).db +} + +// Create insert the value into database +func (s *DB) Create(value interface{}) *DB { + scope := s.clone().NewScope(value) + return scope.callCallbacks(s.parent.callbacks.creates).db +} + +// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +func (s *DB) Delete(value interface{}, where ...interface{}) *DB { + return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db +} + +// Raw use raw sql as conditions, won't run it unless invoked by other methods +// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) +func (s *DB) Raw(sql string, values ...interface{}) *DB { + return s.clone().search.Raw(true).Where(sql, values...).db +} + +// Exec execute raw sql +func (s *DB) Exec(sql string, values ...interface{}) *DB { + scope := s.clone().NewScope(nil) + generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) + generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") + scope.Raw(generatedSQL) + return scope.Exec().db +} + +// Model specify the model you would like to run db operations +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (s *DB) Model(value interface{}) *DB { + c := s.clone() + c.Value = value + return c +} + +// Table specify the table you would like to run db operations +func (s *DB) Table(name string) *DB { + clone := s.clone() + clone.search.Table(name) + clone.Value = nil + return clone +} + +// Debug start debug mode +func (s *DB) Debug() *DB { + return s.clone().LogMode(true) +} + +// Begin begin a transaction +func (s *DB) Begin() *DB { + c := s.clone() + if db, ok := c.db.(sqlDb); ok { + tx, err := db.Begin() + c.db = interface{}(tx).(sqlCommon) + c.AddError(err) + } else { + c.AddError(ErrCantStartTransaction) + } + return c +} + +// Commit commit a transaction +func (s *DB) Commit() *DB { + if db, ok := s.db.(sqlTx); ok { + s.AddError(db.Commit()) + } else { + s.AddError(ErrInvalidTransaction) + } + return s +} + +// Rollback rollback a transaction +func (s *DB) Rollback() *DB { + if db, ok := s.db.(sqlTx); ok { + s.AddError(db.Rollback()) + } else { + s.AddError(ErrInvalidTransaction) + } + return s +} + +// NewRecord check if value's primary key is blank +func (s *DB) NewRecord(value interface{}) bool { + return s.clone().NewScope(value).PrimaryKeyZero() +} + +// RecordNotFound check if returning ErrRecordNotFound error +func (s *DB) RecordNotFound() bool { + for _, err := range s.GetErrors() { + if err == ErrRecordNotFound { + return true + } + } + return false +} + +// CreateTable create table for models +func (s *DB) CreateTable(models ...interface{}) *DB { + db := s.Unscoped() + for _, model := range models { + db = db.NewScope(model).createTable().db + } + return db +} + +// DropTable drop table for models +func (s *DB) DropTable(values ...interface{}) *DB { + db := s.clone() + for _, value := range values { + if tableName, ok := value.(string); ok { + db = db.Table(tableName) + } + + db = db.NewScope(value).dropTable().db + } + return db +} + +// DropTableIfExists drop table if it is exist +func (s *DB) DropTableIfExists(values ...interface{}) *DB { + db := s.clone() + for _, value := range values { + if s.HasTable(value) { + db.AddError(s.DropTable(value).Error) + } + } + return db +} + +// HasTable check has table or not +func (s *DB) HasTable(value interface{}) bool { + var ( + scope = s.clone().NewScope(value) + tableName string + ) + + if name, ok := value.(string); ok { + tableName = name + } else { + tableName = scope.TableName() + } + + has := scope.Dialect().HasTable(tableName) + s.AddError(scope.db.Error) + return has +} + +// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data +func (s *DB) AutoMigrate(values ...interface{}) *DB { + db := s.Unscoped() + for _, value := range values { + db = db.NewScope(value).autoMigrate().db + } + return db +} + +// ModifyColumn modify column to type +func (s *DB) ModifyColumn(column string, typ string) *DB { + scope := s.clone().NewScope(s.Value) + scope.modifyColumn(column, typ) + return scope.db +} + +// DropColumn drop a column +func (s *DB) DropColumn(column string) *DB { + scope := s.clone().NewScope(s.Value) + scope.dropColumn(column) + return scope.db +} + +// AddIndex add index for columns with given name +func (s *DB) AddIndex(indexName string, columns ...string) *DB { + scope := s.Unscoped().NewScope(s.Value) + scope.addIndex(false, indexName, columns...) + return scope.db +} + +// AddUniqueIndex add unique index for columns with given name +func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { + scope := s.Unscoped().NewScope(s.Value) + scope.addIndex(true, indexName, columns...) + return scope.db +} + +// RemoveIndex remove index with name +func (s *DB) RemoveIndex(indexName string) *DB { + scope := s.clone().NewScope(s.Value) + scope.removeIndex(indexName) + return scope.db +} + +// AddForeignKey Add foreign key to the given scope, e.g: +// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") +func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { + scope := s.clone().NewScope(s.Value) + scope.addForeignKey(field, dest, onDelete, onUpdate) + return scope.db +} + +// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode +func (s *DB) Association(column string) *Association { + var err error + scope := s.clone().NewScope(s.Value) + + if primaryField := scope.PrimaryField(); primaryField.IsBlank { + err = errors.New("primary key can't be nil") + } else { + if field, ok := scope.FieldByName(column); ok { + if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { + err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) + } else { + return &Association{scope: scope, column: column, field: field} + } + } else { + err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) + } + } + + return &Association{Error: err} +} + +// Preload preload associations with given conditions +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +func (s *DB) Preload(column string, conditions ...interface{}) *DB { + return s.clone().search.Preload(column, conditions...).db +} + +// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting +func (s *DB) Set(name string, value interface{}) *DB { + return s.clone().InstantSet(name, value) +} + +// InstantSet instant set setting, will affect current db +func (s *DB) InstantSet(name string, value interface{}) *DB { + s.values[name] = value + return s +} + +// Get get setting by name +func (s *DB) Get(name string) (value interface{}, ok bool) { + value, ok = s.values[name] + return +} + +// SetJoinTableHandler set a model's join table handler for a relation +func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { + scope := s.NewScope(source) + for _, field := range scope.GetModelStruct().StructFields { + if field.Name == column || field.DBName == column { + if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + source := (&Scope{Value: source}).GetModelStruct().ModelType + destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType + handler.Setup(field.Relationship, many2many, source, destination) + field.Relationship.JoinTableHandler = handler + if table := handler.Table(s); scope.Dialect().HasTable(table) { + s.Table(table).AutoMigrate(handler) + } + } + } + } +} + +// AddError add error to the db +func (s *DB) AddError(err error) error { + if err != nil { + if err != ErrRecordNotFound { + if s.logMode == 0 { + go s.print(fileWithLineNum(), err) + } else { + s.log(err) + } + + errors := Errors{errors: s.GetErrors()} + errors.Add(err) + if len(errors.GetErrors()) > 1 { + err = errors + } + } + + s.Error = err + } + return err +} + +// GetErrors get happened errors from the db +func (s *DB) GetErrors() (errors []error) { + if errs, ok := s.Error.(errorsInterface); ok { + return errs.GetErrors() + } else if s.Error != nil { + return []error{s.Error} + } + return +} + +//////////////////////////////////////////////////////////////////////////////// +// Private Methods For *gorm.DB +//////////////////////////////////////////////////////////////////////////////// + +func (s *DB) clone() *DB { + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + + for key, value := range s.values { + db.values[key] = value + } + + if s.search == nil { + db.search = &search{limit: -1, offset: -1} + } else { + db.search = s.search.clone() + } + + db.search.db = &db + return &db +} + +func (s *DB) print(v ...interface{}) { + s.logger.(logger).Print(v...) +} + +func (s *DB) log(v ...interface{}) { + if s != nil && s.logMode == 2 { + s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) + } +} + +func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { + if s.logMode == 2 { + s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) + } +} diff --git a/vendor/github.com/jinzhu/gorm/main_test.go b/vendor/github.com/jinzhu/gorm/main_test.go new file mode 100644 index 000000000..8ac015c8d --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/main_test.go @@ -0,0 +1,774 @@ +package gorm_test + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "os" + "reflect" + "strconv" + "testing" + "time" + + "github.com/erikstmartin/go-testdb" + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/mssql" + _ "github.com/jinzhu/gorm/dialects/mysql" + "github.com/jinzhu/gorm/dialects/postgres" + _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/now" +) + +var ( + DB *gorm.DB + t1, t2, t3, t4, t5 time.Time +) + +func init() { + var err error + + if DB, err = OpenTestConnection(); err != nil { + panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) + } + + // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) + // DB.SetLogger(log.New(os.Stdout, "\r\n", 0)) + if os.Getenv("DEBUG") == "true" { + DB.LogMode(true) + } + + DB.DB().SetMaxIdleConns(10) + + runMigration() +} + +func OpenTestConnection() (db *gorm.DB, err error) { + switch os.Getenv("GORM_DIALECT") { + case "mysql": + // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; + // CREATE DATABASE gorm; + // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; + fmt.Println("testing mysql...") + db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") + case "postgres": + fmt.Println("testing postgres...") + db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") + case "foundation": + fmt.Println("testing foundation...") + db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") + case "mssql": + fmt.Println("testing mssql...") + db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") + default: + fmt.Println("testing sqlite3...") + db, err = gorm.Open("sqlite3", "/tmp/gorm.db") + } + return +} + +func TestStringPrimaryKey(t *testing.T) { + type UUIDStruct struct { + ID string `gorm:"primary_key"` + Name string + } + DB.AutoMigrate(&UUIDStruct{}) + + data := UUIDStruct{ID: "uuid", Name: "hello"} + if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { + t.Errorf("string primary key should not be populated") + } +} + +func TestExceptionsWithInvalidSql(t *testing.T) { + var columns []string + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + var count1, count2 int64 + DB.Model(&User{}).Count(&count1) + if count1 <= 0 { + t.Errorf("Should find some users") + } + + if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + DB.Model(&User{}).Count(&count2) + if count1 != count2 { + t.Errorf("No user should not be deleted by invalid SQL") + } +} + +func TestSetTable(t *testing.T) { + DB.Create(getPreparedUser("pluck_user1", "pluck_user")) + DB.Create(getPreparedUser("pluck_user2", "pluck_user")) + DB.Create(getPreparedUser("pluck_user3", "pluck_user")) + + if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { + t.Error("No errors should happen if set table for pluck", err) + } + + var users []User + if DB.Table("users").Find(&[]User{}).Error != nil { + t.Errorf("No errors should happen if set table for find") + } + + if DB.Table("invalid_table").Find(&users).Error == nil { + t.Errorf("Should got error when table is set to an invalid table") + } + + DB.Exec("drop table deleted_users;") + if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { + t.Errorf("Create table with specified table") + } + + DB.Table("deleted_users").Save(&User{Name: "DeletedUser"}) + + var deletedUsers []User + DB.Table("deleted_users").Find(&deletedUsers) + if len(deletedUsers) != 1 { + t.Errorf("Query from specified table") + } + + DB.Save(getPreparedUser("normal_user", "reset_table")) + DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) + var user1, user2, user3 User + DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) + if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { + t.Errorf("unset specified table with blank string") + } +} + +type Order struct { +} + +type Cart struct { +} + +func (c Cart) TableName() string { + return "shopping_cart" +} + +func TestHasTable(t *testing.T) { + type Foo struct { + Id int + Stuff string + } + DB.DropTable(&Foo{}) + + // Table should not exist at this point, HasTable should return false + if ok := DB.HasTable("foos"); ok { + t.Errorf("Table should not exist, but does") + } + if ok := DB.HasTable(&Foo{}); ok { + t.Errorf("Table should not exist, but does") + } + + // We create the table + if err := DB.CreateTable(&Foo{}).Error; err != nil { + t.Errorf("Table should be created") + } + + // And now it should exits, and HasTable should return true + if ok := DB.HasTable("foos"); !ok { + t.Errorf("Table should exist, but HasTable informs it does not") + } + if ok := DB.HasTable(&Foo{}); !ok { + t.Errorf("Table should exist, but HasTable informs it does not") + } +} + +func TestTableName(t *testing.T) { + DB := DB.Model("") + if DB.NewScope(Order{}).TableName() != "orders" { + t.Errorf("Order's table name should be orders") + } + + if DB.NewScope(&Order{}).TableName() != "orders" { + t.Errorf("&Order's table name should be orders") + } + + if DB.NewScope([]Order{}).TableName() != "orders" { + t.Errorf("[]Order's table name should be orders") + } + + if DB.NewScope(&[]Order{}).TableName() != "orders" { + t.Errorf("&[]Order's table name should be orders") + } + + DB.SingularTable(true) + if DB.NewScope(Order{}).TableName() != "order" { + t.Errorf("Order's singular table name should be order") + } + + if DB.NewScope(&Order{}).TableName() != "order" { + t.Errorf("&Order's singular table name should be order") + } + + if DB.NewScope([]Order{}).TableName() != "order" { + t.Errorf("[]Order's singular table name should be order") + } + + if DB.NewScope(&[]Order{}).TableName() != "order" { + t.Errorf("&[]Order's singular table name should be order") + } + + if DB.NewScope(&Cart{}).TableName() != "shopping_cart" { + t.Errorf("&Cart's singular table name should be shopping_cart") + } + + if DB.NewScope(Cart{}).TableName() != "shopping_cart" { + t.Errorf("Cart's singular table name should be shopping_cart") + } + + if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" { + t.Errorf("&[]Cart's singular table name should be shopping_cart") + } + + if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { + t.Errorf("[]Cart's singular table name should be shopping_cart") + } + DB.SingularTable(false) +} + +func TestNullValues(t *testing.T) { + DB.DropTable(&NullValue{}) + DB.AutoMigrate(&NullValue{}) + + if err := DB.Save(&NullValue{ + Name: sql.NullString{String: "hello", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + Male: sql.NullBool{Bool: true, Valid: true}, + Height: sql.NullFloat64{Float64: 100.11, Valid: true}, + AddedAt: NullTime{Time: time.Now(), Valid: true}, + }).Error; err != nil { + t.Errorf("Not error should raise when test null value") + } + + var nv NullValue + DB.First(&nv, "name = ?", "hello") + + if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { + t.Errorf("Should be able to fetch null value") + } + + if err := DB.Save(&NullValue{ + Name: sql.NullString{String: "hello-2", Valid: true}, + Gender: &sql.NullString{String: "F", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: false}, + Male: sql.NullBool{Bool: true, Valid: true}, + Height: sql.NullFloat64{Float64: 100.11, Valid: true}, + AddedAt: NullTime{Time: time.Now(), Valid: false}, + }).Error; err != nil { + t.Errorf("Not error should raise when test null value") + } + + var nv2 NullValue + DB.First(&nv2, "name = ?", "hello-2") + if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { + t.Errorf("Should be able to fetch null value") + } + + if err := DB.Save(&NullValue{ + Name: sql.NullString{String: "hello-3", Valid: false}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: false}, + Male: sql.NullBool{Bool: true, Valid: true}, + Height: sql.NullFloat64{Float64: 100.11, Valid: true}, + AddedAt: NullTime{Time: time.Now(), Valid: false}, + }).Error; err == nil { + t.Errorf("Can't save because of name can't be null") + } +} + +func TestNullValuesWithFirstOrCreate(t *testing.T) { + var nv1 = NullValue{ + Name: sql.NullString{String: "first_or_create", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + } + + var nv2 NullValue + if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { + t.Errorf("first or create with nullvalues") + } + + if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if nv2.Age.Int64 != 18 { + t.Errorf("should update age to 18") + } +} + +func TestTransaction(t *testing.T) { + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + tx2 := DB.Begin() + u2 := User{Name: "transcation-2"} + if err := tx2.Save(&u2).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx2.Commit() + + if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } +} + +func TestRow(t *testing.T) { + user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() + var age int64 + row.Scan(&age) + if age != 10 { + t.Errorf("Scan with Row") + } +} + +func TestRows(t *testing.T) { + user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + count := 0 + for rows.Next() { + var name string + var age int64 + rows.Scan(&name, &age) + count++ + } + + if count != 2 { + t.Errorf("Should found two records") + } +} + +func TestScanRows(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + for rows.Next() { + var result Result + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { + t.Errorf("Should find expected results") + } +} + +func TestScan(t *testing.T) { + user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Age int + } + + var res result + DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res) + if res.Name != user3.Name { + t.Errorf("Scan into struct should work") + } + + var doubleAgeRes result + DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes) + if doubleAgeRes.Age != res.Age*2 { + t.Errorf("Scan double age as age") + } + + var ress []result + DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress) + if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { + t.Errorf("Scan into struct map") + } +} + +func TestRaw(t *testing.T) { + user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Email string + } + + var ress []result + DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress) + if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { + t.Errorf("Raw with scan") + } + + rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() + count := 0 + for rows.Next() { + count++ + } + if count != 1 { + t.Errorf("Raw with Rows should find one record with name 3") + } + + DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { + t.Error("Raw sql to update records") + } +} + +func TestGroup(t *testing.T) { + rows, err := DB.Select("name").Table("users").Group("name").Rows() + + if err == nil { + defer rows.Close() + for rows.Next() { + var name string + rows.Scan(&name) + } + } else { + t.Errorf("Should not raise any error") + } +} + +func TestJoins(t *testing.T) { + var user = User{ + Name: "joins", + CreditCard: CreditCard{Number: "411111111111"}, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var users1 []User + DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) + if len(users1) != 2 { + t.Errorf("should find two users using left join") + } + + var users2 []User + DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) + if len(users2) != 1 { + t.Errorf("should find one users using left join with conditions") + } + + var users3 []User + DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) + if len(users3) != 1 { + t.Errorf("should find one users using multiple left join conditions") + } + + var users4 []User + DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) + if len(users4) != 0 { + t.Errorf("should find no user when searching with unexisting credit card") + } +} + +func TestJoinsWithSelect(t *testing.T) { + type result struct { + Name string + Email string + } + + user := User{ + Name: "joins_with_select", + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var results []result + DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) + if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { + t.Errorf("Should find all two emails with Join select") + } +} + +func TestHaving(t *testing.T) { + rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() + + if err == nil { + defer rows.Close() + for rows.Next() { + var name string + var total int64 + rows.Scan(&name, &total) + + if name == "2" && total != 1 { + t.Errorf("Should have one user having name 2") + } + if name == "3" && total != 2 { + t.Errorf("Should have two users having name 3") + } + } + } else { + t.Errorf("Should not raise any error") + } +} + +func DialectHasTzSupport() bool { + // NB: mssql and FoundationDB do not support time zones. + if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { + return false + } + return true +} + +func TestTimeWithZone(t *testing.T) { + var format = "2006-01-02 15:04:05 -0700" + var times []time.Time + GMT8, _ := time.LoadLocation("Asia/Shanghai") + times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) + times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) + + for index, vtime := range times { + name := "time_with_zone_" + strconv.Itoa(index) + user := User{Name: name, Birthday: vtime} + + if !DialectHasTzSupport() { + // If our driver dialect doesn't support TZ's, just use UTC for everything here. + user.Birthday = vtime.UTC() + } + + DB.Save(&user) + expectedBirthday := "2013-02-18 17:51:49 +0000" + foundBirthday := user.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) + } + + var findUser, findUser2, findUser3 User + DB.First(&findUser, "name = ?", name) + foundBirthday = findUser.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) + } + + if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { + t.Errorf("User should be found") + } + + if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { + t.Errorf("User should not be found") + } + } +} + +func TestHstore(t *testing.T) { + type Details struct { + Id int64 + Bulk postgres.Hstore + } + + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { + t.Skip() + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { + fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") + panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) + } + + DB.Exec("drop table details") + + if err := DB.CreateTable(&Details{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + + bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" + bulk := map[string]*string{ + "bankAccountId": &bankAccountId, + "phoneNumber": &phoneNumber, + "opinion": &opinion, + } + d := Details{Bulk: bulk} + DB.Save(&d) + + var d2 Details + if err := DB.First(&d2).Error; err != nil { + t.Errorf("Got error when tried to fetch details: %+v", err) + } + + for k := range bulk { + if r, ok := d2.Bulk[k]; ok { + if res, _ := bulk[k]; *res != *r { + t.Errorf("Details should be equal") + } + } else { + t.Errorf("Details should be existed") + } + } +} + +func TestSetAndGet(t *testing.T) { + if value, ok := DB.Set("hello", "world").Get("hello"); !ok { + t.Errorf("Should be able to get setting after set") + } else { + if value.(string) != "world" { + t.Errorf("Setted value should not be changed") + } + } + + if _, ok := DB.Get("non_existing"); ok { + t.Errorf("Get non existing key should return error") + } +} + +func TestCompatibilityMode(t *testing.T) { + DB, _ := gorm.Open("testdb", "") + testdb.SetQueryFunc(func(query string) (driver.Rows, error) { + columns := []string{"id", "name", "age"} + result := ` + 1,Tim,20 + 2,Joe,25 + 3,Bob,30 + ` + return testdb.RowsFromCSVString(columns, result), nil + }) + + var users []User + DB.Find(&users) + if (users[0].Name != "Tim") || len(users) != 3 { + t.Errorf("Unexcepted result returned") + } +} + +func TestOpenExistingDB(t *testing.T) { + DB.Save(&User{Name: "jnfeinstein"}) + dialect := os.Getenv("GORM_DIALECT") + + db, err := gorm.Open(dialect, DB.DB()) + if err != nil { + t.Errorf("Should have wrapped the existing DB connection") + } + + var user User + if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { + t.Errorf("Should have found existing record") + } +} + +func TestDdlErrors(t *testing.T) { + var err error + + if err = DB.Close(); err != nil { + t.Errorf("Closing DDL test db connection err=%s", err) + } + defer func() { + // Reopen DB connection. + if DB, err = OpenTestConnection(); err != nil { + t.Fatalf("Failed re-opening db connection: %s", err) + } + }() + + if err := DB.Find(&User{}).Error; err == nil { + t.Errorf("Expected operation on closed db to produce an error, but err was nil") + } +} + +func BenchmarkGorm(b *testing.B) { + b.N = 2000 + for x := 0; x < b.N; x++ { + e := strconv.Itoa(x) + "benchmark@example.org" + email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} + // Insert + DB.Save(&email) + // Query + DB.First(&BigEmail{}, "email = ?", e) + // Update + DB.Model(&email).UpdateColumn("email", "new-"+e) + // Delete + DB.Delete(&email) + } +} + +func BenchmarkRawSql(b *testing.B) { + DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") + DB.SetMaxIdleConns(10) + insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" + querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" + updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" + deleteSql := "DELETE FROM orders WHERE id = $1" + + b.N = 2000 + for x := 0; x < b.N; x++ { + var id int64 + e := strconv.Itoa(x) + "benchmark@example.org" + email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} + // Insert + DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) + // Query + rows, _ := DB.Query(querySql, email.Email) + rows.Close() + // Update + DB.Exec(updateSql, "new-"+e, time.Now(), id) + // Delete + DB.Exec(deleteSql, id) + } +} diff --git a/vendor/github.com/jinzhu/gorm/migration_test.go b/vendor/github.com/jinzhu/gorm/migration_test.go new file mode 100644 index 000000000..38e5c1c2e --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/migration_test.go @@ -0,0 +1,349 @@ +package gorm_test + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/jinzhu/gorm" +) + +type User struct { + Id int64 + Age int64 + UserNum Num + Name string `sql:"size:255"` + Email string + Birthday time.Time // Time + CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically + UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically + Emails []Email // Embedded structs + BillingAddress Address // Embedded struct + BillingAddressID sql.NullInt64 // Embedded struct's foreign key + ShippingAddress Address // Embedded struct + ShippingAddressId int64 // Embedded struct's foreign key + CreditCard CreditCard + Latitude float64 + Languages []Language `gorm:"many2many:user_languages;"` + CompanyID *int + Company Company + Role + PasswordHash []byte + IgnoreMe int64 `sql:"-"` + IgnoreStringSlice []string `sql:"-"` + Ignored struct{ Name string } `sql:"-"` + IgnoredPointer *User `sql:"-"` +} + +type CreditCard struct { + ID int8 + Number string + UserId sql.NullInt64 + CreatedAt time.Time `sql:"not null"` + UpdatedAt time.Time + DeletedAt *time.Time +} + +type Email struct { + Id int16 + UserId int + Email string `sql:"type:varchar(100);"` + CreatedAt time.Time + UpdatedAt time.Time +} + +type Address struct { + ID int + Address1 string + Address2 string + Post string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} + +type Language struct { + gorm.Model + Name string + Users []User `gorm:"many2many:user_languages;"` +} + +type Product struct { + Id int64 + Code string + Price int64 + CreatedAt time.Time + UpdatedAt time.Time + AfterFindCallTimes int64 + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 +} + +type Company struct { + Id int64 + Name string + Owner *User `sql:"-"` +} + +type Role struct { + Name string +} + +func (role *Role) Scan(value interface{}) error { + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +} + +type Num int64 + +func (i *Num) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil +} + +type Animal struct { + Counter uint64 `gorm:"primary_key:yes"` + Name string `sql:"DEFAULT:'galeone'"` + From string //test reserved sql keyword as field name + Age time.Time `sql:"DEFAULT:current_timestamp"` + unexported string // unexported value + CreatedAt time.Time + UpdatedAt time.Time +} + +type JoinTable struct { + From uint64 + To uint64 + Time time.Time `sql:"default: null"` +} + +type Post struct { + Id int64 + CategoryId sql.NullInt64 + MainCategoryId int64 + Title string + Body string + Comments []*Comment + Category Category + MainCategory Category +} + +type Category struct { + gorm.Model + Name string +} + +type Comment struct { + gorm.Model + PostId int64 + Content string + Post Post +} + +// Scanner +type NullValue struct { + Id int64 + Name sql.NullString `sql:"not null"` + Gender *sql.NullString `sql:"not null"` + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + AddedAt NullTime +} + +type NullTime struct { + Time time.Time + Valid bool +} + +func (nt *NullTime) Scan(value interface{}) error { + if value == nil { + nt.Valid = false + return nil + } + nt.Time, nt.Valid = value.(time.Time), true + return nil +} + +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + +func getPreparedUser(name string, role string) *User { + var company Company + DB.Where(Company{Name: role}).FirstOrCreate(&company) + + return &User{ + Name: name, + Age: 20, + Role: Role{role}, + BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, + ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, + CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, + Emails: []Email{ + {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, + }, + Company: company, + Languages: []Language{ + {Name: fmt.Sprintf("lang_1_%v", name)}, + {Name: fmt.Sprintf("lang_2_%v", name)}, + }, + } +} + +func runMigration() { + if err := DB.DropTableIfExists(&User{}).Error; err != nil { + fmt.Printf("Got error when try to delete table users, %+v\n", err) + } + + for _, table := range []string{"animals", "user_languages"} { + DB.Exec(fmt.Sprintf("drop table %v;", table)) + } + + values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} + for _, value := range values { + DB.DropTable(value) + } + + if err := DB.AutoMigrate(values...).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } +} + +func TestIndexes(t *testing.T) { + if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } + + scope := DB.NewScope(&Email{}) + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + t.Errorf("Email should have index idx_email_email") + } + + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { + t.Errorf("Got error when tried to remove index: %+v", err) + } + + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + t.Errorf("Email's index idx_email_email should be deleted") + } + + if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } + + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email should have index idx_email_email_and_user_id") + } + + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { + t.Errorf("Got error when tried to remove index: %+v", err) + } + + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email's index idx_email_email_and_user_id should be deleted") + } + + if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } + + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email should have index idx_email_email_and_user_id") + } + + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { + t.Errorf("Should get to create duplicate record when having unique index") + } + + var user = User{Name: "sample_user"} + DB.Save(&user) + if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { + t.Errorf("Should get no error when append two emails for user") + } + + if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { + t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") + } + + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { + t.Errorf("Got error when tried to remove index: %+v", err) + } + + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email's index idx_email_email_and_user_id should be deleted") + } + + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { + t.Errorf("Should be able to create duplicated emails after remove unique index") + } +} + +type BigEmail struct { + Id int64 + UserId int64 + Email string `sql:"index:idx_email_agent"` + UserAgent string `sql:"index:idx_email_agent"` + RegisteredAt time.Time `sql:"unique_index"` + CreatedAt time.Time + UpdatedAt time.Time +} + +func (b BigEmail) TableName() string { + return "emails" +} + +func TestAutoMigration(t *testing.T) { + DB.AutoMigrate(&Address{}) + if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { + t.Errorf("Auto Migrate should not raise any error") + } + + DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) + + scope := DB.NewScope(&BigEmail{}) + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { + t.Errorf("Failed to create index") + } + + if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") { + t.Errorf("Failed to create index") + } + + var bigemail BigEmail + DB.First(&bigemail, "user_agent = ?", "pc") + if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { + t.Error("Big Emails should be saved and fetched correctly") + } +} diff --git a/vendor/github.com/jinzhu/gorm/model.go b/vendor/github.com/jinzhu/gorm/model.go new file mode 100644 index 000000000..f37ff7eaa --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/model.go @@ -0,0 +1,14 @@ +package gorm + +import "time" + +// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `sql:"index"` +} diff --git a/vendor/github.com/jinzhu/gorm/model_struct.go b/vendor/github.com/jinzhu/gorm/model_struct.go new file mode 100644 index 000000000..6df615d1b --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/model_struct.go @@ -0,0 +1,542 @@ +package gorm + +import ( + "database/sql" + "errors" + "go/ast" + "reflect" + "strings" + "sync" + "time" + + "github.com/jinzhu/inflection" +) + +// DefaultTableNameHandler default table name handler +var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { + return defaultTableName +} + +type safeModelStructsMap struct { + m map[reflect.Type]*ModelStruct + l *sync.RWMutex +} + +func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { + s.l.Lock() + defer s.l.Unlock() + s.m[key] = value +} + +func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { + s.l.RLock() + defer s.l.RUnlock() + return s.m[key] +} + +func newModelStructsMap() *safeModelStructsMap { + return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} +} + +var modelStructsMap = newModelStructsMap() + +// ModelStruct model definition +type ModelStruct struct { + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + defaultTableName string +} + +// TableName get model's table name +func (s *ModelStruct) TableName(db *DB) string { + return DefaultTableNameHandler(db, s.defaultTableName) +} + +// StructField model field's struct definition +type StructField struct { + DBName string + Name string + Names []string + IsPrimaryKey bool + IsNormal bool + IsIgnored bool + IsScanner bool + HasDefaultValue bool + Tag reflect.StructTag + TagSettings map[string]string + Struct reflect.StructField + IsForeignKey bool + Relationship *Relationship +} + +func (structField *StructField) clone() *StructField { + return &StructField{ + DBName: structField.DBName, + Name: structField.Name, + Names: structField.Names, + IsPrimaryKey: structField.IsPrimaryKey, + IsNormal: structField.IsNormal, + IsIgnored: structField.IsIgnored, + IsScanner: structField.IsScanner, + HasDefaultValue: structField.HasDefaultValue, + Tag: structField.Tag, + TagSettings: structField.TagSettings, + Struct: structField.Struct, + IsForeignKey: structField.IsForeignKey, + Relationship: structField.Relationship, + } +} + +// Relationship described the relationship between models +type Relationship struct { + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface +} + +func getForeignField(column string, fields []*StructField) *StructField { + for _, field := range fields { + if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + return field + } + } + return nil +} + +// GetModelStruct get value's model struct, relationships based on struct and tag definition +func (scope *Scope) GetModelStruct() *ModelStruct { + var modelStruct ModelStruct + // Scope value can't be nil + if scope.Value == nil { + return &modelStruct + } + + reflectType := reflect.ValueOf(scope.Value).Type() + for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + + // Scope value need to be a struct + if reflectType.Kind() != reflect.Struct { + return &modelStruct + } + + // Get Cached model struct + if value := modelStructsMap.Get(reflectType); value != nil { + return value + } + + modelStruct.ModelType = reflectType + + // Set default table name + if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok { + modelStruct.defaultTableName = tabler.TableName() + } else { + tableName := ToDBName(reflectType.Name()) + if scope.db == nil || !scope.db.parent.singularTable { + tableName = inflection.Plural(tableName) + } + modelStruct.defaultTableName = tableName + } + + // Get all fields + for i := 0; i < reflectType.NumField(); i++ { + if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { + field := &StructField{ + Struct: fieldStruct, + Name: fieldStruct.Name, + Names: []string{fieldStruct.Name}, + Tag: fieldStruct.Tag, + TagSettings: parseTagSetting(fieldStruct.Tag), + } + + // is ignored field + if fieldStruct.Tag.Get("sql") == "-" { + field.IsIgnored = true + } else { + if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } + + if _, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + } + + indirectType := fieldStruct.Type + for indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() + } + + fieldValue := reflect.New(indirectType).Interface() + if _, isScanner := fieldValue.(sql.Scanner); isScanner { + // is scanner + field.IsScanner, field.IsNormal = true, true + } else if _, isTime := fieldValue.(*time.Time); isTime { + // is time + field.IsNormal = true + } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + // is embedded struct + for _, subField := range scope.New(fieldValue).GetStructFields() { + subField = subField.clone() + subField.Names = append([]string{fieldStruct.Name}, subField.Names...) + if subField.IsPrimaryKey { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) + } + modelStruct.StructFields = append(modelStruct.StructFields, subField) + } + continue + } else { + // build relationships + switch indirectType.Kind() { + case reflect.Slice: + defer func(field *StructField) { + var ( + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + foreignKeys []string + associationForeignKeys []string + elemType = field.Struct.Type + ) + + if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + } + + if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + } + + for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + + if elemType.Kind() == reflect.Struct { + if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + relationship.Kind = "many_to_many" + + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) + } + } + + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + // join table foreign keys for source + joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) + } + } + + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for _, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } + + joinTableHandler := JoinTableHandler{} + joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + relationship.JoinTableHandler = &joinTableHandler + field.Relationship = relationship + } else { + // User has many comments, associationType is User, comment use UserID as foreign key + var associationType = reflectType.Name() + var toFields = toScope.GetStructFields() + relationship.Kind = "has_many" + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + // Dog has many toys, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('dogs') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + } + } + + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, associationType+field.Name) + associationForeignKeys = append(associationForeignKeys, field.Name) + } + } else { + // generate foreign keys from defined association foreign keys + for _, scopeFieldName := range associationForeignKeys { + if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, associationType+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } else { + // generate association foreign keys from foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, associationType) { + associationForeignKey := strings.TrimPrefix(foreignKey, associationType) + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } + } + } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{scope.PrimaryKey()} + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { + if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { + // source foreign keys + foreignField.IsForeignKey = true + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) + + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } + } + } + + if len(relationship.ForeignFieldNames) != 0 { + field.Relationship = relationship + } + } + } else { + field.IsNormal = true + } + }(field) + case reflect.Struct: + defer func(field *StructField) { + var ( + // user has one profile, associationType is User, profile use UserID as foreign key + // user belongs to profile, associationType is Profile, user use ProfileID as foreign key + associationType = reflectType.Name() + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + toFields = toScope.GetStructFields() + tagForeignKeys []string + tagAssociationForeignKeys []string + ) + + if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + } + + if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + } + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + // Cat has one toy, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('cats') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + } + } + + // Has One + { + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, primaryField := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, associationType+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) + } + } else { + // generate foreign keys form association foreign keys + for _, associationForeignKey := range tagAssociationForeignKeys { + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, associationType+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } else { + // generate association foreign keys from foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, associationType) { + associationForeignKey := strings.TrimPrefix(foreignKey, associationType) + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } + } + } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{scope.PrimaryKey()} + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { + if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + foreignField.IsForeignKey = true + // source foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) + + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } + } + } + } + + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "has_one" + field.Relationship = relationship + } else { + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys + + if len(foreignKeys) == 0 { + // generate foreign keys & association foreign keys + if len(associationForeignKeys) == 0 { + for _, primaryField := range toScope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.Name+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) + } + } else { + // generate foreign keys with association foreign keys + for _, associationForeignKey := range associationForeignKeys { + if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { + foreignKeys = append(foreignKeys, field.Name+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } else { + // generate foreign keys & association foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, field.Name) { + associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) + if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } + } + } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{toScope.PrimaryKey()} + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + foreignField.IsForeignKey = true + + // association foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) + + // source foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } + } + } + + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "belongs_to" + field.Relationship = relationship + } + } + }(field) + default: + field.IsNormal = true + } + } + } + + // Even it is ignored, also possible to decode db value into the field + if value, ok := field.TagSettings["COLUMN"]; ok { + field.DBName = value + } else { + field.DBName = ToDBName(fieldStruct.Name) + } + + modelStruct.StructFields = append(modelStruct.StructFields, field) + } + } + + if len(modelStruct.PrimaryFields) == 0 { + if field := getForeignField("id", modelStruct.StructFields); field != nil { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } + } + + modelStructsMap.Set(reflectType, &modelStruct) + + return &modelStruct +} + +// GetStructFields get model's field structs +func (scope *Scope) GetStructFields() (fields []*StructField) { + return scope.GetModelStruct().StructFields +} + +func parseTagSetting(tags reflect.StructTag) map[string]string { + setting := map[string]string{} + for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { + tags := strings.Split(str, ";") + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if len(v) >= 2 { + setting[k] = strings.Join(v[1:], ":") + } else { + setting[k] = k + } + } + } + return setting +} diff --git a/vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go b/vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go new file mode 100644 index 000000000..8b275d182 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/multi_primary_keys_test.go @@ -0,0 +1,381 @@ +package gorm_test + +import ( + "os" + "reflect" + "sort" + "testing" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + Blogs []*Blog `gorm:"many2many:blogs_tags"` +} + +func compareTags(tags []Tag, contents []string) bool { + var tagContents []string + for _, tag := range tags { + tagContents = append(tagContents, tag.Value) + } + sort.Strings(tagContents) + sort.Strings(contents) + return reflect.DeepEqual(tagContents, contents) +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { + t.Errorf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) + if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("Tags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "Tags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + var blog1 Blog + DB.Preload("Tags").Find(&blog1) + if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog).Association("Tags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "Tags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("Tags").Count() != 2 { + t.Errorf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("Tags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Related(&tags3, "Tags") + if !compareTags(tags3, []string{"tag6"}) { + t.Errorf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("Tags").Count() != 1 { + t.Errorf("Blog should has three tags after Delete") + } + + DB.Model(&blog).Association("Tags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Related(&tags4, "Tags") + if !compareTags(tags4, []string{"tag6"}) { + t.Errorf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog).Association("Tags").Clear() + if DB.Model(&blog).Association("Tags").Count() != 0 { + t.Errorf("All tags should be cleared") + } + } +} + +func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("shared_blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + SharedTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { + t.Errorf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) + if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("SharedTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + var blog1 Blog + DB.Preload("SharedTags").Find(&blog1) + if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("SharedTags").Append(tag4) + + DB.Model(&blog).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Errorf("Should find 3 tags with Related") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "SharedTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + DB.Model(&blog2).Related(&tags2, "SharedTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 2 { + t.Errorf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("SharedTags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Related(&tags3, "SharedTags") + if !compareTags(tags3, []string{"tag6"}) { + t.Errorf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 1 { + t.Errorf("Blog should has three tags after Delete") + } + + DB.Model(&blog2).Association("SharedTags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Related(&tags4, "SharedTags") + if !compareTags(tags4, []string{"tag6"}) { + t.Errorf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog2).Association("SharedTags").Clear() + if DB.Model(&blog).Association("SharedTags").Count() != 0 { + t.Errorf("All tags should be cleared") + } + } +} + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("locale_blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if len(tags) != 0 { + t.Errorf("Should find 0 tags with Related for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related for EN Blog") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag4"}) { + t.Errorf("Should find 1 tags with Related for EN Blog") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Errorf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared") + } + } +} diff --git a/vendor/github.com/jinzhu/gorm/pointer_test.go b/vendor/github.com/jinzhu/gorm/pointer_test.go new file mode 100644 index 000000000..2a68a5ab2 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/pointer_test.go @@ -0,0 +1,84 @@ +package gorm_test + +import "testing" + +type PointerStruct struct { + ID int64 + Name *string + Num *int +} + +type NormalStruct struct { + ID int64 + Name string + Num int +} + +func TestPointerFields(t *testing.T) { + DB.DropTable(&PointerStruct{}) + DB.AutoMigrate(&PointerStruct{}) + var name = "pointer struct 1" + var num = 100 + pointerStruct := PointerStruct{Name: &name, Num: &num} + if DB.Create(&pointerStruct).Error != nil { + t.Errorf("Failed to save pointer struct") + } + + var pointerStructResult PointerStruct + if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { + t.Errorf("Failed to query saved pointer struct") + } + + var tableName = DB.NewScope(&PointerStruct{}).TableName() + + var normalStruct NormalStruct + DB.Table(tableName).First(&normalStruct) + if normalStruct.Name != name || normalStruct.Num != num { + t.Errorf("Failed to query saved Normal struct") + } + + var nilPointerStruct = PointerStruct{} + if err := DB.Create(&nilPointerStruct).Error; err != nil { + t.Error("Failed to save nil pointer struct", err) + } + + var pointerStruct2 PointerStruct + if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { + t.Error("Failed to query saved nil pointer struct", err) + } + + var normalStruct2 NormalStruct + if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { + t.Error("Failed to query saved nil pointer struct", err) + } + + var partialNilPointerStruct1 = PointerStruct{Num: &num} + if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { + t.Error("Failed to save partial nil pointer struct", err) + } + + var pointerStruct3 PointerStruct + if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { + t.Error("Failed to query saved partial nil pointer struct", err) + } + + var normalStruct3 NormalStruct + if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { + t.Error("Failed to query saved partial pointer struct", err) + } + + var partialNilPointerStruct2 = PointerStruct{Name: &name} + if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { + t.Error("Failed to save partial nil pointer struct", err) + } + + var pointerStruct4 PointerStruct + if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { + t.Error("Failed to query saved partial nil pointer struct", err) + } + + var normalStruct4 NormalStruct + if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { + t.Error("Failed to query saved partial pointer struct", err) + } +} diff --git a/vendor/github.com/jinzhu/gorm/polymorphic_test.go b/vendor/github.com/jinzhu/gorm/polymorphic_test.go new file mode 100644 index 000000000..df573f97b --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/polymorphic_test.go @@ -0,0 +1,219 @@ +package gorm_test + +import ( + "reflect" + "sort" + "testing" +) + +type Cat struct { + Id int + Name string + Toy Toy `gorm:"polymorphic:Owner;"` +} + +type Dog struct { + Id int + Name string + Toys []Toy `gorm:"polymorphic:Owner;"` +} + +type Toy struct { + Id int + Name string + OwnerId int + OwnerType string +} + +var compareToys = func(toys []Toy, contents []string) bool { + var toyContents []string + for _, toy := range toys { + toyContents = append(toyContents, toy.Name) + } + sort.Strings(toyContents) + sort.Strings(contents) + return reflect.DeepEqual(toyContents, contents) +} + +func TestPolymorphic(t *testing.T) { + cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} + dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} + DB.Save(&cat).Save(&dog) + + if DB.Model(&cat).Association("Toy").Count() != 1 { + t.Errorf("Cat's toys count should be 1") + } + + if DB.Model(&dog).Association("Toys").Count() != 2 { + t.Errorf("Dog's toys count should be 2") + } + + // Query + var catToys []Toy + if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { + t.Errorf("Did not find any has one polymorphic association") + } else if len(catToys) != 1 { + t.Errorf("Should have found only one polymorphic has one association") + } else if catToys[0].Name != cat.Toy.Name { + t.Errorf("Should have found the proper has one polymorphic association") + } + + var dogToys []Toy + if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { + t.Errorf("Did not find any polymorphic has many associations") + } else if len(dogToys) != len(dog.Toys) { + t.Errorf("Should have found all polymorphic has many associations") + } + + var catToy Toy + DB.Model(&cat).Association("Toy").Find(&catToy) + if catToy.Name != cat.Toy.Name { + t.Errorf("Should find has one polymorphic association") + } + + var dogToys1 []Toy + DB.Model(&dog).Association("Toys").Find(&dogToys1) + if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { + t.Errorf("Should find has many polymorphic association") + } + + // Append + DB.Model(&cat).Association("Toy").Append(&Toy{ + Name: "cat toy 2", + }) + + var catToy2 Toy + DB.Model(&cat).Association("Toy").Find(&catToy2) + if catToy2.Name != "cat toy 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + if DB.Model(&cat).Association("Toy").Count() != 1 { + t.Errorf("Cat's toys count should be 1 after Append") + } + + if DB.Model(&dog).Association("Toys").Count() != 2 { + t.Errorf("Should return two polymorphic has many associations") + } + + DB.Model(&dog).Association("Toys").Append(&Toy{ + Name: "dog toy 3", + }) + + var dogToys2 []Toy + DB.Model(&dog).Association("Toys").Find(&dogToys2) + if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { + t.Errorf("Dog's toys should be updated with Append") + } + + if DB.Model(&dog).Association("Toys").Count() != 3 { + t.Errorf("Should return three polymorphic has many associations") + } + + // Replace + DB.Model(&cat).Association("Toy").Replace(&Toy{ + Name: "cat toy 3", + }) + + var catToy3 Toy + DB.Model(&cat).Association("Toy").Find(&catToy3) + if catToy3.Name != "cat toy 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + if DB.Model(&cat).Association("Toy").Count() != 1 { + t.Errorf("Cat's toys count should be 1 after Replace") + } + + if DB.Model(&dog).Association("Toys").Count() != 3 { + t.Errorf("Should return three polymorphic has many associations") + } + + DB.Model(&dog).Association("Toys").Replace(&Toy{ + Name: "dog toy 4", + }, []Toy{ + {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, + }) + + var dogToys3 []Toy + DB.Model(&dog).Association("Toys").Find(&dogToys3) + if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { + t.Errorf("Dog's toys should be updated with Replace") + } + + if DB.Model(&dog).Association("Toys").Count() != 4 { + t.Errorf("Should return three polymorphic has many associations") + } + + // Delete + DB.Model(&cat).Association("Toy").Delete(&catToy2) + + var catToy4 Toy + DB.Model(&cat).Association("Toy").Find(&catToy4) + if catToy4.Name != "cat toy 3" { + t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") + } + + if DB.Model(&cat).Association("Toy").Count() != 1 { + t.Errorf("Cat's toys count should be 1") + } + + if DB.Model(&dog).Association("Toys").Count() != 4 { + t.Errorf("Dog's toys count should be 4") + } + + DB.Model(&cat).Association("Toy").Delete(&catToy3) + + if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { + t.Errorf("Toy should be deleted with Delete") + } + + if DB.Model(&cat).Association("Toy").Count() != 0 { + t.Errorf("Cat's toys count should be 0 after Delete") + } + + if DB.Model(&dog).Association("Toys").Count() != 4 { + t.Errorf("Dog's toys count should not be changed when delete cat's toy") + } + + DB.Model(&dog).Association("Toys").Delete(&dogToys2) + + if DB.Model(&dog).Association("Toys").Count() != 4 { + t.Errorf("Dog's toys count should not be changed when delete unrelated toys") + } + + DB.Model(&dog).Association("Toys").Delete(&dogToys3) + + if DB.Model(&dog).Association("Toys").Count() != 0 { + t.Errorf("Dog's toys count should be deleted with Delete") + } + + // Clear + DB.Model(&cat).Association("Toy").Append(&Toy{ + Name: "cat toy 2", + }) + + if DB.Model(&cat).Association("Toy").Count() != 1 { + t.Errorf("Cat's toys should be added with Append") + } + + DB.Model(&cat).Association("Toy").Clear() + + if DB.Model(&cat).Association("Toy").Count() != 0 { + t.Errorf("Cat's toys should be cleared with Clear") + } + + DB.Model(&dog).Association("Toys").Append(&Toy{ + Name: "dog toy 8", + }) + + if DB.Model(&dog).Association("Toys").Count() != 1 { + t.Errorf("Dog's toys should be added with Append") + } + + DB.Model(&dog).Association("Toys").Clear() + + if DB.Model(&dog).Association("Toys").Count() != 0 { + t.Errorf("Dog's toys should be cleared with Clear") + } +} diff --git a/vendor/github.com/jinzhu/gorm/preload_test.go b/vendor/github.com/jinzhu/gorm/preload_test.go new file mode 100644 index 000000000..5c49ecc21 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/preload_test.go @@ -0,0 +1,1327 @@ +package gorm_test + +import ( + "database/sql" + "encoding/json" + "os" + "reflect" + "testing" + + "github.com/jinzhu/gorm" +) + +func getPreloadUser(name string) *User { + return getPreparedUser(name, "Preload") +} + +func checkUserHasPreloadData(user User, t *testing.T) { + u := getPreloadUser(user.Name) + if user.BillingAddress.Address1 != u.BillingAddress.Address1 { + t.Error("Failed to preload user's BillingAddress") + } + + if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { + t.Error("Failed to preload user's ShippingAddress") + } + + if user.CreditCard.Number != u.CreditCard.Number { + t.Error("Failed to preload user's CreditCard") + } + + if user.Company.Name != u.Company.Name { + t.Error("Failed to preload user's Company") + } + + if len(user.Emails) != len(u.Emails) { + t.Error("Failed to preload user's Emails") + } else { + var found int + for _, e1 := range u.Emails { + for _, e2 := range user.Emails { + if e1.Email == e2.Email { + found++ + break + } + } + } + if found != len(u.Emails) { + t.Error("Failed to preload user's email details") + } + } +} + +func TestPreload(t *testing.T) { + user1 := getPreloadUser("user1") + DB.Save(user1) + + preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company") + var user User + preloadDB.Find(&user) + checkUserHasPreloadData(user, t) + + user2 := getPreloadUser("user2") + DB.Save(user2) + + user3 := getPreloadUser("user3") + DB.Save(user3) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + checkUserHasPreloadData(user, t) + } + + var users2 []*User + preloadDB.Find(&users2) + + for _, user := range users2 { + checkUserHasPreloadData(*user, t) + } + + var users3 []*User + preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) + + for _, user := range users3 { + if user.Name == user3.Name { + if len(user.Emails) != 1 { + t.Errorf("should only preload one emails for user3 when with condition") + } + } else if len(user.Emails) != 0 { + t.Errorf("should not preload any emails for other users when with condition") + } + } +} + +func TestNestedPreload1(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } +} + +func TestNestedPreload2(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []*Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level3{ + Level2s: []Level2{ + { + Level1s: []*Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []*Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + Name string + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload4(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +// Slice: []Level3 +func TestNestedPreload5(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload6(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + + want[1] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload7(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload8(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload9(t *testing.T) { + type ( + Level0 struct { + ID uint + Value string + Level1ID uint + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2_1ID uint + Level0s []Level0 + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level2_1 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + Level2_1 Level2_1 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level2_1{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + { + Value: "value1-1", + Level0s: []Level0{{Value: "Level0-1"}}, + }, + { + Value: "value2-2", + Level0s: []Level0{{Value: "Level0-2"}}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + {Value: "value3-3"}, + {Value: "value4-4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +type LevelA1 struct { + ID uint + Value string +} + +type LevelA2 struct { + ID uint + Value string + LevelA3s []*LevelA3 +} + +type LevelA3 struct { + ID uint + Value string + LevelA1ID sql.NullInt64 + LevelA1 *LevelA1 + LevelA2ID sql.NullInt64 + LevelA2 *LevelA2 +} + +func TestNestedPreload10(t *testing.T) { + DB.DropTableIfExists(&LevelA3{}) + DB.DropTableIfExists(&LevelA2{}) + DB.DropTableIfExists(&LevelA1{}) + + if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { + t.Error(err) + } + + levelA1 := &LevelA1{Value: "foo"} + if err := DB.Save(levelA1).Error; err != nil { + t.Error(err) + } + + want := []*LevelA2{ + { + Value: "bar", + LevelA3s: []*LevelA3{ + { + Value: "qux", + LevelA1: levelA1, + }, + }, + }, + { + Value: "bar 2", + }, + } + for _, levelA2 := range want { + if err := DB.Save(levelA2).Error; err != nil { + t.Error(err) + } + } + + var got []*LevelA2 + if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +type LevelB1 struct { + ID uint + Value string + LevelB3s []*LevelB3 +} + +type LevelB2 struct { + ID uint + Value string +} + +type LevelB3 struct { + ID uint + Value string + LevelB1ID sql.NullInt64 + LevelB1 *LevelB1 + LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` +} + +func TestNestedPreload11(t *testing.T) { + DB.DropTableIfExists(&LevelB2{}) + DB.DropTableIfExists(&LevelB3{}) + DB.DropTableIfExists(&LevelB1{}) + if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { + t.Error(err) + } + + levelB1 := &LevelB1{Value: "foo"} + if err := DB.Create(levelB1).Error; err != nil { + t.Error(err) + } + + levelB3 := &LevelB3{ + Value: "bar", + LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + } + if err := DB.Create(levelB3).Error; err != nil { + t.Error(err) + } + levelB1.LevelB3s = []*LevelB3{levelB3} + + want := []*LevelB1{levelB1} + var got []*LevelB1 + if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { + return + } + + type ( + Level1 struct { + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string + } + Level2 struct { + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string + Level1s []Level1 `gorm:"many2many:levels;"` + } + ) + + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") + + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ + {Value: "ru", LanguageCode: "ru"}, + {Value: "en", LanguageCode: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ + {Value: "zh", LanguageCode: "zh"}, + {Value: "de", LanguageCode: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []Level1{ruLevel1} + got2.Level1s = []Level1{zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } + + if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { + t.Error(err) + } +} + +func TestManyToManyPreloadForNestedPointer(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level3{ + Value: "Bob", + Level2: &Level2{ + Value: "Foo", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, + } + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level3{ + Value: "Tom", + Level2: &Level2{ + Value: "Bar", + Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }, + }, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level3 + if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) + } + + var got4 []Level3 + if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var got5 Level3 + DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level2.Level1s = []*Level1{&ruLevel1} + got2.Level2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) + } +} + +func TestNestedManyToManyPreload(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2s []Level2 `gorm:"many2many:level2_level3;"` + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists("level2_level3") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level3{ + Value: "Level3", + Level2s: []Level2{ + { + Value: "Bob", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, { + Value: "Tom", + Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }, + }, + }, + } + + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } +} + +func TestNestedManyToManyPreload2(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level3{ + Value: "Level3", + Level2: &Level2{ + Value: "Bob", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, + } + + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } +} + +func TestNestedManyToManyPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + level1Zh := &Level1{Value: "zh"} + level1Ru := &Level1{Value: "ru"} + level1En := &Level1{Value: "en"} + + level21 := &Level2{ + Value: "Level2-1", + Level1s: []*Level1{level1Zh, level1Ru}, + } + + level22 := &Level2{ + Value: "Level2-2", + Level1s: []*Level1{level1Zh, level1En}, + } + + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } + + for _, want := range wants { + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + } + + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } +} + +func TestNestedManyToManyPreload4(t *testing.T) { + type ( + Level4 struct { + ID uint + Value string + Level3ID uint + } + Level3 struct { + ID uint + Value string + Level4s []*Level4 + } + Level2 struct { + ID uint + Value string + Level3s []*Level3 `gorm:"many2many:level2_level3;"` + } + Level1 struct { + ID uint + Value string + Level2s []*Level2 `gorm:"many2many:level1_level2;"` + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level4{}) + DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists("level2_level3") + + dummy := Level1{ + Value: "Level1", + Level2s: []*Level2{{ + Value: "Level2", + Level3s: []*Level3{{ + Value: "Level3", + Level4s: []*Level4{{ + Value: "Level4", + }}, + }}, + }}, + } + + if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + if err := DB.Save(&dummy).Error; err != nil { + t.Error(err) + } + + var level1 Level1 + if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { + t.Error(err) + } +} + +func TestManyToManyPreloadForPointer(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + ) + + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") + + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level2{Value: "Bob", Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level2{Value: "Tom", Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var got5 Level2 + DB.Preload("Level1s").First(&got5, "value = ?", "bogus") + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } +} + +func TestNilPointerSlice(t *testing.T) { + type ( + Level3 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level3ID uint + Level3 *Level3 + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level1{Value: "Bob", Level2: &Level2{ + Value: "en", + Level3: &Level3{ + Value: "native", + }, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level1{Value: "Tom", Level2: nil} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got []Level1 + if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { + t.Error(err) + } + + if len(got) != 2 { + t.Errorf("got %v items, expected 2", len(got)) + } + + if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + } + + if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) + } +} + +func toJSONString(v interface{}) []byte { + r, _ := json.MarshalIndent(v, "", " ") + return r +} diff --git a/vendor/github.com/jinzhu/gorm/query_test.go b/vendor/github.com/jinzhu/gorm/query_test.go new file mode 100644 index 000000000..7dc3d91b9 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/query_test.go @@ -0,0 +1,636 @@ +package gorm_test + +import ( + "fmt" + "reflect" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/now" + + "testing" + "time" +) + +func TestFirstAndLast(t *testing.T) { + DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) + DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) + + var user1, user2, user3, user4 User + DB.First(&user1) + DB.Order("id").Limit(1).Find(&user2) + + DB.Last(&user3) + DB.Order("id desc").Limit(1).Find(&user4) + if user1.Id != user2.Id || user3.Id != user4.Id { + t.Errorf("First and Last should by order by primary key") + } + + var users []User + DB.First(&users) + if len(users) != 1 { + t.Errorf("Find first record as slice") + } + + var user User + if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { + t.Errorf("Should not raise any error when order with Join table") + } + + if user.Email != "" { + t.Errorf("User's Email should be blank as no one set it") + } +} + +func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { + DB.Save(&Animal{Name: "animal1"}) + DB.Save(&Animal{Name: "animal2"}) + + var animal1, animal2, animal3, animal4 Animal + DB.First(&animal1) + DB.Order("counter").Limit(1).Find(&animal2) + + DB.Last(&animal3) + DB.Order("counter desc").Limit(1).Find(&animal4) + if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { + t.Errorf("First and Last should work correctly") + } +} + +func TestUIntPrimaryKey(t *testing.T) { + var animal Animal + DB.First(&animal, uint64(1)) + if animal.Counter != 1 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + } + + DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) + if animal.Counter != 2 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + } +} + +func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { + type AddressByZipCode struct { + ZipCode string `gorm:"primary_key"` + Address string + } + + DB.AutoMigrate(&AddressByZipCode{}) + DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) + + var address AddressByZipCode + DB.First(&address, "00501") + if address.ZipCode != "00501" { + t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") + } +} + +func TestFindAsSliceOfPointers(t *testing.T) { + DB.Save(&User{Name: "user"}) + + var users []User + DB.Find(&users) + + var userPointers []*User + DB.Find(&userPointers) + + if len(users) == 0 || len(users) != len(userPointers) { + t.Errorf("Find slice of pointers") + } +} + +func TestSearchWithPlainSQL(t *testing.T) { + user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") + + if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { + t.Errorf("Search with plain SQL") + } + + if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { + t.Errorf("Search with plan SQL (regexp)") + } + + var users []User + DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) + if len(users) != 2 { + t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) + } + + DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) + if len(users) != 3 { + t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) + } + + scopedb.Where("age <> ?", 20).Find(&users) + if len(users) != 2 { + t.Errorf("Should found 2 users age != 20, but got %v", len(users)) + } + + scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) + if len(users) != 2 { + t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + } + + scopedb.Where("birthday > ?", "2002-10-10").Find(&users) + if len(users) != 2 { + t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + } + + scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) + if len(users) != 1 { + t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + } + + DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) + if len(users) != 2 { + t.Errorf("Should found 2 users, but got %v", len(users)) + } + + DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) + if len(users) != 3 { + t.Errorf("Should found 3 users, but got %v", len(users)) + } + + DB.Where("id in (?)", user1.Id).Find(&users) + if len(users) != 1 { + t.Errorf("Should found 1 users, but got %v", len(users)) + } + + if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { + t.Error("no error should happen when query with empty slice, but got: ", err) + } + + if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { + t.Error("no error should happen when query with empty slice, but got: ", err) + } + + if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { + t.Errorf("Should not get RecordNotFound error when looking for none existing records") + } +} + +func TestSearchWithStruct(t *testing.T) { + user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + if DB.Where(user1.Id).First(&User{}).RecordNotFound() { + t.Errorf("Search with primary key") + } + + if DB.First(&User{}, user1.Id).RecordNotFound() { + t.Errorf("Search with primary key as inline condition") + } + + if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { + t.Errorf("Search with primary key as inline condition") + } + + var users []User + DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) + if len(users) != 3 { + t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) + } + + var user User + DB.First(&user, &User{Name: user1.Name}) + if user.Id == 0 || user.Name != user1.Name { + t.Errorf("Search first record with inline pointer of struct") + } + + DB.First(&user, User{Name: user1.Name}) + if user.Id == 0 || user.Name != user.Name { + t.Errorf("Search first record with inline struct") + } + + DB.Where(&User{Name: user1.Name}).First(&user) + if user.Id == 0 || user.Name != user1.Name { + t.Errorf("Search first record with where struct") + } + + DB.Find(&users, &User{Name: user2.Name}) + if len(users) != 1 { + t.Errorf("Search all records with inline struct") + } +} + +func TestSearchWithMap(t *testing.T) { + companyID := 1 + user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID} + DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) + + var user User + DB.First(&user, map[string]interface{}{"name": user1.Name}) + if user.Id == 0 || user.Name != user1.Name { + t.Errorf("Search first record with inline map") + } + + user = User{} + DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) + if user.Id == 0 || user.Name != user2.Name { + t.Errorf("Search first record with where map") + } + + var users []User + DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) + if len(users) != 1 { + t.Errorf("Search all records with inline map") + } + + DB.Find(&users, map[string]interface{}{"name": user3.Name}) + if len(users) != 1 { + t.Errorf("Search all records with inline map") + } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) + if len(users) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) + if len(users) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) + if len(users) != 1 { + t.Errorf("Search all records with inline multiple value map") + } +} + +func TestSearchWithEmptyChain(t *testing.T) { + user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + if DB.Where("").Where("").First(&User{}).Error != nil { + t.Errorf("Should not raise any error if searching with empty strings") + } + + if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { + t.Errorf("Should not raise any error if searching with empty struct") + } + + if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { + t.Errorf("Should not raise any error if searching with empty map") + } +} + +func TestSelect(t *testing.T) { + user1 := User{Name: "SelectUser1"} + DB.Save(&user1) + + var user User + DB.Where("name = ?", user1.Name).Select("name").Find(&user) + if user.Id != 0 { + t.Errorf("Should not have ID because only selected name, %+v", user.Id) + } + + if user.Name != user1.Name { + t.Errorf("Should have user Name when selected it") + } +} + +func TestOrderAndPluck(t *testing.T) { + user1 := User{Name: "OrderPluckUser1", Age: 1} + user2 := User{Name: "OrderPluckUser2", Age: 10} + user3 := User{Name: "OrderPluckUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") + + var ages []int64 + scopedb.Order("age desc").Pluck("age", &ages) + if ages[0] != 20 { + t.Errorf("The first age should be 20 when order with age desc") + } + + var ages1, ages2 []int64 + scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) + if !reflect.DeepEqual(ages1, ages2) { + t.Errorf("The first order is the primary order") + } + + var ages3, ages4 []int64 + scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) + if reflect.DeepEqual(ages3, ages4) { + t.Errorf("Reorder should work") + } + + var names []string + var ages5 []int64 + scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) + if names != nil && ages5 != nil { + if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { + t.Errorf("Order with multiple orders") + } + } else { + t.Errorf("Order with multiple orders") + } + + DB.Model(User{}).Select("name, age").Find(&[]User{}) +} + +func TestLimit(t *testing.T) { + user1 := User{Name: "LimitUser1", Age: 1} + user2 := User{Name: "LimitUser2", Age: 10} + user3 := User{Name: "LimitUser3", Age: 20} + user4 := User{Name: "LimitUser4", Age: 10} + user5 := User{Name: "LimitUser5", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) + + var users1, users2, users3 []User + DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) + + if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { + t.Errorf("Limit should works") + } +} + +func TestOffset(t *testing.T) { + for i := 0; i < 20; i++ { + DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) + } + var users1, users2, users3, users4 []User + DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work") + } +} + +func TestOr(t *testing.T) { + user1 := User{Name: "OrUser1", Age: 1} + user2 := User{Name: "OrUser2", Age: 10} + user3 := User{Name: "OrUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + var users []User + DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) + if len(users) != 2 { + t.Errorf("Find users with or") + } +} + +func TestCount(t *testing.T) { + user1 := User{Name: "CountUser1", Age: 1} + user2 := User{Name: "CountUser2", Age: 10} + user3 := User{Name: "CountUser3", Age: 20} + + DB.Save(&user1).Save(&user2).Save(&user3) + var count, count1, count2 int64 + var users []User + + if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value") + } + + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("Multiple count in chain") + } +} + +func TestNot(t *testing.T) { + DB.Create(getPreparedUser("user1", "not")) + DB.Create(getPreparedUser("user2", "not")) + DB.Create(getPreparedUser("user3", "not")) + + user4 := getPreparedUser("user4", "not") + user4.Company = Company{} + DB.Create(user4) + + DB := DB.Where("role = ?", "not") + + var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User + if DB.Find(&users1).RowsAffected != 4 { + t.Errorf("should find 4 not users") + } + DB.Not(users1[0].Id).Find(&users2) + + if len(users1)-len(users2) != 1 { + t.Errorf("Should ignore the first users with Not") + } + + DB.Not([]int{}).Find(&users3) + if len(users1)-len(users3) != 0 { + t.Errorf("Should find all users with a blank condition") + } + + var name3Count int64 + DB.Table("users").Where("name = ?", "user3").Count(&name3Count) + DB.Not("name", "user3").Find(&users4) + if len(users1)-len(users4) != int(name3Count) { + t.Errorf("Should find all users's name not equal 3") + } + + DB.Not("name = ?", "user3").Find(&users4) + if len(users1)-len(users4) != int(name3Count) { + t.Errorf("Should find all users's name not equal 3") + } + + DB.Not("name <> ?", "user3").Find(&users4) + if len(users4) != int(name3Count) { + t.Errorf("Should find all users's name not equal 3") + } + + DB.Not(User{Name: "user3"}).Find(&users5) + + if len(users1)-len(users5) != int(name3Count) { + t.Errorf("Should find all users's name not equal 3") + } + + DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) + if len(users1)-len(users6) != int(name3Count) { + t.Errorf("Should find all users's name not equal 3") + } + + DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) + if len(users1)-len(users7) != 2 { // not user3 or user4 + t.Errorf("Should find all user's name not equal to 3 who do not have company id") + } + + DB.Not("name", []string{"user3"}).Find(&users8) + if len(users1)-len(users8) != int(name3Count) { + t.Errorf("Should find all users's name not equal 3") + } + + var name2Count int64 + DB.Table("users").Where("name = ?", "user2").Count(&name2Count) + DB.Not("name", []string{"user3", "user2"}).Find(&users9) + if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { + t.Errorf("Should find all users's name not equal 3") + } +} + +func TestFillSmallerStruct(t *testing.T) { + user1 := User{Name: "SmallerUser", Age: 100} + DB.Save(&user1) + type SimpleUser struct { + Name string + Id int64 + UpdatedAt time.Time + CreatedAt time.Time + } + + var simpleUser SimpleUser + DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) + + if simpleUser.Id == 0 || simpleUser.Name == "" { + t.Errorf("Should fill data correctly into smaller struct") + } +} + +func TestFindOrInitialize(t *testing.T) { + var user1, user2, user3, user4, user5, user6 User + DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) + if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) + if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) + if user3.Name != "find or init 2" || user3.Id != 0 { + t.Errorf("user should be initialized with inline search value") + } + + DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and attrs") + } + + DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and assign attrs") + } + + DB.Save(&User{Name: "find or init", Age: 33}) + DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { + t.Errorf("user should be found with FirstOrInit") + } + + DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } +} + +func TestFindOrCreate(t *testing.T) { + var user1, user2, user3, user4, user5, user6, user7, user8 User + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) + if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) + if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) + if user3.Name != "find or create 2" || user3.Id == 0 { + t.Errorf("user should be created with inline search value") + } + + DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) + if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and attrs") + } + + updatedAt1 := user4.UpdatedAt + DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) + if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdateAt should be changed when update values with assign") + } + + DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) + if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) + if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Find(&user7) + if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) + if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { + t.Errorf("embedded struct email should be saved") + } + + if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { + t.Errorf("embedded struct credit card should be saved") + } +} + +func TestSelectWithEscapedFieldName(t *testing.T) { + user1 := User{Name: "EscapedFieldNameUser", Age: 1} + user2 := User{Name: "EscapedFieldNameUser", Age: 10} + user3 := User{Name: "EscapedFieldNameUser", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + var names []string + DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) + + if len(names) != 3 { + t.Errorf("Expected 3 name, but got: %d", len(names)) + } +} + +func TestSelectWithVariables(t *testing.T) { + DB.Save(&User{Name: "jinzhu"}) + + rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() + + if !rows.Next() { + t.Errorf("Should have returned at least one row") + } else { + columns, _ := rows.Columns() + if !reflect.DeepEqual(columns, []string{"fake"}) { + t.Errorf("Should only contains one column") + } + } +} + +func TestSelectWithArrayInput(t *testing.T) { + DB.Save(&User{Name: "jinzhu", Age: 42}) + + var user User + DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) + + if user.Name != "jinzhu" || user.Age != 42 { + t.Errorf("Should have selected both age and name") + } +} diff --git a/vendor/github.com/jinzhu/gorm/scaner_test.go b/vendor/github.com/jinzhu/gorm/scaner_test.go new file mode 100644 index 000000000..214105481 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/scaner_test.go @@ -0,0 +1,70 @@ +package gorm_test + +import ( + "database/sql/driver" + "encoding/json" + "testing" +) + +func TestScannableSlices(t *testing.T) { + if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { + t.Errorf("Should create table with slice values correctly: %s", err) + } + + r1 := RecordWithSlice{ + Strings: ExampleStringSlice{"a", "b", "c"}, + Structs: ExampleStructSlice{ + {"name1", "value1"}, + {"name2", "value2"}, + }, + } + + if err := DB.Save(&r1).Error; err != nil { + t.Errorf("Should save record with slice values") + } + + var r2 RecordWithSlice + + if err := DB.Find(&r2).Error; err != nil { + t.Errorf("Should fetch record with slice values") + } + + if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { + t.Errorf("Should have serialised and deserialised a string array") + } + + if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { + t.Errorf("Should have serialised and deserialised a struct array") + } +} + +type RecordWithSlice struct { + ID uint64 + Strings ExampleStringSlice `sql:"type:text"` + Structs ExampleStructSlice `sql:"type:text"` +} + +type ExampleStringSlice []string + +func (l ExampleStringSlice) Value() (driver.Value, error) { + return json.Marshal(l) +} + +func (l *ExampleStringSlice) Scan(input interface{}) error { + return json.Unmarshal(input.([]byte), l) +} + +type ExampleStruct struct { + Name string + Value string +} + +type ExampleStructSlice []ExampleStruct + +func (l ExampleStructSlice) Value() (driver.Value, error) { + return json.Marshal(l) +} + +func (l *ExampleStructSlice) Scan(input interface{}) error { + return json.Unmarshal(input.([]byte), l) +} diff --git a/vendor/github.com/jinzhu/gorm/scope.go b/vendor/github.com/jinzhu/gorm/scope.go new file mode 100644 index 000000000..844df85c7 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/scope.go @@ -0,0 +1,1246 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "reflect" +) + +// Scope contain current operation's information when you perform any operation on the database +type Scope struct { + Search *search + Value interface{} + SQL string + SQLVars []interface{} + db *DB + instanceID string + primaryKeyField *Field + skipLeft bool + fields *[]*Field + selectAttrs *[]string +} + +// IndirectValue return scope's reflect value's indirect value +func (scope *Scope) IndirectValue() reflect.Value { + return indirect(reflect.ValueOf(scope.Value)) +} + +// New create a new Scope without search information +func (scope *Scope) New(value interface{}) *Scope { + return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} +} + +//////////////////////////////////////////////////////////////////////////////// +// Scope DB +//////////////////////////////////////////////////////////////////////////////// + +// DB return scope's DB connection +func (scope *Scope) DB() *DB { + return scope.db +} + +// NewDB create a new DB without search information +func (scope *Scope) NewDB() *DB { + if scope.db != nil { + db := scope.db.clone() + db.search = nil + db.Value = nil + return db + } + return nil +} + +// SQLDB return *sql.DB +func (scope *Scope) SQLDB() sqlCommon { + return scope.db.db +} + +// Dialect get dialect +func (scope *Scope) Dialect() Dialect { + return scope.db.parent.dialect +} + +// Quote used to quote string to escape them for database +func (scope *Scope) Quote(str string) string { + if strings.Index(str, ".") != -1 { + newStrs := []string{} + for _, str := range strings.Split(str, ".") { + newStrs = append(newStrs, scope.Dialect().Quote(str)) + } + return strings.Join(newStrs, ".") + } + + return scope.Dialect().Quote(str) +} + +// Err add error to Scope +func (scope *Scope) Err(err error) error { + if err != nil { + scope.db.AddError(err) + } + return err +} + +// HasError check if there are any error +func (scope *Scope) HasError() bool { + return scope.db.Error != nil +} + +// Log print log message +func (scope *Scope) Log(v ...interface{}) { + scope.db.log(v...) +} + +// SkipLeft skip remaining callbacks +func (scope *Scope) SkipLeft() { + scope.skipLeft = true +} + +// Fields get value's fields +func (scope *Scope) Fields() []*Field { + if scope.fields == nil { + var ( + fields []*Field + indirectScopeValue = scope.IndirectValue() + isStruct = indirectScopeValue.Kind() == reflect.Struct + ) + + for _, structField := range scope.GetModelStruct().StructFields { + if isStruct { + fieldValue := indirectScopeValue + for _, name := range structField.Names { + fieldValue = reflect.Indirect(fieldValue).FieldByName(name) + } + fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) + } else { + fields = append(fields, &Field{StructField: structField, IsBlank: true}) + } + } + scope.fields = &fields + } + + return *scope.fields +} + +// FieldByName find `gorm.Field` with field name or db name +func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { + var ( + dbName = ToDBName(name) + mostMatchedField *Field + ) + + for _, field := range scope.Fields() { + if field.Name == name || field.DBName == name { + return field, true + } + if field.DBName == dbName { + mostMatchedField = field + } + } + return mostMatchedField, mostMatchedField != nil +} + +// PrimaryFields return scope's primary fields +func (scope *Scope) PrimaryFields() (fields []*Field) { + for _, field := range scope.Fields() { + if field.IsPrimaryKey { + fields = append(fields, field) + } + } + return fields +} + +// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one +func (scope *Scope) PrimaryField() *Field { + if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { + if len(primaryFields) > 1 { + if field, ok := scope.FieldByName("id"); ok { + return field + } + } + return scope.PrimaryFields()[0] + } + return nil +} + +// PrimaryKey get main primary field's db name +func (scope *Scope) PrimaryKey() string { + if field := scope.PrimaryField(); field != nil { + return field.DBName + } + return "" +} + +// PrimaryKeyZero check main primary field's value is blank or not +func (scope *Scope) PrimaryKeyZero() bool { + field := scope.PrimaryField() + return field == nil || field.IsBlank +} + +// PrimaryKeyValue get the primary key's value +func (scope *Scope) PrimaryKeyValue() interface{} { + if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { + return field.Field.Interface() + } + return 0 +} + +// HasColumn to check if has column +func (scope *Scope) HasColumn(column string) bool { + for _, field := range scope.GetStructFields() { + if field.IsNormal && (field.Name == column || field.DBName == column) { + return true + } + } + return false +} + +// SetColumn to set the column's value, column could be field or field's name/dbname +func (scope *Scope) SetColumn(column interface{}, value interface{}) error { + var updateAttrs = map[string]interface{}{} + if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { + updateAttrs = attrs.(map[string]interface{}) + defer scope.InstanceSet("gorm:update_attrs", updateAttrs) + } + + if field, ok := column.(*Field); ok { + updateAttrs[field.DBName] = value + return field.Set(value) + } else if name, ok := column.(string); ok { + var ( + dbName = ToDBName(name) + mostMatchedField *Field + ) + for _, field := range scope.Fields() { + if field.DBName == value { + updateAttrs[field.DBName] = value + return field.Set(value) + } + if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { + mostMatchedField = field + } + } + + if mostMatchedField != nil { + updateAttrs[mostMatchedField.DBName] = value + return mostMatchedField.Set(value) + } + } + return errors.New("could not convert column to field") +} + +// CallMethod call scope value's method, if it is a slice, will call its element's method one by one +func (scope *Scope) CallMethod(methodName string) { + if scope.Value == nil { + return + } + + if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { + for i := 0; i < indirectScopeValue.Len(); i++ { + scope.callMethod(methodName, indirectScopeValue.Index(i)) + } + } else { + scope.callMethod(methodName, indirectScopeValue) + } +} + +// AddToVars add value as sql's vars, used to prevent SQL injection +func (scope *Scope) AddToVars(value interface{}) string { + if expr, ok := value.(*expr); ok { + exp := expr.expr + for _, arg := range expr.args { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } + return exp + } + + scope.SQLVars = append(scope.SQLVars, value) + return scope.Dialect().BindVar(len(scope.SQLVars)) +} + +// SelectAttrs return selected attributes +func (scope *Scope) SelectAttrs() []string { + if scope.selectAttrs == nil { + attrs := []string{} + for _, value := range scope.Search.selects { + if str, ok := value.(string); ok { + attrs = append(attrs, str) + } else if strs, ok := value.([]string); ok { + attrs = append(attrs, strs...) + } else if strs, ok := value.([]interface{}); ok { + for _, str := range strs { + attrs = append(attrs, fmt.Sprintf("%v", str)) + } + } + } + scope.selectAttrs = &attrs + } + return *scope.selectAttrs +} + +// OmitAttrs return omitted attributes +func (scope *Scope) OmitAttrs() []string { + return scope.Search.omits +} + +type tabler interface { + TableName() string +} + +type dbTabler interface { + TableName(*DB) string +} + +// TableName return table name +func (scope *Scope) TableName() string { + if scope.Search != nil && len(scope.Search.tableName) > 0 { + return scope.Search.tableName + } + + if tabler, ok := scope.Value.(tabler); ok { + return tabler.TableName() + } + + if tabler, ok := scope.Value.(dbTabler); ok { + return tabler.TableName(scope.db) + } + + return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) +} + +// QuotedTableName return quoted table name +func (scope *Scope) QuotedTableName() (name string) { + if scope.Search != nil && len(scope.Search.tableName) > 0 { + if strings.Index(scope.Search.tableName, " ") != -1 { + return scope.Search.tableName + } + return scope.Quote(scope.Search.tableName) + } + + return scope.Quote(scope.TableName()) +} + +// CombinedConditionSql return combined condition sql +func (scope *Scope) CombinedConditionSql() string { + return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() + + scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() +} + +// Raw set raw sql +func (scope *Scope) Raw(sql string) *Scope { + scope.SQL = strings.Replace(sql, "$$", "?", -1) + return scope +} + +// Exec perform generated SQL +func (scope *Scope) Exec() *Scope { + defer scope.trace(NowFunc()) + + if !scope.HasError() { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if count, err := result.RowsAffected(); scope.Err(err) == nil { + scope.db.RowsAffected = count + } + } + } + return scope +} + +// Set set value by name +func (scope *Scope) Set(name string, value interface{}) *Scope { + scope.db.InstantSet(name, value) + return scope +} + +// Get get setting by name +func (scope *Scope) Get(name string) (interface{}, bool) { + return scope.db.Get(name) +} + +// InstanceID get InstanceID for scope +func (scope *Scope) InstanceID() string { + if scope.instanceID == "" { + scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) + } + return scope.instanceID +} + +// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback +func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { + return scope.Set(name+scope.InstanceID(), value) +} + +// InstanceGet get instance setting from current operation +func (scope *Scope) InstanceGet(name string) (interface{}, bool) { + return scope.Get(name + scope.InstanceID()) +} + +// Begin start a transaction +func (scope *Scope) Begin() *Scope { + if db, ok := scope.SQLDB().(sqlDb); ok { + if tx, err := db.Begin(); err == nil { + scope.db.db = interface{}(tx).(sqlCommon) + scope.InstanceSet("gorm:started_transaction", true) + } + } + return scope +} + +// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it +func (scope *Scope) CommitOrRollback() *Scope { + if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { + if db, ok := scope.db.db.(sqlTx); ok { + if scope.HasError() { + db.Rollback() + } else { + scope.Err(db.Commit()) + } + scope.db.db = scope.db.parent.db + } + } + return scope +} + +//////////////////////////////////////////////////////////////////////////////// +// Private Methods For *gorm.Scope +//////////////////////////////////////////////////////////////////////////////// + +func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { + // Only get address from non-pointer + if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { + reflectValue = reflectValue.Addr() + } + + if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { + switch method := methodValue.Interface().(type) { + case func(): + method() + case func(*Scope): + method(scope) + case func(*DB): + newDB := scope.NewDB() + method(newDB) + scope.Err(newDB.Error) + case func() error: + scope.Err(method()) + case func(*Scope) error: + scope.Err(method(scope)) + case func(*DB) error: + newDB := scope.NewDB() + scope.Err(method(newDB)) + scope.Err(newDB.Error) + default: + scope.Err(fmt.Errorf("unsupported function %v", methodName)) + } + } +} + +var columnRegexp = regexp.MustCompile("^[a-zA-Z]+(\\.[a-zA-Z]+)*$") // only match string like `name`, `users.name` + +func (scope *Scope) quoteIfPossible(str string) string { + if columnRegexp.MatchString(str) { + return scope.Quote(str) + } + return str +} + +func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { + var ( + ignored interface{} + selectFields []*Field + values = make([]interface{}, len(columns)) + selectedColumnsMap = map[string]int{} + resetFields = map[*Field]int{} + ) + + for index, column := range columns { + values[index] = &ignored + + selectFields = fields + if idx, ok := selectedColumnsMap[column]; ok { + selectFields = selectFields[idx+1:] + } + + for fieldIndex, field := range selectFields { + if field.DBName == column { + if field.Field.Kind() == reflect.Ptr { + values[index] = field.Field.Addr().Interface() + } else { + reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) + reflectValue.Elem().Set(field.Field.Addr()) + values[index] = reflectValue.Interface() + resetFields[field] = index + } + + selectedColumnsMap[column] = fieldIndex + break + } + } + } + + scope.Err(rows.Scan(values...)) + + for field, index := range resetFields { + if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { + field.Field.Set(v) + } + } +} + +func (scope *Scope) primaryCondition(value interface{}) string { + return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) +} + +func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { + switch value := clause["query"].(type) { + case string: + // if string is number + if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + return scope.primaryCondition(scope.AddToVars(value)) + } else if value != "" { + str = fmt.Sprintf("(%v)", value) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: + return scope.primaryCondition(scope.AddToVars(value)) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: + str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey())) + clause["args"] = []interface{}{value} + case map[string]interface{}: + var sqls []string + for key, value := range value { + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) + } + } + return strings.Join(sqls, " AND ") + case interface{}: + var sqls []string + for _, field := range scope.New(value).Fields() { + if !field.IsIgnored && !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } + } + return strings.Join(sqls, " AND ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + } + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { + var notEqualSQL string + var primaryKey = scope.PrimaryKey() + + switch value := clause["query"].(type) { + case string: + // is number + if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + id, _ := strconv.Atoi(value) + return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) + } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { + str = fmt.Sprintf(" NOT (%v) ", value) + notEqualSQL = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) + notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: + return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: + if reflect.ValueOf(value).Len() > 0 { + str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) + clause["args"] = []interface{}{value} + } + return "" + case map[string]interface{}: + var sqls []string + for key, value := range value { + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) + } + } + return strings.Join(sqls, " AND ") + case interface{}: + var sqls []string + for _, field := range scope.New(value).Fields() { + if !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } + } + return strings.Join(sqls, " AND ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + } + default: + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = scanner.Value() + } + str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { + switch value := clause["query"].(type) { + case string: + str = value + case []string: + str = strings.Join(value, ", ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: + values := reflect.ValueOf(arg) + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() + } + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) whereSQL() (sql string) { + var ( + quotedTableName = scope.QuotedTableName() + primaryConditions, andConditions, orConditions []string + ) + + if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { + sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) + primaryConditions = append(primaryConditions, sql) + } + + if !scope.PrimaryKeyZero() { + for _, field := range scope.PrimaryFields() { + sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) + primaryConditions = append(primaryConditions, sql) + } + } + + for _, clause := range scope.Search.whereConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + for _, clause := range scope.Search.orConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + orConditions = append(orConditions, sql) + } + } + + for _, clause := range scope.Search.notConditions { + if sql := scope.buildNotCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + orSQL := strings.Join(orConditions, " OR ") + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) > 0 { + if len(orSQL) > 0 { + combinedSQL = combinedSQL + " OR " + orSQL + } + } else { + combinedSQL = orSQL + } + + if len(primaryConditions) > 0 { + sql = "WHERE " + strings.Join(primaryConditions, " AND ") + if len(combinedSQL) > 0 { + sql = sql + " AND (" + combinedSQL + ")" + } + } else if len(combinedSQL) > 0 { + sql = "WHERE " + combinedSQL + } + return +} + +func (scope *Scope) selectSQL() string { + if len(scope.Search.selects) == 0 { + if len(scope.Search.joinConditions) > 0 { + return fmt.Sprintf("%v.*", scope.QuotedTableName()) + } + return "*" + } + return scope.buildSelectQuery(scope.Search.selects) +} + +func (scope *Scope) orderSQL() string { + if len(scope.Search.orders) == 0 || scope.Search.countingQuery { + return "" + } + + var orders []string + for _, order := range scope.Search.orders { + orders = append(orders, scope.quoteIfPossible(order)) + } + return " ORDER BY " + strings.Join(orders, ",") +} + +func (scope *Scope) limitAndOffsetSQL() string { + return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) +} + +func (scope *Scope) groupSQL() string { + if len(scope.Search.group) == 0 { + return "" + } + return " GROUP BY " + scope.Search.group +} + +func (scope *Scope) havingSQL() string { + if len(scope.Search.havingConditions) == 0 { + return "" + } + + var andConditions []string + for _, clause := range scope.Search.havingConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) == 0 { + return "" + } + + return " HAVING " + combinedSQL +} + +func (scope *Scope) joinsSQL() string { + var joinConditions []string + for _, clause := range scope.Search.joinConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) + } + } + + return strings.Join(joinConditions, " ") + " " +} + +func (scope *Scope) prepareQuerySQL() { + if scope.Search.raw { + scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) + } else { + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) + } + return +} + +func (scope *Scope) inlineCondition(values ...interface{}) *Scope { + if len(values) > 0 { + scope.Search.Where(values[0], values[1:]...) + } + return scope +} + +func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + for _, f := range funcs { + (*f)(scope) + if scope.skipLeft { + break + } + } + return scope +} + +func convertInterfaceToMap(values interface{}) map[string]interface{} { + var attrs = map[string]interface{}{} + + switch value := values.(type) { + case map[string]interface{}: + return value + case []interface{}: + for _, v := range value { + for key, value := range convertInterfaceToMap(v) { + attrs[key] = value + } + } + case interface{}: + reflectValue := reflect.ValueOf(values) + + switch reflectValue.Kind() { + case reflect.Map: + for _, key := range reflectValue.MapKeys() { + attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + } + default: + for _, field := range (&Scope{Value: values}).Fields() { + if !field.IsBlank { + attrs[field.DBName] = field.Field.Interface() + } + } + } + } + return attrs +} + +func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { + if scope.IndirectValue().Kind() != reflect.Struct { + return convertInterfaceToMap(value), true + } + + results = map[string]interface{}{} + + for key, value := range convertInterfaceToMap(value) { + if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { + if _, ok := value.(*expr); ok { + hasUpdate = true + results[field.DBName] = value + } else { + err := field.Set(value) + if field.IsNormal { + hasUpdate = true + if err == ErrUnaddressable { + fmt.Println(err) + results[field.DBName] = value + } else { + results[field.DBName] = field.Field.Interface() + } + } + } + } + } + return +} + +func (scope *Scope) row() *sql.Row { + defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.prepareQuerySQL() + return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) +} + +func (scope *Scope) rows() (*sql.Rows, error) { + defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.prepareQuerySQL() + return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) +} + +func (scope *Scope) initialize() *Scope { + for _, clause := range scope.Search.whereConditions { + scope.updatedAttrsWithValues(clause["query"]) + } + scope.updatedAttrsWithValues(scope.Search.initAttrs) + scope.updatedAttrsWithValues(scope.Search.assignAttrs) + return scope +} + +func (scope *Scope) pluck(column string, value interface{}) *Scope { + dest := reflect.Indirect(reflect.ValueOf(value)) + scope.Search.Select(column) + if dest.Kind() != reflect.Slice { + scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) + return scope + } + + rows, err := scope.rows() + if scope.Err(err) == nil { + defer rows.Close() + for rows.Next() { + elem := reflect.New(dest.Type().Elem()).Interface() + scope.Err(rows.Scan(elem)) + dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) + } + } + return scope +} + +func (scope *Scope) count(value interface{}) *Scope { + scope.Search.Select("count(*)") + scope.Search.countingQuery = true + scope.Err(scope.row().Scan(value)) + return scope +} + +func (scope *Scope) typeName() string { + typ := scope.IndirectValue().Type() + + for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + return typ.Name() +} + +// trace print sql log +func (scope *Scope) trace(t time.Time) { + if len(scope.SQL) > 0 { + scope.db.slog(scope.SQL, t, scope.SQLVars...) + } +} + +func (scope *Scope) changeableField(field *Field) bool { + if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { + for _, attr := range selectAttrs { + if field.Name == attr || field.DBName == attr { + return true + } + } + return false + } + + for _, attr := range scope.OmitAttrs() { + if field.Name == attr || field.DBName == attr { + return false + } + } + + return true +} + +func (scope *Scope) shouldSaveAssociations() bool { + if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { + return false + } + return true && !scope.HasError() +} + +func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { + toScope := scope.db.NewScope(value) + + for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { + fromField, _ := scope.FieldByName(foreignKey) + toField, _ := toScope.FieldByName(foreignKey) + + if fromField != nil { + if relationship := fromField.Relationship; relationship != nil { + if relationship.Kind == "many_to_many" { + joinTableHandler := relationship.JoinTableHandler + scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) + } else if relationship.Kind == "belongs_to" { + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(foreignKey); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) + } + } + scope.Err(query.Find(value).Error) + } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + if relationship.PolymorphicType != "" { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + } + scope.Err(query.Find(value).Error) + } + } else { + sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) + scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) + } + return scope + } else if toField != nil { + sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) + scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) + return scope + } + } + + scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) + return scope +} + +// getTableOptions return the table options string or an empty string if the table options does not exist +func (scope *Scope) getTableOptions() string { + tableOptions, ok := scope.Get("gorm:table_options") + if !ok { + return "" + } + return tableOptions.(string) +} + +func (scope *Scope) createJoinTable(field *StructField) { + if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { + joinTableHandler := relationship.JoinTableHandler + joinTable := joinTableHandler.Table(scope.db) + if !scope.Dialect().HasTable(joinTable) { + toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} + + var sqlTypes, primaryKeys []string + for idx, fieldName := range relationship.ForeignFieldNames { + if field, ok := scope.FieldByName(fieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) + primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) + } + } + + for idx, fieldName := range relationship.AssociationForeignFieldNames { + if field, ok := toScope.FieldByName(fieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) + primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) + } + } + + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + } + scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) + } +} + +func (scope *Scope) createTable() *Scope { + var tags []string + var primaryKeys []string + var primaryKeyInColumnType = false + for _, field := range scope.GetModelStruct().StructFields { + if field.IsNormal { + sqlTag := scope.Dialect().DataTypeOf(field) + + // Check if the primary key constraint was specified as + // part of the column type. If so, we can only support + // one column as the primary key. + if strings.Contains(strings.ToLower(sqlTag), "primary key") { + primaryKeyInColumnType = true + } + + tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) + } + + if field.IsPrimaryKey { + primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) + } + scope.createJoinTable(field) + } + + var primaryKeyStr string + if len(primaryKeys) > 0 && !primaryKeyInColumnType { + primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) + } + + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() + + scope.autoIndex() + return scope +} + +func (scope *Scope) dropTable() *Scope { + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() + return scope +} + +func (scope *Scope) modifyColumn(column string, typ string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() +} + +func (scope *Scope) dropColumn(column string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() +} + +func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { + if scope.Dialect().HasIndex(scope.TableName(), indexName) { + return + } + + var columns []string + for _, name := range column { + columns = append(columns, scope.quoteIfPossible(name)) + } + + sqlCreate := "CREATE INDEX" + if unique { + sqlCreate = "CREATE UNIQUE INDEX" + } + + scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() +} + +func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { + var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + + if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() +} + +func (scope *Scope) removeIndex(indexName string) { + scope.Dialect().RemoveIndex(scope.TableName(), indexName) +} + +func (scope *Scope) autoMigrate() *Scope { + tableName := scope.TableName() + quotedTableName := scope.QuotedTableName() + + if !scope.Dialect().HasTable(tableName) { + scope.createTable() + } else { + for _, field := range scope.GetModelStruct().StructFields { + if !scope.Dialect().HasColumn(tableName, field.DBName) { + if field.IsNormal { + sqlTag := scope.Dialect().DataTypeOf(field) + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() + } + } + scope.createJoinTable(field) + } + scope.autoIndex() + } + return scope +} + +func (scope *Scope) autoIndex() *Scope { + var indexes = map[string][]string{} + var uniqueIndexes = map[string][]string{} + + for _, field := range scope.GetStructFields() { + if name, ok := field.TagSettings["INDEX"]; ok { + if name == "INDEX" { + name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) + } + indexes[name] = append(indexes[name], field.DBName) + } + + if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { + if name == "UNIQUE_INDEX" { + name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) + } + uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) + } + } + + for name, columns := range indexes { + scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + } + + for name, columns := range uniqueIndexes { + scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + } + + return scope +} + +func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { + for _, value := range values { + indirectValue := reflect.ValueOf(value) + for indirectValue.Kind() == reflect.Ptr { + indirectValue = indirectValue.Elem() + } + + switch indirectValue.Kind() { + case reflect.Slice: + for i := 0; i < indirectValue.Len(); i++ { + var result []interface{} + var object = indirect(indirectValue.Index(i)) + for _, column := range columns { + result = append(result, object.FieldByName(column).Interface()) + } + results = append(results, result) + } + case reflect.Struct: + var result []interface{} + for _, column := range columns { + result = append(result, indirectValue.FieldByName(column).Interface()) + } + results = append(results, result) + } + } + return +} + +func (scope *Scope) getColumnAsScope(column string) *Scope { + indirectScopeValue := scope.IndirectValue() + + switch indirectScopeValue.Kind() { + case reflect.Slice: + if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { + fieldType := fieldStruct.Type + if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() + + for i := 0; i < indirectScopeValue.Len(); i++ { + result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) + + if result.Kind() == reflect.Slice { + for j := 0; j < result.Len(); j++ { + if elem := result.Index(j); elem.CanAddr() { + results = reflect.Append(results, elem.Addr()) + } + } + } else if result.CanAddr() { + results = reflect.Append(results, result.Addr()) + } + } + return scope.New(results.Interface()) + } + case reflect.Struct: + if field := indirectScopeValue.FieldByName(column); field.CanAddr() { + return scope.New(field.Addr().Interface()) + } + } + return nil +} diff --git a/vendor/github.com/jinzhu/gorm/scope_test.go b/vendor/github.com/jinzhu/gorm/scope_test.go new file mode 100644 index 000000000..42458995d --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/scope_test.go @@ -0,0 +1,43 @@ +package gorm_test + +import ( + "github.com/jinzhu/gorm" + "testing" +) + +func NameIn1And2(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) +} + +func NameIn2And3(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) +} + +func NameIn(names []string) func(d *gorm.DB) *gorm.DB { + return func(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", names) + } +} + +func TestScopes(t *testing.T) { + user1 := User{Name: "ScopeUser1", Age: 1} + user2 := User{Name: "ScopeUser2", Age: 1} + user3 := User{Name: "ScopeUser3", Age: 2} + DB.Save(&user1).Save(&user2).Save(&user3) + + var users1, users2, users3 []User + DB.Scopes(NameIn1And2).Find(&users1) + if len(users1) != 2 { + t.Errorf("Should found two users's name in 1, 2") + } + + DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) + if len(users2) != 1 { + t.Errorf("Should found one user's name is 2") + } + + DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) + if len(users3) != 2 { + t.Errorf("Should found two users's name in 1, 3") + } +} diff --git a/vendor/github.com/jinzhu/gorm/search.go b/vendor/github.com/jinzhu/gorm/search.go new file mode 100644 index 000000000..078bd4298 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/search.go @@ -0,0 +1,149 @@ +package gorm + +import "fmt" + +type search struct { + db *DB + whereConditions []map[string]interface{} + orConditions []map[string]interface{} + notConditions []map[string]interface{} + havingConditions []map[string]interface{} + joinConditions []map[string]interface{} + initAttrs []interface{} + assignAttrs []interface{} + selects map[string]interface{} + omits []string + orders []string + preload []searchPreload + offset int + limit int + group string + tableName string + raw bool + Unscoped bool + countingQuery bool +} + +type searchPreload struct { + schema string + conditions []interface{} +} + +func (s *search) clone() *search { + clone := *s + return &clone +} + +func (s *search) Where(query interface{}, values ...interface{}) *search { + s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) + return s +} + +func (s *search) Not(query interface{}, values ...interface{}) *search { + s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) + return s +} + +func (s *search) Or(query interface{}, values ...interface{}) *search { + s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) + return s +} + +func (s *search) Attrs(attrs ...interface{}) *search { + s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) + return s +} + +func (s *search) Assign(attrs ...interface{}) *search { + s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) + return s +} + +func (s *search) Order(value string, reorder ...bool) *search { + if len(reorder) > 0 && reorder[0] { + if value != "" { + s.orders = []string{value} + } else { + s.orders = []string{} + } + } else if value != "" { + s.orders = append(s.orders, value) + } + return s +} + +func (s *search) Select(query interface{}, args ...interface{}) *search { + s.selects = map[string]interface{}{"query": query, "args": args} + return s +} + +func (s *search) Omit(columns ...string) *search { + s.omits = columns + return s +} + +func (s *search) Limit(limit int) *search { + s.limit = limit + return s +} + +func (s *search) Offset(offset int) *search { + s.offset = offset + return s +} + +func (s *search) Group(query string) *search { + s.group = s.getInterfaceAsSQL(query) + return s +} + +func (s *search) Having(query string, values ...interface{}) *search { + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) + return s +} + +func (s *search) Joins(query string, values ...interface{}) *search { + s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) + return s +} + +func (s *search) Preload(schema string, values ...interface{}) *search { + var preloads []searchPreload + for _, preload := range s.preload { + if preload.schema != schema { + preloads = append(preloads, preload) + } + } + preloads = append(preloads, searchPreload{schema, values}) + s.preload = preloads + return s +} + +func (s *search) Raw(b bool) *search { + s.raw = b + return s +} + +func (s *search) unscoped() *search { + s.Unscoped = true + return s +} + +func (s *search) Table(name string) *search { + s.tableName = name + return s +} + +func (s *search) getInterfaceAsSQL(value interface{}) (str string) { + switch value.(type) { + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + str = fmt.Sprintf("%v", value) + default: + s.db.AddError(ErrInvalidSQL) + } + + if str == "-1" { + return "" + } + return +} diff --git a/vendor/github.com/jinzhu/gorm/search_test.go b/vendor/github.com/jinzhu/gorm/search_test.go new file mode 100644 index 000000000..4db7ab6a5 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/search_test.go @@ -0,0 +1,30 @@ +package gorm + +import ( + "reflect" + "testing" +) + +func TestCloneSearch(t *testing.T) { + s := new(search) + s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") + + s1 := s.clone() + s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") + + if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { + t.Errorf("Where should be copied") + } + + if reflect.DeepEqual(s.orders, s1.orders) { + t.Errorf("Order should be copied") + } + + if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { + t.Errorf("InitAttrs should be copied") + } + + if reflect.DeepEqual(s.Select, s1.Select) { + t.Errorf("selectStr should be copied") + } +} diff --git a/vendor/github.com/jinzhu/gorm/test_all.sh b/vendor/github.com/jinzhu/gorm/test_all.sh new file mode 100755 index 000000000..6c5593b37 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/test_all.sh @@ -0,0 +1,5 @@ +dialects=("postgres" "mysql" "sqlite") + +for dialect in "${dialects[@]}" ; do + GORM_DIALECT=${dialect} go test +done diff --git a/vendor/github.com/jinzhu/gorm/update_test.go b/vendor/github.com/jinzhu/gorm/update_test.go new file mode 100644 index 000000000..bdf010912 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/update_test.go @@ -0,0 +1,435 @@ +package gorm_test + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" +) + +func TestUpdate(t *testing.T) { + product1 := Product{Code: "product1code"} + product2 := Product{Code: "product2code"} + + DB.Save(&product1).Save(&product2).Update("code", "product2newcode") + + if product2.Code != "product2newcode" { + t.Errorf("Record should be updated") + } + + DB.First(&product1, product1.Id) + DB.First(&product2, product2.Id) + updatedAt1 := product1.UpdatedAt + + if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { + t.Errorf("Product1 should not be updated") + } + + if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { + t.Errorf("Product2's code should be updated") + } + + if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { + t.Errorf("Product2's code should be updated") + } + + DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") + + var product4 Product + DB.First(&product4, product1.Id) + if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updatedAt should be updated if something changed") + } + + if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { + t.Errorf("Product1's code should be updated") + } + + if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { + t.Errorf("Product should not be changed to 789") + } + + if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update with CamelCase") + } + + if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update_column with CamelCase") + } + + var products []Product + DB.Find(&products) + if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { + t.Error("RowsAffected should be correct when do batch update") + } + + DB.First(&product4, product4.Id) + updatedAt4 := product4.UpdatedAt + DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) + var product5 Product + DB.First(&product5, product4.Id) + if product5.Price != product4.Price+100-50 { + t.Errorf("Update with expression") + } + if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { + t.Errorf("Update with expression should update UpdatedAt") + } +} + +func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { + animal := Animal{Name: "Ferdinand"} + DB.Save(&animal) + updatedAt1 := animal.UpdatedAt + + DB.Save(&animal).Update("name", "Francis") + + if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updatedAt should not be updated if nothing changed") + } + + var animals []Animal + DB.Find(&animals) + if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { + t.Error("RowsAffected should be correct when do batch update") + } + + animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) + DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched + DB.First(&animal, animal.Counter) + if animal.Name != "galeone" { + t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name) + } + + // When changing a field with a default value, the change must occur + animal.Name = "amazing horse" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "amazing horse" { + t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) + } + + // When changing a field with a default value with blank value + animal.Name = "" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "" { + t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) + } +} + +func TestUpdates(t *testing.T) { + product1 := Product{Code: "product1code", Price: 10} + product2 := Product{Code: "product2code", Price: 10} + DB.Save(&product1).Save(&product2) + DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) + if product1.Code != "product1newcode" || product1.Price != 100 { + t.Errorf("Record should be updated also with map") + } + + DB.First(&product1, product1.Id) + DB.First(&product2, product2.Id) + updatedAt2 := product2.UpdatedAt + + if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { + t.Errorf("Product2 should not be updated") + } + + if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { + t.Errorf("Product1 should be updated") + } + + DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) + if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { + t.Errorf("Product2's code should be updated") + } + + var product4 Product + DB.First(&product4, product2.Id) + if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updatedAt should be updated if something changed") + } + + if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { + t.Errorf("product2's code should be updated") + } + + updatedAt4 := product4.UpdatedAt + DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) + var product5 Product + DB.First(&product5, product4.Id) + if product5.Price != product4.Price+100 { + t.Errorf("Updates with expression") + } + // product4's UpdatedAt will be reset when updating + if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { + t.Errorf("Updates with expression should update UpdatedAt") + } +} + +func TestUpdateColumn(t *testing.T) { + product1 := Product{Code: "product1code", Price: 10} + product2 := Product{Code: "product2code", Price: 20} + DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) + if product2.Code != "product2newcode" || product2.Price != 100 { + t.Errorf("product 2 should be updated with update column") + } + + var product3 Product + DB.First(&product3, product1.Id) + if product3.Code != "product1code" || product3.Price != 10 { + t.Errorf("product 1 should not be updated") + } + + DB.First(&product2, product2.Id) + updatedAt2 := product2.UpdatedAt + DB.Model(product2).UpdateColumn("code", "update_column_new") + var product4 Product + DB.First(&product4, product2.Id) + if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updatedAt should not be updated with update column") + } + + DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) + var product5 Product + DB.First(&product5, product4.Id) + if product5.Price != product4.Price+100-50 { + t.Errorf("UpdateColumn with expression") + } + if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdateColumn with expression should not update UpdatedAt") + } +} + +func TestSelectWithUpdate(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update") + DB.Create(user) + + var reloadUser User + DB.First(&reloadUser, user.Id) + reloadUser.Name = "new_name" + reloadUser.Age = 50 + reloadUser.BillingAddress = Address{Address1: "New Billing Address"} + reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} + reloadUser.CreditCard = CreditCard{Number: "987654321"} + reloadUser.Emails = []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + } + reloadUser.Company = Company{Name: "new company"} + + DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name == user.Name || queryUser.Age != user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || + queryUser.ShippingAddressId != user.ShippingAddressId || + queryUser.CreditCard.ID == user.CreditCard.ID || + len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { + t.Errorf("Should only update selected relationships") + } +} + +func TestSelectWithUpdateWithMap(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{ + "Name": "new_name", + "Age": 50, + "BillingAddress": Address{Address1: "New Billing Address"}, + "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, + "CreditCard": CreditCard{Number: "987654321"}, + "Emails": []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + }, + "Company": Company{Name: "new company"}, + } + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name == user.Name || queryUser.Age != user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || + queryUser.ShippingAddressId != user.ShippingAddressId || + queryUser.CreditCard.ID == user.CreditCard.ID || + len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { + t.Errorf("Should only update selected relationships") + } +} + +func TestOmitWithUpdate(t *testing.T) { + user := getPreparedUser("omit_user", "omit_with_update") + DB.Create(user) + + var reloadUser User + DB.First(&reloadUser, user.Id) + reloadUser.Name = "new_name" + reloadUser.Age = 50 + reloadUser.BillingAddress = Address{Address1: "New Billing Address"} + reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} + reloadUser.CreditCard = CreditCard{Number: "987654321"} + reloadUser.Emails = []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + } + reloadUser.Company = Company{Name: "new company"} + + DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name != user.Name || queryUser.Age == user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || + queryUser.ShippingAddressId == user.ShippingAddressId || + queryUser.CreditCard.ID != user.CreditCard.ID || + len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { + t.Errorf("Should only update relationships that not omited") + } +} + +func TestOmitWithUpdateWithMap(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{ + "Name": "new_name", + "Age": 50, + "BillingAddress": Address{Address1: "New Billing Address"}, + "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, + "CreditCard": CreditCard{Number: "987654321"}, + "Emails": []Email{ + {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, + }, + "Company": Company{Name: "new company"}, + } + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) + + var queryUser User + DB.Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) + + if queryUser.Name != user.Name || queryUser.Age == user.Age { + t.Errorf("Should only update users with name column") + } + + if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || + queryUser.ShippingAddressId == user.ShippingAddressId || + queryUser.CreditCard.ID != user.CreditCard.ID || + len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { + t.Errorf("Should only update relationships not omited") + } +} + +func TestSelectWithUpdateColumn(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) + + var queryUser User + DB.First(&queryUser, user.Id) + + if queryUser.Name == user.Name || queryUser.Age != user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestOmitWithUpdateColumn(t *testing.T) { + user := getPreparedUser("select_user", "select_with_update_map") + DB.Create(user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var reloadUser User + DB.First(&reloadUser, user.Id) + DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) + + var queryUser User + DB.First(&queryUser, user.Id) + + if queryUser.Name != user.Name || queryUser.Age == user.Age { + t.Errorf("Should omit name column when update user") + } +} + +func TestUpdateColumnsSkipsAssociations(t *testing.T) { + user := getPreparedUser("update_columns_user", "special_role") + user.Age = 99 + address1 := "first street" + user.BillingAddress = Address{Address1: address1} + DB.Save(user) + + // Update a single field of the user and verify that the changed address is not stored. + newAge := int64(100) + user.BillingAddress.Address1 = "second street" + db := DB.Model(user).UpdateColumns(User{Age: newAge}) + if db.RowsAffected != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) + } + + // Verify that Age now=`newAge`. + freshUser := &User{Id: user.Id} + DB.First(freshUser) + if freshUser.Age != newAge { + t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) + } + + // Verify that user's BillingAddress.Address1 is not changed and is still "first street". + DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) + if freshUser.BillingAddress.Address1 != address1 { + t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) + } +} + +func TestUpdatesWithBlankValues(t *testing.T) { + product := Product{Code: "product1", Price: 10} + DB.Save(&product) + + DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) + + var product1 Product + DB.First(&product1, product.Id) + + if product1.Code != "product1" || product1.Price != 100 { + t.Errorf("product's code should not be updated") + } +} + +func TestUpdateDecodeVirtualAttributes(t *testing.T) { + var user = User{ + Name: "jinzhu", + IgnoreMe: 88, + } + + DB.Save(&user) + + DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) + + if user.IgnoreMe != 100 { + t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") + } +} diff --git a/vendor/github.com/jinzhu/gorm/utils.go b/vendor/github.com/jinzhu/gorm/utils.go new file mode 100644 index 000000000..dc69e8046 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/utils.go @@ -0,0 +1,264 @@ +package gorm + +import ( + "bytes" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "runtime" + "strings" + "sync" + "time" +) + +// NowFunc returns current time, this function is exported in order to be able +// to give the flexibility to the developer to customize it according to their +// needs, e.g: +// gorm.NowFunc = func() time.Time { +// return time.Now().UTC() +// } +var NowFunc = func() time.Time { + return time.Now() +} + +// Copied from golint +var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} +var commonInitialismsReplacer *strings.Replacer + +func init() { + var commonInitialismsForReplacer []string + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +type safeMap struct { + m map[string]string + l *sync.RWMutex +} + +func (s *safeMap) Set(key string, value string) { + s.l.Lock() + defer s.l.Unlock() + s.m[key] = value +} + +func (s *safeMap) Get(key string) string { + s.l.RLock() + defer s.l.RUnlock() + return s.m[key] +} + +func newSafeMap() *safeMap { + return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} +} + +var smap = newSafeMap() + +type strCase bool + +const ( + lower strCase = false + upper strCase = true +) + +// ToDBName convert string to db name +func ToDBName(name string) string { + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase strCase + ) + + for i, v := range value[:len(value)-1] { + nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') + if i > 0 { + if currCase == upper { + if lastCase == upper && nextCase == upper { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} + +// SQL expression +type expr struct { + expr string + args []interface{} +} + +// Expr generate raw SQL expression, for example: +// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) +func Expr(expression string, args ...interface{}) *expr { + return &expr{expr: expression, args: args} +} + +func indirect(reflectValue reflect.Value) reflect.Value { + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + return reflectValue +} + +func toQueryMarks(primaryValues [][]interface{}) string { + var results []string + + for _, primaryValue := range primaryValues { + var marks []string + for range primaryValue { + marks = append(marks, "?") + } + + if len(marks) > 1 { + results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) + } else { + results = append(results, strings.Join(marks, "")) + } + } + return strings.Join(results, ",") +} + +func toQueryCondition(scope *Scope, columns []string) string { + var newColumns []string + for _, column := range columns { + newColumns = append(newColumns, scope.Quote(column)) + } + + if len(columns) > 1 { + return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) + } + return strings.Join(newColumns, ",") +} + +func toQueryValues(values [][]interface{}) (results []interface{}) { + for _, value := range values { + for _, v := range value { + results = append(results, v) + } + } + return +} + +func fileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { + return fmt.Sprintf("%v:%v", file, line) + } + } + return "" +} + +func isBlank(value reflect.Value) bool { + return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) +} + +func toSearchableMap(attrs ...interface{}) (result interface{}) { + if len(attrs) > 1 { + if str, ok := attrs[0].(string); ok { + result = map[string]interface{}{str: attrs[1]} + } + } else if len(attrs) == 1 { + if attr, ok := attrs[0].(map[string]interface{}); ok { + result = attr + } + + if attr, ok := attrs[0].(interface{}); ok { + result = attr + } + } + return +} + +func equalAsString(a interface{}, b interface{}) bool { + return toString(a) == toString(b) +} + +func toString(str interface{}) string { + if values, ok := str.([]interface{}); ok { + var results []string + for _, value := range values { + results = append(results, toString(value)) + } + return strings.Join(results, "_") + } else if bytes, ok := str.([]byte); ok { + return string(bytes) + } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { + return fmt.Sprintf("%v", reflectValue.Interface()) + } + return "" +} + +func makeSlice(elemType reflect.Type) interface{} { + if elemType.Kind() == reflect.Slice { + elemType = elemType.Elem() + } + sliceType := reflect.SliceOf(elemType) + slice := reflect.New(sliceType) + slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) + return slice.Interface() +} + +func strInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +// getValueFromFields return given fields's value +func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { + // If value is a nil pointer, Indirect returns a zero Value! + // Therefor we need to check for a zero value, + // as FieldByName could panic + if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { + for _, fieldName := range fieldNames { + if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + result := fieldValue.Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } + } + } + return +} + +func addExtraSpaceIfExist(str string) string { + if str != "" { + return " " + str + } + return "" +} diff --git a/vendor/github.com/jinzhu/gorm/utils_test.go b/vendor/github.com/jinzhu/gorm/utils_test.go new file mode 100644 index 000000000..07f5b17f4 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/utils_test.go @@ -0,0 +1,30 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestToDBNameGenerateFriendlyName(t *testing.T) { + var maps = map[string]string{ + "": "", + "ThisIsATest": "this_is_a_test", + "PFAndESI": "pf_and_esi", + "AbcAndJkl": "abc_and_jkl", + "EmployeeID": "employee_id", + "SKU_ID": "sku_id", + "HTTPAndSMTP": "http_and_smtp", + "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", + "UUID": "uuid", + "HTTPURL": "http_url", + "HTTP_URL": "http_url", + "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", + } + + for key, value := range maps { + if gorm.ToDBName(key) != value { + t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) + } + } +} diff --git a/vendor/github.com/jinzhu/inflection/LICENSE b/vendor/github.com/jinzhu/inflection/LICENSE new file mode 100644 index 000000000..a1ca9a0ff --- /dev/null +++ b/vendor/github.com/jinzhu/inflection/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2015 - Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/jinzhu/inflection/README.md b/vendor/github.com/jinzhu/inflection/README.md new file mode 100644 index 000000000..4dd0f2d9f --- /dev/null +++ b/vendor/github.com/jinzhu/inflection/README.md @@ -0,0 +1,55 @@ +Inflection +========= + +Inflection pluralizes and singularizes English nouns + +## Basic Usage + +```go +inflection.Plural("person") => "people" +inflection.Plural("Person") => "People" +inflection.Plural("PERSON") => "PEOPLE" +inflection.Plural("bus") => "buses" +inflection.Plural("BUS") => "BUSES" +inflection.Plural("Bus") => "Buses" + +inflection.Singular("people") => "person" +inflection.Singular("People") => "Person" +inflection.Singular("PEOPLE") => "PERSON" +inflection.Singular("buses") => "bus" +inflection.Singular("BUSES") => "BUS" +inflection.Singular("Buses") => "Bus" + +inflection.Plural("FancyPerson") => "FancyPeople" +inflection.Singular("FancyPeople") => "FancyPerson" +``` + +## Register Rules + +Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb) + +If you want to register more rules, follow: + +``` +inflection.AddUncountable("fish") +inflection.AddIrregular("person", "people") +inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses" +inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS" +``` + +## Supporting the project + +[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) + + +## Author + +**jinzhu** + +* +* +* + +## License + +Released under the [MIT License](http://www.opensource.org/licenses/MIT). diff --git a/vendor/github.com/jinzhu/inflection/inflections.go b/vendor/github.com/jinzhu/inflection/inflections.go new file mode 100644 index 000000000..606263bb7 --- /dev/null +++ b/vendor/github.com/jinzhu/inflection/inflections.go @@ -0,0 +1,273 @@ +/* +Package inflection pluralizes and singularizes English nouns. + + inflection.Plural("person") => "people" + inflection.Plural("Person") => "People" + inflection.Plural("PERSON") => "PEOPLE" + + inflection.Singular("people") => "person" + inflection.Singular("People") => "Person" + inflection.Singular("PEOPLE") => "PERSON" + + inflection.Plural("FancyPerson") => "FancydPeople" + inflection.Singular("FancyPeople") => "FancydPerson" + +Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb) + +If you want to register more rules, follow: + + inflection.AddUncountable("fish") + inflection.AddIrregular("person", "people") + inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses" + inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS" +*/ +package inflection + +import ( + "regexp" + "strings" +) + +type inflection struct { + regexp *regexp.Regexp + replace string +} + +// Regular is a regexp find replace inflection +type Regular struct { + find string + replace string +} + +// Irregular is a hard replace inflection, +// containing both singular and plural forms +type Irregular struct { + singular string + plural string +} + +// RegularSlice is a slice of Regular inflections +type RegularSlice []Regular + +// IrregularSlice is a slice of Irregular inflections +type IrregularSlice []Irregular + +var pluralInflections = RegularSlice{ + {"([a-z])$", "${1}s"}, + {"s$", "s"}, + {"^(ax|test)is$", "${1}es"}, + {"(octop|vir)us$", "${1}i"}, + {"(octop|vir)i$", "${1}i"}, + {"(alias|status)$", "${1}es"}, + {"(bu)s$", "${1}ses"}, + {"(buffal|tomat)o$", "${1}oes"}, + {"([ti])um$", "${1}a"}, + {"([ti])a$", "${1}a"}, + {"sis$", "ses"}, + {"(?:([^f])fe|([lr])f)$", "${1}${2}ves"}, + {"(hive)$", "${1}s"}, + {"([^aeiouy]|qu)y$", "${1}ies"}, + {"(x|ch|ss|sh)$", "${1}es"}, + {"(matr|vert|ind)(?:ix|ex)$", "${1}ices"}, + {"^(m|l)ouse$", "${1}ice"}, + {"^(m|l)ice$", "${1}ice"}, + {"^(ox)$", "${1}en"}, + {"^(oxen)$", "${1}"}, + {"(quiz)$", "${1}zes"}, +} + +var singularInflections = RegularSlice{ + {"s$", ""}, + {"(ss)$", "${1}"}, + {"(n)ews$", "${1}ews"}, + {"([ti])a$", "${1}um"}, + {"((a)naly|(b)a|(d)iagno|(p)arenthe|(p)rogno|(s)ynop|(t)he)(sis|ses)$", "${1}sis"}, + {"(^analy)(sis|ses)$", "${1}sis"}, + {"([^f])ves$", "${1}fe"}, + {"(hive)s$", "${1}"}, + {"(tive)s$", "${1}"}, + {"([lr])ves$", "${1}f"}, + {"([^aeiouy]|qu)ies$", "${1}y"}, + {"(s)eries$", "${1}eries"}, + {"(m)ovies$", "${1}ovie"}, + {"(c)ookies$", "${1}ookie"}, + {"(x|ch|ss|sh)es$", "${1}"}, + {"^(m|l)ice$", "${1}ouse"}, + {"(bus)(es)?$", "${1}"}, + {"(o)es$", "${1}"}, + {"(shoe)s$", "${1}"}, + {"(cris|test)(is|es)$", "${1}is"}, + {"^(a)x[ie]s$", "${1}xis"}, + {"(octop|vir)(us|i)$", "${1}us"}, + {"(alias|status)(es)?$", "${1}"}, + {"^(ox)en", "${1}"}, + {"(vert|ind)ices$", "${1}ex"}, + {"(matr)ices$", "${1}ix"}, + {"(quiz)zes$", "${1}"}, + {"(database)s$", "${1}"}, +} + +var irregularInflections = IrregularSlice{ + {"person", "people"}, + {"man", "men"}, + {"child", "children"}, + {"sex", "sexes"}, + {"move", "moves"}, + {"mombie", "mombies"}, +} + +var uncountableInflections = []string{"equipment", "information", "rice", "money", "species", "series", "fish", "sheep", "jeans", "police"} + +var compiledPluralMaps []inflection +var compiledSingularMaps []inflection + +func compile() { + compiledPluralMaps = []inflection{} + compiledSingularMaps = []inflection{} + for _, uncountable := range uncountableInflections { + inf := inflection{ + regexp: regexp.MustCompile("^(?i)(" + uncountable + ")$"), + replace: "${1}", + } + compiledPluralMaps = append(compiledPluralMaps, inf) + compiledSingularMaps = append(compiledSingularMaps, inf) + } + + for _, value := range irregularInflections { + infs := []inflection{ + inflection{regexp: regexp.MustCompile(strings.ToUpper(value.singular) + "$"), replace: strings.ToUpper(value.plural)}, + inflection{regexp: regexp.MustCompile(strings.Title(value.singular) + "$"), replace: strings.Title(value.plural)}, + inflection{regexp: regexp.MustCompile(value.singular + "$"), replace: value.plural}, + } + compiledPluralMaps = append(compiledPluralMaps, infs...) + } + + for _, value := range irregularInflections { + infs := []inflection{ + inflection{regexp: regexp.MustCompile(strings.ToUpper(value.plural) + "$"), replace: strings.ToUpper(value.singular)}, + inflection{regexp: regexp.MustCompile(strings.Title(value.plural) + "$"), replace: strings.Title(value.singular)}, + inflection{regexp: regexp.MustCompile(value.plural + "$"), replace: value.singular}, + } + compiledSingularMaps = append(compiledSingularMaps, infs...) + } + + for i := len(pluralInflections) - 1; i >= 0; i-- { + value := pluralInflections[i] + infs := []inflection{ + inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)}, + inflection{regexp: regexp.MustCompile(value.find), replace: value.replace}, + inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace}, + } + compiledPluralMaps = append(compiledPluralMaps, infs...) + } + + for i := len(singularInflections) - 1; i >= 0; i-- { + value := singularInflections[i] + infs := []inflection{ + inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)}, + inflection{regexp: regexp.MustCompile(value.find), replace: value.replace}, + inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace}, + } + compiledSingularMaps = append(compiledSingularMaps, infs...) + } +} + +func init() { + compile() +} + +// AddPlural adds a plural inflection +func AddPlural(find, replace string) { + pluralInflections = append(pluralInflections, Regular{find, replace}) + compile() +} + +// AddSingular adds a singular inflection +func AddSingular(find, replace string) { + singularInflections = append(singularInflections, Regular{find, replace}) + compile() +} + +// AddIrregular adds an irregular inflection +func AddIrregular(singular, plural string) { + irregularInflections = append(irregularInflections, Irregular{singular, plural}) + compile() +} + +// AddUncountable adds an uncountable inflection +func AddUncountable(values ...string) { + uncountableInflections = append(uncountableInflections, values...) + compile() +} + +// GetPlural retrieves the plural inflection values +func GetPlural() RegularSlice { + plurals := make(RegularSlice, len(pluralInflections)) + copy(plurals, pluralInflections) + return plurals +} + +// GetSingular retrieves the singular inflection values +func GetSingular() RegularSlice { + singulars := make(RegularSlice, len(singularInflections)) + copy(singulars, singularInflections) + return singulars +} + +// GetIrregular retrieves the irregular inflection values +func GetIrregular() IrregularSlice { + irregular := make(IrregularSlice, len(irregularInflections)) + copy(irregular, irregularInflections) + return irregular +} + +// GetUncountable retrieves the uncountable inflection values +func GetUncountable() []string { + uncountables := make([]string, len(uncountableInflections)) + copy(uncountables, uncountableInflections) + return uncountables +} + +// SetPlural sets the plural inflections slice +func SetPlural(inflections RegularSlice) { + pluralInflections = inflections + compile() +} + +// SetSingular sets the singular inflections slice +func SetSingular(inflections RegularSlice) { + singularInflections = inflections + compile() +} + +// SetIrregular sets the irregular inflections slice +func SetIrregular(inflections IrregularSlice) { + irregularInflections = inflections + compile() +} + +// SetUncountable sets the uncountable inflections slice +func SetUncountable(inflections []string) { + uncountableInflections = inflections + compile() +} + +// Plural converts a word to its plural form +func Plural(str string) string { + for _, inflection := range compiledPluralMaps { + if inflection.regexp.MatchString(str) { + return inflection.regexp.ReplaceAllString(str, inflection.replace) + } + } + return str +} + +// Singular converts a word to its singular form +func Singular(str string) string { + for _, inflection := range compiledSingularMaps { + if inflection.regexp.MatchString(str) { + return inflection.regexp.ReplaceAllString(str, inflection.replace) + } + } + return str +} diff --git a/vendor/github.com/jinzhu/inflection/inflections_test.go b/vendor/github.com/jinzhu/inflection/inflections_test.go new file mode 100644 index 000000000..689e1dfb1 --- /dev/null +++ b/vendor/github.com/jinzhu/inflection/inflections_test.go @@ -0,0 +1,213 @@ +package inflection + +import ( + "strings" + "testing" +) + +var inflections = map[string]string{ + "star": "stars", + "STAR": "STARS", + "Star": "Stars", + "bus": "buses", + "fish": "fish", + "mouse": "mice", + "query": "queries", + "ability": "abilities", + "agency": "agencies", + "movie": "movies", + "archive": "archives", + "index": "indices", + "wife": "wives", + "safe": "saves", + "half": "halves", + "move": "moves", + "salesperson": "salespeople", + "person": "people", + "spokesman": "spokesmen", + "man": "men", + "woman": "women", + "basis": "bases", + "diagnosis": "diagnoses", + "diagnosis_a": "diagnosis_as", + "datum": "data", + "medium": "media", + "stadium": "stadia", + "analysis": "analyses", + "node_child": "node_children", + "child": "children", + "experience": "experiences", + "day": "days", + "comment": "comments", + "foobar": "foobars", + "newsletter": "newsletters", + "old_news": "old_news", + "news": "news", + "series": "series", + "species": "species", + "quiz": "quizzes", + "perspective": "perspectives", + "ox": "oxen", + "photo": "photos", + "buffalo": "buffaloes", + "tomato": "tomatoes", + "dwarf": "dwarves", + "elf": "elves", + "information": "information", + "equipment": "equipment", + "criterion": "criteria", +} + +// storage is used to restore the state of the global variables +// on each test execution, to ensure no global state pollution +type storage struct { + singulars RegularSlice + plurals RegularSlice + irregulars IrregularSlice + uncountables []string +} + +var backup = storage{} + +func init() { + AddIrregular("criterion", "criteria") + copy(backup.singulars, singularInflections) + copy(backup.plurals, pluralInflections) + copy(backup.irregulars, irregularInflections) + copy(backup.uncountables, uncountableInflections) +} + +func restore() { + copy(singularInflections, backup.singulars) + copy(pluralInflections, backup.plurals) + copy(irregularInflections, backup.irregulars) + copy(uncountableInflections, backup.uncountables) +} + +func TestPlural(t *testing.T) { + for key, value := range inflections { + if v := Plural(strings.ToUpper(key)); v != strings.ToUpper(value) { + t.Errorf("%v's plural should be %v, but got %v", strings.ToUpper(key), strings.ToUpper(value), v) + } + + if v := Plural(strings.Title(key)); v != strings.Title(value) { + t.Errorf("%v's plural should be %v, but got %v", strings.Title(key), strings.Title(value), v) + } + + if v := Plural(key); v != value { + t.Errorf("%v's plural should be %v, but got %v", key, value, v) + } + } +} + +func TestSingular(t *testing.T) { + for key, value := range inflections { + if v := Singular(strings.ToUpper(value)); v != strings.ToUpper(key) { + t.Errorf("%v's singular should be %v, but got %v", strings.ToUpper(value), strings.ToUpper(key), v) + } + + if v := Singular(strings.Title(value)); v != strings.Title(key) { + t.Errorf("%v's singular should be %v, but got %v", strings.Title(value), strings.Title(key), v) + } + + if v := Singular(value); v != key { + t.Errorf("%v's singular should be %v, but got %v", value, key, v) + } + } +} + +func TestAddPlural(t *testing.T) { + defer restore() + ln := len(pluralInflections) + AddPlural("", "") + if ln+1 != len(pluralInflections) { + t.Errorf("Expected len %d, got %d", ln+1, len(pluralInflections)) + } +} + +func TestAddSingular(t *testing.T) { + defer restore() + ln := len(singularInflections) + AddSingular("", "") + if ln+1 != len(singularInflections) { + t.Errorf("Expected len %d, got %d", ln+1, len(singularInflections)) + } +} + +func TestAddIrregular(t *testing.T) { + defer restore() + ln := len(irregularInflections) + AddIrregular("", "") + if ln+1 != len(irregularInflections) { + t.Errorf("Expected len %d, got %d", ln+1, len(irregularInflections)) + } +} + +func TestAddUncountable(t *testing.T) { + defer restore() + ln := len(uncountableInflections) + AddUncountable("", "") + if ln+2 != len(uncountableInflections) { + t.Errorf("Expected len %d, got %d", ln+2, len(uncountableInflections)) + } +} + +func TestGetPlural(t *testing.T) { + plurals := GetPlural() + if len(plurals) != len(pluralInflections) { + t.Errorf("Expected len %d, got %d", len(plurals), len(pluralInflections)) + } +} + +func TestGetSingular(t *testing.T) { + singular := GetSingular() + if len(singular) != len(singularInflections) { + t.Errorf("Expected len %d, got %d", len(singular), len(singularInflections)) + } +} + +func TestGetIrregular(t *testing.T) { + irregular := GetIrregular() + if len(irregular) != len(irregularInflections) { + t.Errorf("Expected len %d, got %d", len(irregular), len(irregularInflections)) + } +} + +func TestGetUncountable(t *testing.T) { + uncountables := GetUncountable() + if len(uncountables) != len(uncountableInflections) { + t.Errorf("Expected len %d, got %d", len(uncountables), len(uncountableInflections)) + } +} + +func TestSetPlural(t *testing.T) { + defer restore() + SetPlural(RegularSlice{{}, {}}) + if len(pluralInflections) != 2 { + t.Errorf("Expected len 2, got %d", len(pluralInflections)) + } +} + +func TestSetSingular(t *testing.T) { + defer restore() + SetSingular(RegularSlice{{}, {}}) + if len(singularInflections) != 2 { + t.Errorf("Expected len 2, got %d", len(singularInflections)) + } +} + +func TestSetIrregular(t *testing.T) { + defer restore() + SetIrregular(IrregularSlice{{}, {}}) + if len(irregularInflections) != 2 { + t.Errorf("Expected len 2, got %d", len(irregularInflections)) + } +} + +func TestSetUncountable(t *testing.T) { + defer restore() + SetUncountable([]string{"", ""}) + if len(uncountableInflections) != 2 { + t.Errorf("Expected len 2, got %d", len(uncountableInflections)) + } +} diff --git a/vendor/github.com/jmoiron/sqlx/.gitignore b/vendor/github.com/jmoiron/sqlx/.gitignore deleted file mode 100644 index 529841cf1..000000000 --- a/vendor/github.com/jmoiron/sqlx/.gitignore +++ /dev/null @@ -1,24 +0,0 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - -*.exe -tags -environ diff --git a/vendor/github.com/jmoiron/sqlx/LICENSE b/vendor/github.com/jmoiron/sqlx/LICENSE deleted file mode 100644 index 0d31edfa7..000000000 --- a/vendor/github.com/jmoiron/sqlx/LICENSE +++ /dev/null @@ -1,23 +0,0 @@ - Copyright (c) 2013, Jason Moiron - - Permission is hereby granted, free of charge, to any person - obtaining a copy of this software and associated documentation - files (the "Software"), to deal in the Software without - restriction, including without limitation the rights to use, - copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the - Software is furnished to do so, subject to the following - conditions: - - The above copyright notice and this permission notice shall be - included in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - OTHER DEALINGS IN THE SOFTWARE. - diff --git a/vendor/github.com/jmoiron/sqlx/README.md b/vendor/github.com/jmoiron/sqlx/README.md deleted file mode 100644 index 5c1bb3cb9..000000000 --- a/vendor/github.com/jmoiron/sqlx/README.md +++ /dev/null @@ -1,183 +0,0 @@ -# sqlx - -[![Build Status](https://drone.io/github.com/jmoiron/sqlx/status.png)](https://drone.io/github.com/jmoiron/sqlx/latest) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE) - -sqlx is a library which provides a set of extensions on go's standard -`database/sql` library. The sqlx versions of `sql.DB`, `sql.TX`, `sql.Stmt`, -et al. all leave the underlying interfaces untouched, so that their interfaces -are a superset on the standard ones. This makes it relatively painless to -integrate existing codebases using database/sql with sqlx. - -Major additional concepts are: - -* Marshal rows into structs (with embedded struct support), maps, and slices -* Named parameter support including prepared statements -* `Get` and `Select` to go quickly from query to struct/slice - -In addition to the [godoc API documentation](http://godoc.org/github.com/jmoiron/sqlx), -there is also some [standard documentation](http://jmoiron.github.io/sqlx/) that -explains how to use `database/sql` along with sqlx. - -## Recent Changes - -* sqlx/types.JsonText has been renamed to JSONText to follow Go naming conventions. - -This breaks backwards compatibility, but it's in a way that is trivially fixable -(`s/JsonText/JSONText/g`). The `types` package is both experimental and not in -active development currently. - -* Using Go 1.6 and below with `types.JSONText` and `types.GzippedText` can be _potentially unsafe_, **especially** when used with common auto-scan sqlx idioms like `Select` and `Get`. See [golang bug #13905](https://github.com/golang/go/issues/13905). - -### Backwards Compatibility - -There is no Go1-like promise of absolute stability, but I take the issue seriously -and will maintain the library in a compatible state unless vital bugs prevent me -from doing so. Since [#59](https://github.com/jmoiron/sqlx/issues/59) and -[#60](https://github.com/jmoiron/sqlx/issues/60) necessitated breaking behavior, -a wider API cleanup was done at the time of fixing. It's possible this will happen -in future; if it does, a git tag will be provided for users requiring the old -behavior to continue to use it until such a time as they can migrate. - -## install - - go get github.com/jmoiron/sqlx - -## issues - -Row headers can be ambiguous (`SELECT 1 AS a, 2 AS a`), and the result of -`Columns()` does not fully qualify column names in queries like: - -```sql -SELECT a.id, a.name, b.id, b.name FROM foos AS a JOIN foos AS b ON a.parent = b.id; -``` - -making a struct or map destination ambiguous. Use `AS` in your queries -to give columns distinct names, `rows.Scan` to scan them manually, or -`SliceScan` to get a slice of results. - -## usage - -Below is an example which shows some common use cases for sqlx. Check -[sqlx_test.go](https://github.com/jmoiron/sqlx/blob/master/sqlx_test.go) for more -usage. - - -```go -package main - -import ( - _ "github.com/lib/pq" - "database/sql" - "github.com/jmoiron/sqlx" - "log" -) - -var schema = ` -CREATE TABLE person ( - first_name text, - last_name text, - email text -); - -CREATE TABLE place ( - country text, - city text NULL, - telcode integer -)` - -type Person struct { - FirstName string `db:"first_name"` - LastName string `db:"last_name"` - Email string -} - -type Place struct { - Country string - City sql.NullString - TelCode int -} - -func main() { - // this Pings the database trying to connect, panics on error - // use sqlx.Open() for sql.Open() semantics - db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable") - if err != nil { - log.Fatalln(err) - } - - // exec the schema or fail; multi-statement Exec behavior varies between - // database drivers; pq will exec them all, sqlite3 won't, ymmv - db.MustExec(schema) - - tx := db.MustBegin() - tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "Jason", "Moiron", "jmoiron@jmoiron.net") - tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "John", "Doe", "johndoeDNE@gmail.net") - tx.MustExec("INSERT INTO place (country, city, telcode) VALUES ($1, $2, $3)", "United States", "New York", "1") - tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Hong Kong", "852") - tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Singapore", "65") - // Named queries can use structs, so if you have an existing struct (i.e. person := &Person{}) that you have populated, you can pass it in as &person - tx.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", &Person{"Jane", "Citizen", "jane.citzen@example.com"}) - tx.Commit() - - // Query the database, storing results in a []Person (wrapped in []interface{}) - people := []Person{} - db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") - jason, john := people[0], people[1] - - fmt.Printf("%#v\n%#v", jason, john) - // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} - // Person{FirstName:"John", LastName:"Doe", Email:"johndoeDNE@gmail.net"} - - // You can also get a single result, a la QueryRow - jason = Person{} - err = db.Get(&jason, "SELECT * FROM person WHERE first_name=$1", "Jason") - fmt.Printf("%#v\n", jason) - // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} - - // if you have null fields and use SELECT *, you must use sql.Null* in your struct - places := []Place{} - err = db.Select(&places, "SELECT * FROM place ORDER BY telcode ASC") - if err != nil { - fmt.Println(err) - return - } - usa, singsing, honkers := places[0], places[1], places[2] - - fmt.Printf("%#v\n%#v\n%#v\n", usa, singsing, honkers) - // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} - // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} - // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} - - // Loop through rows using only one struct - place := Place{} - rows, err := db.Queryx("SELECT * FROM place") - for rows.Next() { - err := rows.StructScan(&place) - if err != nil { - log.Fatalln(err) - } - fmt.Printf("%#v\n", place) - } - // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} - // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} - // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} - - // Named queries, using `:name` as the bindvar. Automatic bindvar support - // which takes into account the dbtype based on the driverName on sqlx.Open/Connect - _, err = db.NamedExec(`INSERT INTO person (first_name,last_name,email) VALUES (:first,:last,:email)`, - map[string]interface{}{ - "first": "Bin", - "last": "Smuth", - "email": "bensmith@allblacks.nz", - }) - - // Selects Mr. Smith from the database - rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:fn`, map[string]interface{}{"fn": "Bin"}) - - // Named queries can also use structs. Their bind names follow the same rules - // as the name -> db mapping, so struct fields are lowercased and the `db` tag - // is taken into consideration. - rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:first_name`, jason) -} -``` - diff --git a/vendor/github.com/jmoiron/sqlx/bind.go b/vendor/github.com/jmoiron/sqlx/bind.go deleted file mode 100644 index 10f7bdf84..000000000 --- a/vendor/github.com/jmoiron/sqlx/bind.go +++ /dev/null @@ -1,207 +0,0 @@ -package sqlx - -import ( - "bytes" - "errors" - "reflect" - "strconv" - "strings" - - "github.com/jmoiron/sqlx/reflectx" -) - -// Bindvar types supported by Rebind, BindMap and BindStruct. -const ( - UNKNOWN = iota - QUESTION - DOLLAR - NAMED -) - -// BindType returns the bindtype for a given database given a drivername. -func BindType(driverName string) int { - switch driverName { - case "postgres", "pgx": - return DOLLAR - case "mysql": - return QUESTION - case "sqlite3": - return QUESTION - case "oci8", "ora", "goracle": - return NAMED - } - return UNKNOWN -} - -// FIXME: this should be able to be tolerant of escaped ?'s in queries without -// losing much speed, and should be to avoid confusion. - -// Rebind a query from the default bindtype (QUESTION) to the target bindtype. -func Rebind(bindType int, query string) string { - switch bindType { - case QUESTION, UNKNOWN: - return query - } - - // Add space enough for 10 params before we have to allocate - rqb := make([]byte, 0, len(query)+10) - - var i, j int - - for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { - rqb = append(rqb, query[:i]...) - - switch bindType { - case DOLLAR: - rqb = append(rqb, '$') - case NAMED: - rqb = append(rqb, ':', 'a', 'r', 'g') - } - - j++ - rqb = strconv.AppendInt(rqb, int64(j), 10) - - query = query[i+1:] - } - - return string(append(rqb, query...)) -} - -// Experimental implementation of Rebind which uses a bytes.Buffer. The code is -// much simpler and should be more resistant to odd unicode, but it is twice as -// slow. Kept here for benchmarking purposes and to possibly replace Rebind if -// problems arise with its somewhat naive handling of unicode. -func rebindBuff(bindType int, query string) string { - if bindType != DOLLAR { - return query - } - - b := make([]byte, 0, len(query)) - rqb := bytes.NewBuffer(b) - j := 1 - for _, r := range query { - if r == '?' { - rqb.WriteRune('$') - rqb.WriteString(strconv.Itoa(j)) - j++ - } else { - rqb.WriteRune(r) - } - } - - return rqb.String() -} - -// In expands slice values in args, returning the modified query string -// and a new arg list that can be executed by a database. The `query` should -// use the `?` bindVar. The return value uses the `?` bindVar. -func In(query string, args ...interface{}) (string, []interface{}, error) { - // argMeta stores reflect.Value and length for slices and - // the value itself for non-slice arguments - type argMeta struct { - v reflect.Value - i interface{} - length int - } - - var flatArgsCount int - var anySlices bool - - meta := make([]argMeta, len(args)) - - for i, arg := range args { - v := reflect.ValueOf(arg) - t := reflectx.Deref(v.Type()) - - if t.Kind() == reflect.Slice { - meta[i].length = v.Len() - meta[i].v = v - - anySlices = true - flatArgsCount += meta[i].length - - if meta[i].length == 0 { - return "", nil, errors.New("empty slice passed to 'in' query") - } - } else { - meta[i].i = arg - flatArgsCount++ - } - } - - // don't do any parsing if there aren't any slices; note that this means - // some errors that we might have caught below will not be returned. - if !anySlices { - return query, args, nil - } - - newArgs := make([]interface{}, 0, flatArgsCount) - buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount)) - - var arg, offset int - - for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { - if arg >= len(meta) { - // if an argument wasn't passed, lets return an error; this is - // not actually how database/sql Exec/Query works, but since we are - // creating an argument list programmatically, we want to be able - // to catch these programmer errors earlier. - return "", nil, errors.New("number of bindVars exceeds arguments") - } - - argMeta := meta[arg] - arg++ - - // not a slice, continue. - // our questionmark will either be written before the next expansion - // of a slice or after the loop when writing the rest of the query - if argMeta.length == 0 { - offset = offset + i + 1 - newArgs = append(newArgs, argMeta.i) - continue - } - - // write everything up to and including our ? character - buf.WriteString(query[:offset+i+1]) - - for si := 1; si < argMeta.length; si++ { - buf.WriteString(", ?") - } - - newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) - - // slice the query and reset the offset. this avoids some bookkeeping for - // the write after the loop - query = query[offset+i+1:] - offset = 0 - } - - buf.WriteString(query) - - if arg < len(meta) { - return "", nil, errors.New("number of bindVars less than number arguments") - } - - return buf.String(), newArgs, nil -} - -func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { - switch val := v.Interface().(type) { - case []interface{}: - args = append(args, val...) - case []int: - for i := range val { - args = append(args, val[i]) - } - case []string: - for i := range val { - args = append(args, val[i]) - } - default: - for si := 0; si < vlen; si++ { - args = append(args, v.Index(si).Interface()) - } - } - - return args -} diff --git a/vendor/github.com/jmoiron/sqlx/doc.go b/vendor/github.com/jmoiron/sqlx/doc.go deleted file mode 100644 index e2b4e60b2..000000000 --- a/vendor/github.com/jmoiron/sqlx/doc.go +++ /dev/null @@ -1,12 +0,0 @@ -// Package sqlx provides general purpose extensions to database/sql. -// -// It is intended to seamlessly wrap database/sql and provide convenience -// methods which are useful in the development of database driven applications. -// None of the underlying database/sql methods are changed. Instead all extended -// behavior is implemented through new methods defined on wrapper types. -// -// Additions include scanning into structs, named query support, rebinding -// queries for different drivers, convenient shorthands for common error handling -// and more. -// -package sqlx diff --git a/vendor/github.com/jmoiron/sqlx/named.go b/vendor/github.com/jmoiron/sqlx/named.go deleted file mode 100644 index dd899d351..000000000 --- a/vendor/github.com/jmoiron/sqlx/named.go +++ /dev/null @@ -1,344 +0,0 @@ -package sqlx - -// Named Query Support -// -// * BindMap - bind query bindvars to map/struct args -// * NamedExec, NamedQuery - named query w/ struct or map -// * NamedStmt - a pre-compiled named query which is a prepared statement -// -// Internal Interfaces: -// -// * compileNamedQuery - rebind a named query, returning a query and list of names -// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist -// -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strconv" - "unicode" - - "github.com/jmoiron/sqlx/reflectx" -) - -// NamedStmt is a prepared statement that executes named queries. Prepare it -// how you would execute a NamedQuery, but pass in a struct or map when executing. -type NamedStmt struct { - Params []string - QueryString string - Stmt *Stmt -} - -// Close closes the named statement. -func (n *NamedStmt) Close() error { - return n.Stmt.Close() -} - -// Exec executes a named statement using the struct passed. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { - args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) - if err != nil { - return *new(sql.Result), err - } - return n.Stmt.Exec(args...) -} - -// Query executes a named statement using the struct argument, returning rows. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { - args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) - if err != nil { - return nil, err - } - return n.Stmt.Query(args...) -} - -// QueryRow executes a named statement against the database. Because sqlx cannot -// create a *sql.Row with an error condition pre-set for binding errors, sqlx -// returns a *sqlx.Row instead. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) QueryRow(arg interface{}) *Row { - args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) - if err != nil { - return &Row{err: err} - } - return n.Stmt.QueryRowx(args...) -} - -// MustExec execs a NamedStmt, panicing on error -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) MustExec(arg interface{}) sql.Result { - res, err := n.Exec(arg) - if err != nil { - panic(err) - } - return res -} - -// Queryx using this NamedStmt -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { - r, err := n.Query(arg) - if err != nil { - return nil, err - } - return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err -} - -// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is -// an alias for QueryRow. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) QueryRowx(arg interface{}) *Row { - return n.QueryRow(arg) -} - -// Select using this NamedStmt -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { - rows, err := n.Queryx(arg) - if err != nil { - return err - } - // if something happens here, we want to make sure the rows are Closed - defer rows.Close() - return scanAll(rows, dest, false) -} - -// Get using this NamedStmt -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { - r := n.QueryRowx(arg) - return r.scanAny(dest, false) -} - -// Unsafe creates an unsafe version of the NamedStmt -func (n *NamedStmt) Unsafe() *NamedStmt { - r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString} - r.Stmt.unsafe = true - return r -} - -// A union interface of preparer and binder, required to be able to prepare -// named statements (as the bindtype must be determined). -type namedPreparer interface { - Preparer - binder -} - -func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { - bindType := BindType(p.DriverName()) - q, args, err := compileNamedQuery([]byte(query), bindType) - if err != nil { - return nil, err - } - stmt, err := Preparex(p, q) - if err != nil { - return nil, err - } - return &NamedStmt{ - QueryString: q, - Params: args, - Stmt: stmt, - }, nil -} - -func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { - if maparg, ok := arg.(map[string]interface{}); ok { - return bindMapArgs(names, maparg) - } - return bindArgs(names, arg, m) -} - -// private interface to generate a list of interfaces from a given struct -// type, given a list of names to pull out of the struct. Used by public -// BindStruct interface. -func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { - arglist := make([]interface{}, 0, len(names)) - - // grab the indirected value of arg - v := reflect.ValueOf(arg) - for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { - v = v.Elem() - } - - fields := m.TraversalsByName(v.Type(), names) - for i, t := range fields { - if len(t) == 0 { - return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg) - } - val := reflectx.FieldByIndexesReadOnly(v, t) - arglist = append(arglist, val.Interface()) - } - - return arglist, nil -} - -// like bindArgs, but for maps. -func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) { - arglist := make([]interface{}, 0, len(names)) - - for _, name := range names { - val, ok := arg[name] - if !ok { - return arglist, fmt.Errorf("could not find name %s in %#v", name, arg) - } - arglist = append(arglist, val) - } - return arglist, nil -} - -// bindStruct binds a named parameter query with fields from a struct argument. -// The rules for binding field names to parameter names follow the same -// conventions as for StructScan, including obeying the `db` struct tags. -func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { - bound, names, err := compileNamedQuery([]byte(query), bindType) - if err != nil { - return "", []interface{}{}, err - } - - arglist, err := bindArgs(names, arg, m) - if err != nil { - return "", []interface{}{}, err - } - - return bound, arglist, nil -} - -// bindMap binds a named parameter query with a map of arguments. -func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { - bound, names, err := compileNamedQuery([]byte(query), bindType) - if err != nil { - return "", []interface{}{}, err - } - - arglist, err := bindMapArgs(names, args) - return bound, arglist, err -} - -// -- Compilation of Named Queries - -// Allow digits and letters in bind params; additionally runes are -// checked against underscores, meaning that bind params can have be -// alphanumeric with underscores. Mind the difference between unicode -// digits and numbers, where '5' is a digit but '五' is not. -var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} - -// FIXME: this function isn't safe for unicode named params, as a failing test -// can testify. This is not a regression but a failure of the original code -// as well. It should be modified to range over runes in a string rather than -// bytes, even though this is less convenient and slower. Hopefully the -// addition of the prepared NamedStmt (which will only do this once) will make -// up for the slightly slower ad-hoc NamedExec/NamedQuery. - -// compile a NamedQuery into an unbound query (using the '?' bindvar) and -// a list of names. -func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { - names = make([]string, 0, 10) - rebound := make([]byte, 0, len(qs)) - - inName := false - last := len(qs) - 1 - currentVar := 1 - name := make([]byte, 0, 10) - - for i, b := range qs { - // a ':' while we're in a name is an error - if b == ':' { - // if this is the second ':' in a '::' escape sequence, append a ':' - if inName && i > 0 && qs[i-1] == ':' { - rebound = append(rebound, ':') - inName = false - continue - } else if inName { - err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) - return query, names, err - } - inName = true - name = []byte{} - // if we're in a name, and this is an allowed character, continue - } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { - // append the byte to the name if we are in a name and not on the last byte - name = append(name, b) - // if we're in a name and it's not an allowed character, the name is done - } else if inName { - inName = false - // if this is the final byte of the string and it is part of the name, then - // make sure to add it to the name - if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { - name = append(name, b) - } - // add the string representation to the names list - names = append(names, string(name)) - // add a proper bindvar for the bindType - switch bindType { - // oracle only supports named type bind vars even for positional - case NAMED: - rebound = append(rebound, ':') - rebound = append(rebound, name...) - case QUESTION, UNKNOWN: - rebound = append(rebound, '?') - case DOLLAR: - rebound = append(rebound, '$') - for _, b := range strconv.Itoa(currentVar) { - rebound = append(rebound, byte(b)) - } - currentVar++ - } - // add this byte to string unless it was not part of the name - if i != last { - rebound = append(rebound, b) - } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { - rebound = append(rebound, b) - } - } else { - // this is a normal byte and should just go onto the rebound query - rebound = append(rebound, b) - } - } - - return string(rebound), names, err -} - -// BindNamed binds a struct or a map to a query with named parameters. -// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. -func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { - return bindNamedMapper(bindType, query, arg, mapper()) -} - -// Named takes a query using named parameters and an argument and -// returns a new query with a list of args that can be executed by -// a database. The return value uses the `?` bindvar. -func Named(query string, arg interface{}) (string, []interface{}, error) { - return bindNamedMapper(QUESTION, query, arg, mapper()) -} - -func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { - if maparg, ok := arg.(map[string]interface{}); ok { - return bindMap(bindType, query, maparg) - } - return bindStruct(bindType, query, arg, m) -} - -// NamedQuery binds a named query and then runs Query on the result using the -// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with -// map[string]interface{} types. -func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { - q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) - if err != nil { - return nil, err - } - return e.Queryx(q, args...) -} - -// NamedExec uses BindStruct to get a query executable by the driver and -// then runs Exec on the result. Returns an error from the binding -// or the query excution itself. -func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { - q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) - if err != nil { - return nil, err - } - return e.Exec(q, args...) -} diff --git a/vendor/github.com/jmoiron/sqlx/named_context.go b/vendor/github.com/jmoiron/sqlx/named_context.go deleted file mode 100644 index 9405007e2..000000000 --- a/vendor/github.com/jmoiron/sqlx/named_context.go +++ /dev/null @@ -1,132 +0,0 @@ -// +build go1.8 - -package sqlx - -import ( - "context" - "database/sql" -) - -// A union interface of contextPreparer and binder, required to be able to -// prepare named statements with context (as the bindtype must be determined). -type namedPreparerContext interface { - PreparerContext - binder -} - -func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { - bindType := BindType(p.DriverName()) - q, args, err := compileNamedQuery([]byte(query), bindType) - if err != nil { - return nil, err - } - stmt, err := PreparexContext(ctx, p, q) - if err != nil { - return nil, err - } - return &NamedStmt{ - QueryString: q, - Params: args, - Stmt: stmt, - }, nil -} - -// ExecContext executes a named statement using the struct passed. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { - args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) - if err != nil { - return *new(sql.Result), err - } - return n.Stmt.ExecContext(ctx, args...) -} - -// QueryContext executes a named statement using the struct argument, returning rows. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { - args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) - if err != nil { - return nil, err - } - return n.Stmt.QueryContext(ctx, args...) -} - -// QueryRowContext executes a named statement against the database. Because sqlx cannot -// create a *sql.Row with an error condition pre-set for binding errors, sqlx -// returns a *sqlx.Row instead. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { - args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) - if err != nil { - return &Row{err: err} - } - return n.Stmt.QueryRowxContext(ctx, args...) -} - -// MustExecContext execs a NamedStmt, panicing on error -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { - res, err := n.ExecContext(ctx, arg) - if err != nil { - panic(err) - } - return res -} - -// QueryxContext using this NamedStmt -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { - r, err := n.QueryContext(ctx, arg) - if err != nil { - return nil, err - } - return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err -} - -// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is -// an alias for QueryRow. -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { - return n.QueryRowContext(ctx, arg) -} - -// SelectContext using this NamedStmt -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { - rows, err := n.QueryxContext(ctx, arg) - if err != nil { - return err - } - // if something happens here, we want to make sure the rows are Closed - defer rows.Close() - return scanAll(rows, dest, false) -} - -// GetContext using this NamedStmt -// Any named placeholder parameters are replaced with fields from arg. -func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { - r := n.QueryRowxContext(ctx, arg) - return r.scanAny(dest, false) -} - -// NamedQueryContext binds a named query and then runs Query on the result using the -// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with -// map[string]interface{} types. -func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { - q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) - if err != nil { - return nil, err - } - return e.QueryxContext(ctx, q, args...) -} - -// NamedExecContext uses BindStruct to get a query executable by the driver and -// then runs Exec on the result. Returns an error from the binding -// or the query excution itself. -func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { - q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) - if err != nil { - return nil, err - } - return e.ExecContext(ctx, q, args...) -} diff --git a/vendor/github.com/jmoiron/sqlx/named_context_test.go b/vendor/github.com/jmoiron/sqlx/named_context_test.go deleted file mode 100644 index 87e94ac22..000000000 --- a/vendor/github.com/jmoiron/sqlx/named_context_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// +build go1.8 - -package sqlx - -import ( - "context" - "database/sql" - "testing" -) - -func TestNamedContextQueries(t *testing.T) { - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - test := Test{t} - var ns *NamedStmt - var err error - - ctx := context.Background() - - // Check that invalid preparations fail - ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name") - if err == nil { - t.Error("Expected an error with invalid prepared statement.") - } - - ns, err = db.PrepareNamedContext(ctx, "invalid sql") - if err == nil { - t.Error("Expected an error with invalid prepared statement.") - } - - // Check closing works as anticipated - ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name") - test.Error(err) - err = ns.Close() - test.Error(err) - - ns, err = db.PrepareNamedContext(ctx, ` - SELECT first_name, last_name, email - FROM person WHERE first_name=:first_name AND email=:email`) - test.Error(err) - - // test Queryx w/ uses Query - p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} - - rows, err := ns.QueryxContext(ctx, p) - test.Error(err) - for rows.Next() { - var p2 Person - rows.StructScan(&p2) - if p.FirstName != p2.FirstName { - t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) - } - if p.LastName != p2.LastName { - t.Errorf("got %s, expected %s", p.LastName, p2.LastName) - } - if p.Email != p2.Email { - t.Errorf("got %s, expected %s", p.Email, p2.Email) - } - } - - // test Select - people := make([]Person, 0, 5) - err = ns.SelectContext(ctx, &people, p) - test.Error(err) - - if len(people) != 1 { - t.Errorf("got %d results, expected %d", len(people), 1) - } - if p.FirstName != people[0].FirstName { - t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) - } - if p.LastName != people[0].LastName { - t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) - } - if p.Email != people[0].Email { - t.Errorf("got %s, expected %s", p.Email, people[0].Email) - } - - // test Exec - ns, err = db.PrepareNamedContext(ctx, ` - INSERT INTO person (first_name, last_name, email) - VALUES (:first_name, :last_name, :email)`) - test.Error(err) - - js := Person{ - FirstName: "Julien", - LastName: "Savea", - Email: "jsavea@ab.co.nz", - } - _, err = ns.ExecContext(ctx, js) - test.Error(err) - - // Make sure we can pull him out again - p2 := Person{} - db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) - if p2.Email != js.Email { - t.Errorf("expected %s, got %s", js.Email, p2.Email) - } - - // test Txn NamedStmts - tx := db.MustBeginTx(ctx, nil) - txns := tx.NamedStmtContext(ctx, ns) - - // We're going to add Steven in this txn - sl := Person{ - FirstName: "Steven", - LastName: "Luatua", - Email: "sluatua@ab.co.nz", - } - - _, err = txns.ExecContext(ctx, sl) - test.Error(err) - // then rollback... - tx.Rollback() - // looking for Steven after a rollback should fail - err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) - if err != sql.ErrNoRows { - t.Errorf("expected no rows error, got %v", err) - } - - // now do the same, but commit - tx = db.MustBeginTx(ctx, nil) - txns = tx.NamedStmtContext(ctx, ns) - _, err = txns.ExecContext(ctx, sl) - test.Error(err) - tx.Commit() - - // looking for Steven after a Commit should succeed - err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) - test.Error(err) - if p2.Email != sl.Email { - t.Errorf("expected %s, got %s", sl.Email, p2.Email) - } - - }) -} diff --git a/vendor/github.com/jmoiron/sqlx/named_test.go b/vendor/github.com/jmoiron/sqlx/named_test.go deleted file mode 100644 index d3459a86f..000000000 --- a/vendor/github.com/jmoiron/sqlx/named_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package sqlx - -import ( - "database/sql" - "testing" -) - -func TestCompileQuery(t *testing.T) { - table := []struct { - Q, R, D, N string - V []string - }{ - // basic test for named parameters, invalid char ',' terminating - { - Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, - R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, - D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, - N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, - V: []string{"name", "age", "first", "last"}, - }, - // This query tests a named parameter ending the string as well as numbers - { - Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, - R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, - D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, - N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, - V: []string{"name1", "name2"}, - }, - { - Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, - R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, - D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, - N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, - V: []string{"name1", "name2"}, - }, - { - Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, - R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, - D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, - N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, - V: []string{"first_name", "last_name"}, - }, - /* This unicode awareness test sadly fails, because of our byte-wise worldview. - * We could certainly iterate by Rune instead, though it's a great deal slower, - * it's probably the RightWay(tm) - { - Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`, - R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, - D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, - N: []string{"name", "age", "first", "last"}, - }, - */ - } - - for _, test := range table { - qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION) - if err != nil { - t.Error(err) - } - if qr != test.R { - t.Errorf("expected %s, got %s", test.R, qr) - } - if len(names) != len(test.V) { - t.Errorf("expected %#v, got %#v", test.V, names) - } else { - for i, name := range names { - if name != test.V[i] { - t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name) - } - } - } - qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR) - if qd != test.D { - t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) - } - - qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED) - if qq != test.N { - t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) - } - } -} - -type Test struct { - t *testing.T -} - -func (t Test) Error(err error, msg ...interface{}) { - if err != nil { - if len(msg) == 0 { - t.t.Error(err) - } else { - t.t.Error(msg...) - } - } -} - -func (t Test) Errorf(err error, format string, args ...interface{}) { - if err != nil { - t.t.Errorf(format, args...) - } -} - -func TestNamedQueries(t *testing.T) { - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - test := Test{t} - var ns *NamedStmt - var err error - - // Check that invalid preparations fail - ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first:name") - if err == nil { - t.Error("Expected an error with invalid prepared statement.") - } - - ns, err = db.PrepareNamed("invalid sql") - if err == nil { - t.Error("Expected an error with invalid prepared statement.") - } - - // Check closing works as anticipated - ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first_name") - test.Error(err) - err = ns.Close() - test.Error(err) - - ns, err = db.PrepareNamed(` - SELECT first_name, last_name, email - FROM person WHERE first_name=:first_name AND email=:email`) - test.Error(err) - - // test Queryx w/ uses Query - p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} - - rows, err := ns.Queryx(p) - test.Error(err) - for rows.Next() { - var p2 Person - rows.StructScan(&p2) - if p.FirstName != p2.FirstName { - t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) - } - if p.LastName != p2.LastName { - t.Errorf("got %s, expected %s", p.LastName, p2.LastName) - } - if p.Email != p2.Email { - t.Errorf("got %s, expected %s", p.Email, p2.Email) - } - } - - // test Select - people := make([]Person, 0, 5) - err = ns.Select(&people, p) - test.Error(err) - - if len(people) != 1 { - t.Errorf("got %d results, expected %d", len(people), 1) - } - if p.FirstName != people[0].FirstName { - t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) - } - if p.LastName != people[0].LastName { - t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) - } - if p.Email != people[0].Email { - t.Errorf("got %s, expected %s", p.Email, people[0].Email) - } - - // test Exec - ns, err = db.PrepareNamed(` - INSERT INTO person (first_name, last_name, email) - VALUES (:first_name, :last_name, :email)`) - test.Error(err) - - js := Person{ - FirstName: "Julien", - LastName: "Savea", - Email: "jsavea@ab.co.nz", - } - _, err = ns.Exec(js) - test.Error(err) - - // Make sure we can pull him out again - p2 := Person{} - db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) - if p2.Email != js.Email { - t.Errorf("expected %s, got %s", js.Email, p2.Email) - } - - // test Txn NamedStmts - tx := db.MustBegin() - txns := tx.NamedStmt(ns) - - // We're going to add Steven in this txn - sl := Person{ - FirstName: "Steven", - LastName: "Luatua", - Email: "sluatua@ab.co.nz", - } - - _, err = txns.Exec(sl) - test.Error(err) - // then rollback... - tx.Rollback() - // looking for Steven after a rollback should fail - err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) - if err != sql.ErrNoRows { - t.Errorf("expected no rows error, got %v", err) - } - - // now do the same, but commit - tx = db.MustBegin() - txns = tx.NamedStmt(ns) - _, err = txns.Exec(sl) - test.Error(err) - tx.Commit() - - // looking for Steven after a Commit should succeed - err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) - test.Error(err) - if p2.Email != sl.Email { - t.Errorf("expected %s, got %s", sl.Email, p2.Email) - } - - }) -} diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/README.md b/vendor/github.com/jmoiron/sqlx/reflectx/README.md deleted file mode 100644 index f01d3d1f0..000000000 --- a/vendor/github.com/jmoiron/sqlx/reflectx/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# reflectx - -The sqlx package has special reflect needs. In particular, it needs to: - -* be able to map a name to a field -* understand embedded structs -* understand mapping names to fields by a particular tag -* user specified name -> field mapping functions - -These behaviors mimic the behaviors by the standard library marshallers and also the -behavior of standard Go accessors. - -The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is -addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct -tags in the ways that are vital to most marshallers, and they are slow. - -This reflectx package extends reflect to achieve these goals. diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go b/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go deleted file mode 100644 index f2802b80b..000000000 --- a/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go +++ /dev/null @@ -1,422 +0,0 @@ -// Package reflectx implements extensions to the standard reflect lib suitable -// for implementing marshalling and unmarshalling packages. The main Mapper type -// allows for Go-compatible named attribute access, including accessing embedded -// struct attributes and the ability to use functions and struct tags to -// customize field names. -// -package reflectx - -import ( - "reflect" - "runtime" - "strings" - "sync" -) - -// A FieldInfo is metadata for a struct field. -type FieldInfo struct { - Index []int - Path string - Field reflect.StructField - Zero reflect.Value - Name string - Options map[string]string - Embedded bool - Children []*FieldInfo - Parent *FieldInfo -} - -// A StructMap is an index of field metadata for a struct. -type StructMap struct { - Tree *FieldInfo - Index []*FieldInfo - Paths map[string]*FieldInfo - Names map[string]*FieldInfo -} - -// GetByPath returns a *FieldInfo for a given string path. -func (f StructMap) GetByPath(path string) *FieldInfo { - return f.Paths[path] -} - -// GetByTraversal returns a *FieldInfo for a given integer path. It is -// analogous to reflect.FieldByIndex, but using the cached traversal -// rather than re-executing the reflect machinery each time. -func (f StructMap) GetByTraversal(index []int) *FieldInfo { - if len(index) == 0 { - return nil - } - - tree := f.Tree - for _, i := range index { - if i >= len(tree.Children) || tree.Children[i] == nil { - return nil - } - tree = tree.Children[i] - } - return tree -} - -// Mapper is a general purpose mapper of names to struct fields. A Mapper -// behaves like most marshallers in the standard library, obeying a field tag -// for name mapping but also providing a basic transform function. -type Mapper struct { - cache map[reflect.Type]*StructMap - tagName string - tagMapFunc func(string) string - mapFunc func(string) string - mutex sync.Mutex -} - -// NewMapper returns a new mapper using the tagName as its struct field tag. -// If tagName is the empty string, it is ignored. -func NewMapper(tagName string) *Mapper { - return &Mapper{ - cache: make(map[reflect.Type]*StructMap), - tagName: tagName, - } -} - -// NewMapperTagFunc returns a new mapper which contains a mapper for field names -// AND a mapper for tag values. This is useful for tags like json which can -// have values like "name,omitempty". -func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { - return &Mapper{ - cache: make(map[reflect.Type]*StructMap), - tagName: tagName, - mapFunc: mapFunc, - tagMapFunc: tagMapFunc, - } -} - -// NewMapperFunc returns a new mapper which optionally obeys a field tag and -// a struct field name mapper func given by f. Tags will take precedence, but -// for any other field, the mapped name will be f(field.Name) -func NewMapperFunc(tagName string, f func(string) string) *Mapper { - return &Mapper{ - cache: make(map[reflect.Type]*StructMap), - tagName: tagName, - mapFunc: f, - } -} - -// TypeMap returns a mapping of field strings to int slices representing -// the traversal down the struct to reach the field. -func (m *Mapper) TypeMap(t reflect.Type) *StructMap { - m.mutex.Lock() - mapping, ok := m.cache[t] - if !ok { - mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) - m.cache[t] = mapping - } - m.mutex.Unlock() - return mapping -} - -// FieldMap returns the mapper's mapping of field names to reflect values. Panics -// if v's Kind is not Struct, or v is not Indirectable to a struct kind. -func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { - v = reflect.Indirect(v) - mustBe(v, reflect.Struct) - - r := map[string]reflect.Value{} - tm := m.TypeMap(v.Type()) - for tagName, fi := range tm.Names { - r[tagName] = FieldByIndexes(v, fi.Index) - } - return r -} - -// FieldByName returns a field by its mapped name as a reflect.Value. -// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. -// Returns zero Value if the name is not found. -func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { - v = reflect.Indirect(v) - mustBe(v, reflect.Struct) - - tm := m.TypeMap(v.Type()) - fi, ok := tm.Names[name] - if !ok { - return v - } - return FieldByIndexes(v, fi.Index) -} - -// FieldsByName returns a slice of values corresponding to the slice of names -// for the value. Panics if v's Kind is not Struct or v is not Indirectable -// to a struct Kind. Returns zero Value for each name not found. -func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { - v = reflect.Indirect(v) - mustBe(v, reflect.Struct) - - tm := m.TypeMap(v.Type()) - vals := make([]reflect.Value, 0, len(names)) - for _, name := range names { - fi, ok := tm.Names[name] - if !ok { - vals = append(vals, *new(reflect.Value)) - } else { - vals = append(vals, FieldByIndexes(v, fi.Index)) - } - } - return vals -} - -// TraversalsByName returns a slice of int slices which represent the struct -// traversals for each mapped name. Panics if t is not a struct or Indirectable -// to a struct. Returns empty int slice for each name not found. -func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { - t = Deref(t) - mustBe(t, reflect.Struct) - tm := m.TypeMap(t) - - r := make([][]int, 0, len(names)) - for _, name := range names { - fi, ok := tm.Names[name] - if !ok { - r = append(r, []int{}) - } else { - r = append(r, fi.Index) - } - } - return r -} - -// FieldByIndexes returns a value for the field given by the struct traversal -// for the given value. -func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { - for _, i := range indexes { - v = reflect.Indirect(v).Field(i) - // if this is a pointer and it's nil, allocate a new value and set it - if v.Kind() == reflect.Ptr && v.IsNil() { - alloc := reflect.New(Deref(v.Type())) - v.Set(alloc) - } - if v.Kind() == reflect.Map && v.IsNil() { - v.Set(reflect.MakeMap(v.Type())) - } - } - return v -} - -// FieldByIndexesReadOnly returns a value for a particular struct traversal, -// but is not concerned with allocating nil pointers because the value is -// going to be used for reading and not setting. -func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { - for _, i := range indexes { - v = reflect.Indirect(v).Field(i) - } - return v -} - -// Deref is Indirect for reflect.Types -func Deref(t reflect.Type) reflect.Type { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t -} - -// -- helpers & utilities -- - -type kinder interface { - Kind() reflect.Kind -} - -// mustBe checks a value against a kind, panicing with a reflect.ValueError -// if the kind isn't that which is required. -func mustBe(v kinder, expected reflect.Kind) { - if k := v.Kind(); k != expected { - panic(&reflect.ValueError{Method: methodName(), Kind: k}) - } -} - -// methodName returns the caller of the function calling methodName -func methodName() string { - pc, _, _, _ := runtime.Caller(2) - f := runtime.FuncForPC(pc) - if f == nil { - return "unknown method" - } - return f.Name() -} - -type typeQueue struct { - t reflect.Type - fi *FieldInfo - pp string // Parent path -} - -// A copying append that creates a new slice each time. -func apnd(is []int, i int) []int { - x := make([]int, len(is)+1) - for p, n := range is { - x[p] = n - } - x[len(x)-1] = i - return x -} - -type mapf func(string) string - -// parseName parses the tag and the target name for the given field using -// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the -// field's name to a target name, and tagMapFunc for mapping the tag to -// a target name. -func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { - // first, set the fieldName to the field's name - fieldName = field.Name - // if a mapFunc is set, use that to override the fieldName - if mapFunc != nil { - fieldName = mapFunc(fieldName) - } - - // if there's no tag to look for, return the field name - if tagName == "" { - return "", fieldName - } - - // if this tag is not set using the normal convention in the tag, - // then return the fieldname.. this check is done because according - // to the reflect documentation: - // If the tag does not have the conventional format, - // the value returned by Get is unspecified. - // which doesn't sound great. - if !strings.Contains(string(field.Tag), tagName+":") { - return "", fieldName - } - - // at this point we're fairly sure that we have a tag, so lets pull it out - tag = field.Tag.Get(tagName) - - // if we have a mapper function, call it on the whole tag - // XXX: this is a change from the old version, which pulled out the name - // before the tagMapFunc could be run, but I think this is the right way - if tagMapFunc != nil { - tag = tagMapFunc(tag) - } - - // finally, split the options from the name - parts := strings.Split(tag, ",") - fieldName = parts[0] - - return tag, fieldName -} - -// parseOptions parses options out of a tag string, skipping the name -func parseOptions(tag string) map[string]string { - parts := strings.Split(tag, ",") - options := make(map[string]string, len(parts)) - if len(parts) > 1 { - for _, opt := range parts[1:] { - // short circuit potentially expensive split op - if strings.Contains(opt, "=") { - kv := strings.Split(opt, "=") - options[kv[0]] = kv[1] - continue - } - options[opt] = "" - } - } - return options -} - -// getMapping returns a mapping for the t type, using the tagName, mapFunc and -// tagMapFunc to determine the canonical names of fields. -func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { - m := []*FieldInfo{} - - root := &FieldInfo{} - queue := []typeQueue{} - queue = append(queue, typeQueue{Deref(t), root, ""}) - -QueueLoop: - for len(queue) != 0 { - // pop the first item off of the queue - tq := queue[0] - queue = queue[1:] - - // ignore recursive field - for p := tq.fi.Parent; p != nil; p = p.Parent { - if tq.fi.Field.Type == p.Field.Type { - continue QueueLoop - } - } - - nChildren := 0 - if tq.t.Kind() == reflect.Struct { - nChildren = tq.t.NumField() - } - tq.fi.Children = make([]*FieldInfo, nChildren) - - // iterate through all of its fields - for fieldPos := 0; fieldPos < nChildren; fieldPos++ { - - f := tq.t.Field(fieldPos) - - // parse the tag and the target name using the mapping options for this field - tag, name := parseName(f, tagName, mapFunc, tagMapFunc) - - // if the name is "-", disabled via a tag, skip it - if name == "-" { - continue - } - - fi := FieldInfo{ - Field: f, - Name: name, - Zero: reflect.New(f.Type).Elem(), - Options: parseOptions(tag), - } - - // if the path is empty this path is just the name - if tq.pp == "" { - fi.Path = fi.Name - } else { - fi.Path = tq.pp + "." + fi.Name - } - - // skip unexported fields - if len(f.PkgPath) != 0 && !f.Anonymous { - continue - } - - // bfs search of anonymous embedded structs - if f.Anonymous { - pp := tq.pp - if tag != "" { - pp = fi.Path - } - - fi.Embedded = true - fi.Index = apnd(tq.fi.Index, fieldPos) - nChildren := 0 - ft := Deref(f.Type) - if ft.Kind() == reflect.Struct { - nChildren = ft.NumField() - } - fi.Children = make([]*FieldInfo, nChildren) - queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) - } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { - fi.Index = apnd(tq.fi.Index, fieldPos) - fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) - queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) - } - - fi.Index = apnd(tq.fi.Index, fieldPos) - fi.Parent = tq.fi - tq.fi.Children[fieldPos] = &fi - m = append(m, &fi) - } - } - - flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} - for _, fi := range flds.Index { - flds.Paths[fi.Path] = fi - if fi.Name != "" && !fi.Embedded { - flds.Names[fi.Path] = fi - } - } - - return flds -} diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go b/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go deleted file mode 100644 index b702f9cd1..000000000 --- a/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go +++ /dev/null @@ -1,905 +0,0 @@ -package reflectx - -import ( - "reflect" - "strings" - "testing" -) - -func ival(v reflect.Value) int { - return v.Interface().(int) -} - -func TestBasic(t *testing.T) { - type Foo struct { - A int - B int - C int - } - - f := Foo{1, 2, 3} - fv := reflect.ValueOf(f) - m := NewMapperFunc("", func(s string) string { return s }) - - v := m.FieldByName(fv, "A") - if ival(v) != f.A { - t.Errorf("Expecting %d, got %d", ival(v), f.A) - } - v = m.FieldByName(fv, "B") - if ival(v) != f.B { - t.Errorf("Expecting %d, got %d", f.B, ival(v)) - } - v = m.FieldByName(fv, "C") - if ival(v) != f.C { - t.Errorf("Expecting %d, got %d", f.C, ival(v)) - } -} - -func TestBasicEmbedded(t *testing.T) { - type Foo struct { - A int - } - - type Bar struct { - Foo // `db:""` is implied for an embedded struct - B int - C int `db:"-"` - } - - type Baz struct { - A int - Bar `db:"Bar"` - } - - m := NewMapperFunc("db", func(s string) string { return s }) - - z := Baz{} - z.A = 1 - z.B = 2 - z.C = 4 - z.Bar.Foo.A = 3 - - zv := reflect.ValueOf(z) - fields := m.TypeMap(reflect.TypeOf(z)) - - if len(fields.Index) != 5 { - t.Errorf("Expecting 5 fields") - } - - // for _, fi := range fields.Index { - // log.Println(fi) - // } - - v := m.FieldByName(zv, "A") - if ival(v) != z.A { - t.Errorf("Expecting %d, got %d", z.A, ival(v)) - } - v = m.FieldByName(zv, "Bar.B") - if ival(v) != z.Bar.B { - t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) - } - v = m.FieldByName(zv, "Bar.A") - if ival(v) != z.Bar.Foo.A { - t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) - } - v = m.FieldByName(zv, "Bar.C") - if _, ok := v.Interface().(int); ok { - t.Errorf("Expecting Bar.C to not exist") - } - - fi := fields.GetByPath("Bar.C") - if fi != nil { - t.Errorf("Bar.C should not exist") - } -} - -func TestEmbeddedSimple(t *testing.T) { - type UUID [16]byte - type MyID struct { - UUID - } - type Item struct { - ID MyID - } - z := Item{} - - m := NewMapper("db") - m.TypeMap(reflect.TypeOf(z)) -} - -func TestBasicEmbeddedWithTags(t *testing.T) { - type Foo struct { - A int `db:"a"` - } - - type Bar struct { - Foo // `db:""` is implied for an embedded struct - B int `db:"b"` - } - - type Baz struct { - A int `db:"a"` - Bar // `db:""` is implied for an embedded struct - } - - m := NewMapper("db") - - z := Baz{} - z.A = 1 - z.B = 2 - z.Bar.Foo.A = 3 - - zv := reflect.ValueOf(z) - fields := m.TypeMap(reflect.TypeOf(z)) - - if len(fields.Index) != 5 { - t.Errorf("Expecting 5 fields") - } - - // for _, fi := range fields.index { - // log.Println(fi) - // } - - v := m.FieldByName(zv, "a") - if ival(v) != z.Bar.Foo.A { // the dominant field - t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) - } - v = m.FieldByName(zv, "b") - if ival(v) != z.B { - t.Errorf("Expecting %d, got %d", z.B, ival(v)) - } -} - -func TestFlatTags(t *testing.T) { - m := NewMapper("db") - - type Asset struct { - Title string `db:"title"` - } - type Post struct { - Author string `db:"author,required"` - Asset Asset `db:""` - } - // Post columns: (author title) - - post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}} - pv := reflect.ValueOf(post) - - v := m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } - v = m.FieldByName(pv, "title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) - } -} - -func TestNestedStruct(t *testing.T) { - m := NewMapper("db") - - type Details struct { - Active bool `db:"active"` - } - type Asset struct { - Title string `db:"title"` - Details Details `db:"details"` - } - type Post struct { - Author string `db:"author,required"` - Asset `db:"asset"` - } - // Post columns: (author asset.title asset.details.active) - - post := Post{ - Author: "Joe", - Asset: Asset{Title: "Hello", Details: Details{Active: true}}, - } - pv := reflect.ValueOf(post) - - v := m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } - v = m.FieldByName(pv, "title") - if _, ok := v.Interface().(string); ok { - t.Errorf("Expecting field to not exist") - } - v = m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) - } - v = m.FieldByName(pv, "asset.details.active") - if v.Interface().(bool) != post.Asset.Details.Active { - t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) - } -} - -func TestInlineStruct(t *testing.T) { - m := NewMapperTagFunc("db", strings.ToLower, nil) - - type Employee struct { - Name string - ID int - } - type Boss Employee - type person struct { - Employee `db:"employee"` - Boss `db:"boss"` - } - // employees columns: (employee.name employee.id boss.name boss.id) - - em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}} - ev := reflect.ValueOf(em) - - fields := m.TypeMap(reflect.TypeOf(em)) - if len(fields.Index) != 6 { - t.Errorf("Expecting 6 fields") - } - - v := m.FieldByName(ev, "employee.name") - if v.Interface().(string) != em.Employee.Name { - t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) - } - v = m.FieldByName(ev, "boss.id") - if ival(v) != em.Boss.ID { - t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) - } -} - -func TestRecursiveStruct(t *testing.T) { - type Person struct { - Parent *Person - } - m := NewMapperFunc("db", strings.ToLower) - var p *Person - m.TypeMap(reflect.TypeOf(p)) -} - -func TestFieldsEmbedded(t *testing.T) { - m := NewMapper("db") - - type Person struct { - Name string `db:"name,size=64"` - } - type Place struct { - Name string `db:"name"` - } - type Article struct { - Title string `db:"title"` - } - type PP struct { - Person `db:"person,required"` - Place `db:",someflag"` - Article `db:",required"` - } - // PP columns: (person.name name title) - - pp := PP{} - pp.Person.Name = "Peter" - pp.Place.Name = "Toronto" - pp.Article.Title = "Best city ever" - - fields := m.TypeMap(reflect.TypeOf(pp)) - // for i, f := range fields { - // log.Println(i, f) - // } - - ppv := reflect.ValueOf(pp) - - v := m.FieldByName(ppv, "person.name") - if v.Interface().(string) != pp.Person.Name { - t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) - } - - v = m.FieldByName(ppv, "name") - if v.Interface().(string) != pp.Place.Name { - t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) - } - - v = m.FieldByName(ppv, "title") - if v.Interface().(string) != pp.Article.Title { - t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) - } - - fi := fields.GetByPath("person") - if _, ok := fi.Options["required"]; !ok { - t.Errorf("Expecting required option to be set") - } - if !fi.Embedded { - t.Errorf("Expecting field to be embedded") - } - if len(fi.Index) != 1 || fi.Index[0] != 0 { - t.Errorf("Expecting index to be [0]") - } - - fi = fields.GetByPath("person.name") - if fi == nil { - t.Errorf("Expecting person.name to exist") - } - if fi.Path != "person.name" { - t.Errorf("Expecting %s, got %s", "person.name", fi.Path) - } - if fi.Options["size"] != "64" { - t.Errorf("Expecting %s, got %s", "64", fi.Options["size"]) - } - - fi = fields.GetByTraversal([]int{1, 0}) - if fi == nil { - t.Errorf("Expecting traveral to exist") - } - if fi.Path != "name" { - t.Errorf("Expecting %s, got %s", "name", fi.Path) - } - - fi = fields.GetByTraversal([]int{2}) - if fi == nil { - t.Errorf("Expecting traversal to exist") - } - if _, ok := fi.Options["required"]; !ok { - t.Errorf("Expecting required option to be set") - } - - trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) - if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { - t.Errorf("Expecting traversal: %v", trs) - } -} - -func TestPtrFields(t *testing.T) { - m := NewMapperTagFunc("db", strings.ToLower, nil) - type Asset struct { - Title string - } - type Post struct { - *Asset `db:"asset"` - Author string - } - - post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}} - pv := reflect.ValueOf(post) - - fields := m.TypeMap(reflect.TypeOf(post)) - if len(fields.Index) != 3 { - t.Errorf("Expecting 3 fields") - } - - v := m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) - } - v = m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } -} - -func TestNamedPtrFields(t *testing.T) { - m := NewMapperTagFunc("db", strings.ToLower, nil) - - type User struct { - Name string - } - - type Asset struct { - Title string - - Owner *User `db:"owner"` - } - type Post struct { - Author string - - Asset1 *Asset `db:"asset1"` - Asset2 *Asset `db:"asset2"` - } - - post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil - pv := reflect.ValueOf(post) - - fields := m.TypeMap(reflect.TypeOf(post)) - if len(fields.Index) != 9 { - t.Errorf("Expecting 9 fields") - } - - v := m.FieldByName(pv, "asset1.title") - if v.Interface().(string) != post.Asset1.Title { - t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) - } - v = m.FieldByName(pv, "asset1.owner.name") - if v.Interface().(string) != post.Asset1.Owner.Name { - t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) - } - v = m.FieldByName(pv, "asset2.title") - if v.Interface().(string) != post.Asset2.Title { - t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) - } - v = m.FieldByName(pv, "asset2.owner.name") - if v.Interface().(string) != post.Asset2.Owner.Name { - t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) - } - v = m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } -} - -func TestFieldMap(t *testing.T) { - type Foo struct { - A int - B int - C int - } - - f := Foo{1, 2, 3} - m := NewMapperFunc("db", strings.ToLower) - - fm := m.FieldMap(reflect.ValueOf(f)) - - if len(fm) != 3 { - t.Errorf("Expecting %d keys, got %d", 3, len(fm)) - } - if fm["a"].Interface().(int) != 1 { - t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) - } - if fm["b"].Interface().(int) != 2 { - t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) - } - if fm["c"].Interface().(int) != 3 { - t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) - } -} - -func TestTagNameMapping(t *testing.T) { - type Strategy struct { - StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"` - StrategyName string - } - - m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string { - if strings.Contains(value, ",") { - return strings.Split(value, ",")[0] - } - return value - }) - strategy := Strategy{"1", "Alpah"} - mapping := m.TypeMap(reflect.TypeOf(strategy)) - - for _, key := range []string{"strategy_id", "STRATEGYNAME"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } - } -} - -func TestMapping(t *testing.T) { - type Person struct { - ID int - Name string - WearsGlasses bool `db:"wears_glasses"` - } - - m := NewMapperFunc("db", strings.ToLower) - p := Person{1, "Jason", true} - mapping := m.TypeMap(reflect.TypeOf(p)) - - for _, key := range []string{"id", "name", "wears_glasses"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } - } - - type SportsPerson struct { - Weight int - Age int - Person - } - s := SportsPerson{Weight: 100, Age: 30, Person: p} - mapping = m.TypeMap(reflect.TypeOf(s)) - for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } - } - - type RugbyPlayer struct { - Position int - IsIntense bool `db:"is_intense"` - IsAllBlack bool `db:"-"` - SportsPerson - } - r := RugbyPlayer{12, true, false, s} - mapping = m.TypeMap(reflect.TypeOf(r)) - for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } - } - - if fi := mapping.GetByPath("isallblack"); fi != nil { - t.Errorf("Expecting to ignore `IsAllBlack` field") - } -} - -func TestGetByTraversal(t *testing.T) { - type C struct { - C0 int - C1 int - } - type B struct { - B0 string - B1 *C - } - type A struct { - A0 int - A1 B - } - - testCases := []struct { - Index []int - ExpectedName string - ExpectNil bool - }{ - { - Index: []int{0}, - ExpectedName: "A0", - }, - { - Index: []int{1, 0}, - ExpectedName: "B0", - }, - { - Index: []int{1, 1, 1}, - ExpectedName: "C1", - }, - { - Index: []int{3, 4, 5}, - ExpectNil: true, - }, - { - Index: []int{}, - ExpectNil: true, - }, - { - Index: nil, - ExpectNil: true, - }, - } - - m := NewMapperFunc("db", func(n string) string { return n }) - tm := m.TypeMap(reflect.TypeOf(A{})) - - for i, tc := range testCases { - fi := tm.GetByTraversal(tc.Index) - if tc.ExpectNil { - if fi != nil { - t.Errorf("%d: expected nil, got %v", i, fi) - } - continue - } - - if fi == nil { - t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName) - continue - } - - if fi.Name != tc.ExpectedName { - t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name) - } - } -} - -// TestMapperMethodsByName tests Mapper methods FieldByName and TraversalsByName -func TestMapperMethodsByName(t *testing.T) { - type C struct { - C0 string - C1 int - } - type B struct { - B0 *C `db:"B0"` - B1 C `db:"B1"` - B2 string `db:"B2"` - } - type A struct { - A0 *B `db:"A0"` - B `db:"A1"` - A2 int - a3 int - } - - val := &A{ - A0: &B{ - B0: &C{C0: "0", C1: 1}, - B1: C{C0: "2", C1: 3}, - B2: "4", - }, - B: B{ - B0: nil, - B1: C{C0: "5", C1: 6}, - B2: "7", - }, - A2: 8, - } - - testCases := []struct { - Name string - ExpectInvalid bool - ExpectedValue interface{} - ExpectedIndexes []int - }{ - { - Name: "A0.B0.C0", - ExpectedValue: "0", - ExpectedIndexes: []int{0, 0, 0}, - }, - { - Name: "A0.B0.C1", - ExpectedValue: 1, - ExpectedIndexes: []int{0, 0, 1}, - }, - { - Name: "A0.B1.C0", - ExpectedValue: "2", - ExpectedIndexes: []int{0, 1, 0}, - }, - { - Name: "A0.B1.C1", - ExpectedValue: 3, - ExpectedIndexes: []int{0, 1, 1}, - }, - { - Name: "A0.B2", - ExpectedValue: "4", - ExpectedIndexes: []int{0, 2}, - }, - { - Name: "A1.B0.C0", - ExpectedValue: "", - ExpectedIndexes: []int{1, 0, 0}, - }, - { - Name: "A1.B0.C1", - ExpectedValue: 0, - ExpectedIndexes: []int{1, 0, 1}, - }, - { - Name: "A1.B1.C0", - ExpectedValue: "5", - ExpectedIndexes: []int{1, 1, 0}, - }, - { - Name: "A1.B1.C1", - ExpectedValue: 6, - ExpectedIndexes: []int{1, 1, 1}, - }, - { - Name: "A1.B2", - ExpectedValue: "7", - ExpectedIndexes: []int{1, 2}, - }, - { - Name: "A2", - ExpectedValue: 8, - ExpectedIndexes: []int{2}, - }, - { - Name: "XYZ", - ExpectInvalid: true, - ExpectedIndexes: []int{}, - }, - { - Name: "a3", - ExpectInvalid: true, - ExpectedIndexes: []int{}, - }, - } - - // build the names array from the test cases - names := make([]string, len(testCases)) - for i, tc := range testCases { - names[i] = tc.Name - } - m := NewMapperFunc("db", func(n string) string { return n }) - v := reflect.ValueOf(val) - values := m.FieldsByName(v, names) - if len(values) != len(testCases) { - t.Errorf("expected %d values, got %d", len(testCases), len(values)) - t.FailNow() - } - indexes := m.TraversalsByName(v.Type(), names) - if len(indexes) != len(testCases) { - t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes)) - t.FailNow() - } - for i, val := range values { - tc := testCases[i] - traversal := indexes[i] - if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) { - t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal) - t.FailNow() - } - val = reflect.Indirect(val) - if tc.ExpectInvalid { - if val.IsValid() { - t.Errorf("%d: expected zero value, got %v", i, val) - } - continue - } - if !val.IsValid() { - t.Errorf("%d: expected valid value, got %v", i, val) - continue - } - actualValue := reflect.Indirect(val).Interface() - if !reflect.DeepEqual(tc.ExpectedValue, actualValue) { - t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue) - } - } -} - -func TestFieldByIndexes(t *testing.T) { - type C struct { - C0 bool - C1 string - C2 int - C3 map[string]int - } - type B struct { - B1 C - B2 *C - } - type A struct { - A1 B - A2 *B - } - testCases := []struct { - value interface{} - indexes []int - expectedValue interface{} - readOnly bool - }{ - { - value: A{ - A1: B{B1: C{C0: true}}, - }, - indexes: []int{0, 0, 0}, - expectedValue: true, - readOnly: true, - }, - { - value: A{ - A2: &B{B2: &C{C1: "answer"}}, - }, - indexes: []int{1, 1, 1}, - expectedValue: "answer", - readOnly: true, - }, - { - value: &A{}, - indexes: []int{1, 1, 3}, - expectedValue: map[string]int{}, - }, - } - - for i, tc := range testCases { - checkResults := func(v reflect.Value) { - if tc.expectedValue == nil { - if !v.IsNil() { - t.Errorf("%d: expected nil, actual %v", i, v.Interface()) - } - } else { - if !reflect.DeepEqual(tc.expectedValue, v.Interface()) { - t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface()) - } - } - } - - checkResults(FieldByIndexes(reflect.ValueOf(tc.value), tc.indexes)) - if tc.readOnly { - checkResults(FieldByIndexesReadOnly(reflect.ValueOf(tc.value), tc.indexes)) - } - } -} - -func TestMustBe(t *testing.T) { - typ := reflect.TypeOf(E1{}) - mustBe(typ, reflect.Struct) - - defer func() { - if r := recover(); r != nil { - valueErr, ok := r.(*reflect.ValueError) - if !ok { - t.Errorf("unexpected Method: %s", valueErr.Method) - t.Error("expected panic with *reflect.ValueError") - return - } - if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" { - } - if valueErr.Kind != reflect.String { - t.Errorf("unexpected Kind: %s", valueErr.Kind) - } - } else { - t.Error("expected panic") - } - }() - - typ = reflect.TypeOf("string") - mustBe(typ, reflect.Struct) - t.Error("got here, didn't expect to") -} - -type E1 struct { - A int -} -type E2 struct { - E1 - B int -} -type E3 struct { - E2 - C int -} -type E4 struct { - E3 - D int -} - -func BenchmarkFieldNameL1(b *testing.B) { - e4 := E4{D: 1} - for i := 0; i < b.N; i++ { - v := reflect.ValueOf(e4) - f := v.FieldByName("D") - if f.Interface().(int) != 1 { - b.Fatal("Wrong value.") - } - } -} - -func BenchmarkFieldNameL4(b *testing.B) { - e4 := E4{} - e4.A = 1 - for i := 0; i < b.N; i++ { - v := reflect.ValueOf(e4) - f := v.FieldByName("A") - if f.Interface().(int) != 1 { - b.Fatal("Wrong value.") - } - } -} - -func BenchmarkFieldPosL1(b *testing.B) { - e4 := E4{D: 1} - for i := 0; i < b.N; i++ { - v := reflect.ValueOf(e4) - f := v.Field(1) - if f.Interface().(int) != 1 { - b.Fatal("Wrong value.") - } - } -} - -func BenchmarkFieldPosL4(b *testing.B) { - e4 := E4{} - e4.A = 1 - for i := 0; i < b.N; i++ { - v := reflect.ValueOf(e4) - f := v.Field(0) - f = f.Field(0) - f = f.Field(0) - f = f.Field(0) - if f.Interface().(int) != 1 { - b.Fatal("Wrong value.") - } - } -} - -func BenchmarkFieldByIndexL4(b *testing.B) { - e4 := E4{} - e4.A = 1 - idx := []int{0, 0, 0, 0} - for i := 0; i < b.N; i++ { - v := reflect.ValueOf(e4) - f := FieldByIndexes(v, idx) - if f.Interface().(int) != 1 { - b.Fatal("Wrong value.") - } - } -} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx.go b/vendor/github.com/jmoiron/sqlx/sqlx.go deleted file mode 100644 index 4859d5ac8..000000000 --- a/vendor/github.com/jmoiron/sqlx/sqlx.go +++ /dev/null @@ -1,1035 +0,0 @@ -package sqlx - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - - "io/ioutil" - "path/filepath" - "reflect" - "strings" - "sync" - - "github.com/jmoiron/sqlx/reflectx" -) - -// Although the NameMapper is convenient, in practice it should not -// be relied on except for application code. If you are writing a library -// that uses sqlx, you should be aware that the name mappings you expect -// can be overridden by your user's application. - -// NameMapper is used to map column names to struct field names. By default, -// it uses strings.ToLower to lowercase struct field names. It can be set -// to whatever you want, but it is encouraged to be set before sqlx is used -// as name-to-field mappings are cached after first use on a type. -var NameMapper = strings.ToLower -var origMapper = reflect.ValueOf(NameMapper) - -// Rather than creating on init, this is created when necessary so that -// importers have time to customize the NameMapper. -var mpr *reflectx.Mapper - -// mprMu protects mpr. -var mprMu sync.Mutex - -// mapper returns a valid mapper using the configured NameMapper func. -func mapper() *reflectx.Mapper { - mprMu.Lock() - defer mprMu.Unlock() - - if mpr == nil { - mpr = reflectx.NewMapperFunc("db", NameMapper) - } else if origMapper != reflect.ValueOf(NameMapper) { - // if NameMapper has changed, create a new mapper - mpr = reflectx.NewMapperFunc("db", NameMapper) - origMapper = reflect.ValueOf(NameMapper) - } - return mpr -} - -// isScannable takes the reflect.Type and the actual dest value and returns -// whether or not it's Scannable. Something is scannable if: -// * it is not a struct -// * it implements sql.Scanner -// * it has no exported fields -func isScannable(t reflect.Type) bool { - if reflect.PtrTo(t).Implements(_scannerInterface) { - return true - } - if t.Kind() != reflect.Struct { - return true - } - - // it's not important that we use the right mapper for this particular object, - // we're only concerned on how many exported fields this struct has - m := mapper() - if len(m.TypeMap(t).Index) == 0 { - return true - } - return false -} - -// ColScanner is an interface used by MapScan and SliceScan -type ColScanner interface { - Columns() ([]string, error) - Scan(dest ...interface{}) error - Err() error -} - -// Queryer is an interface used by Get and Select -type Queryer interface { - Query(query string, args ...interface{}) (*sql.Rows, error) - Queryx(query string, args ...interface{}) (*Rows, error) - QueryRowx(query string, args ...interface{}) *Row -} - -// Execer is an interface used by MustExec and LoadFile -type Execer interface { - Exec(query string, args ...interface{}) (sql.Result, error) -} - -// Binder is an interface for something which can bind queries (Tx, DB) -type binder interface { - DriverName() string - Rebind(string) string - BindNamed(string, interface{}) (string, []interface{}, error) -} - -// Ext is a union interface which can bind, query, and exec, used by -// NamedQuery and NamedExec. -type Ext interface { - binder - Queryer - Execer -} - -// Preparer is an interface used by Preparex. -type Preparer interface { - Prepare(query string) (*sql.Stmt, error) -} - -// determine if any of our extensions are unsafe -func isUnsafe(i interface{}) bool { - switch v := i.(type) { - case Row: - return v.unsafe - case *Row: - return v.unsafe - case Rows: - return v.unsafe - case *Rows: - return v.unsafe - case NamedStmt: - return v.Stmt.unsafe - case *NamedStmt: - return v.Stmt.unsafe - case Stmt: - return v.unsafe - case *Stmt: - return v.unsafe - case qStmt: - return v.unsafe - case *qStmt: - return v.unsafe - case DB: - return v.unsafe - case *DB: - return v.unsafe - case Tx: - return v.unsafe - case *Tx: - return v.unsafe - case sql.Rows, *sql.Rows: - return false - default: - return false - } -} - -func mapperFor(i interface{}) *reflectx.Mapper { - switch i.(type) { - case DB: - return i.(DB).Mapper - case *DB: - return i.(*DB).Mapper - case Tx: - return i.(Tx).Mapper - case *Tx: - return i.(*Tx).Mapper - default: - return mapper() - } -} - -var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() -var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - -// Row is a reimplementation of sql.Row in order to gain access to the underlying -// sql.Rows.Columns() data, necessary for StructScan. -type Row struct { - err error - unsafe bool - rows *sql.Rows - Mapper *reflectx.Mapper -} - -// Scan is a fixed implementation of sql.Row.Scan, which does not discard the -// underlying error from the internal rows object if it exists. -func (r *Row) Scan(dest ...interface{}) error { - if r.err != nil { - return r.err - } - - // TODO(bradfitz): for now we need to defensively clone all - // []byte that the driver returned (not permitting - // *RawBytes in Rows.Scan), since we're about to close - // the Rows in our defer, when we return from this function. - // the contract with the driver.Next(...) interface is that it - // can return slices into read-only temporary memory that's - // only valid until the next Scan/Close. But the TODO is that - // for a lot of drivers, this copy will be unnecessary. We - // should provide an optional interface for drivers to - // implement to say, "don't worry, the []bytes that I return - // from Next will not be modified again." (for instance, if - // they were obtained from the network anyway) But for now we - // don't care. - defer r.rows.Close() - for _, dp := range dest { - if _, ok := dp.(*sql.RawBytes); ok { - return errors.New("sql: RawBytes isn't allowed on Row.Scan") - } - } - - if !r.rows.Next() { - if err := r.rows.Err(); err != nil { - return err - } - return sql.ErrNoRows - } - err := r.rows.Scan(dest...) - if err != nil { - return err - } - // Make sure the query can be processed to completion with no errors. - if err := r.rows.Close(); err != nil { - return err - } - return nil -} - -// Columns returns the underlying sql.Rows.Columns(), or the deferred error usually -// returned by Row.Scan() -func (r *Row) Columns() ([]string, error) { - if r.err != nil { - return []string{}, r.err - } - return r.rows.Columns() -} - -// Err returns the error encountered while scanning. -func (r *Row) Err() error { - return r.err -} - -// DB is a wrapper around sql.DB which keeps track of the driverName upon Open, -// used mostly to automatically bind named queries using the right bindvars. -type DB struct { - *sql.DB - driverName string - unsafe bool - Mapper *reflectx.Mapper -} - -// NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The -// driverName of the original database is required for named query support. -func NewDb(db *sql.DB, driverName string) *DB { - return &DB{DB: db, driverName: driverName, Mapper: mapper()} -} - -// DriverName returns the driverName passed to the Open function for this DB. -func (db *DB) DriverName() string { - return db.driverName -} - -// Open is the same as sql.Open, but returns an *sqlx.DB instead. -func Open(driverName, dataSourceName string) (*DB, error) { - db, err := sql.Open(driverName, dataSourceName) - if err != nil { - return nil, err - } - return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err -} - -// MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error. -func MustOpen(driverName, dataSourceName string) *DB { - db, err := Open(driverName, dataSourceName) - if err != nil { - panic(err) - } - return db -} - -// MapperFunc sets a new mapper for this db using the default sqlx struct tag -// and the provided mapper function. -func (db *DB) MapperFunc(mf func(string) string) { - db.Mapper = reflectx.NewMapperFunc("db", mf) -} - -// Rebind transforms a query from QUESTION to the DB driver's bindvar type. -func (db *DB) Rebind(query string) string { - return Rebind(BindType(db.driverName), query) -} - -// Unsafe returns a version of DB which will silently succeed to scan when -// columns in the SQL result have no fields in the destination struct. -// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its -// safety behavior. -func (db *DB) Unsafe() *DB { - return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper} -} - -// BindNamed binds a query using the DB driver's bindvar type. -func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { - return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper) -} - -// NamedQuery using this DB. -// Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { - return NamedQuery(db, query, arg) -} - -// NamedExec using this DB. -// Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { - return NamedExec(db, query, arg) -} - -// Select using this DB. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { - return Select(db, dest, query, args...) -} - -// Get using this DB. -// Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { - return Get(db, dest, query, args...) -} - -// MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead -// of an *sql.Tx. -func (db *DB) MustBegin() *Tx { - tx, err := db.Beginx() - if err != nil { - panic(err) - } - return tx -} - -// Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx. -func (db *DB) Beginx() (*Tx, error) { - tx, err := db.DB.Begin() - if err != nil { - return nil, err - } - return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err -} - -// Queryx queries the database and returns an *sqlx.Rows. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { - r, err := db.DB.Query(query, args...) - if err != nil { - return nil, err - } - return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err -} - -// QueryRowx queries the database and returns an *sqlx.Row. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) QueryRowx(query string, args ...interface{}) *Row { - rows, err := db.DB.Query(query, args...) - return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} -} - -// MustExec (panic) runs MustExec using this database. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) MustExec(query string, args ...interface{}) sql.Result { - return MustExec(db, query, args...) -} - -// Preparex returns an sqlx.Stmt instead of a sql.Stmt -func (db *DB) Preparex(query string) (*Stmt, error) { - return Preparex(db, query) -} - -// PrepareNamed returns an sqlx.NamedStmt -func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { - return prepareNamed(db, query) -} - -// Tx is an sqlx wrapper around sql.Tx with extra functionality -type Tx struct { - *sql.Tx - driverName string - unsafe bool - Mapper *reflectx.Mapper -} - -// DriverName returns the driverName used by the DB which began this transaction. -func (tx *Tx) DriverName() string { - return tx.driverName -} - -// Rebind a query within a transaction's bindvar type. -func (tx *Tx) Rebind(query string) string { - return Rebind(BindType(tx.driverName), query) -} - -// Unsafe returns a version of Tx which will silently succeed to scan when -// columns in the SQL result have no fields in the destination struct. -func (tx *Tx) Unsafe() *Tx { - return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper} -} - -// BindNamed binds a query within a transaction's bindvar type. -func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) { - return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper) -} - -// NamedQuery within a transaction. -// Any named placeholder parameters are replaced with fields from arg. -func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { - return NamedQuery(tx, query, arg) -} - -// NamedExec a named query within a transaction. -// Any named placeholder parameters are replaced with fields from arg. -func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { - return NamedExec(tx, query, arg) -} - -// Select within a transaction. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { - return Select(tx, dest, query, args...) -} - -// Queryx within a transaction. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { - r, err := tx.Tx.Query(query, args...) - if err != nil { - return nil, err - } - return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err -} - -// QueryRowx within a transaction. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { - rows, err := tx.Tx.Query(query, args...) - return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} -} - -// Get within a transaction. -// Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { - return Get(tx, dest, query, args...) -} - -// MustExec runs MustExec within a transaction. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { - return MustExec(tx, query, args...) -} - -// Preparex a statement within a transaction. -func (tx *Tx) Preparex(query string) (*Stmt, error) { - return Preparex(tx, query) -} - -// Stmtx returns a version of the prepared statement which runs within a transaction. Provided -// stmt can be either *sql.Stmt or *sqlx.Stmt. -func (tx *Tx) Stmtx(stmt interface{}) *Stmt { - var s *sql.Stmt - switch v := stmt.(type) { - case Stmt: - s = v.Stmt - case *Stmt: - s = v.Stmt - case sql.Stmt: - s = &v - case *sql.Stmt: - s = v - default: - panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) - } - return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper} -} - -// NamedStmt returns a version of the prepared statement which runs within a transaction. -func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt { - return &NamedStmt{ - QueryString: stmt.QueryString, - Params: stmt.Params, - Stmt: tx.Stmtx(stmt.Stmt), - } -} - -// PrepareNamed returns an sqlx.NamedStmt -func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) { - return prepareNamed(tx, query) -} - -// Stmt is an sqlx wrapper around sql.Stmt with extra functionality -type Stmt struct { - *sql.Stmt - unsafe bool - Mapper *reflectx.Mapper -} - -// Unsafe returns a version of Stmt which will silently succeed to scan when -// columns in the SQL result have no fields in the destination struct. -func (s *Stmt) Unsafe() *Stmt { - return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper} -} - -// Select using the prepared statement. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) Select(dest interface{}, args ...interface{}) error { - return Select(&qStmt{s}, dest, "", args...) -} - -// Get using the prepared statement. -// Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func (s *Stmt) Get(dest interface{}, args ...interface{}) error { - return Get(&qStmt{s}, dest, "", args...) -} - -// MustExec (panic) using this statement. Note that the query portion of the error -// output will be blank, as Stmt does not expose its query. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) MustExec(args ...interface{}) sql.Result { - return MustExec(&qStmt{s}, "", args...) -} - -// QueryRowx using this statement. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) QueryRowx(args ...interface{}) *Row { - qs := &qStmt{s} - return qs.QueryRowx("", args...) -} - -// Queryx using this statement. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { - qs := &qStmt{s} - return qs.Queryx("", args...) -} - -// qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by -// implementing those interfaces and ignoring the `query` argument. -type qStmt struct{ *Stmt } - -func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) { - return q.Stmt.Query(args...) -} - -func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { - r, err := q.Stmt.Query(args...) - if err != nil { - return nil, err - } - return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err -} - -func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row { - rows, err := q.Stmt.Query(args...) - return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} -} - -func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) { - return q.Stmt.Exec(args...) -} - -// Rows is a wrapper around sql.Rows which caches costly reflect operations -// during a looped StructScan -type Rows struct { - *sql.Rows - unsafe bool - Mapper *reflectx.Mapper - // these fields cache memory use for a rows during iteration w/ structScan - started bool - fields [][]int - values []interface{} -} - -// SliceScan using this Rows. -func (r *Rows) SliceScan() ([]interface{}, error) { - return SliceScan(r) -} - -// MapScan using this Rows. -func (r *Rows) MapScan(dest map[string]interface{}) error { - return MapScan(r, dest) -} - -// StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct. -// Use this and iterate over Rows manually when the memory load of Select() might be -// prohibitive. *Rows.StructScan caches the reflect work of matching up column -// positions to fields to avoid that overhead per scan, which means it is not safe -// to run StructScan on the same Rows instance with different struct types. -func (r *Rows) StructScan(dest interface{}) error { - v := reflect.ValueOf(dest) - - if v.Kind() != reflect.Ptr { - return errors.New("must pass a pointer, not a value, to StructScan destination") - } - - v = reflect.Indirect(v) - - if !r.started { - columns, err := r.Columns() - if err != nil { - return err - } - m := r.Mapper - - r.fields = m.TraversalsByName(v.Type(), columns) - // if we are not unsafe and are missing fields, return an error - if f, err := missingFields(r.fields); err != nil && !r.unsafe { - return fmt.Errorf("missing destination name %s in %T", columns[f], dest) - } - r.values = make([]interface{}, len(columns)) - r.started = true - } - - err := fieldsByTraversal(v, r.fields, r.values, true) - if err != nil { - return err - } - // scan into the struct field pointers and append to our results - err = r.Scan(r.values...) - if err != nil { - return err - } - return r.Err() -} - -// Connect to a database and verify with a ping. -func Connect(driverName, dataSourceName string) (*DB, error) { - db, err := Open(driverName, dataSourceName) - if err != nil { - return db, err - } - err = db.Ping() - return db, err -} - -// MustConnect connects to a database and panics on error. -func MustConnect(driverName, dataSourceName string) *DB { - db, err := Connect(driverName, dataSourceName) - if err != nil { - panic(err) - } - return db -} - -// Preparex prepares a statement. -func Preparex(p Preparer, query string) (*Stmt, error) { - s, err := p.Prepare(query) - if err != nil { - return nil, err - } - return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err -} - -// Select executes a query using the provided Queryer, and StructScans each row -// into dest, which must be a slice. If the slice elements are scannable, then -// the result set must have only one column. Otherwise, StructScan is used. -// The *sql.Rows are closed automatically. -// Any placeholder parameters are replaced with supplied args. -func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { - rows, err := q.Queryx(query, args...) - if err != nil { - return err - } - // if something happens here, we want to make sure the rows are Closed - defer rows.Close() - return scanAll(rows, dest, false) -} - -// Get does a QueryRow using the provided Queryer, and scans the resulting row -// to dest. If dest is scannable, the result must only have one column. Otherwise, -// StructScan is used. Get will return sql.ErrNoRows like row.Scan would. -// Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { - r := q.QueryRowx(query, args...) - return r.scanAny(dest, false) -} - -// LoadFile exec's every statement in a file (as a single call to Exec). -// LoadFile may return a nil *sql.Result if errors are encountered locating or -// reading the file at path. LoadFile reads the entire file into memory, so it -// is not suitable for loading large data dumps, but can be useful for initializing -// schemas or loading indexes. -// -// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 -// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting -// this by requiring something with DriverName() and then attempting to split the -// queries will be difficult to get right, and its current driver-specific behavior -// is deemed at least not complex in its incorrectness. -func LoadFile(e Execer, path string) (*sql.Result, error) { - realpath, err := filepath.Abs(path) - if err != nil { - return nil, err - } - contents, err := ioutil.ReadFile(realpath) - if err != nil { - return nil, err - } - res, err := e.Exec(string(contents)) - return &res, err -} - -// MustExec execs the query using e and panics if there was an error. -// Any placeholder parameters are replaced with supplied args. -func MustExec(e Execer, query string, args ...interface{}) sql.Result { - res, err := e.Exec(query, args...) - if err != nil { - panic(err) - } - return res -} - -// SliceScan using this Rows. -func (r *Row) SliceScan() ([]interface{}, error) { - return SliceScan(r) -} - -// MapScan using this Rows. -func (r *Row) MapScan(dest map[string]interface{}) error { - return MapScan(r, dest) -} - -func (r *Row) scanAny(dest interface{}, structOnly bool) error { - if r.err != nil { - return r.err - } - if r.rows == nil { - r.err = sql.ErrNoRows - return r.err - } - defer r.rows.Close() - - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return errors.New("must pass a pointer, not a value, to StructScan destination") - } - if v.IsNil() { - return errors.New("nil pointer passed to StructScan destination") - } - - base := reflectx.Deref(v.Type()) - scannable := isScannable(base) - - if structOnly && scannable { - return structOnlyError(base) - } - - columns, err := r.Columns() - if err != nil { - return err - } - - if scannable && len(columns) > 1 { - return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns)) - } - - if scannable { - return r.Scan(dest) - } - - m := r.Mapper - - fields := m.TraversalsByName(v.Type(), columns) - // if we are not unsafe and are missing fields, return an error - if f, err := missingFields(fields); err != nil && !r.unsafe { - return fmt.Errorf("missing destination name %s in %T", columns[f], dest) - } - values := make([]interface{}, len(columns)) - - err = fieldsByTraversal(v, fields, values, true) - if err != nil { - return err - } - // scan into the struct field pointers and append to our results - return r.Scan(values...) -} - -// StructScan a single Row into dest. -func (r *Row) StructScan(dest interface{}) error { - return r.scanAny(dest, true) -} - -// SliceScan a row, returning a []interface{} with values similar to MapScan. -// This function is primarily intended for use where the number of columns -// is not known. Because you can pass an []interface{} directly to Scan, -// it's recommended that you do that as it will not have to allocate new -// slices per row. -func SliceScan(r ColScanner) ([]interface{}, error) { - // ignore r.started, since we needn't use reflect for anything. - columns, err := r.Columns() - if err != nil { - return []interface{}{}, err - } - - values := make([]interface{}, len(columns)) - for i := range values { - values[i] = new(interface{}) - } - - err = r.Scan(values...) - - if err != nil { - return values, err - } - - for i := range columns { - values[i] = *(values[i].(*interface{})) - } - - return values, r.Err() -} - -// MapScan scans a single Row into the dest map[string]interface{}. -// Use this to get results for SQL that might not be under your control -// (for instance, if you're building an interface for an SQL server that -// executes SQL from input). Please do not use this as a primary interface! -// This will modify the map sent to it in place, so reuse the same map with -// care. Columns which occur more than once in the result will overwrite -// each other! -func MapScan(r ColScanner, dest map[string]interface{}) error { - // ignore r.started, since we needn't use reflect for anything. - columns, err := r.Columns() - if err != nil { - return err - } - - values := make([]interface{}, len(columns)) - for i := range values { - values[i] = new(interface{}) - } - - err = r.Scan(values...) - if err != nil { - return err - } - - for i, column := range columns { - dest[column] = *(values[i].(*interface{})) - } - - return r.Err() -} - -type rowsi interface { - Close() error - Columns() ([]string, error) - Err() error - Next() bool - Scan(...interface{}) error -} - -// structOnlyError returns an error appropriate for type when a non-scannable -// struct is expected but something else is given -func structOnlyError(t reflect.Type) error { - isStruct := t.Kind() == reflect.Struct - isScanner := reflect.PtrTo(t).Implements(_scannerInterface) - if !isStruct { - return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) - } - if isScanner { - return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name()) - } - return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) -} - -// scanAll scans all rows into a destination, which must be a slice of any -// type. If the destination slice type is a Struct, then StructScan will be -// used on each row. If the destination is some other kind of base type, then -// each row must only have one column which can scan into that type. This -// allows you to do something like: -// -// rows, _ := db.Query("select id from people;") -// var ids []int -// scanAll(rows, &ids, false) -// -// and ids will be a list of the id results. I realize that this is a desirable -// interface to expose to users, but for now it will only be exposed via changes -// to `Get` and `Select`. The reason that this has been implemented like this is -// this is the only way to not duplicate reflect work in the new API while -// maintaining backwards compatibility. -func scanAll(rows rowsi, dest interface{}, structOnly bool) error { - var v, vp reflect.Value - - value := reflect.ValueOf(dest) - - // json.Unmarshal returns errors for these - if value.Kind() != reflect.Ptr { - return errors.New("must pass a pointer, not a value, to StructScan destination") - } - if value.IsNil() { - return errors.New("nil pointer passed to StructScan destination") - } - direct := reflect.Indirect(value) - - slice, err := baseType(value.Type(), reflect.Slice) - if err != nil { - return err - } - - isPtr := slice.Elem().Kind() == reflect.Ptr - base := reflectx.Deref(slice.Elem()) - scannable := isScannable(base) - - if structOnly && scannable { - return structOnlyError(base) - } - - columns, err := rows.Columns() - if err != nil { - return err - } - - // if it's a base type make sure it only has 1 column; if not return an error - if scannable && len(columns) > 1 { - return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns)) - } - - if !scannable { - var values []interface{} - var m *reflectx.Mapper - - switch rows.(type) { - case *Rows: - m = rows.(*Rows).Mapper - default: - m = mapper() - } - - fields := m.TraversalsByName(base, columns) - // if we are not unsafe and are missing fields, return an error - if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { - return fmt.Errorf("missing destination name %s in %T", columns[f], dest) - } - values = make([]interface{}, len(columns)) - - for rows.Next() { - // create a new struct type (which returns PtrTo) and indirect it - vp = reflect.New(base) - v = reflect.Indirect(vp) - - err = fieldsByTraversal(v, fields, values, true) - if err != nil { - return err - } - - // scan into the struct field pointers and append to our results - err = rows.Scan(values...) - if err != nil { - return err - } - - if isPtr { - direct.Set(reflect.Append(direct, vp)) - } else { - direct.Set(reflect.Append(direct, v)) - } - } - } else { - for rows.Next() { - vp = reflect.New(base) - err = rows.Scan(vp.Interface()) - if err != nil { - return err - } - // append - if isPtr { - direct.Set(reflect.Append(direct, vp)) - } else { - direct.Set(reflect.Append(direct, reflect.Indirect(vp))) - } - } - } - - return rows.Err() -} - -// FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately -// it doesn't really feel like it's named properly. There is an incongruency -// between this and the way that StructScan (which might better be ScanStruct -// anyway) works on a rows object. - -// StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. -// StructScan will scan in the entire rows result, so if you do not want to -// allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. -// If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. -func StructScan(rows rowsi, dest interface{}) error { - return scanAll(rows, dest, true) - -} - -// reflect helpers - -func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { - t = reflectx.Deref(t) - if t.Kind() != expected { - return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) - } - return t, nil -} - -// fieldsByName fills a values interface with fields from the passed value based -// on the traversals in int. If ptrs is true, return addresses instead of values. -// We write this instead of using FieldsByName to save allocations and map lookups -// when iterating over many rows. Empty traversals will get an interface pointer. -// Because of the necessity of requesting ptrs or values, it's considered a bit too -// specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { - v = reflect.Indirect(v) - if v.Kind() != reflect.Struct { - return errors.New("argument not a struct") - } - - for i, traversal := range traversals { - if len(traversal) == 0 { - values[i] = new(interface{}) - continue - } - f := reflectx.FieldByIndexes(v, traversal) - if ptrs { - values[i] = f.Addr().Interface() - } else { - values[i] = f.Interface() - } - } - return nil -} - -func missingFields(transversals [][]int) (field int, err error) { - for i, t := range transversals { - if len(t) == 0 { - return i, errors.New("missing field") - } - } - return 0, nil -} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_context.go b/vendor/github.com/jmoiron/sqlx/sqlx_context.go deleted file mode 100644 index 0b1714514..000000000 --- a/vendor/github.com/jmoiron/sqlx/sqlx_context.go +++ /dev/null @@ -1,335 +0,0 @@ -// +build go1.8 - -package sqlx - -import ( - "context" - "database/sql" - "fmt" - "io/ioutil" - "path/filepath" - "reflect" -) - -// ConnectContext to a database and verify with a ping. -func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { - db, err := Open(driverName, dataSourceName) - if err != nil { - return db, err - } - err = db.PingContext(ctx) - return db, err -} - -// QueryerContext is an interface used by GetContext and SelectContext -type QueryerContext interface { - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) - QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row -} - -// PreparerContext is an interface used by PreparexContext. -type PreparerContext interface { - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) -} - -// ExecerContext is an interface used by MustExecContext and LoadFileContext -type ExecerContext interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) -} - -// ExtContext is a union interface which can bind, query, and exec, with Context -// used by NamedQueryContext and NamedExecContext. -type ExtContext interface { - binder - QueryerContext - ExecerContext -} - -// SelectContext executes a query using the provided Queryer, and StructScans -// each row into dest, which must be a slice. If the slice elements are -// scannable, then the result set must have only one column. Otherwise, -// StructScan is used. The *sql.Rows are closed automatically. -// Any placeholder parameters are replaced with supplied args. -func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { - rows, err := q.QueryxContext(ctx, query, args...) - if err != nil { - return err - } - // if something happens here, we want to make sure the rows are Closed - defer rows.Close() - return scanAll(rows, dest, false) -} - -// PreparexContext prepares a statement. -// -// The provided context is used for the preparation of the statement, not for -// the execution of the statement. -func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { - s, err := p.PrepareContext(ctx, query) - if err != nil { - return nil, err - } - return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err -} - -// GetContext does a QueryRow using the provided Queryer, and scans the -// resulting row to dest. If dest is scannable, the result must only have one -// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like -// row.Scan would. Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { - r := q.QueryRowxContext(ctx, query, args...) - return r.scanAny(dest, false) -} - -// LoadFileContext exec's every statement in a file (as a single call to Exec). -// LoadFileContext may return a nil *sql.Result if errors are encountered -// locating or reading the file at path. LoadFile reads the entire file into -// memory, so it is not suitable for loading large data dumps, but can be useful -// for initializing schemas or loading indexes. -// -// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 -// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting -// this by requiring something with DriverName() and then attempting to split the -// queries will be difficult to get right, and its current driver-specific behavior -// is deemed at least not complex in its incorrectness. -func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { - realpath, err := filepath.Abs(path) - if err != nil { - return nil, err - } - contents, err := ioutil.ReadFile(realpath) - if err != nil { - return nil, err - } - res, err := e.ExecContext(ctx, string(contents)) - return &res, err -} - -// MustExecContext execs the query using e and panics if there was an error. -// Any placeholder parameters are replaced with supplied args. -func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { - res, err := e.ExecContext(ctx, query, args...) - if err != nil { - panic(err) - } - return res -} - -// PrepareNamedContext returns an sqlx.NamedStmt -func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { - return prepareNamedContext(ctx, db, query) -} - -// NamedQueryContext using this DB. -// Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { - return NamedQueryContext(ctx, db, query, arg) -} - -// NamedExecContext using this DB. -// Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { - return NamedExecContext(ctx, db, query, arg) -} - -// SelectContext using this DB. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { - return SelectContext(ctx, db, dest, query, args...) -} - -// GetContext using this DB. -// Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { - return GetContext(ctx, db, dest, query, args...) -} - -// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. -// -// The provided context is used for the preparation of the statement, not for -// the execution of the statement. -func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { - return PreparexContext(ctx, db, query) -} - -// QueryxContext queries the database and returns an *sqlx.Rows. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - r, err := db.DB.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err -} - -// QueryRowxContext queries the database and returns an *sqlx.Row. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { - rows, err := db.DB.QueryContext(ctx, query, args...) - return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} -} - -// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead -// of an *sql.Tx. -// -// The provided context is used until the transaction is committed or rolled -// back. If the context is canceled, the sql package will roll back the -// transaction. Tx.Commit will return an error if the context provided to -// MustBeginContext is canceled. -func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { - tx, err := db.BeginTxx(ctx, opts) - if err != nil { - panic(err) - } - return tx -} - -// MustExecContext (panic) runs MustExec using this database. -// Any placeholder parameters are replaced with supplied args. -func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { - return MustExecContext(ctx, db, query, args...) -} - -// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an -// *sql.Tx. -// -// The provided context is used until the transaction is committed or rolled -// back. If the context is canceled, the sql package will roll back the -// transaction. Tx.Commit will return an error if the context provided to -// BeginxContext is canceled. -func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { - tx, err := db.DB.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err -} - -// StmtxContext returns a version of the prepared statement which runs within a -// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. -func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { - var s *sql.Stmt - switch v := stmt.(type) { - case Stmt: - s = v.Stmt - case *Stmt: - s = v.Stmt - case sql.Stmt: - s = &v - case *sql.Stmt: - s = v - default: - panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) - } - return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} -} - -// NamedStmtContext returns a version of the prepared statement which runs -// within a transaction. -func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { - return &NamedStmt{ - QueryString: stmt.QueryString, - Params: stmt.Params, - Stmt: tx.StmtxContext(ctx, stmt.Stmt), - } -} - -// MustExecContext runs MustExecContext within a transaction. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { - return MustExecContext(ctx, tx, query, args...) -} - -// QueryxContext within a transaction and context. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - r, err := tx.Tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err -} - -// SelectContext within a transaction and context. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { - return SelectContext(ctx, tx, dest, query, args...) -} - -// GetContext within a transaction and context. -// Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { - return GetContext(ctx, tx, dest, query, args...) -} - -// QueryRowxContext within a transaction and context. -// Any placeholder parameters are replaced with supplied args. -func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { - rows, err := tx.Tx.QueryContext(ctx, query, args...) - return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} -} - -// NamedExecContext using this Tx. -// Any named placeholder parameters are replaced with fields from arg. -func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { - return NamedExecContext(ctx, tx, query, arg) -} - -// SelectContext using the prepared statement. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { - return SelectContext(ctx, &qStmt{s}, dest, "", args...) -} - -// GetContext using the prepared statement. -// Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. -func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { - return GetContext(ctx, &qStmt{s}, dest, "", args...) -} - -// MustExecContext (panic) using this statement. Note that the query portion of -// the error output will be blank, as Stmt does not expose its query. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { - return MustExecContext(ctx, &qStmt{s}, "", args...) -} - -// QueryRowxContext using this statement. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { - qs := &qStmt{s} - return qs.QueryRowxContext(ctx, "", args...) -} - -// QueryxContext using this statement. -// Any placeholder parameters are replaced with supplied args. -func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { - qs := &qStmt{s} - return qs.QueryxContext(ctx, "", args...) -} - -func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - return q.Stmt.QueryContext(ctx, args...) -} - -func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - r, err := q.Stmt.QueryContext(ctx, args...) - if err != nil { - return nil, err - } - return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err -} - -func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { - rows, err := q.Stmt.QueryContext(ctx, args...) - return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} -} - -func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - return q.Stmt.ExecContext(ctx, args...) -} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go b/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go deleted file mode 100644 index 85e112bd5..000000000 --- a/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go +++ /dev/null @@ -1,1344 +0,0 @@ -// +build go1.8 - -// The following environment variables, if set, will be used: -// -// * SQLX_SQLITE_DSN -// * SQLX_POSTGRES_DSN -// * SQLX_MYSQL_DSN -// -// Set any of these variables to 'skip' to skip them. Note that for MySQL, -// the string '?parseTime=True' will be appended to the DSN if it's not there -// already. -// -package sqlx - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "log" - "strings" - "testing" - "time" - - _ "github.com/go-sql-driver/mysql" - "github.com/jmoiron/sqlx/reflectx" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" -) - -func MultiExecContext(ctx context.Context, e ExecerContext, query string) { - stmts := strings.Split(query, ";\n") - if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { - stmts = stmts[:len(stmts)-1] - } - for _, s := range stmts { - _, err := e.ExecContext(ctx, s) - if err != nil { - fmt.Println(err, s) - } - } -} - -func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func(ctx context.Context, db *DB, t *testing.T)) { - runner := func(ctx context.Context, db *DB, t *testing.T, create, drop string) { - defer func() { - MultiExecContext(ctx, db, drop) - }() - - MultiExecContext(ctx, db, create) - test(ctx, db, t) - } - - if TestPostgres { - create, drop := schema.Postgres() - runner(ctx, pgdb, t, create, drop) - } - if TestSqlite { - create, drop := schema.Sqlite3() - runner(ctx, sldb, t, create, drop) - } - if TestMysql { - create, drop := schema.MySQL() - runner(ctx, mysqldb, t, create, drop) - } -} - -func loadDefaultFixtureContext(ctx context.Context, db *DB, t *testing.T) { - tx := db.MustBeginTx(ctx, nil) - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") - if db.DriverName() == "mysql" { - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") - } else { - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") - } - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") - tx.Commit() -} - -// Test a new backwards compatible feature, that missing scan destinations -// will silently scan into sql.RawText rather than failing/panicing -func TestMissingNamesContextContext(t *testing.T) { - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - type PersonPlus struct { - FirstName string `db:"first_name"` - LastName string `db:"last_name"` - Email string - //AddedAt time.Time `db:"added_at"` - } - - // test Select first - pps := []PersonPlus{} - // pps lacks added_at destination - err := db.SelectContext(ctx, &pps, "SELECT * FROM person") - if err == nil { - t.Error("Expected missing name from Select to fail, but it did not.") - } - - // test Get - pp := PersonPlus{} - err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") - if err == nil { - t.Error("Expected missing name Get to fail, but it did not.") - } - - // test naked StructScan - pps = []PersonPlus{} - rows, err := db.QueryContext(ctx, "SELECT * FROM person LIMIT 1") - if err != nil { - t.Fatal(err) - } - rows.Next() - err = StructScan(rows, &pps) - if err == nil { - t.Error("Expected missing name in StructScan to fail, but it did not.") - } - rows.Close() - - // now try various things with unsafe set. - db = db.Unsafe() - pps = []PersonPlus{} - err = db.SelectContext(ctx, &pps, "SELECT * FROM person") - if err != nil { - t.Error(err) - } - - // test Get - pp = PersonPlus{} - err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") - if err != nil { - t.Error(err) - } - - // test naked StructScan - pps = []PersonPlus{} - rowsx, err := db.QueryxContext(ctx, "SELECT * FROM person LIMIT 1") - if err != nil { - t.Fatal(err) - } - rowsx.Next() - err = StructScan(rowsx, &pps) - if err != nil { - t.Error(err) - } - rowsx.Close() - - // test Named stmt - if !isUnsafe(db) { - t.Error("Expected db to be unsafe, but it isn't") - } - nstmt, err := db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) - if err != nil { - t.Fatal(err) - } - // its internal stmt should be marked unsafe - if !nstmt.Stmt.unsafe { - t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") - } - pps = []PersonPlus{} - err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) - if err != nil { - t.Fatal(err) - } - if len(pps) != 1 { - t.Errorf("Expected 1 person back, got %d", len(pps)) - } - - // test it with a safe db - db.unsafe = false - if isUnsafe(db) { - t.Error("expected db to be safe but it isn't") - } - nstmt, err = db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) - if err != nil { - t.Fatal(err) - } - // it should be safe - if isUnsafe(nstmt) { - t.Error("NamedStmt did not inherit safety") - } - nstmt.Unsafe() - if !isUnsafe(nstmt) { - t.Error("expected newly unsafed NamedStmt to be unsafe") - } - pps = []PersonPlus{} - err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) - if err != nil { - t.Fatal(err) - } - if len(pps) != 1 { - t.Errorf("Expected 1 person back, got %d", len(pps)) - } - - }) -} - -func TestEmbeddedStructsContextContext(t *testing.T) { - type Loop1 struct{ Person } - type Loop2 struct{ Loop1 } - type Loop3 struct{ Loop2 } - - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - peopleAndPlaces := []PersonPlace{} - err := db.SelectContext( - ctx, - &peopleAndPlaces, - `SELECT person.*, place.* FROM - person natural join place`) - if err != nil { - t.Fatal(err) - } - for _, pp := range peopleAndPlaces { - if len(pp.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - if len(pp.Place.Country) == 0 { - t.Errorf("Expected non zero lengthed country.") - } - } - - // test embedded structs with StructScan - rows, err := db.QueryxContext( - ctx, - `SELECT person.*, place.* FROM - person natural join place`) - if err != nil { - t.Error(err) - } - - perp := PersonPlace{} - rows.Next() - err = rows.StructScan(&perp) - if err != nil { - t.Error(err) - } - - if len(perp.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - if len(perp.Place.Country) == 0 { - t.Errorf("Expected non zero lengthed country.") - } - - rows.Close() - - // test the same for embedded pointer structs - peopleAndPlacesPtrs := []PersonPlacePtr{} - err = db.SelectContext( - ctx, - &peopleAndPlacesPtrs, - `SELECT person.*, place.* FROM - person natural join place`) - if err != nil { - t.Fatal(err) - } - for _, pp := range peopleAndPlacesPtrs { - if len(pp.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - if len(pp.Place.Country) == 0 { - t.Errorf("Expected non zero lengthed country.") - } - } - - // test "deep nesting" - l3s := []Loop3{} - err = db.SelectContext(ctx, &l3s, `select * from person`) - if err != nil { - t.Fatal(err) - } - for _, l3 := range l3s { - if len(l3.Loop2.Loop1.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - } - - // test "embed conflicts" - ec := []EmbedConflict{} - err = db.SelectContext(ctx, &ec, `select * from person`) - // I'm torn between erroring here or having some kind of working behavior - // in order to allow for more flexibility in destination structs - if err != nil { - t.Errorf("Was not expecting an error on embed conflicts.") - } - }) -} - -func TestJoinQueryContext(t *testing.T) { - type Employee struct { - Name string - ID int64 - // BossID is an id into the employee table - BossID sql.NullInt64 `db:"boss_id"` - } - type Boss Employee - - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - - var employees []struct { - Employee - Boss `db:"boss"` - } - - err := db.SelectContext(ctx, - &employees, - `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees - JOIN employees AS boss ON employees.boss_id = boss.id`) - if err != nil { - t.Fatal(err) - } - - for _, em := range employees { - if len(em.Employee.Name) == 0 { - t.Errorf("Expected non zero lengthed name.") - } - if em.Employee.BossID.Int64 != em.Boss.ID { - t.Errorf("Expected boss ids to match") - } - } - }) -} - -func TestJoinQueryNamedPointerStructsContext(t *testing.T) { - type Employee struct { - Name string - ID int64 - // BossID is an id into the employee table - BossID sql.NullInt64 `db:"boss_id"` - } - type Boss Employee - - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - - var employees []struct { - Emp1 *Employee `db:"emp1"` - Emp2 *Employee `db:"emp2"` - *Boss `db:"boss"` - } - - err := db.SelectContext(ctx, - &employees, - `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", - emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", - boss.id "boss.id", boss.name "boss.name" FROM employees AS emp - JOIN employees AS boss ON emp.boss_id = boss.id - `) - if err != nil { - t.Fatal(err) - } - - for _, em := range employees { - if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { - t.Errorf("Expected non zero lengthed name.") - } - if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { - t.Errorf("Expected boss ids to match") - } - } - }) -} - -func TestSelectSliceMapTimeContext(t *testing.T) { - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - rows, err := db.QueryxContext(ctx, "SELECT * FROM person") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - _, err := rows.SliceScan() - if err != nil { - t.Error(err) - } - } - - rows, err = db.QueryxContext(ctx, "SELECT * FROM person") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - m := map[string]interface{}{} - err := rows.MapScan(m) - if err != nil { - t.Error(err) - } - } - - }) -} - -func TestNilReceiverContext(t *testing.T) { - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - var p *Person - err := db.GetContext(ctx, p, "SELECT * FROM person LIMIT 1") - if err == nil { - t.Error("Expected error when getting into nil struct ptr.") - } - var pp *[]Person - err = db.SelectContext(ctx, pp, "SELECT * FROM person") - if err == nil { - t.Error("Expected an error when selecting into nil slice ptr.") - } - }) -} - -func TestNamedQueryContext(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE place ( - id integer PRIMARY KEY, - name text NULL - ); - CREATE TABLE person ( - first_name text NULL, - last_name text NULL, - email text NULL - ); - CREATE TABLE placeperson ( - first_name text NULL, - last_name text NULL, - email text NULL, - place_id integer NULL - ); - CREATE TABLE jsperson ( - "FIRST" text NULL, - last_name text NULL, - "EMAIL" text NULL - );`, - drop: ` - drop table person; - drop table jsperson; - drop table place; - drop table placeperson; - `, - } - - RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { - type Person struct { - FirstName sql.NullString `db:"first_name"` - LastName sql.NullString `db:"last_name"` - Email sql.NullString - } - - p := Person{ - FirstName: sql.NullString{String: "ben", Valid: true}, - LastName: sql.NullString{String: "doe", Valid: true}, - Email: sql.NullString{String: "ben@doe.com", Valid: true}, - } - - q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` - _, err := db.NamedExecContext(ctx, q1, p) - if err != nil { - log.Fatal(err) - } - - p2 := &Person{} - rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(p2) - if err != nil { - t.Error(err) - } - if p2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + p2.FirstName.String) - } - if p2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + p2.LastName.String) - } - } - - // these are tests for #73; they verify that named queries work if you've - // changed the db mapper. This code checks both NamedQuery "ad-hoc" style - // queries and NamedStmt queries, which use different code paths internally. - old := *db.Mapper - - type JSONPerson struct { - FirstName sql.NullString `json:"FIRST"` - LastName sql.NullString `json:"last_name"` - Email sql.NullString - } - - jp := JSONPerson{ - FirstName: sql.NullString{String: "ben", Valid: true}, - LastName: sql.NullString{String: "smith", Valid: true}, - Email: sql.NullString{String: "ben@smith.com", Valid: true}, - } - - db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) - - // prepare queries for case sensitivity to test our ToUpper function. - // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line - // strings are `` we use "" by default and swap out for MySQL - pdb := func(s string, db *DB) string { - if db.DriverName() == "mysql" { - return strings.Replace(s, `"`, "`", -1) - } - return s - } - - q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` - _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) - if err != nil { - t.Fatal(err, db.DriverName()) - } - - // Checks that a person pulled out of the db matches the one we put in - check := func(t *testing.T, rows *Rows) { - jp = JSONPerson{} - for rows.Next() { - err = rows.StructScan(&jp) - if err != nil { - t.Error(err) - } - if jp.FirstName.String != "ben" { - t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) - } - if jp.LastName.String != "smith" { - t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) - } - if jp.Email.String != "ben@smith.com" { - t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) - } - } - } - - ns, err := db.PrepareNamed(pdb(` - SELECT * FROM jsperson - WHERE - "FIRST"=:FIRST AND - last_name=:last_name AND - "EMAIL"=:EMAIL - `, db)) - - if err != nil { - t.Fatal(err) - } - rows, err = ns.QueryxContext(ctx, jp) - if err != nil { - t.Fatal(err) - } - - check(t, rows) - - // Check exactly the same thing, but with db.NamedQuery, which does not go - // through the PrepareNamed/NamedStmt path. - rows, err = db.NamedQueryContext(ctx, pdb(` - SELECT * FROM jsperson - WHERE - "FIRST"=:FIRST AND - last_name=:last_name AND - "EMAIL"=:EMAIL - `, db), jp) - if err != nil { - t.Fatal(err) - } - - check(t, rows) - - db.Mapper = &old - - // Test nested structs - type Place struct { - ID int `db:"id"` - Name sql.NullString `db:"name"` - } - type PlacePerson struct { - FirstName sql.NullString `db:"first_name"` - LastName sql.NullString `db:"last_name"` - Email sql.NullString - Place Place `db:"place"` - } - - pl := Place{ - Name: sql.NullString{String: "myplace", Valid: true}, - } - - pp := PlacePerson{ - FirstName: sql.NullString{String: "ben", Valid: true}, - LastName: sql.NullString{String: "doe", Valid: true}, - Email: sql.NullString{String: "ben@doe.com", Valid: true}, - } - - q2 := `INSERT INTO place (id, name) VALUES (1, :name)` - _, err = db.NamedExecContext(ctx, q2, pl) - if err != nil { - log.Fatal(err) - } - - id := 1 - pp.Place.ID = id - - q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` - _, err = db.NamedExecContext(ctx, q3, pp) - if err != nil { - log.Fatal(err) - } - - pp2 := &PlacePerson{} - rows, err = db.NamedQueryContext(ctx, ` - SELECT - first_name, - last_name, - email, - place.id AS "place.id", - place.name AS "place.name" - FROM placeperson - INNER JOIN place ON place.id = placeperson.place_id - WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(pp2) - if err != nil { - t.Error(err) - } - if pp2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) - } - if pp2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + pp2.LastName.String) - } - if pp2.Place.Name.String != "myplace" { - t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) - } - if pp2.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) - } - } - }) -} - -func TestNilInsertsContext(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE tt ( - id integer, - value text NULL DEFAULT NULL - );`, - drop: "drop table tt;", - } - - RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { - type TT struct { - ID int - Value *string - } - var v, v2 TT - r := db.Rebind - - db.MustExecContext(ctx, r(`INSERT INTO tt (id) VALUES (1)`)) - db.GetContext(ctx, &v, r(`SELECT * FROM tt`)) - if v.ID != 1 { - t.Errorf("Expecting id of 1, got %v", v.ID) - } - if v.Value != nil { - t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) - } - - v.ID = 2 - // NOTE: this incidentally uncovered a bug which was that named queries with - // pointer destinations would not work if the passed value here was not addressable, - // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for - // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly - // function. This next line is important as it provides the only coverage for this. - db.NamedExecContext(ctx, `INSERT INTO tt (id, value) VALUES (:id, :value)`, v) - - db.GetContext(ctx, &v2, r(`SELECT * FROM tt WHERE id=2`)) - if v.ID != v2.ID { - t.Errorf("%v != %v", v.ID, v2.ID) - } - if v2.Value != nil { - t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) - } - }) -} - -func TestScanErrorContext(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE kv ( - k text, - v integer - );`, - drop: `drop table kv;`, - } - - RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { - type WrongTypes struct { - K int - V string - } - _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) - if err != nil { - t.Error(err) - } - - rows, err := db.QueryxContext(ctx, "SELECT * FROM kv") - if err != nil { - t.Error(err) - } - for rows.Next() { - var wt WrongTypes - err := rows.StructScan(&wt) - if err == nil { - t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) - } - } - }) -} - -// FIXME: this function is kinda big but it slows things down to be constantly -// loading and reloading the schema.. - -func TestUsageContext(t *testing.T) { - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - slicemembers := []SliceMember{} - err := db.SelectContext(ctx, &slicemembers, "SELECT * FROM place ORDER BY telcode ASC") - if err != nil { - t.Fatal(err) - } - - people := []Person{} - - err = db.SelectContext(ctx, &people, "SELECT * FROM person ORDER BY first_name ASC") - if err != nil { - t.Fatal(err) - } - - jason, john := people[0], people[1] - if jason.FirstName != "Jason" { - t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) - } - if jason.LastName != "Moiron" { - t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) - } - if jason.Email != "jmoiron@jmoiron.net" { - t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) - } - if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { - t.Errorf("John Doe's person record not what expected: Got %v\n", john) - } - - jason = Person{} - err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") - - if err != nil { - t.Fatal(err) - } - if jason.FirstName != "Jason" { - t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) - } - - err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") - if err == nil { - t.Errorf("Expecting an error, got nil\n") - } - if err != sql.ErrNoRows { - t.Errorf("Expected sql.ErrNoRows, got %v\n", err) - } - - // The following tests check statement reuse, which was actually a problem - // due to copying being done when creating Stmt's which was eventually removed - stmt1, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Fatal(err) - } - jason = Person{} - - row := stmt1.QueryRowx("DoesNotExist") - row.Scan(&jason) - row = stmt1.QueryRowx("DoesNotExist") - row.Scan(&jason) - - err = stmt1.GetContext(ctx, &jason, "DoesNotExist User") - if err == nil { - t.Error("Expected an error") - } - err = stmt1.GetContext(ctx, &jason, "DoesNotExist User 2") - if err == nil { - t.Fatal(err) - } - - stmt2, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Fatal(err) - } - jason = Person{} - tx, err := db.Beginx() - if err != nil { - t.Fatal(err) - } - tstmt2 := tx.Stmtx(stmt2) - row2 := tstmt2.QueryRowx("Jason") - err = row2.StructScan(&jason) - if err != nil { - t.Error(err) - } - tx.Commit() - - places := []*Place{} - err = db.SelectContext(ctx, &places, "SELECT telcode FROM place ORDER BY telcode ASC") - if err != nil { - t.Fatal(err) - } - - usa, singsing, honkers := places[0], places[1], places[2] - - if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { - t.Errorf("Expected integer telcodes to work, got %#v", places) - } - - placesptr := []PlacePtr{} - err = db.SelectContext(ctx, &placesptr, "SELECT * FROM place ORDER BY telcode ASC") - if err != nil { - t.Error(err) - } - //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) - - // if you have null fields and use SELECT *, you must use sql.Null* in your struct - // this test also verifies that you can use either a []Struct{} or a []*Struct{} - places2 := []Place{} - err = db.SelectContext(ctx, &places2, "SELECT * FROM place ORDER BY telcode ASC") - if err != nil { - t.Fatal(err) - } - - usa, singsing, honkers = &places2[0], &places2[1], &places2[2] - - // this should return a type error that &p is not a pointer to a struct slice - p := Place{} - err = db.SelectContext(ctx, &p, "SELECT * FROM place ORDER BY telcode ASC") - if err == nil { - t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") - } - - // this should be an error - pl := []Place{} - err = db.SelectContext(ctx, pl, "SELECT * FROM place ORDER BY telcode ASC") - if err == nil { - t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") - } - - if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { - t.Errorf("Expected integer telcodes to work, got %#v", places) - } - - stmt, err := db.PreparexContext(ctx, db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) - if err != nil { - t.Error(err) - } - - places = []*Place{} - err = stmt.SelectContext(ctx, &places, 10) - if len(places) != 2 { - t.Error("Expected 2 places, got 0.") - } - if err != nil { - t.Fatal(err) - } - singsing, honkers = places[0], places[1] - if singsing.TelCode != 65 || honkers.TelCode != 852 { - t.Errorf("Expected the right telcodes, got %#v", places) - } - - rows, err := db.QueryxContext(ctx, "SELECT * FROM place") - if err != nil { - t.Fatal(err) - } - place := Place{} - for rows.Next() { - err = rows.StructScan(&place) - if err != nil { - t.Fatal(err) - } - } - - rows, err = db.QueryxContext(ctx, "SELECT * FROM place") - if err != nil { - t.Fatal(err) - } - m := map[string]interface{}{} - for rows.Next() { - err = rows.MapScan(m) - if err != nil { - t.Fatal(err) - } - _, ok := m["country"] - if !ok { - t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) - } - } - - rows, err = db.QueryxContext(ctx, "SELECT * FROM place") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - s, err := rows.SliceScan() - if err != nil { - t.Error(err) - } - if len(s) != 3 { - t.Errorf("Expected 3 columns in result, got %d\n", len(s)) - } - } - - // test advanced querying - // test that NamedExec works with a map as well as a struct - _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ - "first": "Bin", - "last": "Smuth", - "email": "bensmith@allblacks.nz", - }) - if err != nil { - t.Fatal(err) - } - - // ensure that if the named param happens right at the end it still works - // ensure that NamedQuery works with a map[string]interface{} - rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) - if err != nil { - t.Fatal(err) - } - - ben := &Person{} - for rows.Next() { - err = rows.StructScan(ben) - if err != nil { - t.Fatal(err) - } - if ben.FirstName != "Bin" { - t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) - } - if ben.LastName != "Smuth" { - t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) - } - } - - ben.FirstName = "Ben" - ben.LastName = "Smith" - ben.Email = "binsmuth@allblacks.nz" - - // Insert via a named query using the struct - _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) - - if err != nil { - t.Fatal(err) - } - - rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", ben) - if err != nil { - t.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(ben) - if err != nil { - t.Fatal(err) - } - if ben.FirstName != "Ben" { - t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) - } - if ben.LastName != "Smith" { - t.Fatal("Expected first name of `Smith`, got " + ben.LastName) - } - } - // ensure that Get does not panic on emppty result set - person := &Person{} - err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") - if err == nil { - t.Fatal("Should have got an error for Get on non-existant row.") - } - - // lets test prepared statements some more - - stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Fatal(err) - } - rows, err = stmt.QueryxContext(ctx, "Ben") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(ben) - if err != nil { - t.Fatal(err) - } - if ben.FirstName != "Ben" { - t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) - } - if ben.LastName != "Smith" { - t.Fatal("Expected first name of `Smith`, got " + ben.LastName) - } - } - - john = Person{} - stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Error(err) - } - err = stmt.GetContext(ctx, &john, "John") - if err != nil { - t.Error(err) - } - - // test name mapping - // THIS USED TO WORK BUT WILL NO LONGER WORK. - db.MapperFunc(strings.ToUpper) - rsa := CPlace{} - err = db.GetContext(ctx, &rsa, "SELECT * FROM capplace;") - if err != nil { - t.Error(err, "in db:", db.DriverName()) - } - db.MapperFunc(strings.ToLower) - - // create a copy and change the mapper, then verify the copy behaves - // differently from the original. - dbCopy := NewDb(db.DB, db.DriverName()) - dbCopy.MapperFunc(strings.ToUpper) - err = dbCopy.GetContext(ctx, &rsa, "SELECT * FROM capplace;") - if err != nil { - fmt.Println(db.DriverName()) - t.Error(err) - } - - err = db.GetContext(ctx, &rsa, "SELECT * FROM cappplace;") - if err == nil { - t.Error("Expected no error, got ", err) - } - - // test base type slices - var sdest []string - rows, err = db.QueryxContext(ctx, "SELECT email FROM person ORDER BY email ASC;") - if err != nil { - t.Error(err) - } - err = scanAll(rows, &sdest, false) - if err != nil { - t.Error(err) - } - - // test Get with base types - var count int - err = db.GetContext(ctx, &count, "SELECT count(*) FROM person;") - if err != nil { - t.Error(err) - } - if count != len(sdest) { - t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) - } - - // test Get and Select with time.Time, #84 - var addedAt time.Time - err = db.GetContext(ctx, &addedAt, "SELECT added_at FROM person LIMIT 1;") - if err != nil { - t.Error(err) - } - - var addedAts []time.Time - err = db.SelectContext(ctx, &addedAts, "SELECT added_at FROM person;") - if err != nil { - t.Error(err) - } - - // test it on a double pointer - var pcount *int - err = db.GetContext(ctx, &pcount, "SELECT count(*) FROM person;") - if err != nil { - t.Error(err) - } - if *pcount != count { - t.Errorf("expected %d = %d", *pcount, count) - } - - // test Select... - sdest = []string{} - err = db.SelectContext(ctx, &sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") - if err != nil { - t.Error(err) - } - expected := []string{"Ben", "Bin", "Jason", "John"} - for i, got := range sdest { - if got != expected[i] { - t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) - } - } - - var nsdest []sql.NullString - err = db.SelectContext(ctx, &nsdest, "SELECT city FROM place ORDER BY city ASC") - if err != nil { - t.Error(err) - } - for _, val := range nsdest { - if val.Valid && val.String != "New York" { - t.Errorf("expected single valid result to be `New York`, but got %s", val.String) - } - } - }) -} - -// tests that sqlx will not panic when the wrong driver is passed because -// of an automatic nil dereference in sqlx.Open(), which was fixed. -func TestDoNotPanicOnConnectContext(t *testing.T) { - _, err := ConnectContext(context.Background(), "bogus", "hehe") - if err == nil { - t.Errorf("Should return error when using bogus driverName") - } -} - -func TestEmbeddedMapsContext(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE message ( - string text, - properties text - );`, - drop: `drop table message;`, - } - - RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { - messages := []Message{ - {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, - {"Thanks, Joy", PropertyMap{"pull": "request"}}, - } - q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` - for _, m := range messages { - _, err := db.NamedExecContext(ctx, q1, m) - if err != nil { - t.Fatal(err) - } - } - var count int - err := db.GetContext(ctx, &count, "SELECT count(*) FROM message") - if err != nil { - t.Fatal(err) - } - if count != len(messages) { - t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) - } - - var m Message - err = db.GetContext(ctx, &m, "SELECT * FROM message LIMIT 1;") - if err != nil { - t.Fatal(err) - } - if m.Properties == nil { - t.Fatal("Expected m.Properties to not be nil, but it was.") - } - }) -} - -func TestIssue197Context(t *testing.T) { - // this test actually tests for a bug in database/sql: - // https://github.com/golang/go/issues/13905 - // this potentially makes _any_ named type that is an alias for []byte - // unsafe to use in a lot of different ways (basically, unsafe to hold - // onto after loading from the database). - t.Skip() - - type mybyte []byte - type Var struct{ Raw json.RawMessage } - type Var2 struct{ Raw []byte } - type Var3 struct{ Raw mybyte } - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - var err error - var v, q Var - if err = db.GetContext(ctx, &v, `SELECT '{"a": "b"}' AS raw`); err != nil { - t.Fatal(err) - } - if err = db.GetContext(ctx, &q, `SELECT 'null' AS raw`); err != nil { - t.Fatal(err) - } - - var v2, q2 Var2 - if err = db.GetContext(ctx, &v2, `SELECT '{"a": "b"}' AS raw`); err != nil { - t.Fatal(err) - } - if err = db.GetContext(ctx, &q2, `SELECT 'null' AS raw`); err != nil { - t.Fatal(err) - } - - var v3, q3 Var3 - if err = db.QueryRowContext(ctx, `SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { - t.Fatal(err) - } - if err = db.QueryRowContext(ctx, `SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { - t.Fatal(err) - } - t.Fail() - }) -} - -func TestInContext(t *testing.T) { - // some quite normal situations - type tr struct { - q string - args []interface{} - c int - } - tests := []tr{ - {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", - []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, - 7}, - {"SELECT * FROM foo WHERE x in (?)", - []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, - 8}, - } - for _, test := range tests { - q, a, err := In(test.q, test.args...) - if err != nil { - t.Error(err) - } - if len(a) != test.c { - t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) - } - if strings.Count(q, "?") != test.c { - t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) - } - } - - // too many bindVars, but no slices, so short circuits parsing - // i'm not sure if this is the right behavior; this query/arg combo - // might not work, but we shouldn't parse if we don't need to - { - orig := "SELECT * FROM foo WHERE x = ? AND y = ?" - q, a, err := In(orig, "foo", "bar", "baz") - if err != nil { - t.Error(err) - } - if len(a) != 3 { - t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) - } - if q != orig { - t.Error("Expected unchanged query.") - } - } - - tests = []tr{ - // too many bindvars; slice present so should return error during parse - {"SELECT * FROM foo WHERE x = ? and y = ?", - []interface{}{"foo", []int{1, 2, 3}, "bar"}, - 0}, - // empty slice, should return error before parse - {"SELECT * FROM foo WHERE x = ?", - []interface{}{[]int{}}, - 0}, - // too *few* bindvars, should return an error - {"SELECT * FROM foo WHERE x = ? AND y in (?)", - []interface{}{[]int{1, 2, 3}}, - 0}, - } - for _, test := range tests { - _, _, err := In(test.q, test.args...) - if err == nil { - t.Error("Expected an error, but got nil.") - } - } - RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { - loadDefaultFixtureContext(ctx, db, t) - //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") - //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") - //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") - telcodes := []int{852, 65} - q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" - query, args, err := In(q, telcodes) - if err != nil { - t.Error(err) - } - query = db.Rebind(query) - places := []Place{} - err = db.SelectContext(ctx, &places, query, args...) - if err != nil { - t.Error(err) - } - if len(places) != 2 { - t.Fatalf("Expecting 2 results, got %d", len(places)) - } - if places[0].TelCode != 65 { - t.Errorf("Expecting singapore first, but got %#v", places[0]) - } - if places[1].TelCode != 852 { - t.Errorf("Expecting hong kong second, but got %#v", places[1]) - } - }) -} - -func TestEmbeddedLiteralsContext(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE x ( - k text - );`, - drop: `drop table x;`, - } - - RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { - type t1 struct { - K *string - } - type t2 struct { - Inline struct { - F string - } - K *string - } - - db.MustExecContext(ctx, db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") - - target := t1{} - err := db.GetContext(ctx, &target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") - if err != nil { - t.Error(err) - } - if *target.K != "one" { - t.Error("Expected target.K to be `one`, got ", target.K) - } - - target2 := t2{} - err = db.GetContext(ctx, &target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") - if err != nil { - t.Error(err) - } - if *target2.K != "one" { - t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) - } - }) -} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_test.go b/vendor/github.com/jmoiron/sqlx/sqlx_test.go deleted file mode 100644 index 5752773a0..000000000 --- a/vendor/github.com/jmoiron/sqlx/sqlx_test.go +++ /dev/null @@ -1,1792 +0,0 @@ -// The following environment variables, if set, will be used: -// -// * SQLX_SQLITE_DSN -// * SQLX_POSTGRES_DSN -// * SQLX_MYSQL_DSN -// -// Set any of these variables to 'skip' to skip them. Note that for MySQL, -// the string '?parseTime=True' will be appended to the DSN if it's not there -// already. -// -package sqlx - -import ( - "database/sql" - "database/sql/driver" - "encoding/json" - "fmt" - "log" - "os" - "reflect" - "strings" - "testing" - "time" - - _ "github.com/go-sql-driver/mysql" - "github.com/jmoiron/sqlx/reflectx" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" -) - -/* compile time checks that Db, Tx, Stmt (qStmt) implement expected interfaces */ -var _, _ Ext = &DB{}, &Tx{} -var _, _ ColScanner = &Row{}, &Rows{} -var _ Queryer = &qStmt{} -var _ Execer = &qStmt{} - -var TestPostgres = true -var TestSqlite = true -var TestMysql = true - -var sldb *DB -var pgdb *DB -var mysqldb *DB -var active = []*DB{} - -func init() { - ConnectAll() -} - -func ConnectAll() { - var err error - - pgdsn := os.Getenv("SQLX_POSTGRES_DSN") - mydsn := os.Getenv("SQLX_MYSQL_DSN") - sqdsn := os.Getenv("SQLX_SQLITE_DSN") - - TestPostgres = pgdsn != "skip" - TestMysql = mydsn != "skip" - TestSqlite = sqdsn != "skip" - - if !strings.Contains(mydsn, "parseTime=true") { - mydsn += "?parseTime=true" - } - - if TestPostgres { - pgdb, err = Connect("postgres", pgdsn) - if err != nil { - fmt.Printf("Disabling PG tests:\n %v\n", err) - TestPostgres = false - } - } else { - fmt.Println("Disabling Postgres tests.") - } - - if TestMysql { - mysqldb, err = Connect("mysql", mydsn) - if err != nil { - fmt.Printf("Disabling MySQL tests:\n %v", err) - TestMysql = false - } - } else { - fmt.Println("Disabling MySQL tests.") - } - - if TestSqlite { - sldb, err = Connect("sqlite3", sqdsn) - if err != nil { - fmt.Printf("Disabling SQLite:\n %v", err) - TestSqlite = false - } - } else { - fmt.Println("Disabling SQLite tests.") - } -} - -type Schema struct { - create string - drop string -} - -func (s Schema) Postgres() (string, string) { - return s.create, s.drop -} - -func (s Schema) MySQL() (string, string) { - return strings.Replace(s.create, `"`, "`", -1), s.drop -} - -func (s Schema) Sqlite3() (string, string) { - return strings.Replace(s.create, `now()`, `CURRENT_TIMESTAMP`, -1), s.drop -} - -var defaultSchema = Schema{ - create: ` -CREATE TABLE person ( - first_name text, - last_name text, - email text, - added_at timestamp default now() -); - -CREATE TABLE place ( - country text, - city text NULL, - telcode integer -); - -CREATE TABLE capplace ( - "COUNTRY" text, - "CITY" text NULL, - "TELCODE" integer -); - -CREATE TABLE nullperson ( - first_name text NULL, - last_name text NULL, - email text NULL -); - -CREATE TABLE employees ( - name text, - id integer, - boss_id integer -); - -`, - drop: ` -drop table person; -drop table place; -drop table capplace; -drop table nullperson; -drop table employees; -`, -} - -type Person struct { - FirstName string `db:"first_name"` - LastName string `db:"last_name"` - Email string - AddedAt time.Time `db:"added_at"` -} - -type Person2 struct { - FirstName sql.NullString `db:"first_name"` - LastName sql.NullString `db:"last_name"` - Email sql.NullString -} - -type Place struct { - Country string - City sql.NullString - TelCode int -} - -type PlacePtr struct { - Country string - City *string - TelCode int -} - -type PersonPlace struct { - Person - Place -} - -type PersonPlacePtr struct { - *Person - *Place -} - -type EmbedConflict struct { - FirstName string `db:"first_name"` - Person -} - -type SliceMember struct { - Country string - City sql.NullString - TelCode int - People []Person `db:"-"` - Addresses []Place `db:"-"` -} - -// Note that because of field map caching, we need a new type here -// if we've used Place already somewhere in sqlx -type CPlace Place - -func MultiExec(e Execer, query string) { - stmts := strings.Split(query, ";\n") - if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { - stmts = stmts[:len(stmts)-1] - } - for _, s := range stmts { - _, err := e.Exec(s) - if err != nil { - fmt.Println(err, s) - } - } -} - -func RunWithSchema(schema Schema, t *testing.T, test func(db *DB, t *testing.T)) { - runner := func(db *DB, t *testing.T, create, drop string) { - defer func() { - MultiExec(db, drop) - }() - - MultiExec(db, create) - test(db, t) - } - - if TestPostgres { - create, drop := schema.Postgres() - runner(pgdb, t, create, drop) - } - if TestSqlite { - create, drop := schema.Sqlite3() - runner(sldb, t, create, drop) - } - if TestMysql { - create, drop := schema.MySQL() - runner(mysqldb, t, create, drop) - } -} - -func loadDefaultFixture(db *DB, t *testing.T) { - tx := db.MustBegin() - tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") - tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") - tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") - tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") - tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") - if db.DriverName() == "mysql" { - tx.MustExec(tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") - } else { - tx.MustExec(tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") - } - tx.MustExec(tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") - tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") - tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") - tx.Commit() -} - -// Test a new backwards compatible feature, that missing scan destinations -// will silently scan into sql.RawText rather than failing/panicing -func TestMissingNames(t *testing.T) { - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - type PersonPlus struct { - FirstName string `db:"first_name"` - LastName string `db:"last_name"` - Email string - //AddedAt time.Time `db:"added_at"` - } - - // test Select first - pps := []PersonPlus{} - // pps lacks added_at destination - err := db.Select(&pps, "SELECT * FROM person") - if err == nil { - t.Error("Expected missing name from Select to fail, but it did not.") - } - - // test Get - pp := PersonPlus{} - err = db.Get(&pp, "SELECT * FROM person LIMIT 1") - if err == nil { - t.Error("Expected missing name Get to fail, but it did not.") - } - - // test naked StructScan - pps = []PersonPlus{} - rows, err := db.Query("SELECT * FROM person LIMIT 1") - if err != nil { - t.Fatal(err) - } - rows.Next() - err = StructScan(rows, &pps) - if err == nil { - t.Error("Expected missing name in StructScan to fail, but it did not.") - } - rows.Close() - - // now try various things with unsafe set. - db = db.Unsafe() - pps = []PersonPlus{} - err = db.Select(&pps, "SELECT * FROM person") - if err != nil { - t.Error(err) - } - - // test Get - pp = PersonPlus{} - err = db.Get(&pp, "SELECT * FROM person LIMIT 1") - if err != nil { - t.Error(err) - } - - // test naked StructScan - pps = []PersonPlus{} - rowsx, err := db.Queryx("SELECT * FROM person LIMIT 1") - if err != nil { - t.Fatal(err) - } - rowsx.Next() - err = StructScan(rowsx, &pps) - if err != nil { - t.Error(err) - } - rowsx.Close() - - // test Named stmt - if !isUnsafe(db) { - t.Error("Expected db to be unsafe, but it isn't") - } - nstmt, err := db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) - if err != nil { - t.Fatal(err) - } - // its internal stmt should be marked unsafe - if !nstmt.Stmt.unsafe { - t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") - } - pps = []PersonPlus{} - err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"}) - if err != nil { - t.Fatal(err) - } - if len(pps) != 1 { - t.Errorf("Expected 1 person back, got %d", len(pps)) - } - - // test it with a safe db - db.unsafe = false - if isUnsafe(db) { - t.Error("expected db to be safe but it isn't") - } - nstmt, err = db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) - if err != nil { - t.Fatal(err) - } - // it should be safe - if isUnsafe(nstmt) { - t.Error("NamedStmt did not inherit safety") - } - nstmt.Unsafe() - if !isUnsafe(nstmt) { - t.Error("expected newly unsafed NamedStmt to be unsafe") - } - pps = []PersonPlus{} - err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"}) - if err != nil { - t.Fatal(err) - } - if len(pps) != 1 { - t.Errorf("Expected 1 person back, got %d", len(pps)) - } - - }) -} - -func TestEmbeddedStructs(t *testing.T) { - type Loop1 struct{ Person } - type Loop2 struct{ Loop1 } - type Loop3 struct{ Loop2 } - - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - peopleAndPlaces := []PersonPlace{} - err := db.Select( - &peopleAndPlaces, - `SELECT person.*, place.* FROM - person natural join place`) - if err != nil { - t.Fatal(err) - } - for _, pp := range peopleAndPlaces { - if len(pp.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - if len(pp.Place.Country) == 0 { - t.Errorf("Expected non zero lengthed country.") - } - } - - // test embedded structs with StructScan - rows, err := db.Queryx( - `SELECT person.*, place.* FROM - person natural join place`) - if err != nil { - t.Error(err) - } - - perp := PersonPlace{} - rows.Next() - err = rows.StructScan(&perp) - if err != nil { - t.Error(err) - } - - if len(perp.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - if len(perp.Place.Country) == 0 { - t.Errorf("Expected non zero lengthed country.") - } - - rows.Close() - - // test the same for embedded pointer structs - peopleAndPlacesPtrs := []PersonPlacePtr{} - err = db.Select( - &peopleAndPlacesPtrs, - `SELECT person.*, place.* FROM - person natural join place`) - if err != nil { - t.Fatal(err) - } - for _, pp := range peopleAndPlacesPtrs { - if len(pp.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - if len(pp.Place.Country) == 0 { - t.Errorf("Expected non zero lengthed country.") - } - } - - // test "deep nesting" - l3s := []Loop3{} - err = db.Select(&l3s, `select * from person`) - if err != nil { - t.Fatal(err) - } - for _, l3 := range l3s { - if len(l3.Loop2.Loop1.Person.FirstName) == 0 { - t.Errorf("Expected non zero lengthed first name.") - } - } - - // test "embed conflicts" - ec := []EmbedConflict{} - err = db.Select(&ec, `select * from person`) - // I'm torn between erroring here or having some kind of working behavior - // in order to allow for more flexibility in destination structs - if err != nil { - t.Errorf("Was not expecting an error on embed conflicts.") - } - }) -} - -func TestJoinQuery(t *testing.T) { - type Employee struct { - Name string - ID int64 - // BossID is an id into the employee table - BossID sql.NullInt64 `db:"boss_id"` - } - type Boss Employee - - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - - var employees []struct { - Employee - Boss `db:"boss"` - } - - err := db.Select( - &employees, - `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees - JOIN employees AS boss ON employees.boss_id = boss.id`) - if err != nil { - t.Fatal(err) - } - - for _, em := range employees { - if len(em.Employee.Name) == 0 { - t.Errorf("Expected non zero lengthed name.") - } - if em.Employee.BossID.Int64 != em.Boss.ID { - t.Errorf("Expected boss ids to match") - } - } - }) -} - -func TestJoinQueryNamedPointerStructs(t *testing.T) { - type Employee struct { - Name string - ID int64 - // BossID is an id into the employee table - BossID sql.NullInt64 `db:"boss_id"` - } - type Boss Employee - - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - - var employees []struct { - Emp1 *Employee `db:"emp1"` - Emp2 *Employee `db:"emp2"` - *Boss `db:"boss"` - } - - err := db.Select( - &employees, - `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", - emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", - boss.id "boss.id", boss.name "boss.name" FROM employees AS emp - JOIN employees AS boss ON emp.boss_id = boss.id - `) - if err != nil { - t.Fatal(err) - } - - for _, em := range employees { - if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { - t.Errorf("Expected non zero lengthed name.") - } - if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { - t.Errorf("Expected boss ids to match") - } - } - }) -} - -func TestSelectSliceMapTime(t *testing.T) { - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - rows, err := db.Queryx("SELECT * FROM person") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - _, err := rows.SliceScan() - if err != nil { - t.Error(err) - } - } - - rows, err = db.Queryx("SELECT * FROM person") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - m := map[string]interface{}{} - err := rows.MapScan(m) - if err != nil { - t.Error(err) - } - } - - }) -} - -func TestNilReceiver(t *testing.T) { - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - var p *Person - err := db.Get(p, "SELECT * FROM person LIMIT 1") - if err == nil { - t.Error("Expected error when getting into nil struct ptr.") - } - var pp *[]Person - err = db.Select(pp, "SELECT * FROM person") - if err == nil { - t.Error("Expected an error when selecting into nil slice ptr.") - } - }) -} - -func TestNamedQuery(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE place ( - id integer PRIMARY KEY, - name text NULL - ); - CREATE TABLE person ( - first_name text NULL, - last_name text NULL, - email text NULL - ); - CREATE TABLE placeperson ( - first_name text NULL, - last_name text NULL, - email text NULL, - place_id integer NULL - ); - CREATE TABLE jsperson ( - "FIRST" text NULL, - last_name text NULL, - "EMAIL" text NULL - );`, - drop: ` - drop table person; - drop table jsperson; - drop table place; - drop table placeperson; - `, - } - - RunWithSchema(schema, t, func(db *DB, t *testing.T) { - type Person struct { - FirstName sql.NullString `db:"first_name"` - LastName sql.NullString `db:"last_name"` - Email sql.NullString - } - - p := Person{ - FirstName: sql.NullString{String: "ben", Valid: true}, - LastName: sql.NullString{String: "doe", Valid: true}, - Email: sql.NullString{String: "ben@doe.com", Valid: true}, - } - - q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` - _, err := db.NamedExec(q1, p) - if err != nil { - log.Fatal(err) - } - - p2 := &Person{} - rows, err := db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", p) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(p2) - if err != nil { - t.Error(err) - } - if p2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + p2.FirstName.String) - } - if p2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + p2.LastName.String) - } - } - - // these are tests for #73; they verify that named queries work if you've - // changed the db mapper. This code checks both NamedQuery "ad-hoc" style - // queries and NamedStmt queries, which use different code paths internally. - old := *db.Mapper - - type JSONPerson struct { - FirstName sql.NullString `json:"FIRST"` - LastName sql.NullString `json:"last_name"` - Email sql.NullString - } - - jp := JSONPerson{ - FirstName: sql.NullString{String: "ben", Valid: true}, - LastName: sql.NullString{String: "smith", Valid: true}, - Email: sql.NullString{String: "ben@smith.com", Valid: true}, - } - - db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) - - // prepare queries for case sensitivity to test our ToUpper function. - // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line - // strings are `` we use "" by default and swap out for MySQL - pdb := func(s string, db *DB) string { - if db.DriverName() == "mysql" { - return strings.Replace(s, `"`, "`", -1) - } - return s - } - - q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` - _, err = db.NamedExec(pdb(q1, db), jp) - if err != nil { - t.Fatal(err, db.DriverName()) - } - - // Checks that a person pulled out of the db matches the one we put in - check := func(t *testing.T, rows *Rows) { - jp = JSONPerson{} - for rows.Next() { - err = rows.StructScan(&jp) - if err != nil { - t.Error(err) - } - if jp.FirstName.String != "ben" { - t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) - } - if jp.LastName.String != "smith" { - t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) - } - if jp.Email.String != "ben@smith.com" { - t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) - } - } - } - - ns, err := db.PrepareNamed(pdb(` - SELECT * FROM jsperson - WHERE - "FIRST"=:FIRST AND - last_name=:last_name AND - "EMAIL"=:EMAIL - `, db)) - - if err != nil { - t.Fatal(err) - } - rows, err = ns.Queryx(jp) - if err != nil { - t.Fatal(err) - } - - check(t, rows) - - // Check exactly the same thing, but with db.NamedQuery, which does not go - // through the PrepareNamed/NamedStmt path. - rows, err = db.NamedQuery(pdb(` - SELECT * FROM jsperson - WHERE - "FIRST"=:FIRST AND - last_name=:last_name AND - "EMAIL"=:EMAIL - `, db), jp) - if err != nil { - t.Fatal(err) - } - - check(t, rows) - - db.Mapper = &old - - // Test nested structs - type Place struct { - ID int `db:"id"` - Name sql.NullString `db:"name"` - } - type PlacePerson struct { - FirstName sql.NullString `db:"first_name"` - LastName sql.NullString `db:"last_name"` - Email sql.NullString - Place Place `db:"place"` - } - - pl := Place{ - Name: sql.NullString{String: "myplace", Valid: true}, - } - - pp := PlacePerson{ - FirstName: sql.NullString{String: "ben", Valid: true}, - LastName: sql.NullString{String: "doe", Valid: true}, - Email: sql.NullString{String: "ben@doe.com", Valid: true}, - } - - q2 := `INSERT INTO place (id, name) VALUES (1, :name)` - _, err = db.NamedExec(q2, pl) - if err != nil { - log.Fatal(err) - } - - id := 1 - pp.Place.ID = id - - q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` - _, err = db.NamedExec(q3, pp) - if err != nil { - log.Fatal(err) - } - - pp2 := &PlacePerson{} - rows, err = db.NamedQuery(` - SELECT - first_name, - last_name, - email, - place.id AS "place.id", - place.name AS "place.name" - FROM placeperson - INNER JOIN place ON place.id = placeperson.place_id - WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(pp2) - if err != nil { - t.Error(err) - } - if pp2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) - } - if pp2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + pp2.LastName.String) - } - if pp2.Place.Name.String != "myplace" { - t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) - } - if pp2.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) - } - } - }) -} - -func TestNilInserts(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE tt ( - id integer, - value text NULL DEFAULT NULL - );`, - drop: "drop table tt;", - } - - RunWithSchema(schema, t, func(db *DB, t *testing.T) { - type TT struct { - ID int - Value *string - } - var v, v2 TT - r := db.Rebind - - db.MustExec(r(`INSERT INTO tt (id) VALUES (1)`)) - db.Get(&v, r(`SELECT * FROM tt`)) - if v.ID != 1 { - t.Errorf("Expecting id of 1, got %v", v.ID) - } - if v.Value != nil { - t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) - } - - v.ID = 2 - // NOTE: this incidentally uncovered a bug which was that named queries with - // pointer destinations would not work if the passed value here was not addressable, - // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for - // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly - // function. This next line is important as it provides the only coverage for this. - db.NamedExec(`INSERT INTO tt (id, value) VALUES (:id, :value)`, v) - - db.Get(&v2, r(`SELECT * FROM tt WHERE id=2`)) - if v.ID != v2.ID { - t.Errorf("%v != %v", v.ID, v2.ID) - } - if v2.Value != nil { - t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) - } - }) -} - -func TestScanError(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE kv ( - k text, - v integer - );`, - drop: `drop table kv;`, - } - - RunWithSchema(schema, t, func(db *DB, t *testing.T) { - type WrongTypes struct { - K int - V string - } - _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) - if err != nil { - t.Error(err) - } - - rows, err := db.Queryx("SELECT * FROM kv") - if err != nil { - t.Error(err) - } - for rows.Next() { - var wt WrongTypes - err := rows.StructScan(&wt) - if err == nil { - t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) - } - } - }) -} - -// FIXME: this function is kinda big but it slows things down to be constantly -// loading and reloading the schema.. - -func TestUsage(t *testing.T) { - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - slicemembers := []SliceMember{} - err := db.Select(&slicemembers, "SELECT * FROM place ORDER BY telcode ASC") - if err != nil { - t.Fatal(err) - } - - people := []Person{} - - err = db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") - if err != nil { - t.Fatal(err) - } - - jason, john := people[0], people[1] - if jason.FirstName != "Jason" { - t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) - } - if jason.LastName != "Moiron" { - t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) - } - if jason.Email != "jmoiron@jmoiron.net" { - t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) - } - if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { - t.Errorf("John Doe's person record not what expected: Got %v\n", john) - } - - jason = Person{} - err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") - - if err != nil { - t.Fatal(err) - } - if jason.FirstName != "Jason" { - t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) - } - - err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") - if err == nil { - t.Errorf("Expecting an error, got nil\n") - } - if err != sql.ErrNoRows { - t.Errorf("Expected sql.ErrNoRows, got %v\n", err) - } - - // The following tests check statement reuse, which was actually a problem - // due to copying being done when creating Stmt's which was eventually removed - stmt1, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Fatal(err) - } - jason = Person{} - - row := stmt1.QueryRowx("DoesNotExist") - row.Scan(&jason) - row = stmt1.QueryRowx("DoesNotExist") - row.Scan(&jason) - - err = stmt1.Get(&jason, "DoesNotExist User") - if err == nil { - t.Error("Expected an error") - } - err = stmt1.Get(&jason, "DoesNotExist User 2") - if err == nil { - t.Fatal(err) - } - - stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Fatal(err) - } - jason = Person{} - tx, err := db.Beginx() - if err != nil { - t.Fatal(err) - } - tstmt2 := tx.Stmtx(stmt2) - row2 := tstmt2.QueryRowx("Jason") - err = row2.StructScan(&jason) - if err != nil { - t.Error(err) - } - tx.Commit() - - places := []*Place{} - err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC") - if err != nil { - t.Fatal(err) - } - - usa, singsing, honkers := places[0], places[1], places[2] - - if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { - t.Errorf("Expected integer telcodes to work, got %#v", places) - } - - placesptr := []PlacePtr{} - err = db.Select(&placesptr, "SELECT * FROM place ORDER BY telcode ASC") - if err != nil { - t.Error(err) - } - //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) - - // if you have null fields and use SELECT *, you must use sql.Null* in your struct - // this test also verifies that you can use either a []Struct{} or a []*Struct{} - places2 := []Place{} - err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC") - if err != nil { - t.Fatal(err) - } - - usa, singsing, honkers = &places2[0], &places2[1], &places2[2] - - // this should return a type error that &p is not a pointer to a struct slice - p := Place{} - err = db.Select(&p, "SELECT * FROM place ORDER BY telcode ASC") - if err == nil { - t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") - } - - // this should be an error - pl := []Place{} - err = db.Select(pl, "SELECT * FROM place ORDER BY telcode ASC") - if err == nil { - t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") - } - - if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { - t.Errorf("Expected integer telcodes to work, got %#v", places) - } - - stmt, err := db.Preparex(db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) - if err != nil { - t.Error(err) - } - - places = []*Place{} - err = stmt.Select(&places, 10) - if len(places) != 2 { - t.Error("Expected 2 places, got 0.") - } - if err != nil { - t.Fatal(err) - } - singsing, honkers = places[0], places[1] - if singsing.TelCode != 65 || honkers.TelCode != 852 { - t.Errorf("Expected the right telcodes, got %#v", places) - } - - rows, err := db.Queryx("SELECT * FROM place") - if err != nil { - t.Fatal(err) - } - place := Place{} - for rows.Next() { - err = rows.StructScan(&place) - if err != nil { - t.Fatal(err) - } - } - - rows, err = db.Queryx("SELECT * FROM place") - if err != nil { - t.Fatal(err) - } - m := map[string]interface{}{} - for rows.Next() { - err = rows.MapScan(m) - if err != nil { - t.Fatal(err) - } - _, ok := m["country"] - if !ok { - t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) - } - } - - rows, err = db.Queryx("SELECT * FROM place") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - s, err := rows.SliceScan() - if err != nil { - t.Error(err) - } - if len(s) != 3 { - t.Errorf("Expected 3 columns in result, got %d\n", len(s)) - } - } - - // test advanced querying - // test that NamedExec works with a map as well as a struct - _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ - "first": "Bin", - "last": "Smuth", - "email": "bensmith@allblacks.nz", - }) - if err != nil { - t.Fatal(err) - } - - // ensure that if the named param happens right at the end it still works - // ensure that NamedQuery works with a map[string]interface{} - rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) - if err != nil { - t.Fatal(err) - } - - ben := &Person{} - for rows.Next() { - err = rows.StructScan(ben) - if err != nil { - t.Fatal(err) - } - if ben.FirstName != "Bin" { - t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) - } - if ben.LastName != "Smuth" { - t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) - } - } - - ben.FirstName = "Ben" - ben.LastName = "Smith" - ben.Email = "binsmuth@allblacks.nz" - - // Insert via a named query using the struct - _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) - - if err != nil { - t.Fatal(err) - } - - rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", ben) - if err != nil { - t.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(ben) - if err != nil { - t.Fatal(err) - } - if ben.FirstName != "Ben" { - t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) - } - if ben.LastName != "Smith" { - t.Fatal("Expected first name of `Smith`, got " + ben.LastName) - } - } - // ensure that Get does not panic on emppty result set - person := &Person{} - err = db.Get(person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") - if err == nil { - t.Fatal("Should have got an error for Get on non-existant row.") - } - - // lets test prepared statements some more - - stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Fatal(err) - } - rows, err = stmt.Queryx("Ben") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(ben) - if err != nil { - t.Fatal(err) - } - if ben.FirstName != "Ben" { - t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) - } - if ben.LastName != "Smith" { - t.Fatal("Expected first name of `Smith`, got " + ben.LastName) - } - } - - john = Person{} - stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) - if err != nil { - t.Error(err) - } - err = stmt.Get(&john, "John") - if err != nil { - t.Error(err) - } - - // test name mapping - // THIS USED TO WORK BUT WILL NO LONGER WORK. - db.MapperFunc(strings.ToUpper) - rsa := CPlace{} - err = db.Get(&rsa, "SELECT * FROM capplace;") - if err != nil { - t.Error(err, "in db:", db.DriverName()) - } - db.MapperFunc(strings.ToLower) - - // create a copy and change the mapper, then verify the copy behaves - // differently from the original. - dbCopy := NewDb(db.DB, db.DriverName()) - dbCopy.MapperFunc(strings.ToUpper) - err = dbCopy.Get(&rsa, "SELECT * FROM capplace;") - if err != nil { - fmt.Println(db.DriverName()) - t.Error(err) - } - - err = db.Get(&rsa, "SELECT * FROM cappplace;") - if err == nil { - t.Error("Expected no error, got ", err) - } - - // test base type slices - var sdest []string - rows, err = db.Queryx("SELECT email FROM person ORDER BY email ASC;") - if err != nil { - t.Error(err) - } - err = scanAll(rows, &sdest, false) - if err != nil { - t.Error(err) - } - - // test Get with base types - var count int - err = db.Get(&count, "SELECT count(*) FROM person;") - if err != nil { - t.Error(err) - } - if count != len(sdest) { - t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) - } - - // test Get and Select with time.Time, #84 - var addedAt time.Time - err = db.Get(&addedAt, "SELECT added_at FROM person LIMIT 1;") - if err != nil { - t.Error(err) - } - - var addedAts []time.Time - err = db.Select(&addedAts, "SELECT added_at FROM person;") - if err != nil { - t.Error(err) - } - - // test it on a double pointer - var pcount *int - err = db.Get(&pcount, "SELECT count(*) FROM person;") - if err != nil { - t.Error(err) - } - if *pcount != count { - t.Errorf("expected %d = %d", *pcount, count) - } - - // test Select... - sdest = []string{} - err = db.Select(&sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") - if err != nil { - t.Error(err) - } - expected := []string{"Ben", "Bin", "Jason", "John"} - for i, got := range sdest { - if got != expected[i] { - t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) - } - } - - var nsdest []sql.NullString - err = db.Select(&nsdest, "SELECT city FROM place ORDER BY city ASC") - if err != nil { - t.Error(err) - } - for _, val := range nsdest { - if val.Valid && val.String != "New York" { - t.Errorf("expected single valid result to be `New York`, but got %s", val.String) - } - } - }) -} - -type Product struct { - ProductID int -} - -// tests that sqlx will not panic when the wrong driver is passed because -// of an automatic nil dereference in sqlx.Open(), which was fixed. -func TestDoNotPanicOnConnect(t *testing.T) { - _, err := Connect("bogus", "hehe") - if err == nil { - t.Errorf("Should return error when using bogus driverName") - } -} - -func TestRebind(t *testing.T) { - q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` - q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` - - s1 := Rebind(DOLLAR, q1) - s2 := Rebind(DOLLAR, q2) - - if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` { - t.Errorf("q1 failed") - } - - if s2 != `INSERT INTO foo (a, b, c) VALUES ($1, $2, "foo"), ("Hi", $3, $4)` { - t.Errorf("q2 failed") - } - - s1 = Rebind(NAMED, q1) - s2 = Rebind(NAMED, q2) - - ex1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ` + - `(:arg1, :arg2, :arg3, :arg4, :arg5, :arg6, :arg7, :arg8, :arg9, :arg10)` - if s1 != ex1 { - t.Error("q1 failed on Named params") - } - - ex2 := `INSERT INTO foo (a, b, c) VALUES (:arg1, :arg2, "foo"), ("Hi", :arg3, :arg4)` - if s2 != ex2 { - t.Error("q2 failed on Named params") - } -} - -func TestBindMap(t *testing.T) { - // Test that it works.. - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` - am := map[string]interface{}{ - "name": "Jason Moiron", - "age": 30, - "first": "Jason", - "last": "Moiron", - } - - bq, args, _ := bindMap(QUESTION, q1, am) - expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` - if bq != expect { - t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) - } - - if args[0].(string) != "Jason Moiron" { - t.Errorf("Expected `Jason Moiron`, got %v\n", args[0]) - } - - if args[1].(int) != 30 { - t.Errorf("Expected 30, got %v\n", args[1]) - } - - if args[2].(string) != "Jason" { - t.Errorf("Expected Jason, got %v\n", args[2]) - } - - if args[3].(string) != "Moiron" { - t.Errorf("Expected Moiron, got %v\n", args[3]) - } -} - -// Test for #117, embedded nil maps - -type Message struct { - Text string `db:"string"` - Properties PropertyMap `db:"properties"` // Stored as JSON in the database -} - -type PropertyMap map[string]string - -// Implement driver.Valuer and sql.Scanner interfaces on PropertyMap -func (p PropertyMap) Value() (driver.Value, error) { - if len(p) == 0 { - return nil, nil - } - return json.Marshal(p) -} - -func (p PropertyMap) Scan(src interface{}) error { - v := reflect.ValueOf(src) - if !v.IsValid() || v.IsNil() { - return nil - } - if data, ok := src.([]byte); ok { - return json.Unmarshal(data, &p) - } - return fmt.Errorf("Could not not decode type %T -> %T", src, p) -} - -func TestEmbeddedMaps(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE message ( - string text, - properties text - );`, - drop: `drop table message;`, - } - - RunWithSchema(schema, t, func(db *DB, t *testing.T) { - messages := []Message{ - {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, - {"Thanks, Joy", PropertyMap{"pull": "request"}}, - } - q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` - for _, m := range messages { - _, err := db.NamedExec(q1, m) - if err != nil { - t.Fatal(err) - } - } - var count int - err := db.Get(&count, "SELECT count(*) FROM message") - if err != nil { - t.Fatal(err) - } - if count != len(messages) { - t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) - } - - var m Message - err = db.Get(&m, "SELECT * FROM message LIMIT 1;") - if err != nil { - t.Fatal(err) - } - if m.Properties == nil { - t.Fatal("Expected m.Properties to not be nil, but it was.") - } - }) -} - -func TestIssue197(t *testing.T) { - // this test actually tests for a bug in database/sql: - // https://github.com/golang/go/issues/13905 - // this potentially makes _any_ named type that is an alias for []byte - // unsafe to use in a lot of different ways (basically, unsafe to hold - // onto after loading from the database). - t.Skip() - - type mybyte []byte - type Var struct{ Raw json.RawMessage } - type Var2 struct{ Raw []byte } - type Var3 struct{ Raw mybyte } - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - var err error - var v, q Var - if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil { - t.Fatal(err) - } - if err = db.Get(&q, `SELECT 'null' AS raw`); err != nil { - t.Fatal(err) - } - - var v2, q2 Var2 - if err = db.Get(&v2, `SELECT '{"a": "b"}' AS raw`); err != nil { - t.Fatal(err) - } - if err = db.Get(&q2, `SELECT 'null' AS raw`); err != nil { - t.Fatal(err) - } - - var v3, q3 Var3 - if err = db.QueryRow(`SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { - t.Fatal(err) - } - if err = db.QueryRow(`SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { - t.Fatal(err) - } - t.Fail() - }) -} - -func TestIn(t *testing.T) { - // some quite normal situations - type tr struct { - q string - args []interface{} - c int - } - tests := []tr{ - {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", - []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, - 7}, - {"SELECT * FROM foo WHERE x in (?)", - []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, - 8}, - } - for _, test := range tests { - q, a, err := In(test.q, test.args...) - if err != nil { - t.Error(err) - } - if len(a) != test.c { - t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) - } - if strings.Count(q, "?") != test.c { - t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) - } - } - - // too many bindVars, but no slices, so short circuits parsing - // i'm not sure if this is the right behavior; this query/arg combo - // might not work, but we shouldn't parse if we don't need to - { - orig := "SELECT * FROM foo WHERE x = ? AND y = ?" - q, a, err := In(orig, "foo", "bar", "baz") - if err != nil { - t.Error(err) - } - if len(a) != 3 { - t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) - } - if q != orig { - t.Error("Expected unchanged query.") - } - } - - tests = []tr{ - // too many bindvars; slice present so should return error during parse - {"SELECT * FROM foo WHERE x = ? and y = ?", - []interface{}{"foo", []int{1, 2, 3}, "bar"}, - 0}, - // empty slice, should return error before parse - {"SELECT * FROM foo WHERE x = ?", - []interface{}{[]int{}}, - 0}, - // too *few* bindvars, should return an error - {"SELECT * FROM foo WHERE x = ? AND y in (?)", - []interface{}{[]int{1, 2, 3}}, - 0}, - } - for _, test := range tests { - _, _, err := In(test.q, test.args...) - if err == nil { - t.Error("Expected an error, but got nil.") - } - } - RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { - loadDefaultFixture(db, t) - //tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") - //tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") - //tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") - telcodes := []int{852, 65} - q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" - query, args, err := In(q, telcodes) - if err != nil { - t.Error(err) - } - query = db.Rebind(query) - places := []Place{} - err = db.Select(&places, query, args...) - if err != nil { - t.Error(err) - } - if len(places) != 2 { - t.Fatalf("Expecting 2 results, got %d", len(places)) - } - if places[0].TelCode != 65 { - t.Errorf("Expecting singapore first, but got %#v", places[0]) - } - if places[1].TelCode != 852 { - t.Errorf("Expecting hong kong second, but got %#v", places[1]) - } - }) -} - -func TestBindStruct(t *testing.T) { - var err error - - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` - - type tt struct { - Name string - Age int - First string - Last string - } - - type tt2 struct { - Field1 string `db:"field_1"` - Field2 string `db:"field_2"` - } - - type tt3 struct { - tt2 - Name string - } - - am := tt{"Jason Moiron", 30, "Jason", "Moiron"} - - bq, args, _ := bindStruct(QUESTION, q1, am, mapper()) - expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` - if bq != expect { - t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) - } - - if args[0].(string) != "Jason Moiron" { - t.Errorf("Expected `Jason Moiron`, got %v\n", args[0]) - } - - if args[1].(int) != 30 { - t.Errorf("Expected 30, got %v\n", args[1]) - } - - if args[2].(string) != "Jason" { - t.Errorf("Expected Jason, got %v\n", args[2]) - } - - if args[3].(string) != "Moiron" { - t.Errorf("Expected Moiron, got %v\n", args[3]) - } - - am2 := tt2{"Hello", "World"} - bq, args, _ = bindStruct(QUESTION, "INSERT INTO foo (a, b) VALUES (:field_2, :field_1)", am2, mapper()) - expect = `INSERT INTO foo (a, b) VALUES (?, ?)` - if bq != expect { - t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) - } - - if args[0].(string) != "World" { - t.Errorf("Expected 'World', got %s\n", args[0].(string)) - } - if args[1].(string) != "Hello" { - t.Errorf("Expected 'Hello', got %s\n", args[1].(string)) - } - - am3 := tt3{Name: "Hello!"} - am3.Field1 = "Hello" - am3.Field2 = "World" - - bq, args, err = bindStruct(QUESTION, "INSERT INTO foo (a, b, c) VALUES (:name, :field_1, :field_2)", am3, mapper()) - - if err != nil { - t.Fatal(err) - } - - expect = `INSERT INTO foo (a, b, c) VALUES (?, ?, ?)` - if bq != expect { - t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) - } - - if args[0].(string) != "Hello!" { - t.Errorf("Expected 'Hello!', got %s\n", args[0].(string)) - } - if args[1].(string) != "Hello" { - t.Errorf("Expected 'Hello', got %s\n", args[1].(string)) - } - if args[2].(string) != "World" { - t.Errorf("Expected 'World', got %s\n", args[0].(string)) - } -} - -func TestEmbeddedLiterals(t *testing.T) { - var schema = Schema{ - create: ` - CREATE TABLE x ( - k text - );`, - drop: `drop table x;`, - } - - RunWithSchema(schema, t, func(db *DB, t *testing.T) { - type t1 struct { - K *string - } - type t2 struct { - Inline struct { - F string - } - K *string - } - - db.MustExec(db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") - - target := t1{} - err := db.Get(&target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") - if err != nil { - t.Error(err) - } - if *target.K != "one" { - t.Error("Expected target.K to be `one`, got ", target.K) - } - - target2 := t2{} - err = db.Get(&target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") - if err != nil { - t.Error(err) - } - if *target2.K != "one" { - t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) - } - }) -} - -func BenchmarkBindStruct(b *testing.B) { - b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` - type t struct { - Name string - Age int - First string - Last string - } - am := t{"Jason Moiron", 30, "Jason", "Moiron"} - b.StartTimer() - for i := 0; i < b.N; i++ { - bindStruct(DOLLAR, q1, am, mapper()) - } -} - -func BenchmarkBindMap(b *testing.B) { - b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` - am := map[string]interface{}{ - "name": "Jason Moiron", - "age": 30, - "first": "Jason", - "last": "Moiron", - } - b.StartTimer() - for i := 0; i < b.N; i++ { - bindMap(DOLLAR, q1, am) - } -} - -func BenchmarkIn(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` - - for i := 0; i < b.N; i++ { - _, _, _ = In(q, []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}...) - } -} - -func BenchmarkIn1k(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` - - var vals [1000]interface{} - - for i := 0; i < b.N; i++ { - _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) - } -} - -func BenchmarkIn1kInt(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` - - var vals [1000]int - - for i := 0; i < b.N; i++ { - _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) - } -} - -func BenchmarkIn1kString(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` - - var vals [1000]string - - for i := 0; i < b.N; i++ { - _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) - } -} - -func BenchmarkRebind(b *testing.B) { - b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` - q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` - b.StartTimer() - - for i := 0; i < b.N; i++ { - Rebind(DOLLAR, q1) - Rebind(DOLLAR, q2) - } -} - -func BenchmarkRebindBuffer(b *testing.B) { - b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` - q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` - b.StartTimer() - - for i := 0; i < b.N; i++ { - rebindBuff(DOLLAR, q1) - rebindBuff(DOLLAR, q2) - } -} diff --git a/vendor/github.com/lib/pq/hstore/hstore.go b/vendor/github.com/lib/pq/hstore/hstore.go new file mode 100644 index 000000000..72d5abf51 --- /dev/null +++ b/vendor/github.com/lib/pq/hstore/hstore.go @@ -0,0 +1,118 @@ +package hstore + +import ( + "database/sql" + "database/sql/driver" + "strings" +) + +// A wrapper for transferring Hstore values back and forth easily. +type Hstore struct { + Map map[string]sql.NullString +} + +// escapes and quotes hstore keys/values +// s should be a sql.NullString or string +func hQuote(s interface{}) string { + var str string + switch v := s.(type) { + case sql.NullString: + if !v.Valid { + return "NULL" + } + str = v.String + case string: + str = v + default: + panic("not a string or sql.NullString") + } + + str = strings.Replace(str, "\\", "\\\\", -1) + return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` +} + +// Scan implements the Scanner interface. +// +// Note h.Map is reallocated before the scan to clear existing values. If the +// hstore column's database value is NULL, then h.Map is set to nil instead. +func (h *Hstore) Scan(value interface{}) error { + if value == nil { + h.Map = nil + return nil + } + h.Map = make(map[string]sql.NullString) + var b byte + pair := [][]byte{{}, {}} + pi := 0 + inQuote := false + didQuote := false + sawSlash := false + bindex := 0 + for bindex, b = range value.([]byte) { + if sawSlash { + pair[pi] = append(pair[pi], b) + sawSlash = false + continue + } + + switch b { + case '\\': + sawSlash = true + continue + case '"': + inQuote = !inQuote + if !didQuote { + didQuote = true + } + continue + default: + if !inQuote { + switch b { + case ' ', '\t', '\n', '\r': + continue + case '=': + continue + case '>': + pi = 1 + didQuote = false + continue + case ',': + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + pair[0] = []byte{} + pair[1] = []byte{} + pi = 0 + continue + } + } + } + pair[pi] = append(pair[pi], b) + } + if bindex > 0 { + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + } + return nil +} + +// Value implements the driver Valuer interface. Note if h.Map is nil, the +// database column value will be set to NULL. +func (h Hstore) Value() (driver.Value, error) { + if h.Map == nil { + return nil, nil + } + parts := []string{} + for key, val := range h.Map { + thispart := hQuote(key) + "=>" + hQuote(val) + parts = append(parts, thispart) + } + return []byte(strings.Join(parts, ",")), nil +} diff --git a/vendor/github.com/lib/pq/hstore/hstore_test.go b/vendor/github.com/lib/pq/hstore/hstore_test.go new file mode 100644 index 000000000..1c9f2bd49 --- /dev/null +++ b/vendor/github.com/lib/pq/hstore/hstore_test.go @@ -0,0 +1,148 @@ +package hstore + +import ( + "database/sql" + "os" + "testing" + + _ "github.com/lib/pq" +) + +type Fatalistic interface { + Fatal(args ...interface{}) +} + +func openTestConn(t Fatalistic) *sql.DB { + datname := os.Getenv("PGDATABASE") + sslmode := os.Getenv("PGSSLMODE") + + if datname == "" { + os.Setenv("PGDATABASE", "pqgotest") + } + + if sslmode == "" { + os.Setenv("PGSSLMODE", "disable") + } + + conn, err := sql.Open("postgres", "") + if err != nil { + t.Fatal(err) + } + + return conn +} + +func TestHstore(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + // quitely create hstore if it doesn't exist + _, err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore") + if err != nil { + t.Skipf("Skipping hstore tests - hstore extension create failed: %s", err.Error()) + } + + hs := Hstore{} + + // test for null-valued hstores + err = db.QueryRow("SELECT NULL::hstore").Scan(&hs) + if err != nil { + t.Fatal(err) + } + if hs.Map != nil { + t.Fatalf("expected null map") + } + + err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) + if err != nil { + t.Fatalf("re-query null map failed: %s", err.Error()) + } + if hs.Map != nil { + t.Fatalf("expected null map") + } + + // test for empty hstores + err = db.QueryRow("SELECT ''::hstore").Scan(&hs) + if err != nil { + t.Fatal(err) + } + if hs.Map == nil { + t.Fatalf("expected empty map, got null map") + } + if len(hs.Map) != 0 { + t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) + } + + err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) + if err != nil { + t.Fatalf("re-query empty map failed: %s", err.Error()) + } + if hs.Map == nil { + t.Fatalf("expected empty map, got null map") + } + if len(hs.Map) != 0 { + t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) + } + + // a few example maps to test out + hsOnePair := Hstore{ + Map: map[string]sql.NullString{ + "key1": {String: "value1", Valid: true}, + }, + } + + hsThreePairs := Hstore{ + Map: map[string]sql.NullString{ + "key1": {String: "value1", Valid: true}, + "key2": {String: "value2", Valid: true}, + "key3": {String: "value3", Valid: true}, + }, + } + + hsSmorgasbord := Hstore{ + Map: map[string]sql.NullString{ + "nullstring": {String: "NULL", Valid: true}, + "actuallynull": {String: "", Valid: false}, + "NULL": {String: "NULL string key", Valid: true}, + "withbracket": {String: "value>42", Valid: true}, + "withequal": {String: "value=42", Valid: true}, + `"withquotes1"`: {String: `this "should" be fine`, Valid: true}, + `"withquotes"2"`: {String: `this "should\" also be fine`, Valid: true}, + "embedded1": {String: "value1=>x1", Valid: true}, + "embedded2": {String: `"value2"=>x2`, Valid: true}, + "withnewlines": {String: "\n\nvalue\t=>2", Valid: true}, + "<>": {String: `this, "should,\" also, => be fine`, Valid: true}, + }, + } + + // test encoding in query params, then decoding during Scan + testBidirectional := func(h Hstore) { + err = db.QueryRow("SELECT $1::hstore", h).Scan(&hs) + if err != nil { + t.Fatalf("re-query %d-pair map failed: %s", len(h.Map), err.Error()) + } + if hs.Map == nil { + t.Fatalf("expected %d-pair map, got null map", len(h.Map)) + } + if len(hs.Map) != len(h.Map) { + t.Fatalf("expected %d-pair map, got len(map)=%d", len(h.Map), len(hs.Map)) + } + + for key, val := range hs.Map { + otherval, found := h.Map[key] + if !found { + t.Fatalf(" key '%v' not found in %d-pair map", key, len(h.Map)) + } + if otherval.Valid != val.Valid { + t.Fatalf(" value %v <> %v in %d-pair map", otherval, val, len(h.Map)) + } + if otherval.String != val.String { + t.Fatalf(" value '%v' <> '%v' in %d-pair map", otherval.String, val.String, len(h.Map)) + } + } + } + + testBidirectional(hsOnePair) + testBidirectional(hsThreePairs) + testBidirectional(hsSmorgasbord) +}