Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out gzipwriter interface #106

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 22 additions & 46 deletions gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"bufio"
"compress/gzip"
"fmt"
"github.com/NYTimes/gziphandler/writer"
"github.com/NYTimes/gziphandler/writer/stdlib"
"io"
"mime"
"net"
"net/http"
"strconv"
"strings"
"sync"
)

const (
Expand All @@ -36,48 +37,15 @@ const (
DefaultMinSize = 1400
)

// gzipWriterPools stores a sync.Pool for each compression level for reuse of
// gzip.Writers. Use poolIndex to covert a compression level to an index into
// gzipWriterPools.
var gzipWriterPools [gzip.BestCompression - gzip.BestSpeed + 2]*sync.Pool

func init() {
for i := gzip.BestSpeed; i <= gzip.BestCompression; i++ {
addLevelPool(i)
}
addLevelPool(gzip.DefaultCompression)
}

// poolIndex maps a compression level to its index into gzipWriterPools. It
// assumes that level is a valid gzip compression level.
func poolIndex(level int) int {
// gzip.DefaultCompression == -1, so we need to treat it special.
if level == gzip.DefaultCompression {
return gzip.BestCompression - gzip.BestSpeed + 1
}
return level - gzip.BestSpeed
}

func addLevelPool(level int) {
gzipWriterPools[poolIndex(level)] = &sync.Pool{
New: func() interface{} {
// NewWriterLevel only returns error on a bad level, we are guaranteeing
// that this will be a valid level so it is okay to ignore the returned
// error.
w, _ := gzip.NewWriterLevel(nil, level)
return w
},
}
}

// GzipResponseWriter provides an http.ResponseWriter interface, which gzips
// bytes before writing them to the underlying response. This doesn't close the
// writers, so don't forget to do that.
// It can be configured to skip response smaller than minSize.
type GzipResponseWriter struct {
http.ResponseWriter
index int // Index for gzipWriterPools.
gw *gzip.Writer
level int
gwFactory writer.GzipWriterFactory
gw writer.GzipWriter

code int // Saves the WriteHeader value.

Expand Down Expand Up @@ -217,9 +185,7 @@ func (w *GzipResponseWriter) WriteHeader(code int) {
func (w *GzipResponseWriter) init() {
// Bytes written during ServeHTTP are redirected to this gzip writer
// before being written to the underlying response.
gzw := gzipWriterPools[w.index].Get().(*gzip.Writer)
gzw.Reset(w.ResponseWriter)
w.gw = gzw
w.gw = w.gwFactory(w.ResponseWriter, w.level)
}

// Close will close the gzip.Writer and will put it back in the gzipWriterPool.
Expand All @@ -239,7 +205,6 @@ func (w *GzipResponseWriter) Close() error {
}

err := w.gw.Close()
gzipWriterPools[w.index].Put(w.gw)
w.gw = nil
return err
}
Expand Down Expand Up @@ -305,8 +270,9 @@ func NewGzipLevelAndMinSize(level, minSize int) (func(http.Handler) http.Handler

func GzipHandlerWithOpts(opts ...option) (func(http.Handler) http.Handler, error) {
c := &config{
level: gzip.DefaultCompression,
minSize: DefaultMinSize,
level: gzip.DefaultCompression,
minSize: DefaultMinSize,
newWriter: stdlib.NewWriter,
}

for _, o := range opts {
Expand All @@ -318,14 +284,13 @@ func GzipHandlerWithOpts(opts ...option) (func(http.Handler) http.Handler, error
}

return func(h http.Handler) http.Handler {
index := poolIndex(c.level)

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(vary, acceptEncoding)
if acceptsGzip(r) {
gw := &GzipResponseWriter{
ResponseWriter: w,
index: index,
gwFactory: c.newWriter,
level: c.level,
minSize: c.minSize,
contentTypes: c.contentTypes,
}
Expand Down Expand Up @@ -378,6 +343,7 @@ func (pct parsedContentType) equals(mediaType string, params map[string]string)
type config struct {
minSize int
level int
newWriter writer.GzipWriterFactory
contentTypes []parsedContentType
}

Expand Down Expand Up @@ -407,6 +373,16 @@ func CompressionLevel(level int) option {
}
}

// Implementation changes the implementation of GzipWriter
//
// The default implementation is writer/stdlib/NewWriter
// which is backed by standard library's compress/zlib
func Implementation(writer writer.GzipWriterFactory) option {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation above.
Add a link to the default implementation.

return func(c *config) {
c.newWriter = writer
}
}

// ContentTypes specifies a list of content types to compare
// the Content-Type header to before compressing. If none
// match, the response will be returned as-is.
Expand Down
25 changes: 0 additions & 25 deletions gzip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,31 +321,6 @@ func TestGzipHandlerMinSize(t *testing.T) {
}
}

func TestGzipDoubleClose(t *testing.T) {
// reset the pool for the default compression so we can make sure duplicates
// aren't added back by double close
addLevelPool(gzip.DefaultCompression)

handler := GzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// call close here and it'll get called again interally by
// NewGzipLevelHandler's handler defer
w.Write([]byte("test"))
w.(io.Closer).Close()
}))

r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)

// the second close shouldn't have added the same writer
// so we pull out 2 writers from the pool and make sure they're different
w1 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get()
w2 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get()
// assert.NotEqual looks at the value and not the address, so we use regular ==
assert.False(t, w1 == w2)
}

type panicOnSecondWriteHeaderWriter struct {
http.ResponseWriter
headerWritten bool
Expand Down
11 changes: 11 additions & 0 deletions writer/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package writer

import "io"

type GzipWriter interface {
Close() error
Flush() error
Write(p []byte) (int, error)
}

type GzipWriterFactory = func(writer io.Writer, level int) GzipWriter
68 changes: 68 additions & 0 deletions writer/stdlib/stdlib.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package stdlib
whs marked this conversation as resolved.
Show resolved Hide resolved

import (
"compress/gzip"
"github.com/NYTimes/gziphandler/writer"
"io"
"sync"
)

// gzipWriterPools stores a sync.Pool for each compression level for reuse of
// gzip.Writers. Use poolIndex to covert a compression level to an index into
// gzipWriterPools.
var gzipWriterPools [gzip.BestCompression - gzip.BestSpeed + 2]*sync.Pool

func init() {
for i := gzip.BestSpeed; i <= gzip.BestCompression; i++ {
addLevelPool(i)
}
addLevelPool(gzip.DefaultCompression)
}

// poolIndex maps a compression level to its index into gzipWriterPools. It
// assumes that level is a valid gzip compression level.
func poolIndex(level int) int {
// gzip.DefaultCompression == -1, so we need to treat it special.
if level == gzip.DefaultCompression {
return gzip.BestCompression - gzip.BestSpeed + 1
}
return level - gzip.BestSpeed
}

func addLevelPool(level int) {
gzipWriterPools[poolIndex(level)] = &sync.Pool{
New: func() interface{} {
// NewWriterLevel only returns error on a bad level, we are guaranteeing
// that this will be a valid level so it is okay to ignore the returned
// error.
w, _ := gzip.NewWriterLevel(nil, level)
return w
},
}
}

type pooledWriter struct {
*gzip.Writer
index int
}

func (pw *pooledWriter) Close() error {
err := pw.Writer.Close()
gzipWriterPools[pw.index].Put(pw.Writer)
pw.Writer = nil
return err
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing (to match the original gzip.go line 243):

pw.Writer = nil

This will allow to be future-proof against double-close or use-after-close.

}

func NewWriter(w io.Writer, level int) writer.GzipWriter {
index := poolIndex(level)
gzw := gzipWriterPools[index].Get().(*gzip.Writer)
gzw.Reset(w)
return &pooledWriter{
Writer: gzw,
index: index,
}
}

func ImplementationInfo() string {
return "compress/zlib"
}
25 changes: 25 additions & 0 deletions writer/stdlib/stdlib_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package stdlib

import (
"bytes"
"compress/gzip"
"github.com/stretchr/testify/assert"
"testing"
)

func TestGzipDoubleClose(t *testing.T) {
// reset the pool for the default compression so we can make sure duplicates
// aren't added back by double close
addLevelPool(gzip.DefaultCompression)

w := bytes.NewBufferString("")
writer := NewWriter(w, gzip.DefaultCompression)
writer.Close()

// the second close shouldn't have added the same writer
// so we pull out 2 writers from the pool and make sure they're different
w1 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get()
w2 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get()
// assert.NotEqual looks at the value and not the address, so we use regular ==
assert.False(t, w1 == w2)
}