From 3e641a5da9247ac45218514210ab7f388dc41001 Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Fri, 15 Dec 2023 15:42:08 +0000 Subject: [PATCH] Consistent hashing: consistently retry another host (#119) * Consistent hashing: consistently retry another host This implements a retry policy to consistently retry another host within the SRV set if we get a failure talking to the original host. We still use the original fallback strategy if this retry fails. (Should we?) * extract consistent hashing code This extracts the consistent hashing code into a new package and writes more tests for it. This removes a bunch of duplication from ConsistentHashingMode. As a consequence this jumbles the consistent hashing algorithm (because I've embedded the Key into a top-level CacheKey struct which includes Attempt for retry support). At this stage this is safe because nothing live is using it. * refactor rename method avoid implicit return remove unneeded conditional on m.DomainsToCache * lint --- pkg/consistent/consistent.go | 45 ++++++++++ pkg/consistent/consistent_test.go | 74 ++++++++++++++++ pkg/download/consistent_hashing.go | 77 +++++++++-------- pkg/download/consistent_hashing_test.go | 107 ++++++++++++++++++++---- 4 files changed, 251 insertions(+), 52 deletions(-) create mode 100644 pkg/consistent/consistent.go create mode 100644 pkg/consistent/consistent_test.go diff --git a/pkg/consistent/consistent.go b/pkg/consistent/consistent.go new file mode 100644 index 0000000..0171a7d --- /dev/null +++ b/pkg/consistent/consistent.go @@ -0,0 +1,45 @@ +// Package consistent implements consistent hashing for cache nodes. +package consistent + +import ( + "fmt" + "slices" + + "github.com/dgryski/go-jump" + "github.com/mitchellh/hashstructure/v2" +) + +type cacheKey struct { + Key any + Attempt int +} + +// HashBucket returns a bucket from [0,buckets). If you want to implement a +// retry, you can pass previousBuckets, which indicates buckets which must be +// avoided in the output. HashBucket will modify the previousBuckets slice by +// sorting it. +func HashBucket(key any, buckets int, previousBuckets ...int) (int, error) { + if len(previousBuckets) >= buckets { + return -1, fmt.Errorf("No more buckets left: %d buckets available but %d already attempted", buckets, previousBuckets) + } + // we set IgnoreZeroValue so that we can add fields to the hash key + // later without breaking things. + // note that it's not safe to share a HashOptions so we create a fresh one each time. + hashopts := &hashstructure.HashOptions{IgnoreZeroValue: true} + hash, err := hashstructure.Hash(cacheKey{Key: key, Attempt: len(previousBuckets)}, hashstructure.FormatV2, hashopts) + if err != nil { + return -1, fmt.Errorf("error calculating hash of key: %w", err) + } + + // jump is an implementation of Google's Jump Consistent Hash. + // + // See http://arxiv.org/abs/1406.2294 for details. + bucket := int(jump.Hash(hash, buckets-len(previousBuckets))) + slices.Sort(previousBuckets) + for _, prev := range previousBuckets { + if bucket >= prev { + bucket++ + } + } + return bucket, nil +} diff --git a/pkg/consistent/consistent_test.go b/pkg/consistent/consistent_test.go new file mode 100644 index 0000000..086d764 --- /dev/null +++ b/pkg/consistent/consistent_test.go @@ -0,0 +1,74 @@ +package consistent_test + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/pget/pkg/consistent" +) + +func TestHashingDoesNotChangeWhenZeroValueFieldsAreAdded(t *testing.T) { + a, err := consistent.HashBucket(struct{}{}, 1024) + require.NoError(t, err) + b, err := consistent.HashBucket(struct{ I int }{}, 1024) + require.NoError(t, err) + + assert.Equal(t, a, b) +} + +func TestRetriesScatterBuckets(t *testing.T) { + // This test is tricky! We want an example of hash keys which map to the + // same bucket, but after one retry map to different buckets. + // + // These two keys happen to have this property for 10 buckets: + strA := "abcdefg" + strB := "1234567" + a, err := consistent.HashBucket(strA, 10) + require.NoError(t, err) + b, err := consistent.HashBucket(strB, 10) + require.NoError(t, err) + + // strA and strB to map to the same bucket + require.Equal(t, a, b) + + aRetry, err := consistent.HashBucket(strA, 10, a) + require.NoError(t, err) + bRetry, err := consistent.HashBucket(strB, 10, b) + require.NoError(t, err) + + // but after retry they map to different buckets + assert.NotEqual(t, aRetry, bRetry) +} + +func FuzzRetriesMostNotRepeatIndices(f *testing.F) { + f.Add("test.replicate.delivery", 5) + f.Add("test.replicate.delivery", 0) + f.Fuzz(func(t *testing.T, key string, excessBuckets int) { + if excessBuckets < 0 { + t.Skip("invalid value") + } + attempts := 20 + buckets := attempts + excessBuckets + if buckets < 0 { + t.Skip("integer overflow") + } + previous := []int{} + for i := 0; i < attempts; i++ { + next, err := consistent.HashBucket(key, buckets, previous...) + require.NoError(t, err) + + // we must be in range + assert.Less(t, next, buckets) + assert.GreaterOrEqual(t, next, 0) + + // we shouldn't repeat any previous value + assert.NotContains(t, previous, next) + + previous = append(previous, next) + slices.Sort(previous) + } + }) +} diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index 311bbf1..ed147b4 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -9,12 +9,11 @@ import ( "net/url" "strconv" - jump "github.com/dgryski/go-jump" - "github.com/mitchellh/hashstructure/v2" "golang.org/x/sync/errgroup" "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/config" + "github.com/replicate/pget/pkg/consistent" "github.com/replicate/pget/pkg/logging" ) @@ -252,7 +251,7 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, if err != nil { return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err) } - err = m.consistentHashIfNeeded(req, start, end) + cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end) if err != nil { return nil, err } @@ -262,7 +261,28 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, resp, err := m.Client.Do(req) if err != nil { - return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err) + if errors.Is(err, client.ErrStrategyFallback) { + origErr := err + req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) + if err != nil { + return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err) + } + _, err = m.rewriteRequestToCacheHost(req, start, end, cachePodIndex) + if err != nil { + // return origErr so that we can use our regular fallback strategy + return nil, origErr + } + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request") + + resp, err = m.Client.Do(req) + if err != nil { + // return origErr so that we can use our regular fallback strategy + return nil, origErr + } + } else { + return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err) + } } if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status) @@ -271,39 +291,24 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, return resp, nil } -func (m *ConsistentHashingMode) consistentHashIfNeeded(req *http.Request, start int64, end int64) error { +func (m *ConsistentHashingMode) rewriteRequestToCacheHost(req *http.Request, start int64, end int64, previousPodIndexes ...int) (int, error) { logger := logging.GetLogger() - for _, host := range m.DomainsToCache { - if host == req.URL.Host { - if start/m.SliceSize != end/m.SliceSize { - return fmt.Errorf("can't make a range request across a slice boundary: %d-%d straddles a slice boundary (slice size is %d)", start, end, m.SliceSize) - } - slice := start / m.SliceSize - - key := CacheKey{URL: req.URL, Slice: slice} - // we set IgnoreZeroValue so that we can add fields to the hash key - // later without breaking things. - // note that it's not safe to share a HashOptions so we create a fresh one each time. - hashopts := &hashstructure.HashOptions{IgnoreZeroValue: true} - hash, err := hashstructure.Hash(key, hashstructure.FormatV2, hashopts) - if err != nil { - return fmt.Errorf("error calculating hash of key") - } + if start/m.SliceSize != end/m.SliceSize { + return 0, fmt.Errorf("can't make a range request across a slice boundary: %d-%d straddles a slice boundary (slice size is %d)", start, end, m.SliceSize) + } + slice := start / m.SliceSize - logger.Debug().Uint64("hash_sum", hash).Int("len_cache_hosts", len(m.CacheHosts)).Msg("consistent hashing") - - // jump is an implementation of Google's Jump Consistent Hash. - // - // See http://arxiv.org/abs/1406.2294 for details. - cachePodIndex := int(jump.Hash(hash, len(m.CacheHosts))) - cacheHost := m.CacheHosts[cachePodIndex] - logger.Debug().Str("cache_key", fmt.Sprintf("%+v", key)).Int64("start", start).Int64("end", end).Int64("slice_size", m.SliceSize).Int("bucket", cachePodIndex).Msg("consistent hashing") - if cacheHost != "" { - req.URL.Scheme = "http" - req.URL.Host = cacheHost - } - return nil - } + key := CacheKey{URL: req.URL, Slice: slice} + + cachePodIndex, err := consistent.HashBucket(key, len(m.CacheHosts), previousPodIndexes...) + if err != nil { + return -1, err + } + cacheHost := m.CacheHosts[cachePodIndex] + logger.Debug().Str("cache_key", fmt.Sprintf("%+v", key)).Int64("start", start).Int64("end", end).Int64("slice_size", m.SliceSize).Int("bucket", cachePodIndex).Msg("consistent hashing") + if cacheHost != "" { + req.URL.Scheme = "http" + req.URL.Host = cacheHost } - return nil + return cachePodIndex, nil } diff --git a/pkg/download/consistent_hashing_test.go b/pkg/download/consistent_hashing_test.go index 4d7e568..b6a3456 100644 --- a/pkg/download/consistent_hashing_test.go +++ b/pkg/download/consistent_hashing_test.go @@ -53,7 +53,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 2, minChunkSize: 1, - expectedOutput: "0001110000001110", + expectedOutput: "1111110000000000", }, { name: "3 hosts", @@ -61,7 +61,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 3, minChunkSize: 1, - expectedOutput: "0001110002221110", + expectedOutput: "2221110000002222", }, { name: "4 hosts", @@ -69,7 +69,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 4, minChunkSize: 1, - expectedOutput: "0001113333331110", + expectedOutput: "3331113333332222", }, { name: "5 hosts", @@ -77,7 +77,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 5, minChunkSize: 1, - expectedOutput: "0001114443331110", + expectedOutput: "3334443333332224", }, { name: "6 hosts", @@ -85,7 +85,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 6, minChunkSize: 1, - expectedOutput: "0001114443331115", + expectedOutput: "3334443333335554", }, { name: "7 hosts", @@ -93,7 +93,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 7, minChunkSize: 1, - expectedOutput: "0006664443336665", + expectedOutput: "3334446666665556", }, { name: "8 hosts", @@ -101,7 +101,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 8, minChunkSize: 1, - expectedOutput: "0006664443336667", + expectedOutput: "3334446666667776", }, { name: "test when fileSize % sliceSize == 0", @@ -109,7 +109,7 @@ var chTestCases = []chTestCase{ sliceSize: 4, numCacheHosts: 8, minChunkSize: 1, - expectedOutput: "0000666644443333", + expectedOutput: "3333444466666666", }, { name: "when minChunkSize == sliceSize", @@ -117,7 +117,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 8, minChunkSize: 3, - expectedOutput: "0006664443336667", + expectedOutput: "3334446666667776", }, { name: "test when concurrency > file size", @@ -125,7 +125,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 8, minChunkSize: 3, - expectedOutput: "0006664443336667", + expectedOutput: "3334446666667776", }, { name: "test when concurrency < number of slices", @@ -133,7 +133,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 8, minChunkSize: 3, - expectedOutput: "0006664443336667", + expectedOutput: "3334446666667776", }, { name: "test when minChunkSize == file size", @@ -141,7 +141,7 @@ var chTestCases = []chTestCase{ sliceSize: 16, numCacheHosts: 8, minChunkSize: 16, - expectedOutput: "0000000000000000", + expectedOutput: "3333333333333333", }, { name: "test when minChunkSize > file size", @@ -149,7 +149,7 @@ var chTestCases = []chTestCase{ sliceSize: 24, numCacheHosts: 8, minChunkSize: 24, - expectedOutput: "0000000000000000", + expectedOutput: "3333333333333333", }, { name: "if minChunkSize > sliceSize, sliceSize overrides it", @@ -157,7 +157,7 @@ var chTestCases = []chTestCase{ sliceSize: 3, numCacheHosts: 8, minChunkSize: 24, - expectedOutput: "0006664443336667", + expectedOutput: "3334446666667776", }, } @@ -178,7 +178,7 @@ func TestConsistentHashing(t *testing.T) { MaxConcurrency: tc.concurrency, MinChunkSize: tc.minChunkSize, CacheHosts: hostnames[0:tc.numCacheHosts], - DomainsToCache: []string{"test.replicate.delivery"}, + DomainsToCache: []string{"test.replicate.com"}, SliceSize: tc.sliceSize, } @@ -188,7 +188,8 @@ func TestConsistentHashing(t *testing.T) { strategy, err := download.GetConsistentHashingMode(opts) assert.NoError(t, err) - reader, _, err := strategy.Fetch(ctx, "http://test.replicate.delivery/hello.txt") + assert.Equal(t, tc.numCacheHosts, len(strategy.Options.CacheHosts)) + reader, _, err := strategy.Fetch(ctx, "http://test.replicate.com/hello.txt") assert.NoError(t, err) bytes, err := io.ReadAll(reader) assert.NoError(t, err) @@ -198,6 +199,80 @@ func TestConsistentHashing(t *testing.T) { } } +func TestConsistentHashRetries(t *testing.T) { + hostnames := make([]string, len(testFSes)) + for i, fs := range testFSes { + ts := httptest.NewServer(http.FileServer(http.FS(fs))) + defer ts.Close() + url, err := url.Parse(ts.URL) + require.NoError(t, err) + hostnames[i] = url.Host + } + // deliberately "break" one cache host + hostnames[0] = "localhost:1" + + opts := download.Options{ + Client: client.Options{}, + MaxConcurrency: 8, + MinChunkSize: 1, + CacheHosts: hostnames, + DomainsToCache: []string{"fake.replicate.delivery"}, + SliceSize: 1, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) + + reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") + assert.NoError(t, err) + bytes, err := io.ReadAll(reader) + assert.NoError(t, err) + + // with a functional hostnames[0], we'd see 0344760706165500, but instead we + // should fall back to this. Note that each 0 value has been changed to a + // different index; we don't want every request that previously hit 0 to hit + // the same new host. + assert.Equal(t, "3344761726165516", string(bytes)) +} + +// with only two hosts, we should *always* fall back to the other host +func TestConsistentHashRetriesTwoHosts(t *testing.T) { + hostnames := make([]string, 2) + for i, fs := range testFSes[0:1] { + ts := httptest.NewServer(http.FileServer(http.FS(fs))) + defer ts.Close() + url, err := url.Parse(ts.URL) + require.NoError(t, err) + hostnames[i] = url.Host + } + hostnames[1] = "localhost:1" + + opts := download.Options{ + Client: client.Options{}, + MaxConcurrency: 8, + MinChunkSize: 1, + CacheHosts: hostnames, + DomainsToCache: []string{"testing.replicate.delivery"}, + SliceSize: 1, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) + + reader, _, err := strategy.Fetch(ctx, "http://testing.replicate.delivery/hello.txt") + assert.NoError(t, err) + bytes, err := io.ReadAll(reader) + assert.NoError(t, err) + + assert.Equal(t, "0000000000000000", string(bytes)) +} + func TestConsistentHashingHasFallback(t *testing.T) { server := httptest.NewServer(http.FileServer(http.FS(testFSes[0]))) defer server.Close()