Skip to content

Commit

Permalink
Merge branch 'master' into distinguish_unique
Browse files Browse the repository at this point in the history
# Conflicts:
#	migrator.go
  • Loading branch information
black-06 committed Oct 13, 2023
2 parents 3688999 + 2a61ba0 commit 3674fd1
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 11 deletions.
7 changes: 1 addition & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@ name: tests

on:
push:
branches:
- '**'
- '*'
pull_request:
branches:
- master

permissions:
contents: read
Expand All @@ -27,7 +22,7 @@ jobs:
go-version: ${{ matrix.go }}

- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4

# Run build of the application
- name: Run build
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ db, err := gorm.Open(mysql.New(mysql.Config{
import (
_ "example.com/my_mysql_driver"
"gorm.io/gorm"
"gorm.io/driver/mysql"
)

db, err := gorm.Open(mysql.New(mysql.Config{
Expand Down
11 changes: 7 additions & 4 deletions error_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ import (
"gorm.io/gorm"
)

var errCodes = map[string]uint16{
"uniqueConstraint": 1062,
// The error codes to map mysql errors to gorm errors, here is the mysql error codes reference https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html.
var errCodes = map[uint16]error{
1062: gorm.ErrDuplicatedKey,
1452: gorm.ErrForeignKeyViolated,
}

func (dialector Dialector) Translate(err error) error {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
if mysqlErr.Number == errCodes["uniqueConstraint"] {
return gorm.ErrDuplicatedKey
if translatedErr, found := errCodes[mysqlErr.Number]; found {
return translatedErr
}
return mysqlErr
}

return err
Expand Down
58 changes: 58 additions & 0 deletions error_translator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package mysql

import (
"errors"
"testing"

"gorm.io/gorm"

"github.com/go-sql-driver/mysql"
)

func TestDialector_Translate(t *testing.T) {
normalErr := errors.New("normal error")

type fields struct {
Config *Config
}
type args struct {
err error
}
tests := []struct {
name string
fields fields
args args
want error
}{
{
name: "it should translate error to ErrDuplicatedKey when the error number is 1062",
args: args{err: &mysql.MySQLError{Number: uint16(1062)}},
want: gorm.ErrDuplicatedKey,
},
{
name: "it should translate error to ErrForeignKeyViolated when the error number is 1452",
args: args{err: &mysql.MySQLError{Number: uint16(1452)}},
want: gorm.ErrForeignKeyViolated,
},
{
name: "it should not translate the error when the error number is not registered in translated error codes",
args: args{err: &mysql.MySQLError{Number: uint16(8888)}},
want: &mysql.MySQLError{Number: uint16(8888)},
},
{
name: "it should not translate the error when the error is not a mysql error",
args: args{err: normalErr},
want: normalErr,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dialector := Dialector{
Config: tt.fields.Config,
}
if err := dialector.Translate(tt.args.err); !errors.Is(err, tt.want) {
t.Errorf("Translate() got error = %v, want error %v", err, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ go 1.14

require (
github.com/go-sql-driver/mysql v1.7.0
gorm.io/gorm v1.25.1
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55 h1:sC1Xj4TYrLqg1n3AN10w871An7wJM0gzgcm8jkIkECQ=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
25 changes: 25 additions & 0 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,31 @@ func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, co
})
}

func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name)
}

if !f.IgnoreMigration {
fieldType := m.FullDataTypeOf(f)
columnName := clause.Column{Name: f.DBName}
values := []interface{}{m.CurrentTable(stmt), columnName, fieldType}
var alterSql strings.Builder
alterSql.WriteString("ALTER TABLE ? ADD ? ?")
if f.PrimaryKey || strings.Contains(strings.ToLower(fieldType.SQL), "auto_increment") {
alterSql.WriteString(", ADD PRIMARY KEY (?)")
values = append(values, columnName)
}
return m.DB.Exec(alterSql.String(), values...).Error
}

return nil
})
}

func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
Expand Down

0 comments on commit 3674fd1

Please sign in to comment.