diff --git a/Makefile b/Makefile index 3ce5fb2..a2ef134 100644 --- a/Makefile +++ b/Makefile @@ -15,10 +15,15 @@ $(GOLANGCI_LINT): ## Download Go linter lint: $(GOLANGCI_LINT) ## Run Go linter $(GOLANGCI_LINT) run -v -c .golangci.yml ./... -.PHONY: validate -validate: lint test +.PHONY: tidy +tidy: go mod tidy && git diff --exit-code +.PHONY: validate +validate: tidy lint test bench + @echo + @echo "\033[32mEVERYTHING PASSED!\033[0m" + .PHONY: test test: ## Run unit tests and measure code coverage (go test -v -race -p=1 -count=1 -tags holster_test_mode -coverprofile coverage.out ./...; ret=$$?; \ @@ -28,7 +33,7 @@ test: ## Run unit tests and measure code coverage .PHONY: bench bench: ## Run Go benchmarks - go test ./... -bench . -benchtime 5s -timeout 0 -run='^$$' -benchmem + go test ./... -bench . -timeout 6m -run='^$$' -benchmem .PHONY: docker docker: ## Build Docker image @@ -49,7 +54,6 @@ clean-proto: ## Clean the generated source files from the protobuf sources @find . -name "*.pb.go" -type f -delete @find . -name "*.pb.*.go" -type f -delete - .PHONY: proto proto: ## Build protos ./buf.gen.yaml diff --git a/algorithms.go b/algorithms.go index aca5f9f..79ce797 100644 --- a/algorithms.go +++ b/algorithms.go @@ -18,13 +18,26 @@ package gubernator import ( "context" + "errors" "github.com/mailgun/holster/v4/clock" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/attribute" + "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/trace" ) +var errAlreadyExistsInCache = errors.New("already exists in cache") + +type rateContext struct { + context.Context + // TODO(thrawn01): Roll this into `rateContext` + ReqState RateLimitContext + + Request *RateLimitRequest + CacheItem *CacheItem + Store Store + Cache Cache +} + // ### NOTE ### // The both token and leaky follow the same semantic which allows for requests of more than the limit // to be rejected, but subsequent requests within the same window that are under the limit to succeed. @@ -33,407 +46,425 @@ import ( // with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket -func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitRequest, reqState RateLimitContext) (resp *RateLimitResponse, err error) { - // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) - - if s != nil && !ok { - // Cache miss. - // Check our store for the item. - if item, ok = s.Get(ctx, r); ok { - c.Add(item) +func tokenBucket(ctx rateContext) (resp *RateLimitResponse, err error) { + tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) + defer tokenBucketTimer.ObserveDuration() + var ok bool + + // Get rate limit from cache + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) + + // If not in the cache, check the store if provided + if ctx.Store != nil && !ok { + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "tokenBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "tokenBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + // If no item was found, or the item is expired. + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) } + return rl, err } - if ok { - // Item found in cache or store. - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - c.Remove(hashKey) + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() - if s != nil { - s.Remove(ctx, hashKey) - } - return &RateLimitResponse{ - Status: Status_UNDER_LIMIT, - Limit: r.Limit, - Remaining: r.Limit, - ResetTime: 0, - }, nil + t, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - t, ok := item.Value.(*TokenBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - trace.SpanFromContext(ctx).AddEvent("Client switched algorithms; perhaps due to a migration?") + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + rl, err := initTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return tokenBucket(ctx) + } + return rl, err + } - c.Remove(hashKey) + defer ctx.CacheItem.mutex.Unlock() - if s != nil { - s.Remove(ctx, hashKey) - } + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = ctx.Request.Limit + t.Limit = ctx.Request.Limit + t.Status = Status_UNDER_LIMIT + + if ctx.Store != nil { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + } + return &RateLimitResponse{ + Status: Status_UNDER_LIMIT, + Limit: ctx.Request.Limit, + Remaining: ctx.Request.Limit, + ResetTime: 0, + }, nil + } - return tokenBucketNewItem(ctx, s, c, r, reqState) + // Update the limit if it changed. + if t.Limit != ctx.Request.Limit { + // Add difference to remaining. + t.Remaining += ctx.Request.Limit - t.Limit + if t.Remaining < 0 { + t.Remaining = 0 } + t.Limit = ctx.Request.Limit + } + + rl := &RateLimitResponse{ + Status: t.Status, + Limit: ctx.Request.Limit, + Remaining: t.Remaining, + ResetTime: ctx.CacheItem.ExpireAt, + } - // Update the limit if it changed. - if t.Limit != r.Limit { - // Add difference to remaining. - t.Remaining += r.Limit - t.Limit - if t.Remaining < 0 { - t.Remaining = 0 + // If the duration config changed, update the new ExpireAt. + if t.Duration != ctx.Request.Duration { + span := trace.SpanFromContext(ctx) + span.AddEvent("Duration changed") + expire := t.CreatedAt + ctx.Request.Duration + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) + if err != nil { + return nil, err } - t.Limit = r.Limit } - rl := &RateLimitResponse{ - Status: t.Status, - Limit: r.Limit, - Remaining: t.Remaining, - ResetTime: item.ExpireAt, + // If our new duration means we are currently expired. + createdAt := *ctx.Request.CreatedAt + if expire <= createdAt { + // Renew item. + span.AddEvent("Limit has expired") + expire = createdAt + ctx.Request.Duration + t.CreatedAt = createdAt + t.Remaining = t.Limit } - // If the duration config changed, update the new ExpireAt. - if t.Duration != r.Duration { - span := trace.SpanFromContext(ctx) - span.AddEvent("Duration changed") - expire := t.CreatedAt + r.Duration - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - } + ctx.CacheItem.ExpireAt = expire + t.Duration = ctx.Request.Duration + rl.ResetTime = expire + } - // If our new duration means we are currently expired. - createdAt := *r.CreatedAt - if expire <= createdAt { - // Renew item. - span.AddEvent("Limit has expired") - expire = createdAt + r.Duration - t.CreatedAt = createdAt - t.Remaining = t.Limit - } + if ctx.Store != nil && ctx.ReqState.IsOwner { + defer func() { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + }() + } - item.ExpireAt = expire - t.Duration = r.Duration - rl.ResetTime = expire - } + // Client is only interested in retrieving the current status or + // updating the rate limit config. + if ctx.Request.Hits == 0 { + return rl, nil + } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() + // If we are already at the limit. + if rl.Remaining == 0 && ctx.Request.Hits > 0 { + trace.SpanFromContext(ctx).AddEvent("Already over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT + t.Status = rl.Status + return rl, nil + } - // Client is only interested in retrieving the current status or - // updating the rate limit config. - if r.Hits == 0 { - return rl, nil - } + // If requested hits takes the remainder. + if t.Remaining == ctx.Request.Hits { + trace.SpanFromContext(ctx).AddEvent("At the limit") + t.Remaining = 0 + rl.Remaining = 0 + return rl, nil + } - // If we are already at the limit. - if rl.Remaining == 0 && r.Hits > 0 { - trace.SpanFromContext(ctx).AddEvent("Already over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - t.Status = rl.Status - return rl, nil + // If requested is more than available, then return over the limit + // without updating the cache. + if ctx.Request.Hits > t.Remaining { + trace.SpanFromContext(ctx).AddEvent("Over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } - - // If requested hits takes the remainder. - if t.Remaining == r.Hits { - trace.SpanFromContext(ctx).AddEvent("At the limit") + rl.Status = Status_OVER_LIMIT + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + // DRAIN_OVER_LIMIT behavior drains the remaining counter. t.Remaining = 0 rl.Remaining = 0 - return rl, nil - } - - // If requested is more than available, then return over the limit - // without updating the cache. - if r.Hits > t.Remaining { - trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - t.Remaining = 0 - rl.Remaining = 0 - } - return rl, nil } - - t.Remaining -= r.Hits - rl.Remaining = t.Remaining return rl, nil } - // Item is not found in cache or store, create new. - return tokenBucketNewItem(ctx, s, c, r, reqState) + t.Remaining -= ctx.Request.Hits + rl.Remaining = t.Remaining + return rl, nil } -// Called by tokenBucket() when adding a new item in the store. -func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitRequest, reqState RateLimitContext) (resp *RateLimitResponse, err error) { - createdAt := *r.CreatedAt - expire := createdAt + r.Duration +// initTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. +func initTokenBucketItem(ctx rateContext) (resp *RateLimitResponse, err error) { + createdAt := *ctx.Request.CreatedAt + expire := createdAt + ctx.Request.Duration - t := &TokenBucketItem{ - Limit: r.Limit, - Duration: r.Duration, - Remaining: r.Limit - r.Hits, + t := TokenBucketItem{ + Limit: ctx.Request.Limit, + Duration: ctx.Request.Duration, + Remaining: ctx.Request.Limit - ctx.Request.Hits, CreatedAt: createdAt, } // Add a new rate limit to the cache. - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) if err != nil { return nil, err } } - item := &CacheItem{ - Algorithm: Algorithm_TOKEN_BUCKET, - Key: r.HashKey(), - Value: t, - ExpireAt: expire, - } - rl := &RateLimitResponse{ Status: Status_UNDER_LIMIT, - Limit: r.Limit, + Limit: ctx.Request.Limit, Remaining: t.Remaining, ResetTime: expire, } // Client could be requesting that we always return OVER_LIMIT. - if r.Hits > r.Limit { + if ctx.Request.Hits > ctx.Request.Limit { trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT - rl.Remaining = r.Limit - t.Remaining = r.Limit + rl.Remaining = ctx.Request.Limit + t.Remaining = ctx.Request.Limit } - c.Add(item) + // If the cache item already exists, update it + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = ctx.Request.Algorithm + ctx.CacheItem.ExpireAt = expire + in, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Likely the store gave us the wrong cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return initTokenBucketItem(ctx) + } + *in = t + ctx.CacheItem.mutex.Unlock() + } else { + // else create a new cache item and add it to the cache + ctx.CacheItem = &CacheItem{ + Algorithm: Algorithm_TOKEN_BUCKET, + Key: ctx.Request.HashKey(), + Value: &t, + ExpireAt: expire, + } + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + return rl, errAlreadyExistsInCache + } + } - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return rl, nil } // Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket -func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitRequest, reqState RateLimitContext) (resp *RateLimitResponse, err error) { - if r.Burst == 0 { - r.Burst = r.Limit - } +func leakyBucket(ctx rateContext) (resp *RateLimitResponse, err error) { + leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) + defer leakyBucketTimer.ObserveDuration() + var ok bool - createdAt := *r.CreatedAt + if ctx.Request.Burst == 0 { + ctx.Request.Burst = ctx.Request.Limit + } // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) - - if s != nil && !ok { - // Cache miss. - // Check our store for the item. - if item, ok = s.Get(ctx, r); ok { - c.Add(item) + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) + + if ctx.Store != nil && !ok { + // Cache missed, check our store for the item. + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "leakyBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "leakyBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + // If no item was found, or the item is expired. + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initLeakyBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx) } + return rl, err } - if ok { - // Item found in cache or store. - - b, ok := item.Value.(*LeakyBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - c.Remove(hashKey) - - if s != nil { - s.Remove(ctx, hashKey) - } + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() - return leakyBucketNewItem(ctx, s, c, r, reqState) + // Item found in cache or store. + t, ok := ctx.CacheItem.Value.(*LeakyBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - b.Remaining = float64(r.Burst) + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + rl, err := initLeakyBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return leakyBucket(ctx) } + return rl, err + } - // Update burst, limit and duration if they changed - if b.Burst != r.Burst { - if r.Burst > int64(b.Remaining) { - b.Remaining = float64(r.Burst) - } - b.Burst = r.Burst - } + defer ctx.CacheItem.mutex.Unlock() - b.Limit = r.Limit - b.Duration = r.Duration + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = float64(ctx.Request.Burst) + } - duration := r.Duration - rate := float64(duration) / float64(r.Limit) + // Update burst, limit and duration if they changed + if t.Burst != ctx.Request.Burst { + if ctx.Request.Burst > int64(t.Remaining) { + t.Remaining = float64(ctx.Request.Burst) + } + t.Burst = ctx.Request.Burst + } - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - d, err := GregorianDuration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) - if err != nil { - return nil, err - } + t.Limit = ctx.Request.Limit + t.Duration = ctx.Request.Duration - // Calculate the rate using the entire duration of the gregorian interval - // IE: Minute = 60,000 milliseconds, etc.. etc.. - rate = float64(d) / float64(r.Limit) - // Update the duration to be the end of the gregorian interval - duration = expire - (n.UnixNano() / 1000000) - } + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) - if r.Hits != 0 { - c.UpdateExpiration(r.HashKey(), createdAt+duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + d, err := GregorianDuration(clock.Now(), ctx.Request.Duration) + if err != nil { + return nil, err + } + n := clock.Now() + expire, err := GregorianExpiration(n, ctx.Request.Duration) + if err != nil { + return nil, err } - // Calculate how much leaked out of the bucket since the last time we leaked a hit - elapsed := createdAt - b.UpdatedAt - leak := float64(elapsed) / rate + // Calculate the rate using the entire duration of the gregorian interval + // IE: Minute = 60,000 milliseconds, etc.. etc.. + rate = float64(d) / float64(ctx.Request.Limit) + // Update the duration to be the end of the gregorian interval + duration = expire - (n.UnixNano() / 1000000) + } - if int64(leak) > 0 { - b.Remaining += leak - b.UpdatedAt = createdAt - } + createdAt := *ctx.Request.CreatedAt + if ctx.Request.Hits != 0 { + ctx.CacheItem.ExpireAt = createdAt + duration + } - if int64(b.Remaining) > b.Burst { - b.Remaining = float64(b.Burst) - } + // Calculate how much leaked out of the bucket since the last time we leaked a hit + elapsed := createdAt - t.UpdatedAt + leak := float64(elapsed) / rate - rl := &RateLimitResponse{ - Limit: b.Limit, - Remaining: int64(b.Remaining), - Status: Status_UNDER_LIMIT, - ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), - } + if int64(leak) > 0 { + t.Remaining += leak + t.UpdatedAt = createdAt + } - // TODO: Feature missing: check for Duration change between item/request. + if int64(t.Remaining) > t.Burst { + t.Remaining = float64(t.Burst) + } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() - } + rl := &RateLimitResponse{ + Limit: t.Limit, + Remaining: int64(t.Remaining), + Status: Status_UNDER_LIMIT, + ResetTime: createdAt + (t.Limit-int64(t.Remaining))*int64(rate), + } - // If we are already at the limit - if int64(b.Remaining) == 0 && r.Hits > 0 { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - return rl, nil - } + // TODO: Feature missing: check for Duration change between item/request. - // If requested hits takes the remainder - if int64(b.Remaining) == r.Hits { - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) - return rl, nil - } + if ctx.Store != nil && ctx.ReqState.IsOwner { + defer func() { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + }() + } - // If requested is more than available, then return over the limit - // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. - if r.Hits > int64(b.Remaining) { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT + // If we are already at the limit + if int64(t.Remaining) == 0 && ctx.Request.Hits > 0 { + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) + } + rl.Status = Status_OVER_LIMIT + return rl, nil + } - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - b.Remaining = 0 - rl.Remaining = 0 - } + // If requested hits takes the remainder + if int64(t.Remaining) == ctx.Request.Hits { + t.Remaining = 0 + rl.Remaining = int64(t.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } - return rl, nil + // If requested is more than available, then return over the limit + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. + if ctx.Request.Hits > int64(t.Remaining) { + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT - // Client is only interested in retrieving the current status - if r.Hits == 0 { - return rl, nil + // DRAIN_OVER_LIMIT behavior drains the remaining counter. + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + t.Remaining = 0 + rl.Remaining = 0 } - b.Remaining -= float64(r.Hits) - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } - return leakyBucketNewItem(ctx, s, c, r, reqState) + // Client is only interested in retrieving the current status + if ctx.Request.Hits == 0 { + return rl, nil + } + + t.Remaining -= float64(ctx.Request.Hits) + rl.Remaining = int64(t.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } // Called by leakyBucket() when adding a new item in the store. -func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitRequest, reqState RateLimitContext) (resp *RateLimitResponse, err error) { - createdAt := *r.CreatedAt - duration := r.Duration - rate := float64(duration) / float64(r.Limit) - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { +func initLeakyBucketItem(ctx rateContext) (resp *RateLimitResponse, err error) { + createdAt := *ctx.Request.CreatedAt + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) + expire, err := GregorianExpiration(n, ctx.Request.Duration) if err != nil { return nil, err } @@ -444,23 +475,23 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReque // Create a new leaky bucket b := LeakyBucketItem{ - Remaining: float64(r.Burst - r.Hits), - Limit: r.Limit, + Remaining: float64(ctx.Request.Burst - ctx.Request.Hits), + Limit: ctx.Request.Limit, Duration: duration, UpdatedAt: createdAt, - Burst: r.Burst, + Burst: ctx.Request.Burst, } rl := RateLimitResponse{ Status: Status_UNDER_LIMIT, Limit: b.Limit, - Remaining: r.Burst - r.Hits, - ResetTime: createdAt + (b.Limit-(r.Burst-r.Hits))*int64(rate), + Remaining: ctx.Request.Burst - ctx.Request.Hits, + ResetTime: createdAt + (b.Limit-(ctx.Request.Burst-ctx.Request.Hits))*int64(rate), } // Client could be requesting that we start with the bucket OVER_LIMIT - if r.Hits > r.Burst { - if reqState.IsOwner { + if ctx.Request.Hits > ctx.Request.Burst { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -469,17 +500,33 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReque b.Remaining = 0 } - item := &CacheItem{ - ExpireAt: createdAt + duration, - Algorithm: r.Algorithm, - Key: r.HashKey(), - Value: &b, + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = ctx.Request.Algorithm + ctx.CacheItem.ExpireAt = createdAt + duration + in, ok := ctx.CacheItem.Value.(*LeakyBucketItem) + if !ok { + // Likely the store gave us the wrong cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return initLeakyBucketItem(ctx) + } + *in = b + ctx.CacheItem.mutex.Unlock() + } else { + ctx.CacheItem = &CacheItem{ + ExpireAt: createdAt + duration, + Algorithm: ctx.Request.Algorithm, + Key: ctx.Request.HashKey(), + Value: &b, + } + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + return nil, errAlreadyExistsInCache + } } - c.Add(item) - - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return &rl, nil diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index 644fd22..aece7ef 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -1,160 +1,182 @@ package gubernator_test import ( - "strconv" + "math/rand" "sync" "testing" "time" "github.com/gubernator-io/gubernator/v3" "github.com/mailgun/holster/v4/clock" + "github.com/stretchr/testify/require" ) func BenchmarkCache(b *testing.B) { + const defaultNumKeys = 8192 testCases := []struct { Name string - NewTestCache func() gubernator.Cache + NewTestCache func() (gubernator.Cache, error) LockRequired bool }{ { Name: "LRUCache", - NewTestCache: func() gubernator.Cache { - return gubernator.NewLRUCache(0) + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewLRUCache(0), nil }, LockRequired: true, }, + { + Name: "OtterCache", + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewOtterCache(0) + }, + LockRequired: false, + }, } for _, testCase := range testCases { b.Run(testCase.Name, func(b *testing.B) { b.Run("Sequential reads", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) - for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) + for _, key := range keys { item := &gubernator.CacheItem{ Key: key, - Value: i, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) - _, _ = cache.GetItem(key) + index := int(rand.Uint32() & uint32(mask)) + _, _ = cache.GetItem(keys[index&mask]) } }) b.Run("Sequential writes", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { + index := int(rand.Uint32() & uint32(mask)) item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: keys[index&mask], + Value: "value:" + keys[index&mask], ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } }) b.Run("Concurrent reads", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) - for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) + for _, key := range keys { item := &gubernator.CacheItem{ Key: key, - Value: i, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } - var wg sync.WaitGroup var mutex sync.Mutex - var task func(i int) + var task func(key string) if testCase.LockRequired { - task = func(i int) { + task = func(key string) { mutex.Lock() defer mutex.Unlock() - key := strconv.Itoa(i) _, _ = cache.GetItem(key) - wg.Done() } } else { - task = func(i int) { - key := strconv.Itoa(i) + task = func(key string) { _, _ = cache.GetItem(key) - wg.Done() } } b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - wg.Add(1) - go task(i) - } + mask := len(keys) - 1 + + b.RunParallel(func(pb *testing.PB) { + index := int(rand.Uint32() & uint32(mask)) + for pb.Next() { + task(keys[index&mask]) + } + }) - wg.Wait() }) b.Run("Concurrent writes", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) - var wg sync.WaitGroup var mutex sync.Mutex - var task func(i int) + var task func(key string) if testCase.LockRequired { - task = func(i int) { + task = func(key string) { mutex.Lock() defer mutex.Unlock() item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: key, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) - wg.Done() + cache.AddIfNotPresent(item) } } else { - task = func(i int) { + task = func(key string) { item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: key, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) - wg.Done() + cache.AddIfNotPresent(item) } } + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - wg.Add(1) - go task(i) - } - - wg.Wait() + b.RunParallel(func(pb *testing.PB) { + index := int(rand.Uint32() & uint32(mask)) + for pb.Next() { + task(keys[index&mask]) + } + }) }) }) } } + +func GenerateRandomKeys(size int) []string { + keys := make([]string, 0, size) + for i := 0; i < size; i++ { + keys = append(keys, gubernator.RandomString(20)) + } + return keys +} diff --git a/benchmark_test.go b/benchmark_test.go index 6280349..4cac6dc 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -18,12 +18,13 @@ package gubernator_test import ( "context" + "fmt" + "runtime" "testing" guber "github.com/gubernator-io/gubernator/v3" "github.com/gubernator-io/gubernator/v3/cluster" "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/syncutil" "github.com/stretchr/testify/require" ) @@ -57,9 +58,9 @@ func BenchmarkServer(b *testing.B) { require.NoError(b, err, "Error in conf.SetDefaults") createdAt := epochMillis(clock.Now()) d := cluster.GetRandomDaemon(cluster.DataCenterNone) - client := d.MustClient().(guber.PeerClient) b.Run("GetPeerRateLimit", func(b *testing.B) { + client := d.MustClient().(guber.PeerClient) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -146,15 +147,27 @@ func BenchmarkServer(b *testing.B) { }) b.Run("Thundering herd", func(b *testing.B) { - client := cluster.GetRandomDaemon(cluster.DataCenterNone).MustClient() require.NoError(b, err, "Error in guber.DialV1Server") + var clients []guber.Client + + // Create a client for each CPU on the system. This should allow us to simulate the + // maximum contention possible for this system. + for i := 0; i < runtime.NumCPU(); i++ { + client, err := guber.NewClient(guber.WithNoTLS(d.Listener.Addr().String())) + require.NoError(b, err) + clients = append(clients, client) + } b.ResetTimer() - fan := syncutil.NewFanOut(100) + mask := len(clients) - 1 - for n := 0; n < b.N; n++ { - fan.Run(func(o interface{}) error { + var idx int + b.RunParallel(func(pb *testing.PB) { + client := clients[idx&mask] + idx++ + + for pb.Next() { var resp guber.CheckRateLimitsResponse - err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + err = client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ Requests: []*guber.RateLimitRequest{ { Name: b.Name(), @@ -166,10 +179,9 @@ func BenchmarkServer(b *testing.B) { }, }, &resp) if err != nil { - b.Errorf("Error in client.GetRateLimits: %s", err) + fmt.Printf("%s\n", err.Error()) } - return nil - }, nil) - } + } + }) }) } diff --git a/cache.go b/cache.go index 0fd431a..3d95dfb 100644 --- a/cache.go +++ b/cache.go @@ -16,31 +16,50 @@ limitations under the License. package gubernator +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/mailgun/holster/v4/clock" +) + type Cache interface { - Add(item *CacheItem) bool - UpdateExpiration(key string, expireAt int64) bool + // AddIfNotPresent adds the item to the cache if it doesn't already exist. + // Returns true if the item was added, false if the item already exists. + AddIfNotPresent(item *CacheItem) bool GetItem(key string) (value *CacheItem, ok bool) Each() chan *CacheItem Remove(key string) Size() int64 + Stats() CacheStats Close() error } +// CacheItem is 64 bytes aligned in size +// Since both TokenBucketItem and LeakyBucketItem both 40 bytes in size then a CacheItem with +// the Value attached takes up 64 + 40 = 104 bytes of space. Not counting the size of the key. type CacheItem struct { - Algorithm Algorithm - Key string - Value interface{} + mutex sync.Mutex // 8 bytes + Key string // 16 bytes + Value interface{} // 16 bytes // Timestamp when rate limit expires in epoch milliseconds. - ExpireAt int64 + ExpireAt int64 // 8 Bytes // Timestamp when the cache should invalidate this rate limit. This is useful when used in conjunction with // a persistent store to ensure our node has the most up to date info from the store. Ignored if set to `0` // It is set by the persistent store implementation to indicate when the node should query the persistent store // for the latest rate limit data. - InvalidAt int64 + InvalidAt int64 // 8 bytes + Algorithm Algorithm // 4 bytes + // 4 Bytes of Padding } func (item *CacheItem) IsExpired() bool { + // TODO(thrawn01): Eliminate the need for this mutex lock + item.mutex.Lock() + defer item.mutex.Unlock() + now := MillisecondNow() // If the entry is invalidated @@ -55,3 +74,91 @@ func (item *CacheItem) IsExpired() bool { return false } + +func (item *CacheItem) Copy(from *CacheItem) { + item.mutex.Lock() + defer item.mutex.Unlock() + + item.InvalidAt = from.InvalidAt + item.Algorithm = from.Algorithm + item.ExpireAt = from.ExpireAt + item.Value = from.Value + item.Key = from.Key +} + +// MillisecondNow returns unix epoch in milliseconds +func MillisecondNow() int64 { + return clock.Now().UnixNano() / 1000000 +} + +type CacheStats struct { + Size int64 + Hit int64 + Miss int64 + UnexpiredEvictions int64 +} + +// CacheCollector provides prometheus metrics collector for Cache implementations +// Register only one collector, add one or more caches to this collector. +type CacheCollector struct { + caches []Cache + metricSize prometheus.Gauge + metricAccess *prometheus.CounterVec + metricUnexpiredEvictions prometheus.Counter +} + +func NewCacheCollector() *CacheCollector { + return &CacheCollector{ + caches: []Cache{}, + metricSize: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gubernator_cache_size", + Help: "The number of items in LRU Cache which holds the rate limits.", + }), + metricAccess: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "gubernator_cache_access_count", + Help: "Cache access counts. Label \"type\" = hit|miss.", + }, []string{"type"}), + metricUnexpiredEvictions: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gubernator_unexpired_evictions_count", + Help: "Count the number of cache items which were evicted while unexpired.", + }), + } +} + +var _ prometheus.Collector = &CacheCollector{} + +// AddCache adds a Cache object to be tracked by the collector. +func (c *CacheCollector) AddCache(cache Cache) { + c.caches = append(c.caches, cache) +} + +// Describe fetches prometheus metrics to be registered +func (c *CacheCollector) Describe(ch chan<- *prometheus.Desc) { + c.metricSize.Describe(ch) + c.metricAccess.Describe(ch) + c.metricUnexpiredEvictions.Describe(ch) +} + +// Collect fetches metric counts and gauges from the cache +func (c *CacheCollector) Collect(ch chan<- prometheus.Metric) { + stats := c.getStats() + c.metricSize.Set(float64(stats.Size)) + c.metricSize.Collect(ch) + c.metricAccess.WithLabelValues("miss").Add(float64(stats.Miss)) + c.metricAccess.WithLabelValues("hit").Add(float64(stats.Hit)) + c.metricAccess.Collect(ch) + c.metricUnexpiredEvictions.Add(float64(stats.UnexpiredEvictions)) + c.metricUnexpiredEvictions.Collect(ch) +} + +func (c *CacheCollector) getStats() CacheStats { + var total CacheStats + for _, cache := range c.caches { + stats := cache.Stats() + total.Hit += stats.Hit + total.Miss += stats.Miss + total.Size += stats.Size + total.UnexpiredEvictions += stats.UnexpiredEvictions + } + return total +} diff --git a/cache_manager.go b/cache_manager.go new file mode 100644 index 0000000..d24e252 --- /dev/null +++ b/cache_manager.go @@ -0,0 +1,172 @@ +/* +Copyright 2024 Derrick J. Wippler + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gubernator + +import ( + "context" + "sync" + + "github.com/pkg/errors" +) + +type CacheManager interface { + CheckRateLimit(context.Context, *RateLimitRequest, RateLimitContext) (*RateLimitResponse, error) + GetCacheItem(context.Context, string) (*CacheItem, bool, error) + AddCacheItem(context.Context, string, *CacheItem) error + Store(ctx context.Context) error + Load(context.Context) error + Close() error +} + +type cacheManager struct { + conf Config + cache Cache +} + +// NewCacheManager creates a new instance of the CacheManager interface using +// the cache returned by Config.CacheFactory +func NewCacheManager(conf Config) (CacheManager, error) { + + cache, err := conf.CacheFactory(conf.CacheSize) + if err != nil { + return nil, err + } + return &cacheManager{ + cache: cache, + conf: conf, + }, nil +} + +// GetRateLimit fetches the item from the cache if it exists, and preforms the appropriate rate limit calculation +func (m *cacheManager) CheckRateLimit(ctx context.Context, req *RateLimitRequest, state RateLimitContext) (*RateLimitResponse, error) { + var rlResponse *RateLimitResponse + var err error + + switch req.Algorithm { + case Algorithm_TOKEN_BUCKET: + rlResponse, err = tokenBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + }) + if err != nil { + msg := "Error in tokenBucket" + countError(err, msg) + } + + case Algorithm_LEAKY_BUCKET: + rlResponse, err = leakyBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + }) + if err != nil { + msg := "Error in leakyBucket" + countError(err, msg) + } + + default: + err = errors.Errorf("Invalid rate limit algorithm '%d'", req.Algorithm) + } + + return rlResponse, err +} + +// Store saves every cache item into persistent storage provided via Config.Loader +func (m *cacheManager) Store(ctx context.Context) error { + out := make(chan *CacheItem, 500) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + for item := range m.cache.Each() { + select { + case out <- item: + + case <-ctx.Done(): + return + } + } + }() + + go func() { + wg.Wait() + close(out) + }() + + if ctx.Err() != nil { + return ctx.Err() + } + + if err := m.conf.Loader.Save(out); err != nil { + return errors.Wrap(err, "while calling p.conf.Loader.Save()") + } + return nil +} + +// Close closes the cache manager +func (m *cacheManager) Close() error { + return m.cache.Close() +} + +// Load cache items from persistent storage provided via Config.Loader +func (m *cacheManager) Load(ctx context.Context) error { + ch, err := m.conf.Loader.Load() + if err != nil { + return errors.Wrap(err, "Error in loader.Load") + } + + for { + var item *CacheItem + var ok bool + + select { + case item, ok = <-ch: + if !ok { + return nil + } + case <-ctx.Done(): + return ctx.Err() + } + retry: + if !m.cache.AddIfNotPresent(item) { + cItem, ok := m.cache.GetItem(item.Key) + if !ok { + goto retry + } + cItem.Copy(item) + } + } +} + +// GetCacheItem returns an item from the cache +func (m *cacheManager) GetCacheItem(_ context.Context, key string) (*CacheItem, bool, error) { + item, ok := m.cache.GetItem(key) + return item, ok, nil +} + +// AddCacheItem adds an item to the cache. The CacheItem.Key should be set correctly, else the item +// will not be added to the cache correctly. +func (m *cacheManager) AddCacheItem(_ context.Context, _ string, item *CacheItem) error { + _ = m.cache.AddIfNotPresent(item) + return nil +} diff --git a/workers_test.go b/cache_manager_test.go similarity index 81% rename from workers_test.go rename to cache_manager_test.go index b3c6b9f..7994b08 100644 --- a/workers_test.go +++ b/cache_manager_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestGubernatorPool(t *testing.T) { +func TestCacheManager(t *testing.T) { ctx := context.Background() testCases := []struct { @@ -43,7 +43,7 @@ func TestGubernatorPool(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { // Setup mock data. const NumCacheItems = 100 - cacheItems := []*guber.CacheItem{} + var cacheItems []*guber.CacheItem for i := 0; i < NumCacheItems; i++ { cacheItems = append(cacheItems, &guber.CacheItem{ Key: fmt.Sprintf("Foobar%04d", i), @@ -55,15 +55,16 @@ func TestGubernatorPool(t *testing.T) { t.Run("Load()", func(t *testing.T) { mockLoader := &MockLoader2{} mockCache := &MockCache{} - conf := &guber.Config{ - CacheFactory: func(maxSize int) guber.Cache { - return mockCache + conf := guber.Config{ + CacheFactory: func(maxSize int) (guber.Cache, error) { + return mockCache, nil }, Loader: mockLoader, Workers: testCase.workers, } assert.NoError(t, conf.SetDefaults()) - chp := guber.NewWorkerPool(conf) + manager, err := guber.NewCacheManager(conf) + require.NoError(t, err) // Mock Loader. fakeLoadCh := make(chan *guber.CacheItem, NumCacheItems) @@ -75,35 +76,36 @@ func TestGubernatorPool(t *testing.T) { // Mock Cache. for _, item := range cacheItems { - mockCache.On("Add", item).Once().Return(false) + mockCache.On("AddIfNotPresent", item).Once().Return(true) } // Call code. - err := chp.Load(ctx) + err = manager.Load(ctx) // Verify. - require.NoError(t, err, "Error in chp.Load") + require.NoError(t, err, "Error in manager.Load") }) t.Run("Store()", func(t *testing.T) { mockLoader := &MockLoader2{} mockCache := &MockCache{} - conf := &guber.Config{ - CacheFactory: func(maxSize int) guber.Cache { - return mockCache + conf := guber.Config{ + CacheFactory: func(maxSize int) (guber.Cache, error) { + return mockCache, nil }, Loader: mockLoader, Workers: testCase.workers, } require.NoError(t, conf.SetDefaults()) - chp := guber.NewWorkerPool(conf) + chp, err := guber.NewCacheManager(conf) + require.NoError(t, err) // Mock Loader. mockLoader.On("Save", mock.Anything).Once().Return(nil). Run(func(args mock.Arguments) { // Verify items sent over the channel passed to Save(). saveCh := args.Get(0).(chan *guber.CacheItem) - savedItems := []*guber.CacheItem{} + var savedItems []*guber.CacheItem for item := range saveCh { savedItems = append(savedItems, item) } @@ -124,7 +126,7 @@ func TestGubernatorPool(t *testing.T) { mockCache.On("Each").Times(testCase.workers).Return(eachCh) // Call code. - err := chp.Store(ctx) + err = chp.Store(ctx) // Verify. require.NoError(t, err, "Error in chp.Store") diff --git a/client.go b/client.go index ca58a8d..c62fe86 100644 --- a/client.go +++ b/client.go @@ -93,13 +93,13 @@ func NewPeerClient(opts ClientOptions) PeerClient { func (c *client) CheckRateLimits(ctx context.Context, req *CheckRateLimitsRequest, resp *CheckRateLimitsResponse) error { payload, err := proto.Marshal(req) if err != nil { - return duh.NewClientError(fmt.Errorf("while marshaling request payload: %w", err), nil) + return duh.NewClientError("while marshaling request payload: %w", err, nil) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.opts.Endpoint, RPCRateLimitCheck), bytes.NewReader(payload)) if err != nil { - return duh.NewClientError(err, nil) + return duh.NewClientError("", err, nil) } r.Header.Set("Content-Type", duh.ContentTypeProtoBuf) @@ -109,13 +109,13 @@ func (c *client) CheckRateLimits(ctx context.Context, req *CheckRateLimitsReques func (c *client) HealthCheck(ctx context.Context, resp *HealthCheckResponse) error { payload, err := proto.Marshal(&HealthCheckRequest{}) if err != nil { - return duh.NewClientError(fmt.Errorf("while marshaling request payload: %w", err), nil) + return duh.NewClientError("while marshaling request payload: %w", err, nil) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.opts.Endpoint, RPCHealthCheck), bytes.NewReader(payload)) if err != nil { - return duh.NewClientError(err, nil) + return duh.NewClientError("", err, nil) } r.Header.Set("Content-Type", duh.ContentTypeProtoBuf) @@ -125,13 +125,13 @@ func (c *client) HealthCheck(ctx context.Context, resp *HealthCheckResponse) err func (c *client) Forward(ctx context.Context, req *ForwardRequest, resp *ForwardResponse) error { payload, err := proto.Marshal(req) if err != nil { - return duh.NewClientError(fmt.Errorf("while marshaling request payload: %w", err), nil) + return duh.NewClientError("while marshaling request payload: %w", err, nil) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.opts.Endpoint, RPCPeerForward), bytes.NewReader(payload)) if err != nil { - return duh.NewClientError(err, nil) + return duh.NewClientError("", err, nil) } c.prop.Inject(ctx, propagation.HeaderCarrier(r.Header)) @@ -142,12 +142,12 @@ func (c *client) Forward(ctx context.Context, req *ForwardRequest, resp *Forward func (c *client) Update(ctx context.Context, req *UpdateRequest) error { payload, err := proto.Marshal(req) if err != nil { - return duh.NewClientError(fmt.Errorf("while marshaling request payload: %w", err), nil) + return duh.NewClientError("while marshaling request payload: %w", err, nil) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.opts.Endpoint, RPCPeerUpdate), bytes.NewReader(payload)) if err != nil { - return duh.NewClientError(err, nil) + return duh.NewClientError("", err, nil) } r.Header.Set("Content-Type", duh.ContentTypeProtoBuf) @@ -227,10 +227,10 @@ func RandomPeer(peers []PeerInfo) PeerInfo { // RandomString returns a random alpha string of 'n' length func RandomString(n int) string { const alphanumeric = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - var bytes = make([]byte, n) - _, _ = crand.Read(bytes) - for i, b := range bytes { - bytes[i] = alphanumeric[b%byte(len(alphanumeric))] + var buf = make([]byte, n) + _, _ = crand.Read(buf) + for i, b := range buf { + buf[i] = alphanumeric[b%byte(len(alphanumeric))] } - return string(bytes) + return string(buf) } diff --git a/cluster/cluster.go b/cluster/cluster.go index a5cf998..028c3db 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -70,8 +70,9 @@ func GetRandomDaemon(dc string) *gubernator.Daemon { } if len(local) == 0 { - panic(fmt.Sprintf("failed to find random peer for dc '%s'", dc)) + panic(fmt.Sprintf("failed to find random daemon for dc '%s'", dc)) } + return local[rand.Intn(len(local))] } @@ -180,6 +181,7 @@ func StartWith(localPeers []gubernator.PeerInfo) error { HTTPListenAddress: peer.HTTPAddress, AdvertiseAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, + CacheProvider: "otter", Behaviors: gubernator.BehaviorConfig{ // Suitable for testing but not production GlobalSyncWait: clock.Millisecond * 50, diff --git a/cmd/gubernator/main_test.go b/cmd/gubernator/main_test.go index 65337da..77411ce 100644 --- a/cmd/gubernator/main_test.go +++ b/cmd/gubernator/main_test.go @@ -77,9 +77,9 @@ func TestCLI(t *testing.T) { time.Sleep(time.Second * 1) err = c.Process.Signal(syscall.SIGTERM) - require.NoError(t, err) <-waitCh + require.NoError(t, err, out.String()) assert.Contains(t, out.String(), tt.contains) }) } diff --git a/config.go b/config.go index 13e1baa..1d6044c 100644 --- a/config.go +++ b/config.go @@ -79,7 +79,7 @@ type Config struct { Behaviors BehaviorConfig // (Optional) The cache implementation - CacheFactory func(maxSize int) Cache + CacheFactory func(maxSize int) (Cache, error) // (Optional) A persistent store implementation. Allows the implementor the ability to store the rate limits this // instance of gubernator owns. It's up to the implementor to decide what rate limits to persist. @@ -134,8 +134,8 @@ func (c *Config) SetDefaults() error { setter.SetDefault(&c.Logger, logrus.New().WithField("category", "gubernator")) if c.CacheFactory == nil { - c.CacheFactory = func(maxSize int) Cache { - return NewLRUCache(maxSize) + c.CacheFactory = func(maxSize int) (Cache, error) { + return NewOtterCache(maxSize) } } @@ -260,6 +260,9 @@ type DaemonConfig struct { // (Optional) TraceLevel sets the tracing level, this controls the number of spans included in a single trace. // Valid options are (tracing.InfoLevel, tracing.DebugLevel) Defaults to tracing.InfoLevel TraceLevel tracing.Level + + // (Optional) CacheProvider specifies which cache implementation to store rate limits in + CacheProvider string } func (d *DaemonConfig) ClientTLS() *tls.Config { @@ -427,7 +430,10 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi setter.SetDefault(&conf.DNSPoolConf.ResolvConf, os.Getenv("GUBER_RESOLV_CONF"), "/etc/resolv.conf") setter.SetDefault(&conf.DNSPoolConf.OwnAddress, conf.AdvertiseAddress) + setter.SetDefault(&conf.CacheProvider, os.Getenv("GUBER_CACHE_PROVIDER"), "default-lru") + // PeerPicker Config + // TODO: Deprecated: Remove in GUBER_PEER_PICKER in v3 if pp := os.Getenv("GUBER_PEER_PICKER"); pp != "" { var replicas int var hash string diff --git a/daemon.go b/daemon.go index 59c9e63..321f8b0 100644 --- a/daemon.go +++ b/daemon.go @@ -72,11 +72,31 @@ func SpawnDaemon(ctx context.Context, conf DaemonConfig) (*Daemon, error) { func (d *Daemon) Start(ctx context.Context) error { var err error + // The cache for storing rate limits. registry := prometheus.NewRegistry() + cacheCollector := NewCacheCollector() + if err := registry.Register(cacheCollector); err != nil { + return errors.Wrap(err, "during call to promRegister.Register()") + } - // The LRU cache for storing rate limits. - cacheCollector := NewLRUCacheCollector() - registry.MustRegister(cacheCollector) + cacheFactory := func(maxSize int) (Cache, error) { + switch d.conf.CacheProvider { + case "default-lru": + cache := NewLRUCache(maxSize) + cacheCollector.AddCache(cache) + return cache, nil + case "otter", "": + cache, err := NewOtterCache(maxSize) + if err != nil { + return nil, err + } + cacheCollector.AddCache(cache) + return cache, nil + default: + return nil, errors.Errorf("'GUBER_CACHE_PROVIDER=%s' is invalid; "+ + "choices are ['otter', 'default-lru']", d.conf.CacheProvider) + } + } if err := SetupTLS(d.conf.TLS); err != nil { return err @@ -86,20 +106,16 @@ func (d *Daemon) Start(ctx context.Context) error { PeerClientFactory: func(info PeerInfo) PeerClient { return NewPeerClient(WithPeerInfo(info)) }, - CacheFactory: func(maxSize int) Cache { - cache := NewLRUCache(maxSize) - cacheCollector.AddCache(cache) - return cache - }, - DataCenter: d.conf.DataCenter, - InstanceID: d.conf.InstanceID, - CacheSize: d.conf.CacheSize, - Behaviors: d.conf.Behaviors, - Workers: d.conf.Workers, - LocalPicker: d.conf.Picker, - Loader: d.conf.Loader, - Store: d.conf.Store, - Logger: d.log, + CacheFactory: cacheFactory, + DataCenter: d.conf.DataCenter, + InstanceID: d.conf.InstanceID, + CacheSize: d.conf.CacheSize, + Behaviors: d.conf.Behaviors, + Workers: d.conf.Workers, + LocalPicker: d.conf.Picker, + Loader: d.conf.Loader, + Store: d.conf.Store, + Logger: d.log, }) if err != nil { return errors.Wrap(err, "while creating new gubernator service") diff --git a/example.conf b/example.conf index 3396a40..254db8d 100644 --- a/example.conf +++ b/example.conf @@ -264,4 +264,13 @@ GUBER_INSTANCE_ID= ############################ # OTEL_EXPORTER_OTLP_PROTOCOL=otlp # OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io -# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= \ No newline at end of file +# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= + +############################ +# Cache Providers +############################ +# +# Select the cache provider, available options are 'default-lru', 'otter' +# default-lru - A built in LRU implementation which uses a mutex +# otter - Is a lock-less cache implementation based on S3-FIFO algorithm (https://maypok86.github.io/otter/) +# GUBER_CACHE_PROVIDER=default-lru diff --git a/functional_test.go b/functional_test.go index 6cc8ec7..5234ba5 100644 --- a/functional_test.go +++ b/functional_test.go @@ -750,7 +750,12 @@ func TestLeakyBucketGregorian(t *testing.T) { }, } + // Truncate to the nearest minute. now := clock.Now() + trunc := now.Truncate(time.Hour) + trunc = now.Add(now.Sub(trunc)) + clock.Advance(now.Sub(trunc)) + for _, test := range tests { t.Run(test.Name, func(t *testing.T) { var resp guber.CheckRateLimitsResponse @@ -2280,8 +2285,14 @@ func waitForIdle(timeout clock.Duration, daemons ...*guber.Daemon) error { if err != nil { return err } - ggql := metrics["gubernator_global_queue_length"] - gsql := metrics["gubernator_global_send_queue_length"] + ggql, ok := metrics["gubernator_global_queue_length"] + if !ok { + return errors.New("gubernator_global_queue_length not found") + } + gsql, ok := metrics["gubernator_global_send_queue_length"] + if !ok { + return errors.New("gubernator_global_send_queue_length not found") + } if ggql.Value == 0 && gsql.Value == 0 { return nil @@ -2322,6 +2333,8 @@ func getPeerCounters(t *testing.T, peers []*guber.Daemon, name string) map[strin } func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitRequest, expectStatus guber.Status, expectRemaining int64) { + t.Helper() + if req.Hits != 0 { t.Logf("Sending %d hits to peer %s", req.Hits, d.InstanceID) } diff --git a/global.go b/global.go index 3329200..174f068 100644 --- a/global.go +++ b/global.go @@ -54,7 +54,7 @@ func newGlobalManager(conf BehaviorConfig, instance *Service) *globalManager { }), metricGlobalSendQueueLength: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "gubernator_global_send_queue_length", - Help: "The count of requests queued up for global broadcast. This is only used for GetRateLimit requests using global behavior.", + Help: "The count of requests queued up for global broadcast. This is only used for CheckRateLimit requests using global behavior.", }), metricBroadcastDuration: prometheus.NewSummary(prometheus.SummaryOpts{ Name: "gubernator_broadcast_duration", @@ -63,7 +63,7 @@ func newGlobalManager(conf BehaviorConfig, instance *Service) *globalManager { }), metricGlobalQueueLength: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "gubernator_global_queue_length", - Help: "The count of requests queued up for global broadcast. This is only used for GetRateLimit requests using global behavior.", + Help: "The count of requests queued up for global broadcast. This is only used for CheckRateLimit requests using global behavior.", }), } gm.runAsyncHits() @@ -242,9 +242,7 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] for _, update := range updates { grlReq := proto.Clone(update).(*RateLimitRequest) grlReq.Hits = 0 - - // Get current rate limit state. - state, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq, reqState) + state, err := gm.instance.cache.CheckRateLimit(ctx, grlReq, reqState) if err != nil { gm.log.WithError(err).Error("while retrieving rate limit status") continue diff --git a/go.mod b/go.mod index 28b40da..c179407 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/gubernator-io/gubernator/v3 go 1.22.0 require ( - github.com/OneOfOne/xxhash v1.2.8 github.com/davecgh/go-spew v1.1.1 - github.com/duh-rpc/duh-go v0.0.2-0.20230929155108-5d641b0c008a + github.com/duh-rpc/duh-go v0.1.0 github.com/hashicorp/memberlist v0.5.0 github.com/mailgun/errors v0.1.5 github.com/mailgun/holster/v4 v4.19.0 + github.com/maypok86/otter v1.2.1 github.com/miekg/dns v1.1.50 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 @@ -40,7 +40,9 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/dolthub/maphash v0.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/gammazero/deque v0.2.1 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/go.sum b/go.sum index e680f0f..4f94803 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= -github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8= -github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= @@ -106,8 +104,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= -github.com/duh-rpc/duh-go v0.0.2-0.20230929155108-5d641b0c008a h1:v/NQEfHHOY/huFECKxKZnEkY5jVD8Yix8TPa0FjgKbg= -github.com/duh-rpc/duh-go v0.0.2-0.20230929155108-5d641b0c008a/go.mod h1:OoCoGsZkeED84v8TAE86m2NM5ZfNLNlqUUm7tYO+h+k= +github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= +github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= +github.com/duh-rpc/duh-go v0.1.0 h1:Ym7XvNhl1CD6dgy+YWiPfhkOQGNzFsBsIc5uvYdF08c= +github.com/duh-rpc/duh-go v0.1.0/go.mod h1:OoCoGsZkeED84v8TAE86m2NM5ZfNLNlqUUm7tYO+h+k= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -130,6 +130,8 @@ github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoD github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/getkin/kin-openapi v0.76.0/go.mod h1:660oXbgy5JFMKreazJaQTw7o+X00qeSyhcnluiMv+Xg= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= @@ -308,6 +310,8 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/maypok86/otter v1.2.1 h1:xyvMW+t0vE1sKt/++GTkznLitEl7D/msqXkAbLwiC1M= +github.com/maypok86/otter v1.2.1/go.mod h1:mKLfoI7v1HOmQMwFgX4QkRk23mX6ge3RDvjdHOWG4R4= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= diff --git a/gubernator.go b/gubernator.go index ad0b523..f26fa2b 100644 --- a/gubernator.go +++ b/gubernator.go @@ -46,7 +46,7 @@ type Service struct { propagator propagation.TraceContext global *globalManager peerMutex sync.RWMutex - workerPool *WorkerPool + cache CacheManager log FieldLogger conf Config isClosed bool @@ -83,14 +83,6 @@ var ( Name: "gubernator_check_error_counter", Help: "The number of errors while checking rate limits.", }, []string{"error"}) - metricCommandCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_command_counter", - Help: "The count of commands processed by each worker in WorkerPool.", - }, []string{"worker", "method"}) - metricWorkerQueue = prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Name: "gubernator_worker_queue_length", - Help: "The count of requests queued up in WorkerPool.", - }, []string{"method", "worker"}) // Batch behavior. metricBatchSendRetries = prometheus.NewCounterVec(prometheus.CounterOpts{ @@ -123,7 +115,10 @@ func NewService(conf Config) (s *Service, err error) { conf: conf, } - s.workerPool = NewWorkerPool(&conf) + s.cache, err = NewCacheManager(conf) + if err != nil { + return nil, fmt.Errorf("during NewCacheManager(): %w", err) + } s.global = newGlobalManager(conf.Behaviors, s) if s.conf.Loader == nil { @@ -131,9 +126,9 @@ func NewService(conf Config) (s *Service, err error) { } // Load the cache. - err = s.workerPool.Load(ctx) + err = s.cache.Load(ctx) if err != nil { - return nil, fmt.Errorf("error in workerPool.Load: %w", err) + return nil, fmt.Errorf("error in CacheManager.Load: %w", err) } return s, nil @@ -147,19 +142,19 @@ func (s *Service) Close(ctx context.Context) (err error) { s.global.Close() if s.conf.Loader != nil { - err = s.workerPool.Store(ctx) + err = s.cache.Store(ctx) if err != nil { s.log.WithError(err). Error("Error in workerPool.Store") - return fmt.Errorf("error in workerPool.Store: %w", err) + return fmt.Errorf("error in CacheManager.Store: %w", err) } } - err = s.workerPool.Close() + err = s.cache.Close() if err != nil { s.log.WithError(err). Error("Error in workerPool.Close") - return fmt.Errorf("error in workerPool.Close: %w", err) + return fmt.Errorf("error in CacheManager.Close: %w", err) } // Close all the peer clients @@ -181,12 +176,12 @@ func (s *Service) CheckRateLimits(ctx context.Context, req *CheckRateLimitsReque if len(req.Requests) > maxBatchSize { metricCheckErrorCounter.WithLabelValues("Request too large").Add(1) return duh.NewServiceError(duh.CodeBadRequest, - fmt.Errorf("CheckRateLimitsRequest.RateLimits list too large; max size is '%d'", maxBatchSize), nil) + fmt.Sprintf("CheckRateLimitsRequest.RateLimits list too large; max size is '%d'", maxBatchSize), nil, nil) } if len(req.Requests) == 0 { return duh.NewServiceError(duh.CodeBadRequest, - errors.New("CheckRateLimitsRequest.RateLimits list is empty; provide at least one rate limit"), nil) + "CheckRateLimitsRequest.RateLimits list is empty; provide at least one rate limit", nil, nil) } resp.Responses = make([]*RateLimitResponse, len(req.Requests)) @@ -321,12 +316,13 @@ func (s *Service) asyncRequest(ctx context.Context, req *AsyncReq) { for { if attempts > 5 { - err = fmt.Errorf("GetPeer() keeps returning peers that are not connected for '%s': %w", req.Key, err) + err = fmt.Errorf("attempts exhausted while communicating with '%s' for '%s': %w", + req.Peer.Info().HTTPAddress, req.Key, err) s.log.WithContext(ctx). WithError(err). WithField("key", req.Key). - Error("GetPeer() returned peer that is not connected") - countError(err, "Peer not connected") + Error("attempts exhausted while communicating with peer") + countError(err, "peer communication failed") resp.Resp = &RateLimitResponse{Error: err.Error()} break } @@ -336,11 +332,11 @@ func (s *Service) asyncRequest(ctx context.Context, req *AsyncReq) { if reqState.IsOwner { resp.Resp, err = s.checkLocalRateLimit(ctx, req.Req, reqState) if err != nil { - err = fmt.Errorf("error in checkLocalRateLimit for '%s': %w", req.Key, err) + err = fmt.Errorf("during checkLocalRateLimit() for '%s': %w", req.Key, err) s.log.WithContext(ctx). WithError(err). WithField("key", req.Key). - Error("Error applying rate limit") + Error("while applying rate limit") resp.Resp = &RateLimitResponse{Error: err.Error()} } break @@ -348,7 +344,8 @@ func (s *Service) asyncRequest(ctx context.Context, req *AsyncReq) { } // Make an RPC call to the peer that owns this rate limit - r, err := req.Peer.Forward(ctx, req.Req) + var r *RateLimitResponse + r, err = req.Peer.Forward(ctx, req.Req) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, ErrPeerShutdown) { @@ -357,9 +354,9 @@ func (s *Service) asyncRequest(ctx context.Context, req *AsyncReq) { metricBatchSendRetries.WithLabelValues(req.Req.Name).Inc() req.Peer, err = s.GetPeer(ctx, req.Key) if err != nil { - err := fmt.Errorf("error finding peer that owns rate limit '%s': %w", req.Key, err) + err = fmt.Errorf("while finding peer that owns rate limit '%s': %w", req.Key, err) s.log.WithContext(ctx).WithError(err).WithField("key", req.Key).Error(err) - countError(err, "Error in GetPeer") + countError(err, "during GetPeer()") resp.Resp = &RateLimitResponse{Error: err.Error()} break } @@ -367,7 +364,7 @@ func (s *Service) asyncRequest(ctx context.Context, req *AsyncReq) { } // Not calling `countError()` because we expect the remote end to report this error. - err = fmt.Errorf("error while fetching rate limit '%s' from peer: %w", req.Key, err) + err = fmt.Errorf("while fetching rate limit '%s' from peer: %w", req.Key, err) resp.Resp = &RateLimitResponse{Error: err.Error()} break } @@ -418,18 +415,31 @@ func (s *Service) checkGlobalRateLimit(ctx context.Context, req *RateLimitReques // Update updates the local cache with a list of rate limit state from a peer // This method should only be called by a peer. -func (s *Service) Update(ctx context.Context, r *UpdateRequest, resp *v1.Reply) (err error) { +func (s *Service) Update(ctx context.Context, r *UpdateRequest, _ *v1.Reply) (err error) { ctx = tracing.StartNamedScopeDebug(ctx, "Service.Update") defer func() { tracing.EndScope(ctx, err) }() now := MillisecondNow() + for _, g := range r.Globals { - item := &CacheItem{ - ExpireAt: g.State.ResetTime, - Algorithm: g.Algorithm, - Key: g.Key, + item, _, err := s.cache.GetCacheItem(ctx, g.Key) + if err != nil { + return err + } + + if item == nil { + item = &CacheItem{ + ExpireAt: g.State.ResetTime, + Algorithm: g.Algorithm, + Key: g.Key, + } + err := s.cache.AddCacheItem(ctx, g.Key, item) + if err != nil { + return fmt.Errorf("during CacheManager.AddCacheItem(): %w", err) + } } + item.mutex.Lock() switch g.Algorithm { case Algorithm_LEAKY_BUCKET: item.Value = &LeakyBucketItem{ @@ -448,13 +458,8 @@ func (s *Service) Update(ctx context.Context, r *UpdateRequest, resp *v1.Reply) CreatedAt: now, } } - err := s.workerPool.AddCacheItem(ctx, g.Key, item) - if err != nil { - return fmt.Errorf("error in workerPool.AddCacheItem: %w", err) - } + item.mutex.Unlock() } - - resp.Code = duh.CodeOK return nil } @@ -466,7 +471,7 @@ func (s *Service) Forward(ctx context.Context, req *ForwardRequest, resp *Forwar if len(req.Requests) > maxBatchSize { metricCheckErrorCounter.WithLabelValues("Request too large").Add(1) return duh.NewServiceError(duh.CodeBadRequest, - fmt.Errorf("'Forward.requests' list too large; max size is '%d'", maxBatchSize), nil) + fmt.Sprintf("'Forward.requests' list too large; max size is '%d'", maxBatchSize), nil, nil) } // Invoke each rate limit request. @@ -582,9 +587,9 @@ func (s *Service) checkLocalRateLimit(ctx context.Context, r *RateLimitRequest, defer func() { tracing.EndScope(ctx, err) }() defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Service.checkLocalRateLimit")).ObserveDuration() - resp, err := s.workerPool.GetRateLimit(ctx, r, reqState) + resp, err := s.cache.CheckRateLimit(ctx, r, reqState) if err != nil { - return nil, fmt.Errorf("during workerPool.GetRateLimit: %w", err) + return nil, fmt.Errorf("during CacheManager.CheckRateLimit: %w", err) } // If global behavior, then broadcast update to all peers. @@ -725,12 +730,10 @@ func (s *Service) Describe(ch chan<- *prometheus.Desc) { metricBatchSendDuration.Describe(ch) metricBatchSendRetries.Describe(ch) metricCheckErrorCounter.Describe(ch) - metricCommandCounter.Describe(ch) metricConcurrentChecks.Describe(ch) metricFuncTimeDuration.Describe(ch) metricGetRateLimitCounter.Describe(ch) metricOverLimitCounter.Describe(ch) - metricWorkerQueue.Describe(ch) s.global.metricBroadcastDuration.Describe(ch) s.global.metricGlobalQueueLength.Describe(ch) s.global.metricGlobalSendDuration.Describe(ch) @@ -743,12 +746,10 @@ func (s *Service) Collect(ch chan<- prometheus.Metric) { metricBatchSendDuration.Collect(ch) metricBatchSendRetries.Collect(ch) metricCheckErrorCounter.Collect(ch) - metricCommandCounter.Collect(ch) metricConcurrentChecks.Collect(ch) metricFuncTimeDuration.Collect(ch) metricGetRateLimitCounter.Collect(ch) metricOverLimitCounter.Collect(ch) - metricWorkerQueue.Collect(ch) s.global.metricBroadcastDuration.Collect(ch) s.global.metricGlobalQueueLength.Collect(ch) s.global.metricGlobalSendDuration.Collect(ch) diff --git a/lru_cache.go b/lru_cache.go new file mode 100644 index 0000000..b1333a1 --- /dev/null +++ b/lru_cache.go @@ -0,0 +1,158 @@ +/* +Modifications Copyright 2024 Derrick Wippler + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +This work is derived from github.com/golang/groupcache/lru +*/ + +package gubernator + +import ( + "container/list" + "sync" + "sync/atomic" + + "github.com/mailgun/holster/v4/setter" +) + +// LRUCache is a mutex protected LRU cache that supports expiration and is thread-safe +type LRUCache struct { + cache map[string]*list.Element + ll *list.List + mu sync.Mutex + stats CacheStats + cacheSize int + cacheLen int64 +} + +var _ Cache = &LRUCache{} + +// NewLRUCache creates a new Cache with a maximum size. +func NewLRUCache(maxSize int) *LRUCache { + setter.SetDefault(&maxSize, 50_000) + + return &LRUCache{ + cache: make(map[string]*list.Element), + ll: list.New(), + cacheSize: maxSize, + } +} + +// Each maintains a goroutine that iterates over every item in the cache. +// Other go routines operating on this cache will block until all items +// are read from the returned channel. +func (c *LRUCache) Each() chan *CacheItem { + out := make(chan *CacheItem) + go func() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, ele := range c.cache { + out <- ele.Value.(*CacheItem) + } + close(out) + }() + return out +} + +// AddIfNotPresent adds the item to the cache if it doesn't already exist. +// Returns true if the item was added, false if the item already exists. +func (c *LRUCache) AddIfNotPresent(item *CacheItem) bool { + c.mu.Lock() + defer c.mu.Unlock() + + // If the key already exist, do nothing + if _, ok := c.cache[item.Key]; ok { + return false + } + + ele := c.ll.PushFront(item) + c.cache[item.Key] = ele + if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { + c.removeOldest() + } + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) + return true +} + +// GetItem returns the item stored in the cache +func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ele, hit := c.cache[key]; hit { + entry := ele.Value.(*CacheItem) + + c.stats.Hit++ + c.ll.MoveToFront(ele) + return entry, true + } + + c.stats.Miss++ + return +} + +// Remove removes the provided key from the cache. +func (c *LRUCache) Remove(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if ele, hit := c.cache[key]; hit { + c.removeElement(ele) + } +} + +// RemoveOldest removes the oldest item from the cache. +func (c *LRUCache) removeOldest() { + ele := c.ll.Back() + if ele != nil { + entry := ele.Value.(*CacheItem) + + if MillisecondNow() < entry.ExpireAt { + c.stats.UnexpiredEvictions++ + } + + c.removeElement(ele) + } +} + +func (c *LRUCache) removeElement(e *list.Element) { + c.ll.Remove(e) + kv := e.Value.(*CacheItem) + delete(c.cache, kv.Key) + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) +} + +// Size returns the number of items in the cache. +func (c *LRUCache) Size() int64 { + return atomic.LoadInt64(&c.cacheLen) +} + +func (c *LRUCache) Close() error { + c.cache = nil + c.ll = nil + c.cacheLen = 0 + return nil +} + +// Stats returns the current status for the cache +func (c *LRUCache) Stats() CacheStats { + c.mu.Lock() + defer func() { + c.stats = CacheStats{} + c.mu.Unlock() + }() + + c.stats.Size = atomic.LoadInt64(&c.cacheLen) + return c.stats +} diff --git a/lrucache_test.go b/lru_cache_test.go similarity index 87% rename from lrucache_test.go rename to lru_cache_test.go index d2a6622..427d22e 100644 --- a/lrucache_test.go +++ b/lru_cache_test.go @@ -27,17 +27,15 @@ import ( "github.com/gubernator-io/gubernator/v3" "github.com/mailgun/holster/v4/clock" "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - dto "github.com/prometheus/client_model/go" ) -func TestLRUCache(t *testing.T) { +func TestLRUMutexCache(t *testing.T) { const iterations = 1000 const concurrency = 100 expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() - var mutex sync.Mutex t.Run("Happy path", func(t *testing.T) { cache := gubernator.NewLRUCache(0) @@ -50,10 +48,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - exists := cache.Add(item) - mutex.Unlock() - assert.False(t, exists) + assert.True(t, cache.AddIfNotPresent(item)) } // Validate cache. @@ -61,9 +56,7 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() item, ok := cache.GetItem(key) - mutex.Unlock() require.True(t, ok) require.NotNil(t, item) assert.Equal(t, item.Value, i) @@ -72,9 +65,7 @@ func TestLRUCache(t *testing.T) { // Clear cache. for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() cache.Remove(key) - mutex.Unlock() } assert.Zero(t, cache.Size()) @@ -90,8 +81,7 @@ func TestLRUCache(t *testing.T) { Value: "initial value", ExpireAt: expireAt, } - exists1 := cache.Add(item1) - require.False(t, exists1) + require.True(t, cache.AddIfNotPresent(item1)) // Update same key. item2 := &gubernator.CacheItem{ @@ -99,8 +89,11 @@ func TestLRUCache(t *testing.T) { Value: "new value", ExpireAt: expireAt, } - exists2 := cache.Add(item2) - require.True(t, exists2) + require.False(t, cache.AddIfNotPresent(item2)) + + updateItem, ok := cache.GetItem(item1.Key) + require.True(t, ok) + updateItem.Value = "new value" // Verify. verifyItem, ok := cache.GetItem(key) @@ -119,8 +112,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(t, exists) + assert.True(t, cache.AddIfNotPresent(item)) } assert.Equal(t, int64(iterations), cache.Size()) @@ -136,9 +128,7 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() item, ok := cache.GetItem(key) - mutex.Unlock() assert.True(t, ok) require.NotNil(t, item) assert.Equal(t, item.Value, i) @@ -171,9 +161,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } }() } @@ -194,10 +182,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - exists := cache.Add(item) - mutex.Unlock() - assert.False(t, exists) + assert.True(t, cache.AddIfNotPresent(item)) } assert.Equal(t, int64(iterations), cache.Size()) @@ -213,9 +198,7 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() item, ok := cache.GetItem(key) - mutex.Unlock() assert.True(t, ok) require.NotNil(t, item) assert.Equal(t, item.Value, i) @@ -233,9 +216,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } }() } @@ -256,9 +237,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } assert.Equal(t, int64(iterations), cache.Size()) @@ -275,15 +254,11 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { // Get, cache hit. key := strconv.Itoa(i) - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() // Get, cache miss. key2 := strconv.Itoa(rand.Intn(1000) + 10000) - mutex.Lock() _, _ = cache.GetItem(key2) - mutex.Unlock() } }() @@ -299,9 +274,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) // Add new. key2 := strconv.Itoa(rand.Intn(1000) + 20000) @@ -310,13 +283,11 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item2) - mutex.Unlock() + cache.AddIfNotPresent(item2) } }() - collector := gubernator.NewLRUCacheCollector() + collector := gubernator.NewCacheCollector() collector.AddCache(cache) go func() { @@ -342,7 +313,7 @@ func TestLRUCache(t *testing.T) { promRegister := prometheus.NewRegistry() // The LRU cache for storing rate limits. - cacheCollector := gubernator.NewLRUCacheCollector() + cacheCollector := gubernator.NewCacheCollector() err := promRegister.Register(cacheCollector) require.NoError(t, err) @@ -351,7 +322,7 @@ func TestLRUCache(t *testing.T) { // fill cache with short duration cache items for i := 0; i < 10; i++ { - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: fmt.Sprintf("short-expiry-%d", i), Value: "bar", @@ -363,7 +334,7 @@ func TestLRUCache(t *testing.T) { clock.Advance(6 * time.Minute) // add a new cache item to force eviction - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: "evict1", Value: "bar", @@ -389,7 +360,7 @@ func TestLRUCache(t *testing.T) { promRegister := prometheus.NewRegistry() // The LRU cache for storing rate limits. - cacheCollector := gubernator.NewLRUCacheCollector() + cacheCollector := gubernator.NewCacheCollector() err := promRegister.Register(cacheCollector) require.NoError(t, err) @@ -398,7 +369,7 @@ func TestLRUCache(t *testing.T) { // fill cache with long duration cache items for i := 0; i < 10; i++ { - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: fmt.Sprintf("long-expiry-%d", i), Value: "bar", @@ -407,7 +378,7 @@ func TestLRUCache(t *testing.T) { } // add a new cache item to force eviction - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: "evict2", Value: "bar", @@ -428,8 +399,7 @@ func TestLRUCache(t *testing.T) { }) } -func BenchmarkLRUCache(b *testing.B) { - var mutex sync.Mutex +func BenchmarkLRUMutexCache(b *testing.B) { b.Run("Sequential reads", func(b *testing.B) { cache := gubernator.NewLRUCache(b.N) @@ -443,8 +413,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(b, exists) + assert.True(b, cache.AddIfNotPresent(item)) } b.ReportAllocs() @@ -452,9 +421,7 @@ func BenchmarkLRUCache(b *testing.B) { for i := 0; i < b.N; i++ { key := strconv.Itoa(i) - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() } }) @@ -472,9 +439,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } }) @@ -490,8 +455,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(b, exists) + assert.True(b, cache.AddIfNotPresent(item)) } var launchWg, doneWg sync.WaitGroup @@ -505,9 +469,7 @@ func BenchmarkLRUCache(b *testing.B) { defer doneWg.Done() launchWg.Wait() - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() }() } @@ -536,9 +498,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) }(i) } @@ -562,8 +522,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(b, exists) + assert.True(b, cache.AddIfNotPresent(item)) } for i := 0; i < b.N; i++ { @@ -574,9 +533,7 @@ func BenchmarkLRUCache(b *testing.B) { defer doneWg.Done() launchWg.Wait() - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() }() go func(i int) { @@ -588,9 +545,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) }(i) } @@ -614,9 +569,7 @@ func BenchmarkLRUCache(b *testing.B) { launchWg.Wait() key := strconv.Itoa(i) - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() }(i) go func(i int) { @@ -629,9 +582,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + _ = cache.AddIfNotPresent(item) }(i) } diff --git a/lrucache.go b/lrucache.go deleted file mode 100644 index 0386720..0000000 --- a/lrucache.go +++ /dev/null @@ -1,214 +0,0 @@ -/* -Modifications Copyright 2018-2022 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -This work is derived from github.com/golang/groupcache/lru -*/ - -package gubernator - -import ( - "container/list" - "sync/atomic" - - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/setter" - "github.com/prometheus/client_golang/prometheus" -) - -// LRUCache is an LRU cache that supports expiration and is not thread-safe -// Be sure to use a mutex to prevent concurrent method calls. -type LRUCache struct { - cache map[string]*list.Element - ll *list.List - cacheSize int - cacheLen int64 -} - -// LRUCacheCollector provides prometheus metrics collector for LRUCache. -// Register only one collector, add one or more caches to this collector. -type LRUCacheCollector struct { - caches []Cache -} - -var _ Cache = &LRUCache{} -var _ prometheus.Collector = &LRUCacheCollector{} - -var metricCacheSize = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "gubernator_cache_size", - Help: "The number of items in LRU Cache which holds the rate limits.", -}) -var metricCacheAccess = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_cache_access_count", - Help: "Cache access counts. Label \"type\" = hit|miss.", -}, []string{"type"}) -var metricCacheUnexpiredEvictions = prometheus.NewCounter(prometheus.CounterOpts{ - Name: "gubernator_unexpired_evictions_count", - Help: "Count the number of cache items which were evicted while unexpired.", -}) - -// NewLRUCache creates a new Cache with a maximum size. -func NewLRUCache(maxSize int) *LRUCache { - setter.SetDefault(&maxSize, 50_000) - - return &LRUCache{ - cache: make(map[string]*list.Element), - ll: list.New(), - cacheSize: maxSize, - } -} - -// Each is not thread-safe. Each() maintains a goroutine that iterates. -// Other go routines cannot safely access the Cache while iterating. -// It would be safer if this were done using an iterator or delegate pattern -// that doesn't require a goroutine. May need to reassess functional requirements. -func (c *LRUCache) Each() chan *CacheItem { - out := make(chan *CacheItem) - go func() { - for _, ele := range c.cache { - out <- ele.Value.(*CacheItem) - } - close(out) - }() - return out -} - -// Add adds a value to the cache. -func (c *LRUCache) Add(item *CacheItem) bool { - // If the key already exist, set the new value - if ee, ok := c.cache[item.Key]; ok { - c.ll.MoveToFront(ee) - ee.Value = item - return true - } - - ele := c.ll.PushFront(item) - c.cache[item.Key] = ele - if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { - c.removeOldest() - } - atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) - return false -} - -// MillisecondNow returns unix epoch in milliseconds -func MillisecondNow() int64 { - return clock.Now().UnixNano() / 1000000 -} - -// GetItem returns the item stored in the cache -func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { - if ele, hit := c.cache[key]; hit { - entry := ele.Value.(*CacheItem) - - if entry.IsExpired() { - c.removeElement(ele) - metricCacheAccess.WithLabelValues("miss").Add(1) - return - } - - metricCacheAccess.WithLabelValues("hit").Add(1) - c.ll.MoveToFront(ele) - return entry, true - } - - metricCacheAccess.WithLabelValues("miss").Add(1) - return -} - -// Remove removes the provided key from the cache. -func (c *LRUCache) Remove(key string) { - if ele, hit := c.cache[key]; hit { - c.removeElement(ele) - } -} - -// RemoveOldest removes the oldest item from the cache. -func (c *LRUCache) removeOldest() { - ele := c.ll.Back() - if ele != nil { - entry := ele.Value.(*CacheItem) - - if MillisecondNow() < entry.ExpireAt { - metricCacheUnexpiredEvictions.Add(1) - } - - c.removeElement(ele) - } -} - -func (c *LRUCache) removeElement(e *list.Element) { - c.ll.Remove(e) - kv := e.Value.(*CacheItem) - delete(c.cache, kv.Key) - atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) -} - -// Size returns the number of items in the cache. -func (c *LRUCache) Size() int64 { - return atomic.LoadInt64(&c.cacheLen) -} - -// UpdateExpiration updates the expiration time for the key -func (c *LRUCache) UpdateExpiration(key string, expireAt int64) bool { - if ele, hit := c.cache[key]; hit { - entry := ele.Value.(*CacheItem) - entry.ExpireAt = expireAt - return true - } - return false -} - -func (c *LRUCache) Close() error { - c.cache = nil - c.ll = nil - c.cacheLen = 0 - return nil -} - -func NewLRUCacheCollector() *LRUCacheCollector { - return &LRUCacheCollector{ - caches: []Cache{}, - } -} - -// AddCache adds a Cache object to be tracked by the collector. -func (collector *LRUCacheCollector) AddCache(cache Cache) { - collector.caches = append(collector.caches, cache) -} - -// Describe fetches prometheus metrics to be registered -func (collector *LRUCacheCollector) Describe(ch chan<- *prometheus.Desc) { - metricCacheSize.Describe(ch) - metricCacheAccess.Describe(ch) - metricCacheUnexpiredEvictions.Describe(ch) -} - -// Collect fetches metric counts and gauges from the cache -func (collector *LRUCacheCollector) Collect(ch chan<- prometheus.Metric) { - metricCacheSize.Set(collector.getSize()) - metricCacheSize.Collect(ch) - metricCacheAccess.Collect(ch) - metricCacheUnexpiredEvictions.Collect(ch) -} - -func (collector *LRUCacheCollector) getSize() float64 { - var size float64 - - for _, cache := range collector.caches { - size += float64(cache.Size()) - } - - return size -} diff --git a/mock_cache_test.go b/mock_cache_test.go index 96e9942..a2d8b5e 100644 --- a/mock_cache_test.go +++ b/mock_cache_test.go @@ -29,16 +29,11 @@ type MockCache struct { var _ guber.Cache = &MockCache{} -func (m *MockCache) Add(item *guber.CacheItem) bool { +func (m *MockCache) AddIfNotPresent(item *guber.CacheItem) bool { args := m.Called(item) return args.Bool(0) } -func (m *MockCache) UpdateExpiration(key string, expireAt int64) bool { - args := m.Called(key, expireAt) - return args.Bool(0) -} - func (m *MockCache) GetItem(key string) (value *guber.CacheItem, ok bool) { args := m.Called(key) retval, _ := args.Get(0).(*guber.CacheItem) @@ -60,6 +55,10 @@ func (m *MockCache) Size() int64 { return int64(args.Int(0)) } +func (m *MockCache) Stats() guber.CacheStats { + return guber.CacheStats{} +} + func (m *MockCache) Close() error { args := m.Called() return args.Error(0) diff --git a/otter.go b/otter.go new file mode 100644 index 0000000..0a191dc --- /dev/null +++ b/otter.go @@ -0,0 +1,108 @@ +package gubernator + +import ( + "fmt" + "sync/atomic" + + "github.com/mailgun/holster/v4/setter" + "github.com/maypok86/otter" +) + +type OtterCache struct { + cache otter.Cache[string, *CacheItem] + stats CacheStats +} + +// NewOtterCache returns a new cache backed by otter. If size is 0, then +// the cache is created with a default cache size. +func NewOtterCache(size int) (*OtterCache, error) { + // Default is 500k bytes in size + setter.SetDefault(&size, 500_000) + b, err := otter.NewBuilder[string, *CacheItem](size) + if err != nil { + return nil, fmt.Errorf("during otter.NewBuilder(): %w", err) + } + + o := &OtterCache{} + + b.DeletionListener(func(key string, value *CacheItem, cause otter.DeletionCause) { + if cause == otter.Size { + atomic.AddInt64(&o.stats.UnexpiredEvictions, 1) + } + }) + + b.Cost(func(key string, value *CacheItem) uint32 { + // The total size of the CacheItem and Bucket item is 104 bytes. + // See cache.go:CacheItem definition for details. + return uint32(104 + len(value.Key)) + }) + + o.cache, err = b.Build() + if err != nil { + return nil, fmt.Errorf("during otter.Builder.Build(): %w", err) + } + return o, nil +} + +// AddIfNotPresent adds a new CacheItem to the cache. The key must be provided via CacheItem.Key +// returns true if the item was added to the cache; false if the item was too large +// for the cache or if the key already exists in the cache. +func (o *OtterCache) AddIfNotPresent(item *CacheItem) bool { + return o.cache.SetIfAbsent(item.Key, item) +} + +// GetItem returns an item in the cache that corresponds to the provided key +func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { + item, ok := o.cache.Get(key) + if !ok { + atomic.AddInt64(&o.stats.Miss, 1) + return nil, false + } + + atomic.AddInt64(&o.stats.Hit, 1) + return item, true +} + +// Each returns a channel which the call can use to iterate through +// all the items in the cache. +func (o *OtterCache) Each() chan *CacheItem { + ch := make(chan *CacheItem) + + go func() { + o.cache.Range(func(_ string, v *CacheItem) bool { + ch <- v + return true + }) + close(ch) + }() + return ch +} + +// Remove explicitly removes and item from the cache. +// NOTE: A deletion call to otter requires a mutex to preform, +// if possible, avoid preforming explicit removal from the cache. +// Instead, prefer the item to be evicted naturally. +func (o *OtterCache) Remove(key string) { + o.cache.Delete(key) +} + +// Size return the current number of items in the cache +func (o *OtterCache) Size() int64 { + return int64(o.cache.Size()) +} + +// Stats returns the current cache stats and resets the values to zero +func (o *OtterCache) Stats() CacheStats { + var result CacheStats + result.UnexpiredEvictions = atomic.SwapInt64(&o.stats.UnexpiredEvictions, 0) + result.Miss = atomic.SwapInt64(&o.stats.Miss, 0) + result.Hit = atomic.SwapInt64(&o.stats.Hit, 0) + result.Size = int64(o.cache.Size()) + return result +} + +// Close closes the cache and all associated background processes +func (o *OtterCache) Close() error { + o.cache.Close() + return nil +} diff --git a/otter_test.go b/otter_test.go new file mode 100644 index 0000000..c426a2d --- /dev/null +++ b/otter_test.go @@ -0,0 +1,410 @@ +package gubernator_test + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/gubernator-io/gubernator/v3" + "github.com/mailgun/holster/v4/clock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOtterCache(t *testing.T) { + const iterations = 1000 + const concurrency = 100 + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + t.Run("Happy path", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + // Validate cache. + assert.Equal(t, int64(iterations), cache.Size()) + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + require.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + + // Clear cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + cache.Remove(key) + } + + assert.Zero(t, cache.Size()) + }) + + t.Run("Update an existing key", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + const key = "foobar" + + // Add key. + item1 := &gubernator.CacheItem{ + Key: key, + Value: "initial value", + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item1) + + // Update same key is refused + item2 := &gubernator.CacheItem{ + Key: key, + Value: "new value", + ExpireAt: expireAt, + } + assert.False(t, cache.AddIfNotPresent(item2)) + + // Fetch and update the CacheItem + update, ok := cache.GetItem(key) + assert.True(t, ok) + update.Value = "new value" + + // Verify. + verifyItem, ok := cache.GetItem(key) + require.True(t, ok) + assert.Equal(t, item2, verifyItem) + }) + + t.Run("Concurrent reads", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent writes", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent reads and writes", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + + go func() { + defer doneWg.Done() + launchWg.Wait() + + // Write different keys than the keys we are reading to avoid race on Add() / GetItem() + for i := iterations; i < iterations*2; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) +} + +func BenchmarkOtterCache(b *testing.B) { + + b.Run("Sequential reads", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + } + }) + + b.Run("Sequential writes", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }) + + b.Run("Concurrent reads", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent writes", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of existing keys", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of non-existent keys", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + doneWg.Add(2) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + }(i) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := "z" + strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) +} diff --git a/peer.go b/peer.go index b61c382..b3e46e2 100644 --- a/peer.go +++ b/peer.go @@ -51,10 +51,9 @@ type response struct { } type request struct { - reqState RateLimitContext - request *RateLimitRequest - ctx context.Context - resp chan *response + request *RateLimitRequest + ctx context.Context + resp chan *response } type PeerConfig struct { diff --git a/store.go b/store.go index 9feea2b..e8c18a9 100644 --- a/store.go +++ b/store.go @@ -16,7 +16,9 @@ limitations under the License. package gubernator -import "context" +import ( + "context" +) // PERSISTENT STORE DETAILS @@ -26,20 +28,23 @@ import "context" // and `Get()` to keep the in memory cache and persistent store up to date with the latest ratelimit data. // Both interfaces can be implemented simultaneously to ensure data is always saved to persistent storage. +// LeakyBucketItem is 40 bytes aligned in size type LeakyBucketItem struct { - Limit int64 - Duration int64 - Remaining float64 - UpdatedAt int64 - Burst int64 + Limit int64 // 8 bytes + Duration int64 // 8 bytes + Remaining float64 // 8 bytes + UpdatedAt int64 // 8 bytes + Burst int64 // 8 bytes } +// TokenBucketItem is 40 bytes aligned in size type TokenBucketItem struct { - Status Status - Limit int64 - Duration int64 - Remaining int64 - CreatedAt int64 + Limit int64 // 8 bytes + Duration int64 // 8 bytes + Remaining int64 // 8 bytes + CreatedAt int64 // 8 bytes + Status Status // 4 bytes + // 4 bytes of padding } // Store interface allows implementors to off load storage of all or a subset of ratelimits to @@ -47,18 +52,18 @@ type TokenBucketItem struct { // to maximize performance of gubernator. // Implementations MUST be threadsafe. type Store interface { - // Called by gubernator *after* a rate limit item is updated. It's up to the store to + // OnChange is called by gubernator *after* a rate limit item is updated. It's up to the store to // decide if this rate limit item should be persisted in the store. It's up to the // store to expire old rate limit items. The CacheItem represents the current state of // the rate limit item *after* the RateLimitRequest has been applied. OnChange(ctx context.Context, r *RateLimitRequest, item *CacheItem) - // Called by gubernator when a rate limit is missing from the cache. It's up to the store + // Get is called by gubernator when a rate limit is missing from the cache. It's up to the store // to decide if this request is fulfilled. Should return true if the request is fulfilled // and false if the request is not fulfilled or doesn't exist in the store. Get(ctx context.Context, r *RateLimitRequest) (*CacheItem, bool) - // Called by gubernator when an existing rate limit should be removed from the store. + // Remove ic called by gubernator when an existing rate limit should be removed from the store. // NOTE: This is NOT called when an rate limit expires from the cache, store implementors // must expire rate limits in the store. Remove(ctx context.Context, key string) diff --git a/store_test.go b/store_test.go index 6fa3921..9685f54 100644 --- a/store_test.go +++ b/store_test.go @@ -18,6 +18,8 @@ package gubernator_test import ( "context" + "fmt" + "sync" "testing" "github.com/gubernator-io/gubernator/v3" @@ -83,6 +85,108 @@ func TestLoader(t *testing.T) { assert.Equal(t, gubernator.Status_UNDER_LIMIT, item.Status) } +type NoOpStore struct{} + +func (ms *NoOpStore) Remove(ctx context.Context, key string) {} +func (ms *NoOpStore) OnChange(ctx context.Context, r *gubernator.RateLimitRequest, item *gubernator.CacheItem) { +} + +func (ms *NoOpStore) Get(ctx context.Context, r *gubernator.RateLimitRequest) (*gubernator.CacheItem, bool) { + return &gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Key: r.HashKey(), + Value: gubernator.TokenBucketItem{ + CreatedAt: gubernator.MillisecondNow(), + Duration: gubernator.Minute * 60, + Limit: 1_000, + Remaining: 1_000, + Status: 0, + }, + ExpireAt: 0, + }, true +} + +// The goal of this test is to generate some race conditions where multiple routines load from the store and or +// add items to the cache in parallel thus creating a race condition the code must then handle. +func TestHighContentionFromStore(t *testing.T) { + const ( + // Increase these number to improve the chance of contention, but at the cost of test speed. + numGoroutines = 150 + numKeys = 100 + ) + store := &NoOpStore{} + d, err := gubernator.SpawnDaemon(context.Background(), gubernator.DaemonConfig{ + HTTPListenAddress: "localhost:0", + Behaviors: gubernator.BehaviorConfig{ + // Suitable for testing but not production + GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production + GlobalTimeout: clock.Second, + }, + Store: store, + }) + require.NoError(t, err) + d.SetPeers([]gubernator.PeerInfo{{HTTPAddress: d.Config().HTTPListenAddress, IsOwner: true}}) + + keys := GenerateRandomKeys(numKeys) + + var wg sync.WaitGroup + var ready sync.WaitGroup + wg.Add(numGoroutines) + ready.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + // Create a client for each concurrent request to avoid contention in the client + client, err := gubernator.NewClient(gubernator.WithNoTLS(d.Listener.Addr().String())) + require.NoError(t, err) + ready.Wait() + for idx := 0; idx < numKeys; idx++ { + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 1, + }, + }, + }, &resp) + if err != nil { + // NOTE: you may see `connection reset by peer` if the server is overloaded + // and needs to forcibly drop some connections due to out of open file handlers etc... + fmt.Printf("%s\n", err) + } + } + wg.Done() + }() + ready.Done() + } + wg.Wait() + + for idx := 0; idx < numKeys; idx++ { + var resp gubernator.CheckRateLimitsResponse + err := d.MustClient().CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 0, + }, + }, + }, &resp) + require.NoError(t, err) + assert.Equal(t, int64(0), resp.Responses[0].Remaining) + } + + assert.NoError(t, d.Close(context.Background())) +} + func TestStore(t *testing.T) { ctx := context.Background() setup := func() (*MockStore2, *gubernator.Daemon, gubernator.Client) { diff --git a/tls.go b/tls.go index 46cd2d7..25f951e 100644 --- a/tls.go +++ b/tls.go @@ -259,17 +259,20 @@ func SetupTLS(conf *TLSConfig) error { // If user asked for client auth if conf.ClientAuth != tls.NoClientCert { clientPool := x509.NewCertPool() + var certProvided bool if conf.ClientAuthCaPEM != nil { // If client auth CA was provided clientPool.AppendCertsFromPEM(conf.ClientAuthCaPEM.Bytes()) + certProvided = true } else if conf.CaPEM != nil { // else use the servers CA clientPool.AppendCertsFromPEM(conf.CaPEM.Bytes()) + certProvided = true } - // error if neither was provided - if len(clientPool.Subjects()) == 0 { //nolint:all + // error if neither cert was provided + if !certProvided { return errors.New("client auth enabled, but no CA's provided") } diff --git a/workers.go b/workers.go deleted file mode 100644 index 4a2b381..0000000 --- a/workers.go +++ /dev/null @@ -1,626 +0,0 @@ -/* -Copyright 2018-2022 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -// Thread-safe worker pool for handling concurrent Gubernator requests. -// Ensures requests are synchronized to avoid caching conflicts. -// Handle concurrent requests by sharding cache key space across multiple -// workers. -// Uses hash ring design pattern to distribute requests to an assigned worker. -// No mutex locking necessary because each worker has its own data space and -// processes requests sequentially. -// -// Request workflow: -// - A 63-bit hash is generated from an incoming request by its Key/Name -// values. (Actually 64 bit, but we toss out one bit to properly calculate -// the next step.) -// - Workers are assigned equal size hash ranges. The worker is selected by -// choosing the worker index associated with that linear hash value range. -// - The worker has command channels for each method call. The request is -// enqueued to the appropriate channel. -// - The worker pulls the request from the appropriate channel and executes the -// business logic for that method. Then, it sends a response back using the -// requester's provided response channel. - -import ( - "context" - "io" - "strconv" - "sync" - "sync/atomic" - - "github.com/OneOfOne/xxhash" - "github.com/mailgun/holster/v4/errors" - "github.com/mailgun/holster/v4/setter" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/trace" -) - -type WorkerPool struct { - hasher workerHasher - workers []*Worker - workerCacheSize int - hashRingStep uint64 - conf *Config - done chan struct{} -} - -type Worker struct { - name string - conf *Config - cache Cache - getRateLimitRequestuest chan request - storeRequest chan workerStoreRequest - loadRequest chan workerLoadRequest - addCacheItemRequest chan workerAddCacheItemRequest - getCacheItemRequest chan workerGetCacheItemRequest -} - -type workerHasher interface { - // ComputeHash63 returns a 63-bit hash derived from input. - ComputeHash63(input string) uint64 -} - -// hasher is the default implementation of workerHasher. -type hasher struct{} - -// Method request/response structs. -type workerStoreRequest struct { - ctx context.Context - response chan workerStoreResponse - out chan<- *CacheItem -} - -type workerStoreResponse struct{} - -type workerLoadRequest struct { - ctx context.Context - response chan workerLoadResponse - in <-chan *CacheItem -} - -type workerLoadResponse struct{} - -type workerAddCacheItemRequest struct { - ctx context.Context - response chan workerAddCacheItemResponse - item *CacheItem -} - -type workerAddCacheItemResponse struct { - exists bool -} - -type workerGetCacheItemRequest struct { - ctx context.Context - response chan workerGetCacheItemResponse - key string -} - -type workerGetCacheItemResponse struct { - item *CacheItem - ok bool -} - -var _ io.Closer = &WorkerPool{} -var _ workerHasher = &hasher{} - -var workerCounter int64 - -func NewWorkerPool(conf *Config) *WorkerPool { - setter.SetDefault(&conf.CacheSize, 50_000) - - // Compute hashRingStep as interval between workers' 63-bit hash ranges. - // 64th bit is used here as a max value that is just out of range of 63-bit space to calculate the step. - chp := &WorkerPool{ - workers: make([]*Worker, conf.Workers), - workerCacheSize: conf.CacheSize / conf.Workers, - hasher: newHasher(), - hashRingStep: uint64(1<<63) / uint64(conf.Workers), - conf: conf, - done: make(chan struct{}), - } - - // Create workers. - conf.Logger.Debugf("Starting %d Gubernator workers...", conf.Workers) - for i := 0; i < conf.Workers; i++ { - chp.workers[i] = chp.newWorker() - go chp.dispatch(chp.workers[i]) - } - - return chp -} - -func newHasher() *hasher { - return &hasher{} -} - -func (ph *hasher) ComputeHash63(input string) uint64 { - return xxhash.ChecksumString64S(input, 0) >> 1 -} - -func (p *WorkerPool) Close() error { - close(p.done) - return nil -} - -// Create a new pool worker instance. -func (p *WorkerPool) newWorker() *Worker { - worker := &Worker{ - conf: p.conf, - cache: p.conf.CacheFactory(p.workerCacheSize), - getRateLimitRequestuest: make(chan request), - storeRequest: make(chan workerStoreRequest), - loadRequest: make(chan workerLoadRequest), - addCacheItemRequest: make(chan workerAddCacheItemRequest), - getCacheItemRequest: make(chan workerGetCacheItemRequest), - } - workerNumber := atomic.AddInt64(&workerCounter, 1) - 1 - worker.name = strconv.FormatInt(workerNumber, 10) - return worker -} - -// getWorker Returns the request channel associated with the key. -// Hash the key, then lookup hash ring to find the worker. -func (p *WorkerPool) getWorker(key string) *Worker { - hash := p.hasher.ComputeHash63(key) - idx := hash / p.hashRingStep - return p.workers[idx] -} - -// Pool worker for processing Gubernator requests. -// Each worker maintains its own state. -// A hash ring will distribute requests to an assigned worker by key. -// See: getWorker() -func (p *WorkerPool) dispatch(worker *Worker) { - for { - // Dispatch requests from each channel. - select { - case req, ok := <-worker.getRateLimitRequestuest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - resp := new(response) - resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, req.reqState, worker.cache) - select { - case req.resp <- resp: - // Success. - - case <-req.ctx.Done(): - // Context canceled. - trace.SpanFromContext(req.ctx).RecordError(resp.err) - } - metricCommandCounter.WithLabelValues(worker.name, "GetRateLimit").Inc() - - case req, ok := <-worker.storeRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleStore(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "Store").Inc() - - case req, ok := <-worker.loadRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleLoad(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "Load").Inc() - - case req, ok := <-worker.addCacheItemRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleAddCacheItem(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "AddCacheItem").Inc() - - case req, ok := <-worker.getCacheItemRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleGetCacheItem(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "GetCacheItem").Inc() - - case <-p.done: - // Clean up. - return - } - } -} - -// GetRateLimit sends a GetRateLimit request to worker pool. -func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitRequest, reqState RateLimitContext) (*RateLimitResponse, error) { - // Delegate request to assigned channel based on request key. - worker := p.getWorker(rlRequest.HashKey()) - queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - handlerRequest := request{ - ctx: ctx, - resp: make(chan *response, 1), - request: rlRequest, - reqState: reqState, - } - - // Send request. - select { - case worker.getRateLimitRequestuest <- handlerRequest: - // Successfully sent request. - case <-ctx.Done(): - return nil, ctx.Err() - } - - // Wait for response. - select { - case handlerResponse := <-handlerRequest.resp: - // Successfully read response. - return handlerResponse.rl, handlerResponse.err - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// Handle request received by worker. -func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitRequest, reqState RateLimitContext, cache Cache) (*RateLimitResponse, error) { - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration() - var rlResponse *RateLimitResponse - var err error - - switch req.Algorithm { - case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, reqState) - if err != nil { - msg := "Error in tokenBucket" - countError(err, msg) - err = errors.Wrap(err, msg) - trace.SpanFromContext(ctx).RecordError(err) - } - - case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, reqState) - if err != nil { - msg := "Error in leakyBucket" - countError(err, msg) - err = errors.Wrap(err, msg) - trace.SpanFromContext(ctx).RecordError(err) - } - - default: - err = errors.Errorf("Invalid rate limit algorithm '%d'", req.Algorithm) - trace.SpanFromContext(ctx).RecordError(err) - metricCheckErrorCounter.WithLabelValues("Invalid algorithm").Add(1) - } - - return rlResponse, err -} - -// Load atomically loads cache from persistent storage. -// Read from persistent storage. Load into each appropriate worker's cache. -// Workers are locked during this load operation to prevent race conditions. -func (p *WorkerPool) Load(ctx context.Context) (err error) { - queueGauge := metricWorkerQueue.WithLabelValues("Load", "") - queueGauge.Inc() - defer queueGauge.Dec() - ch, err := p.conf.Loader.Load() - if err != nil { - return errors.Wrap(err, "Error in loader.Load") - } - - type loadChannel struct { - ch chan *CacheItem - worker *Worker - respChan chan workerLoadResponse - } - - // Map request channel hash to load channel. - loadChMap := map[*Worker]loadChannel{} - - // Send each item to the assigned channel's cache. -MAIN: - for { - var item *CacheItem - var ok bool - - select { - case item, ok = <-ch: - if !ok { - break MAIN - } - // Successfully received item. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - - worker := p.getWorker(item.Key) - - // Initiate a load channel with each worker. - loadCh, exist := loadChMap[worker] - if !exist { - loadCh = loadChannel{ - ch: make(chan *CacheItem), - worker: worker, - respChan: make(chan workerLoadResponse), - } - loadChMap[worker] = loadCh - - // Tie up the worker while loading. - worker.loadRequest <- workerLoadRequest{ - ctx: ctx, - response: loadCh.respChan, - in: loadCh.ch, - } - } - - // Send item to worker's load channel. - select { - case loadCh.ch <- item: - // Successfully sent item. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - } - - // Clean up. - for _, loadCh := range loadChMap { - close(loadCh.ch) - - // Load response confirms all items have been loaded and the worker - // resumes normal operation. - select { - case <-loadCh.respChan: - // Successfully received response. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - } - - return nil -} - -func (worker *Worker) handleLoad(request workerLoadRequest, cache Cache) { -MAIN: - for { - var item *CacheItem - var ok bool - - select { - case item, ok = <-request.in: - if !ok { - break MAIN - } - // Successfully received item. - - case <-request.ctx.Done(): - // Context canceled. - return - } - - cache.Add(item) - } - - response := workerLoadResponse{} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// Store atomically stores cache to persistent storage. -// Save all workers' caches to persistent storage. -// Workers are locked during this store operation to prevent race conditions. -func (p *WorkerPool) Store(ctx context.Context) (err error) { - queueGauge := metricWorkerQueue.WithLabelValues("Store", "") - queueGauge.Inc() - defer queueGauge.Dec() - var wg sync.WaitGroup - out := make(chan *CacheItem, 500) - - // Iterate each worker's cache to `out` channel. - for _, worker := range p.workers { - wg.Add(1) - - go func(ctx context.Context, worker *Worker) { - defer wg.Done() - - respChan := make(chan workerStoreResponse) - req := workerStoreRequest{ - ctx: ctx, - response: respChan, - out: out, - } - - select { - case worker.storeRequest <- req: - // Successfully sent request. - select { - case <-respChan: - // Successfully received response. - return - - case <-ctx.Done(): - // Context canceled. - trace.SpanFromContext(ctx).RecordError(ctx.Err()) - return - } - - case <-ctx.Done(): - // Context canceled. - trace.SpanFromContext(ctx).RecordError(ctx.Err()) - return - } - }(ctx, worker) - } - - // When all iterators are done, close `out` channel. - go func() { - wg.Wait() - close(out) - }() - - if ctx.Err() != nil { - return ctx.Err() - } - - if err = p.conf.Loader.Save(out); err != nil { - return errors.Wrap(err, "while calling p.conf.Loader.Save()") - } - - return nil -} - -func (worker *Worker) handleStore(request workerStoreRequest, cache Cache) { - for item := range cache.Each() { - select { - case request.out <- item: - // Successfully sent item. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - return - } - } - - response := workerStoreResponse{} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// AddCacheItem adds an item to the worker's cache. -func (p *WorkerPool) AddCacheItem(ctx context.Context, key string, item *CacheItem) (err error) { - worker := p.getWorker(key) - queueGauge := metricWorkerQueue.WithLabelValues("AddCacheItem", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - respChan := make(chan workerAddCacheItemResponse) - req := workerAddCacheItemRequest{ - ctx: ctx, - response: respChan, - item: item, - } - - select { - case worker.addCacheItemRequest <- req: - // Successfully sent request. - select { - case <-respChan: - // Successfully received response. - return nil - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } -} - -func (worker *Worker) handleAddCacheItem(request workerAddCacheItemRequest, cache Cache) { - exists := cache.Add(request.item) - response := workerAddCacheItemResponse{exists} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// GetCacheItem gets item from worker's cache. -func (p *WorkerPool) GetCacheItem(ctx context.Context, key string) (item *CacheItem, found bool, err error) { - worker := p.getWorker(key) - queueGauge := metricWorkerQueue.WithLabelValues("GetCacheItem", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - respChan := make(chan workerGetCacheItemResponse) - req := workerGetCacheItemRequest{ - ctx: ctx, - response: respChan, - key: key, - } - - select { - case worker.getCacheItemRequest <- req: - // Successfully sent request. - select { - case resp := <-respChan: - // Successfully received response. - return resp.item, resp.ok, nil - - case <-ctx.Done(): - // Context canceled. - return nil, false, ctx.Err() - } - - case <-ctx.Done(): - // Context canceled. - return nil, false, ctx.Err() - } -} - -func (worker *Worker) handleGetCacheItem(request workerGetCacheItemRequest, cache Cache) { - item, ok := cache.GetItem(request.key) - response := workerGetCacheItemResponse{item, ok} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} diff --git a/workers_internal_test.go b/workers_internal_test.go deleted file mode 100644 index 291971a..0000000 --- a/workers_internal_test.go +++ /dev/null @@ -1,84 +0,0 @@ -/* -Copyright 2024 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -type MockHasher struct { - mock.Mock -} - -func (m *MockHasher) ComputeHash63(input string) uint64 { - args := m.Called(input) - retval, _ := args.Get(0).(uint64) - return retval -} - -func TestWorkersInternal(t *testing.T) { - t.Run("getWorker()", func(t *testing.T) { - const concurrency = 32 - conf := &Config{ - Workers: concurrency, - } - require.NoError(t, conf.SetDefaults()) - - // Test that getWorker() interpolates the hash to find the expected worker. - testCases := []struct { - Name string - Hash uint64 - ExpectedIdx int - }{ - {"Hash 0%", 0, 0}, - {"Hash 50%", 0x3fff_ffff_ffff_ffff, (concurrency / 2) - 1}, - {"Hash 50% + 1", 0x4000_0000_0000_0000, concurrency / 2}, - {"Hash 100%", 0x7fff_ffff_ffff_ffff, concurrency - 1}, - } - - for _, testCase := range testCases { - t.Run(testCase.Name, func(t *testing.T) { - pool := NewWorkerPool(conf) - defer pool.Close() - mockHasher := &MockHasher{} - pool.hasher = mockHasher - - // Setup mocks. - mockHasher.On("ComputeHash63", mock.Anything).Once().Return(testCase.Hash) - - // Call code. - worker := pool.getWorker("Foobar") - - // Verify - require.NotNil(t, worker) - - var actualIdx int - for ; actualIdx < len(pool.workers); actualIdx++ { - if pool.workers[actualIdx] == worker { - break - } - } - assert.Equal(t, testCase.ExpectedIdx, actualIdx) - mockHasher.AssertExpectations(t) - }) - } - }) -}