diff --git a/CHANGES b/CHANGES index 4087f68..283ef27 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,10 @@ +v6.0.0 (TBD): +- updated common helpers to be generic +- updated datastructures to be generic +- cleanup package structre and remove deprecated ones +- updated logger with formatting functionality +- modernized test harness & mocks + 5.4.0 (Jan 10, 2024) - Added `Scan` operation to Redis diff --git a/asynctask/asynctasks_test.go b/asynctask/asynctasks_test.go index b1012fb..77e6b99 100644 --- a/asynctask/asynctasks_test.go +++ b/asynctask/asynctasks_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) func TestAsyncTaskNormalOperation(t *testing.T) { @@ -29,27 +30,14 @@ func TestAsyncTaskNormalOperation(t *testing.T) { task1.Start() time.Sleep(1 * time.Second) - if !task1.IsRunning() { - t.Error("Task should be running") - } - time.Sleep(1 * time.Second) + assert.True(t, task1.IsRunning()) + time.Sleep(1 * time.Second) task1.Stop(true) - if task1.IsRunning() { - t.Error("Task should be stopped") - } - - if !onInit.Load().(bool) { - t.Error("Initialization hook not executed") - } - - if !onExecution.Load().(bool) { - t.Error("Main task function not executed") - } - - if !onStop.Load().(bool) { - t.Error("After execution function not executed") - } + assert.False(t, task1.IsRunning()) + assert.True(t, onInit.Load().(bool)) + assert.True(t, onExecution.Load().(bool)) + assert.True(t, onStop.Load().(bool)) } func TestAsyncTaskPanics(t *testing.T) { @@ -94,15 +82,10 @@ func TestAsyncTaskPanics(t *testing.T) { task3.Start() time.Sleep(time.Second * 2) task3.Stop(true) - if task1.IsRunning() { - t.Error("Task1 is running and should be stopped") - } - if task2.IsRunning() { - t.Error("Task2 is running and should be stopped") - } - if task3.IsRunning() { - t.Error("Task3 is running and should be stopped") - } + + assert.False(t, task1.IsRunning()) + assert.False(t, task2.IsRunning()) + assert.False(t, task3.IsRunning()) } func TestAsyncTaskErrors(t *testing.T) { @@ -138,9 +121,8 @@ func TestAsyncTaskErrors(t *testing.T) { task2.Start() time.Sleep(2 * time.Second) - if res.Load().(int) != 0 { - t.Error("Task should have never executed if there was an error when calling onInit()") - } + + assert.Equal(t, int(0), res.Load().(int)) } func TestAsyncTaskWakeUp(t *testing.T) { @@ -163,7 +145,5 @@ func TestAsyncTaskWakeUp(t *testing.T) { _ = task1.WakeUp() _ = task1.Stop(true) - if atomic.LoadInt32(&res) != 3 { - t.Errorf("Task shuld have executed 4 times. It ran %d times", res) - } + assert.Equal(t, int32(3), atomic.LoadInt32(&res)) } diff --git a/backoff/backoff_test.go b/backoff/backoff_test.go index 8836a3e..58c719f 100644 --- a/backoff/backoff_test.go +++ b/backoff/backoff_test.go @@ -3,32 +3,20 @@ package backoff import ( "testing" "time" + "github.com/stretchr/testify/assert" ) func TestBackoff(t *testing.T) { base := int64(10) maxAllowed := 60 * time.Second backoff := New(base, maxAllowed) - if backoff.base != base { - t.Error("It should be equals to 10") - } - if backoff.maxAllowed != maxAllowed { - t.Error("It should be equals to 60") - } - if backoff.Next() != 1*time.Second { - t.Error("It should be 1 second") - } - if backoff.Next() != 10*time.Second { - t.Error("It should be 10 seconds") - } - if backoff.Next() != 60*time.Second { - t.Error("It should be 60 seconds") - } + assert.Equal(t, base, backoff.base) + assert.Equal(t, maxAllowed, backoff.maxAllowed) + assert.Equal(t, 1*time.Second, backoff.Next()) + assert.Equal(t, 10*time.Second, backoff.Next()) + assert.Equal(t, 60*time.Second, backoff.Next()) + backoff.Reset() - if backoff.current != 0 { - t.Error("It should be zero") - } - if backoff.Next() != 1*time.Second { - t.Error("It should be 1 second") - } + assert.Equal(t, int64(0), backoff.current) + assert.Equal(t, 1*time.Second, backoff.Next()) } diff --git a/common/common.go b/common/common.go index b3bbff8..de86446 100644 --- a/common/common.go +++ b/common/common.go @@ -13,11 +13,11 @@ func Ref[T any](x T) *T { // RefOrNil returns a pointer to the value supplied if it's not the default value, nil otherwise func RefOrNil[T comparable](x T) *T { - var t T - if x == t { - return nil - } - return &x + var t T + if x == t { + return nil + } + return &x } // PointerOf performs a type-assertion to T and returns a pointer if successful, nil otherwise. @@ -93,3 +93,11 @@ func Min[T cmp.Ordered](i1 T, rest ...T) T { } return min } + +func AsInterfaceSlice[T any](in []T) []interface{} { + out := make([]interface{}, len(in)) + for idx := range in { + out[idx] = in[idx] + } + return out +} diff --git a/datastructures/boolslice/boolslice_test.go b/datastructures/boolslice/boolslice_test.go index 7432888..20e3093 100644 --- a/datastructures/boolslice/boolslice_test.go +++ b/datastructures/boolslice/boolslice_test.go @@ -3,18 +3,16 @@ package boolslice import ( "math" "testing" + + "github.com/stretchr/testify/assert" ) func TestBoolSlice(t *testing.T) { _, err := NewBoolSlice(12) - if err == nil { - t.Error("It should return err") - } + assert.NotNil(t, err) b, err := NewBoolSlice(int(math.Pow(2, 15))) - if err != nil { - t.Error("It should not return err", err) - } + assert.Nil(t, err) i1 := 12 i2 := 20 @@ -22,96 +20,57 @@ func TestBoolSlice(t *testing.T) { i4 := 2000 i5 := 8192 - if err := b.Set(int(math.Pow(2, 15)) + 1); err == nil { - t.Error("It should return err") - } - if err := b.Set(i1); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i2); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i3); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i4); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i5); err != nil { - t.Error("It should not return err") - } - - if _, err := b.Get(int(math.Pow(2, 15)) + 1); err == nil { - t.Error("It should return err") - } - if v, _ := b.Get(i1); !v { - t.Error("It should match", i1) - } - if v, _ := b.Get(i2); !v { - t.Error("It should match", i2) - } - if v, _ := b.Get(i3); !v { - t.Error("It should match", i3) - } - if v, _ := b.Get(i4); !v { - t.Error("It should match", i4) - } - if v, _ := b.Get(i5); !v { - t.Error("It should match", i5) - } - if v, _ := b.Get(200); v { - t.Error("It should not match 200") - } - if v, _ := b.Get(5000); v { - t.Error("It should not match 5000") - } - - if len(b.Bytes()) != int(math.Pow(2, 15)/8) { - t.Error("Len should be 4096") - } - - if err := b.Clear(int(math.Pow(2, 15)) + 1); err == nil { - t.Error("It should return err") - } - if err := b.Clear(i1); err != nil { - t.Error("It should not return err") - } - - if v, _ := b.Get(i1); v { - t.Error("It should not match after cleared", i1) - } - - if _, err := Rebuild(1, nil); err.Error() != "size must be a multiple of 8" { - t.Error("It should return err") - } - - if _, err := Rebuild(8, nil); err.Error() != "data cannot be empty" { - t.Error("It should return err") - } + assert.Equal(t, ErrorOutOfBounds, b.Set(int(math.Pow(2, 15)) + 1)) + assert.Nil(t, b.Set(i1)) + assert.Nil(t, b.Set(i2)) + assert.Nil(t, b.Set(i3)) + assert.Nil(t, b.Set(i4)) + assert.Nil(t, b.Set(i5)) + + set, err := b.Get(int(math.Pow(2, 15)) + 1) + assert.False(t, set) + assert.Equal(t, ErrorOutOfBounds, err) + + for _, i := range []int{i1, i2, i3, i4, i5} { + res, err := b.Get(i) + assert.Nil(t, err) + assert.True(t, res) + } + + for _, i := range []int{200, 500} { + res, err := b.Get(i) + assert.Nil(t, err) + assert.False(t, res) + } + + assert.Equal(t, math.Pow(2, 15)/8, float64(len(b.Bytes()))) + assert.Equal(t, ErrorOutOfBounds, b.Clear(int(math.Pow(2, 15)) + 1)) + assert.Nil(t, b.Clear(i1)) + + v, err := b.Get(i1) + assert.Nil(t, err) + assert.False(t, v) + + res, err := Rebuild(1, nil) + assert.Nil(t, res) + assert.NotNil(t, err) + + res, err = Rebuild(8, nil) + assert.Nil(t, res) + assert.NotNil(t, err) rebuilt, err := Rebuild(int(math.Pow(2, 15)), b.Bytes()) - if err != nil { - t.Error("It should not return err") - } - if v, _ := rebuilt.Get(i2); !v { - t.Error("It should match", i2) - } - if v, _ := rebuilt.Get(i3); !v { - t.Error("It should match", i3) - } - if v, _ := rebuilt.Get(i4); !v { - t.Error("It should match", i4) - } - if v, _ := rebuilt.Get(i5); !v { - t.Error("It should match", i5) - } - if v, _ := rebuilt.Get(i1); v { - t.Error("It should not match 12") - } - if v, _ := rebuilt.Get(200); v { - t.Error("It should not match 200") - } - if v, _ := rebuilt.Get(5000); v { - t.Error("It should not match 5000") - } + assert.Nil(t, err) + + for _, i := range []int{i2, i3, i4, i5} { + res, err := rebuilt.Get(i) + assert.Nil(t, err) + assert.True(t, res) + } + + for _, i := range []int{200, 5000} { + res, err := rebuilt.Get(i) + assert.Nil(t, err) + assert.False(t, res) + } } diff --git a/datastructures/cache/cache_test.go b/datastructures/cache/cache_test.go index 9ff36cc..1bf9eda 100644 --- a/datastructures/cache/cache_test.go +++ b/datastructures/cache/cache_test.go @@ -6,110 +6,62 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestSimpleCache(t *testing.T) { cache, err := NewSimpleLRU[string, int](5, 1*time.Second) - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } + assert.Nil(t, err) for i := 1; i <= 5; i++ { err := cache.Set(fmt.Sprintf("someKey%d", i), i) - if err != nil { - t.Errorf("Setting value 'someKey%d', should not have raised an error. Got: %s", i, err) - } + assert.Nil(t, err) } for i := 1; i <= 5; i++ { val, err := cache.Get(fmt.Sprintf("someKey%d", i)) - if err != nil { - t.Errorf("Getting value 'someKey%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key 'someKey%d' should be %d. Is %d", i, i, val) - } + assert.Nil(t, err) + assert.Equal(t, i, val) } cache.Set("someKey6", 6) // Oldest item (1) should have been removed val, err := cache.Get("someKey1") - if err == nil { - t.Errorf("Getting value 'someKey1', should not have raised an error. Got: %s", err) - } - + assert.NotNil(t, err) asMiss, ok := err.(*Miss) - if !ok { - t.Errorf("Error should be of type Miss. Is %T", err) - } - - if asMiss.Key != "someKey1" || asMiss.Where != "LOCAL" { - t.Errorf("Incorrect data within the Miss error. Got: %+v", asMiss) - } - - if val != 0 { - t.Errorf("Value for key 'someKey1' should be nil. Is %d", val) - } + assert.True(t, ok) + assert.Equal(t, "someKey1", asMiss.Key) + assert.Equal(t, "LOCAL", asMiss.Where) + assert.Equal(t, 0, val) // 2-6 should be available for i := 2; i <= 6; i++ { val, err := cache.Get(fmt.Sprintf("someKey%d", i)) - if err != nil { - t.Errorf("Getting value 'someKey%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key 'someKey%d' should be %d. Is %d", i, i, val) - } - } - - if len(cache.items) != 5 { - t.Error("Items size should be 5. is: ", len(cache.items)) + assert.Nil(t, err) + assert.Equal(t, i, val) } - if len(cache.ttls) != len(cache.items) { - t.Error("TTLs size should be the same size as items") - } - - if cache.lru.Len() != 5 { - t.Error("LRU size should be 5. is: ", cache.lru.Len()) - } + assert.Equal(t, 5, len(cache.items)) + assert.Equal(t, 5, len(cache.ttls)) + assert.Equal(t, 5, cache.lru.Len()) time.Sleep(2 * time.Second) // Wait for all keys to expire. + for i := 2; i <= 6; i++ { val, err := cache.Get(fmt.Sprintf("someKey%d", i)) - if val != 0 { - t.Errorf("No value should have been returned for expired key 'someKey%d'.", i) - } - - if err == nil { - t.Errorf("Getting value 'someKey%d', should have raised an 'Expired' error. Got nil", i) - continue - } - - asExpiredErr, ok := err.(*Expired) - if !ok { - t.Errorf("Returned error should be of 'Expired' type. Is %T", err) - continue - } - - if asExpiredErr.Key != fmt.Sprintf("someKey%d", i) { - t.Errorf("Key in Expired error should be 'someKey%d'. Is: '%s'", i, asExpiredErr.Key) - } - - if asExpiredErr.Value != i { - t.Errorf("Value in Expired error should be %d. Is %+v", i, asExpiredErr.Value) - } + assert.Equal(t, 0, val) + assert.NotNil(t, err) + asExpired, ok := err.(*Expired) + assert.True(t, ok) + assert.Equal(t, fmt.Sprintf("someKey%d", i), asExpired.Key) + assert.Equal(t, i, asExpired.Value) ttl, ok := cache.ttls[fmt.Sprintf("someKey%d", i)] - if !ok { - t.Errorf("A ttl entry should exist for key 'someKey%d'", i) - continue - } + assert.True(t, ok) + assert.Equal(t, asExpired.When, ttl.Add(cache.ttl)) - if asExpiredErr.When != ttl.Add(cache.ttl) { - t.Errorf("Key 'someKey%d' should have expired at %+v. It did at %+v", i, ttl.Add(cache.ttl), asExpiredErr.When) - } } } @@ -141,59 +93,35 @@ func TestSimpleCacheHighConcurrency(t *testing.T) { wg.Wait() } - func TestInt64Cache(t *testing.T) { c, err := NewSimpleLRU[int64, int64](5, NoTTL) - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } + assert.Nil(t, err) for i := int64(1); i <= 5; i++ { - err := c.Set(i, i) - if err != nil { - t.Errorf("Setting value '%d', should not have raised an error. Got: %s", i, err) - } + assert.Nil(t, c.Set(i, i)) } for i := int64(1); i <= 5; i++ { val, err := c.Get(i) - if err != nil { - t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) - } + assert.Nil(t, err) + assert.Equal(t, i, val) } c.Set(6, 6) // Oldest item (1) should have been removed val, err := c.Get(1) - if err == nil { - t.Errorf("Getting value 'someKey1', should not have raised an error. Got: %s", err) - } - + assert.NotNil(t, err) _, ok := err.(*Miss) - if !ok { - t.Errorf("Error should be of type Miss. Is %T", err) - } - - if val != 0 { - t.Errorf("Value for key 'someKey1' should be nil. Is %d", val) - } + assert.True(t, ok) + assert.Equal(t, int64(0), val) // 2-6 should be available for i := int64(2); i <= 6; i++ { val, err := c.Get(i) - if err != nil { - t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) - } + assert.Nil(t, err) + assert.Equal(t, i, val) } - if len(c.items) != 5 { - t.Error("Items size should be 5. is: ", len(c.items)) - } + assert.Equal(t, 5, len(c.items)) } diff --git a/datastructures/cache/mocks/mocks.go b/datastructures/cache/mocks/mocks.go new file mode 100644 index 0000000..9fcf797 --- /dev/null +++ b/datastructures/cache/mocks/mocks.go @@ -0,0 +1,22 @@ +package mocks + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type LayerMock struct { + mock.Mock +} + +func (m *LayerMock) Get(ctx context.Context, key string) (string, error) { + args := m.Called(ctx, key) + return args.String(0), args.Error(1) +} + +func (m *LayerMock) Set(ctx context.Context, key string, value string) error { + args := m.Called(ctx, key, value) + return args.Error(0) +} + diff --git a/datastructures/cache/multilevel_test.go b/datastructures/cache/multilevel_test.go index 22d2907..6adc401 100644 --- a/datastructures/cache/multilevel_test.go +++ b/datastructures/cache/multilevel_test.go @@ -2,212 +2,61 @@ package cache import ( "context" - "errors" - "fmt" "testing" + "github.com/splitio/go-toolkit/v6/datastructures/cache/mocks" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) -type LayerMock struct { - getCall func(ctx context.Context, key string) (string, error) - setCall func(ctx context.Context, key string, value string) error -} - -func (m *LayerMock) Get(ctx context.Context, key string) (string, error) { - return m.getCall(ctx, key) -} - -func (m *LayerMock) Set(ctx context.Context, key string, value string) error { - return m.setCall(ctx, key, value) -} - -type callTracker struct { - calls map[string]int - t *testing.T -} - -func newCallTracker(t *testing.T) *callTracker { - return &callTracker{calls: make(map[string]int), t: t} -} - -func (c *callTracker) track(name string) { c.calls[name]++ } - -func (c *callTracker) reset() { c.calls = make(map[string]int) } - -func (c *callTracker) checkCall(name string, count int) { - c.t.Helper() - if c.calls[name] != count { - c.t.Errorf("calls for '%s' should be %d. is: %d", name, count, c.calls[name]) - } -} - -func (c *callTracker) checkTotalCalls(count int) { - c.t.Helper() - if len(c.calls) != count { - c.t.Errorf("The nomber of total calls should be '%d' and is '%d'", count, len(c.calls)) - } -} - func TestMultiLevelCache(t *testing.T) { // To test this we setup 3 layers of caching in order of querying: top -> mid -> bottom // Top layer has key1, doesn't have key2 (returns Miss), has key3 expired and errors out when requesting any other Key // Mid layer has key 2, returns a Miss on any other key, and fails the test if key1 is fetched (because it was present on top layer) // Bottom layer fails if key1 or 2 are requested, has key 3. returns Miss if any other key is requested - calls := newCallTracker(t) - topLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (string, error) { - calls.track(fmt.Sprintf("top:get:%s", key)) - switch key { - case "key1": - return "value1", nil - case "key2": - return "", &Miss{Where: "layer1", Key: "key2"} - case "key3": - return "", &Expired{Key: "key3", Value: "someOtherValue"} - default: - return "", errors.New("someError") - } - }, - setCall: func(ctx context.Context, key string, value string) error { - calls.track(fmt.Sprintf("top:set:%s", key)) - switch key { - case "key1": - t.Error("Set should not be called on the top layer for key1") - break - case "key2": - break - case "key3": - break - default: - return errors.New("someError") - } - return nil - }, - } - midLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (string, error) { - calls.track(fmt.Sprintf("mid:get:%s", key)) - switch key { - case "key1": - t.Error("Get should not be called on the mid layer for key1") - return "", nil - case "key2": - return "value2", nil - default: - return "", &Miss{Where: "layer2", Key: key} - } - }, - setCall: func(ctx context.Context, key string, value string) error { - calls.track(fmt.Sprintf("mid:set:%s", key)) - switch key { - case "key1": - t.Error("Set should not be called on the mid layer for key1") - case "key2": - t.Error("Set should not be called on the mid layer for key2") - case "key3": - default: - return errors.New("someError") - } - return nil - }, - } + ctx := context.Background() - bottomLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (string, error) { - calls.track(fmt.Sprintf("bot:get:%s", key)) - switch key { - case "key1": - t.Error("Get should not be called on the mid layer for key1") - return "", nil - case "key2": - t.Error("Get should not be called on the mid layer for key1") - return "", nil - case "key3": - return "value3", nil - default: - return "", &Miss{Where: "layer3", Key: key} - } - }, - setCall: func(ctx context.Context, key string, value string) error { - calls.track(fmt.Sprintf("bot:set:%s", key)) - switch key { - case "key1": - t.Error("Set should not be called on the mid layer for key1") - case "key2": - t.Error("Set should not be called on the mid layer for key2") - default: - return errors.New("someError") - } - return nil - }, - } + topLayer := &mocks.LayerMock{} + topLayer.On("Get", ctx, "key1").Once().Return("value1", nil) + topLayer.On("Get", ctx, "key2").Once().Return("", &Miss{Where: "layer1", Key: "key2"}) + topLayer.On("Get", ctx, "key3").Once().Return("value1", &Expired{Key: "key3", Value: "someOtherValue"}) + topLayer.On("Get", ctx, "key4").Once().Return("", &Miss{Where: "layer1", Key: "key4"}) + topLayer.On("Set", ctx, "key2", "value2").Once().Return(nil) + topLayer.On("Set", ctx, "key3", "value3").Once().Return(nil) + + midLayer := &mocks.LayerMock{} + midLayer.On("Get", ctx, "key2").Once().Return("value2", nil) + midLayer.On("Get", ctx, "key3").Once().Return("", &Miss{Where: "layer2", Key: "key3"}, nil) + midLayer.On("Get", ctx, "key4").Once().Return("", &Miss{Where: "layer2", Key: "key4"}) + midLayer.On("Set", ctx, "key3", "value3").Once().Return(nil) + + bottomLayer := &mocks.LayerMock{} + bottomLayer.On("Get", ctx, "key3").Once().Return("value3", nil) + bottomLayer.On("Get", ctx, "key4").Once().Return("", &Miss{Where: "layer3", Key: "key4"}) cacheML := MultiLevelCacheImpl[string, string]{ logger: logging.NewLogger(nil), layers: []MLCLayer[string, string]{topLayer, midLayer, bottomLayer}, } - value1, err := cacheML.Get(context.TODO(), "key1") - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } - if value1 != "value1" { - t.Error("Get 'key1' should return 'value1'. Got: ", value1) - } - calls.checkCall("top:get:key1", 1) - calls.checkTotalCalls(1) - - calls.reset() - value2, err := cacheML.Get(context.TODO(), "key2") - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } - if value2 != "value2" { - t.Error("Get 'key2' should return 'value2'. Got: ", value2) - } - calls.checkCall("top:get:key2", 1) - calls.checkCall("mid:get:key2", 1) - calls.checkCall("top:set:key2", 1) - calls.checkTotalCalls(3) - - calls.reset() - value3, err := cacheML.Get(context.TODO(), "key3") - if err != nil { - t.Error("Error should be nil. Was: ", err) - } + value1, err := cacheML.Get(ctx, "key1") + assert.Nil(t, err) + assert.Equal(t, "value1", value1) - if value3 != "value3" { - t.Error("Get 'key3' should return 'value3'. Got: ", value3) - } - calls.checkCall("top:get:key3", 1) - calls.checkCall("mid:get:key3", 1) - calls.checkCall("bot:get:key3", 1) - calls.checkCall("mid:set:key3", 1) - calls.checkCall("top:set:key3", 1) - calls.checkTotalCalls(5) + value2, err := cacheML.Get(ctx, "key2") + assert.Nil(t, err) + assert.Equal(t, "value2", value2) - calls.reset() - value4, err := cacheML.Get(context.TODO(), "key4") - if err == nil { - t.Error("Error should be returned when getting nonexistant key.") - } + value3, err := cacheML.Get(ctx, "key3") + assert.Nil(t, err) + assert.Equal(t, "value3", value3) + value4, err := cacheML.Get(ctx, "key4") + assert.NotNil(t, err) asMiss, ok := err.(*Miss) - if !ok { - t.Errorf("Error should be of Miss type. Is %T", err) - } - - if asMiss.Where != "ALL_LEVELS" || asMiss.Key != "key4" { - t.Errorf("Incorrect 'Where' or 'Key'. Got: %+v", asMiss) - } - - if value4 != "" { - t.Errorf("Value returned for GET 'key4' should be nil. Is: %+v", value4) - } - calls.checkCall("top:get:key4", 1) - calls.checkCall("top:get:key4", 1) - calls.checkCall("top:get:key4", 1) - calls.checkTotalCalls(3) + assert.True(t, ok) + assert.Equal(t, "ALL_LEVELS", asMiss.Where) + assert.Equal(t, "key4", asMiss.Key) + assert.Equal(t, "", value4) } diff --git a/datastructures/queuecache/cache_test.go b/datastructures/queuecache/cache_test.go index 6e5d734..324b8e3 100644 --- a/datastructures/queuecache/cache_test.go +++ b/datastructures/queuecache/cache_test.go @@ -4,6 +4,8 @@ import ( "errors" "math" "testing" + + "github.com/stretchr/testify/assert" ) func TestCacheBasicUsage(t *testing.T) { @@ -36,78 +38,42 @@ func TestCacheBasicUsage(t *testing.T) { for index, item := range first5 { asInt, ok := item.(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != index { - t.Error("Each number should be equal to its index") - } + assert.True(t, ok) + assert.Equal(t, index, asInt) } offset := 5 next5, err := myCache.Fetch(5) - if err != nil { - t.Error(err) - } + assert.Nil(t, err) for index, item := range next5 { asInt, ok := item.(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != index+offset { - t.Error("Each number should be equal to its index") - } + assert.True(t, ok) + assert.Equal(t, index+offset, asInt) } index = 0 myCache = New(10, fetchMore) for i := 0; i < 100; i++ { elem, err := myCache.Fetch(1) - if err != nil { - t.Error(err) - } - + assert.Nil(t, err) asInt, ok := elem[0].(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != i { - t.Error("Each number should be equal to its index") - t.Error("asInt", asInt) - t.Error("index", i) - } + assert.True(t, ok) + assert.Equal(t, i, asInt) } elems, err := myCache.Fetch(1) - if elems != nil { - t.Error("Elem should be nil and is: ", elems) - } - - if err == nil || err.Error() != "NO_MORE_DATA" { - t.Error("Error should be NO_MORE_DATA and is: ", err.Error()) - } + assert.Nil(t, elems) + assert.ErrorContains(t, err, "NO_MORE_DATA") // Set index to 0 so that refill works and restart tests. index = 0 for i := 0; i < 100; i++ { elem, err := myCache.Fetch(1) - if err != nil { - t.Error(err) - } + assert.Nil(t, err) asInt, ok := elem[0].(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != i { - t.Error("Each number should be equal to its index") - t.Error("asInt", asInt) - t.Error("index", i) - } + assert.True(t, ok) + assert.Equal(t, i, asInt) } } @@ -118,18 +84,11 @@ func TestRefillPanic(t *testing.T) { myCache := New(10, fetchMore) result, err := myCache.Fetch(5) - - if result != nil { - t.Error("Result should have been nil and is: ", result) - } - if err == nil { - t.Error("Error should not have been nil") - } + assert.Nil(t, result) + assert.NotNil(t, err) _, ok := err.(*RefillError) - if !ok { - t.Error("Returned error should have been a RefillError") - } + assert.True(t, ok) } func TestCountWorksProperly(t *testing.T) { @@ -137,25 +96,17 @@ func TestCountWorksProperly(t *testing.T) { cache.readCursor = 0 cache.writeCursor = 0 - if cache.Count() != 0 { - t.Error("Count should be 0 and is: ", cache.Count()) - } + assert.Equal(t, 0, cache.Count()) cache.readCursor = 0 cache.writeCursor = 1 - if cache.Count() != 1 { - t.Error("Count should be 1 and is: ", cache.Count()) - } + assert.Equal(t, 1, cache.Count()) cache.readCursor = 50 cache.writeCursor = 99 - if cache.Count() != 49 { - t.Error("Count should be 49 and is: ", cache.Count()) - } + assert.Equal(t, 49, cache.Count()) cache.readCursor = 50 cache.writeCursor = 20 - if cache.Count() != 70 { - t.Error("Count should be 69 and is: ", cache.Count()) - } + assert.Equal(t, 70, cache.Count()) } diff --git a/datautils/compress.go b/datautils/compress.go index ca555aa..ee2f53a 100644 --- a/datautils/compress.go +++ b/datautils/compress.go @@ -5,7 +5,7 @@ import ( "compress/gzip" "compress/zlib" "fmt" - "io/ioutil" + "io" ) const ( @@ -48,7 +48,7 @@ func Decompress(data []byte, compressType int) ([]byte, error) { return nil, err } defer gz.Close() - raw, err := ioutil.ReadAll(gz) + raw, err := io.ReadAll(gz) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func Decompress(data []byte, compressType int) ([]byte, error) { return nil, err } defer zl.Close() - raw, err := ioutil.ReadAll(zl) + raw, err := io.ReadAll(zl) if err != nil { return nil, err } diff --git a/datautils/compress_test.go b/datautils/compress_test.go index 0acfbcb..6d8973b 100644 --- a/datautils/compress_test.go +++ b/datautils/compress_test.go @@ -1,53 +1,41 @@ package datautils -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestCompressDecompressError(t *testing.T) { data := "compression" _, err := Compress([]byte(data), 4) - if err == nil || err.Error() != "compression type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "compression type not found") _, err = Decompress([]byte("err"), 4) - if err == nil || err.Error() != "compression type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "compression type not found") } func TestCompressDecompressGZip(t *testing.T) { data := "compression gzip" compressed, err := Compress([]byte(data), GZip) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) decompressed, err := Decompress(compressed, GZip) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) - if string(decompressed) != data { - t.Error("It should be equal") - } + assert.Equal(t, data, string(decompressed)) } func TestCompressDecompressZLib(t *testing.T) { data := "compression zlib" compressed, err := Compress([]byte(data), Zlib) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) decompressed, err := Decompress(compressed, Zlib) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) - if string(decompressed) != data { - t.Error("It should be equal") - } + assert.Equal(t, data, string(decompressed)) } diff --git a/datautils/encode_test.go b/datautils/encode_test.go index acd9bf0..1e79b30 100644 --- a/datautils/encode_test.go +++ b/datautils/encode_test.go @@ -1,32 +1,25 @@ package datautils -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestError(t *testing.T) { _, err := Encode([]byte("err"), 4) - if err == nil || err.Error() != "encode type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "encode type not found") _, err = Decode("err", 4) - if err == nil || err.Error() != "encode type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "encode type not found") } func TestB64EncodeDecode(t *testing.T) { data := "encode b64" encoded, err := Encode([]byte(data), Base64) - if err != nil { - t.Error("It should not return err") - } + assert.Nil(t, err) decoded, err := Decode(encoded, Base64) - if err != nil { - t.Error("It should not return err") - } - - if data != string(decoded) { - t.Error("It should be equal") - } + assert.Nil(t, err) + assert.Equal(t, data, string(decoded)) } diff --git a/deepcopy/deepcopy.go b/deepcopy/deepcopy.go deleted file mode 100644 index 73b83f4..0000000 --- a/deepcopy/deepcopy.go +++ /dev/null @@ -1,114 +0,0 @@ -package deepcopy - -import ( - "reflect" - "time" -) - -// Interface for delegating copy process to type -type Interface interface { - DeepCopy() interface{} -} - -// Copy creates a deep copy of whatever is passed to it and returns the copy -// in an interface{}. The returned value will need to be asserted to the -// correct type. -func Copy(src interface{}) interface{} { - if src == nil { - return nil - } - - // Make the interface a reflect.Value - original := reflect.ValueOf(src) - - // Make a copy of the same type as the original. - cpy := reflect.New(original.Type()).Elem() - - // Recursively copy the original. - copyRecursive(original, cpy) - - // Return the copy as an interface. - return cpy.Interface() -} - -// copyRecursive does the actual copying of the interface. It currently has -// limited support for what it can handle. Add as needed. -func copyRecursive(original, cpy reflect.Value) { - // check for implement deepcopy.Interface - if original.CanInterface() { - if copier, ok := original.Interface().(Interface); ok { - cpy.Set(reflect.ValueOf(copier.DeepCopy())) - return - } - } - - // handle according to original's Kind - switch original.Kind() { - case reflect.Ptr: - // Get the actual value being pointed to. - originalValue := original.Elem() - - // if it isn't valid, return. - if !originalValue.IsValid() { - return - } - cpy.Set(reflect.New(originalValue.Type())) - copyRecursive(originalValue, cpy.Elem()) - - case reflect.Interface: - // If this is a nil, don't do anything - if original.IsNil() { - return - } - // Get the value for the interface, not the pointer. - originalValue := original.Elem() - - // Get the value by calling Elem(). - copyValue := reflect.New(originalValue.Type()).Elem() - copyRecursive(originalValue, copyValue) - cpy.Set(copyValue) - - case reflect.Struct: - t, ok := original.Interface().(time.Time) - if ok { - cpy.Set(reflect.ValueOf(t)) - return - } - // Go through each field of the struct and copy it. - for i := 0; i < original.NumField(); i++ { - // The Type's StructField for a given field is checked to see if StructField.PkgPath - // is set to determine if the field is exported or not because CanSet() returns false - // for settable fields. I'm not sure why. -mohae - if original.Type().Field(i).PkgPath != "" { - continue - } - copyRecursive(original.Field(i), cpy.Field(i)) - } - - case reflect.Slice: - if original.IsNil() { - return - } - // Make a new slice and copy each element. - cpy.Set(reflect.MakeSlice(original.Type(), original.Len(), original.Cap())) - for i := 0; i < original.Len(); i++ { - copyRecursive(original.Index(i), cpy.Index(i)) - } - - case reflect.Map: - if original.IsNil() { - return - } - cpy.Set(reflect.MakeMap(original.Type())) - for _, key := range original.MapKeys() { - originalValue := original.MapIndex(key) - copyValue := reflect.New(originalValue.Type()).Elem() - copyRecursive(originalValue, copyValue) - copyKey := Copy(key.Interface()) - cpy.SetMapIndex(reflect.ValueOf(copyKey), copyValue) - } - - default: - cpy.Set(original) - } -} diff --git a/deepcopy/deepcopy_test.go b/deepcopy/deepcopy_test.go deleted file mode 100644 index f150b1a..0000000 --- a/deepcopy/deepcopy_test.go +++ /dev/null @@ -1,1110 +0,0 @@ -package deepcopy - -import ( - "fmt" - "reflect" - "testing" - "time" - "unsafe" -) - -// just basic is this working stuff -func TestSimple(t *testing.T) { - Strings := []string{"a", "b", "c"} - cpyS := Copy(Strings).([]string) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyS)).Data { - t.Error("[]string: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyBools - } - if len(cpyS) != len(Strings) { - t.Errorf("[]string: len was %d; want %d", len(cpyS), len(Strings)) - goto CopyBools - } - for i, v := range Strings { - if v != cpyS[i] { - t.Errorf("[]string: got %v at index %d of the copy; want %v", cpyS[i], i, v) - } - } - -CopyBools: - Bools := []bool{true, true, false, false} - cpyB := Copy(Bools).([]bool) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyB)).Data { - t.Error("[]bool: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyBytes - } - if len(cpyB) != len(Bools) { - t.Errorf("[]bool: len was %d; want %d", len(cpyB), len(Bools)) - goto CopyBytes - } - for i, v := range Bools { - if v != cpyB[i] { - t.Errorf("[]bool: got %v at index %d of the copy; want %v", cpyB[i], i, v) - } - } - -CopyBytes: - Bytes := []byte("hello") - cpyBt := Copy(Bytes).([]byte) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyBt)).Data { - t.Error("[]byte: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyInts - } - if len(cpyBt) != len(Bytes) { - t.Errorf("[]byte: len was %d; want %d", len(cpyBt), len(Bytes)) - goto CopyInts - } - for i, v := range Bytes { - if v != cpyBt[i] { - t.Errorf("[]byte: got %v at index %d of the copy; want %v", cpyBt[i], i, v) - } - } - -CopyInts: - Ints := []int{42} - cpyI := Copy(Ints).([]int) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyI)).Data { - t.Error("[]int: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyUints - } - if len(cpyI) != len(Ints) { - t.Errorf("[]int: len was %d; want %d", len(cpyI), len(Ints)) - goto CopyUints - } - for i, v := range Ints { - if v != cpyI[i] { - t.Errorf("[]int: got %v at index %d of the copy; want %v", cpyI[i], i, v) - } - } - -CopyUints: - Uints := []uint{1, 2, 3, 4, 5} - cpyU := Copy(Uints).([]uint) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyU)).Data { - t.Error("[]: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyFloat32s - } - if len(cpyU) != len(Uints) { - t.Errorf("[]uint: len was %d; want %d", len(cpyU), len(Uints)) - goto CopyFloat32s - } - for i, v := range Uints { - if v != cpyU[i] { - t.Errorf("[]uint: got %v at index %d of the copy; want %v", cpyU[i], i, v) - } - } - -CopyFloat32s: - Float32s := []float32{3.14} - cpyF := Copy(Float32s).([]float32) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyF)).Data { - t.Error("[]float32: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyInterfaces - } - if len(cpyF) != len(Float32s) { - t.Errorf("[]float32: len was %d; want %d", len(cpyF), len(Float32s)) - goto CopyInterfaces - } - for i, v := range Float32s { - if v != cpyF[i] { - t.Errorf("[]float32: got %v at index %d of the copy; want %v", cpyF[i], i, v) - } - } - -CopyInterfaces: - Interfaces := []interface{}{"a", 42, true, 4.32} - cpyIf := Copy(Interfaces).([]interface{}) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyIf)).Data { - t.Error("[]interfaces: expected SliceHeader data pointers to point to different locations, they didn't") - return - } - if len(cpyIf) != len(Interfaces) { - t.Errorf("[]interface{}: len was %d; want %d", len(cpyIf), len(Interfaces)) - return - } - for i, v := range Interfaces { - if v != cpyIf[i] { - t.Errorf("[]interface{}: got %v at index %d of the copy; want %v", cpyIf[i], i, v) - } - } -} - -type Basics struct { - String string - Strings []string - StringArr [4]string - Bool bool - Bools []bool - Byte byte - Bytes []byte - Int int - Ints []int - Int8 int8 - Int8s []int8 - Int16 int16 - Int16s []int16 - Int32 int32 - Int32s []int32 - Int64 int64 - Int64s []int64 - Uint uint - Uints []uint - Uint8 uint8 - Uint8s []uint8 - Uint16 uint16 - Uint16s []uint16 - Uint32 uint32 - Uint32s []uint32 - Uint64 uint64 - Uint64s []uint64 - Float32 float32 - Float32s []float32 - Float64 float64 - Float64s []float64 - Complex64 complex64 - Complex64s []complex64 - Complex128 complex128 - Complex128s []complex128 - Interface interface{} - Interfaces []interface{} -} - -// These tests test that all supported basic types are copied correctly. This -// is done by copying a struct with fields of most of the basic types as []T. -func TestMostTypes(t *testing.T) { - test := Basics{ - String: "kimchi", - Strings: []string{"uni", "ika"}, - StringArr: [4]string{"malort", "barenjager", "fernet", "salmiakki"}, - Bool: true, - Bools: []bool{true, false, true}, - Byte: 'z', - Bytes: []byte("abc"), - Int: 42, - Ints: []int{0, 1, 3, 4}, - Int8: 8, - Int8s: []int8{8, 9, 10}, - Int16: 16, - Int16s: []int16{16, 17, 18, 19}, - Int32: 32, - Int32s: []int32{32, 33}, - Int64: 64, - Int64s: []int64{64}, - Uint: 420, - Uints: []uint{11, 12, 13}, - Uint8: 81, - Uint8s: []uint8{81, 82}, - Uint16: 160, - Uint16s: []uint16{160, 161, 162, 163, 164}, - Uint32: 320, - Uint32s: []uint32{320, 321}, - Uint64: 640, - Uint64s: []uint64{6400, 6401, 6402, 6403}, - Float32: 32.32, - Float32s: []float32{32.32, 33}, - Float64: 64.1, - Float64s: []float64{64, 65, 66}, - Complex64: complex64(-64 + 12i), - Complex64s: []complex64{complex64(-65 + 11i), complex64(66 + 10i)}, - Complex128: complex128(-128 + 12i), - Complex128s: []complex128{complex128(-128 + 11i), complex128(129 + 10i)}, - Interfaces: []interface{}{42, true, "pan-galactic"}, - } - - cpy := Copy(test).(Basics) - - // see if they point to the same location - if fmt.Sprintf("%p", &cpy) == fmt.Sprintf("%p", &test) { - t.Error("address of copy was the same as original; they should be different") - return - } - - // Go through each field and check to see it got copied properly - if cpy.String != test.String { - t.Errorf("String: got %v; want %v", cpy.String, test.String) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Strings)).Data { - t.Error("Strings: address of copy was the same as original; they should be different") - goto StringArr - } - - if len(cpy.Strings) != len(test.Strings) { - t.Errorf("Strings: len was %d; want %d", len(cpy.Strings), len(test.Strings)) - goto StringArr - } - for i, v := range test.Strings { - if v != cpy.Strings[i] { - t.Errorf("Strings: got %v at index %d of the copy; want %v", cpy.Strings[i], i, v) - } - } - -StringArr: - if unsafe.Pointer(&test.StringArr) == unsafe.Pointer(&cpy.StringArr) { - t.Error("StringArr: address of copy was the same as original; they should be different") - goto Bools - } - for i, v := range test.StringArr { - if v != cpy.StringArr[i] { - t.Errorf("StringArr: got %v at index %d of the copy; want %v", cpy.StringArr[i], i, v) - } - } - -Bools: - if cpy.Bool != test.Bool { - t.Errorf("Bool: got %v; want %v", cpy.Bool, test.Bool) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Bools)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Bools)).Data { - t.Error("Bools: address of copy was the same as original; they should be different") - goto Bytes - } - if len(cpy.Bools) != len(test.Bools) { - t.Errorf("Bools: len was %d; want %d", len(cpy.Bools), len(test.Bools)) - goto Bytes - } - for i, v := range test.Bools { - if v != cpy.Bools[i] { - t.Errorf("Bools: got %v at index %d of the copy; want %v", cpy.Bools[i], i, v) - } - } - -Bytes: - if cpy.Byte != test.Byte { - t.Errorf("Byte: got %v; want %v", cpy.Byte, test.Byte) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Bytes)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Bytes)).Data { - t.Error("Bytes: address of copy was the same as original; they should be different") - goto Ints - } - if len(cpy.Bytes) != len(test.Bytes) { - t.Errorf("Bytes: len was %d; want %d", len(cpy.Bytes), len(test.Bytes)) - goto Ints - } - for i, v := range test.Bytes { - if v != cpy.Bytes[i] { - t.Errorf("Bytes: got %v at index %d of the copy; want %v", cpy.Bytes[i], i, v) - } - } - -Ints: - if cpy.Int != test.Int { - t.Errorf("Int: got %v; want %v", cpy.Int, test.Int) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Ints)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Ints)).Data { - t.Error("Ints: address of copy was the same as original; they should be different") - goto Int8s - } - if len(cpy.Ints) != len(test.Ints) { - t.Errorf("Ints: len was %d; want %d", len(cpy.Ints), len(test.Ints)) - goto Int8s - } - for i, v := range test.Ints { - if v != cpy.Ints[i] { - t.Errorf("Ints: got %v at index %d of the copy; want %v", cpy.Ints[i], i, v) - } - } - -Int8s: - if cpy.Int8 != test.Int8 { - t.Errorf("Int8: got %v; want %v", cpy.Int8, test.Int8) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int8s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int8s)).Data { - t.Error("Int8s: address of copy was the same as original; they should be different") - goto Int16s - } - if len(cpy.Int8s) != len(test.Int8s) { - t.Errorf("Int8s: len was %d; want %d", len(cpy.Int8s), len(test.Int8s)) - goto Int16s - } - for i, v := range test.Int8s { - if v != cpy.Int8s[i] { - t.Errorf("Int8s: got %v at index %d of the copy; want %v", cpy.Int8s[i], i, v) - } - } - -Int16s: - if cpy.Int16 != test.Int16 { - t.Errorf("Int16: got %v; want %v", cpy.Int16, test.Int16) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int16s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int16s)).Data { - t.Error("Int16s: address of copy was the same as original; they should be different") - goto Int32s - } - if len(cpy.Int16s) != len(test.Int16s) { - t.Errorf("Int16s: len was %d; want %d", len(cpy.Int16s), len(test.Int16s)) - goto Int32s - } - for i, v := range test.Int16s { - if v != cpy.Int16s[i] { - t.Errorf("Int16s: got %v at index %d of the copy; want %v", cpy.Int16s[i], i, v) - } - } - -Int32s: - if cpy.Int32 != test.Int32 { - t.Errorf("Int32: got %v; want %v", cpy.Int32, test.Int32) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int32s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int32s)).Data { - t.Error("Int32s: address of copy was the same as original; they should be different") - goto Int64s - } - if len(cpy.Int32s) != len(test.Int32s) { - t.Errorf("Int32s: len was %d; want %d", len(cpy.Int32s), len(test.Int32s)) - goto Int64s - } - for i, v := range test.Int32s { - if v != cpy.Int32s[i] { - t.Errorf("Int32s: got %v at index %d of the copy; want %v", cpy.Int32s[i], i, v) - } - } - -Int64s: - if cpy.Int64 != test.Int64 { - t.Errorf("Int64: got %v; want %v", cpy.Int64, test.Int64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int64s)).Data { - t.Error("Int64s: address of copy was the same as original; they should be different") - goto Uints - } - if len(cpy.Int64s) != len(test.Int64s) { - t.Errorf("Int64s: len was %d; want %d", len(cpy.Int64s), len(test.Int64s)) - goto Uints - } - for i, v := range test.Int64s { - if v != cpy.Int64s[i] { - t.Errorf("Int64s: got %v at index %d of the copy; want %v", cpy.Int64s[i], i, v) - } - } - -Uints: - if cpy.Uint != test.Uint { - t.Errorf("Uint: got %v; want %v", cpy.Uint, test.Uint) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uints)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uints)).Data { - t.Error("Uints: address of copy was the same as original; they should be different") - goto Uint8s - } - if len(cpy.Uints) != len(test.Uints) { - t.Errorf("Uints: len was %d; want %d", len(cpy.Uints), len(test.Uints)) - goto Uint8s - } - for i, v := range test.Uints { - if v != cpy.Uints[i] { - t.Errorf("Uints: got %v at index %d of the copy; want %v", cpy.Uints[i], i, v) - } - } - -Uint8s: - if cpy.Uint8 != test.Uint8 { - t.Errorf("Uint8: got %v; want %v", cpy.Uint8, test.Uint8) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint8s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint8s)).Data { - t.Error("Uint8s: address of copy was the same as original; they should be different") - goto Uint16s - } - if len(cpy.Uint8s) != len(test.Uint8s) { - t.Errorf("Uint8s: len was %d; want %d", len(cpy.Uint8s), len(test.Uint8s)) - goto Uint16s - } - for i, v := range test.Uint8s { - if v != cpy.Uint8s[i] { - t.Errorf("Uint8s: got %v at index %d of the copy; want %v", cpy.Uint8s[i], i, v) - } - } - -Uint16s: - if cpy.Uint16 != test.Uint16 { - t.Errorf("Uint16: got %v; want %v", cpy.Uint16, test.Uint16) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint16s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint16s)).Data { - t.Error("Uint16s: address of copy was the same as original; they should be different") - goto Uint32s - } - if len(cpy.Uint16s) != len(test.Uint16s) { - t.Errorf("Uint16s: len was %d; want %d", len(cpy.Uint16s), len(test.Uint16s)) - goto Uint32s - } - for i, v := range test.Uint16s { - if v != cpy.Uint16s[i] { - t.Errorf("Uint16s: got %v at index %d of the copy; want %v", cpy.Uint16s[i], i, v) - } - } - -Uint32s: - if cpy.Uint32 != test.Uint32 { - t.Errorf("Uint32: got %v; want %v", cpy.Uint32, test.Uint32) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint32s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint32s)).Data { - t.Error("Uint32s: address of copy was the same as original; they should be different") - goto Uint64s - } - if len(cpy.Uint32s) != len(test.Uint32s) { - t.Errorf("Uint32s: len was %d; want %d", len(cpy.Uint32s), len(test.Uint32s)) - goto Uint64s - } - for i, v := range test.Uint32s { - if v != cpy.Uint32s[i] { - t.Errorf("Uint32s: got %v at index %d of the copy; want %v", cpy.Uint32s[i], i, v) - } - } - -Uint64s: - if cpy.Uint64 != test.Uint64 { - t.Errorf("Uint64: got %v; want %v", cpy.Uint64, test.Uint64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint64s)).Data { - t.Error("Uint64s: address of copy was the same as original; they should be different") - goto Float32s - } - if len(cpy.Uint64s) != len(test.Uint64s) { - t.Errorf("Uint64s: len was %d; want %d", len(cpy.Uint64s), len(test.Uint64s)) - goto Float32s - } - for i, v := range test.Uint64s { - if v != cpy.Uint64s[i] { - t.Errorf("Uint64s: got %v at index %d of the copy; want %v", cpy.Uint64s[i], i, v) - } - } - -Float32s: - if cpy.Float32 != test.Float32 { - t.Errorf("Float32: got %v; want %v", cpy.Float32, test.Float32) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Float32s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Float32s)).Data { - t.Error("Float32s: address of copy was the same as original; they should be different") - goto Float64s - } - if len(cpy.Float32s) != len(test.Float32s) { - t.Errorf("Float32s: len was %d; want %d", len(cpy.Float32s), len(test.Float32s)) - goto Float64s - } - for i, v := range test.Float32s { - if v != cpy.Float32s[i] { - t.Errorf("Float32s: got %v at index %d of the copy; want %v", cpy.Float32s[i], i, v) - } - } - -Float64s: - if cpy.Float64 != test.Float64 { - t.Errorf("Float64: got %v; want %v", cpy.Float64, test.Float64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Float64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Float64s)).Data { - t.Error("Float64s: address of copy was the same as original; they should be different") - goto Complex64s - } - if len(cpy.Float64s) != len(test.Float64s) { - t.Errorf("Float64s: len was %d; want %d", len(cpy.Float64s), len(test.Float64s)) - goto Complex64s - } - for i, v := range test.Float64s { - if v != cpy.Float64s[i] { - t.Errorf("Float64s: got %v at index %d of the copy; want %v", cpy.Float64s[i], i, v) - } - } - -Complex64s: - if cpy.Complex64 != test.Complex64 { - t.Errorf("Complex64: got %v; want %v", cpy.Complex64, test.Complex64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Complex64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Complex64s)).Data { - t.Error("Complex64s: address of copy was the same as original; they should be different") - goto Complex128s - } - if len(cpy.Complex64s) != len(test.Complex64s) { - t.Errorf("Complex64s: len was %d; want %d", len(cpy.Complex64s), len(test.Complex64s)) - goto Complex128s - } - for i, v := range test.Complex64s { - if v != cpy.Complex64s[i] { - t.Errorf("Complex64s: got %v at index %d of the copy; want %v", cpy.Complex64s[i], i, v) - } - } - -Complex128s: - if cpy.Complex128 != test.Complex128 { - t.Errorf("Complex128s: got %v; want %v", cpy.Complex128s, test.Complex128s) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Complex128s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Complex128s)).Data { - t.Error("Complex128s: address of copy was the same as original; they should be different") - goto Interfaces - } - if len(cpy.Complex128s) != len(test.Complex128s) { - t.Errorf("Complex128s: len was %d; want %d", len(cpy.Complex128s), len(test.Complex128s)) - goto Interfaces - } - for i, v := range test.Complex128s { - if v != cpy.Complex128s[i] { - t.Errorf("Complex128s: got %v at index %d of the copy; want %v", cpy.Complex128s[i], i, v) - } - } - -Interfaces: - if cpy.Interface != test.Interface { - t.Errorf("Interface: got %v; want %v", cpy.Interface, test.Interface) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Interfaces)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Interfaces)).Data { - t.Error("Interfaces: address of copy was the same as original; they should be different") - return - } - if len(cpy.Interfaces) != len(test.Interfaces) { - t.Errorf("Interfaces: len was %d; want %d", len(cpy.Interfaces), len(test.Interfaces)) - return - } - for i, v := range test.Interfaces { - if v != cpy.Interfaces[i] { - t.Errorf("Interfaces: got %v at index %d of the copy; want %v", cpy.Interfaces[i], i, v) - } - } -} - -// not meant to be exhaustive -func TestComplexSlices(t *testing.T) { - orig3Int := [][][]int{[][]int{[]int{1, 2, 3}, []int{11, 22, 33}}, [][]int{[]int{7, 8, 9}, []int{66, 77, 88, 99}}} - cpyI := Copy(orig3Int).([][][]int) - if (*reflect.SliceHeader)(unsafe.Pointer(&orig3Int)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyI)).Data { - t.Error("[][][]int: address of copy was the same as original; they should be different") - return - } - if len(orig3Int) != len(cpyI) { - t.Errorf("[][][]int: len of copy was %d; want %d", len(cpyI), len(orig3Int)) - goto sliceMap - } - for i, v := range orig3Int { - if len(v) != len(cpyI[i]) { - t.Errorf("[][][]int: len of element %d was %d; want %d", i, len(cpyI[i]), len(v)) - continue - } - for j, vv := range v { - if len(vv) != len(cpyI[i][j]) { - t.Errorf("[][][]int: len of element %d:%d was %d, want %d", i, j, len(cpyI[i][j]), len(vv)) - continue - } - for k, vvv := range vv { - if vvv != cpyI[i][j][k] { - t.Errorf("[][][]int: element %d:%d:%d was %d, want %d", i, j, k, cpyI[i][j][k], vvv) - } - } - } - - } - -sliceMap: - slMap := []map[int]string{map[int]string{0: "a", 1: "b"}, map[int]string{10: "k", 11: "l", 12: "m"}} - cpyM := Copy(slMap).([]map[int]string) - if (*reflect.SliceHeader)(unsafe.Pointer(&slMap)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyM)).Data { - t.Error("[]map[int]string: address of copy was the same as original; they should be different") - } - if len(slMap) != len(cpyM) { - t.Errorf("[]map[int]string: len of copy was %d; want %d", len(cpyM), len(slMap)) - goto done - } - for i, v := range slMap { - if len(v) != len(cpyM[i]) { - t.Errorf("[]map[int]string: len of element %d was %d; want %d", i, len(cpyM[i]), len(v)) - continue - } - for k, vv := range v { - val, ok := cpyM[i][k] - if !ok { - t.Errorf("[]map[int]string: element %d was expected to have a value at key %d, it didn't", i, k) - continue - } - if val != vv { - t.Errorf("[]map[int]string: element %d, key %d: got %s, want %s", i, k, val, vv) - } - } - } -done: -} - -type A struct { - Int int - String string - UintSl []uint - NilSl []string - Map map[string]int - MapB map[string]*B - SliceB []B - B - T time.Time -} - -type B struct { - Vals []string -} - -var AStruct = A{ - Int: 42, - String: "Konichiwa", - UintSl: []uint{0, 1, 2, 3}, - Map: map[string]int{"a": 1, "b": 2}, - MapB: map[string]*B{ - "hi": &B{Vals: []string{"hello", "bonjour"}}, - "bye": &B{Vals: []string{"good-bye", "au revoir"}}, - }, - SliceB: []B{ - B{Vals: []string{"Ciao", "Aloha"}}, - }, - B: B{Vals: []string{"42"}}, - T: time.Now(), -} - -func TestStructA(t *testing.T) { - cpy := Copy(AStruct).(A) - if &cpy == &AStruct { - t.Error("expected copy to have a different address than the original; it was the same") - return - } - if cpy.Int != AStruct.Int { - t.Errorf("A.Int: got %v, want %v", cpy.Int, AStruct.Int) - } - if cpy.String != AStruct.String { - t.Errorf("A.String: got %v; want %v", cpy.String, AStruct.String) - } - if (*reflect.SliceHeader)(unsafe.Pointer(&cpy.UintSl)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.UintSl)).Data { - t.Error("A.Uintsl: expected the copies address to be different; it wasn't") - goto NilSl - } - if len(cpy.UintSl) != len(AStruct.UintSl) { - t.Errorf("A.UintSl: got len of %d, want %d", len(cpy.UintSl), len(AStruct.UintSl)) - goto NilSl - } - for i, v := range AStruct.UintSl { - if cpy.UintSl[i] != v { - t.Errorf("A.UintSl %d: got %d, want %d", i, cpy.UintSl[i], v) - } - } - -NilSl: - if cpy.NilSl != nil { - t.Error("A.NilSl: expected slice to be nil, it wasn't") - } - - if *(*uintptr)(unsafe.Pointer(&cpy.Map)) == *(*uintptr)(unsafe.Pointer(&AStruct.Map)) { - t.Error("A.Map: expected the copy's address to be different; it wasn't") - goto AMapB - } - if len(cpy.Map) != len(AStruct.Map) { - t.Errorf("A.Map: got len of %d, want %d", len(cpy.Map), len(AStruct.Map)) - goto AMapB - } - for k, v := range AStruct.Map { - val, ok := cpy.Map[k] - if !ok { - t.Errorf("A.Map: expected the key %s to exist in the copy, it didn't", k) - continue - } - if val != v { - t.Errorf("A.Map[%s]: got %d, want %d", k, val, v) - } - } - -AMapB: - if *(*uintptr)(unsafe.Pointer(&cpy.MapB)) == *(*uintptr)(unsafe.Pointer(&AStruct.MapB)) { - t.Error("A.MapB: expected the copy's address to be different; it wasn't") - goto ASliceB - } - if len(cpy.MapB) != len(AStruct.MapB) { - t.Errorf("A.MapB: got len of %d, want %d", len(cpy.MapB), len(AStruct.MapB)) - goto ASliceB - } - for k, v := range AStruct.MapB { - val, ok := cpy.MapB[k] - if !ok { - t.Errorf("A.MapB: expected the key %s to exist in the copy, it didn't", k) - continue - } - if unsafe.Pointer(val) == unsafe.Pointer(v) { - t.Errorf("A.MapB[%s]: expected the addresses of the values to be different; they weren't", k) - continue - } - // the slice headers should point to different data - if (*reflect.SliceHeader)(unsafe.Pointer(&v.Vals)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&val.Vals)).Data { - t.Errorf("%s: expected B's SliceHeaders to point to different Data locations; they did not.", k) - continue - } - for i, vv := range v.Vals { - if vv != val.Vals[i] { - t.Errorf("A.MapB[%s].Vals[%d]: got %s want %s", k, i, vv, val.Vals[i]) - } - } - } - -ASliceB: - if (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.SliceB)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.SliceB)).Data { - t.Error("A.SliceB: expected the copy's address to be different; it wasn't") - goto B - } - - if len(AStruct.SliceB) != len(cpy.SliceB) { - t.Errorf("A.SliceB: got length of %d; want %d", len(cpy.SliceB), len(AStruct.SliceB)) - goto B - } - - for i := range AStruct.SliceB { - if unsafe.Pointer(&AStruct.SliceB[i]) == unsafe.Pointer(&cpy.SliceB[i]) { - t.Errorf("A.SliceB[%d]: expected them to have different addresses, they didn't", i) - continue - } - if (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.SliceB[i].Vals)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.SliceB[i].Vals)).Data { - t.Errorf("A.SliceB[%d]: expected B.Vals SliceHeader.Data to point to different locations; they did not", i) - continue - } - if len(AStruct.SliceB[i].Vals) != len(cpy.SliceB[i].Vals) { - t.Errorf("A.SliceB[%d]: expected B's vals to have the same length, they didn't", i) - continue - } - for j, val := range AStruct.SliceB[i].Vals { - if val != cpy.SliceB[i].Vals[j] { - t.Errorf("A.SliceB[%d].Vals[%d]: got %v; want %v", i, j, cpy.SliceB[i].Vals[j], val) - } - } - } -B: - if unsafe.Pointer(&AStruct.B) == unsafe.Pointer(&cpy.B) { - t.Error("A.B: expected them to have different addresses, they didn't") - goto T - } - if (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.B.Vals)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.B.Vals)).Data { - t.Error("A.B.Vals: expected the SliceHeaders.Data to point to different locations; they didn't") - goto T - } - if len(AStruct.B.Vals) != len(cpy.B.Vals) { - t.Error("A.B.Vals: expected their lengths to be the same, they weren't") - goto T - } - for i, v := range AStruct.B.Vals { - if v != cpy.B.Vals[i] { - t.Errorf("A.B.Vals[%d]: got %s want %s", i, cpy.B.Vals[i], v) - } - } -T: - if fmt.Sprintf("%p", &AStruct.T) == fmt.Sprintf("%p", &cpy.T) { - t.Error("A.T: expected them to have different addresses, they didn't") - return - } - if AStruct.T != cpy.T { - t.Errorf("A.T: got %v, want %v", cpy.T, AStruct.T) - } -} - -type Unexported struct { - A string - B int - aa string - bb int - cc []int - dd map[string]string -} - -func TestUnexportedFields(t *testing.T) { - u := &Unexported{ - A: "A", - B: 42, - aa: "aa", - bb: 42, - cc: []int{1, 2, 3}, - dd: map[string]string{"hello": "bonjour"}, - } - cpy := Copy(u).(*Unexported) - if cpy == u { - t.Error("expected addresses to be different, they weren't") - return - } - if u.A != cpy.A { - t.Errorf("Unexported.A: got %s want %s", cpy.A, u.A) - } - if u.B != cpy.B { - t.Errorf("Unexported.A: got %d want %d", cpy.B, u.B) - } - if cpy.aa != "" { - t.Errorf("Unexported.aa: unexported field should not be set, it was set to %s", cpy.aa) - } - if cpy.bb != 0 { - t.Errorf("Unexported.bb: unexported field should not be set, it was set to %d", cpy.bb) - } - if cpy.cc != nil { - t.Errorf("Unexported.cc: unexported field should not be set, it was set to %#v", cpy.cc) - } - if cpy.dd != nil { - t.Errorf("Unexported.dd: unexported field should not be set, it was set to %#v", cpy.dd) - } -} - -// Note: this test will fail until https://github.com/golang/go/issues/15716 is -// fixed and the version it is part of gets released. -type T struct { - time.Time -} - -func TestTimeCopy(t *testing.T) { - tests := []struct { - Y int - M time.Month - D int - h int - m int - s int - nsec int - TZ string - }{ - {2016, time.July, 4, 23, 11, 33, 3000, "America/New_York"}, - {2015, time.October, 31, 9, 44, 23, 45935, "UTC"}, - {2014, time.May, 5, 22, 01, 50, 219300, "Europe/Prague"}, - } - - for i, test := range tests { - l, err := time.LoadLocation(test.TZ) - if err != nil { - t.Errorf("%d: unexpected error: %s", i, err) - continue - } - var x T - x.Time = time.Date(test.Y, test.M, test.D, test.h, test.m, test.s, test.nsec, l) - c := Copy(x).(T) - if fmt.Sprintf("%p", &c) == fmt.Sprintf("%p", &x) { - t.Errorf("%d: expected the copy to have a different address than the original value; they were the same: %p %p", i, &c, &x) - continue - } - if x.UnixNano() != c.UnixNano() { - t.Errorf("%d: nanotime: got %v; want %v", i, c.UnixNano(), x.UnixNano()) - continue - } - if x.Location() != c.Location() { - t.Errorf("%d: location: got %q; want %q", i, c.Location(), x.Location()) - } - } -} - -func TestPointerToStruct(t *testing.T) { - type Foo struct { - Bar int - } - - f := &Foo{Bar: 42} - cpy := Copy(f) - if f == cpy { - t.Errorf("expected copy to point to a different location: orig: %p; copy: %p", f, cpy) - } - if !reflect.DeepEqual(f, cpy) { - t.Errorf("expected the copy to be equal to the original (except for memory location); it wasn't: got %#v; want %#v", f, cpy) - } -} - -func TestIssue9(t *testing.T) { - // simple pointer copy - x := 42 - testA := map[string]*int{ - "a": nil, - "b": &x, - } - copyA := Copy(testA).(map[string]*int) - if unsafe.Pointer(&testA) == unsafe.Pointer(©A) { - t.Fatalf("expected the map pointers to be different: testA: %v\tcopyA: %v", unsafe.Pointer(&testA), unsafe.Pointer(©A)) - } - if !reflect.DeepEqual(testA, copyA) { - t.Errorf("got %#v; want %#v", copyA, testA) - } - if testA["b"] == copyA["b"] { - t.Errorf("entries for 'b' pointed to the same address: %v; expected them to point to different addresses", testA["b"]) - } - - // map copy - type Foo struct { - Alpha string - } - - type Bar struct { - Beta string - Gamma int - Delta *Foo - } - - type Biz struct { - Epsilon map[int]*Bar - } - - testB := Biz{ - Epsilon: map[int]*Bar{ - 0: &Bar{}, - 1: &Bar{ - Beta: "don't panic", - Gamma: 42, - Delta: nil, - }, - 2: &Bar{ - Beta: "sudo make me a sandwich.", - Gamma: 11, - Delta: &Foo{ - Alpha: "okay.", - }, - }, - }, - } - - copyB := Copy(testB).(Biz) - if !reflect.DeepEqual(testB, copyB) { - t.Errorf("got %#v; want %#v", copyB, testB) - return - } - - // check that the maps point to different locations - if unsafe.Pointer(&testB.Epsilon) == unsafe.Pointer(©B.Epsilon) { - t.Fatalf("expected the map pointers to be different; they weren't: testB: %v\tcopyB: %v", unsafe.Pointer(&testB.Epsilon), unsafe.Pointer(©B.Epsilon)) - } - - for k, v := range testB.Epsilon { - if v == nil && copyB.Epsilon[k] == nil { - continue - } - if v == nil && copyB.Epsilon[k] != nil { - t.Errorf("%d: expected copy of a nil entry to be nil; it wasn't: %#v", k, copyB.Epsilon[k]) - continue - } - if v == copyB.Epsilon[k] { - t.Errorf("entries for '%d' pointed to the same address: %v; expected them to point to different addresses", k, v) - continue - } - if v.Beta != copyB.Epsilon[k].Beta { - t.Errorf("%d.Beta: got %q; want %q", k, copyB.Epsilon[k].Beta, v.Beta) - } - if v.Gamma != copyB.Epsilon[k].Gamma { - t.Errorf("%d.Gamma: got %d; want %d", k, copyB.Epsilon[k].Gamma, v.Gamma) - } - if v.Delta == nil && copyB.Epsilon[k].Delta == nil { - continue - } - if v.Delta == nil && copyB.Epsilon[k].Delta != nil { - t.Errorf("%d.Delta: got %#v; want nil", k, copyB.Epsilon[k].Delta) - } - if v.Delta == copyB.Epsilon[k].Delta { - t.Errorf("%d.Delta: expected the pointers to be different, they were the same: %v", k, v.Delta) - continue - } - if v.Delta.Alpha != copyB.Epsilon[k].Delta.Alpha { - t.Errorf("%d.Delta.Foo: got %q; want %q", k, v.Delta.Alpha, copyB.Epsilon[k].Delta.Alpha) - } - } - - // test that map keys are deep copied - testC := map[*Foo][]string{ - &Foo{Alpha: "Henry Dorsett Case"}: []string{ - "Cutter", - }, - &Foo{Alpha: "Molly Millions"}: []string{ - "Rose Kolodny", - "Cat Mother", - "Steppin' Razor", - }, - } - - copyC := Copy(testC).(map[*Foo][]string) - if unsafe.Pointer(&testC) == unsafe.Pointer(©C) { - t.Fatalf("expected the map pointers to be different; they weren't: testB: %v\tcopyB: %v", unsafe.Pointer(&testB.Epsilon), unsafe.Pointer(©B.Epsilon)) - } - - // make sure the lengths are the same - if len(testC) != len(copyC) { - t.Fatalf("got len %d; want %d", len(copyC), len(testC)) - } - - // check that everything was deep copied: since the key is a pointer, we check to - // see if the pointers are different but the values being pointed to are the same. - for k, v := range testC { - for kk, vv := range copyC { - if *kk == *k { - if kk == k { - t.Errorf("key pointers should be different: orig: %p; copy: %p", k, kk) - } - // check that the slices are the same but different - if !reflect.DeepEqual(v, vv) { - t.Errorf("expected slice contents to be the same; they weren't: orig: %v; copy: %v", v, vv) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&v)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&vv)).Data { - t.Errorf("expected the SliceHeaders.Data to point to different locations; they didn't: %v", (*reflect.SliceHeader)(unsafe.Pointer(&v)).Data) - } - break - } - } - } - - type Bizz struct { - *Foo - } - - testD := map[Bizz]string{ - Bizz{&Foo{"Neuromancer"}}: "Rio", - Bizz{&Foo{"Wintermute"}}: "Berne", - } - copyD := Copy(testD).(map[Bizz]string) - if len(copyD) != len(testD) { - t.Fatalf("copy had %d elements; expected %d", len(copyD), len(testD)) - } - - for k, v := range testD { - var found bool - for kk, vv := range copyD { - if reflect.DeepEqual(k, kk) { - found = true - // check that Foo points to different locations - if unsafe.Pointer(k.Foo) == unsafe.Pointer(kk.Foo) { - t.Errorf("Expected Foo to point to different locations; they didn't: orig: %p; copy %p", k.Foo, kk.Foo) - break - } - if *k.Foo != *kk.Foo { - t.Errorf("Expected copy of the key's Foo field to have the same value as the original, it wasn't: orig: %#v; copy: %#v", k.Foo, kk.Foo) - } - if v != vv { - t.Errorf("Expected the values to be the same; the weren't: got %v; want %v", vv, v) - } - } - } - if !found { - t.Errorf("expected key %v to exist in the copy; it didn't", k) - } - } -} - -type I struct { - A string -} - -func (i *I) DeepCopy() interface{} { - return &I{A: "custom copy"} -} - -type NestI struct { - I *I -} - -func TestInterface(t *testing.T) { - i := &I{A: "A"} - copied := Copy(i).(*I) - if copied.A != "custom copy" { - t.Errorf("expected value %v, but it's %v", "custom copy", copied.A) - } - // check for nesting values - ni := &NestI{I: &I{A: "A"}} - copiedNest := Copy(ni).(*NestI) - if copiedNest.I.A != "custom copy" { - t.Errorf("expected value %v, but it's %v", "custom copy", copiedNest.I.A) - } -} diff --git a/hashing/hasher_test.go b/hashing/hasher_test.go index b496589..e9532d6 100644 --- a/hashing/hasher_test.go +++ b/hashing/hasher_test.go @@ -7,14 +7,14 @@ import ( "os" "strconv" "testing" + + "github.com/stretchr/testify/assert" + ) func TestMurmurHashOnAlphanumericData(t *testing.T) { inFile, err := os.Open("../testdata/murmur3-sample-data-v2.csv") - if err != nil { - t.Error("Missing test file...") - return - } + assert.Nil(t, err) defer inFile.Close() reader := csv.NewReader(bufio.NewReader(inFile)) @@ -32,19 +32,13 @@ func TestMurmurHashOnAlphanumericData(t *testing.T) { digest, _ := strconv.ParseUint(arr[2], 10, 32) calculated := NewMurmur332Hasher(uint32(seed)).Hash([]byte(str)) - if calculated != uint32(digest) { - t.Errorf("%d: Murmur hash calculation failed for string %s. Should be %d and was %d", line, str, digest, calculated) - break - } + assert.Equal(t, calculated, uint32(digest)) } } func TestMurmurHashOnNonAlphanumericData(t *testing.T) { inFile, err := os.Open("../testdata/murmur3-sample-data-non-alpha-numeric-v2.csv") - if err != nil { - t.Error("Missing test file...") - return - } + assert.Nil(t, err) defer inFile.Close() reader := csv.NewReader(bufio.NewReader(inFile)) @@ -62,9 +56,6 @@ func TestMurmurHashOnNonAlphanumericData(t *testing.T) { digest, _ := strconv.ParseUint(arr[2], 10, 32) calculated := NewMurmur332Hasher(uint32(seed)).Hash([]byte(str)) - if calculated != uint32(digest) { - t.Errorf("%d: Murmur hash calculation failed for string %s. Should be %d and was %d", line, str, digest, calculated) - break - } + assert.Equal(t, calculated, uint32(digest)) } } diff --git a/hashing/murmur128_test.go b/hashing/murmur128_test.go index 0aacb79..30fa1ae 100644 --- a/hashing/murmur128_test.go +++ b/hashing/murmur128_test.go @@ -1,17 +1,17 @@ package hashing import ( - "os" + "os" "strconv" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestMurmur128(t *testing.T) { raw, err := os.ReadFile("../testdata/murmur3_64_uuids.csv") - if err != nil { - t.Error("error reading murmur128 test cases files: ", err.Error()) - } + assert.Nil(t, err) lines := strings.Split(string(raw), "\n") for _, line := range lines { @@ -23,8 +23,6 @@ func TestMurmur128(t *testing.T) { expected, _ := strconv.ParseInt(fields[2], 10, 64) h1, _ := Sum128WithSeed([]byte(fields[0]), uint32(seed)) - if int64(h1) != expected { - t.Errorf("Hashes don't match. Expected: %d, actual: %d", expected, uint64(h1)) - } + assert.Equal(t, expected, int64(h1)) } } diff --git a/hashing/util_test.go b/hashing/util_test.go index 19f5810..691adbb 100644 --- a/hashing/util_test.go +++ b/hashing/util_test.go @@ -1,22 +1,18 @@ package hashing -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestEncode(t *testing.T) { hash, err := Encode(nil, "something") - if hash != "" { - t.Error("Unexpected result") - } - if err == nil || err.Error() != "Hasher could not be nil" { - t.Error("Unexpected error message") - } - + assert.ErrorContains(t, err, "Hasher could not be nil") + assert.Equal(t, "", hash) + hasher := NewMurmur332Hasher(0) hash2, err := Encode(hasher, "something") - if err != nil { - t.Error("It should not return error") - } - if hash2 != "NDE0MTg0MjI2MQ==" { - t.Error("Unexpected result") - } + assert.Nil(t, err) + assert.Equal(t, "NDE0MTg0MjI2MQ==", hash2) } diff --git a/redis/helpers/helpers_test.go b/redis/helpers/helpers_test.go index d968aa0..f838af7 100644 --- a/redis/helpers/helpers_test.go +++ b/redis/helpers/helpers_test.go @@ -4,58 +4,38 @@ import ( "errors" "testing" - "github.com/splitio/go-toolkit/v6/redis" "github.com/splitio/go-toolkit/v6/redis/mocks" + "github.com/stretchr/testify/assert" ) func TestEnsureConnected(t *testing.T) { - redisClient := mocks.MockClient{ - PingCall: func() redis.Result { - return &mocks.MockResultOutput{ - ErrCall: func() error { return nil }, - StringCall: func() string { return "PONG" }, - } - }, - } - EnsureConnected(&redisClient) + var resMock mocks.MockResultOutput + resMock.On("String").Return(pong).Once() + resMock.On("Err").Return(nil).Once() + + var clientMock mocks.MockClient + clientMock.On("Ping").Return(&resMock).Once() + EnsureConnected(&clientMock) } func TestEnsureConnectedError(t *testing.T) { - defer func() { - if r := recover(); r != nil { - if r != "Couldn't connect to redis: someError" { - t.Error("Expected \"Couldn't connect to redis: someError\". Got: ", r) - } - } - }() - redisClient := mocks.MockClient{ - PingCall: func() redis.Result { - return &mocks.MockResultOutput{ - ErrCall: func() error { return errors.New("someError") }, - StringCall: func() string { return "" }, - } - }, - } - EnsureConnected(&redisClient) - t.Error("Should not reach this line") + var resMock mocks.MockResultOutput + resMock.On("String").Return("").Once() + resMock.On("Err").Return(errors.New("someError")).Once() + + var clientMock mocks.MockClient + clientMock.On("Ping").Return(&resMock).Once() + + assert.Panics(t, func() { EnsureConnected(&clientMock) }) } func TestEnsureConnectedNotPong(t *testing.T) { - defer func() { - if r := recover(); r != nil { - if r != "Invalid redis ping response when connecting: PANG" { - t.Error("Invalid redis ping response when connecting: PANG", r) - } - } - }() - redisClient := mocks.MockClient{ - PingCall: func() redis.Result { - return &mocks.MockResultOutput{ - ErrCall: func() error { return nil }, - StringCall: func() string { return "PANG" }, - } - }, - } - EnsureConnected(&redisClient) - t.Error("Should not reach this line") + var resMock mocks.MockResultOutput + resMock.On("String").Return("PANG").Once() + resMock.On("Err").Return(nil).Once() + + var clientMock mocks.MockClient + clientMock.On("Ping").Return(&resMock).Once() + + assert.Panics(t, func() { EnsureConnected(&clientMock) }) } diff --git a/redis/mocks/mocks.go b/redis/mocks/mocks.go index fb84349..e24aa1f 100644 --- a/redis/mocks/mocks.go +++ b/redis/mocks/mocks.go @@ -3,324 +3,299 @@ package mocks import ( "time" + "github.com/splitio/go-toolkit/v6/common" "github.com/splitio/go-toolkit/v6/redis" + "github.com/stretchr/testify/mock" ) -// MockResultOutput mocks struct -type MockResultOutput struct { - ErrCall func() error - IntCall func() int64 - StringCall func() string - BoolCall func() bool - DurationCall func() time.Duration - ResultCall func() (int64, error) - ResultStringCall func() (string, error) - MultiCall func() ([]string, error) - MultiInterfaceCall func() ([]interface{}, error) - MapStringStringCall func() (map[string]string, error) -} - -// Int mocks Int -func (m *MockResultOutput) Int() int64 { - return m.IntCall() +type MockClient struct { + mock.Mock } -// Err mocks Err -func (m *MockResultOutput) Err() error { - return m.ErrCall() +// ClusterCountKeysInSlot implements redis.Client. +func (m *MockClient) ClusterCountKeysInSlot(slot int) redis.Result { + return m.Called(slot).Get(0).(redis.Result) } -// String mocks String -func (m *MockResultOutput) String() string { - return m.StringCall() +// ClusterKeysInSlot implements redis.Client. +func (m *MockClient) ClusterKeysInSlot(slot int, count int) redis.Result { + return m.Called(slot, count).Get(0).(redis.Result) } -// Bool mocks Bool -func (m *MockResultOutput) Bool() bool { - return m.BoolCall() +// ClusterMode implements redis.Client. +func (m *MockClient) ClusterMode() bool { + return m.Called().Bool(0) } -// Duration mocks Duration -func (m *MockResultOutput) Duration() time.Duration { - return m.DurationCall() +// ClusterSlotForKey implements redis.Client. +func (m *MockClient) ClusterSlotForKey(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// Result mocks Result -func (m *MockResultOutput) Result() (int64, error) { - return m.ResultCall() +// Decr implements redis.Client. +func (m *MockClient) Decr(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// ResultString mocks ResultString -func (m *MockResultOutput) ResultString() (string, error) { - return m.ResultStringCall() +// Del implements redis.Client. +func (m *MockClient) Del(keys ...string) redis.Result { + return m.Called(common.AsInterfaceSlice(keys)...).Get(0).(redis.Result) } -// Multi mocks Multi -func (m *MockResultOutput) Multi() ([]string, error) { - return m.MultiCall() +// Eval implements redis.Client. +func (m *MockClient) Eval(script string, keys []string, args ...interface{}) redis.Result { + return m.Called(append([]interface{}{script, keys}, args...)...).Get(0).(redis.Result) } -// MultiInterface mocks MultiInterface -func (m *MockResultOutput) MultiInterface() ([]interface{}, error) { - return m.MultiInterfaceCall() +// Exists implements redis.Client. +func (m *MockClient) Exists(keys ...string) redis.Result { + return m.Called(common.AsInterfaceSlice(keys)...).Get(0).(redis.Result) } -// MapStringString mocks MapStringString -func (m *MockResultOutput) MapStringString() (map[string]string, error) { - return m.MapStringStringCall() +// Expire implements redis.Client. +func (m *MockClient) Expire(key string, value time.Duration) redis.Result { + return m.Called(key, value).Get(0).(redis.Result) } -// MpockPipeline impl -type MockPipeline struct { - LRangeCall func(key string, start, stop int64) - LTrimCall func(key string, start, stop int64) - LLenCall func(key string) - HIncrByCall func(key string, field string, value int64) - HLenCall func(key string) - SetCall func(key string, value interface{}, expiration time.Duration) - IncrCall func(key string) - DecrCall func(key string) - SAddCall func(key string, members ...interface{}) - SRemCall func(key string, members ...interface{}) - SMembersCall func(key string) - DelCall func(keys ...string) - ExecCall func() ([]redis.Result, error) +// Get implements redis.Client. +func (m *MockClient) Get(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) LRange(key string, start, stop int64) { - m.LRangeCall(key, start, stop) +// HGetAll implements redis.Client. +func (m *MockClient) HGetAll(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) LTrim(key string, start, stop int64) { - m.LTrimCall(key, start, stop) +// HIncrBy implements redis.Client. +func (m *MockClient) HIncrBy(key string, field string, value int64) redis.Result { + return m.Called(key, field, value).Get(0).(redis.Result) } -func (m *MockPipeline) LLen(key string) { - m.LLenCall(key) +// HSet implements redis.Client. +func (m *MockClient) HSet(key string, hashKey string, value interface{}) redis.Result { + return m.Called(key, hashKey, value).Get(0).(redis.Result) + } -func (m *MockPipeline) HIncrBy(key string, field string, value int64) { - m.HIncrByCall(key, field, value) +// Incr implements redis.Client. +func (m *MockClient) Incr(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) HLen(key string) { - m.HLenCall(key) +// Keys implements redis.Client. +func (m *MockClient) Keys(pattern string) redis.Result { + return m.Called(pattern).Get(0).(redis.Result) + } -func (m *MockPipeline) Set(key string, value interface{}, expiration time.Duration) { - m.SetCall(key, value, expiration) +// LLen implements redis.Client. +func (m *MockClient) LLen(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) Incr(key string) { - m.IncrCall(key) +// LRange implements redis.Client. +func (m *MockClient) LRange(key string, start int64, stop int64) redis.Result { + return m.Called(key, start, stop).Get(0).(redis.Result) } -func (m *MockPipeline) Decr(key string) { - m.DecrCall(key) +// LTrim implements redis.Client. +func (m *MockClient) LTrim(key string, start int64, stop int64) redis.Result { + return m.Called(key, start, stop).Get(0).(redis.Result) + } -func (m *MockPipeline) SAdd(key string, members ...interface{}) { - m.SAddCall(key, members...) +// MGet implements redis.Client. +func (m *MockClient) MGet(keys []string) redis.Result { + return m.Called(keys).Get(0).(redis.Result) } -func (m *MockPipeline) SRem(key string, members ...interface{}) { - m.SRemCall(key, members...) +// Ping implements redis.Client. +func (m *MockClient) Ping() redis.Result { + return m.Called().Get(0).(redis.Result) } -func (m *MockPipeline) SMembers(key string) { - m.SMembersCall(key) +// Pipeline implements redis.Client. +func (m *MockClient) Pipeline() redis.Pipeline { + return m.Called().Get(0).(redis.Pipeline) } -func (m *MockPipeline) Del(keys ...string) { - m.DelCall(keys...) +// RPush implements redis.Client. +func (m *MockClient) RPush(key string, values ...interface{}) redis.Result { + return m.Called(append([]interface{}{key}, values...)...).Get(0).(redis.Result) + } -func (m *MockPipeline) Exec() ([]redis.Result, error) { - return m.ExecCall() +// SAdd implements redis.Client. +func (m *MockClient) SAdd(key string, members ...interface{}) redis.Result { + return m.Called(append([]interface{}{key}, members...)...).Get(0).(redis.Result) } -// MockClient mocks for testing purposes -type MockClient struct { - ClusterModeCall func() bool - ClusterCountKeysInSlotCall func(slot int) redis.Result - ClusterSlotForKeyCall func(key string) redis.Result - ClusterKeysInSlotCall func(slot int, count int) redis.Result - DelCall func(keys ...string) redis.Result - GetCall func(key string) redis.Result - SetCall func(key string, value interface{}, expiration time.Duration) redis.Result - PingCall func() redis.Result - ExistsCall func(keys ...string) redis.Result - KeysCall func(pattern string) redis.Result - SMembersCall func(key string) redis.Result - SIsMemberCall func(key string, member interface{}) redis.Result - SAddCall func(key string, members ...interface{}) redis.Result - SRemCall func(key string, members ...interface{}) redis.Result - IncrCall func(key string) redis.Result - DecrCall func(key string) redis.Result - RPushCall func(key string, values ...interface{}) redis.Result - LRangeCall func(key string, start, stop int64) redis.Result - LTrimCall func(key string, start, stop int64) redis.Result - LLenCall func(key string) redis.Result - ExpireCall func(key string, value time.Duration) redis.Result - TTLCall func(key string) redis.Result - MGetCall func(keys []string) redis.Result - SCardCall func(key string) redis.Result - EvalCall func(script string, keys []string, args ...interface{}) redis.Result - HIncrByCall func(key string, field string, value int64) redis.Result - HGetAllCall func(key string) redis.Result - HSetCall func(key string, hashKey string, value interface{}) redis.Result - TypeCall func(key string) redis.Result - PipelineCall func() redis.Pipeline - ScanCall func(cursor uint64, match string, count int64) redis.Result +// SCard implements redis.Client. +func (m *MockClient) SCard(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockClient) ClusterMode() bool { - return m.ClusterModeCall() +// SIsMember implements redis.Client. +func (m *MockClient) SIsMember(key string, member interface{}) redis.Result { + return m.Called(key, member).Get(0).(redis.Result) + } -func (m *MockClient) ClusterCountKeysInSlot(slot int) redis.Result { - return m.ClusterCountKeysInSlotCall(slot) +// SMembers implements redis.Client. +func (m *MockClient) SMembers(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockClient) ClusterSlotForKey(key string) redis.Result { - return m.ClusterSlotForKeyCall(key) +// SRem implements redis.Client. +func (m *MockClient) SRem(key string, members ...interface{}) redis.Result { + return m.Called(append([]interface{}{key}, members...)...).Get(0).(redis.Result) } -func (m *MockClient) ClusterKeysInSlot(slot int, count int) redis.Result { - return m.ClusterKeysInSlotCall(slot, count) +// Scan implements redis.Client. +func (m *MockClient) Scan(cursor uint64, match string, count int64) redis.Result { + return m.Called(cursor, match, count).Get(0).(redis.Result) } -// Del mocks get -func (m *MockClient) Del(keys ...string) redis.Result { - return m.DelCall(keys...) +// Set implements redis.Client. +func (m *MockClient) Set(key string, value interface{}, expiration time.Duration) redis.Result { + return m.Called(key, value, expiration).Get(0).(redis.Result) } -// Get mocks get -func (m *MockClient) Get(key string) redis.Result { - return m.GetCall(key) +// TTL implements redis.Client. +func (m *MockClient) TTL(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// Set mocks set -func (m *MockClient) Set(key string, value interface{}, expiration time.Duration) redis.Result { - return m.SetCall(key, value, expiration) +// Type implements redis.Client. +func (m *MockClient) Type(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// Exists mocks set -func (m *MockClient) Exists(keys ...string) redis.Result { - return m.ExistsCall(keys...) +type MockPipeline struct { + mock.Mock } -// Ping mocks ping -func (m *MockClient) Ping() redis.Result { - return m.PingCall() +// Decr implements redis.Pipeline. +func (m *MockPipeline) Decr(key string) { + m.Called(key) } -// Keys mocks keys -func (m *MockClient) Keys(pattern string) redis.Result { - return m.KeysCall(pattern) +// Del implements redis.Pipeline. +func (m *MockPipeline) Del(keys ...string) { + m.Called(common.AsInterfaceSlice(keys)) } -// SMembers mocks SMembers -func (m *MockClient) SMembers(key string) redis.Result { - return m.SMembersCall(key) +// Exec implements redis.Pipeline. +func (m *MockPipeline) Exec() ([]redis.Result, error) { + args := m.Called() + return args.Get(0).([]redis.Result), args.Error(1) } -// SIsMember mocks SIsMember -func (m *MockClient) SIsMember(key string, member interface{}) redis.Result { - return m.SIsMemberCall(key, member) +// HIncrBy implements redis.Pipeline. +func (m *MockPipeline) HIncrBy(key string, field string, value int64) { + m.Called(key, field, value) } -// SAdd mocks SAdd -func (m *MockClient) SAdd(key string, members ...interface{}) redis.Result { - return m.SAddCall(key, members...) +// HLen implements redis.Pipeline. +func (m *MockPipeline) HLen(key string) { + m.Called(key) } -// SRem mocks SRem -func (m *MockClient) SRem(key string, members ...interface{}) redis.Result { - return m.SRemCall(key, members...) +// Incr implements redis.Pipeline. +func (m *MockPipeline) Incr(key string) { + m.Called(key) } -// Incr mocks Incr -func (m *MockClient) Incr(key string) redis.Result { - return m.IncrCall(key) +// LLen implements redis.Pipeline. +func (m *MockPipeline) LLen(key string) { + m.Called(key) } -// Decr mocks Decr -func (m *MockClient) Decr(key string) redis.Result { - return m.DecrCall(key) +// LRange implements redis.Pipeline. +func (m *MockPipeline) LRange(key string, start int64, stop int64) { + m.Called(key, start, stop) } -// RPush mocks RPush -func (m *MockClient) RPush(key string, values ...interface{}) redis.Result { - return m.RPushCall(key, values...) +// LTrim implements redis.Pipeline. +func (m *MockPipeline) LTrim(key string, start int64, stop int64) { + m.Called(key, start, stop) +} + +// SAdd implements redis.Pipeline. +func (m *MockPipeline) SAdd(key string, members ...interface{}) { + m.Called(append([]interface{}{key}, members...)...) } -// LRange mocks LRange -func (m *MockClient) LRange(key string, start, stop int64) redis.Result { - return m.LRangeCall(key, start, stop) +// SMembers implements redis.Pipeline. +func (m *MockPipeline) SMembers(key string) { + m.Called(key) } -// LTrim mocks LTrim -func (m *MockClient) LTrim(key string, start, stop int64) redis.Result { - return m.LTrimCall(key, start, stop) +// SRem implements redis.Pipeline. +func (m *MockPipeline) SRem(key string, members ...interface{}) { + m.Called(append([]interface{}{key}, members...)...) } -// LLen mocks LLen -func (m *MockClient) LLen(key string) redis.Result { - return m.LLenCall(key) +// Set implements redis.Pipeline. +func (m *MockPipeline) Set(key string, value interface{}, expiration time.Duration) { + m.Called(key, value) } -// Expire mocks Expire -func (m *MockClient) Expire(key string, value time.Duration) redis.Result { - return m.ExpireCall(key, value) +type MockResultOutput struct { + mock.Mock } -// TTL mocks TTL -func (m *MockClient) TTL(key string) redis.Result { - return m.TTLCall(key) +// Bool implements redis.Result. +func (m *MockResultOutput) Bool() bool { + return m.Called().Bool(0) } -// MGet mocks MGet -func (m *MockClient) MGet(keys []string) redis.Result { - return m.MGetCall(keys) +// Duration implements redis.Result. +func (m *MockResultOutput) Duration() time.Duration { + return m.Called().Get(0).(time.Duration) } -// SCard mocks SCard -func (m *MockClient) SCard(key string) redis.Result { - return m.SCardCall(key) +// Err implements redis.Result. +func (m *MockResultOutput) Err() error { + return m.Called().Error(0) } -// Eval mocks Eval -func (m *MockClient) Eval(script string, keys []string, args ...interface{}) redis.Result { - return m.EvalCall(script, keys, args...) +// Int implements redis.Result. +func (m *MockResultOutput) Int() int64 { + return m.Called().Get(0).(int64) } -// HIncrBy mocks HIncrByCall -func (m *MockClient) HIncrBy(key string, field string, value int64) redis.Result { - return m.HIncrByCall(key, field, value) +func (m *MockResultOutput) String() string { + return m.Called().Get(0).(string) } -// HGetAll mocks HGetAll -func (m *MockClient) HGetAll(key string) redis.Result { - return m.HGetAllCall(key) +// MapStringString implements redis.Result. +func (m *MockResultOutput) MapStringString() (map[string]string, error) { + args := m.Called() + return args.Get(0).(map[string]string), args.Error(1) } -// HSet implements HGetAll wrapper for redis -func (m *MockClient) HSet(key string, hashKey string, value interface{}) redis.Result { - return m.HSetCall(key, hashKey, value) +// Multi implements redis.Result. +func (m *MockResultOutput) Multi() ([]string, error) { + args := m.Called() + return args.Get(0).([]string), args.Error(1) } -// Type implements Type wrapper for redis with prefix -func (m *MockClient) Type(key string) redis.Result { - return m.TypeCall(key) +// MultiInterface implements redis.Result. +func (m *MockResultOutput) MultiInterface() ([]interface{}, error) { + args := m.Called() + return args.Get(0).([]interface{}), args.Error(1) } -// Pipeline mock -func (m *MockClient) Pipeline() redis.Pipeline { - return m.PipelineCall() +// Result implements redis.Result. +func (m *MockResultOutput) Result() (int64, error) { + args := m.Called() + return args.Get(0).(int64), args.Error(1) } -// Scan mock -func (m *MockClient) Scan(cursor uint64, match string, count int64) redis.Result { - return m.ScanCall(cursor, match, count) +// ResultString implements redis.Result. +func (m *MockResultOutput) ResultString() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) } diff --git a/redis/wrapper_test.go b/redis/wrapper_test.go index e5fb32b..cd7b865 100644 --- a/redis/wrapper_test.go +++ b/redis/wrapper_test.go @@ -18,29 +18,20 @@ func TestRedisWrapperKeysAndScan(t *testing.T) { } keys, err := client.Keys("utest*").Multi() - if err != nil { - t.Error("there should not be any error. Got: ", err) - } - - if len(keys) != 10 { - t.Error("should be 10 keys. Got: ", len(keys)) - } - + assert.Nil(t, err) + assert.Equal(t, 10, len(keys)) var cursor uint64 + scanKeys := make([]string, 0) for { result := client.Scan(cursor, "utest*", 10) - if result.Err() != nil { - t.Error("there should not be any error. Got: ", result.Err()) - } + assert.Nil(t, result.Err()) cursor = uint64(result.Int()) keys, err := result.Multi() - if err != nil { - t.Error("there should not be any error. Got: ", err) - } - + assert.Nil(t, err) + scanKeys = append(scanKeys, keys...) if cursor == 0 { @@ -48,10 +39,7 @@ func TestRedisWrapperKeysAndScan(t *testing.T) { } } - if len(scanKeys) != 10 { - t.Error("should be 10 keys. Got: ", len(scanKeys)) - } - + assert.Equal(t, 10, len(scanKeys)) for i := 0; i < 10; i++ { client.Del(fmt.Sprintf("utest.key-del%d", i)) } @@ -84,64 +72,24 @@ func TestRedisWrapperPipeline(t *testing.T) { pipe.Decr("key-incr") pipe.Del([]string{"key-del1", "key-del2"}...) result, err := pipe.Exec() - if err != nil { - t.Error("there should not be any error. Got: ", err) - } - - if len(result) != 14 { - t.Error("there should be 13 elements") - } + assert.Nil(t, err) + assert.Equal(t, 14, len(result)) items, _ := result[0].Multi() assert.Equal(t, []string{"e1", "e2", "e3"}, items) - if l := result[1].Int(); l != 3 { - t.Error("length should be 3. is: ", l) - } - - if i := client.LLen("key1").Int(); i != 1 { - t.Error("new length should be 1. Is: ", i) - } - - if c := result[3].Int(); c != 5 { - t.Error("count should be 5. Is: ", c) - } - - if c := result[4].Int(); c != 4 { - t.Error("count should be 5. Is: ", c) - } - - if c := result[5].Int(); c != 7 { - t.Error("count should be 5. Is: ", c) - } - - if l := result[6].Int(); l != 2 { - t.Error("hlen should be 2. is: ", l) - } - - if ib := client.HIncrBy("key-test", "field-test", 1); ib.Int() != 6 { - t.Error("new count should be 6") - } - - if ib := client.Get("key-set"); ib.String() != "field-test-1" { - t.Error("it should be field-test-1") - } - - if c := result[8].Int(); c != 2 { - t.Error("count should be 2. Is: ", c) - } - if d, _ := result[9].Multi(); len(d) != 2 { - t.Error("count should be 2. Is: ", len(d)) - } - if c := result[10].Int(); c != 2 { - t.Error("count should be 2. Is: ", c) - } - if c := result[11].Int(); c != 1 { - t.Error("count should be 1. Is: ", c) - } - if c := result[12].Int(); c != 0 { - t.Error("count should be zero. Is: ", c) - } - if c := result[13].Int(); c != 2 { - t.Error("count should be 2. Is: ", c) - } + assert.Equal(t, int64(3), result[1].Int()) + assert.Equal(t, int64(1), client.LLen("key1").Int()) + assert.Equal(t, int64(5), result[3].Int()) + assert.Equal(t, int64(4), result[4].Int()) + assert.Equal(t, int64(7), result[5].Int()) + assert.Equal(t, int64(2), result[6].Int()) + assert.Equal(t, int64(6), client.HIncrBy("key-test", "field-test", 1).Int()) + assert.Equal(t, "field-test-1", client.Get("key-set").String()) + assert.Equal(t, int64(2), result[8].Int()) + d, _ := result[9].Multi() + assert.Equal(t, 2, len(d)) + assert.Equal(t, int64(2), result[10].Int()) + assert.Equal(t, int64(1), result[11].Int()) + assert.Equal(t, int64(0), result[12].Int()) + assert.Equal(t, int64(2), result[13].Int()) } diff --git a/sse/event_test.go b/sse/event_test.go index a64e257..d48806a 100644 --- a/sse/event_test.go +++ b/sse/event_test.go @@ -2,6 +2,9 @@ package sse import ( "testing" + + "github.com/stretchr/testify/assert" + ) func TestEventBuilder(t *testing.T) { @@ -13,30 +16,16 @@ func TestEventBuilder(t *testing.T) { builder.AddLine(":some Comment") e := builder.Build() - if e.Event() != "message" { - t.Error("event should be 'message'") - } - if e.Data() != "something" { - t.Error("data should be 'something'") - } - if e.ID() != "1234" { - t.Error("Id should be 1234") - } - if e.Retry() != 1 { - t.Error("retry should be 1234") - } - if e.IsEmpty() { - t.Error("event should not be empty") - } - if e.IsError() { - t.Error("event is not an error") - } + assert.Equal(t, "message", e.Event()) + assert.Equal(t, "something", e.Data()) + assert.Equal(t, "1234", e.ID()) + assert.Equal(t, int64(1), e.Retry()) + assert.False(t, e.IsEmpty()) + assert.False(t, e.IsEmpty()) builder.Reset() builder.AddLine("event: error") builder.AddLine("data: someError") e2 := builder.Build() - if !e2.IsError() { - t.Error("event is an error") - } + assert.True(t, e2.IsError()) } diff --git a/sse/mocks/mocks.go b/sse/mocks/mocks.go index b90ec8c..235f108 100644 --- a/sse/mocks/mocks.go +++ b/sse/mocks/mocks.go @@ -1,34 +1,31 @@ package mocks +import "github.com/stretchr/testify/mock" + type RawEventMock struct { - IDCall func() string - EventCall func() string - DataCall func() string - RetryCall func() int64 - IsErrorCall func() bool - IsEmptyCall func() bool + mock.Mock } func (r *RawEventMock) ID() string { - return r.IDCall() + return r.Called().String(0) } func (r *RawEventMock) Event() string { - return r.EventCall() + return r.Called().String(0) } func (r *RawEventMock) Data() string { - return r.DataCall() + return r.Called().String(0) } func (r *RawEventMock) Retry() int64 { - return r.RetryCall() + return r.Called().Get(0).(int64) } func (r *RawEventMock) IsError() bool { - return r.IsErrorCall() + return r.Called().Bool(0) } func (r *RawEventMock) IsEmpty() bool { - return r.IsEmptyCall() + return r.Called().Bool(0) } diff --git a/sse/sse_test.go b/sse/sse_test.go index 59e256b..a6a2ee2 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -1,7 +1,6 @@ package sse import ( - "errors" "fmt" "net/http" "net/http/httptest" @@ -10,16 +9,15 @@ import ( "time" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) func TestSSEErrorConnecting(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) client, _ := NewClient("", 120, 10, logger) err := client.Do(make(map[string]string), make(map[string]string), func(e RawEvent) { t.Error("It should not execute anything") }) - asErrConecting := &ErrConnectionFailed{} - if !errors.As(err, &asErrConecting) { - t.Errorf("Unexpected type of error: %+v", err) - } + _, ok := err.(*ErrConnectionFailed) + assert.True(t, ok) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -34,25 +32,19 @@ func TestSSEErrorConnecting(t *testing.T) { mockedClient.lifecycle.Setup() err = mockedClient.Do(make(map[string]string), make(map[string]string), func(e RawEvent) { - t.Error("Should not execute callback") + assert.Fail(t, "Should not execute callback") }) - if !errors.As(err, &asErrConecting) { - t.Errorf("Unexpected type of error: %+v", err) - } + _, ok = err.(*ErrConnectionFailed) + assert.True(t, ok) } func TestSSE(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("some") != "some" { - t.Error("It should send header") - } - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + assert.Equal(t, "some", r.Header.Get("some")) + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -79,18 +71,14 @@ func TestSSE(t *testing.T) { result = e mutextTest.Unlock() }) - if err != nil { - t.Error("sse client ended in error:", err) - } + assert.Nil(t, err) }() time.Sleep(2 * time.Second) mockedClient.Shutdown(true) mutextTest.RLock() - if result.Data() != `{"id":"YCh53QfLxO:0:0","data":"some","timestamp":1591911770828}` { - t.Error("Unexpected result: ", result.Data()) - } + assert.Equal(t, `{"id":"YCh53QfLxO:0:0","data":"some","timestamp":1591911770828}`, result.Data()) mutextTest.RUnlock() } @@ -103,11 +91,8 @@ func TestSSENoTimeout(t *testing.T) { finished := false mutexTest.Unlock() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -129,15 +114,11 @@ func TestSSENoTimeout(t *testing.T) { time.Sleep(1500 * time.Millisecond) mutexTest.RLock() - if finished { - t.Error("It should not be finished") - } + assert.False(t, finished) mutexTest.RUnlock() time.Sleep(1500 * time.Millisecond) mutexTest.RLock() - if !finished { - t.Error("It should be finished") - } + assert.True(t, finished) mutexTest.RUnlock() clientSSE.Shutdown(true) } @@ -146,11 +127,8 @@ func TestStopBlock(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -172,9 +150,7 @@ func TestStopBlock(t *testing.T) { waiter := make(chan struct{}, 1) go func() { err := mockedClient.Do(make(map[string]string), make(map[string]string), func(e RawEvent) {}) - if err != nil { - t.Error("sse client ended in error: ", err) - } + assert.Nil(t, err) waiter <- struct{}{} }() @@ -187,11 +163,8 @@ func TestConnectionEOF(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -211,60 +184,6 @@ func TestConnectionEOF(t *testing.T) { mockedClient.lifecycle.Setup() err := mockedClient.Do(make(map[string]string), make(map[string]string), func(e RawEvent) {}) - if err != ErrReadingStream { - t.Error("Should have triggered an ErrorReadingStreamError. Got: ", err) - } - + assert.ErrorIs(t, err, ErrReadingStream) mockedClient.Shutdown(true) } - -/* -func TestCustom(t *testing.T) { - url := `https://streaming.split.io/event-stream` - logger := logging.NewLogger(&logging.LoggerOptions{LogLevel: logging.LevelError, StandardLoggerFlags: log.Llongfile}) - client, _ := NewClient(url, 50, logger) - - ready := make(chan struct{}) - accessToken := `` - channels := "NzM2MDI5Mzc0_MTgyNTg1MTgwNg==_splits,[?occupancy=metrics.publishers]control_pri,[?occupancy=metrics.publishers]control_sec" - go func() { - err := client.Do( - map[string]string{ - "accessToken": accessToken, - "v": "1.1", - "channel": channels, - }, - func(e RawEvent) { - fmt.Printf("Event: %+v\n", e) - }) - if err != nil { - t.Error("sse error:", err) - } - ready <- struct{}{} - }() - time.Sleep(5 * time.Second) - client.Shutdown(true) - <-ready - fmt.Println(1) - go func() { - err := client.Do( - map[string]string{ - "accessToken": accessToken, - "v": "1.1", - "channel": channels, - }, - func(e RawEvent) { - fmt.Printf("Event: %+v\n", e) - }) - if err != nil { - t.Error("sse error:", err) - } - ready <- struct{}{} - }() - time.Sleep(5 * time.Second) - client.Shutdown(true) - <-ready - fmt.Println(2) - -} -*/ diff --git a/struct/jsonvalidator/validator_test.go b/struct/jsonvalidator/validator_test.go index c79eb74..3dea230 100644 --- a/struct/jsonvalidator/validator_test.go +++ b/struct/jsonvalidator/validator_test.go @@ -2,6 +2,8 @@ package jsonvalidator import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestLen0(t *testing.T) { @@ -13,12 +15,8 @@ func TestLen0(t *testing.T) { originChild := OriginChild{two: "Test", three: 1} err := ValidateConfiguration(originChild, nil) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "no configuration provided" { - t.Error("Wrong message") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "no configuration provided") } func TestSame(t *testing.T) { @@ -36,9 +34,7 @@ func TestSame(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10}}")) - if err != nil { - t.Error("Should not inform error") - } + assert.Nil(t, err) } func TestDifferentPropertyParent(t *testing.T) { @@ -56,12 +52,8 @@ func TestDifferentPropertyParent(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"four\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"four\" is not a valid property in configuration" { - t.Error("Wrong message") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"four\" is not a valid property in configuration") } func TestDifferentPropertyChild(t *testing.T) { @@ -79,12 +71,8 @@ func TestDifferentPropertyChild(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"four\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message", err.Error()) - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestDifferentParentAndChild(t *testing.T) { @@ -102,12 +90,8 @@ func TestDifferentParentAndChild(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"testChild\": {\"two\": \"test\", \"three\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"testChild\" is not a valid property in configuration" { - t.Error("Wrong message, it should inform parent") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"testChild\" is not a valid property in configuration") } func TestDifferentPropertyInChild(t *testing.T) { @@ -125,12 +109,8 @@ func TestDifferentPropertyInChild(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10, \"four\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message=", err.Error()) - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestDifferentPropertyInChildBool(t *testing.T) { @@ -148,12 +128,8 @@ func TestDifferentPropertyInChildBool(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10, \"four\": true}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message=") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestDifferentPropertyInChildNumber(t *testing.T) { @@ -171,12 +147,8 @@ func TestDifferentPropertyInChildNumber(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10, \"four\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message=") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestSameThirdLevel(t *testing.T) { @@ -200,11 +172,7 @@ func TestSameThirdLevel(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"child\": {\"two\": \"test\", \"three\": 10}, \"three\": 10}}")) - if err != nil { - t.Error(err.Error()) - - t.Error("Should not inform error") - } + assert.Nil(t, err) } func TestDifferenthirdLevel(t *testing.T) { @@ -228,10 +196,6 @@ func TestDifferenthirdLevel(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"child\": {\"t\": \"test\", \"three\": 10}, \"three\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.child.t\" is not a valid property in configuration" { - t.Error("Wrong message", err.Error()) - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.child.t\" is not a valid property in configuration") } diff --git a/struct/traits/lifecycle/lifecycle_test.go b/struct/traits/lifecycle/lifecycle_test.go index e0ee14f..52e7c6e 100644 --- a/struct/traits/lifecycle/lifecycle_test.go +++ b/struct/traits/lifecycle/lifecycle_test.go @@ -4,31 +4,19 @@ import ( "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestLifecycleManager(t *testing.T) { m := Manager{} m.Setup() - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } - - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } - - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } - - if !m.InitializationComplete() { - t.Error("should complete initialization correctly") - } - - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done := make(chan struct{}, 1) go func() { @@ -43,36 +31,22 @@ func TestLifecycleManager(t *testing.T) { } }() - if !m.BeginShutdown() { - t.Error("shutdown should be correctly propagated") - } - if m.BeginShutdown() { - t.Error("once shutdown is started, it should no longer propagate further requests") - } - m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } - <-done // ensure that await actually waits + assert.True(t, m.BeginShutdown()) + assert.False(t, m.BeginShutdown()) - // Start again + m.AwaitShutdownComplete() - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } + assert.False(t, m.IsRunning()) - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } + <-done // ensure that await actually waits - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } + // Start again - m.InitializationComplete() - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done = make(chan struct{}, 1) go func() { @@ -87,16 +61,13 @@ func TestLifecycleManager(t *testing.T) { } }() - if !m.BeginShutdown() { - t.Error("shutdown should be correctly propagated") - } - if m.BeginShutdown() { - t.Error("once shutdown is started, it should no longer propagate further requests") - } + assert.True(t, m.BeginShutdown()) + assert.False(t, m.BeginShutdown()) + m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + + assert.False(t, m.IsRunning()) + <-done // ensure that await actually waits } @@ -104,22 +75,11 @@ func TestLifecycleManagerAbnormalShutdown(t *testing.T) { m := Manager{} m.Setup() - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } - - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } - - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } - - m.InitializationComplete() - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done := make(chan struct{}, 1) go func() { @@ -134,30 +94,18 @@ func TestLifecycleManagerAbnormalShutdown(t *testing.T) { } }() + m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + assert.False(t, m.IsRunning()) <-done // ensure that await actually waits // Start again - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } - - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } - - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } - - m.InitializationComplete() - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done = make(chan struct{}, 1) go func() { @@ -172,17 +120,12 @@ func TestLifecycleManagerAbnormalShutdown(t *testing.T) { } }() - if !m.BeginShutdown() { - t.Error("shutdown should be correctly propagated") - } + assert.True(t, m.BeginShutdown()) + assert.False(t, m.BeginShutdown()) - if m.BeginShutdown() { - t.Error("once shutdown is started, it should no longer propagate further requests") - } m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + assert.False(t, m.IsRunning()) + <-done // ensure that await actually waits } @@ -190,15 +133,9 @@ func TestShutdownRequestWhileInitNotComplete(t *testing.T) { m := Manager{} m.Setup() - m.BeginInitialization() - if !m.BeginShutdown() { - t.Error("should accept the shutdown request") - } - - if m.InitializationComplete() { - t.Error("initialization cannot complete.") - } - + assert.True(t, m.BeginInitialization()) + assert.True(t, m.BeginShutdown()) + assert.False(t, m.InitializationComplete()) m.ShutdownComplete() // Now restart the lifecycle to see if it works properly @@ -221,14 +158,11 @@ func TestShutdownRequestWhileInitNotComplete(t *testing.T) { } } }() - m.BeginShutdown() + + assert.True(t, m.BeginShutdown()) m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + assert.False(t, m.IsRunning()) <-done // ensure that await actually waits - if atomic.LoadInt32(&executed) != 0 { - t.Error("the goroutine should have not executed further than the InitializationComplete check.") - } + assert.Equal(t, int32(0), atomic.LoadInt32(&executed)) } diff --git a/sync/atomicbool_test.go b/sync/atomicbool_test.go index 445ee8e..64d53c8 100644 --- a/sync/atomicbool_test.go +++ b/sync/atomicbool_test.go @@ -2,40 +2,20 @@ package sync import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestAtomicBool(t *testing.T) { a := NewAtomicBool(false) - if a.IsSet() { - t.Error("initial value should be false") - } - - if !a.TestAndSet() { - t.Error("compare and swap should succeed with no other concurrent access.") - } - - if a.TestAndSet() { - t.Error("compare and swap should return false if it didn't change anything.") - } - - if !a.IsSet() { - t.Error("should now be true") - } + assert.False(t, a.IsSet()) + assert.True(t, a.TestAndSet()) + assert.False(t, a.TestAndSet()) + assert.True(t, a.IsSet()) b := NewAtomicBool(true) - if !b.IsSet() { - t.Error("initial value should be true") - } - - if b.TestAndClear() != true { - t.Error("compare and swap should succeed with no other concurrent access.") - } - - if !a.TestAndClear() { - t.Error("compare and swap should return false if it didn't change anything.") - } - - if b.IsSet() { - t.Error("should now be false") - } + assert.True(t, b.IsSet()) + assert.True(t, b.TestAndClear()) + assert.False(t, b.TestAndClear()) + assert.False(t, b.IsSet()) } diff --git a/workerpool/workerpool_test.go b/workerpool/workerpool_test.go index 9606b7f..5ea67ad 100644 --- a/workerpool/workerpool_test.go +++ b/workerpool/workerpool_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) var resMutex sync.RWMutex @@ -55,23 +56,17 @@ func TestWorkerAdminConstructionAndNormalOperation(t *testing.T) { } resMutex.RLock() - if results["worker_2"] > 10 { - t.Error("Worker should have stopped working!") - } + assert.Less(t, results["worker_2"], 10) resMutex.RUnlock() + time.Sleep(time.Second * 1) errs := wa.StopAll(false) - if errs != nil { - t.Error("Not all workers stopped properly") - t.Error(errs) - } + assert.Nil(t, errs) time.Sleep(time.Second * 1) for _, i := range []int{1, 2, 3} { wName := fmt.Sprintf("worker_%d", i) - if wa.IsWorkerRunning(wName) { - t.Errorf("Worker %s should be stopped", wName) - } + assert.False(t, wa.IsWorkerRunning(wName)) } } @@ -131,21 +126,15 @@ func TestWaitingForWorkersToFinish(t *testing.T) { } resMutex.RLock() - if results["worker_2"] > 10 { - t.Error("Worker should have stopped working!") - } + assert.Less(t, results["worker_2"], 10) resMutex.RUnlock() time.Sleep(time.Second * 1) + errs := wa.StopAll(true) - if errs != nil { - t.Error("Not all workers stopped properly") - t.Error(errs) - } + assert.Nil(t, errs) for _, i := range []int{1, 2, 3, 4} { wName := fmt.Sprintf("worker_%d", i) - if wa.IsWorkerRunning(wName) { - t.Errorf("Worker %s should be stopped", wName) - } + assert.False(t, wa.IsWorkerRunning(wName)) } }