Skip to content

Commit

Permalink
feat: ack issuance (#1534)
Browse files Browse the repository at this point in the history
* feat: basic ack

* feat: more work on ack store

* feat: more work on ack service

* feat: more tests

* fix: ack

* feat: improve cov

* fix: test

* feat: add expired event

* fix: merge

* fix: lint

* fix: mocks check

* fix: test

* fix: more tests

* feat: support both formats

* feat: add ack data ttl
  • Loading branch information
skynet2 authored Nov 23, 2023
1 parent db3e315 commit 2cd24a4
Show file tree
Hide file tree
Showing 30 changed files with 1,671 additions and 175 deletions.
306 changes: 155 additions & 151 deletions api/spec/openapi.gen.go

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions cmd/vc-rest/startcmd/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ const (
oidc4ciTransactionDataTTLFlagUsage = "OIDC4CI transaction data TTL. Defaults to 15m. " +
commonEnvVarUsageText + oidc4ciTransactionDataTTLEnvKey

oidc4ciAckDataTTLFlagName = "vc-oidc4ci-ack-data-ttl"
oidc4ciAckDataTTLEnvKey = "VC_OIDC4CI_ACK_DATA_TTL"
oidc4ciAckDataTTLFlagUsage = "OIDC4CI ack data TTL. Defaults to 24h. " +
commonEnvVarUsageText + oidc4ciAckDataTTLEnvKey

oidc4ciAuthStateTTLFlagName = "vc-oidc4ci-auth-state-ttl"
oidc4ciAuthStateTTLEnvKey = "VC_OIDC4CI_AUTH_STATE_TTL"
oidc4ciAuthStateTTLFlagUsage = "OIDC4CI auth state data TTL. Defaults to 15m. " +
Expand Down Expand Up @@ -379,6 +384,7 @@ const (
defaultOIDC4VPTransactionDataTTL = time.Hour
defaultOIDC4VPNonceDataTTL = 15 * time.Minute
defaultOIDC4CITransactionDataTTL = 15 * time.Minute
defaultOIDC4CIAckDataTTL = 24 * time.Hour
defaultOIDC4CIAuthStateTTL = 15 * time.Minute
defaultDataEncryptionKeyLength = 256
)
Expand Down Expand Up @@ -432,6 +438,7 @@ type transientDataParams struct {
storeType string
claimDataTTL int32
oidc4ciTransactionDataTTL int32
oidc4ciAckDataTTL int32
oidc4ciAuthStateTTL int32
oidc4vpNonceStoreDataTTL int32
oidc4vpTransactionDataTTL int32
Expand Down Expand Up @@ -793,6 +800,13 @@ func getTransientDataParams(cmd *cobra.Command) (*transientDataParams, error) {
if err != nil {
return nil, err
}

oidc4ciAckDataTTL, err := getDuration(
cmd, oidc4ciAckDataTTLFlagName, oidc4ciAckDataTTLEnvKey, defaultOIDC4CIAckDataTTL)
if err != nil {
return nil, err
}

oidc4ciAuthStateTTL, err := getDuration(
cmd, oidc4ciAuthStateTTLFlagName, oidc4ciAuthStateTTLEnvKey, defaultOIDC4CIAuthStateTTL)
if err != nil {
Expand All @@ -803,6 +817,7 @@ func getTransientDataParams(cmd *cobra.Command) (*transientDataParams, error) {
storeType: transientDataStoreType,
claimDataTTL: int32(claimDataTTL.Seconds()),
oidc4ciTransactionDataTTL: int32(oidc4ciTransactionDataTTL.Seconds()),
oidc4ciAckDataTTL: int32(oidc4ciAckDataTTL.Seconds()),
oidc4ciAuthStateTTL: int32(oidc4ciAuthStateTTL.Seconds()),
oidc4vpReceivedClaimsDataTTL: int32(oidc4vpReceivedClaimsDataTTL.Seconds()),
oidc4vpNonceStoreDataTTL: int32(oidc4vpNonceStoreDataTTL.Seconds()),
Expand Down Expand Up @@ -1163,6 +1178,7 @@ func createFlags(startCmd *cobra.Command) {
startCmd.Flags().StringP(oidc4vpTransactionDataTTLFlagName, "", "", oidc4vpTransactionDataTTLFlagUsage)
startCmd.Flags().StringP(oidc4vpNonceTTLFlagName, "", "", oidc4vpNonceTTLFlagUsage)
startCmd.Flags().StringP(oidc4ciTransactionDataTTLFlagName, "", "", oidc4ciTransactionDataTTLFlagUsage)
startCmd.Flags().StringP(oidc4ciAckDataTTLFlagName, "", "", oidc4ciAckDataTTLFlagUsage)
startCmd.Flags().StringP(oidc4ciAuthStateTTLFlagName, "", "", oidc4ciAuthStateTTLFlagUsage)

startCmd.Flags().StringP(otelServiceNameFlagName, "", "", otelServiceNameFlagUsage)
Expand Down
23 changes: 23 additions & 0 deletions cmd/vc-rest/startcmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ import (
"github.com/trustbloc/vcs/pkg/storage/mongodb/vcstatusstore"
"github.com/trustbloc/vcs/pkg/storage/redis"
redisclient "github.com/trustbloc/vcs/pkg/storage/redis"
"github.com/trustbloc/vcs/pkg/storage/redis/ackstore"
oidc4ciclaimdatastoreredis "github.com/trustbloc/vcs/pkg/storage/redis/oidc4ciclaimdatastore"
oidc4cinoncestoreredis "github.com/trustbloc/vcs/pkg/storage/redis/oidc4cinoncestore"
oidc4cistatestoreredis "github.com/trustbloc/vcs/pkg/storage/redis/oidc4cistatestore"
Expand Down Expand Up @@ -642,6 +643,8 @@ func buildEchoHandler(
return nil, fmt.Errorf("failed to instantiate oidc4ci transaction store: %w", err)
}

ackStore := getAckStore(redisClient, conf.StartupParameters.transientDataParams.oidc4ciAckDataTTL)

oidc4ciClaimDataStore, err := getOIDC4CIClaimDataStore(
conf.StartupParameters.transientDataParams.storeType,
redisClientNoTracing,
Expand Down Expand Up @@ -693,6 +696,12 @@ func buildEchoHandler(
},
)

ackService := oidc4ci.NewAckService(&oidc4ci.AckServiceConfig{
EventSvc: eventSvc,
EventTopic: conf.StartupParameters.issuerEventTopic,
AckStore: ackStore,
ProfileSvc: issuerProfileSvc,
})
oidc4ciService, err = oidc4ci.NewService(&oidc4ci.Config{
TransactionStore: oidc4ciTransactionStore,
ClaimDataStore: oidc4ciClaimDataStore,
Expand All @@ -710,6 +719,7 @@ func buildEchoHandler(
CryptoJWTSigner: vcCrypto,
JSONSchemaValidator: jsonSchemaValidator,
ClientAttestationService: clientAttestationService,
AckService: ackService,
})
if err != nil {
return nil, fmt.Errorf("failed to instantiate new oidc4ci service: %w", err)
Expand Down Expand Up @@ -813,6 +823,7 @@ func buildEchoHandler(
ClientManager: clientManagerService,
ClientIDSchemeService: clientIDSchemeSvc,
Tracer: conf.Tracer,
AckService: ackService,
}))

oidc4vpv1.RegisterHandlers(e, oidc4vpv1.NewController(&oidc4vpv1.Config{
Expand Down Expand Up @@ -1146,6 +1157,18 @@ func getOIDC4CITransactionStore(
return store, nil
}

func getAckStore(
redisClient *redis.Client,
oidc4ciTransactionDataTTL int32,
) *ackstore.Store {
if redisClient == nil {
logger.Warn("Redis client is not configured. Acknowledgement store will not be used")
return nil
}

return ackstore.New(redisClient, oidc4ciTransactionDataTTL)
}

func createRequestObjectStore(
repoType string,
s3Region string,
Expand Down
1 change: 1 addition & 0 deletions component/wallet-cli/pkg/walletrunner/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,5 @@ type CredentialResponse struct {
CNonceExpiresIn int `json:"c_nonce_expires_in,omitempty"`
Credential interface{} `json:"credential"`
Format verifiable2.OIDCFormat `json:"format"`
AckID *string `json:"ack_id"`
}
1 change: 1 addition & 0 deletions component/wallet-cli/pkg/walletrunner/wallet_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ type PerfInfo struct {
GetIssuerCredentialsOIDCConfig time.Duration `json:"vci_get_issuer_credentials_oidc_config"`
GetAccessToken time.Duration `json:"vci_get_access_token"`
GetCredential time.Duration `json:"vci_get_credential"`
CredentialsAck time.Duration `json:"vci_credentials_ack"`
FetchRequestObject time.Duration `json:"vp_fetch_request_object"`
VerifyAuthorizationRequest time.Duration `json:"vp_verify_authorization_request"`
QueryCredentialFromWallet time.Duration `json:"vp_query_credential_from_wallet"`
Expand Down
20 changes: 15 additions & 5 deletions component/wallet-cli/pkg/walletrunner/wallet_runner_oidc4vci.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (s *Service) RunOIDC4VCI(config *OIDC4VCIConfig, hooks *Hooks) error {

s.print("Getting credential")

vc, _, err := s.getCredential(
credResponse, _, err := s.getCredential(
oidcIssuerCredentialConfig.CredentialEndpoint,
config.CredentialType,
config.CredentialFormat,
Expand All @@ -243,6 +243,7 @@ func (s *Service) RunOIDC4VCI(config *OIDC4VCIConfig, hooks *Hooks) error {
return fmt.Errorf("get credential: %w", err)
}

vc := credResponse.Credential
b, err = json.Marshal(vc)
if err != nil {
return fmt.Errorf("marshal vc: %w", err)
Expand Down Expand Up @@ -271,6 +272,10 @@ func (s *Service) RunOIDC4VCI(config *OIDC4VCIConfig, hooks *Hooks) error {
s.wallet.Close()
}

if err = s.handleIssuanceAck(oidcIssuerCredentialConfig, credResponse); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -392,7 +397,7 @@ func (s *Service) RunOIDC4CIWalletInitiated(config *OIDC4VCIConfig, hooks *Hooks
s.token = token

s.print("Getting credential")
vc, _, err := s.getCredential(
credResponse, _, err := s.getCredential(
oidcIssuerCredentialConfig.CredentialEndpoint,
config.CredentialType,
config.CredentialFormat,
Expand All @@ -403,6 +408,7 @@ func (s *Service) RunOIDC4CIWalletInitiated(config *OIDC4VCIConfig, hooks *Hooks
return fmt.Errorf("get credential: %w", err)
}

vc := credResponse.Credential
b, err = json.Marshal(vc)
if err != nil {
return fmt.Errorf("marshal vc: %w", err)
Expand Down Expand Up @@ -431,6 +437,10 @@ func (s *Service) RunOIDC4CIWalletInitiated(config *OIDC4VCIConfig, hooks *Hooks
s.wallet.Close()
}

if err = s.handleIssuanceAck(oidcIssuerCredentialConfig, credResponse); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -520,7 +530,7 @@ func (s *Service) getCredential(
credentialFormat,
issuerURI string,
beforeCredentialRequestOpts ...CredentialRequestOpt,
) (interface{}, time.Duration, error) {
) (*CredentialResponse, time.Duration, error) {
credentialsRequestParamsOverride := &credentialRequestOpts{}
for _, f := range beforeCredentialRequestOpts {
f(credentialsRequestParamsOverride)
Expand Down Expand Up @@ -559,7 +569,7 @@ func (s *Service) getCredential(
} else if strings.Contains(didKeyID, "did:jwk") {
res, err := jwk.New().Read(strings.Split(didKeyID, "#")[0])
if err != nil {
return "", 0, err
return nil, 0, err
}

signerKeyID = res.DIDDocument.VerificationMethod[0].ID
Expand Down Expand Up @@ -628,7 +638,7 @@ func (s *Service) getCredential(
return nil, finalDuration, fmt.Errorf("decode credential response: %w", err)
}

return credentialResp.Credential, finalDuration, nil
return &credentialResp, finalDuration, nil
}

func (s *Service) print(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package walletrunner

import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -22,6 +24,7 @@ import (
"golang.org/x/oauth2"

"github.com/trustbloc/vcs/component/wallet-cli/pkg/credentialoffer"
issuerv1 "github.com/trustbloc/vcs/pkg/restapi/v1/issuer"
"github.com/trustbloc/vcs/pkg/restapi/v1/oidc4ci"
)

Expand Down Expand Up @@ -113,7 +116,7 @@ func (s *Service) RunOIDC4CIPreAuth(config *OIDC4VCIConfig, hooks *Hooks) (*veri

s.print("Getting credential")
startTime = time.Now()
vc, vcsDuration, err := s.getCredential(
credResponse, vcsDuration, err := s.getCredential(
oidcIssuerCredentialConfig.CredentialEndpoint,
config.CredentialType,
config.CredentialFormat,
Expand All @@ -126,6 +129,7 @@ func (s *Service) RunOIDC4CIPreAuth(config *OIDC4VCIConfig, hooks *Hooks) (*veri
s.perfInfo.VcsCIFlowDuration += vcsDuration
s.perfInfo.GetCredential = time.Since(startTime)

vc := credResponse.Credential
b, err := json.Marshal(vc)
if err != nil {
return nil, fmt.Errorf("marshal vc: %w", err)
Expand All @@ -152,5 +156,58 @@ func (s *Service) RunOIDC4CIPreAuth(config *OIDC4VCIConfig, hooks *Hooks) (*veri
s.wallet.Close()
}

startTime = time.Now()
if err = s.handleIssuanceAck(oidcIssuerCredentialConfig, credResponse); err != nil {
return nil, err
}
s.perfInfo.CredentialsAck = time.Since(startTime)

return vcParsed, nil
}

func (s *Service) handleIssuanceAck(
wellKnown *issuerv1.WellKnownOpenIDIssuerConfiguration,
credResponse *CredentialResponse,
) error {
if wellKnown == nil || credResponse == nil {
return nil
}

if wellKnown.CredentialAckEndpoint == "" || lo.FromPtr(credResponse.AckID) == "" {
return nil
}

s.print("Sending wallet ACK")

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, s.httpClient)
httpClient := s.oauthClient.Client(ctx, s.token)

b, err := json.Marshal(oidc4ci.AckRequest{
Credentials: []oidc4ci.AcpRequestItem{
{
AckId: *credResponse.AckID,
ErrorDescription: nil,
Status: "success",
IssuerIdentifier: &wellKnown.CredentialIssuer,
},
},
})
if err != nil {
return err
}

resp, err := httpClient.Post(wellKnown.CredentialAckEndpoint, "application/json", bytes.NewBuffer(b))
if err != nil {
return err
}

s.print(fmt.Sprintf("Wallet ACK sent with status code %v", resp.StatusCode))

b, _ = io.ReadAll(resp.Body) // nolint
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("expected to receive status code %d but got status code %d with response body %s",
http.StatusNoContent, resp.StatusCode, string(b))
}

return nil
}
Loading

0 comments on commit 2cd24a4

Please sign in to comment.