diff --git a/kit.go b/kit.go index 2ee6eba..bc1cf95 100644 --- a/kit.go +++ b/kit.go @@ -163,10 +163,13 @@ func (mk *mockit) MysqlExecExpect(ep ExpectParam, tb testing.TB) { mk.sqlEx.ExpectCommit() } +var errorInterface = reflect.TypeOf((*error)(nil)).Elem() + func (mk *mockit) MysqlQueryExpect(ep ExpectParam, tb testing.TB) { var ( - rows *sqlmock.Rows - err error + rows *sqlmock.Rows + errExpected bool + err error ) p := ep.(*expectParam) if len(p.method) == 0 { @@ -177,9 +180,14 @@ func (mk *mockit) MysqlQueryExpect(ep ExpectParam, tb testing.TB) { } else { if len(p.returns) != 1 { tb.Fatal("the query must return one result") + } + typ := reflect.TypeOf(p.returns[0]) + if typ.Implements(errorInterface) { + errExpected = true + } else { + rows, err = sqlmock.NewRowsFromInterface(p.returns[0], mk.ormTag) } - rows, err = sqlmock.NewRowsFromInterface(p.returns[0], mk.ormTag) } if err != nil { tb.Fatalf("new rows failed:%s", err) @@ -188,7 +196,11 @@ func (mk *mockit) MysqlQueryExpect(ep ExpectParam, tb testing.TB) { for _, v := range p.args { args = append(args, v) } - mk.sqlEx.ExpectQuery(p.method).WithArgs(args...).WillReturnRows(rows) + if errExpected { + mk.sqlEx.ExpectQuery(p.method).WithArgs(args...).WillReturnError(p.returns[0].(error)) + } else { + mk.sqlEx.ExpectQuery(p.method).WithArgs(args...).WillReturnRows(rows) + } } func (mk *mockit) InterfaceExpect(ep ExpectParam, tb testing.TB) { diff --git a/kit_test.go b/kit_test.go index ba04c75..2478ed1 100644 --- a/kit_test.go +++ b/kit_test.go @@ -201,6 +201,20 @@ func TestMockMySql(t *testing.T) { assert.Nil(t, err) assert.EqualValues(t, expectedReturns, res) }) + t.Run("error return", func(t *testing.T) { + relationId := int64(200) + pageIndex := 2 + pageSize := 20 + errExpected := gorm.ErrRecordNotFound + p := NewExpectParam(). + WithMethod("SELECT (.+) FROM `demo` WHERE relation_id = (.+)"). + WithArgs(relationId, false). + WithReturns(errExpected) + kit.MysqlQueryExpect(p, t) + res, err := srv.mysql.List(ctx, relationId, pageIndex, pageSize) + assert.Equal(t, errExpected, err) + assert.Nil(t, res) + }) } func TestMockRedis(t *testing.T) {