Skip to content

Commit

Permalink
Consistent hashing: consistently retry another host (#119)
Browse files Browse the repository at this point in the history
* Consistent hashing: consistently retry another host

This implements a retry policy to consistently retry another host within the SRV
set if we get a failure talking to the original host.

We still use the original fallback strategy if this retry fails. (Should we?)

* extract consistent hashing code

This extracts the consistent hashing code into a new package and writes more
tests for it.

This removes a bunch of duplication from ConsistentHashingMode.

As a consequence this jumbles the consistent hashing algorithm (because I've
embedded the Key into a top-level CacheKey struct which includes Attempt for
retry support). At this stage this is safe because nothing live is using it.

* refactor

rename method

avoid implicit return

remove unneeded conditional on m.DomainsToCache

* lint
  • Loading branch information
philandstuff authored Dec 15, 2023
1 parent 5eceeda commit 3e641a5
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 52 deletions.
45 changes: 45 additions & 0 deletions pkg/consistent/consistent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Package consistent implements consistent hashing for cache nodes.
package consistent

import (
"fmt"
"slices"

"github.com/dgryski/go-jump"
"github.com/mitchellh/hashstructure/v2"
)

type cacheKey struct {
Key any
Attempt int
}

// HashBucket returns a bucket from [0,buckets). If you want to implement a
// retry, you can pass previousBuckets, which indicates buckets which must be
// avoided in the output. HashBucket will modify the previousBuckets slice by
// sorting it.
func HashBucket(key any, buckets int, previousBuckets ...int) (int, error) {
if len(previousBuckets) >= buckets {
return -1, fmt.Errorf("No more buckets left: %d buckets available but %d already attempted", buckets, previousBuckets)
}
// we set IgnoreZeroValue so that we can add fields to the hash key
// later without breaking things.
// note that it's not safe to share a HashOptions so we create a fresh one each time.
hashopts := &hashstructure.HashOptions{IgnoreZeroValue: true}
hash, err := hashstructure.Hash(cacheKey{Key: key, Attempt: len(previousBuckets)}, hashstructure.FormatV2, hashopts)
if err != nil {
return -1, fmt.Errorf("error calculating hash of key: %w", err)
}

// jump is an implementation of Google's Jump Consistent Hash.
//
// See http://arxiv.org/abs/1406.2294 for details.
bucket := int(jump.Hash(hash, buckets-len(previousBuckets)))
slices.Sort(previousBuckets)
for _, prev := range previousBuckets {
if bucket >= prev {
bucket++
}
}
return bucket, nil
}
74 changes: 74 additions & 0 deletions pkg/consistent/consistent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package consistent_test

import (
"slices"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/replicate/pget/pkg/consistent"
)

func TestHashingDoesNotChangeWhenZeroValueFieldsAreAdded(t *testing.T) {
a, err := consistent.HashBucket(struct{}{}, 1024)
require.NoError(t, err)
b, err := consistent.HashBucket(struct{ I int }{}, 1024)
require.NoError(t, err)

assert.Equal(t, a, b)
}

func TestRetriesScatterBuckets(t *testing.T) {
// This test is tricky! We want an example of hash keys which map to the
// same bucket, but after one retry map to different buckets.
//
// These two keys happen to have this property for 10 buckets:
strA := "abcdefg"
strB := "1234567"
a, err := consistent.HashBucket(strA, 10)
require.NoError(t, err)
b, err := consistent.HashBucket(strB, 10)
require.NoError(t, err)

// strA and strB to map to the same bucket
require.Equal(t, a, b)

aRetry, err := consistent.HashBucket(strA, 10, a)
require.NoError(t, err)
bRetry, err := consistent.HashBucket(strB, 10, b)
require.NoError(t, err)

// but after retry they map to different buckets
assert.NotEqual(t, aRetry, bRetry)
}

func FuzzRetriesMostNotRepeatIndices(f *testing.F) {
f.Add("test.replicate.delivery", 5)
f.Add("test.replicate.delivery", 0)
f.Fuzz(func(t *testing.T, key string, excessBuckets int) {
if excessBuckets < 0 {
t.Skip("invalid value")
}
attempts := 20
buckets := attempts + excessBuckets
if buckets < 0 {
t.Skip("integer overflow")
}
previous := []int{}
for i := 0; i < attempts; i++ {
next, err := consistent.HashBucket(key, buckets, previous...)
require.NoError(t, err)

// we must be in range
assert.Less(t, next, buckets)
assert.GreaterOrEqual(t, next, 0)

// we shouldn't repeat any previous value
assert.NotContains(t, previous, next)

previous = append(previous, next)
slices.Sort(previous)
}
})
}
77 changes: 41 additions & 36 deletions pkg/download/consistent_hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ import (
"net/url"
"strconv"

jump "github.com/dgryski/go-jump"
"github.com/mitchellh/hashstructure/v2"
"golang.org/x/sync/errgroup"

"github.com/replicate/pget/pkg/client"
"github.com/replicate/pget/pkg/config"
"github.com/replicate/pget/pkg/consistent"
"github.com/replicate/pget/pkg/logging"
)

Expand Down Expand Up @@ -252,7 +251,7 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64,
if err != nil {
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
}
err = m.consistentHashIfNeeded(req, start, end)
cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end)
if err != nil {
return nil, err
}
Expand All @@ -262,7 +261,28 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64,

resp, err := m.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err)
if errors.Is(err, client.ErrStrategyFallback) {
origErr := err
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
if err != nil {
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
}
_, err = m.rewriteRequestToCacheHost(req, start, end, cachePodIndex)
if err != nil {
// return origErr so that we can use our regular fallback strategy
return nil, origErr
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request")

resp, err = m.Client.Do(req)
if err != nil {
// return origErr so that we can use our regular fallback strategy
return nil, origErr
}
} else {
return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err)
}
}
if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status)
Expand All @@ -271,39 +291,24 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64,
return resp, nil
}

func (m *ConsistentHashingMode) consistentHashIfNeeded(req *http.Request, start int64, end int64) error {
func (m *ConsistentHashingMode) rewriteRequestToCacheHost(req *http.Request, start int64, end int64, previousPodIndexes ...int) (int, error) {
logger := logging.GetLogger()
for _, host := range m.DomainsToCache {
if host == req.URL.Host {
if start/m.SliceSize != end/m.SliceSize {
return fmt.Errorf("can't make a range request across a slice boundary: %d-%d straddles a slice boundary (slice size is %d)", start, end, m.SliceSize)
}
slice := start / m.SliceSize

key := CacheKey{URL: req.URL, Slice: slice}
// we set IgnoreZeroValue so that we can add fields to the hash key
// later without breaking things.
// note that it's not safe to share a HashOptions so we create a fresh one each time.
hashopts := &hashstructure.HashOptions{IgnoreZeroValue: true}
hash, err := hashstructure.Hash(key, hashstructure.FormatV2, hashopts)
if err != nil {
return fmt.Errorf("error calculating hash of key")
}
if start/m.SliceSize != end/m.SliceSize {
return 0, fmt.Errorf("can't make a range request across a slice boundary: %d-%d straddles a slice boundary (slice size is %d)", start, end, m.SliceSize)
}
slice := start / m.SliceSize

logger.Debug().Uint64("hash_sum", hash).Int("len_cache_hosts", len(m.CacheHosts)).Msg("consistent hashing")

// jump is an implementation of Google's Jump Consistent Hash.
//
// See http://arxiv.org/abs/1406.2294 for details.
cachePodIndex := int(jump.Hash(hash, len(m.CacheHosts)))
cacheHost := m.CacheHosts[cachePodIndex]
logger.Debug().Str("cache_key", fmt.Sprintf("%+v", key)).Int64("start", start).Int64("end", end).Int64("slice_size", m.SliceSize).Int("bucket", cachePodIndex).Msg("consistent hashing")
if cacheHost != "" {
req.URL.Scheme = "http"
req.URL.Host = cacheHost
}
return nil
}
key := CacheKey{URL: req.URL, Slice: slice}

cachePodIndex, err := consistent.HashBucket(key, len(m.CacheHosts), previousPodIndexes...)
if err != nil {
return -1, err
}
cacheHost := m.CacheHosts[cachePodIndex]
logger.Debug().Str("cache_key", fmt.Sprintf("%+v", key)).Int64("start", start).Int64("end", end).Int64("slice_size", m.SliceSize).Int("bucket", cachePodIndex).Msg("consistent hashing")
if cacheHost != "" {
req.URL.Scheme = "http"
req.URL.Host = cacheHost
}
return nil
return cachePodIndex, nil
}
Loading

0 comments on commit 3e641a5

Please sign in to comment.