Skip to content

Commit

Permalink
Merge pull request #2 from go-flexible/feature/custom-logger-support
Browse files Browse the repository at this point in the history
feature/custom logger support
  • Loading branch information
ladydascalie authored Apr 16, 2022
2 parents fcd7cc5 + 898fec5 commit 27165e1
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 30 deletions.
74 changes: 46 additions & 28 deletions flexhttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package flexhttp

import (
"context"
"log"
"net"
"net/http"
"os"
"time"
)

Expand All @@ -17,45 +15,65 @@ const (
DefaultReadHeaderTimeout = 1 * time.Second
)

var (
// DefaultHTTPServer provides a default http server.
DefaultHTTPServer = &http.Server{
WriteTimeout: DefaultWriteTimeout,
ReadTimeout: DefaultReadTimeout,
IdleTimeout: DefaultIdleTimeout,
ReadHeaderTimeout: DefaultReadHeaderTimeout,
}
// DefaultHTTPServer provides a default http server.
var DefaultHTTPServer = &http.Server{
WriteTimeout: DefaultWriteTimeout,
ReadTimeout: DefaultReadTimeout,
IdleTimeout: DefaultIdleTimeout,
ReadHeaderTimeout: DefaultReadHeaderTimeout,
}

// logger defines a logger with a prefix.
logger = log.New(os.Stderr, "flexhttp: ", 0)
)
// Logger defines any logger able to call Printf.
type Logger interface {
Printf(format string, v ...interface{})
}

// Option is a type of func that allows you change defaults of the *Server.
type Option func(s *Server)

// WithLogger allows you to set a logger for the server.
func WithLogger(logger Logger) Option {
return func(s *Server) {
s.logger = logger
}
}

// Server defines the flexhttp Server.
type Server struct{ *http.Server }
type Server struct {
logger Logger
*http.Server
}

// New returns a new flexhttp server, using the provided http.Server.
// NOTE: If no values are provided, defaults will be used the following fields:
// - ReadTimeout
// - ReadHeaderTimeout
// - WriteTimeout
// - IdleTimeout
func New(server *http.Server) *Server {
if server == nil {
server = DefaultHTTPServer
func New(httpServer *http.Server, options ...Option) *Server {
if httpServer == nil {
httpServer = DefaultHTTPServer
}
if httpServer.ReadTimeout == 0 {
httpServer.ReadTimeout = DefaultHTTPServer.ReadTimeout
}
if server.ReadTimeout == 0 {
server.ReadTimeout = DefaultHTTPServer.ReadTimeout
if httpServer.ReadHeaderTimeout == 0 {
httpServer.ReadHeaderTimeout = DefaultHTTPServer.ReadHeaderTimeout
}
if server.ReadHeaderTimeout == 0 {
server.ReadHeaderTimeout = DefaultHTTPServer.ReadHeaderTimeout
if httpServer.WriteTimeout == 0 {
httpServer.WriteTimeout = DefaultHTTPServer.WriteTimeout
}
if server.WriteTimeout == 0 {
server.WriteTimeout = DefaultHTTPServer.WriteTimeout
if httpServer.IdleTimeout == 0 {
httpServer.IdleTimeout = DefaultHTTPServer.IdleTimeout
}
if server.IdleTimeout == 0 {
server.IdleTimeout = DefaultHTTPServer.IdleTimeout

server := &Server{Server: httpServer}

for _, opt := range options {
opt(server)
}
return &Server{Server: server}

return server
}

// Run satisfies the flex Runner interface.
Expand All @@ -65,13 +83,13 @@ func (s *Server) Run(_ context.Context) error {
return err
}
if address, ok := listener.Addr().(*net.TCPAddr); ok {
logger.Printf("serving on http://%s", address)
s.logger.Printf("serving on http://%s", address)
}
return s.Serve(listener)
}

// Halt satisfies the flex Halter interface.
func (s *Server) Halt(ctx context.Context) error {
logger.Printf("shutting down http server...")
s.logger.Printf("shutting down http server...")
return s.Shutdown(ctx)
}
36 changes: 34 additions & 2 deletions flexhttp_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package flexhttp_test

import (
"bytes"
"context"
"io"
"log"
"net/http"
"os"
"reflect"
"strings"
"testing"
"time"

"github.com/go-flexible/flexhttp"
)

var (
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
ctx context.Context
Expand All @@ -38,8 +42,8 @@ func Example() {

func TestNewHTTPServer(t *testing.T) {
testcases := []struct {
name string
srv *http.Server
name string
expectedTimeout time.Duration
}{
{
Expand Down Expand Up @@ -81,12 +85,40 @@ func TestNewHTTPServer(t *testing.T) {
}
}

func TestOption_WithLogger(t *testing.T) {
var buf bytes.Buffer

w := io.MultiWriter(&buf, os.Stderr) // so we get console output.
logger := log.New(w, "TEST_LOGGER: ", 0) // so we get consistent output.

metrics := flexhttp.New(
&http.Server{},
flexhttp.WithLogger(logger),
)

ctx, cancel := context.WithCancel(context.Background())
cancel()

go func() {
_ = metrics.Run(ctx)
}()
_ = metrics.Halt(ctx)

t.Log(buf.String())

// ugly? yes, but, it will do.
if !strings.Contains(buf.String(), "TEST_LOGGER: ") {
t.Fatal("expected log message to contain prefix")
}
}

func equal(t *testing.T, got, want interface{}) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Fatalf("got: %#[1]v (%[1]T), but wanted: %#[2]v (%[2]T)", got, want)
}
}

func notEqual(t *testing.T, got, want interface{}) {
t.Helper()
if reflect.DeepEqual(got, want) {
Expand Down

0 comments on commit 27165e1

Please sign in to comment.