Skip to content

Commit

Permalink
feat: add token claims webhook
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Palesandro committed Nov 9, 2023
1 parent 0770dbb commit 2446616
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 1 deletion.
69 changes: 69 additions & 0 deletions pkg/webhook/claims/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package claims

import (
"encoding/json"
"fmt"

"github.com/dexidp/dex/pkg/webhook/config"
"github.com/dexidp/dex/pkg/webhook/helpers"
)

func NewClaimMutatingHook(hook *config.ClaimsMutatingHook) (*ClaimsMutatingHook, error) {
var hookInvoker ClaimsHookCaller
switch hook.Type {
case config.External:
h, err := helpers.NewWebhookHTTPHelpers(hook.Config)
if err != nil {
return nil, fmt.Errorf("could not create webhook http helpers: %v", err)
}
hookInvoker = NewWebhookCaller(h, hook.AcceptedClaims)
default:
return nil, fmt.Errorf("unknown type: %s", hook.Type)
}
return &ClaimsMutatingHook{
Name: hook.Name,
Type: hook.Type,
HookInvoker: hookInvoker,
}, nil
}

func NewWebhookCaller(h helpers.WebhookHTTPHelpers, acceptedClaims []string) *WebhookCallerImpl {
return &WebhookCallerImpl{
AcceptedClaims: acceptedClaims,
transportHelper: h,
}
}

type ClaimsWebhookPayload struct {
ConnID string `json:"connID"`
Claims map[string]interface{} `json:"claims"`
}

func (w WebhookCallerImpl) callHook(claims map[string]interface{}, connID string) (map[string]interface{}, error) {
toMarshal := ClaimsWebhookPayload{
ConnID: connID,
Claims: claims,
}

payload, err := json.Marshal(toMarshal)
if err != nil {
return nil, fmt.Errorf("could not serialize claims: %v", err)
}

body, err := w.transportHelper.CallWebhook(payload)
if err != nil {
return nil, fmt.Errorf("could not call webhook: %v", err)
}
var res map[string]interface{}

if err := json.Unmarshal(body, &res); err != nil {
return nil, fmt.Errorf("could not unmarshal response: %v", err)
}

return res, nil
}

func (w WebhookCallerImpl) CallHook(claims map[string]interface{}, connID string) (map[string]interface{}, error) {
filteredClaims := constrainScope(claims, w.AcceptedClaims)
return w.callHook(filteredClaims, connID)
}
77 changes: 77 additions & 0 deletions pkg/webhook/claims/claims_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package claims

import (
"testing"

"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"

"github.com/dexidp/dex/pkg/webhook/config"
"github.com/dexidp/dex/pkg/webhook/helpers"
)

func TestNewClaimsMutating(t *testing.T) {
hook, err := NewClaimMutatingHook(&config.ClaimsMutatingHook{
Name: "test",
Type: config.External,
AcceptedClaims: []string{"claim1", "claim2"},
Config: &config.WebhookConfig{
URL: "https://test.com",
InsecureSkipVerify: true,
},
})
assert.NoError(t, err)
assert.NotNil(t, hook)
assert.Equal(t, hook.Name, "test")
assert.Equal(t, hook.Type, config.External)
assert.IsType(t, hook.HookInvoker, &WebhookCallerImpl{})
}

func TestNewClaimsMutating_UnknownType(t *testing.T) {
hook, err := NewClaimMutatingHook(&config.ClaimsMutatingHook{
Name: "test",
Type: "Unknown",
AcceptedClaims: []string{"claim1", "claim2"},
Config: &config.WebhookConfig{
URL: "https://test.com",
InsecureSkipVerify: true,
},
})
assert.Error(t, err)
assert.Nil(t, hook)
}

func TestNewWebhookCaller(t *testing.T) {
h, err := helpers.NewWebhookHTTPHelpers(&config.WebhookConfig{
URL: "https://test.com",
InsecureSkipVerify: true,
})
assert.NoError(t, err)
d := NewWebhookCaller(h, []string{"claim1", "claim2"})
assert.NotNil(t, d)
assert.Equal(t, d.AcceptedClaims, []string{"claim1", "claim2"})
assert.IsType(t, d.transportHelper, h)
}

func TestCallHook_Logic_Error(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
h := helpers.NewMockWebhookHTTPHelpers(ctrl)
h.EXPECT().CallWebhook(gomock.Any()).Return(nil, assert.AnError)
d := NewWebhookCaller(h, []string{"claim1", "claim2"})
hook, err := d.CallHook(map[string]interface{}{"claim1": "value1", "claim2": "value2"}, "test")
assert.Error(t, err)
assert.Nil(t, hook)
}

func TestCallHook_Logic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
h := helpers.NewMockWebhookHTTPHelpers(ctrl)
h.EXPECT().CallWebhook([]byte(`{"connID":"test","claims":{"claim1":"value1"}}`)).Return([]byte(
`{"connID": "test", "claims": { "claim1" : "value1" } }`), nil)
d := NewWebhookCaller(h, []string{"claim1", "claim3"})
hook, err := d.CallHook(map[string]interface{}{"claim1": "value1", "claim2": "value2"}, "test")
assert.NoError(t, err)
assert.NotNil(t, hook)
}
13 changes: 13 additions & 0 deletions pkg/webhook/claims/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package claims

import "golang.org/x/exp/slices"

func constrainScope(claims map[string]interface{}, acceptedClaims []string) map[string]interface{} {
scopedClaims := make(map[string]interface{})
for k, v := range claims {
if slices.Contains(acceptedClaims, k) {
scopedClaims[k] = v
}
}
return scopedClaims
}
17 changes: 17 additions & 0 deletions pkg/webhook/claims/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package claims

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestConstrainScope(t *testing.T) {
res := constrainScope(map[string]interface{}{
"claim1": "value1",
"claim2": "value2",
}, []string{"claim1", "claim3"})
assert.Equal(t, res, map[string]interface{}{
"claim1": "value1",
})
}
36 changes: 36 additions & 0 deletions pkg/webhook/claims/idtoken.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package claims

import (
"encoding/json"
"fmt"

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

func getProtectedClaims() []string {
return []string{"iss", "sub", "aud", "exp", "iat", "azp", "nonce", "at_hash", "c_hash"}
}

func generateIDClaims(baseIDClaims map[string]interface{}, customClaims map[string]interface{}) map[string]interface{} {
finalClaims := map[string]interface{}{}
maps.Copy(finalClaims, baseIDClaims)
// Adding the immutable claims to the token
protectedClaims := getProtectedClaims()
for claim := range customClaims {
if !slices.Contains(protectedClaims, claim) {
finalClaims[claim] = customClaims[claim]
}
}
return finalClaims
}

func GenerateTokenFromTemplate(baseIDClaims map[string]interface{}, customClaims map[string]interface{}) ([]byte,
error,
) {
payload, err := json.Marshal(generateIDClaims(baseIDClaims, customClaims))
if err != nil {
return []byte{}, fmt.Errorf("could not serialize claims: %v", err)
}
return payload, nil
}
54 changes: 54 additions & 0 deletions pkg/webhook/claims/idtoken_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package claims

import (
"testing"

"github.com/stretchr/testify/assert"
)

func Test_GroupTranslation(t *testing.T) {
baseIDToken := map[string]interface{}{
"groups": []string{"group1", "group2"},
}
receivedIDToken := map[string]interface{}{
"groups": []string{"test2:group1", "test2:group2"},
}
res, err := GenerateTokenFromTemplate(baseIDToken, receivedIDToken)
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{"groups":["test2:group1","test2:group2"]}`))
}

func Test_EmptyInput(t *testing.T) {
res, err := GenerateTokenFromTemplate(map[string]interface{}{}, map[string]interface{}{})
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{}`))
}

func Test_ProtectedClaims(t *testing.T) {
baseIDToken := map[string]interface{}{
"iss": "iss",
"sub": "sub",
"aud": "aud",
"groups": []string{"group1", "group2"},
}
receivedIDToken := map[string]interface{}{
"iss": "iss2",
"groups": []string{"test2:group1", "test2:group2"},
}
res, err := GenerateTokenFromTemplate(baseIDToken, receivedIDToken)
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{"aud":"aud","groups":["test2:group1","test2:group2"],"iss":"iss","sub":"sub"}`))
}

func Test_NotStandardClaims(t *testing.T) {
baseIDToken := map[string]interface{}{
"groups": []string{"group1", "group2"},
}
receivedIDToken := map[string]interface{}{
"groups": []string{"test2:group1", "test2:group2"},
"custom": "custom",
}
res, err := GenerateTokenFromTemplate(baseIDToken, receivedIDToken)
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{"custom":"custom","groups":["test2:group1","test2:group2"]}`))
}
29 changes: 29 additions & 0 deletions pkg/webhook/claims/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package claims

import (
"github.com/dexidp/dex/pkg/webhook/config"
"github.com/dexidp/dex/pkg/webhook/helpers"
)

type MutateClaimsRequest interface {
MutateClaims(claims map[string]interface{}, connID string) (map[string]interface{}, error)
}

type ClaimsHookCaller interface {
CallHook(claims map[string]interface{}, connID string) (map[string]interface{}, error)
}

type ClaimsMutatingHook struct {
// Name is the name of the webhook
Name string `json:"name"`
// To be modified to enum?
Type config.HookType `json:"type"`
HookInvoker ClaimsHookCaller
}

var _ ClaimsHookCaller = &WebhookCallerImpl{}

type WebhookCallerImpl struct {
AcceptedClaims []string
transportHelper helpers.WebhookHTTPHelpers
}
42 changes: 41 additions & 1 deletion server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
jose "gopkg.in/square/go-jose.v2"

"github.com/dexidp/dex/connector"
claimsWebhook "github.com/dexidp/dex/pkg/webhook/claims"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
)
Expand Down Expand Up @@ -411,7 +412,18 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
tok.AuthorizingParty = clientID
}

payload, err := json.Marshal(tok)
// Pre-filter the claims before passing them to the webhook.
// Pass them to the mutating webhook.
var res map[string]interface{}
res = forgeMap(tok)
for _, v := range s.claimsWebhookFilter {
res, err = v.HookInvoker.CallHook(res, connID)
if err != nil {
return "", time.Time{}, err
}
}

payload, err := claimsWebhook.GenerateTokenFromTemplate(convertToMap(tok), res)
if err != nil {
return "", expiry, fmt.Errorf("could not serialize claims: %v", err)
}
Expand Down Expand Up @@ -700,3 +712,31 @@ func (s *storageKeySet) VerifySignature(_ context.Context, jwt string) (payload

return nil, errors.New("failed to verify id token signature")
}

func forgeMap(tok idTokenClaims) map[string]interface{} {
return map[string]interface{}{
"groups": tok.Groups,
"preferred_username": tok.PreferredUsername,
"email": tok.Email,
}
}

func convertToMap(baseIDClaims idTokenClaims) map[string]interface{} {
return map[string]interface{}{
"iss": baseIDClaims.Issuer,
"sub": baseIDClaims.Subject,
"aud": baseIDClaims.Audience,
"exp": baseIDClaims.Expiry,
"iat": baseIDClaims.IssuedAt,
"azp": baseIDClaims.AuthorizingParty,
"nonce": baseIDClaims.Nonce,
"at_hash": baseIDClaims.AccessTokenHash,
"c_hash": baseIDClaims.CodeHash,
"email": baseIDClaims.Email,
"email_verified": baseIDClaims.EmailVerified,
"groups": baseIDClaims.Groups,
"name": baseIDClaims.Name,
"preferred_username": baseIDClaims.PreferredUsername,
"federated_claims": baseIDClaims.FederatedIDClaims,
}
}
3 changes: 3 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/dexidp/dex/connector/openshift"
"github.com/dexidp/dex/connector/saml"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/pkg/webhook/claims"
"github.com/dexidp/dex/pkg/webhook/config"
"github.com/dexidp/dex/pkg/webhook/connectors"
"github.com/dexidp/dex/storage"
Expand Down Expand Up @@ -192,6 +193,8 @@ type Server struct {
logger log.Logger

connectorWebhookFilter []*connectors.ConnectorFilterHook

claimsWebhookFilter []*claims.ClaimsMutatingHook
}

// NewServer constructs a server from the provided config.
Expand Down

0 comments on commit 2446616

Please sign in to comment.