Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collapse VC into Auth Middleware #62

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,25 @@ package auth

import (
"context"
"lybbrio/internal/ent/schema/ksuid"
"lybbrio/internal/ent/schema/permissions"
"lybbrio/internal/viewer"
"net/http"

"github.com/go-chi/render"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

type claimCtxKeyType string

const claimCtxKey claimCtxKeyType = "claims"

func withClaims(ctx context.Context, claims *Claims) context.Context {
return context.WithValue(ctx, claimCtxKey, claims)
}

func ClaimsFromCtx(ctx context.Context) *Claims {
claims, ok := ctx.Value(claimCtxKey).(*Claims)
if !ok {
return nil
func viewerCtxFromClaims(ctx context.Context, claims *Claims) context.Context {
perms := permissions.NewPermissions()
for _, perm := range claims.Permissions {
perms.Add(permissions.FromString(perm))
}
return claims
return viewer.NewContext(ctx, ksuid.ID(claims.UserID), perms)
}

func Middleware(prov *JWTProvider) func(http.Handler) http.Handler {
func ViewerContextMiddleware(prov *JWTProvider) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand Down Expand Up @@ -53,7 +48,7 @@ func Middleware(prov *JWTProvider) func(http.Handler) http.Handler {
return
}

ctx = withClaims(ctx, claims)
ctx = viewerCtxFromClaims(ctx, claims)
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("user_id", claims.UserID)
})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package auth

import (
"lybbrio/internal/ent/schema/permissions"
"lybbrio/internal/viewer"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -19,16 +21,19 @@ func Test_Middleware(t *testing.T) {
)
require.NoError(err)

handler := Middleware(provider)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := ClaimsFromCtx(r.Context())
require.Equal("some_user_id", claims.UserID)
require.Equal("some_user_name", claims.UserName)
handler := ViewerContextMiddleware(provider)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
view := viewer.FromContext(r.Context())
uid, ok := view.UserID()
require.True(ok)
require.Equal("some_user_id", uid.String())
require.True(view.Has(permissions.Admin))

}))

token, err := provider.CreateToken(
"some_user_id",
"some_user_name",
[]string{"some_permission"},
[]string{"Admin"},
)
require.NoError(err)

Expand Down Expand Up @@ -59,11 +64,11 @@ func Test_Middleware_BadToken(t *testing.T) {
token, err := wrong_provider.CreateToken(
"some_user_id",
"some_user_name",
[]string{"some_permission"},
[]string{"Admin"},
)
require.NoError(err)

handler := Middleware(provider)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := ViewerContextMiddleware(provider)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

Expand Down Expand Up @@ -96,7 +101,7 @@ func Test_Middleware_EmptyToken(t *testing.T) {
)
require.NoError(err)

handler := Middleware(provider)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := ViewerContextMiddleware(provider)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

Expand Down
3 changes: 1 addition & 2 deletions internal/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ func rootRun(cmd *cobra.Command, args []string) {
r.Mount("/auth", auth.Routes(client, jwtProvider))
r.Route("/graphql", func(r chi.Router) {
r.With(
auth.Middleware(jwtProvider),
middleware.ViewerContextMiddleware(client),
auth.ViewerContextMiddleware(jwtProvider),
middleware.SuperRead,
).Handle("/", graphqlHandler)
r.Handle("/playground", playground.Handler("Lybbrio GraphQL playground", "/graphql"))
Expand Down
19 changes: 18 additions & 1 deletion internal/ent/schema/permissions/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
type Permission int

const (
Admin = iota + 1
Admin Permission = iota + 1
CanCreatePublic
CanEdit

Expand All @@ -28,6 +28,19 @@ func (p Permission) String() string {
return ""
}

func FromString(s string) Permission {
switch s {
case Admin.String():
return Admin
case CanCreatePublic.String():
return CanCreatePublic
case CanEdit.String():
return CanEdit
default:
return 0
}
}

type Permissions map[Permission]struct{}

func NewPermissions(permissions ...Permission) Permissions {
Expand All @@ -43,6 +56,10 @@ func (p Permissions) Has(perm Permission) bool {
return ok
}

func (p Permissions) Add(perm Permission) {
p[perm] = struct{}{}
}

func (p Permissions) StringSlice() []string {
ret := make([]string, 0, len(p))
for k := range p {
Expand Down
45 changes: 0 additions & 45 deletions internal/middleware/viewer.go

This file was deleted.

4 changes: 1 addition & 3 deletions internal/tests/viewer_context_integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"lybbrio/internal/auth"
"lybbrio/internal/db"
"lybbrio/internal/ent/schema/permissions"
"lybbrio/internal/middleware"
"lybbrio/internal/viewer"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -47,8 +46,7 @@ func Test_ViewerContextGetsSet(t *testing.T) {
require.NoError(err)

r := chi.NewRouter()
r.Use(auth.Middleware(jwtProvider))
r.Use(middleware.ViewerContextMiddleware(client))
r.Use(auth.ViewerContextMiddleware(jwtProvider))

r.Get("/", func(w http.ResponseWriter, r *http.Request) {
viewerCtx := viewer.FromContext(r.Context())
Expand Down
Loading