From 191c1a9d7a2f9ef670ef28d2fd9ec818d84b7bc9 Mon Sep 17 00:00:00 2001 From: Leo Antunes Date: Sat, 18 Feb 2023 16:39:26 +0100 Subject: [PATCH] feat!: add support for type parameter --- cmd/jwt/main.go | 2 +- example_test.go | 26 ++--- hmac_example_test.go | 11 +- http_example_test.go | 36 +++--- parser.go | 72 ++++++------ parser_option.go | 20 ++-- parser_test.go | 239 ++++++++++++++++++++++------------------ request/request.go | 39 ++----- request/request_test.go | 4 +- token.go | 22 ++-- token_test.go | 6 +- validator.go | 2 +- 12 files changed, 244 insertions(+), 235 deletions(-) diff --git a/cmd/jwt/main.go b/cmd/jwt/main.go index f1e49a90..ac31ce44 100644 --- a/cmd/jwt/main.go +++ b/cmd/jwt/main.go @@ -128,7 +128,7 @@ func verifyToken() error { } // Parse the token. Load the key from command line option - token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) { + token, err := jwt.Parse(string(tokData), func(t *jwt.Token[jwt.MapClaims]) (interface{}, error) { if isNone() { return jwt.UnsafeAllowNoneSignatureType, nil } diff --git a/example_test.go b/example_test.go index 58fdea43..0781eb79 100644 --- a/example_test.go +++ b/example_test.go @@ -25,7 +25,7 @@ func ExampleNewWithClaims_registeredClaims() { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) ss, err := token.SignedString(mySigningKey) fmt.Printf("%v %v", ss, err) - //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 + // Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 } // Example creating a token using a custom claims type. The RegisteredClaims is embedded @@ -67,7 +67,7 @@ func ExampleNewWithClaims_customClaimsType() { ss, err := token.SignedString(mySigningKey) fmt.Printf("%v %v", ss, err) - //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM + // Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM } // Example creating a token using a custom claims type. The RegisteredClaims is embedded @@ -80,12 +80,12 @@ func ExampleParseWithClaims_customClaimsType() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }) - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + if token.Valid { + fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer) } else { fmt.Println(err) } @@ -103,12 +103,12 @@ func ExampleParseWithClaims_validationOptions() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }, jwt.WithLeeway(5*time.Second)) - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + if token.Valid { + fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer) } else { fmt.Println(err) } @@ -138,12 +138,12 @@ func (m MyCustomClaims) Validate() error { func ExampleParseWithClaims_customValidation() { tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }, jwt.WithLeeway(5*time.Second)) - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + if token.Valid { + fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer) } else { fmt.Println(err) } @@ -154,9 +154,9 @@ func ExampleParseWithClaims_customValidation() { // An example of parsing the error types using errors.Is. func ExampleParse_errorChecking() { // Token from another example. This token is expired - var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" + tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }) diff --git a/hmac_example_test.go b/hmac_example_test.go index 4b2ff08a..256c21a8 100644 --- a/hmac_example_test.go +++ b/hmac_example_test.go @@ -47,7 +47,7 @@ func ExampleParse_hmac() { // useful if you use multiple keys for your application. The standard is to use 'kid' in the // head of the token to identify which key to use, but the parsed token (head and claims) is provided // to the callback, providing flexibility. - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) @@ -56,12 +56,11 @@ func ExampleParse_hmac() { // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key") return hmacSampleSecret, nil }) - - if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - fmt.Println(claims["foo"], claims["nbf"]) - } else { - fmt.Println(err) + if err != nil { + panic(err) } + fmt.Println(token.Claims["foo"], token.Claims["nbf"]) + // Output: bar 1.4444784e+09 } diff --git a/http_example_test.go b/http_example_test.go index 090aa4f7..550482f5 100644 --- a/http_example_test.go +++ b/http_example_test.go @@ -99,21 +99,20 @@ func Example_getTokenViaHTTP() { tokenString := strings.TrimSpace(buf.String()) // Parse the token - token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) { // since we only use the one private key to sign the tokens, // we also only use its public counter part to verify return verifyKey, nil }) fatal(err) - claims := token.Claims.(*CustomClaimsExample) + claims := token.Claims fmt.Println(claims.CustomerInfo.Name) - //Output: test + // Output: test } func Example_useTokenViaHTTP() { - // Make a sample token // In a real world situation, this token will have been acquired from // some other API call (see Example_getTokenViaHTTP) @@ -138,18 +137,18 @@ func Example_useTokenViaHTTP() { func createToken(user string) (string, error) { // create a signer for rsa 256 - t := jwt.New(jwt.GetSigningMethod("RS256")) - - // set our claims - t.Claims = &CustomClaimsExample{ - jwt.RegisteredClaims{ - // set the expire time - // see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)), + t := jwt.NewWithClaims( + jwt.GetSigningMethod("RS256"), + &CustomClaimsExample{ + jwt.RegisteredClaims{ + // set the expire time + // see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)), + }, + "level1", + CustomerInfo{user, "human"}, }, - "level1", - CustomerInfo{user, "human"}, - } + ) // Creat token string return t.SignedString(signKey) @@ -192,12 +191,11 @@ func authHandler(w http.ResponseWriter, r *http.Request) { // only accessible with a valid token func restrictedHandler(w http.ResponseWriter, r *http.Request) { // Get token from request - token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) { + token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) { // since we only use the one private key to sign the tokens, // we also only use its public counter part to verify return verifyKey, nil - }, request.WithClaims(&CustomClaimsExample{})) - + }) // If the token is missing or invalid, return error if err != nil { w.WriteHeader(http.StatusUnauthorized) @@ -206,5 +204,5 @@ func restrictedHandler(w http.ResponseWriter, r *http.Request) { } // Token is valid - fmt.Fprintln(w, "Welcome,", token.Claims.(*CustomClaimsExample).Name) + fmt.Fprintln(w, "Welcome,", token.Claims.Name) } diff --git a/parser.go b/parser.go index 46b67931..37768cc2 100644 --- a/parser.go +++ b/parser.go @@ -7,7 +7,7 @@ import ( "strings" ) -type Parser struct { +type parserOpts struct { // If populated, only these methods will be considered valid. validMethods []string @@ -20,44 +20,54 @@ type Parser struct { validator *validator } +type Parser[T Claims] struct { + opts parserOpts +} + // NewParser creates a new Parser with the specified options -func NewParser(options ...ParserOption) *Parser { - p := &Parser{ - validator: &validator{}, +func NewParser(options ...ParserOption) *Parser[MapClaims] { + p := &Parser[MapClaims]{ + opts: parserOpts{validator: &validator{}}, } // Loop through our parsing options and apply them for _, option := range options { - option(p) + option(&p.opts) } return p } -// Parse parses, validates, verifies the signature and returns the parsed token. -// keyFunc will receive the parsed token and should return the key for validating. -func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { - return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc) +func NewParserFor[T Claims](options ...ParserOption) *Parser[T] { + p := &Parser[T]{ + opts: parserOpts{validator: &validator{}}, + } + + // Loop through our parsing options and apply them + for _, option := range options { + option(&p.opts) + } + + return p } -// ParseWithClaims parses, validates, and verifies like Parse, but supplies a default object implementing the Claims -// interface. This provides default values which can be overridden and allows a caller to use their own type, rather -// than the default MapClaims implementation of Claims. +// Parse parses, validates, verifies the signature and returns the parsed token. +// keyFunc will receive the parsed token and should return the key for validating. // // Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims), // make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the // proper memory for it before passing in the overall claims, otherwise you might run into a panic. -func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) { - token, parts, err := p.ParseUnverified(tokenString, claims) +func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) { + token, parts, err := p.ParseUnverified(tokenString) if err != nil { return token, err } // Verify signing method is in the required set - if p.validMethods != nil { - var signingMethodValid = false - var alg = token.Method.Alg() - for _, m := range p.validMethods { + if p.opts.validMethods != nil { + signingMethodValid := false + alg := token.Method.Alg() + for _, m := range p.opts.validMethods { if m == alg { signingMethodValid = true break @@ -86,13 +96,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } // Validate Claims - if !p.skipClaimsValidation { + if !p.opts.skipClaimsValidation { // Make sure we have at least a default validator - if p.validator == nil { - p.validator = newValidator() + if p.opts.validator == nil { + p.opts.validator = newValidator() } - if err := p.validator.Validate(claims); err != nil { + if err := p.opts.validator.Validate(token.Claims); err != nil { return token, newError("", ErrTokenInvalidClaims, err) } } @@ -109,13 +119,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf // // It's only ever useful in cases where you know the signature is valid (because it has // been checked previously in the stack) and you want to extract values from it. -func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { +func (p *Parser[T]) ParseUnverified(tokenString string) (token *Token[T], parts []string, err error) { parts = strings.Split(tokenString, ".") if len(parts) != 3 { return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed) } - token = &Token{Raw: tokenString} + token = &Token[T]{Raw: tokenString} // parse Header var headerBytes []byte @@ -131,23 +141,17 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // parse Claims var claimBytes []byte - token.Claims = claims - if claimBytes, err = DecodeSegment(parts[1]); err != nil { return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } + dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) - if p.useJSONNumber { + if p.opts.useJSONNumber { dec.UseNumber() } - // JSON Decode. Special case for map type to avoid weird pointer behavior - if c, ok := token.Claims.(MapClaims); ok { - err = dec.Decode(&c) - } else { - err = dec.Decode(&claims) - } + // Handle decode error - if err != nil { + if err = dec.Decode(&token.Claims); err != nil { return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err) } diff --git a/parser_option.go b/parser_option.go index 8d5917e9..ffb902bc 100644 --- a/parser_option.go +++ b/parser_option.go @@ -6,14 +6,14 @@ import "time" // behavior of the parser. To add new options, just create a function (ideally // beginning with With or Without) that returns an anonymous function that takes // a *Parser type as input and manipulates its configuration accordingly. -type ParserOption func(*Parser) +type ParserOption func(*parserOpts) // WithValidMethods is an option to supply algorithm methods that the parser // will check. Only those methods will be considered valid. It is heavily // encouraged to use this option in order to prevent attacks such as // https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/. func WithValidMethods(methods []string) ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.validMethods = methods } } @@ -21,7 +21,7 @@ func WithValidMethods(methods []string) ParserOption { // WithJSONNumber is an option to configure the underlying JSON parser with // UseNumber. func WithJSONNumber() ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.useJSONNumber = true } } @@ -29,14 +29,14 @@ func WithJSONNumber() ParserOption { // WithoutClaimsValidation is an option to disable claims validation. This // option should only be used if you exactly know what you are doing. func WithoutClaimsValidation() ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.skipClaimsValidation = true } } // WithLeeway returns the ParserOption for specifying the leeway window. func WithLeeway(leeway time.Duration) ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.validator.leeway = leeway } } @@ -45,7 +45,7 @@ func WithLeeway(leeway time.Duration) ParserOption { // primary use-case for this is testing. If you are looking for a way to account // for clock-skew, WithLeeway should be used instead. func WithTimeFunc(f func() time.Time) ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.validator.timeFunc = f } } @@ -53,7 +53,7 @@ func WithTimeFunc(f func() time.Time) ParserOption { // WithIssuedAt returns the ParserOption to enable verification // of issued-at. func WithIssuedAt() ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.validator.verifyIat = true } } @@ -67,7 +67,7 @@ func WithIssuedAt() ParserOption { // writing secure application, we decided to REQUIRE the existence of the claim, // if an audience is expected. func WithAudience(aud string) ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.validator.expectedAud = aud } } @@ -81,7 +81,7 @@ func WithAudience(aud string) ParserOption { // writing secure application, we decided to REQUIRE the existence of the claim, // if an issuer is expected. func WithIssuer(iss string) ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.validator.expectedIss = iss } } @@ -95,7 +95,7 @@ func WithIssuer(iss string) ParserOption { // writing secure application, we decided to REQUIRE the existence of the claim, // if a subject is expected. func WithSubject(sub string) ParserOption { - return func(p *Parser) { + return func(p *parserOpts) { p.validator.expectedSub = sub } } diff --git a/parser_test.go b/parser_test.go index fdb5eef3..888dee81 100644 --- a/parser_test.go +++ b/parser_test.go @@ -22,14 +22,38 @@ var ( jwtTestEC256PublicKey crypto.PublicKey jwtTestEC256PrivateKey crypto.PrivateKey paddedKey crypto.PublicKey - defaultKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return jwtTestDefaultKey, nil } - ecdsaKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return jwtTestEC256PublicKey, nil } - paddedKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return paddedKey, nil } - emptyKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, nil } - errorKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, errKeyFuncError } - nilKeyFunc jwt.Keyfunc = nil ) +type keyFuncKind int + +const ( + keyFuncDefault keyFuncKind = iota + keyFuncECDSA + keyFuncPadded + keyFuncEmpty + keyFuncError + keyFuncNil +) + +func getKeyFunc[T jwt.Claims](kind keyFuncKind) jwt.Keyfunc[T] { + switch kind { + case keyFuncDefault: + return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestDefaultKey, nil } + case keyFuncECDSA: + return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestEC256PublicKey, nil } + case keyFuncPadded: + return func(t *jwt.Token[T]) (interface{}, error) { return paddedKey, nil } + case keyFuncEmpty: + return func(t *jwt.Token[T]) (interface{}, error) { return nil, nil } + case keyFuncError: + return func(t *jwt.Token[T]) (interface{}, error) { return nil, errKeyFuncError } + case keyFuncNil: + return nil + default: + panic("unknown keyfunc kind") + } +} + func init() { // Load public keys jwtTestDefaultKey = test.LoadRSAPublicKeyFromDisk("test/sample_key.pub") @@ -42,23 +66,22 @@ func init() { // Load private keys jwtTestRSAPrivateKey = test.LoadRSAPrivateKeyFromDisk("test/sample_key") jwtTestEC256PrivateKey = test.LoadECPrivateKeyFromDisk("test/ec256-private.pem") - } var jwtTestData = []struct { name string tokenString string - keyfunc jwt.Keyfunc + keyfuncKind keyFuncKind claims jwt.Claims valid bool err []error - parser *jwt.Parser + parserOpts []jwt.ParserOption signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose }{ { "invalid JWT", "thisisnotreallyajwt", - defaultKeyFunc, + keyFuncDefault, nil, false, []error{jwt.ErrTokenMalformed}, @@ -68,7 +91,7 @@ var jwtTestData = []struct { { "invalid JSON claim", "eyJhbGciOiJSUzI1NiIsInppcCI6IkRFRiJ9.eNqqVkqtKFCyMjQ1s7Q0sbA0MtFRyk3NTUot8kxRslIKLbZQggn4JeamAoUcfRz99HxcXRWeze172tr4bFq7Ui0AAAD__w.jBXD4LT4aq4oXTgDoPkiV6n4QdSZPZI1Z4J8MWQC42aHK0oXwcovEU06dVbtB81TF-2byuu0-qi8J0GUttODT67k6gCl6DV_iuCOV7gczwTcvKslotUvXzoJ2wa0QuujnjxLEE50r0p6k0tsv_9OIFSUZzDksJFYNPlJH2eFG55DROx4TsOz98az37SujZi9GGbTc9SLgzFHPrHMrovRZ5qLC_w4JrdtsLzBBI11OQJgRYwV8fQf4O8IsMkHtetjkN7dKgUkJtRarNWOk76rpTPppLypiLU4_J0-wrElLMh1TzUVZW6Fz2cDHDDBACJgMmKQ2pOFEDK_vYZN74dLCF5GiTZV6DbXhNxO7lqT7JUN4a3p2z96G7WNRjblf2qZeuYdQvkIsiK-rCbSIE836XeY5gaBgkOzuEvzl_tMrpRmb5Oox1ibOfVT2KBh9Lvqsb1XbQjCio2CLE2ViCLqoe0AaRqlUyrk3n8BIG-r0IW4dcw96CEryEMIjsjVp9mtPXamJzf391kt8Rf3iRBqwv3zP7Plg1ResXbmsFUgOflAUPcYmfLug4W3W52ntcUlTHAKXrNfaJL9QQiYAaDukG-ZHDytsOWTuuXw7lVxjt-XYi1VbRAIjh1aIYSELEmEpE4Ny74htQtywYXMQNfJpB0nNn8IiWakgcYYMJ0TmKM", - defaultKeyFunc, + keyFuncDefault, nil, false, []error{jwt.ErrTokenMalformed}, @@ -78,7 +101,7 @@ var jwtTestData = []struct { { "bearer in JWT", "bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", - defaultKeyFunc, + keyFuncDefault, nil, false, []error{jwt.ErrTokenMalformed}, @@ -88,7 +111,7 @@ var jwtTestData = []struct { { "basic", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar"}, true, nil, @@ -98,7 +121,7 @@ var jwtTestData = []struct { { "basic expired", "", // autogen - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, []error{jwt.ErrTokenExpired}, @@ -108,7 +131,7 @@ var jwtTestData = []struct { { "basic nbf", "", // autogen - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, []error{jwt.ErrTokenNotValidYet}, @@ -118,7 +141,7 @@ var jwtTestData = []struct { { "expired and nbf", "", // autogen - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, @@ -128,7 +151,7 @@ var jwtTestData = []struct { { "basic invalid", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.EhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar"}, false, []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, @@ -138,7 +161,7 @@ var jwtTestData = []struct { { "basic nokeyfunc", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", - nilKeyFunc, + keyFuncNil, jwt.MapClaims{"foo": "bar"}, false, []error{jwt.ErrTokenUnverifiable}, @@ -148,7 +171,7 @@ var jwtTestData = []struct { { "basic nokey", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", - emptyKeyFunc, + keyFuncEmpty, jwt.MapClaims{"foo": "bar"}, false, []error{jwt.ErrTokenSignatureInvalid}, @@ -158,7 +181,7 @@ var jwtTestData = []struct { { "basic errorkey", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", - errorKeyFunc, + keyFuncError, jwt.MapClaims{"foo": "bar"}, false, []error{jwt.ErrTokenUnverifiable, errKeyFuncError}, @@ -168,171 +191,171 @@ var jwtTestData = []struct { { "invalid signing method", "", - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar"}, false, []error{jwt.ErrTokenSignatureInvalid}, - jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), + []jwt.ParserOption{jwt.WithValidMethods([]string{"HS256"})}, jwt.SigningMethodRS256, }, { "valid RSA signing method", "", - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar"}, true, nil, - jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), + []jwt.ParserOption{jwt.WithValidMethods([]string{"RS256", "HS256"})}, jwt.SigningMethodRS256, }, { "ECDSA signing method not accepted", "", - ecdsaKeyFunc, + keyFuncECDSA, jwt.MapClaims{"foo": "bar"}, false, []error{jwt.ErrTokenSignatureInvalid}, - jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), + []jwt.ParserOption{jwt.WithValidMethods([]string{"RS256", "HS256"})}, jwt.SigningMethodES256, }, { "valid ECDSA signing method", "", - ecdsaKeyFunc, + keyFuncECDSA, jwt.MapClaims{"foo": "bar"}, true, nil, - jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), + []jwt.ParserOption{jwt.WithValidMethods([]string{"HS256", "ES256"})}, jwt.SigningMethodES256, }, { "JSON Number", "", - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": json.Number("123.4")}, true, nil, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "JSON Number - basic expired", "", // autogen - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, []error{jwt.ErrTokenExpired}, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "JSON Number - basic nbf", "", // autogen - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, false, []error{jwt.ErrTokenNotValidYet}, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "JSON Number - expired and nbf", "", // autogen - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "SkipClaimsValidation during token parsing", "", // autogen - defaultKeyFunc, + keyFuncDefault, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, true, nil, - jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), + []jwt.ParserOption{jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()}, jwt.SigningMethodRS256, }, { "RFC7519 Claims", "", - defaultKeyFunc, + keyFuncDefault, &jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)), }, true, nil, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "RFC7519 Claims - single aud", "", - defaultKeyFunc, + keyFuncDefault, &jwt.RegisteredClaims{ Audience: jwt.ClaimStrings{"test"}, }, true, nil, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "RFC7519 Claims - multiple aud", "", - defaultKeyFunc, + keyFuncDefault, &jwt.RegisteredClaims{ Audience: jwt.ClaimStrings{"test", "test"}, }, true, nil, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "RFC7519 Claims - single aud with wrong type", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 } - defaultKeyFunc, + keyFuncDefault, &jwt.RegisteredClaims{ Audience: nil, // because of the unmarshal error, this will be empty }, false, []error{jwt.ErrTokenMalformed}, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "RFC7519 Claims - multiple aud with wrong types", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] } - defaultKeyFunc, + keyFuncDefault, &jwt.RegisteredClaims{ Audience: nil, // because of the unmarshal error, this will be empty }, false, []error{jwt.ErrTokenMalformed}, - jwt.NewParser(jwt.WithJSONNumber()), + []jwt.ParserOption{jwt.WithJSONNumber()}, jwt.SigningMethodRS256, }, { "RFC7519 Claims - nbf with 60s skew", "", // autogen - defaultKeyFunc, + keyFuncDefault, &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, false, []error{jwt.ErrTokenNotValidYet}, - jwt.NewParser(jwt.WithLeeway(time.Minute)), + []jwt.ParserOption{jwt.WithLeeway(time.Minute)}, jwt.SigningMethodRS256, }, { "RFC7519 Claims - nbf with 120s skew", "", // autogen - defaultKeyFunc, + keyFuncDefault, &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, true, nil, - jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), + []jwt.ParserOption{jwt.WithLeeway(2 * time.Minute)}, jwt.SigningMethodRS256, }, } @@ -351,32 +374,42 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string { return test.MakeSampleToken(claims, signingMethod, privateKey) } -func TestParser_Parse(t *testing.T) { +// cloneToken is necesssary to "forget" the type information back to a generic jwt.Claims. +// Assignment of parameterized types is currently (1.20) not supported. +func cloneToken[T jwt.Claims](tin *jwt.Token[T]) *jwt.Token[jwt.Claims] { + tout := &jwt.Token[jwt.Claims]{} + tout.Claims = tin.Claims + tout.Header = tin.Header + tout.Method = tin.Method + tout.Raw = tin.Raw + tout.Signature = tin.Signature + tout.Valid = tin.Valid + return tout +} +func TestParser_Parse(t *testing.T) { // Iterate over test data set and run tests for _, data := range jwtTestData { t.Run(data.name, func(t *testing.T) { - // If the token string is blank, use helper function to generate string if data.tokenString == "" { data.tokenString = signToken(data.claims, data.signingMethod) } // Parse the token - var token *jwt.Token + var token *jwt.Token[jwt.Claims] var err error - var parser = data.parser - if parser == nil { - parser = jwt.NewParser() - } - // Figure out correct claims type switch data.claims.(type) { case jwt.MapClaims: - token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc) + parser := jwt.NewParser(data.parserOpts...) + t, e := parser.Parse(data.tokenString, getKeyFunc[jwt.MapClaims](data.keyfuncKind)) + err = e + token = cloneToken(t) case *jwt.RegisteredClaims: - token, err = parser.ParseWithClaims(data.tokenString, &jwt.RegisteredClaims{}, data.keyfunc) - case nil: - token, err = parser.ParseWithClaims(data.tokenString, nil, data.keyfunc) + parser := jwt.NewParserFor[*jwt.RegisteredClaims](data.parserOpts...) + t, e := parser.Parse(data.tokenString, getKeyFunc[*jwt.RegisteredClaims](data.keyfuncKind)) + err = e + token = cloneToken(t) } // Verify result matches expectation @@ -389,7 +422,7 @@ func TestParser_Parse(t *testing.T) { } if !data.valid && err == nil { - t.Errorf("[%v] Invalid token passed validation", data.name) + t.Fatalf("[%v] Invalid token passed validation", data.name) } // Since the returned token is nil in the ErrTokenMalformed, we @@ -403,7 +436,7 @@ func TestParser_Parse(t *testing.T) { if err == nil { t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name) } else { - var all = false + all := false for _, e := range data.err { all = errors.Is(err, e) } @@ -428,7 +461,6 @@ func TestParser_Parse(t *testing.T) { } func TestParser_ParseUnverified(t *testing.T) { - // Iterate over test data set and run tests for _, data := range jwtTestData { // Skip test data, that intentionally contains malformed tokens, as they would lead to an error @@ -443,22 +475,19 @@ func TestParser_ParseUnverified(t *testing.T) { } // Parse the token - var token *jwt.Token + var token *jwt.Token[jwt.Claims] var err error - var parser = data.parser - if parser == nil { - parser = new(jwt.Parser) - } - // Figure out correct claims type switch data.claims.(type) { case jwt.MapClaims: - token, _, err = parser.ParseUnverified(data.tokenString, jwt.MapClaims{}) + parser := jwt.NewParser(data.parserOpts...) + t, _, e := parser.ParseUnverified(data.tokenString) + err = e + token = cloneToken(t) case *jwt.RegisteredClaims: - token, _, err = parser.ParseUnverified(data.tokenString, &jwt.RegisteredClaims{}) - } - - if err != nil { - t.Errorf("[%v] Invalid token", data.name) + parser := jwt.NewParserFor[*jwt.RegisteredClaims](data.parserOpts...) + t, _, e := parser.ParseUnverified(data.tokenString) + err = e + token = cloneToken(t) } // Verify result matches expectation @@ -488,7 +517,7 @@ var setPaddingTestData = []struct { paddedDecode bool strictDecode bool signingMethod jwt.SigningMethod - keyfunc jwt.Keyfunc + keyFuncKind keyFuncKind valid bool }{ { @@ -497,7 +526,7 @@ var setPaddingTestData = []struct { claims: jwt.MapClaims{"foo": "paddedbar"}, paddedDecode: false, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: true, }, { @@ -506,7 +535,7 @@ var setPaddingTestData = []struct { claims: jwt.MapClaims{"foo": "paddedbar"}, paddedDecode: true, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: true, }, { @@ -515,7 +544,7 @@ var setPaddingTestData = []struct { claims: jwt.MapClaims{"foo": "paddedbar"}, paddedDecode: false, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: false, }, { @@ -524,7 +553,7 @@ var setPaddingTestData = []struct { claims: jwt.MapClaims{"foo": "paddedbar"}, paddedDecode: true, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: true, }, { @@ -533,7 +562,7 @@ var setPaddingTestData = []struct { claims: nil, paddedDecode: false, signingMethod: jwt.SigningMethodES256, - keyfunc: paddedKeyFunc, + keyFuncKind: keyFuncPadded, valid: false, }, { @@ -542,7 +571,7 @@ var setPaddingTestData = []struct { claims: nil, paddedDecode: true, signingMethod: jwt.SigningMethodES256, - keyfunc: paddedKeyFunc, + keyFuncKind: keyFuncPadded, valid: true, }, // DecodeStrict tests, DecodePaddingAllowed=false @@ -554,7 +583,7 @@ var setPaddingTestData = []struct { paddedDecode: false, strictDecode: false, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: true, }, { @@ -565,7 +594,7 @@ var setPaddingTestData = []struct { paddedDecode: false, strictDecode: false, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: true, }, { @@ -576,7 +605,7 @@ var setPaddingTestData = []struct { paddedDecode: false, strictDecode: true, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: true, }, { @@ -587,7 +616,7 @@ var setPaddingTestData = []struct { paddedDecode: false, strictDecode: true, signingMethod: jwt.SigningMethodRS256, - keyfunc: defaultKeyFunc, + keyFuncKind: keyFuncDefault, valid: false, }, // DecodeStrict tests, DecodePaddingAllowed=true @@ -599,7 +628,7 @@ var setPaddingTestData = []struct { paddedDecode: true, strictDecode: false, signingMethod: jwt.SigningMethodES256, - keyfunc: paddedKeyFunc, + keyFuncKind: keyFuncPadded, valid: true, }, { @@ -610,7 +639,7 @@ var setPaddingTestData = []struct { paddedDecode: true, strictDecode: false, signingMethod: jwt.SigningMethodES256, - keyfunc: paddedKeyFunc, + keyFuncKind: keyFuncPadded, valid: true, }, { @@ -621,7 +650,7 @@ var setPaddingTestData = []struct { paddedDecode: true, strictDecode: true, signingMethod: jwt.SigningMethodES256, - keyfunc: paddedKeyFunc, + keyFuncKind: keyFuncPadded, valid: true, }, { @@ -632,7 +661,7 @@ var setPaddingTestData = []struct { paddedDecode: true, strictDecode: true, signingMethod: jwt.SigningMethodES256, - keyfunc: paddedKeyFunc, + keyFuncKind: keyFuncPadded, valid: false, }, } @@ -650,12 +679,10 @@ func TestSetPadding(t *testing.T) { } // Parse the token - var token *jwt.Token - var err error parser := jwt.NewParser(jwt.WithoutClaimsValidation()) // Figure out correct claims type - token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc) + token, err := parser.Parse(data.tokenString, getKeyFunc[jwt.MapClaims](data.keyFuncKind)) if (err == nil) != data.valid || token.Valid != data.valid { t.Errorf("[%v] Error Parsing Token with decoding padding set to %v: %v", @@ -664,7 +691,6 @@ func TestSetPadding(t *testing.T) { err, ) } - }) jwt.DecodePaddingAllowed = false jwt.DecodeStrict = false @@ -672,7 +698,6 @@ func TestSetPadding(t *testing.T) { } func BenchmarkParseUnverified(b *testing.B) { - // Iterate over test data set and run tests for _, data := range jwtTestData { // If the token string is blank, use helper function to generate string @@ -680,33 +705,31 @@ func BenchmarkParseUnverified(b *testing.B) { data.tokenString = signToken(data.claims, data.signingMethod) } - // Parse the token - var parser = data.parser - if parser == nil { - parser = new(jwt.Parser) - } // Figure out correct claims type switch data.claims.(type) { case jwt.MapClaims: + parser := jwt.NewParser(data.parserOpts...) b.Run("map_claims", func(b *testing.B) { - benchmarkParsing(b, parser, data.tokenString, jwt.MapClaims{}) + benchmarkParsing(b, parser, data.tokenString) }) case *jwt.RegisteredClaims: + parser := jwt.NewParser(data.parserOpts...) b.Run("registered_claims", func(b *testing.B) { - benchmarkParsing(b, parser, data.tokenString, &jwt.RegisteredClaims{}) + benchmarkParsing(b, parser, data.tokenString) }) } + } } // Helper method for benchmarking various parsing methods -func benchmarkParsing(b *testing.B, parser *jwt.Parser, tokenString string, claims jwt.Claims) { +func benchmarkParsing[T jwt.Claims](b *testing.B, parser *jwt.Parser[T], tokenString string) { b.Helper() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{}) + _, _, err := parser.ParseUnverified(tokenString) if err != nil { b.Fatal(err) } diff --git a/request/request.go b/request/request.go index 5723c809..e906902a 100644 --- a/request/request.go +++ b/request/request.go @@ -12,9 +12,12 @@ import ( // the logic for extracting a token. Several useful implementations are provided. // // You can provide options to modify parsing behavior -func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...ParseFromRequestOption) (token *jwt.Token, err error) { +func ParseFromRequest[T jwt.Claims](req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc[T], options ...ParseFromRequestOption[T]) (token *jwt.Token[T], err error) { // Create basic parser struct - p := &fromRequestParser{req, extractor, nil, nil} + p := &fromRequestParser[T]{ + req: req, + extractor: extractor, + } // Handle options for _, option := range options { @@ -22,11 +25,8 @@ func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfun } // Set defaults - if p.claims == nil { - p.claims = jwt.MapClaims{} - } if p.parser == nil { - p.parser = &jwt.Parser{} + p.parser = &jwt.Parser[T]{} } // perform extract @@ -36,35 +36,20 @@ func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfun } // perform parse - return p.parser.ParseWithClaims(tokenString, p.claims, keyFunc) -} - -// ParseFromRequestWithClaims is an alias for ParseFromRequest but with custom Claims type. -// -// Deprecated: use ParseFromRequest and the WithClaims option -func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { - return ParseFromRequest(req, extractor, keyFunc, WithClaims(claims)) + return p.parser.Parse(tokenString, keyFunc) } -type fromRequestParser struct { +type fromRequestParser[T jwt.Claims] struct { req *http.Request extractor Extractor - claims jwt.Claims - parser *jwt.Parser + parser *jwt.Parser[T] } -type ParseFromRequestOption func(*fromRequestParser) - -// WithClaims parses with custom claims -func WithClaims(claims jwt.Claims) ParseFromRequestOption { - return func(p *fromRequestParser) { - p.claims = claims - } -} +type ParseFromRequestOption[T jwt.Claims] func(*fromRequestParser[T]) // WithParser parses using a custom parser -func WithParser(parser *jwt.Parser) ParseFromRequestOption { - return func(p *fromRequestParser) { +func WithParser[T jwt.Claims](parser *jwt.Parser[T]) ParseFromRequestOption[T] { + return func(p *fromRequestParser[T]) { p.parser = parser } } diff --git a/request/request_test.go b/request/request_test.go index 0906d1cf..dd19bb01 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -58,7 +58,7 @@ func TestParseRequest(t *testing.T) { // load keys from disk privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key") publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub") - keyfunc := func(*jwt.Token) (interface{}, error) { + keyfunc := func(*jwt.Token[jwt.MapClaims]) (interface{}, error) { return publicKey, nil } @@ -85,7 +85,7 @@ func TestParseRequest(t *testing.T) { r.Header.Set(k, tokenString) } } - token, err := ParseFromRequestWithClaims(r, data.extractor, jwt.MapClaims{}, keyfunc) + token, err := ParseFromRequest(r, data.extractor, keyfunc) if token == nil { t.Errorf("[%v] Token was not found: %v", data.name, err) diff --git a/token.go b/token.go index 85350b1e..0b7b93be 100644 --- a/token.go +++ b/token.go @@ -28,29 +28,29 @@ var DecodeStrict bool // the key for verification. The function receives the parsed, but unverified // Token. This allows you to use properties in the Header of the token (such as // `kid`) to identify which key to use. -type Keyfunc func(*Token) (interface{}, error) +type Keyfunc[T Claims] func(*Token[T]) (interface{}, error) // Token represents a JWT Token. Different fields will be used depending on // whether you're creating or parsing/verifying a token. -type Token struct { +type Token[T Claims] struct { Raw string // Raw contains the raw token. Populated when you [Parse] a token Method SigningMethod // Method is the signing method used or to be used Header map[string]interface{} // Header is the first segment of the token - Claims Claims // Claims is the second segment of the token + Claims T // Claims is the second segment of the token Signature string // Signature is the third segment of the token. Populated when you Parse a token Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token } // New creates a new [Token] with the specified signing method and an empty map of // claims. -func New(method SigningMethod) *Token { +func New(method SigningMethod) *Token[MapClaims] { return NewWithClaims(method, MapClaims{}) } // NewWithClaims creates a new [Token] with the specified signing method and // claims. -func NewWithClaims(method SigningMethod, claims Claims) *Token { - return &Token{ +func NewWithClaims[T Claims](method SigningMethod, claims T) *Token[T] { + return &Token[T]{ Header: map[string]interface{}{ "typ": "JWT", "alg": method.Alg(), @@ -62,7 +62,7 @@ func NewWithClaims(method SigningMethod, claims Claims) *Token { // SignedString creates and returns a complete, signed JWT. The token is signed // using the SigningMethod specified in the token. -func (t *Token) SignedString(key interface{}) (string, error) { +func (t *Token[T]) SignedString(key interface{}) (string, error) { sstr, err := t.SigningString() if err != nil { return "", err @@ -79,7 +79,7 @@ func (t *Token) SignedString(key interface{}) (string, error) { // SigningString generates the signing string. This is the most expensive part // of the whole deal. Unless you need this for something special, just go // straight for the SignedString. -func (t *Token) SigningString() (string, error) { +func (t *Token[T]) SigningString() (string, error) { h, err := json.Marshal(t.Header) if err != nil { return "", err @@ -100,7 +100,7 @@ func (t *Token) SigningString() (string, error) { // expected algorithm. For more details about the importance of validating the // 'alg' claim, see // https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ -func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { +func Parse(tokenString string, keyFunc Keyfunc[MapClaims], options ...ParserOption) (*Token[MapClaims], error) { return NewParser(options...).Parse(tokenString, keyFunc) } @@ -111,8 +111,8 @@ func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token // embed a non-pointer version of the claims or b) if you are using a pointer, // allocate the proper memory for it before passing in the overall claims, // otherwise you might run into a panic. -func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { - return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc) +func ParseWithClaims[T Claims](tokenString string, keyFunc Keyfunc[T], options ...ParserOption) (*Token[T], error) { + return NewParserFor[T](options...).Parse(tokenString, keyFunc) } // EncodeSegment encodes a JWT specific base64url encoding with padding stripped diff --git a/token_test.go b/token_test.go index 52a00212..f228d895 100644 --- a/token_test.go +++ b/token_test.go @@ -40,7 +40,7 @@ func TestToken_SigningString(t1 *testing.T) { } for _, tt := range tests { t1.Run(tt.name, func(t1 *testing.T) { - t := &jwt.Token{ + t := &jwt.Token[jwt.Claims]{ Raw: tt.fields.Raw, Method: tt.fields.Method, Header: tt.fields.Header, @@ -61,7 +61,7 @@ func TestToken_SigningString(t1 *testing.T) { } func BenchmarkToken_SigningString(b *testing.B) { - t := &jwt.Token{ + t := &jwt.Token[jwt.Claims]{ Method: jwt.SigningMethodHS256, Header: map[string]interface{}{ "typ": "JWT", @@ -73,7 +73,7 @@ func BenchmarkToken_SigningString(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - t.SigningString() + _, _ = t.SigningString() } }) } diff --git a/validator.go b/validator.go index 38504389..aaa2bfdd 100644 --- a/validator.go +++ b/validator.go @@ -66,7 +66,7 @@ type validator struct { // options. This validator can then be used to validate already parsed claims. func newValidator(opts ...ParserOption) *validator { p := NewParser(opts...) - return p.validator + return p.opts.validator } // Validate validates the given claims. It will also perform any custom