Skip to content

Commit

Permalink
Merge pull request #5 from Rennbon/main
Browse files Browse the repository at this point in the history
feature: sql支持error return
  • Loading branch information
Rennbon authored Jan 7, 2022
2 parents cf3236b + a519590 commit 862ac54
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
20 changes: 16 additions & 4 deletions kit.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,13 @@ func (mk *mockit) MysqlExecExpect(ep ExpectParam, tb testing.TB) {
mk.sqlEx.ExpectCommit()
}

var errorInterface = reflect.TypeOf((*error)(nil)).Elem()

func (mk *mockit) MysqlQueryExpect(ep ExpectParam, tb testing.TB) {
var (
rows *sqlmock.Rows
err error
rows *sqlmock.Rows
errExpected bool
err error
)
p := ep.(*expectParam)
if len(p.method) == 0 {
Expand All @@ -177,9 +180,14 @@ func (mk *mockit) MysqlQueryExpect(ep ExpectParam, tb testing.TB) {
} else {
if len(p.returns) != 1 {
tb.Fatal("the query must return one result")
}

typ := reflect.TypeOf(p.returns[0])
if typ.Implements(errorInterface) {
errExpected = true
} else {
rows, err = sqlmock.NewRowsFromInterface(p.returns[0], mk.ormTag)
}
rows, err = sqlmock.NewRowsFromInterface(p.returns[0], mk.ormTag)
}
if err != nil {
tb.Fatalf("new rows failed:%s", err)
Expand All @@ -188,7 +196,11 @@ func (mk *mockit) MysqlQueryExpect(ep ExpectParam, tb testing.TB) {
for _, v := range p.args {
args = append(args, v)
}
mk.sqlEx.ExpectQuery(p.method).WithArgs(args...).WillReturnRows(rows)
if errExpected {
mk.sqlEx.ExpectQuery(p.method).WithArgs(args...).WillReturnError(p.returns[0].(error))
} else {
mk.sqlEx.ExpectQuery(p.method).WithArgs(args...).WillReturnRows(rows)
}
}

func (mk *mockit) InterfaceExpect(ep ExpectParam, tb testing.TB) {
Expand Down
14 changes: 14 additions & 0 deletions kit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,20 @@ func TestMockMySql(t *testing.T) {
assert.Nil(t, err)
assert.EqualValues(t, expectedReturns, res)
})
t.Run("error return", func(t *testing.T) {
relationId := int64(200)
pageIndex := 2
pageSize := 20
errExpected := gorm.ErrRecordNotFound
p := NewExpectParam().
WithMethod("SELECT (.+) FROM `demo` WHERE relation_id = (.+)").
WithArgs(relationId, false).
WithReturns(errExpected)
kit.MysqlQueryExpect(p, t)
res, err := srv.mysql.List(ctx, relationId, pageIndex, pageSize)
assert.Equal(t, errExpected, err)
assert.Nil(t, res)
})
}

func TestMockRedis(t *testing.T) {
Expand Down

0 comments on commit 862ac54

Please sign in to comment.