From 55f2c5fda4df3a23838607b7c9f781270b1f7bc9 Mon Sep 17 00:00:00 2001 From: evan slack Date: Sat, 5 Aug 2023 17:59:40 -0400 Subject: [PATCH 1/2] fix nil pointer deref with incorrect logger call --- backend/postgres/scrape.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/backend/postgres/scrape.go b/backend/postgres/scrape.go index 93350ea..f365997 100644 --- a/backend/postgres/scrape.go +++ b/backend/postgres/scrape.go @@ -5,15 +5,13 @@ import ( "database/sql" "github.com/evanofslack/analogdb" - "github.com/evanofslack/analogdb/logger" ) // ensure interface is implemented var _ analogdb.ScrapeService = (*ScrapeService)(nil) type ScrapeService struct { - db *DB - logger *logger.Logger + db *DB } func NewScrapeService(db *DB) *ScrapeService { @@ -22,8 +20,8 @@ func NewScrapeService(db *DB) *ScrapeService { func (s *ScrapeService) KeywordUpdatedPostIDs(ctx context.Context) ([]int, error) { - s.logger.Debug().Ctx(ctx).Msg("Starting get keyword updated post ids") - defer s.logger.Debug().Ctx(ctx).Msg("Finished keyword updated post ids") + s.db.logger.Debug().Ctx(ctx).Msg("Starting get keyword updated post ids") + defer s.db.logger.Debug().Ctx(ctx).Msg("Finished keyword updated post ids") tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { From ea6212297168a191155745b94d8969a48aca847e Mon Sep 17 00:00:00 2001 From: evan slack Date: Sat, 5 Aug 2023 19:27:41 -0400 Subject: [PATCH 2/2] move around logging and rate limit middleware --- backend/cmd/analogdb/main.go | 2 +- backend/logger/middleware.go | 51 ------------------------------ backend/server/auth.go | 14 +++++---- backend/server/log.go | 61 ++++++++++++++++++++++++++++++++++++ backend/server/metrics.go | 3 -- backend/server/middleware.go | 41 ++---------------------- backend/server/ratelimit.go | 41 ++++++++++++++++++++++++ backend/server/scrape.go | 1 + backend/server/status.go | 12 +++++-- scraper/api.py | 6 ++-- 10 files changed, 128 insertions(+), 104 deletions(-) delete mode 100644 backend/logger/middleware.go create mode 100644 backend/server/log.go create mode 100644 backend/server/ratelimit.go diff --git a/backend/cmd/analogdb/main.go b/backend/cmd/analogdb/main.go index 3533da6..c94e0f5 100644 --- a/backend/cmd/analogdb/main.go +++ b/backend/cmd/analogdb/main.go @@ -155,8 +155,8 @@ func main() { } server.PostService = postService - server.ReadyService = readyService server.AuthorService = authorService + server.ReadyService = readyService server.ScrapeService = scrapeService server.KeywordService = keywordService server.SimilarityService = similarityService diff --git a/backend/logger/middleware.go b/backend/logger/middleware.go deleted file mode 100644 index 3475452..0000000 --- a/backend/logger/middleware.go +++ /dev/null @@ -1,51 +0,0 @@ -package logger - -import ( - "net/http" - "time" - - "github.com/go-chi/chi/v5/middleware" -) - -func Middleware(logger *Logger) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - - start := time.Now() - ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - ctx := r.Context() - - defer func() { - if rec := recover(); rec != nil { - http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } - - // don't log healthcheck requests - if p := r.URL.Path; p == "/healthz" || p == "/readyz" { - return - } - - // log end request - logger.Info(). - Ctx(ctx). - Timestamp(). - Fields(map[string]interface{}{ - "remote_ip": r.RemoteAddr, - "path": r.URL.Path, - "proto": r.Proto, - "method": r.Method, - "user_agent": r.Header.Get("User-Agent"), - "status": ww.Status(), - "latency_ms": float64(time.Since(start).Nanoseconds()) / 1000000.0, - "bytes_in": r.Header.Get("Content-Length"), - "bytes_out": ww.BytesWritten(), - }). - Msg("Incoming request") - }() - - next.ServeHTTP(ww, r) - } - return http.HandlerFunc(fn) - } -} diff --git a/backend/server/auth.go b/backend/server/auth.go index 42e05fb..afa6de5 100644 --- a/backend/server/auth.go +++ b/backend/server/auth.go @@ -1,26 +1,28 @@ package server import ( + "context" "crypto/sha256" "crypto/subtle" "net/http" ) +type contextKey string + +const authKey contextKey = "authorized" + func (s *Server) auth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - username := s.config.Auth.Username password := s.config.Auth.Password - authenticated := s.passBasicAuth(username, password, r) + if authenticated { - s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Authorized with basic auth") - next.ServeHTTP(w, r) + ctx := context.WithValue(r.Context(), authKey, true) + next.ServeHTTP(w, r.WithContext(ctx)) return } - s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Unauthorized with basic auth") w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) http.Error(w, "Unauthorized", http.StatusUnauthorized) }) diff --git a/backend/server/log.go b/backend/server/log.go new file mode 100644 index 0000000..51bcb02 --- /dev/null +++ b/backend/server/log.go @@ -0,0 +1,61 @@ +package server + +import ( + "net/http" + "runtime/debug" + "time" + + "github.com/go-chi/chi/v5/middleware" +) + +func (server *Server) logRequests(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + start := time.Now() + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + ctx := r.Context() + + next.ServeHTTP(ww, r) + + defer func() { + if rec := recover(); rec != nil { + err := rec.(error) + server.logger.Log(). + Stack(). + Err(err). + Ctx(ctx). + Bytes("debug_stack", debug.Stack()). + Msg("Caught error with recoverer") + http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + + // don't log healthcheck requests + if p := r.URL.Path; p == healthRoute || p == readyRoute { + return + } + + authorized := false + if a := r.Context().Value(authKey); a != nil { + authorized = true + } + + // log end request + server.logger.Info(). + Ctx(ctx). + Fields(map[string]interface{}{ + "remote_ip": r.RemoteAddr, + "path": r.URL.Path, + "proto": r.Proto, + "method": r.Method, + "user_agent": r.Header.Get("User-Agent"), + "status": ww.Status(), + "latency_ms": float64(time.Since(start).Nanoseconds()) / 1000000.0, + "bytes_in": r.Header.Get("Content-Length"), + "bytes_out": ww.BytesWritten(), + "authorized": authorized, + }). + Msg("Incoming request") + }() + + }) +} diff --git a/backend/server/metrics.go b/backend/server/metrics.go index 1d7869c..aa19e1b 100644 --- a/backend/server/metrics.go +++ b/backend/server/metrics.go @@ -83,10 +83,7 @@ func (stats *httpStats) register(registerer prometheus.Registerer) error { func (server *Server) collectStats(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // start timing start := time.Now() - - // wrap and serve ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) next.ServeHTTP(ww, r) diff --git a/backend/server/middleware.go b/backend/server/middleware.go index d23d0e1..2668ee7 100644 --- a/backend/server/middleware.go +++ b/backend/server/middleware.go @@ -1,13 +1,10 @@ package server import ( - "net/http" "time" - "github.com/evanofslack/analogdb/logger" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" - "github.com/go-chi/httprate" "github.com/riandyrn/otelchi" ) @@ -32,24 +29,10 @@ func (s *Server) mountMiddleware() { } // log all requests - s.router.Use(logger.Middleware(s.logger)) + s.router.Use(s.logRequests) - // is rate limiting enabled? - if s.config.App.RateLimitEnabled { - - // rate limit by IP with json response - rateLimiter := httprate.Limit(rateLimit, rateLimitPeriod, - httprate.WithKeyFuncs(httprate.KeyByIP), - httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - w.Write([]byte(`{"error": "Too many requests"}`)) - })) - - // bypass rate limit if authenticated - s.router.Use(middleware.Maybe(rateLimiter, s.applyRateLimit)) - s.logger.Info().Int("limit", rateLimit).Str("period", rateLimitPeriod.String()).Msg("Added rate limiting middleware") - } + // apply rate limit + s.addRatelimiter() corsHandler := cors.Handler(cors.Options{ AllowedOrigins: []string{"https://*", "http://*", "http://localhost"}, @@ -63,21 +46,3 @@ func (s *Server) mountMiddleware() { // CORS s.router.Use(corsHandler) } - -// apply rate limit only if user is not authenticated -func (s *Server) applyRateLimit(r *http.Request) bool { - - ctx := r.Context() - - rl_username := s.config.Auth.RateLimitUsername - rl_password := s.config.Auth.RateLimitPassword - - authenticated := s.passBasicAuth(rl_username, rl_password, r) - if authenticated { - s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Bypassing rate limit") - return false - } - - s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Applying rate limit") - return true -} diff --git a/backend/server/ratelimit.go b/backend/server/ratelimit.go new file mode 100644 index 0000000..35ced70 --- /dev/null +++ b/backend/server/ratelimit.go @@ -0,0 +1,41 @@ +package server + +import ( + "net/http" + + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/httprate" +) + +func (server *Server) addRatelimiter() { + + if !server.config.App.RateLimitEnabled { + return + } + + // rate limit by IP with json response + rateLimiter := httprate.Limit(rateLimit, rateLimitPeriod, + httprate.WithKeyFuncs(httprate.KeyByIP), + httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error": "Too many requests"}`)) + })) + + server.router.Use(middleware.Maybe(rateLimiter, server.applyRateLimit)) + server.logger.Info().Msg("Added rate limiting middleware") +} + +// apply rate limit only if user is not authenticated +func (server *Server) applyRateLimit(r *http.Request) bool { + + rl_username := server.config.Auth.RateLimitUsername + rl_password := server.config.Auth.RateLimitPassword + + authenticated := server.passBasicAuth(rl_username, rl_password, r) + if authenticated { + return false + } + return true + +} diff --git a/backend/server/scrape.go b/backend/server/scrape.go index 36ab805..c2cde66 100644 --- a/backend/server/scrape.go +++ b/backend/server/scrape.go @@ -22,6 +22,7 @@ func (s *Server) mountScrapeHandlers() { } func (s *Server) getKeywordUpdatedPosts(w http.ResponseWriter, r *http.Request) { + ids, err := s.ScrapeService.KeywordUpdatedPostIDs(r.Context()) if err != nil { s.writeError(w, r, err) diff --git a/backend/server/status.go b/backend/server/status.go index 591a6c8..b370693 100644 --- a/backend/server/status.go +++ b/backend/server/status.go @@ -7,11 +7,17 @@ import ( "github.com/go-chi/chi/v5" ) +const ( + pingRoute = "/ping" + healthRoute = "/healthz" + readyRoute = "/readyz" +) + func (s *Server) mountStatusHandlers() { - s.router.Route("/ping", func(r chi.Router) { r.Get("/", s.ping) }) - s.router.Route("/healthz", func(r chi.Router) { r.Get("/", s.healthz) }) - s.router.Route("/readyz", func(r chi.Router) { r.Get("/", s.readyz) }) + s.router.Route(pingRoute, func(r chi.Router) { r.Get("/", s.ping) }) + s.router.Route(healthRoute, func(r chi.Router) { r.Get("/", s.healthz) }) + s.router.Route(readyRoute, func(r chi.Router) { r.Get("/", s.readyz) }) } func (s *Server) ping(w http.ResponseWriter, r *http.Request) { diff --git a/scraper/api.py b/scraper/api.py index 7d5edf7..a89b53a 100644 --- a/scraper/api.py +++ b/scraper/api.py @@ -107,14 +107,16 @@ def get_all_post_ids() -> List[int]: def get_keyword_updated_post_ids(username: str, password: str) -> List[int]: - url = f"{base_url}/scrape/keywords/updated" + path = "scrape/keywords/updated" + + url = f"{base_url}/{path}" r = requests.get( url=url, auth=HTTPBasicAuth(username=username, password=password), ) if r.status_code != 200: raise Exception( - f"failed to fetch scrape/keyword/updated with response: {r.json()}" + f"failed to fetch {path} with response: {r.json()}" ) try: data = r.json()