diff --git a/flexhttp.go b/flexhttp.go index 528a97c..f49bd8a 100644 --- a/flexhttp.go +++ b/flexhttp.go @@ -2,10 +2,8 @@ package flexhttp import ( "context" - "log" "net" "net/http" - "os" "time" ) @@ -17,21 +15,34 @@ const ( DefaultReadHeaderTimeout = 1 * time.Second ) -var ( - // DefaultHTTPServer provides a default http server. - DefaultHTTPServer = &http.Server{ - WriteTimeout: DefaultWriteTimeout, - ReadTimeout: DefaultReadTimeout, - IdleTimeout: DefaultIdleTimeout, - ReadHeaderTimeout: DefaultReadHeaderTimeout, - } +// DefaultHTTPServer provides a default http server. +var DefaultHTTPServer = &http.Server{ + WriteTimeout: DefaultWriteTimeout, + ReadTimeout: DefaultReadTimeout, + IdleTimeout: DefaultIdleTimeout, + ReadHeaderTimeout: DefaultReadHeaderTimeout, +} - // logger defines a logger with a prefix. - logger = log.New(os.Stderr, "flexhttp: ", 0) -) +// Logger defines any logger able to call Printf. +type Logger interface { + Printf(format string, v ...interface{}) +} + +// Option is a type of func that allows you change defaults of the *Server. +type Option func(s *Server) + +// WithLogger allows you to set a logger for the server. +func WithLogger(logger Logger) Option { + return func(s *Server) { + s.logger = logger + } +} // Server defines the flexhttp Server. -type Server struct{ *http.Server } +type Server struct { + logger Logger + *http.Server +} // New returns a new flexhttp server, using the provided http.Server. // NOTE: If no values are provided, defaults will be used the following fields: @@ -39,23 +50,30 @@ type Server struct{ *http.Server } // - ReadHeaderTimeout // - WriteTimeout // - IdleTimeout -func New(server *http.Server) *Server { - if server == nil { - server = DefaultHTTPServer +func New(httpServer *http.Server, options ...Option) *Server { + if httpServer == nil { + httpServer = DefaultHTTPServer + } + if httpServer.ReadTimeout == 0 { + httpServer.ReadTimeout = DefaultHTTPServer.ReadTimeout } - if server.ReadTimeout == 0 { - server.ReadTimeout = DefaultHTTPServer.ReadTimeout + if httpServer.ReadHeaderTimeout == 0 { + httpServer.ReadHeaderTimeout = DefaultHTTPServer.ReadHeaderTimeout } - if server.ReadHeaderTimeout == 0 { - server.ReadHeaderTimeout = DefaultHTTPServer.ReadHeaderTimeout + if httpServer.WriteTimeout == 0 { + httpServer.WriteTimeout = DefaultHTTPServer.WriteTimeout } - if server.WriteTimeout == 0 { - server.WriteTimeout = DefaultHTTPServer.WriteTimeout + if httpServer.IdleTimeout == 0 { + httpServer.IdleTimeout = DefaultHTTPServer.IdleTimeout } - if server.IdleTimeout == 0 { - server.IdleTimeout = DefaultHTTPServer.IdleTimeout + + server := &Server{Server: httpServer} + + for _, opt := range options { + opt(server) } - return &Server{Server: server} + + return server } // Run satisfies the flex Runner interface. @@ -65,13 +83,13 @@ func (s *Server) Run(_ context.Context) error { return err } if address, ok := listener.Addr().(*net.TCPAddr); ok { - logger.Printf("serving on http://%s", address) + s.logger.Printf("serving on http://%s", address) } return s.Serve(listener) } // Halt satisfies the flex Halter interface. func (s *Server) Halt(ctx context.Context) error { - logger.Printf("shutting down http server...") + s.logger.Printf("shutting down http server...") return s.Shutdown(ctx) } diff --git a/flexhttp_test.go b/flexhttp_test.go index c74c255..13119ed 100644 --- a/flexhttp_test.go +++ b/flexhttp_test.go @@ -1,10 +1,14 @@ package flexhttp_test import ( + "bytes" "context" + "io" + "log" "net/http" "os" "reflect" + "strings" "testing" "time" @@ -12,7 +16,7 @@ import ( ) var ( - handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) ctx context.Context @@ -38,8 +42,8 @@ func Example() { func TestNewHTTPServer(t *testing.T) { testcases := []struct { - name string srv *http.Server + name string expectedTimeout time.Duration }{ { @@ -81,12 +85,40 @@ func TestNewHTTPServer(t *testing.T) { } } +func TestOption_WithLogger(t *testing.T) { + var buf bytes.Buffer + + w := io.MultiWriter(&buf, os.Stderr) // so we get console output. + logger := log.New(w, "TEST_LOGGER: ", 0) // so we get consistent output. + + metrics := flexhttp.New( + &http.Server{}, + flexhttp.WithLogger(logger), + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + go func() { + _ = metrics.Run(ctx) + }() + _ = metrics.Halt(ctx) + + t.Log(buf.String()) + + // ugly? yes, but, it will do. + if !strings.Contains(buf.String(), "TEST_LOGGER: ") { + t.Fatal("expected log message to contain prefix") + } +} + func equal(t *testing.T, got, want interface{}) { t.Helper() if !reflect.DeepEqual(got, want) { t.Fatalf("got: %#[1]v (%[1]T), but wanted: %#[2]v (%[2]T)", got, want) } } + func notEqual(t *testing.T, got, want interface{}) { t.Helper() if reflect.DeepEqual(got, want) {