Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat ratelimiter #46

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
hooks:
- id: yamlfmt
- repo: https://github.com/crate-ci/typos
rev: v1.23.6
rev: v1.24.5
hooks:
- id: typos
- repo: local
Expand Down
111 changes: 111 additions & 0 deletions pkg/middleware/ratelimiter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# ratelimiter

Rate limiter for any resource type (not just http requests), inspired by Cloudflare's approach: [How we built rate limiting capable of scaling to millions of domains.](https://blog.cloudflare.com/counting-things-a-lot-of-different-things/)

## Usage

```go
package main

import (
"fmt"
"log"
"time"

"github.com/theopenlane/core/pkg/middleware/ratelimiter"
)

func main() {
limitedKey := "key"
windowSize := 1 * time.Minute
// create map data store for rate limiter and set each element's expiration time to 2*windowSize and old data flush interval to 10*time.Second
dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second)

var maxLimit int64 = 5
// allow 5 requests per windowSize (1 minute)
rateLimiter := ratelimiter.New(dataStore, maxLimit, windowSize)

for i := 0; i < 10; i++ {
limitStatus, err := rateLimiter.Check(limitedKey)
if err != nil {
log.Fatal(err)
}

if limitStatus.IsLimited {
fmt.Printf("too high rate for key: %s: rate: %f, limit: %d\nsleep: %s", limitedKey, limitStatus.CurrentRate, maxLimit, *limitStatus.LimitDuration)
time.Sleep(*limitStatus.LimitDuration)
} else {
err := rateLimiter.Inc(limitedKey)
if err != nil {
log.Fatal(err)
}
}
}
}
```

### Rate-limit IP requests in http middleware

```go
func rateLimitMiddleware(rateLimiter *ratelimiter.RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteIP := GetRemoteIP([]string{"X-Forwarded-For", "RemoteAddr", "X-Real-IP"}, 0, r)
key := fmt.Sprintf("%s_%s_%s", remoteIP, r.URL.String(), r.Method)

limitStatus, err := rateLimiter.Check(key)
if err != nil {
// if rate limit error then pass the request
next.ServeHTTP(w, r)
}

if limitStatus.IsLimited {
w.WriteHeader(http.StatusTooManyRequests)
return
}

rateLimiter.Inc(key)
next.ServeHTTP(w, r)
})
}
}

func hello(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
}

func main() {
windowSize := 1 * time.Minute
// create map data store for rate limiter and set each element's expiration time to 2*windowSize and old data flush interval to 10*time.Second
dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second)
// allow 5 requests per windowSize (1 minute)
rateLimiter := ratelimiter.New(dataStore, 5, windowSize)

rateLimiterHandler := rateLimitMiddleware(rateLimiter)
helloHandler := http.HandlerFunc(hello)
http.Handle("/", rateLimiterHandler(helloHandler))

log.Fatal(http.ListenAndServe(":8080", nil))

}
```
See full [example](./examples/http_middleware/http_middleware.go)

### Implement your own limit data store (or external persistence method)

To use custom data store (memcached, Redis, MySQL etc.) you just need to implement the `LimitStore` interface, for example:

```go
type FakeDataStore struct{}

func (f FakeDataStore) Inc(key string, window time.Time) error {
return nil
}

func (f FakeDataStore) Get(key string, previousWindow, currentWindow time.Time) (prevValue int64, currValue int64, err error) {
return 0, 0, nil
}

rateLimiter := ratelimiter.New(FakeDataStore{}, maxLimit, windowSize)

```
2 changes: 2 additions & 0 deletions pkg/middleware/ratelimiter/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package ratelimiter is a ratelimiter based on cloudflare's approach
package ratelimiter
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package main
matoszz marked this conversation as resolved.
Show resolved Hide resolved

import (
"fmt"
"html"
"log"
"net"
"net/http"
"strings"
"time"

"github.com/theopenlane/core/pkg/middleware/ratelimiter"
)

// GetRemoteIP returns the remote IP address of the request
func GetRemoteIP(ipLookups []string, forwardedForIndexFromBehind int, r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")

for _, lookup := range ipLookups {
if lookup == "RemoteAddr" {
// 1. Cover the basic use cases for both ipv4 and ipv6
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// 2. Upon error, just return the remote addr.
return r.RemoteAddr
}
return ip
}
if lookup == "X-Forwarded-For" && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
parts := strings.Split(forwardedFor, ",")
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}

partIndex := len(parts) - 1 - forwardedForIndexFromBehind
if partIndex < 0 {
partIndex = 0
}

return parts[partIndex]
}
if lookup == "X-Real-IP" && realIP != "" {
return realIP
}
}

return ""
}

// rateLimitMiddleware is a middleware that rate limits the requests
func rateLimitMiddleware(rateLimiter *ratelimiter.RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteIP := GetRemoteIP([]string{"X-Forwarded-For", "RemoteAddr", "X-Real-IP"}, 0, r)
key := fmt.Sprintf("%s_%s_%s", remoteIP, r.URL.String(), r.Method)

limitStatus, err := rateLimiter.Check(key)
if err != nil {
// if rate limit error then pass the request
next.ServeHTTP(w, r)
}
if limitStatus.IsLimited {
w.WriteHeader(http.StatusTooManyRequests)
return
}

if err := rateLimiter.Inc(key); err != nil {
log.Printf("could not increment key: %s", key)
}

next.ServeHTTP(w, r)
})
}
}

func hello(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
}

func main() {
windowSize := 1 * time.Minute
dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second)
rateLimiter := ratelimiter.New(dataStore, 5, windowSize)
rateLimiterHandler := rateLimitMiddleware(rateLimiter)
helloHandler := http.HandlerFunc(hello)
http.Handle("/", rateLimiterHandler(helloHandler))
log.Fatal(http.ListenAndServe(":8080", nil))
}
36 changes: 36 additions & 0 deletions pkg/middleware/ratelimiter/examples/simple/simple.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package main

import (
"fmt"
"log"
"time"

"github.com/theopenlane/core/pkg/middleware/ratelimiter"
)

func main() {
limitedKey := "key"
windowSize := 1 * time.Minute
// create map data store for rate limiter and set each element's expiration time to 2*windowSize and old data flush interval to 10*time.Second
dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second)

var maxLimit int64 = 5
// allow 5 requests per windowSize (1 minute)
rateLimiter := ratelimiter.New(dataStore, maxLimit, windowSize)

for i := 0; i < 10; i++ {
limitStatus, err := rateLimiter.Check(limitedKey)
if err != nil {
log.Fatal(err)
}
if limitStatus.IsLimited {
fmt.Printf("too high rate for key: %s: rate: %f, limit: %d\nsleep: %s", limitedKey, limitStatus.CurrentRate, maxLimit, *limitStatus.LimitDuration)
time.Sleep(*limitStatus.LimitDuration)
} else {
err := rateLimiter.Inc(limitedKey)
if err != nil {
log.Fatal(err)
}
}
}
}
87 changes: 87 additions & 0 deletions pkg/middleware/ratelimiter/map_limit_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package ratelimiter

import (
"fmt"
"sync"
"time"
)

// MapLimitStore represents a data structure for in-memory storage of ratelimiter information
type MapLimitStore struct {
// data is a map of key to limitValue
data map[string]limitValue
// mutex is a mutex for data map
mutex sync.RWMutex
// expirationTime is the time after which the data is considered expired
expirationTime time.Duration
}

// limitValue represents value of the limit counter
type limitValue struct {
val int64
lastUpdate time.Time
}

// NewMapLimitStore creates new in-memory data store for internal limiter data
func NewMapLimitStore(expirationTime time.Duration, flushInterval time.Duration) (m *MapLimitStore) {
m = &MapLimitStore{
data: make(map[string]limitValue),
expirationTime: expirationTime,
}
matoszz marked this conversation as resolved.
Show resolved Hide resolved

go func() {
ticker := time.NewTicker(flushInterval)

for range ticker.C {
m.mutex.Lock()
for key, val := range m.data {
if val.lastUpdate.Before(time.Now().UTC().Add(-m.expirationTime)) {
delete(m.data, key)
}
}

m.mutex.Unlock()
}
}()

return m
}

// Inc increments current window limit counter
func (m *MapLimitStore) Inc(key string, window time.Time) error {
m.mutex.Lock()

defer m.mutex.Unlock()

data := m.data[mapKey(key, window)]
data.val++
data.lastUpdate = time.Now().UTC()
m.data[mapKey(key, window)] = data

return nil
}

// Get gets value of previous window counter and current window counter
func (m *MapLimitStore) Get(key string, previousWindow, currentWindow time.Time) (prevValue int64, currValue int64, err error) {
m.mutex.RLock()

defer m.mutex.RUnlock()

prevValue = m.data[mapKey(key, previousWindow)].val
currValue = m.data[mapKey(key, currentWindow)].val

return
}

// Size returns current length of data map
func (m *MapLimitStore) Size() int {
m.mutex.RLock()
defer m.mutex.RUnlock()

return len(m.data)
}

// mapKey creates a key for the map
func mapKey(key string, window time.Time) string {
return fmt.Sprintf("%s_%s", key, window.Format(time.RFC3339))
}
Loading