Skip to content

Commit

Permalink
handler refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
glaslos committed Oct 26, 2023
1 parent f86ee58 commit 2c5b454
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
19 changes: 10 additions & 9 deletions backend/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"

"github.com/honeynet/ochi/backend/entities"

"github.com/julienschmidt/httprouter"
"google.golang.org/api/idtoken"
)
Expand Down Expand Up @@ -86,7 +87,7 @@ type response struct {

// sessionHandler creates a new token for the user
func (cs *server) sessionHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
user, err := cs.uRepo.Get(userID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -156,7 +157,7 @@ func (cs *server) loginHandler(w http.ResponseWriter, r *http.Request, _ httprou

// getQueriesHandler returns a list of queries belonging to ther user.
func (cs *server) getQueriesHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
queries, err := cs.queryRepo.FindByOwnerId(userID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand All @@ -172,7 +173,7 @@ func (cs *server) getQueriesHandler(w http.ResponseWriter, r *http.Request, _ ht

// createQueryHandler creates a new query.
func (cs *server) createQueryHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
decoder := json.NewDecoder(r.Body)
defer r.Body.Close()
var t entities.Query
Expand All @@ -196,7 +197,7 @@ func (cs *server) createQueryHandler(w http.ResponseWriter, r *http.Request, _ h

// udpateQueryHandler updates an existing query making sure the user owns the query.
func (cs *server) updateQueryHandler(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
id := p.ByName("id")
q, err := cs.queryRepo.GetByID(id)
if err != nil {
Expand All @@ -219,7 +220,7 @@ func (cs *server) updateQueryHandler(w http.ResponseWriter, r *http.Request, p h
return
}
if id != q.ID {
http.Error(w, "Ids dont match", http.StatusBadRequest)
http.Error(w, "Ids don't match", http.StatusBadRequest)
return
}
err = cs.queryRepo.Update(q.ID, q.Content, q.Description, q.Active)
Expand All @@ -232,7 +233,7 @@ func (cs *server) updateQueryHandler(w http.ResponseWriter, r *http.Request, p h

// deleteQueryHandler deletes a query making sure the user owns the query.
func (cs *server) deleteQueryHandler(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
id := p.ByName("id")
q, err := cs.queryRepo.GetByID(id)
if err != nil {
Expand Down Expand Up @@ -261,7 +262,7 @@ func (cs *server) deleteQueryHandler(w http.ResponseWriter, r *http.Request, p h

// createEventHandler creates a new event
func (cs *server) createEventHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
decoder := json.NewDecoder(r.Body)
defer r.Body.Close()
var event entities.Event
Expand All @@ -285,7 +286,7 @@ func (cs *server) createEventHandler(w http.ResponseWriter, r *http.Request, _ h

// deleteEventHandler deletes an event making sure the user owns the query.
func (cs *server) deleteEventHandler(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
id := p.ByName("id")
event, err := cs.eventRepo.GetByID(id)
if err != nil {
Expand All @@ -310,7 +311,7 @@ func (cs *server) deleteEventHandler(w http.ResponseWriter, r *http.Request, p h

// getEventsHandler returns a list of events belonging to ther user.
func (cs *server) getEventsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
userID := userIDFromCtx(r.Context())
events, err := cs.eventRepo.FindByOwnerId(userID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
10 changes: 5 additions & 5 deletions backend/middleware.go → backend/handlers/auth.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package backend
package handlers

import (
"context"
Expand All @@ -9,7 +9,7 @@ import (
"github.com/julienschmidt/httprouter"
)

func tokenMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
func TokenMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
token, ok := r.URL.Query()["token"]
if !ok || len(token) == 0 || token[0] != secret {
Expand All @@ -21,9 +21,9 @@ func tokenMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
}
}

type userID string
type UserID string

func bearerMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
func BearerMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
authHeader := r.Header.Get("Authorization")
authFields := strings.Fields(authHeader)
Expand All @@ -43,7 +43,7 @@ func bearerMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
return
}

r = r.WithContext(context.WithValue(r.Context(), userID("userID"), claims.UserID))
r = r.WithContext(context.WithValue(r.Context(), UserID("userID"), claims.UserID))

h(w, r, ps)
}
Expand Down
1 change: 0 additions & 1 deletion backend/handlers/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

// CorsOptionsHandler defines a global handler for HTTP OPTIONS requests.
func CorsOptionsHandler(w http.ResponseWriter, r *http.Request) {

if r.Header.Get("Access-Control-Request-Method") != "" {
// Set CORS headers
header := w.Header()
Expand Down
20 changes: 10 additions & 10 deletions backend/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ func newRouter(cs *server) (*httprouter.Router, error) {

// websocket
r.GET("/subscribe", cs.subscribeHandler)
r.POST("/publish", tokenMiddleware(cs.publishHandler, os.Args[2]))
r.POST("/publish", handlers.TokenMiddleware(cs.publishHandler, os.Args[2]))

// user
r.POST("/login", cs.loginHandler)
r.GET("/session", handlers.CorsMiddleware(bearerMiddleware(cs.sessionHandler, os.Args[3])))
r.GET("/session", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.sessionHandler, os.Args[3])))

// query
// TODO: make CorsMiddleware more generic instead of specifying it on every handler.
r.GET("/queries", handlers.CorsMiddleware(bearerMiddleware(cs.getQueriesHandler, os.Args[3])))
r.POST("/queries", handlers.CorsMiddleware(bearerMiddleware(cs.createQueryHandler, os.Args[3])))
r.PATCH("/queries/:id", handlers.CorsMiddleware(bearerMiddleware(cs.updateQueryHandler, os.Args[3])))
r.DELETE("/queries/:id", handlers.CorsMiddleware(bearerMiddleware(cs.deleteQueryHandler, os.Args[3])))
r.GET("/queries", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.getQueriesHandler, os.Args[3])))
r.POST("/queries", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.createQueryHandler, os.Args[3])))
r.PATCH("/queries/:id", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.updateQueryHandler, os.Args[3])))
r.DELETE("/queries/:id", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.deleteQueryHandler, os.Args[3])))

// event
r.POST("/api/events", handlers.CorsMiddleware(bearerMiddleware(cs.createEventHandler, os.Args[3])))
r.DELETE("/api/events/:id", handlers.CorsMiddleware(bearerMiddleware(cs.deleteEventHandler, os.Args[3])))
r.GET("/api/events", handlers.CorsMiddleware(bearerMiddleware(cs.getEventsHandler, os.Args[3])))
r.GET("/api/events/:id", handlers.CorsMiddleware(bearerMiddleware(cs.getEventByIDHandler, os.Args[3])))
r.POST("/api/events", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.createEventHandler, os.Args[3])))
r.DELETE("/api/events/:id", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.deleteEventHandler, os.Args[3])))
r.GET("/api/events", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.getEventsHandler, os.Args[3])))
r.GET("/api/events/:id", handlers.CorsMiddleware(handlers.BearerMiddleware(cs.getEventByIDHandler, os.Args[3])))

return r, nil
}
11 changes: 10 additions & 1 deletion backend/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
package backend

import "strings"
import (
"context"
"strings"

"github.com/honeynet/ochi/backend/handlers"
)

func isNotFoundError(e error) bool {
return strings.Contains(e.Error(), "no rows in result set")
}

func userIDFromCtx(ctx context.Context) string {
return ctx.Value(handlers.UserID("userID")).(string)
}

0 comments on commit 2c5b454

Please sign in to comment.