Skip to content

Commit

Permalink
Support spped limit
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jun 11, 2024
1 parent 0628026 commit 659c61e
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 22 deletions.
48 changes: 46 additions & 2 deletions cmd/crproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ var (
address string
userpass []string
disableKeepAlives []string
limitDelay bool
blobsSpeedLimit string
ipsSpeedLimit string
totalBlobsSpeedLimit string
blockImageList []string
retry int
Expand All @@ -46,7 +48,9 @@ func init() {
pflag.StringSliceVarP(&userpass, "user", "u", nil, "host and username and password -u user:pwd@host")
pflag.StringVarP(&address, "address", "a", ":8080", "listen on the address")
pflag.StringSliceVar(&disableKeepAlives, "disable-keep-alives", nil, "disable keep alives for the host")
pflag.BoolVar(&limitDelay, "limit-delay", false, "limit with delay")
pflag.StringVar(&blobsSpeedLimit, "blobs-speed-limit", "", "blobs speed limit per second (default unlimited)")
pflag.StringVar(&ipsSpeedLimit, "ips-speed-limit", "", "ips speed limit per second (default unlimited)")
pflag.StringVar(&totalBlobsSpeedLimit, "total-blobs-speed-limit", "", "total blobs speed limit per second (default unlimited)")
pflag.StringSliceVar(&blockImageList, "block-image-list", nil, "block image list")
pflag.IntVar(&retry, "retry", 0, "retry times")
Expand Down Expand Up @@ -162,13 +166,22 @@ func main() {
opts = append(opts, crproxy.WithUserAndPass(bc))
}

if ipsSpeedLimit != "" {
b, d, err := getLimit(ipsSpeedLimit)
if err != nil {
logger.Println("failed to FromHumanSize:", err)
os.Exit(1)
}
opts = append(opts, crproxy.WithIPsSpeedLimit(b, d))
}

if blobsSpeedLimit != "" {
b, err := geario.FromHumanSize(blobsSpeedLimit)
b, d, err := getLimit(blobsSpeedLimit)
if err != nil {
logger.Println("failed to FromHumanSize:", err)
os.Exit(1)
}
opts = append(opts, crproxy.WithBlobsSpeedLimit(b))
opts = append(opts, crproxy.WithBlobsSpeedLimit(b, d))
}

if totalBlobsSpeedLimit != "" {
Expand All @@ -183,6 +196,9 @@ func main() {
if retry > 0 {
opts = append(opts, crproxy.WithRetry(retry, retryInterval))
}
if limitDelay {
opts = append(opts, crproxy.WithLimitDelay(true))
}

crp, err := crproxy.NewCRProxy(opts...)
if err != nil {
Expand Down Expand Up @@ -211,3 +227,31 @@ func main() {
os.Exit(1)
}
}

func getLimit(s string) (geario.B, time.Duration, error) {
i := strings.Index(s, "/")
if i == -1 {
b, err := geario.FromHumanSize(s)
if err != nil {
return 0, 0, err
}
return b, time.Second, nil
}

b, err := geario.FromHumanSize(s[:i])
if err != nil {
return 0, 0, err
}

dur := s[i+1:]
if dur[0] < '0' || dur[0] > '9' {
dur = "1" + dur
}

d, err := time.ParseDuration(dur)
if err != nil {
return 0, 0, err
}

return b, d, nil
}
170 changes: 150 additions & 20 deletions crproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strings"
"sync"
"time"
"strconv"

"github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/distribution/distribution/v3/registry/client/auth"
Expand All @@ -23,12 +24,12 @@ import (
"github.com/wzshiming/geario"
"github.com/wzshiming/httpseek"
"github.com/wzshiming/lru"
"github.com/wzshiming/crproxy/internal/maps"
)

var (
prefix = "/v2/"
catalog = prefix + "_catalog"
speedLimitDuration = time.Second
prefix = "/v2/"
catalog = prefix + "_catalog"
)

type Logger interface {
Expand All @@ -50,18 +51,29 @@ type CRProxy struct {
bytesPool sync.Pool
logger Logger
totalBlobsSpeedLimit *geario.Gear
speedLimitRecord maps.SyncMap[string, *geario.BPS]
blobsSpeedLimit *geario.B
blobsSpeedLimitDuration time.Duration
ipsSpeedLimit *geario.B
ipsSpeedLimitDuration time.Duration
blockFunc func(*PathInfo) bool
retry int
retryInterval time.Duration
storageDriver storagedriver.StorageDriver
linkExpires time.Duration
mutCache sync.Map
redirectLinks *url.URL
limitDelay bool
}

type Option func(c *CRProxy)

func WithLimitDelay(b bool) Option {
return func(c *CRProxy) {
c.limitDelay = b
}
}

func WithLinkExpires(d time.Duration) Option {
return func(c *CRProxy) {
c.linkExpires = d
Expand All @@ -80,15 +92,23 @@ func WithStorageDriver(storageDriver storagedriver.StorageDriver) Option {
}
}

func WithBlobsSpeedLimit(limit geario.B) Option {
func WithBlobsSpeedLimit(limit geario.B, duration time.Duration) Option {
return func(c *CRProxy) {
c.blobsSpeedLimit = &limit
c.blobsSpeedLimitDuration = duration
}
}

func WithIPsSpeedLimit(limit geario.B, duration time.Duration) Option {
return func(c *CRProxy) {
c.ipsSpeedLimit = &limit
c.ipsSpeedLimitDuration = duration
}
}

func WithTotalBlobsSpeedLimit(limit geario.B) Option {
return func(c *CRProxy) {
c.totalBlobsSpeedLimit = geario.NewGear(speedLimitDuration, limit)
c.totalBlobsSpeedLimit = geario.NewGear(time.Second, limit)
}
}

Expand Down Expand Up @@ -401,6 +421,10 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
r.URL.Scheme = c.getScheme(info.Host)
r.URL.Path = path

if !c.checkLimit(rw, r, info) {
return
}

if c.storageDriver != nil && info.Blobs != "" {
c.cacheBlobResponse(rw, r, info)
return
Expand Down Expand Up @@ -437,6 +461,8 @@ func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info *
rw.WriteHeader(resp.StatusCode)

if r.Method != http.MethodHead {
c.accumulativeLimit(rw, r, info, resp.ContentLength)

buf := c.bytesPool.Get().([]byte)
defer c.bytesPool.Put(buf)
var body io.Reader = resp.Body
Expand All @@ -446,7 +472,7 @@ func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info *
}

if c.blobsSpeedLimit != nil && info.Blobs != "" {
body = geario.NewGear(speedLimitDuration, *c.blobsSpeedLimit).Reader(body)
body = geario.NewGear(c.blobsSpeedLimitDuration, *c.blobsSpeedLimit).Reader(body)
}

io.CopyBuffer(rw, body, buf)
Expand Down Expand Up @@ -481,8 +507,9 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf
close(closeCh)
}

_, err := c.storageDriver.Stat(ctx, blobPath)
stat, err := c.storageDriver.Stat(ctx, blobPath)
if err == nil {
c.accumulativeLimit(rw, r, info, stat.Size())
err = c.redirect(rw, r, blobPath)
if err == nil {
doneCache()
Expand All @@ -496,23 +523,33 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf
c.logger.Println("Cache miss", blobPath)
}

errCh := make(chan error, 1)
type repo struct {
err error
size int64
}
signalCh := make(chan repo, 1)

go func() {
defer doneCache()
err = c.cacheBlobContent(r, blobPath, info)
errCh <- err
size, err := c.cacheBlobContent(r, blobPath, info)
signalCh <- repo{
err: err,
size: size,
}
}()

select {
case <-ctx.Done():
c.errorResponse(rw, r, ctx.Err())
return
case err := <-errCh:
if err != nil {
c.errorResponse(rw, r, err)
case signal := <-signalCh:
if signal.err != nil {
c.errorResponse(rw, r, signal.err)
return
}

c.accumulativeLimit(rw, r, info, signal.size)

err = c.redirect(rw, r, blobPath)
if err != nil {
if c.logger != nil {
Expand All @@ -523,11 +560,11 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf
}
}

func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathInfo) error {
func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathInfo) (int64, error) {
cli := c.getClientset(info.Host, info.Image)
resp, err := c.doWithAuth(cli, r, info.Host)
if err != nil {
return err
return 0, err
}
defer func() {
resp.Body.Close()
Expand All @@ -538,28 +575,32 @@ func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathI

fw, err := c.storageDriver.Writer(context.Background(), blobPath, false)
if err != nil {
return err
return 0, err
}

h := sha256.New()
n, err := io.CopyBuffer(fw, io.TeeReader(resp.Body, h), buf)
if err != nil {
fw.Cancel()
return err
return 0, err
}

if n != resp.ContentLength {
fw.Cancel()
return fmt.Errorf("expected %d bytes, got %d", resp.ContentLength, n)
return 0, fmt.Errorf("expected %d bytes, got %d", resp.ContentLength, n)
}

hash := hex.EncodeToString(h.Sum(nil)[:])
if info.Blobs[7:] != hash {
fw.Cancel()
return fmt.Errorf("expected %s hash, got %s", info.Blobs[7:], hash)
return 0, fmt.Errorf("expected %s hash, got %s", info.Blobs[7:], hash)
}

return fw.Commit()
err = fw.Commit()
if err != nil {
return 0, err
}
return n, nil
}

func (c *CRProxy) errorResponse(rw http.ResponseWriter, r *http.Request, err error) {
Expand All @@ -576,6 +617,95 @@ func (c *CRProxy) notFoundResponse(rw http.ResponseWriter, r *http.Request) {
http.NotFound(rw, r)
}

var (
ErrorCodeTooManyRequests = errcode.ErrorCodeTooManyRequests

ErrorCodeTooManyBandwidthsByBlob = errcode.Register("errcode", errcode.ErrorDescriptor{
Value: "TOOMANYBANDWIDTHS",
Message: "blob too many bandwidths",
Description: `Blobs are accessed too much`,
HTTPStatusCode: http.StatusTooManyRequests,
})
)

func addr(str string) string {
i := strings.LastIndex(str, ":")
if i <= 0 {
return ""
}
return str[:i]
}

func (c *CRProxy) checkLimit(rw http.ResponseWriter, r *http.Request, info *PathInfo) bool {
if c.blobsSpeedLimit != nil && info.Blobs != "" {
bps, _ := c.speedLimitRecord.LoadOrStore(info.Blobs, geario.NewBPSAver(c.blobsSpeedLimitDuration))
aver := bps.Aver()
if aver > *c.blobsSpeedLimit {
if c.logger != nil {
c.logger.Println("exceed limit", info.Blobs, aver, *c.blobsSpeedLimit)
}
if c.limitDelay {
select {
case <-r.Context().Done():
return false
case <-time.After(bps.Next().Sub(time.Now())):
}

} else {
err := ErrorCodeTooManyBandwidthsByBlob
rw.Header().Set("X-Retry-After", strconv.FormatInt(bps.Next().Unix(), 10))
errcode.ServeJSON(rw, err)
return false
}
}
}

if c.ipsSpeedLimit != nil && info.Blobs != "" {
address := addr(r.RemoteAddr)
bps, _ := c.speedLimitRecord.LoadOrStore(address, geario.NewBPSAver(c.ipsSpeedLimitDuration))
aver := bps.Aver()
if aver > *c.ipsSpeedLimit {
if c.logger != nil {
c.logger.Println("exceed limit", address, aver, *c.ipsSpeedLimit)
}
if c.limitDelay {
select {
case <-r.Context().Done():
return false
case <-time.After(bps.Next().Sub(time.Now())):
}
} else {
err := ErrorCodeTooManyRequests
rw.Header().Set("X-Retry-After", strconv.FormatInt(bps.Next().Unix(), 10))
errcode.ServeJSON(rw, err)
return false
}
}
}

return true
}

func (c *CRProxy) accumulativeLimit(rw http.ResponseWriter, r *http.Request, info *PathInfo, size int64) {
if r.Method != http.MethodGet {
return
}

if c.blobsSpeedLimit != nil && info.Blobs != "" {
bps, ok := c.speedLimitRecord.Load(info.Blobs)
if ok {
bps.Add(geario.B(size))
}
}

if c.ipsSpeedLimit != nil && info.Blobs != "" {
bps, ok := c.speedLimitRecord.Load(addr(r.RemoteAddr))
if ok {
bps.Add(geario.B(size))
}
}
}

func (c *CRProxy) redirect(rw http.ResponseWriter, r *http.Request, blobPath string) error {
options := map[string]interface{}{
"method": r.Method,
Expand Down
Loading

0 comments on commit 659c61e

Please sign in to comment.