From 51a6702b7926a0491200a896b78e7ef0d0253ee9 Mon Sep 17 00:00:00 2001 From: "k.petremann" Date: Mon, 21 Oct 2024 17:47:35 +0200 Subject: [PATCH] refactor: move back to net/http Since Go1.22, net/http includes easy inpath variable getter and easy HTTP verb management. There is no reason for us to keep httprouter anymore. --- go.mod | 1 - go.sum | 2 -- internal/api/auth/basic_auth.go | 17 +++++++------ internal/api/router/endpoints.go | 39 +++++++++++++++--------------- internal/api/router/manager.go | 41 ++++++++++++++++---------------- 5 files changed, 47 insertions(+), 53 deletions(-) diff --git a/go.mod b/go.mod index 0747bcc..b504b76 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/go-ldap/ldap/v3 v3.4.8 github.com/go-playground/validator/v10 v10.22.0 github.com/google/go-cmp v0.6.0 - github.com/julienschmidt/httprouter v1.3.0 github.com/openconfig/goyang v1.6.0 github.com/openconfig/ygot v0.29.20 github.com/prometheus/client_golang v1.20.1 diff --git a/go.sum b/go.sum index 73f6bde..ae53518 100644 --- a/go.sum +++ b/go.sum @@ -99,8 +99,6 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6 github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= -github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= -github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/api/auth/basic_auth.go b/internal/api/auth/basic_auth.go index 8d44023..eb69f61 100644 --- a/internal/api/auth/basic_auth.go +++ b/internal/api/auth/basic_auth.go @@ -8,7 +8,6 @@ import ( "net/http" "github.com/criteo/data-aggregation-api/internal/config" - "github.com/julienschmidt/httprouter" "github.com/rs/zerolog/log" ) @@ -74,24 +73,24 @@ func (b *BasicAuth) configureLdap(ldap *LDAPAuth) error { return nil } -func (b *BasicAuth) Wrap(next httprouter.Handle) httprouter.Handle { +func (b *BasicAuth) Wrap(next http.HandlerFunc) http.HandlerFunc { switch b.mode { case noAuth: - return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { next(w, r, ps) } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next(w, r) }) case ldapMode: return BasicAuthLDAP(b.ldapAuth, next) default: - return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { log.Error().Str("auth-method", string(b.mode)).Str("authentication issue", "bad server configuration").Send() http.Error(w, "authentication issue: bad server configuration", http.StatusInternalServerError) - } + }) } } // BasicAuthLDAP is a middleware wrapping the target HTTP HandlerFunc. // It retrieves BasicAuth credentials and authenticate against LDAP. -func BasicAuthLDAP(ldapAuth *LDAPAuth, next httprouter.Handle) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func BasicAuthLDAP(ldapAuth *LDAPAuth, next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { username, password, ok := r.BasicAuth() if !ok { w.Header().Set(wwwAuthenticate, realm) @@ -103,6 +102,6 @@ func BasicAuthLDAP(ldapAuth *LDAPAuth, next httprouter.Handle) httprouter.Handle http.Error(w, unauthorizedResponse, http.StatusUnauthorized) return } - next(w, r, ps) - } + next(w, r) + }) } diff --git a/internal/api/router/endpoints.go b/internal/api/router/endpoints.go index dcbc69f..fc8ec8e 100644 --- a/internal/api/router/endpoints.go +++ b/internal/api/router/endpoints.go @@ -9,7 +9,6 @@ import ( "github.com/criteo/data-aggregation-api/internal/app" "github.com/criteo/data-aggregation-api/internal/convertor/device" - "github.com/julienschmidt/httprouter" ) const contentType = "Content-Type" @@ -17,27 +16,27 @@ const applicationJSON = "application/json" const hostnameKey = "hostname" const wildcard = "*" -func healthCheck(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { +func healthCheck(w http.ResponseWriter, _ *http.Request) { w.Header().Set(contentType, applicationJSON) _, _ = fmt.Fprintf(w, `{"status": "ok"}`) } -func getVersion(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { +func getVersion(w http.ResponseWriter, _ *http.Request) { w.Header().Set(contentType, applicationJSON) _, _ = fmt.Fprintf(w, `{"version": "%s", "build_time": "%s", "build_user": "%s"}`, app.Info.Version, app.Info.BuildTime, app.Info.BuildUser) } -func prometheusMetrics(h http.Handler) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func prometheusMetrics(h http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { h.ServeHTTP(w, r) } } // getAFKEnabled endpoint returns all AFK enabled devices. // They are supposed to be managed by AFK, meaning the configuration should be applied periodically. -func (m *Manager) getAFKEnabled(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) { +func (m *Manager) getAFKEnabled(w http.ResponseWriter, r *http.Request) { w.Header().Set(contentType, applicationJSON) - hostname := ps.ByName(hostnameKey) + hostname := r.PathValue(hostnameKey) if hostname == wildcard { out, err := m.devices.ListAFKEnabledDevicesJSON() @@ -68,10 +67,10 @@ func (m *Manager) getAFKEnabled(w http.ResponseWriter, _ *http.Request, ps httpr } // getDeviceOpenConfig endpoint returns OpenConfig JSON for one or all devices. -func (m *Manager) getDeviceOpenConfig(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) { +func (m *Manager) getDeviceOpenConfig(w http.ResponseWriter, r *http.Request) { w.Header().Set(contentType, applicationJSON) - hostname := ps.ByName(hostnameKey) - if ps.ByName(hostnameKey) == wildcard { + hostname := r.PathValue(hostnameKey) + if hostname == wildcard { cfg, err := m.devices.GetAllDevicesOpenConfigJSON() if err != nil { log.Error().Err(err).Send() @@ -93,10 +92,10 @@ func (m *Manager) getDeviceOpenConfig(w http.ResponseWriter, _ *http.Request, ps } // getDeviceIETFConfig endpoint returns Ietf JSON for one or all devices. -func (m *Manager) getDeviceIETFConfig(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) { +func (m *Manager) getDeviceIETFConfig(w http.ResponseWriter, r *http.Request) { w.Header().Set(contentType, applicationJSON) - hostname := ps.ByName(hostnameKey) - if ps.ByName(hostnameKey) == wildcard { + hostname := r.PathValue(hostnameKey) + if hostname == wildcard { cfg, err := m.devices.GetAllDevicesIETFConfigJSON() if err != nil { log.Error().Err(err).Send() @@ -118,10 +117,10 @@ func (m *Manager) getDeviceIETFConfig(w http.ResponseWriter, _ *http.Request, ps } // getDeviceConfig endpoint returns Ietf & openconfig JSON for one or all devices. -func (m *Manager) getDeviceConfig(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) { +func (m *Manager) getDeviceConfig(w http.ResponseWriter, r *http.Request) { w.Header().Set(contentType, applicationJSON) - hostname := ps.ByName(hostnameKey) - if ps.ByName(hostnameKey) == wildcard { + hostname := r.PathValue(hostnameKey) + if hostname == wildcard { cfg, err := m.devices.GetAllDevicesConfigJSON() if err != nil { log.Error().Err(err).Send() @@ -143,7 +142,7 @@ func (m *Manager) getDeviceConfig(w http.ResponseWriter, _ *http.Request, ps htt } // getLastReport returns the last or current report. -func (m *Manager) getLastReport(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { +func (m *Manager) getLastReport(w http.ResponseWriter, r *http.Request) { out, err := m.reports.GetLastJSON() if err != nil { log.Error().Err(err).Send() @@ -156,7 +155,7 @@ func (m *Manager) getLastReport(w http.ResponseWriter, _ *http.Request, _ httpro } // getLastCompleteReport returns the previous build report. -func (m *Manager) getLastCompleteReport(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { +func (m *Manager) getLastCompleteReport(w http.ResponseWriter, r *http.Request) { out, err := m.reports.GetLastCompleteJSON() if err != nil { log.Error().Err(err).Send() @@ -169,7 +168,7 @@ func (m *Manager) getLastCompleteReport(w http.ResponseWriter, _ *http.Request, } // getLastSuccessfulReport returns the previous successful build report. -func (m *Manager) getLastSuccessfulReport(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { +func (m *Manager) getLastSuccessfulReport(w http.ResponseWriter, r *http.Request) { out, err := m.reports.GetLastSuccessfulJSON() if err != nil { log.Error().Err(err).Send() @@ -184,7 +183,7 @@ func (m *Manager) getLastSuccessfulReport(w http.ResponseWriter, _ *http.Request // triggerBuild enables the user to trigger a new build. // // It only accepts one build request at a time. -func (m *Manager) triggerBuild(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { +func (m *Manager) triggerBuild(w http.ResponseWriter, r *http.Request) { w.Header().Set(contentType, applicationJSON) select { case m.newBuildRequest <- struct{}{}: diff --git a/internal/api/router/manager.go b/internal/api/router/manager.go index 70c4777..cd29911 100644 --- a/internal/api/router/manager.go +++ b/internal/api/router/manager.go @@ -14,7 +14,6 @@ import ( "github.com/criteo/data-aggregation-api/internal/config" "github.com/criteo/data-aggregation-api/internal/convertor/device" "github.com/criteo/data-aggregation-api/internal/report" - "github.com/julienschmidt/httprouter" ) const shutdownTimeout = 5 * time.Second @@ -54,34 +53,34 @@ func (m *Manager) ListenAndServe(ctx context.Context, address string, port int, return err } - router := httprouter.New() + mux := http.NewServeMux() - router.GET("/metrics", prometheusMetrics(promhttp.Handler())) - router.GET("/api/version", getVersion) - router.GET("/api/health", healthCheck) - router.GET("/v1/devices/:hostname/afk_enabled", withAuth.Wrap(m.getAFKEnabled)) - router.GET("/v1/devices/:hostname/openconfig", withAuth.Wrap(m.getDeviceOpenConfig)) - router.GET("/v1/devices/:hostname/ietfconfig", withAuth.Wrap(m.getDeviceIETFConfig)) - router.GET("/v1/devices/:hostname/config", withAuth.Wrap(m.getDeviceConfig)) - router.GET("/v1/report/last", withAuth.Wrap(m.getLastReport)) - router.GET("/v1/report/last/complete", withAuth.Wrap(m.getLastCompleteReport)) - router.GET("/v1/report/last/successful", withAuth.Wrap(m.getLastSuccessfulReport)) - router.POST("/v1/build/trigger", withAuth.Wrap(m.triggerBuild)) + mux.HandleFunc("GET /metrics", prometheusMetrics(promhttp.Handler())) + mux.HandleFunc("GET /api/version", getVersion) + mux.HandleFunc("GET /api/health", healthCheck) + mux.HandleFunc("GET /v1/devices/{hostname}/afk_enabled", withAuth.Wrap(m.getAFKEnabled)) + mux.HandleFunc("GET /v1/devices/{hostname}/openconfig", withAuth.Wrap(m.getDeviceOpenConfig)) + mux.HandleFunc("GET /v1/devices/{hostname}/ietfconfig", withAuth.Wrap(m.getDeviceIETFConfig)) + mux.HandleFunc("GET /v1/devices/{hostname}/config", withAuth.Wrap(m.getDeviceConfig)) + mux.HandleFunc("GET /v1/report/last", withAuth.Wrap(m.getLastReport)) + mux.HandleFunc("GET /v1/report/last/complete", withAuth.Wrap(m.getLastCompleteReport)) + mux.HandleFunc("GET /v1/report/last/successful", withAuth.Wrap(m.getLastSuccessfulReport)) + mux.HandleFunc("POST /v1/build/trigger", withAuth.Wrap(m.triggerBuild)) if enablepprof { - router.HandlerFunc(http.MethodGet, "/debug/pprof/", pprof.Index) - router.HandlerFunc(http.MethodGet, "/debug/pprof/allocs", pprof.Index) - router.HandlerFunc(http.MethodGet, "/debug/pprof/goroutine", pprof.Index) - router.HandlerFunc(http.MethodGet, "/debug/pprof/heap", pprof.Index) - router.HandlerFunc(http.MethodGet, "/debug/pprof/profile", pprof.Profile) - router.HandlerFunc(http.MethodGet, "/debug/pprof/trace", pprof.Trace) - router.HandlerFunc(http.MethodGet, "/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("GET /debug/pprof/", pprof.Index) + mux.HandleFunc("GET /debug/pprof/allocs", pprof.Index) + mux.HandleFunc("GET /debug/pprof/goroutine", pprof.Index) + mux.HandleFunc("GET /debug/pprof/heap", pprof.Index) + mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile) + mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace) + mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol) } listenSocket := fmt.Sprint(address, ":", port) log.Info().Msgf("Start webserver - listening on %s", listenSocket) - httpServer := http.Server{Addr: listenSocket, Handler: router} + httpServer := http.Server{Addr: listenSocket, Handler: mux} // TODO: handle http failure! with a channel go func() {