Skip to content

Commit

Permalink
Return 500 on internal or upstream errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Vilsol committed Jan 7, 2022
1 parent 710a382 commit 447822f
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 29 deletions.
4 changes: 2 additions & 2 deletions cache/hashmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (c *HashMapCache) Store(path []byte, host []byte, instance *commonInstance)
}
}

func (c *HashMapCache) Get(path []byte, host []byte) (string, io.Reader, int) {
func (c *HashMapCache) Get(path []byte, host []byte) (string, io.Reader, int, bool) {
if !c.hosts {
if instance, ok := c.data.Get(path); ok {
return instance.(*commonInstance).Get(instance.(*commonInstance), nil)
Expand All @@ -115,7 +115,7 @@ func (c *HashMapCache) Get(path []byte, host []byte) (string, io.Reader, int) {
}
}

return "", nil, 0
return "", nil, 0, false
}

func (c *HashMapCache) Source() source.Source {
Expand Down
14 changes: 7 additions & 7 deletions cache/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ func index(source source.Source, f source.IndexFunc) (int64, error) {
return totalSize, nil
}

func load(c Cache) func(*commonInstance, []byte) (string, io.Reader, int) {
return func(instance *commonInstance, host []byte) (string, io.Reader, int) {
hijacker := c.Source().Get(instance.Instance.AbsolutePath, host)
func load(c Cache) func(*commonInstance, []byte) (string, io.Reader, int, bool) {
return func(instance *commonInstance, host []byte) (string, io.Reader, int, bool) {
hijacker, failed := c.Source().Get(instance.Instance.AbsolutePath, host)

if hijacker == nil {
return "", nil, 0
return "", nil, 0, failed
}

log.Debug().Msgf("Loaded file [%d][%s]: %s", hijacker.Size, hijacker.FileType(), instance.Instance.AbsolutePath)
Expand All @@ -71,12 +71,12 @@ func load(c Cache) func(*commonInstance, []byte) (string, io.Reader, int) {
instance.Instance.LoadTime = time.Now()
instance.Instance.Data = hijacker.Buffer
instance.Instance.ContentType = hijacker.FileType()
instance.Get = func(cache *commonInstance, _ []byte) (string, io.Reader, int) {
return cache.Instance.ContentType, bytes.NewReader(cache.Instance.Data), len(cache.Instance.Data)
instance.Get = func(cache *commonInstance, _ []byte) (string, io.Reader, int, bool) {
return cache.Instance.ContentType, bytes.NewReader(cache.Instance.Data), len(cache.Instance.Data), false
}
}

return hijacker.FileType(), hijacker, hijacker.Size
return hijacker.FileType(), hijacker, hijacker.Size, false
}
}

Expand Down
4 changes: 2 additions & 2 deletions cache/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type CachedInstance struct {

type commonInstance struct {
Instance *CachedInstance
Get func(instance *commonInstance, host []byte) (string, io.Reader, int)
Get func(instance *commonInstance, host []byte) (string, io.Reader, int, bool)
}

type KeyValue struct {
Expand All @@ -26,7 +26,7 @@ type KeyValue struct {

type Cache interface {
Index() (int64, error)
Get(path []byte, host []byte) (string, io.Reader, int)
Get(path []byte, host []byte) (string, io.Reader, int, bool)
Source() source.Source
Iter() <-chan KeyValue
Store(path []byte, host []byte, instance *commonInstance)
Expand Down
9 changes: 7 additions & 2 deletions server/webserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ type Webserver struct {
}

func (h *Webserver) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
if fileType, stream, size := h.Cache.Get(ctx.Path(), ctx.Host()); size > 0 {
fileType, stream, size, failed := h.Cache.Get(ctx.Path(), ctx.Host())
if size > 0 {
ctx.SetContentType(fileType)
ctx.SetBodyStream(stream, size)
} else {
ctx.SetStatusCode(404)
if failed {
ctx.SetStatusCode(500)
} else {
ctx.SetStatusCode(404)
}
}
}

Expand Down
8 changes: 4 additions & 4 deletions source/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@ var _ Source = (*Local)(nil)
type Local struct {
}

func (l Local) Get(path string, host []byte) *utils.StreamHijacker {
func (l Local) Get(path string, host []byte) (*utils.StreamHijacker, bool) {
file, err := os.OpenFile(path, os.O_RDONLY, 0664)
if err != nil {
log.Err(err).Msg("error reading file")
return nil
return nil, true
}

stat, err := file.Stat()
if err != nil {
log.Err(err).Msg("error reading file")
return nil
return nil, true
}

fileType := mime.TypeByExtension(filepath.Ext(filepath.Base(path)))

return utils.NewStreamHijacker(int(stat.Size()), fileType, file)
return utils.NewStreamHijacker(int(stat.Size()), fileType, file), false
}

func (l Local) IndexPath(dir string, f IndexFunc) (int64, int64, error) {
Expand Down
2 changes: 1 addition & 1 deletion source/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func NewS3(bucket string, key string, secret string, endpoint string, region str
}, nil
}

func (s S3) Get(path string, _ []byte) *utils.StreamHijacker {
func (s S3) Get(path string, _ []byte) (*utils.StreamHijacker, bool) {
return GetS3(s.S3Client, s.Bucket, path)
}

Expand Down
14 changes: 7 additions & 7 deletions source/s3_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ func NewS3Redis(network string, address string, username string, password string
}, nil
}

func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker {
func (s S3Redis) Get(path string, host []byte) (*utils.StreamHijacker, bool) {
var s3Wrapper *S3Wrapper

if host == nil {
return nil
return nil, true
}

if instance, ok := s.CredentialCache.Get(utils.ByteSliceToString(host)); ok {
Expand All @@ -60,11 +60,11 @@ func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker {
if get.Err() != nil {
if errors.Is(get.Err(), redis.Nil) {
log.Warn().Str("host", utils.ByteSliceToString(host)).Msg("no credentials found")
return nil
return nil, true
}

log.Error().Err(get.Err()).Msg("failed to get credentials")
return nil
return nil, true
}

s3Flat := yeet.GetRootAsS3(utils.UnsafeGetBytes(get.Val()), 0)
Expand All @@ -79,13 +79,13 @@ func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker {

if err != nil {
log.Err(err).Msg("failed to create new S3 session")
return nil
return nil, true
}

cf, err := cuckoo.Decode(s3Flat.Filter())
if err != nil {
log.Err(err).Msg("failed to decode filter")
return nil
return nil, true
}

s3Wrapper = &S3Wrapper{
Expand All @@ -100,7 +100,7 @@ func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker {
return GetS3(s3Wrapper.S3Client, s3Wrapper.Bucket, path)
}

return nil
return nil, false
}

func (s S3Redis) IndexPath(_ string, _ IndexFunc) (int64, int64, error) {
Expand Down
6 changes: 3 additions & 3 deletions source/s3_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"strings"
)

func GetS3(client *s3.S3, bucket string, path string) *utils.StreamHijacker {
func GetS3(client *s3.S3, bucket string, path string) (*utils.StreamHijacker, bool) {
cleanedKey := strings.TrimPrefix(path, "/")

object, err := client.GetObject(&s3.GetObjectInput{
Expand All @@ -20,10 +20,10 @@ func GetS3(client *s3.S3, bucket string, path string) *utils.StreamHijacker {

if err != nil {
log.Err(err).Msg("failed to get object")
return nil
return nil, true
}

fileType := mime.TypeByExtension(filepath.Ext(filepath.Base(path)))

return utils.NewStreamHijacker(int(*object.ContentLength), fileType, object.Body)
return utils.NewStreamHijacker(int(*object.ContentLength), fileType, object.Body), false
}
2 changes: 1 addition & 1 deletion source/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
type IndexFunc = func(absolutePath string, cleanedPath string) int64

type Source interface {
Get(path string, host []byte) *utils.StreamHijacker
Get(path string, host []byte) (*utils.StreamHijacker, bool)
IndexPath(dir string, f IndexFunc) (int64, int64, error)
Watch() (<-chan WatchEvent, error)
}
Expand Down

0 comments on commit 447822f

Please sign in to comment.