Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and merge Callbacks #3151

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,63 @@ func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequ
}

func (r Wrapper) Callback(ctx context.Context, request CallbackRequestObject) (CallbackResponseObject, error) {
// check id in path
_, err := r.toOwnedDID(ctx, request.Did)
// validate request
// check did in path
ownDID, err := r.toOwnedDID(ctx, request.Did)
if err != nil {
// this is an OAuthError already, will be rendered as 400 but that's fine (for now) for an illegal id
return nil, err
}
// check if state is present and resolves to a client state
if request.Params.State == nil || *request.Params.State == "" {
// without state it is an invalid request, but try to provide as much useful information as possible
if request.Params.Error != nil && *request.Params.Error != "" {
callbackError := callbackRequestToError(request, nil)
callbackError.InternalError = errors.New("missing state parameter")
return nil, callbackError
}
return nil, oauthError(oauth.InvalidRequest, "missing state parameter")
}
oauthSession := new(OAuthSession)
if err = r.oauthClientStateStore().Get(*request.Params.State, oauthSession); err != nil {
return nil, oauthError(oauth.InvalidRequest, "invalid or expired state", err)
}
if !ownDID.Equals(*oauthSession.OwnDID) {
// TODO: this is a manipulated request, add error logging?
return nil, withCallbackURI(oauthError(oauth.InvalidRequest, "session DID does not match request"), oauthSession.redirectURI())
}

// if error is present, redirect error back to application initiating the flow
if request.Params.Error != nil && *request.Params.Error != "" {
return nil, callbackRequestToError(request, oauthSession.redirectURI())
}

// if error is present, delegate call to error handler
if request.Params.Error != nil {
return r.handleCallbackError(request)
// check if code is present
if request.Params.Code == nil || *request.Params.Code == "" {
return nil, withCallbackURI(oauthError(oauth.InvalidRequest, "missing code parameter"), oauthSession.redirectURI())
}

return r.handleCallback(ctx, request)
// continue flow
switch oauthSession.ClientFlow {
case credentialRequestClientFlow:
return r.handleOpenID4VCICallback(ctx, *request.Params.Code, oauthSession)
case accessTokenRequestClientFlow:
return r.handleCallback(ctx, *request.Params.Code, oauthSession)
default:
// programming error, should never happen
return nil, withCallbackURI(oauthError(oauth.ServerError, "unknown client flow for callback: '"+oauthSession.ClientFlow+"'"), oauthSession.redirectURI())
}
}

// callbackRequestToError should only be used if request.params.Error is present
func callbackRequestToError(request CallbackRequestObject, redirectURI *url.URL) oauth.OAuth2Error {
requestErr := oauth.OAuth2Error{
Code: oauth.ErrorCode(*request.Params.Error),
RedirectURI: redirectURI,
}
if request.Params.ErrorDescription != nil {
requestErr.Description = *request.Params.ErrorDescription
}
return requestErr
}

func (r Wrapper) RetrieveAccessToken(_ context.Context, request RetrieveAccessTokenRequestObject) (RetrieveAccessTokenResponseObject, error) {
Expand Down
144 changes: 122 additions & 22 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,11 @@ func TestWrapper_HandleAuthorizeRequest(t *testing.T) {
// handleAuthorizeRequestFromVerifier
_ = ctx.client.storageEngine.GetSessionDatabase().GetStore(oAuthFlowTimeout, oauthClientStateKey...).Put("state", OAuthSession{
// this is the state from the holder that was stored at the creation of the first authorization request to the verifier
ClientID: holderDID.String(),
Scope: "test",
OwnDID: &holderDID,
ClientState: "state",
RedirectURI: "https://example.com/iam/holder/cb",
ResponseType: "code",
ClientID: holderDID.String(),
Scope: "test",
OwnDID: &holderDID,
ClientState: "state",
RedirectURI: "https://example.com/iam/holder/cb",
})
_ = ctx.client.userSessionStore().Put("session-id", UserSession{
TenantDID: holderDID,
Expand Down Expand Up @@ -461,31 +460,40 @@ func TestWrapper_Callback(t *testing.T) {
errorDescription := "error description"
state := "state"
token := "token"
redirectURI, parseErr := url.Parse("https://example.com/iam/holder/cb")
require.NoError(t, parseErr)

session := OAuthSession{
ClientFlow: "access_token_request",
SessionID: "token",
OwnDID: &holderDID,
RedirectURI: "https://example.com/iam/holder/cb",
VerifierDID: &verifierDID,
RedirectURI: redirectURI.String(),
OtherDID: &verifierDID,
TokenEndpoint: "https://example.com/token",
}

t.Run("ok - error flow", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
putState(ctx, "state", session)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
State: &state,
Error: &errorCode,
ErrorDescription: &errorDescription,
},
})

require.NoError(t, err)
assert.Equal(t, "https://example.com/iam/holder/cb?error=error&error_description=error+description", res.(Callback302Response).Headers.Location)
var oauthErr oauth.OAuth2Error
require.ErrorAs(t, err, &oauthErr)
assert.Equal(t, oauth.OAuth2Error{
Code: oauth.ErrorCode(errorCode),
Description: errorDescription,
RedirectURI: redirectURI,
}, err)
assert.Nil(t, res)
})
t.Run("ok - success flow", func(t *testing.T) {
ctx := newTestClient(t)
Expand All @@ -494,11 +502,11 @@ func TestWrapper_Callback(t *testing.T) {
putState(ctx, "state", withDPoP)
putToken(ctx, token)
codeVerifier := getState(ctx, state).PKCEParams.Verifier
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil).Times(2)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:123/callback", holderDID, codeVerifier, true).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:holder/callback", holderDID, codeVerifier, true).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
Expand All @@ -518,21 +526,22 @@ func TestWrapper_Callback(t *testing.T) {
t.Run("ok - no DPoP", func(t *testing.T) {
ctx := newTestClient(t)
_ = ctx.client.oauthClientStateStore().Put(state, OAuthSession{
ClientFlow: "access_token_request",
OwnDID: &holderDID,
PKCEParams: generatePKCEParams(),
RedirectURI: "https://example.com/iam/holder/cb",
SessionID: "token",
UseDPoP: false,
VerifierDID: &verifierDID,
OtherDID: &verifierDID,
TokenEndpoint: session.TokenEndpoint,
})
putToken(ctx, token)
codeVerifier := getState(ctx, state).PKCEParams.Verifier
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil).Times(2)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:123/callback", holderDID, codeVerifier, false).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:holder/callback", holderDID, codeVerifier, false).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
Expand All @@ -542,17 +551,108 @@ func TestWrapper_Callback(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, res)
})
t.Run("unknown did", func(t *testing.T) {
t.Run("err - unknown did", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(false, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(false, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
})

assert.EqualError(t, err, "DID document not managed by this node")
assert.Nil(t, res)
})
t.Run("err - did mismatch", func(t *testing.T) {
ctx := newTestClient(t)
putState(ctx, "state", session)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

assert.Nil(t, res)
requireOAuthError(t, err, oauth.InvalidRequest, "session DID does not match request")

})
t.Run("err - missing state", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "missing state parameter")
})
t.Run("err - error flow but missing state", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Error: &errorCode,
ErrorDescription: &errorDescription,
},
})

requireOAuthError(t, err, oauth.ErrorCode(errorCode), errorDescription)
assert.EqualError(t, err, "error - missing state parameter - error description")
})
t.Run("err - expired state/session", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "invalid or expired state")
})
t.Run("err - missing code", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
putState(ctx, "state", session)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
State: &state,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "missing code parameter")
})
t.Run("err - unknown flow", func(t *testing.T) {
ctx := newTestClient(t)
_ = ctx.client.oauthClientStateStore().Put(state, OAuthSession{
ClientFlow: "",
OwnDID: &holderDID,
})
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

requireOAuthError(t, err, oauth.ServerError, "unknown client flow for callback: ''")
})
}

func TestWrapper_RetrieveAccessToken(t *testing.T) {
Expand Down
Loading
Loading