diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 55c8605..84051bf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,12 +2,7 @@ name: tests on: push: - branches: - - '**' - - '*' pull_request: - branches: - - master permissions: contents: read @@ -27,7 +22,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 # Run build of the application - name: Run build diff --git a/README.md b/README.md index b8f7a6c..9d75d5c 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ db, err := gorm.Open(mysql.New(mysql.Config{ import ( _ "example.com/my_mysql_driver" "gorm.io/gorm" + "gorm.io/driver/mysql" ) db, err := gorm.Open(mysql.New(mysql.Config{ diff --git a/error_translator.go b/error_translator.go index 79f6646..bf91276 100644 --- a/error_translator.go +++ b/error_translator.go @@ -6,15 +6,18 @@ import ( "gorm.io/gorm" ) -var errCodes = map[string]uint16{ - "uniqueConstraint": 1062, +// The error codes to map mysql errors to gorm errors, here is the mysql error codes reference https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html. +var errCodes = map[uint16]error{ + 1062: gorm.ErrDuplicatedKey, + 1452: gorm.ErrForeignKeyViolated, } func (dialector Dialector) Translate(err error) error { if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[mysqlErr.Number]; found { + return translatedErr } + return mysqlErr } return err diff --git a/error_translator_test.go b/error_translator_test.go new file mode 100644 index 0000000..a23e6db --- /dev/null +++ b/error_translator_test.go @@ -0,0 +1,58 @@ +package mysql + +import ( + "errors" + "testing" + + "gorm.io/gorm" + + "github.com/go-sql-driver/mysql" +) + +func TestDialector_Translate(t *testing.T) { + normalErr := errors.New("normal error") + + type fields struct { + Config *Config + } + type args struct { + err error + } + tests := []struct { + name string + fields fields + args args + want error + }{ + { + name: "it should translate error to ErrDuplicatedKey when the error number is 1062", + args: args{err: &mysql.MySQLError{Number: uint16(1062)}}, + want: gorm.ErrDuplicatedKey, + }, + { + name: "it should translate error to ErrForeignKeyViolated when the error number is 1452", + args: args{err: &mysql.MySQLError{Number: uint16(1452)}}, + want: gorm.ErrForeignKeyViolated, + }, + { + name: "it should not translate the error when the error number is not registered in translated error codes", + args: args{err: &mysql.MySQLError{Number: uint16(8888)}}, + want: &mysql.MySQLError{Number: uint16(8888)}, + }, + { + name: "it should not translate the error when the error is not a mysql error", + args: args{err: normalErr}, + want: normalErr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dialector := Dialector{ + Config: tt.fields.Config, + } + if err := dialector.Translate(tt.args.err); !errors.Is(err, tt.want) { + t.Errorf("Translate() got error = %v, want error %v", err, tt.want) + } + }) + } +} diff --git a/go.mod b/go.mod index 1d0b80d..fc63353 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.14 require ( github.com/go-sql-driver/mysql v1.7.0 - gorm.io/gorm v1.25.1 + gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55 ) diff --git a/go.sum b/go.sum index e3169ce..5db00fb 100644 --- a/go.sum +++ b/go.sum @@ -6,3 +6,5 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64= gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55 h1:sC1Xj4TYrLqg1n3AN10w871An7wJM0gzgcm8jkIkECQ= +gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= \ No newline at end of file diff --git a/migrator.go b/migrator.go index bce776d..d74903e 100644 --- a/migrator.go +++ b/migrator.go @@ -119,6 +119,31 @@ func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, co }) } +func (m Migrator) AddColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // avoid using the same name field + f := stmt.Schema.LookUpField(name) + if f == nil { + return fmt.Errorf("failed to look up field with name: %s", name) + } + + if !f.IgnoreMigration { + fieldType := m.FullDataTypeOf(f) + columnName := clause.Column{Name: f.DBName} + values := []interface{}{m.CurrentTable(stmt), columnName, fieldType} + var alterSql strings.Builder + alterSql.WriteString("ALTER TABLE ? ADD ? ?") + if f.PrimaryKey || strings.Contains(strings.ToLower(fieldType.SQL), "auto_increment") { + alterSql.WriteString(", ADD PRIMARY KEY (?)") + values = append(values, columnName) + } + return m.DB.Exec(alterSql.String(), values...).Error + } + + return nil + }) +} + func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil {