Skip to content

Commit

Permalink
Move ExtractProtectedHeaders function to the crypto package
Browse files Browse the repository at this point in the history
This commit moves the ExtractProtectedHeaders to the crypto package where it located together with other JWS/JWT related methods.
  • Loading branch information
rolandgroen committed Nov 4, 2024
1 parent 6b1208f commit 0f9723c
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 194 deletions.
3 changes: 3 additions & 0 deletions crypto/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import (
// ErrPrivateKeyNotFound is returned when the private key doesn't exist
var ErrPrivateKeyNotFound = errors.New("private key not found")

// ErrorInvalidNumberOfSignatures indicates that the number of signatures present in the JWT is invalid.
var ErrorInvalidNumberOfSignatures = errors.New("invalid number of signatures")

// KIDNamingFunc is a function passed to New() which generates the kid for the pub/priv key
type KIDNamingFunc func(key crypto.PublicKey) (string, error)

Expand Down
22 changes: 22 additions & 0 deletions crypto/jwx.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,28 @@ func EncryptJWE(payload []byte, protectedHeaders map[string]interface{}, publicK
return string(encoded), err
}

// ExtractProtectedHeaders extracts the protected headers from a JWT string.
// The function takes a JWT string as input and returns a map of the protected headers.
// Note that:
// - This method ignores any parsing errors and returns an empty map instead of an error.
func ExtractProtectedHeaders(jwt string) (map[string]interface{}, error) {
headers := make(map[string]interface{})
if jwt != "" {
message, _ := jws.ParseString(jwt)
if message != nil {
if len(message.Signatures()) != 1 {
return nil, ErrorInvalidNumberOfSignatures
}
var err error
headers, err = message.Signatures()[0].ProtectedHeaders().AsMap(context.Background())
if err != nil {
return nil, err
}
}
}
return headers, nil
}

func (client *Crypto) getPrivateKey(ctx context.Context, kid string) (crypto.Signer, string, error) {
keyRef, err := client.findKeyReferenceByKid(ctx, kid)
if err != nil {
Expand Down
114 changes: 114 additions & 0 deletions crypto/jwx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,117 @@ func Test_signingAlg(t *testing.T) {
assert.EqualError(t, err, "could not determine signature algorithm for key type '<nil>'")
})
}

func TestExtractProtectedHeaders(t *testing.T) {

var normalJws = func(claims map[string]interface{}) (string, error) {
jwk, err := GenerateJWK()
if err != nil {
return "", err
}
marshal, err := json.Marshal(claims)
if err != nil {
return "", err
}
sign, err := jws.Sign(marshal, jws.WithKey(jwa.ES256, jwk))
if err != nil {
return "", err
}
return string(sign), err
}
var doubleSignedJws = func(claims map[string]interface{}) (string, error) {
jwk, err := GenerateJWK()
if err != nil {
return "", err
}
marshal, err := json.Marshal(claims)
if err != nil {
return "", err
}
sign, err := jws.Sign(marshal, jws.WithKey(jwa.ES256, jwk), jws.WithKey(jwa.ES256, jwk), jws.WithJSON())
if err != nil {
return "", err
}
return string(sign), err
}
var noSignedJws = func(claims map[string]interface{}) (string, error) {
marshal, err := json.Marshal(claims)
if err != nil {
return "", err
}
sign, err := jws.Sign(marshal, jws.WithInsecureNoSignature())
if err != nil {
return "", err
}
return string(sign), err
}

jwt, err := normalJws(map[string]interface{}{"iss": "test"})
if err != nil {
t.Error(err)
}
double, err := doubleSignedJws(map[string]interface{}{"iss": "test"})
if err != nil {
t.Error(err)
}
none, err := noSignedJws(map[string]interface{}{"iss": "test"})
if err != nil {
t.Error(err)
}
testCases := []struct {
name string
jwt string
expectResults bool
expectError error
}{
{
name: "ValidJWT",
jwt: jwt,
expectResults: true,
},
{
name: "too many signatures",
jwt: double,
expectResults: false,
expectError: ErrorInvalidNumberOfSignatures,
},
{
name: "no signatures",
jwt: none,
expectResults: true,
},
{
name: "InvalidJWTHeader",
jwt: "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsIng1YyI6dHJ1ZX0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.fyenaNFjX705H02aOrpHayRVHa1uVxpQRUxWCl91rB4",
},
{
name: "InvalidJWT",
jwt: "invalidToken",
},
{
name: "EmptyJWT",
jwt: "",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
headers, err := ExtractProtectedHeaders(tc.jwt)
if err != nil {
if tc.expectError == nil {
t.Errorf("ExtractProtectedHeaders() error = %v", err)
} else if err.Error() != tc.expectError.Error() {
t.Errorf("ExtractProtectedHeaders() error = %v, expected: %v", err, tc.expectError)
}
} else {
if !tc.expectResults && len(headers) > 0 {
t.Errorf("ExtractProtectedHeaders() = %v, expected an empty header map", headers)
} else if tc.expectResults {
if _, ok := headers["alg"]; ok == false {
t.Errorf("ExtractProtectedHeaders() = %v, expected a valid header map", headers)
}
}
}
})
}
}
3 changes: 1 addition & 2 deletions vcr/verifier/signature_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
crypt "crypto"
"errors"
"fmt"
"github.com/nuts-foundation/nuts-node/vdr/didx509"
"strings"
"time"

Expand All @@ -42,7 +41,7 @@ type signatureVerifier struct {
jsonldManager jsonld.JSONLD
}

var ExtractProtectedHeaders = didx509.ExtractProtectedHeaders
var ExtractProtectedHeaders = crypto.ExtractProtectedHeaders

// VerifySignature checks if the signature on a VP is valid at a given time
func (sv *signatureVerifier) VerifySignature(credentialToVerify vc.VerifiableCredential, validateAt *time.Time) error {
Expand Down
52 changes: 0 additions & 52 deletions vdr/didx509/jwt_utils.go

This file was deleted.

140 changes: 0 additions & 140 deletions vdr/didx509/jwt_utils_test.go

This file was deleted.

0 comments on commit 0f9723c

Please sign in to comment.