Skip to content

Commit

Permalink
refactor: move back to net/http
Browse files Browse the repository at this point in the history
Since Go1.22, net/http includes easy inpath variable getter and easy
HTTP verb management.

There is no reason for us to keep httprouter anymore.
  • Loading branch information
kpetremann committed Oct 23, 2024
1 parent 7ebd3c2 commit 51a6702
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 53 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ require (
github.com/go-ldap/ldap/v3 v3.4.8
github.com/go-playground/validator/v10 v10.22.0
github.com/google/go-cmp v0.6.0
github.com/julienschmidt/httprouter v1.3.0
github.com/openconfig/goyang v1.6.0
github.com/openconfig/ygot v0.29.20
github.com/prometheus/client_golang v1.20.1
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
Expand Down
17 changes: 8 additions & 9 deletions internal/api/auth/basic_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/http"

"github.com/criteo/data-aggregation-api/internal/config"
"github.com/julienschmidt/httprouter"
"github.com/rs/zerolog/log"
)

Expand Down Expand Up @@ -74,24 +73,24 @@ func (b *BasicAuth) configureLdap(ldap *LDAPAuth) error {
return nil
}

func (b *BasicAuth) Wrap(next httprouter.Handle) httprouter.Handle {
func (b *BasicAuth) Wrap(next http.HandlerFunc) http.HandlerFunc {
switch b.mode {
case noAuth:
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { next(w, r, ps) }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next(w, r) })
case ldapMode:
return BasicAuthLDAP(b.ldapAuth, next)
default:
return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
log.Error().Str("auth-method", string(b.mode)).Str("authentication issue", "bad server configuration").Send()
http.Error(w, "authentication issue: bad server configuration", http.StatusInternalServerError)
}
})
}
}

// BasicAuthLDAP is a middleware wrapping the target HTTP HandlerFunc.
// It retrieves BasicAuth credentials and authenticate against LDAP.
func BasicAuthLDAP(ldapAuth *LDAPAuth, next httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func BasicAuthLDAP(ldapAuth *LDAPAuth, next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
w.Header().Set(wwwAuthenticate, realm)
Expand All @@ -103,6 +102,6 @@ func BasicAuthLDAP(ldapAuth *LDAPAuth, next httprouter.Handle) httprouter.Handle
http.Error(w, unauthorizedResponse, http.StatusUnauthorized)
return
}
next(w, r, ps)
}
next(w, r)
})
}
39 changes: 19 additions & 20 deletions internal/api/router/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,34 @@ import (

"github.com/criteo/data-aggregation-api/internal/app"
"github.com/criteo/data-aggregation-api/internal/convertor/device"
"github.com/julienschmidt/httprouter"
)

const contentType = "Content-Type"
const applicationJSON = "application/json"
const hostnameKey = "hostname"
const wildcard = "*"

func healthCheck(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func healthCheck(w http.ResponseWriter, _ *http.Request) {
w.Header().Set(contentType, applicationJSON)
_, _ = fmt.Fprintf(w, `{"status": "ok"}`)
}

func getVersion(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func getVersion(w http.ResponseWriter, _ *http.Request) {
w.Header().Set(contentType, applicationJSON)
_, _ = fmt.Fprintf(w, `{"version": "%s", "build_time": "%s", "build_user": "%s"}`, app.Info.Version, app.Info.BuildTime, app.Info.BuildUser)
}

func prometheusMetrics(h http.Handler) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func prometheusMetrics(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r)
}
}

// getAFKEnabled endpoint returns all AFK enabled devices.
// They are supposed to be managed by AFK, meaning the configuration should be applied periodically.
func (m *Manager) getAFKEnabled(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) {
func (m *Manager) getAFKEnabled(w http.ResponseWriter, r *http.Request) {
w.Header().Set(contentType, applicationJSON)
hostname := ps.ByName(hostnameKey)
hostname := r.PathValue(hostnameKey)

if hostname == wildcard {
out, err := m.devices.ListAFKEnabledDevicesJSON()
Expand Down Expand Up @@ -68,10 +67,10 @@ func (m *Manager) getAFKEnabled(w http.ResponseWriter, _ *http.Request, ps httpr
}

// getDeviceOpenConfig endpoint returns OpenConfig JSON for one or all devices.
func (m *Manager) getDeviceOpenConfig(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) {
func (m *Manager) getDeviceOpenConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set(contentType, applicationJSON)
hostname := ps.ByName(hostnameKey)
if ps.ByName(hostnameKey) == wildcard {
hostname := r.PathValue(hostnameKey)
if hostname == wildcard {
cfg, err := m.devices.GetAllDevicesOpenConfigJSON()
if err != nil {
log.Error().Err(err).Send()
Expand All @@ -93,10 +92,10 @@ func (m *Manager) getDeviceOpenConfig(w http.ResponseWriter, _ *http.Request, ps
}

// getDeviceIETFConfig endpoint returns Ietf JSON for one or all devices.
func (m *Manager) getDeviceIETFConfig(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) {
func (m *Manager) getDeviceIETFConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set(contentType, applicationJSON)
hostname := ps.ByName(hostnameKey)
if ps.ByName(hostnameKey) == wildcard {
hostname := r.PathValue(hostnameKey)
if hostname == wildcard {
cfg, err := m.devices.GetAllDevicesIETFConfigJSON()
if err != nil {
log.Error().Err(err).Send()
Expand All @@ -118,10 +117,10 @@ func (m *Manager) getDeviceIETFConfig(w http.ResponseWriter, _ *http.Request, ps
}

// getDeviceConfig endpoint returns Ietf & openconfig JSON for one or all devices.
func (m *Manager) getDeviceConfig(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) {
func (m *Manager) getDeviceConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set(contentType, applicationJSON)
hostname := ps.ByName(hostnameKey)
if ps.ByName(hostnameKey) == wildcard {
hostname := r.PathValue(hostnameKey)
if hostname == wildcard {
cfg, err := m.devices.GetAllDevicesConfigJSON()
if err != nil {
log.Error().Err(err).Send()
Expand All @@ -143,7 +142,7 @@ func (m *Manager) getDeviceConfig(w http.ResponseWriter, _ *http.Request, ps htt
}

// getLastReport returns the last or current report.
func (m *Manager) getLastReport(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func (m *Manager) getLastReport(w http.ResponseWriter, r *http.Request) {
out, err := m.reports.GetLastJSON()
if err != nil {
log.Error().Err(err).Send()
Expand All @@ -156,7 +155,7 @@ func (m *Manager) getLastReport(w http.ResponseWriter, _ *http.Request, _ httpro
}

// getLastCompleteReport returns the previous build report.
func (m *Manager) getLastCompleteReport(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func (m *Manager) getLastCompleteReport(w http.ResponseWriter, r *http.Request) {
out, err := m.reports.GetLastCompleteJSON()
if err != nil {
log.Error().Err(err).Send()
Expand All @@ -169,7 +168,7 @@ func (m *Manager) getLastCompleteReport(w http.ResponseWriter, _ *http.Request,
}

// getLastSuccessfulReport returns the previous successful build report.
func (m *Manager) getLastSuccessfulReport(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func (m *Manager) getLastSuccessfulReport(w http.ResponseWriter, r *http.Request) {
out, err := m.reports.GetLastSuccessfulJSON()
if err != nil {
log.Error().Err(err).Send()
Expand All @@ -184,7 +183,7 @@ func (m *Manager) getLastSuccessfulReport(w http.ResponseWriter, _ *http.Request
// triggerBuild enables the user to trigger a new build.
//
// It only accepts one build request at a time.
func (m *Manager) triggerBuild(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func (m *Manager) triggerBuild(w http.ResponseWriter, r *http.Request) {
w.Header().Set(contentType, applicationJSON)
select {
case m.newBuildRequest <- struct{}{}:
Expand Down
41 changes: 20 additions & 21 deletions internal/api/router/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/criteo/data-aggregation-api/internal/config"
"github.com/criteo/data-aggregation-api/internal/convertor/device"
"github.com/criteo/data-aggregation-api/internal/report"
"github.com/julienschmidt/httprouter"
)

const shutdownTimeout = 5 * time.Second
Expand Down Expand Up @@ -54,34 +53,34 @@ func (m *Manager) ListenAndServe(ctx context.Context, address string, port int,
return err
}

router := httprouter.New()
mux := http.NewServeMux()

router.GET("/metrics", prometheusMetrics(promhttp.Handler()))
router.GET("/api/version", getVersion)
router.GET("/api/health", healthCheck)
router.GET("/v1/devices/:hostname/afk_enabled", withAuth.Wrap(m.getAFKEnabled))
router.GET("/v1/devices/:hostname/openconfig", withAuth.Wrap(m.getDeviceOpenConfig))
router.GET("/v1/devices/:hostname/ietfconfig", withAuth.Wrap(m.getDeviceIETFConfig))
router.GET("/v1/devices/:hostname/config", withAuth.Wrap(m.getDeviceConfig))
router.GET("/v1/report/last", withAuth.Wrap(m.getLastReport))
router.GET("/v1/report/last/complete", withAuth.Wrap(m.getLastCompleteReport))
router.GET("/v1/report/last/successful", withAuth.Wrap(m.getLastSuccessfulReport))
router.POST("/v1/build/trigger", withAuth.Wrap(m.triggerBuild))
mux.HandleFunc("GET /metrics", prometheusMetrics(promhttp.Handler()))
mux.HandleFunc("GET /api/version", getVersion)
mux.HandleFunc("GET /api/health", healthCheck)
mux.HandleFunc("GET /v1/devices/{hostname}/afk_enabled", withAuth.Wrap(m.getAFKEnabled))
mux.HandleFunc("GET /v1/devices/{hostname}/openconfig", withAuth.Wrap(m.getDeviceOpenConfig))
mux.HandleFunc("GET /v1/devices/{hostname}/ietfconfig", withAuth.Wrap(m.getDeviceIETFConfig))
mux.HandleFunc("GET /v1/devices/{hostname}/config", withAuth.Wrap(m.getDeviceConfig))
mux.HandleFunc("GET /v1/report/last", withAuth.Wrap(m.getLastReport))
mux.HandleFunc("GET /v1/report/last/complete", withAuth.Wrap(m.getLastCompleteReport))
mux.HandleFunc("GET /v1/report/last/successful", withAuth.Wrap(m.getLastSuccessfulReport))
mux.HandleFunc("POST /v1/build/trigger", withAuth.Wrap(m.triggerBuild))

if enablepprof {
router.HandlerFunc(http.MethodGet, "/debug/pprof/", pprof.Index)
router.HandlerFunc(http.MethodGet, "/debug/pprof/allocs", pprof.Index)
router.HandlerFunc(http.MethodGet, "/debug/pprof/goroutine", pprof.Index)
router.HandlerFunc(http.MethodGet, "/debug/pprof/heap", pprof.Index)
router.HandlerFunc(http.MethodGet, "/debug/pprof/profile", pprof.Profile)
router.HandlerFunc(http.MethodGet, "/debug/pprof/trace", pprof.Trace)
router.HandlerFunc(http.MethodGet, "/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("GET /debug/pprof/", pprof.Index)
mux.HandleFunc("GET /debug/pprof/allocs", pprof.Index)
mux.HandleFunc("GET /debug/pprof/goroutine", pprof.Index)
mux.HandleFunc("GET /debug/pprof/heap", pprof.Index)
mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
}

listenSocket := fmt.Sprint(address, ":", port)
log.Info().Msgf("Start webserver - listening on %s", listenSocket)

httpServer := http.Server{Addr: listenSocket, Handler: router}
httpServer := http.Server{Addr: listenSocket, Handler: mux}

// TODO: handle http failure! with a channel
go func() {
Expand Down

0 comments on commit 51a6702

Please sign in to comment.