From 1954bede0614b10b7c23a2d6bf267b0994903b9c Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Tue, 14 May 2024 16:24:11 -0500 Subject: [PATCH] Fixed race conditions in leakybucket --- algorithms.go | 219 ++++++++++++++++++++++++--------------------- cache.go | 1 - cache_manager.go | 19 ++-- lrucache.go | 20 ----- mock_cache_test.go | 5 -- otter.go | 23 ----- store.go | 37 -------- store_test.go | 5 +- 8 files changed, 131 insertions(+), 198 deletions(-) diff --git a/algorithms.go b/algorithms.go index a118294..a84c84c 100644 --- a/algorithms.go +++ b/algorithms.go @@ -34,8 +34,6 @@ type rateContext struct { CacheItem *CacheItem Store Store Cache Cache - // TODO: Remove - InstanceID string } // ### NOTE ### @@ -50,8 +48,6 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() var ok bool - // TODO: Remove - //fmt.Printf("[%s] tokenBucket()\n", ctx.InstanceID) // Get rate limit from cache hashKey := ctx.Request.HashKey() @@ -69,9 +65,8 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { } // If no item was found, or the item is expired. - if ctx.CacheItem == nil || ctx.CacheItem.IsExpired() { - // Initialize the Token bucket item - rl, err := InitTokenBucketItem(ctx) + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initTokenBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { // Someone else added a new token bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. @@ -82,7 +77,6 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { // Gain exclusive rights to this item while we calculate the rate limit ctx.CacheItem.mutex.Lock() - defer ctx.CacheItem.mutex.Unlock() t, ok := ctx.CacheItem.Value.(*TokenBucketItem) if !ok { @@ -91,15 +85,18 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { if ctx.Store != nil { ctx.Store.Remove(ctx, hashKey) } + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() ctx.CacheItem = nil - - rl, err := InitTokenBucketItem(ctx) + rl, err := initTokenBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { return tokenBucket(ctx) } return rl, err } + defer ctx.CacheItem.mutex.Unlock() + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { t.Remaining = ctx.Request.Limit t.Limit = ctx.Request.Limit @@ -213,8 +210,8 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { return rl, nil } -// InitTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. -func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { +// initTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. +func initTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { createdAt := *ctx.Request.CreatedAt expire := createdAt + ctx.Request.Duration @@ -254,14 +251,14 @@ func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { // If the cache item already exists, update it if ctx.CacheItem != nil { ctx.CacheItem.mutex.Lock() - ctx.CacheItem.Algorithm = Algorithm_TOKEN_BUCKET + ctx.CacheItem.Algorithm = ctx.Request.Algorithm ctx.CacheItem.ExpireAt = expire in, ok := ctx.CacheItem.Value.(*TokenBucketItem) if !ok { - // Likely the store gave us the wrong cache type + // Likely the store gave us the wrong cache item ctx.CacheItem.mutex.Unlock() ctx.CacheItem = nil - return InitTokenBucketItem(ctx) + return initTokenBucketItem(ctx) } *in = t ctx.CacheItem.mutex.Unlock() @@ -286,134 +283,136 @@ func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { } // Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket -func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { +func leakyBucket(ctx rateContext) (resp *RateLimitResp, err error) { leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() + var ok bool - // TODO(thrawn01): Test for race conditions, and fix - - if r.Burst == 0 { - r.Burst = r.Limit + if ctx.Request.Burst == 0 { + ctx.Request.Burst = ctx.Request.Limit } - createdAt := *r.CreatedAt - // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) - if s != nil && !ok { + if ctx.Store != nil && !ok { // Cache missed, check our store for the item. - if item, ok = s.Get(ctx, r); ok { - if !c.Add(item) { + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.Add(ctx.CacheItem) { // Someone else added a new leaky bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. - return leakyBucket(ctx, s, c, r, reqState) + return leakyBucket(ctx) } } } - if !ok { - rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + // If no item was found, or the item is expired. + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initLeakyBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { // Someone else added a new leaky bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. - return leakyBucket(ctx, s, c, r, reqState) + return leakyBucket(ctx) } return rl, err } + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() + // Item found in cache or store. - b, ok := item.Value.(*LeakyBucketItem) + t, ok := ctx.CacheItem.Value.(*LeakyBucketItem) if !ok { // Client switched algorithms; perhaps due to a migration? - c.Remove(hashKey) - if s != nil { - s.Remove(ctx, hashKey) + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - - rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + rl, err := initLeakyBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { - return leakyBucket(ctx, s, c, r, reqState) + return leakyBucket(ctx) } return rl, err } - // Gain exclusive rights to this item while we calculate the rate limit - b.mutex.Lock() - defer b.mutex.Unlock() + defer ctx.CacheItem.mutex.Unlock() - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - b.Remaining = float64(r.Burst) + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = float64(ctx.Request.Burst) } // Update burst, limit and duration if they changed - if b.Burst != r.Burst { - if r.Burst > int64(b.Remaining) { - b.Remaining = float64(r.Burst) + if t.Burst != ctx.Request.Burst { + if ctx.Request.Burst > int64(t.Remaining) { + t.Remaining = float64(ctx.Request.Burst) } - b.Burst = r.Burst + t.Burst = ctx.Request.Burst } - b.Limit = r.Limit - b.Duration = r.Duration + t.Limit = ctx.Request.Limit + t.Duration = ctx.Request.Duration - duration := r.Duration - rate := float64(duration) / float64(r.Limit) + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - d, err := GregorianDuration(clock.Now(), r.Duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + d, err := GregorianDuration(clock.Now(), ctx.Request.Duration) if err != nil { return nil, err } n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) + expire, err := GregorianExpiration(n, ctx.Request.Duration) if err != nil { return nil, err } // Calculate the rate using the entire duration of the gregorian interval // IE: Minute = 60,000 milliseconds, etc.. etc.. - rate = float64(d) / float64(r.Limit) + rate = float64(d) / float64(ctx.Request.Limit) // Update the duration to be the end of the gregorian interval duration = expire - (n.UnixNano() / 1000000) } - if r.Hits != 0 { - c.UpdateExpiration(r.HashKey(), createdAt+duration) + createdAt := *ctx.Request.CreatedAt + if ctx.Request.Hits != 0 { + ctx.CacheItem.ExpireAt = createdAt + duration } // Calculate how much leaked out of the bucket since the last time we leaked a hit - elapsed := createdAt - b.UpdatedAt + elapsed := createdAt - t.UpdatedAt leak := float64(elapsed) / rate if int64(leak) > 0 { - b.Remaining += leak - b.UpdatedAt = createdAt + t.Remaining += leak + t.UpdatedAt = createdAt } - if int64(b.Remaining) > b.Burst { - b.Remaining = float64(b.Burst) + if int64(t.Remaining) > t.Burst { + t.Remaining = float64(t.Burst) } rl := &RateLimitResp{ - Limit: b.Limit, - Remaining: int64(b.Remaining), + Limit: t.Limit, + Remaining: int64(t.Remaining), Status: Status_UNDER_LIMIT, - ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), + ResetTime: createdAt + (t.Limit-int64(t.Remaining))*int64(rate), } // TODO: Feature missing: check for Duration change between item/request. - if s != nil && reqState.IsOwner { + if ctx.Store != nil && ctx.ReqState.IsOwner { defer func() { - s.OnChange(ctx, r, item) + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) }() } // If we are already at the limit - if int64(b.Remaining) == 0 && r.Hits > 0 { - if reqState.IsOwner { + if int64(t.Remaining) == 0 && ctx.Request.Hits > 0 { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -421,24 +420,24 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat } // If requested hits takes the remainder - if int64(b.Remaining) == r.Hits { - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) + if int64(t.Remaining) == ctx.Request.Hits { + t.Remaining = 0 + rl.Remaining = int64(t.Remaining) rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } // If requested is more than available, then return over the limit // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. - if r.Hits > int64(b.Remaining) { - if reqState.IsOwner { + if ctx.Request.Hits > int64(t.Remaining) { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT // DRAIN_OVER_LIMIT behavior drains the remaining counter. - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - b.Remaining = 0 + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + t.Remaining = 0 rl.Remaining = 0 } @@ -446,25 +445,25 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat } // Client is only interested in retrieving the current status - if r.Hits == 0 { + if ctx.Request.Hits == 0 { return rl, nil } - b.Remaining -= float64(r.Hits) - rl.Remaining = int64(b.Remaining) + t.Remaining -= float64(ctx.Request.Hits) + rl.Remaining = int64(t.Remaining) rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } // Called by leakyBucket() when adding a new item in the store. -func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { - createdAt := *r.CreatedAt - duration := r.Duration - rate := float64(duration) / float64(r.Limit) - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { +func initLeakyBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { + createdAt := *ctx.Request.CreatedAt + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) + expire, err := GregorianExpiration(n, ctx.Request.Duration) if err != nil { return nil, err } @@ -475,23 +474,23 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, // Create a new leaky bucket b := LeakyBucketItem{ - Remaining: float64(r.Burst - r.Hits), - Limit: r.Limit, + Remaining: float64(ctx.Request.Burst - ctx.Request.Hits), + Limit: ctx.Request.Limit, Duration: duration, UpdatedAt: createdAt, - Burst: r.Burst, + Burst: ctx.Request.Burst, } rl := RateLimitResp{ Status: Status_UNDER_LIMIT, Limit: b.Limit, - Remaining: r.Burst - r.Hits, - ResetTime: createdAt + (b.Limit-(r.Burst-r.Hits))*int64(rate), + Remaining: ctx.Request.Burst - ctx.Request.Hits, + ResetTime: createdAt + (b.Limit-(ctx.Request.Burst-ctx.Request.Hits))*int64(rate), } // Client could be requesting that we start with the bucket OVER_LIMIT - if r.Hits > r.Burst { - if reqState.IsOwner { + if ctx.Request.Hits > ctx.Request.Burst { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -500,19 +499,33 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, b.Remaining = 0 } - item := &CacheItem{ - ExpireAt: createdAt + duration, - Algorithm: r.Algorithm, - Key: r.HashKey(), - Value: &b, - } - - if !c.Add(item) { - return nil, errAlreadyExistsInCache + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = ctx.Request.Algorithm + ctx.CacheItem.ExpireAt = createdAt + duration + in, ok := ctx.CacheItem.Value.(*LeakyBucketItem) + if !ok { + // Likely the store gave us the wrong cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return initLeakyBucketItem(ctx) + } + *in = b + ctx.CacheItem.mutex.Unlock() + } else { + ctx.CacheItem = &CacheItem{ + ExpireAt: createdAt + duration, + Algorithm: ctx.Request.Algorithm, + Key: ctx.Request.HashKey(), + Value: &b, + } + if !ctx.Cache.Add(ctx.CacheItem) { + return nil, errAlreadyExistsInCache + } } - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return &rl, nil diff --git a/cache.go b/cache.go index dbeea21..70631e4 100644 --- a/cache.go +++ b/cache.go @@ -20,7 +20,6 @@ import "sync" type Cache interface { Add(item *CacheItem) bool - UpdateExpiration(key string, expireAt int64) bool GetItem(key string) (value *CacheItem, ok bool) Each() chan *CacheItem Remove(key string) diff --git a/cache_manager.go b/cache_manager.go index 3514ce1..90555d4 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -59,12 +59,11 @@ func (m *cacheManager) GetRateLimit(ctx context.Context, req *RateLimitReq, stat switch req.Algorithm { case Algorithm_TOKEN_BUCKET: rlResponse, err = tokenBucket(rateContext{ - Store: m.conf.Store, - Cache: m.cache, - ReqState: state, - Request: req, - Context: ctx, - InstanceID: m.conf.InstanceID, + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, }) if err != nil { msg := "Error in tokenBucket" @@ -72,7 +71,13 @@ func (m *cacheManager) GetRateLimit(ctx context.Context, req *RateLimitReq, stat } case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, m.conf.Store, m.cache, req, state) + rlResponse, err = leakyBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + }) if err != nil { msg := "Error in leakyBucket" countError(err, msg) diff --git a/lrucache.go b/lrucache.go index 5bef041..83fad9a 100644 --- a/lrucache.go +++ b/lrucache.go @@ -119,13 +119,6 @@ func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) - // TODO(thrawn01): Remove - //if entry.IsExpired() { - // c.removeElement(ele) - // metricCacheAccess.WithLabelValues("miss").Add(1) - // return - //} - metricCacheAccess.WithLabelValues("hit").Add(1) c.ll.MoveToFront(ele) return entry, true @@ -168,19 +161,6 @@ func (c *LRUCache) Size() int64 { return atomic.LoadInt64(&c.cacheLen) } -// UpdateExpiration updates the expiration time for the key -func (c *LRUCache) UpdateExpiration(key string, expireAt int64) bool { - c.mu.Lock() - defer c.mu.Unlock() - - if ele, hit := c.cache[key]; hit { - entry := ele.Value.(*CacheItem) - entry.ExpireAt = expireAt - return true - } - return false -} - func (c *LRUCache) Close() error { c.cache = nil c.ll = nil diff --git a/mock_cache_test.go b/mock_cache_test.go index 3eea640..15f12cb 100644 --- a/mock_cache_test.go +++ b/mock_cache_test.go @@ -34,11 +34,6 @@ func (m *MockCache) Add(item *guber.CacheItem) bool { return args.Bool(0) } -func (m *MockCache) UpdateExpiration(key string, expireAt int64) bool { - args := m.Called(key, expireAt) - return args.Bool(0) -} - func (m *MockCache) GetItem(key string) (value *guber.CacheItem, ok bool) { args := m.Called(key) retval, _ := args.Get(0).(*guber.CacheItem) diff --git a/otter.go b/otter.go index 9f60c55..992bbfc 100644 --- a/otter.go +++ b/otter.go @@ -50,33 +50,10 @@ func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { return nil, false } - // TODO(thrawn01): Remove - //if item.IsExpired() { - // metricCacheAccess.WithLabelValues("miss").Add(1) - // // If the item is expired, just return `nil` - // // - // // We avoid the explicit deletion of the expired item to avoid acquiring a mutex lock in otter. - // // Explicit deletions in otter require a mutex, which can cause performance bottlenecks - // // under high concurrency scenarios. By allowing the item to be evicted naturally by - // // otter's eviction mechanism, we avoid impacting performance under high concurrency. - // return nil, false - //} metricCacheAccess.WithLabelValues("hit").Add(1) return item, true } -// UpdateExpiration will update an item in the cache with a new expiration date. -// returns true if the item exists in the cache and was updated. -func (o *OtterCache) UpdateExpiration(key string, expireAt int64) bool { - item, ok := o.cache.Get(key) - if !ok { - return false - } - - item.ExpireAt = expireAt - return true -} - // Each returns a channel which the call can use to iterate through // all the items in the cache. func (o *OtterCache) Each() chan *CacheItem { diff --git a/store.go b/store.go index 089ea50..b96868f 100644 --- a/store.go +++ b/store.go @@ -18,7 +18,6 @@ package gubernator import ( "context" - "sync" ) // PERSISTENT STORE DETAILS @@ -30,7 +29,6 @@ import ( // Both interfaces can be implemented simultaneously to ensure data is always saved to persistent storage. type LeakyBucketItem struct { - mutex sync.Mutex Limit int64 Duration int64 Remaining float64 @@ -81,41 +79,6 @@ type Loader interface { Save(chan *CacheItem) error } -// TODO Remove -//func NewMockStore() *MockStore { -// ml := &MockStore{ -// Called: make(map[string]int), -// CacheItems: make(map[string]*CacheItem), -// } -// ml.Called["OnChange()"] = 0 -// ml.Called["Remove()"] = 0 -// ml.Called["Get()"] = 0 -// return ml -//} -// -//type MockStore struct { -// Called map[string]int -// CacheItems map[string]*CacheItem -//} -// -//var _ Store = &MockStore{} -// -//func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) { -// ms.Called["OnChange()"] += 1 -// ms.CacheItems[item.Key] = item -//} -// -//func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) { -// ms.Called["Get()"] += 1 -// item, ok := ms.CacheItems[r.HashKey()] -// return item, ok -//} -// -//func (ms *MockStore) Remove(ctx context.Context, key string) { -// ms.Called["Remove()"] += 1 -// delete(ms.CacheItems, key) -//} - func NewMockLoader() *MockLoader { ml := &MockLoader{ Called: make(map[string]int), diff --git a/store_test.go b/store_test.go index ff29df0..f3e1122 100644 --- a/store_test.go +++ b/store_test.go @@ -150,8 +150,9 @@ func (ms *NoOpStore) Get(ctx context.Context, r *gubernator.RateLimitReq) (*gube // add items to the cache in parallel thus creating a race condition the code must then handle. func TestHighContentionFromStore(t *testing.T) { const ( - numGoroutines = 1_000 - numKeys = 400 + // Increase these number to improve the chance of contention, but at the cost of test speed. + numGoroutines = 500 + numKeys = 100 ) store := &NoOpStore{} srv := newV1Server(t, "localhost:0", gubernator.Config{