diff --git a/batch.go b/batch.go index 808ed20..5827524 100644 --- a/batch.go +++ b/batch.go @@ -296,6 +296,7 @@ func (c *Client) strongFetchBatch(ctx context.Context, keys []string, expire tim select { case <-ctx.Done(): ch <- pair{idx: i, err: ctx.Err()} + return case <-ticker.C: // equal to time.Sleep(c.Options.LockSleep) but can be canceled } diff --git a/batch_cover_test.go b/batch_cover_test.go index b1c2fa6..d387e5e 100644 --- a/batch_cover_test.go +++ b/batch_cover_test.go @@ -123,3 +123,70 @@ func TestTagAsDeletedBatchWait(t *testing.T) { assert.Error(t, err, fmt.Errorf("wait replicas 1 failed. result replicas: 0")) } } + +func TestWeakFetchBatchCanceled(t *testing.T) { + clearCache() + rc := NewClient(rdb, NewDefaultOptions()) + n := int(rand.Int31n(20) + 10) + idxs := genIdxs(n) + keys, values1, values2 := genKeys(idxs), genValues(n, "value_"), genValues(n, "eulav_") + values3 := genValues(n, "vvvv_") + go func() { + dc2 := NewClient(rdb, NewDefaultOptions()) + v, err := dc2.FetchBatch(keys, 60*time.Second, genBatchDataFunc(values1, 400)) + assert.Nil(t, err) + assert.Equal(t, values1, v) + }() + time.Sleep(20 * time.Millisecond) + + began := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err := rc.FetchBatch2(ctx, keys, 60*time.Second, genBatchDataFunc(values2, 200)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) + + ctx, cancel = context.WithCancel(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + began = time.Now() + _, err = rc.FetchBatch2(ctx, keys, 60*time.Second, genBatchDataFunc(values3, 200)) + assert.ErrorIs(t, err, context.Canceled) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) +} + +func TestStrongFetchBatchCanceled(t *testing.T) { + clearCache() + rc := NewClient(rdb, NewDefaultOptions()) + rc.Options.StrongConsistency = true + n := int(rand.Int31n(20) + 10) + idxs := genIdxs(n) + keys, values1, values2 := genKeys(idxs), genValues(n, "value_"), genValues(n, "eulav_") + values3 := genValues(n, "vvvv_") + go func() { + dc2 := NewClient(rdb, NewDefaultOptions()) + v, err := dc2.FetchBatch(keys, 60*time.Second, genBatchDataFunc(values1, 400)) + assert.Nil(t, err) + assert.Equal(t, values1, v) + }() + time.Sleep(20 * time.Millisecond) + + began := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err := rc.FetchBatch2(ctx, keys, 60*time.Second, genBatchDataFunc(values2, 200)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) + + ctx, cancel = context.WithCancel(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + began = time.Now() + _, err = rc.FetchBatch2(ctx, keys, 60*time.Second, genBatchDataFunc(values3, 200)) + assert.ErrorIs(t, err, context.Canceled) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) +} diff --git a/client_test.go b/client_test.go index 82476ca..644a5f7 100644 --- a/client_test.go +++ b/client_test.go @@ -147,6 +147,37 @@ func TestStrongErrorFetch(t *testing.T) { assert.True(t, time.Since(began) < time.Duration(150)*time.Millisecond) } +func TestStrongFetchCanceled(t *testing.T) { + clearCache() + rc := NewClient(rdb, NewDefaultOptions()) + rc.Options.StrongConsistency = true + expected := "value1" + go func() { + dc2 := NewClient(rdb, NewDefaultOptions()) + v, err := dc2.Fetch(rdbKey, 60*time.Second, genDataFunc(expected, 400)) + assert.Nil(t, err) + assert.Equal(t, expected, v) + }() + time.Sleep(20 * time.Millisecond) + + began := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err := rc.Fetch2(ctx, rdbKey, 60*time.Second, genDataFunc(expected, 200)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) + + ctx, cancel = context.WithCancel(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + began = time.Now() + _, err = rc.Fetch2(ctx, rdbKey, 60*time.Second, genDataFunc(expected, 200)) + assert.ErrorIs(t, err, context.Canceled) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) +} + func TestWeakErrorFetch(t *testing.T) { rc := NewClient(rdb, NewDefaultOptions()) @@ -165,6 +196,37 @@ func TestWeakErrorFetch(t *testing.T) { assert.True(t, time.Since(began) < time.Duration(150)*time.Millisecond) } +func TestWeakFetchCanceled(t *testing.T) { + rc := NewClient(rdb, NewDefaultOptions()) + + clearCache() + expected := "value1" + go func() { + dc2 := NewClient(rdb, NewDefaultOptions()) + v, err := dc2.Fetch(rdbKey, 60*time.Second, genDataFunc(expected, 400)) + assert.Nil(t, err) + assert.Equal(t, expected, v) + }() + time.Sleep(20 * time.Millisecond) + + began := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err := rc.Fetch2(ctx, rdbKey, 60*time.Second, genDataFunc(expected, 200)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) + + ctx, cancel = context.WithCancel(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + began = time.Now() + _, err = rc.Fetch2(ctx, rdbKey, 60*time.Second, genDataFunc(expected, 200)) + assert.ErrorIs(t, err, context.Canceled) + assert.True(t, time.Since(began) < time.Duration(110)*time.Millisecond) +} + func TestRawGet(t *testing.T) { rc := NewClient(rdb, NewDefaultOptions()) _, err := rc.RawGet(ctx, "not-exists")