Skip to content

Commit

Permalink
Merge pull request #205 from denistakeda/issue-63-rate-limiter Introd…
Browse files Browse the repository at this point in the history
…uce rate limiter
  • Loading branch information
rekby authored Apr 3, 2023
2 parents 0130084 + b8213c1 commit 5f23f42
Show file tree
Hide file tree
Showing 34 changed files with 3,238 additions and 2 deletions.
14 changes: 14 additions & 0 deletions cmd/static/default-config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ HTTPSBackend = false
# Ignore backend https certificate validations if HTTPSBackend is true
HTTPSBackendIgnoreCert = true

# Maximum amount of requests per host in a unit of time defined at "RateLimitTimeWindow".
# 0 means no rate limit
RateLimit = 0

# Value in milliseconds
RateLimitTimeWindowMs = 1000

# The number of requests per host that can be handled in paralles. The default value 0
# means the value will be the same as RateLimit
RateLimitBurst = 0

# The size of LRU cache for the rate limiting information
RateLimitCacheSize = 100000

[CheckDomains]

# Allow domain if it resolver for one of public IPs of this server.
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ require (
)

require (
github.com/hashicorp/golang-lru/v2 v2.0.1
github.com/jonboulle/clockwork v0.4.0
github.com/letsencrypt/pebble/v2 v2.4.0
github.com/rekby/fastuuid v0.9.0
golang.org/x/time v0.3.0
)

require (
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,16 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4=
github.com/hashicorp/golang-lru/v2 v2.0.1/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hexdigest/gowrap v1.1.7/go.mod h1:Z+nBFUDLa01iaNM+/jzoOA1JJ7sm51rnYFauKFUB5fs=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4=
github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
Expand Down Expand Up @@ -417,6 +421,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
Expand Down
18 changes: 16 additions & 2 deletions internal/proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type Config struct {
HTTPSBackend bool
HTTPSBackendIgnoreCert bool
EnableAccessLog bool
RateLimit int
RateLimitTimeWindowMs int
RateLimitBurst int
RateLimitCacheSize int
}

func (c *Config) Apply(ctx context.Context, p *HTTPProxy) error {
Expand All @@ -42,11 +46,21 @@ func (c *Config) Apply(ctx context.Context, p *HTTPProxy) error {
chain = append(chain, director)
}

rateLimiter, resErr := NewRateLimiter(RateLimitParams{
RateLimit: c.RateLimit,
TimeWindow: time.Duration(c.RateLimitTimeWindowMs) * time.Millisecond,
Burst: c.RateLimitBurst,
CacheSize: c.RateLimitCacheSize,
})

appendDirector(c.getDefaultTargetDirector)
appendDirector(c.getMapDirector)
appendDirector(c.getHeadersDirector)
appendDirector(c.getSchemaDirector)
p.HTTPTransport = Transport{c.HTTPSBackendIgnoreCert}
p.HTTPTransport = Transport{
IgnoreHTTPSCertificate: c.HTTPSBackendIgnoreCert,
RateLimiter: rateLimiter,
}
p.EnableAccessLog = c.EnableAccessLog

if resErr != nil {
Expand Down Expand Up @@ -90,7 +104,7 @@ func (c *Config) getDefaultTargetDirector(ctx context.Context) (Director, error)
return NewDirectorHost(defaultTarget.String()), nil
}

//can return nil,nil
// can return nil,nil
func (c *Config) getHeadersDirector(ctx context.Context) (Director, error) {
logger := zc.L(ctx)

Expand Down
93 changes: 93 additions & 0 deletions internal/proxy/rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package proxy

import (
"net/http"
"sync"
"time"

"github.com/hashicorp/golang-lru/v2"
"github.com/jonboulle/clockwork"
"golang.org/x/time/rate"
)

type RateLimiter struct {
rateLimit int
timeWindow time.Duration
burst int

clock clockwork.Clock
mx sync.RWMutex
cache *lru.Cache[string, *rate.Limiter]
}

type RateLimitParams struct {
RateLimit int
TimeWindow time.Duration
Burst int
CacheSize int
Clock clockwork.Clock
}

func NewRateLimiter(params RateLimitParams) (*RateLimiter, error) {
if params.RateLimit == 0 {
return &RateLimiter{}, nil
}

cache, err := lru.New[string, *rate.Limiter](params.CacheSize)
if err != nil {
return nil, err
}

self := &RateLimiter{
rateLimit: params.RateLimit,
timeWindow: params.TimeWindow,
burst: params.Burst,
cache: cache,
clock: params.Clock,
}

if self.clock == nil {
self.clock = clockwork.NewRealClock()
}

return self, nil
}

func (rl *RateLimiter) Allow(r *http.Request) bool {
if rl.rateLimit == 0 {
return true
}
if r.Context().Err() != nil {
return false
}

return rl.getLimiter(r).AllowN(rl.clock.Now(), 1)
}

func (rl *RateLimiter) getLimiter(r *http.Request) *rate.Limiter {
rl.mx.RLock()
ip := getIP(r)

limiter, ok := rl.cache.Get(ip)
if ok {
rl.mx.RUnlock()
return limiter
}

rl.mx.RUnlock()
rl.mx.Lock()
defer rl.mx.Unlock()

// we need to check cache again to avoid data race
limiter, ok = rl.cache.Get(ip)
if !ok {
limiter = rate.NewLimiter(rate.Limit(float64(rl.rateLimit)/rl.timeWindow.Seconds()), rl.burst)
rl.cache.Add(ip, limiter)
}

return limiter
}

func getIP(r *http.Request) string {
return r.RemoteAddr
}
161 changes: 161 additions & 0 deletions internal/proxy/rate_limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package proxy

import (
"context"
"net/http"
"testing"
"time"

"github.com/jonboulle/clockwork"
"github.com/maxatome/go-testdeep"
)

func TestRateLimiter_Allow(t *testing.T) {
t.Run("should trigger an error if cache size is less than zero", func(t *testing.T) {
_, err := NewRateLimiter(RateLimitParams{RateLimit: 1, CacheSize: -1})
testdeep.CmpError(t, err)
})

t.Run("should always allow if rate limit is zero", func(t *testing.T) {
td := testdeep.NewT(t)
rateLimiter, err := NewRateLimiter(RateLimitParams{RateLimit: 0})
req, _ := http.NewRequest(http.MethodGet, "http://url.com", nil)

td.CmpNoError(err)
td.True(rateLimiter.Allow(req))
})
}

func TestMaxRequestsPerSec(t *testing.T) {
req1, _ := http.NewRequest("GET", "http://url1.com", nil)
req1.RemoteAddr = "ip1"

req2, _ := http.NewRequest("GET", "http://url2.com", nil)
req2.RemoteAddr = "ip2"

ctx, cancel := context.WithCancel(context.Background())
cancel()
canceledReq, _ := http.NewRequestWithContext(ctx, "GET", "http://canceled.com", nil)
canceledReq.RemoteAddr = "canceled"

type reqSpec struct {
req *http.Request
wantAllowedRequestAround int // expected the limiter to allow +-1
}

tests := []struct {
name string

rateLimit int
timeWindow time.Duration
testTime time.Duration

reqSpecs []reqSpec
}{
{
name: "should limit the amount of requests per second",

rateLimit: 10,
timeWindow: time.Second,
testTime: time.Second,

reqSpecs: []reqSpec{
{
req: req1,
wantAllowedRequestAround: 10,
},
},
},
{
name: "should restart the timer for the next time window",

rateLimit: 10,
timeWindow: 500 * time.Millisecond,
testTime: time.Second,

reqSpecs: []reqSpec{
{
req: req1,
wantAllowedRequestAround: 20,
},
},
},
{
name: "requests from different IPs should NOT influence each other",

rateLimit: 10,
timeWindow: time.Second,
testTime: time.Second,

reqSpecs: []reqSpec{
{
req: req1,
wantAllowedRequestAround: 10,
},
{
req: req2,
wantAllowedRequestAround: 10,
},
},
},
{
name: "canceled request should always fail",

rateLimit: 10,
timeWindow: time.Second,
testTime: time.Second,

reqSpecs: []reqSpec{
{
req: canceledReq,
wantAllowedRequestAround: 0,
},
},
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

// Preparations
startTime := time.Now()
clock := clockwork.NewFakeClockAt(startTime)
limiter, err := NewRateLimiter(RateLimitParams{
RateLimit: tt.rateLimit,
TimeWindow: tt.timeWindow,
Burst: 1,
CacheSize: 100,
Clock: clock,
})
testdeep.CmpNoError(t, err)

successCounters := make([]int, len(tt.reqSpecs))
reqCounters := make([]int, len(tt.reqSpecs))

// The test itself
for clock.Since(startTime) < tt.testTime {
for idx, spec := range tt.reqSpecs {
reqCounters[idx]++
if limiter.Allow(spec.req) {
successCounters[idx]++
}
}
clock.Advance(time.Millisecond)
}

// Check the expectations
for idx, spec := range tt.reqSpecs {
testdeep.CmpBetween(
t,
successCounters[idx],
spec.wantAllowedRequestAround-1,
spec.wantAllowedRequestAround+1,
testdeep.BoundsInIn,
)
testdeep.CmpGt(t, reqCounters[idx], successCounters[idx])
}
})
}
}
13 changes: 13 additions & 0 deletions internal/proxy/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,22 @@ var defaultHTTPTransport = defaultTransport()

type Transport struct {
IgnoreHTTPSCertificate bool
RateLimiter *RateLimiter
}

func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if !t.RateLimiter.Allow(req) {
return &http.Response{
Status: "429 Too Many Requests",
StatusCode: http.StatusTooManyRequests,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Request: req,
Header: make(http.Header, 0),
}, nil
}

return t.getTransport(req).RoundTrip(req)
}

Expand Down
Loading

0 comments on commit 5f23f42

Please sign in to comment.