From 659c61e6d6c41d314946d3a37d7f5680ac97cc26 Mon Sep 17 00:00:00 2001 From: Shiming Zhang Date: Tue, 11 Jun 2024 14:42:29 +0800 Subject: [PATCH] Support spped limit --- cmd/crproxy/main.go | 48 +++++++++++- crproxy.go | 170 +++++++++++++++++++++++++++++++++++++----- internal/maps/maps.go | 104 ++++++++++++++++++++++++++ 3 files changed, 300 insertions(+), 22 deletions(-) create mode 100644 internal/maps/maps.go diff --git a/cmd/crproxy/main.go b/cmd/crproxy/main.go index bf17fe2..c30e565 100644 --- a/cmd/crproxy/main.go +++ b/cmd/crproxy/main.go @@ -30,7 +30,9 @@ var ( address string userpass []string disableKeepAlives []string + limitDelay bool blobsSpeedLimit string + ipsSpeedLimit string totalBlobsSpeedLimit string blockImageList []string retry int @@ -46,7 +48,9 @@ func init() { pflag.StringSliceVarP(&userpass, "user", "u", nil, "host and username and password -u user:pwd@host") pflag.StringVarP(&address, "address", "a", ":8080", "listen on the address") pflag.StringSliceVar(&disableKeepAlives, "disable-keep-alives", nil, "disable keep alives for the host") + pflag.BoolVar(&limitDelay, "limit-delay", false, "limit with delay") pflag.StringVar(&blobsSpeedLimit, "blobs-speed-limit", "", "blobs speed limit per second (default unlimited)") + pflag.StringVar(&ipsSpeedLimit, "ips-speed-limit", "", "ips speed limit per second (default unlimited)") pflag.StringVar(&totalBlobsSpeedLimit, "total-blobs-speed-limit", "", "total blobs speed limit per second (default unlimited)") pflag.StringSliceVar(&blockImageList, "block-image-list", nil, "block image list") pflag.IntVar(&retry, "retry", 0, "retry times") @@ -162,13 +166,22 @@ func main() { opts = append(opts, crproxy.WithUserAndPass(bc)) } + if ipsSpeedLimit != "" { + b, d, err := getLimit(ipsSpeedLimit) + if err != nil { + logger.Println("failed to FromHumanSize:", err) + os.Exit(1) + } + opts = append(opts, crproxy.WithIPsSpeedLimit(b, d)) + } + if blobsSpeedLimit != "" { - b, err := geario.FromHumanSize(blobsSpeedLimit) + b, d, err := getLimit(blobsSpeedLimit) if err != nil { logger.Println("failed to FromHumanSize:", err) os.Exit(1) } - opts = append(opts, crproxy.WithBlobsSpeedLimit(b)) + opts = append(opts, crproxy.WithBlobsSpeedLimit(b, d)) } if totalBlobsSpeedLimit != "" { @@ -183,6 +196,9 @@ func main() { if retry > 0 { opts = append(opts, crproxy.WithRetry(retry, retryInterval)) } + if limitDelay { + opts = append(opts, crproxy.WithLimitDelay(true)) + } crp, err := crproxy.NewCRProxy(opts...) if err != nil { @@ -211,3 +227,31 @@ func main() { os.Exit(1) } } + +func getLimit(s string) (geario.B, time.Duration, error) { + i := strings.Index(s, "/") + if i == -1 { + b, err := geario.FromHumanSize(s) + if err != nil { + return 0, 0, err + } + return b, time.Second, nil + } + + b, err := geario.FromHumanSize(s[:i]) + if err != nil { + return 0, 0, err + } + + dur := s[i+1:] + if dur[0] < '0' || dur[0] > '9' { + dur = "1" + dur + } + + d, err := time.ParseDuration(dur) + if err != nil { + return 0, 0, err + } + + return b, d, nil +} diff --git a/crproxy.go b/crproxy.go index 61b1aa9..5e9b0fd 100644 --- a/crproxy.go +++ b/crproxy.go @@ -14,6 +14,7 @@ import ( "strings" "sync" "time" + "strconv" "github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/client/auth" @@ -23,12 +24,12 @@ import ( "github.com/wzshiming/geario" "github.com/wzshiming/httpseek" "github.com/wzshiming/lru" + "github.com/wzshiming/crproxy/internal/maps" ) var ( - prefix = "/v2/" - catalog = prefix + "_catalog" - speedLimitDuration = time.Second + prefix = "/v2/" + catalog = prefix + "_catalog" ) type Logger interface { @@ -50,7 +51,11 @@ type CRProxy struct { bytesPool sync.Pool logger Logger totalBlobsSpeedLimit *geario.Gear + speedLimitRecord maps.SyncMap[string, *geario.BPS] blobsSpeedLimit *geario.B + blobsSpeedLimitDuration time.Duration + ipsSpeedLimit *geario.B + ipsSpeedLimitDuration time.Duration blockFunc func(*PathInfo) bool retry int retryInterval time.Duration @@ -58,10 +63,17 @@ type CRProxy struct { linkExpires time.Duration mutCache sync.Map redirectLinks *url.URL + limitDelay bool } type Option func(c *CRProxy) +func WithLimitDelay(b bool) Option { + return func(c *CRProxy) { + c.limitDelay = b + } +} + func WithLinkExpires(d time.Duration) Option { return func(c *CRProxy) { c.linkExpires = d @@ -80,15 +92,23 @@ func WithStorageDriver(storageDriver storagedriver.StorageDriver) Option { } } -func WithBlobsSpeedLimit(limit geario.B) Option { +func WithBlobsSpeedLimit(limit geario.B, duration time.Duration) Option { return func(c *CRProxy) { c.blobsSpeedLimit = &limit + c.blobsSpeedLimitDuration = duration + } +} + +func WithIPsSpeedLimit(limit geario.B, duration time.Duration) Option { + return func(c *CRProxy) { + c.ipsSpeedLimit = &limit + c.ipsSpeedLimitDuration = duration } } func WithTotalBlobsSpeedLimit(limit geario.B) Option { return func(c *CRProxy) { - c.totalBlobsSpeedLimit = geario.NewGear(speedLimitDuration, limit) + c.totalBlobsSpeedLimit = geario.NewGear(time.Second, limit) } } @@ -401,6 +421,10 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { r.URL.Scheme = c.getScheme(info.Host) r.URL.Path = path + if !c.checkLimit(rw, r, info) { + return + } + if c.storageDriver != nil && info.Blobs != "" { c.cacheBlobResponse(rw, r, info) return @@ -437,6 +461,8 @@ func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info * rw.WriteHeader(resp.StatusCode) if r.Method != http.MethodHead { + c.accumulativeLimit(rw, r, info, resp.ContentLength) + buf := c.bytesPool.Get().([]byte) defer c.bytesPool.Put(buf) var body io.Reader = resp.Body @@ -446,7 +472,7 @@ func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info * } if c.blobsSpeedLimit != nil && info.Blobs != "" { - body = geario.NewGear(speedLimitDuration, *c.blobsSpeedLimit).Reader(body) + body = geario.NewGear(c.blobsSpeedLimitDuration, *c.blobsSpeedLimit).Reader(body) } io.CopyBuffer(rw, body, buf) @@ -481,8 +507,9 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf close(closeCh) } - _, err := c.storageDriver.Stat(ctx, blobPath) + stat, err := c.storageDriver.Stat(ctx, blobPath) if err == nil { + c.accumulativeLimit(rw, r, info, stat.Size()) err = c.redirect(rw, r, blobPath) if err == nil { doneCache() @@ -496,23 +523,33 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf c.logger.Println("Cache miss", blobPath) } - errCh := make(chan error, 1) + type repo struct { + err error + size int64 + } + signalCh := make(chan repo, 1) go func() { defer doneCache() - err = c.cacheBlobContent(r, blobPath, info) - errCh <- err + size, err := c.cacheBlobContent(r, blobPath, info) + signalCh <- repo{ + err: err, + size: size, + } }() select { case <-ctx.Done(): c.errorResponse(rw, r, ctx.Err()) return - case err := <-errCh: - if err != nil { - c.errorResponse(rw, r, err) + case signal := <-signalCh: + if signal.err != nil { + c.errorResponse(rw, r, signal.err) return } + + c.accumulativeLimit(rw, r, info, signal.size) + err = c.redirect(rw, r, blobPath) if err != nil { if c.logger != nil { @@ -523,11 +560,11 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf } } -func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathInfo) error { +func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathInfo) (int64, error) { cli := c.getClientset(info.Host, info.Image) resp, err := c.doWithAuth(cli, r, info.Host) if err != nil { - return err + return 0, err } defer func() { resp.Body.Close() @@ -538,28 +575,32 @@ func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathI fw, err := c.storageDriver.Writer(context.Background(), blobPath, false) if err != nil { - return err + return 0, err } h := sha256.New() n, err := io.CopyBuffer(fw, io.TeeReader(resp.Body, h), buf) if err != nil { fw.Cancel() - return err + return 0, err } if n != resp.ContentLength { fw.Cancel() - return fmt.Errorf("expected %d bytes, got %d", resp.ContentLength, n) + return 0, fmt.Errorf("expected %d bytes, got %d", resp.ContentLength, n) } hash := hex.EncodeToString(h.Sum(nil)[:]) if info.Blobs[7:] != hash { fw.Cancel() - return fmt.Errorf("expected %s hash, got %s", info.Blobs[7:], hash) + return 0, fmt.Errorf("expected %s hash, got %s", info.Blobs[7:], hash) } - return fw.Commit() + err = fw.Commit() + if err != nil { + return 0, err + } + return n, nil } func (c *CRProxy) errorResponse(rw http.ResponseWriter, r *http.Request, err error) { @@ -576,6 +617,95 @@ func (c *CRProxy) notFoundResponse(rw http.ResponseWriter, r *http.Request) { http.NotFound(rw, r) } +var ( + ErrorCodeTooManyRequests = errcode.ErrorCodeTooManyRequests + + ErrorCodeTooManyBandwidthsByBlob = errcode.Register("errcode", errcode.ErrorDescriptor{ + Value: "TOOMANYBANDWIDTHS", + Message: "blob too many bandwidths", + Description: `Blobs are accessed too much`, + HTTPStatusCode: http.StatusTooManyRequests, + }) +) + +func addr(str string) string { + i := strings.LastIndex(str, ":") + if i <= 0 { + return "" + } + return str[:i] +} + +func (c *CRProxy) checkLimit(rw http.ResponseWriter, r *http.Request, info *PathInfo) bool { + if c.blobsSpeedLimit != nil && info.Blobs != "" { + bps, _ := c.speedLimitRecord.LoadOrStore(info.Blobs, geario.NewBPSAver(c.blobsSpeedLimitDuration)) + aver := bps.Aver() + if aver > *c.blobsSpeedLimit { + if c.logger != nil { + c.logger.Println("exceed limit", info.Blobs, aver, *c.blobsSpeedLimit) + } + if c.limitDelay { + select { + case <-r.Context().Done(): + return false + case <-time.After(bps.Next().Sub(time.Now())): + } + + } else { + err := ErrorCodeTooManyBandwidthsByBlob + rw.Header().Set("X-Retry-After", strconv.FormatInt(bps.Next().Unix(), 10)) + errcode.ServeJSON(rw, err) + return false + } + } + } + + if c.ipsSpeedLimit != nil && info.Blobs != "" { + address := addr(r.RemoteAddr) + bps, _ := c.speedLimitRecord.LoadOrStore(address, geario.NewBPSAver(c.ipsSpeedLimitDuration)) + aver := bps.Aver() + if aver > *c.ipsSpeedLimit { + if c.logger != nil { + c.logger.Println("exceed limit", address, aver, *c.ipsSpeedLimit) + } + if c.limitDelay { + select { + case <-r.Context().Done(): + return false + case <-time.After(bps.Next().Sub(time.Now())): + } + } else { + err := ErrorCodeTooManyRequests + rw.Header().Set("X-Retry-After", strconv.FormatInt(bps.Next().Unix(), 10)) + errcode.ServeJSON(rw, err) + return false + } + } + } + + return true +} + +func (c *CRProxy) accumulativeLimit(rw http.ResponseWriter, r *http.Request, info *PathInfo, size int64) { + if r.Method != http.MethodGet { + return + } + + if c.blobsSpeedLimit != nil && info.Blobs != "" { + bps, ok := c.speedLimitRecord.Load(info.Blobs) + if ok { + bps.Add(geario.B(size)) + } + } + + if c.ipsSpeedLimit != nil && info.Blobs != "" { + bps, ok := c.speedLimitRecord.Load(addr(r.RemoteAddr)) + if ok { + bps.Add(geario.B(size)) + } + } +} + func (c *CRProxy) redirect(rw http.ResponseWriter, r *http.Request, blobPath string) error { options := map[string]interface{}{ "method": r.Method, diff --git a/internal/maps/maps.go b/internal/maps/maps.go new file mode 100644 index 0000000..f806f5f --- /dev/null +++ b/internal/maps/maps.go @@ -0,0 +1,104 @@ +package maps + +import ( + "sync" +) + +// SyncMap is a wrapper around sync.Map that provides a few additional methods. +type SyncMap[K comparable, V any] struct { + m sync.Map +} + +// Load returns the value stored in the map for a key, +// or nil if no value is present. +func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) + if !ok { + return value, false + } + return v.(V), true +} + +// Store sets the value for a key. +func (m *SyncMap[K, V]) Store(key K, value V) { + m.m.Store(key, value) +} + +// Delete deletes the value for a key. +func (m *SyncMap[K, V]) Delete(key K) { + m.m.Delete(key) +} + +// Range calls f sequentially for each key and value present in the map. +func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value interface{}) bool { + return f(key.(K), value.(V)) + }) +} + +// LoadAndDelete deletes the value for a key, returning the previous value if any. +func (m *SyncMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + v, loaded := m.m.LoadAndDelete(key) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +// LoadOrStore returns the existing value for the key if present. +func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (V, bool) { + v, loaded := m.m.LoadOrStore(key, value) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +// Swap stores value for key and returns the previous value for that key. +func (m *SyncMap[K, V]) Swap(key K, value V) (V, bool) { + v, loaded := m.m.Swap(key, value) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +// Size returns the number of items in the map. +func (m *SyncMap[K, V]) Size() int { + size := 0 + m.m.Range(func(key, value interface{}) bool { + size++ + return true + }) + return size +} + +// Keys returns all the keys in the map. +func (m *SyncMap[K, V]) Keys() []K { + keys := []K{} + m.m.Range(func(key, value interface{}) bool { + keys = append(keys, key.(K)) + return true + }) + return keys +} + +// Values returns all the values in the map. +func (m *SyncMap[K, V]) Values() []V { + values := []V{} + m.m.Range(func(key, value interface{}) bool { + values = append(values, value.(V)) + return true + }) + return values +} + +// IsEmpty returns true if the map is empty. +func (m *SyncMap[K, V]) IsEmpty() bool { + empty := true + m.m.Range(func(key, value interface{}) bool { + empty = false + return false + }) + return empty +}