Skip to content

Commit

Permalink
Merge pull request #113 from evanofslack/backend
Browse files Browse the repository at this point in the history
Change up middlewares and fix nil pointer deref
  • Loading branch information
evanofslack committed Aug 5, 2023
2 parents be0d4d2 + ea62122 commit bb8ea11
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 109 deletions.
2 changes: 1 addition & 1 deletion backend/cmd/analogdb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ func main() {
}

server.PostService = postService
server.ReadyService = readyService
server.AuthorService = authorService
server.ReadyService = readyService
server.ScrapeService = scrapeService
server.KeywordService = keywordService
server.SimilarityService = similarityService
Expand Down
51 changes: 0 additions & 51 deletions backend/logger/middleware.go

This file was deleted.

8 changes: 3 additions & 5 deletions backend/postgres/scrape.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ import (
"database/sql"

"github.com/evanofslack/analogdb"
"github.com/evanofslack/analogdb/logger"
)

// ensure interface is implemented
var _ analogdb.ScrapeService = (*ScrapeService)(nil)

type ScrapeService struct {
db *DB
logger *logger.Logger
db *DB
}

func NewScrapeService(db *DB) *ScrapeService {
Expand All @@ -22,8 +20,8 @@ func NewScrapeService(db *DB) *ScrapeService {

func (s *ScrapeService) KeywordUpdatedPostIDs(ctx context.Context) ([]int, error) {

s.logger.Debug().Ctx(ctx).Msg("Starting get keyword updated post ids")
defer s.logger.Debug().Ctx(ctx).Msg("Finished keyword updated post ids")
s.db.logger.Debug().Ctx(ctx).Msg("Starting get keyword updated post ids")
defer s.db.logger.Debug().Ctx(ctx).Msg("Finished keyword updated post ids")

tx, err := s.db.db.BeginTx(ctx, nil)
if err != nil {
Expand Down
14 changes: 8 additions & 6 deletions backend/server/auth.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
package server

import (
"context"
"crypto/sha256"
"crypto/subtle"
"net/http"
)

type contextKey string

const authKey contextKey = "authorized"

func (s *Server) auth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

ctx := r.Context()

username := s.config.Auth.Username
password := s.config.Auth.Password

authenticated := s.passBasicAuth(username, password, r)

if authenticated {
s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Authorized with basic auth")
next.ServeHTTP(w, r)
ctx := context.WithValue(r.Context(), authKey, true)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Unauthorized with basic auth")
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
Expand Down
61 changes: 61 additions & 0 deletions backend/server/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package server

import (
"net/http"
"runtime/debug"
"time"

"github.com/go-chi/chi/v5/middleware"
)

func (server *Server) logRequests(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

start := time.Now()
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
ctx := r.Context()

next.ServeHTTP(ww, r)

defer func() {
if rec := recover(); rec != nil {
err := rec.(error)
server.logger.Log().
Stack().
Err(err).
Ctx(ctx).
Bytes("debug_stack", debug.Stack()).
Msg("Caught error with recoverer")
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

// don't log healthcheck requests
if p := r.URL.Path; p == healthRoute || p == readyRoute {
return
}

authorized := false
if a := r.Context().Value(authKey); a != nil {
authorized = true
}

// log end request
server.logger.Info().
Ctx(ctx).
Fields(map[string]interface{}{
"remote_ip": r.RemoteAddr,
"path": r.URL.Path,
"proto": r.Proto,
"method": r.Method,
"user_agent": r.Header.Get("User-Agent"),
"status": ww.Status(),
"latency_ms": float64(time.Since(start).Nanoseconds()) / 1000000.0,
"bytes_in": r.Header.Get("Content-Length"),
"bytes_out": ww.BytesWritten(),
"authorized": authorized,
}).
Msg("Incoming request")
}()

})
}
3 changes: 0 additions & 3 deletions backend/server/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ func (stats *httpStats) register(registerer prometheus.Registerer) error {
func (server *Server) collectStats(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// start timing
start := time.Now()

// wrap and serve
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)

next.ServeHTTP(ww, r)
Expand Down
41 changes: 3 additions & 38 deletions backend/server/middleware.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package server

import (
"net/http"
"time"

"github.com/evanofslack/analogdb/logger"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/go-chi/httprate"
"github.com/riandyrn/otelchi"
)

Expand All @@ -32,24 +29,10 @@ func (s *Server) mountMiddleware() {
}

// log all requests
s.router.Use(logger.Middleware(s.logger))
s.router.Use(s.logRequests)

// is rate limiting enabled?
if s.config.App.RateLimitEnabled {

// rate limit by IP with json response
rateLimiter := httprate.Limit(rateLimit, rateLimitPeriod,
httprate.WithKeyFuncs(httprate.KeyByIP),
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "Too many requests"}`))
}))

// bypass rate limit if authenticated
s.router.Use(middleware.Maybe(rateLimiter, s.applyRateLimit))
s.logger.Info().Int("limit", rateLimit).Str("period", rateLimitPeriod.String()).Msg("Added rate limiting middleware")
}
// apply rate limit
s.addRatelimiter()

corsHandler := cors.Handler(cors.Options{
AllowedOrigins: []string{"https://*", "http://*", "http://localhost"},
Expand All @@ -63,21 +46,3 @@ func (s *Server) mountMiddleware() {
// CORS
s.router.Use(corsHandler)
}

// apply rate limit only if user is not authenticated
func (s *Server) applyRateLimit(r *http.Request) bool {

ctx := r.Context()

rl_username := s.config.Auth.RateLimitUsername
rl_password := s.config.Auth.RateLimitPassword

authenticated := s.passBasicAuth(rl_username, rl_password, r)
if authenticated {
s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Bypassing rate limit")
return false
}

s.logger.Debug().Ctx(ctx).Bool("authenticated", authenticated).Msg("Applying rate limit")
return true
}
41 changes: 41 additions & 0 deletions backend/server/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package server

import (
"net/http"

"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/httprate"
)

func (server *Server) addRatelimiter() {

if !server.config.App.RateLimitEnabled {
return
}

// rate limit by IP with json response
rateLimiter := httprate.Limit(rateLimit, rateLimitPeriod,
httprate.WithKeyFuncs(httprate.KeyByIP),
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "Too many requests"}`))
}))

server.router.Use(middleware.Maybe(rateLimiter, server.applyRateLimit))
server.logger.Info().Msg("Added rate limiting middleware")
}

// apply rate limit only if user is not authenticated
func (server *Server) applyRateLimit(r *http.Request) bool {

rl_username := server.config.Auth.RateLimitUsername
rl_password := server.config.Auth.RateLimitPassword

authenticated := server.passBasicAuth(rl_username, rl_password, r)
if authenticated {
return false
}
return true

}
1 change: 1 addition & 0 deletions backend/server/scrape.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func (s *Server) mountScrapeHandlers() {
}

func (s *Server) getKeywordUpdatedPosts(w http.ResponseWriter, r *http.Request) {

ids, err := s.ScrapeService.KeywordUpdatedPostIDs(r.Context())
if err != nil {
s.writeError(w, r, err)
Expand Down
12 changes: 9 additions & 3 deletions backend/server/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ import (
"github.com/go-chi/chi/v5"
)

const (
pingRoute = "/ping"
healthRoute = "/healthz"
readyRoute = "/readyz"
)

func (s *Server) mountStatusHandlers() {

s.router.Route("/ping", func(r chi.Router) { r.Get("/", s.ping) })
s.router.Route("/healthz", func(r chi.Router) { r.Get("/", s.healthz) })
s.router.Route("/readyz", func(r chi.Router) { r.Get("/", s.readyz) })
s.router.Route(pingRoute, func(r chi.Router) { r.Get("/", s.ping) })
s.router.Route(healthRoute, func(r chi.Router) { r.Get("/", s.healthz) })
s.router.Route(readyRoute, func(r chi.Router) { r.Get("/", s.readyz) })
}

func (s *Server) ping(w http.ResponseWriter, r *http.Request) {
Expand Down
6 changes: 4 additions & 2 deletions scraper/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,16 @@ def get_all_post_ids() -> List[int]:

def get_keyword_updated_post_ids(username: str, password: str) -> List[int]:

url = f"{base_url}/scrape/keywords/updated"
path = "scrape/keywords/updated"

url = f"{base_url}/{path}"
r = requests.get(
url=url,
auth=HTTPBasicAuth(username=username, password=password),
)
if r.status_code != 200:
raise Exception(
f"failed to fetch scrape/keyword/updated with response: {r.json()}"
f"failed to fetch {path} with response: {r.json()}"
)
try:
data = r.json()
Expand Down

0 comments on commit bb8ea11

Please sign in to comment.