From d3f78a0c2d0c5eaca57681f9264762e5d2c5a1e4 Mon Sep 17 00:00:00 2001 From: Leo Antunes Date: Wed, 22 Feb 2023 23:15:56 +0100 Subject: [PATCH] feat: add backward-compat alias types --- cmd/jwt/main.go | 2 +- example_test.go | 8 ++++---- hmac_example_test.go | 2 +- http_example_test.go | 4 ++-- parser.go | 6 +++--- parser_test.go | 20 ++++++++++---------- request/request.go | 12 +++++++++--- request/request_test.go | 2 +- token.go | 24 +++++++++++++++--------- token_test.go | 4 ++-- 10 files changed, 48 insertions(+), 36 deletions(-) diff --git a/cmd/jwt/main.go b/cmd/jwt/main.go index ac31ce44..5fba30cb 100644 --- a/cmd/jwt/main.go +++ b/cmd/jwt/main.go @@ -128,7 +128,7 @@ func verifyToken() error { } // Parse the token. Load the key from command line option - token, err := jwt.Parse(string(tokData), func(t *jwt.Token[jwt.MapClaims]) (interface{}, error) { + token, err := jwt.Parse(string(tokData), func(t *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) { if isNone() { return jwt.UnsafeAllowNoneSignatureType, nil } diff --git a/example_test.go b/example_test.go index 0781eb79..0e84fda3 100644 --- a/example_test.go +++ b/example_test.go @@ -80,7 +80,7 @@ func ExampleParseWithClaims_customClaimsType() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }) @@ -103,7 +103,7 @@ func ExampleParseWithClaims_validationOptions() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }, jwt.WithLeeway(5*time.Second)) @@ -138,7 +138,7 @@ func (m MyCustomClaims) Validate() error { func ExampleParseWithClaims_customValidation() { tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" - token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }, jwt.WithLeeway(5*time.Second)) @@ -156,7 +156,7 @@ func ExampleParse_errorChecking() { // Token from another example. This token is expired tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" - token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) { return []byte("AllYourBase"), nil }) diff --git a/hmac_example_test.go b/hmac_example_test.go index 256c21a8..0c811072 100644 --- a/hmac_example_test.go +++ b/hmac_example_test.go @@ -47,7 +47,7 @@ func ExampleParse_hmac() { // useful if you use multiple keys for your application. The standard is to use 'kid' in the // head of the token to identify which key to use, but the parsed token (head and claims) is provided // to the callback, providing flexibility. - token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) diff --git a/http_example_test.go b/http_example_test.go index 550482f5..62fc9151 100644 --- a/http_example_test.go +++ b/http_example_test.go @@ -99,7 +99,7 @@ func Example_getTokenViaHTTP() { tokenString := strings.TrimSpace(buf.String()) // Parse the token - token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) { // since we only use the one private key to sign the tokens, // we also only use its public counter part to verify return verifyKey, nil @@ -191,7 +191,7 @@ func authHandler(w http.ResponseWriter, r *http.Request) { // only accessible with a valid token func restrictedHandler(w http.ResponseWriter, r *http.Request) { // Get token from request - token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) { + token, err := request.ParseFromRequestWithClaims(r, request.OAuth2Extractor, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) { // since we only use the one private key to sign the tokens, // we also only use its public counter part to verify return verifyKey, nil diff --git a/parser.go b/parser.go index 37768cc2..fac9697f 100644 --- a/parser.go +++ b/parser.go @@ -57,7 +57,7 @@ func NewParserFor[T Claims](options ...ParserOption) *Parser[T] { // Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims), // make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the // proper memory for it before passing in the overall claims, otherwise you might run into a panic. -func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) { +func (p *Parser[T]) Parse(tokenString string, keyFunc KeyfuncFor[T]) (*TokenFor[T], error) { token, parts, err := p.ParseUnverified(tokenString) if err != nil { return token, err @@ -119,13 +119,13 @@ func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], er // // It's only ever useful in cases where you know the signature is valid (because it has // been checked previously in the stack) and you want to extract values from it. -func (p *Parser[T]) ParseUnverified(tokenString string) (token *Token[T], parts []string, err error) { +func (p *Parser[T]) ParseUnverified(tokenString string) (token *TokenFor[T], parts []string, err error) { parts = strings.Split(tokenString, ".") if len(parts) != 3 { return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed) } - token = &Token[T]{Raw: tokenString} + token = &TokenFor[T]{Raw: tokenString} // parse Header var headerBytes []byte diff --git a/parser_test.go b/parser_test.go index 888dee81..f47424d9 100644 --- a/parser_test.go +++ b/parser_test.go @@ -35,18 +35,18 @@ const ( keyFuncNil ) -func getKeyFunc[T jwt.Claims](kind keyFuncKind) jwt.Keyfunc[T] { +func getKeyFunc[T jwt.Claims](kind keyFuncKind) jwt.KeyfuncFor[T] { switch kind { case keyFuncDefault: - return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestDefaultKey, nil } + return func(t *jwt.TokenFor[T]) (interface{}, error) { return jwtTestDefaultKey, nil } case keyFuncECDSA: - return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestEC256PublicKey, nil } + return func(t *jwt.TokenFor[T]) (interface{}, error) { return jwtTestEC256PublicKey, nil } case keyFuncPadded: - return func(t *jwt.Token[T]) (interface{}, error) { return paddedKey, nil } + return func(t *jwt.TokenFor[T]) (interface{}, error) { return paddedKey, nil } case keyFuncEmpty: - return func(t *jwt.Token[T]) (interface{}, error) { return nil, nil } + return func(t *jwt.TokenFor[T]) (interface{}, error) { return nil, nil } case keyFuncError: - return func(t *jwt.Token[T]) (interface{}, error) { return nil, errKeyFuncError } + return func(t *jwt.TokenFor[T]) (interface{}, error) { return nil, errKeyFuncError } case keyFuncNil: return nil default: @@ -376,8 +376,8 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string { // cloneToken is necesssary to "forget" the type information back to a generic jwt.Claims. // Assignment of parameterized types is currently (1.20) not supported. -func cloneToken[T jwt.Claims](tin *jwt.Token[T]) *jwt.Token[jwt.Claims] { - tout := &jwt.Token[jwt.Claims]{} +func cloneToken[T jwt.Claims](tin *jwt.TokenFor[T]) *jwt.TokenFor[jwt.Claims] { + tout := &jwt.TokenFor[jwt.Claims]{} tout.Claims = tin.Claims tout.Header = tin.Header tout.Method = tin.Method @@ -397,7 +397,7 @@ func TestParser_Parse(t *testing.T) { } // Parse the token - var token *jwt.Token[jwt.Claims] + var token *jwt.TokenFor[jwt.Claims] var err error switch data.claims.(type) { case jwt.MapClaims: @@ -475,7 +475,7 @@ func TestParser_ParseUnverified(t *testing.T) { } // Parse the token - var token *jwt.Token[jwt.Claims] + var token *jwt.TokenFor[jwt.Claims] var err error switch data.claims.(type) { case jwt.MapClaims: diff --git a/request/request.go b/request/request.go index e906902a..83e3c9ae 100644 --- a/request/request.go +++ b/request/request.go @@ -12,7 +12,11 @@ import ( // the logic for extracting a token. Several useful implementations are provided. // // You can provide options to modify parsing behavior -func ParseFromRequest[T jwt.Claims](req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc[T], options ...ParseFromRequestOption[T]) (token *jwt.Token[T], err error) { +func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...Option) (token *jwt.Token, err error) { + return ParseFromRequestWithClaims(req, extractor, keyFunc, options...) +} + +func ParseFromRequestWithClaims[T jwt.Claims](req *http.Request, extractor Extractor, keyFunc jwt.KeyfuncFor[T], options ...OptionFor[T]) (token *jwt.TokenFor[T], err error) { // Create basic parser struct p := &fromRequestParser[T]{ req: req, @@ -45,10 +49,12 @@ type fromRequestParser[T jwt.Claims] struct { parser *jwt.Parser[T] } -type ParseFromRequestOption[T jwt.Claims] func(*fromRequestParser[T]) +type OptionFor[T jwt.Claims] func(*fromRequestParser[T]) + +type Option = OptionFor[jwt.MapClaims] // WithParser parses using a custom parser -func WithParser[T jwt.Claims](parser *jwt.Parser[T]) ParseFromRequestOption[T] { +func WithParser[T jwt.Claims](parser *jwt.Parser[T]) OptionFor[T] { return func(p *fromRequestParser[T]) { p.parser = parser } diff --git a/request/request_test.go b/request/request_test.go index dd19bb01..eb55e42d 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -58,7 +58,7 @@ func TestParseRequest(t *testing.T) { // load keys from disk privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key") publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub") - keyfunc := func(*jwt.Token[jwt.MapClaims]) (interface{}, error) { + keyfunc := func(*jwt.TokenFor[jwt.MapClaims]) (interface{}, error) { return publicKey, nil } diff --git a/token.go b/token.go index 0b7b93be..9a7b4a83 100644 --- a/token.go +++ b/token.go @@ -28,11 +28,14 @@ var DecodeStrict bool // the key for verification. The function receives the parsed, but unverified // Token. This allows you to use properties in the Header of the token (such as // `kid`) to identify which key to use. -type Keyfunc[T Claims] func(*Token[T]) (interface{}, error) +type KeyfuncFor[T Claims] func(*TokenFor[T]) (interface{}, error) + +// Keyfunc is an alias for KeyfuncFor[Claims], for backward compatibility. +type Keyfunc = KeyfuncFor[MapClaims] // Token represents a JWT Token. Different fields will be used depending on // whether you're creating or parsing/verifying a token. -type Token[T Claims] struct { +type TokenFor[T Claims] struct { Raw string // Raw contains the raw token. Populated when you [Parse] a token Method SigningMethod // Method is the signing method used or to be used Header map[string]interface{} // Header is the first segment of the token @@ -41,16 +44,19 @@ type Token[T Claims] struct { Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token } +// Token is an alias for TokenFor[Claims], for backward compatibility. +type Token = TokenFor[MapClaims] + // New creates a new [Token] with the specified signing method and an empty map of // claims. -func New(method SigningMethod) *Token[MapClaims] { +func New(method SigningMethod) *Token { return NewWithClaims(method, MapClaims{}) } // NewWithClaims creates a new [Token] with the specified signing method and // claims. -func NewWithClaims[T Claims](method SigningMethod, claims T) *Token[T] { - return &Token[T]{ +func NewWithClaims[T Claims](method SigningMethod, claims T) *TokenFor[T] { + return &TokenFor[T]{ Header: map[string]interface{}{ "typ": "JWT", "alg": method.Alg(), @@ -62,7 +68,7 @@ func NewWithClaims[T Claims](method SigningMethod, claims T) *Token[T] { // SignedString creates and returns a complete, signed JWT. The token is signed // using the SigningMethod specified in the token. -func (t *Token[T]) SignedString(key interface{}) (string, error) { +func (t *TokenFor[T]) SignedString(key interface{}) (string, error) { sstr, err := t.SigningString() if err != nil { return "", err @@ -79,7 +85,7 @@ func (t *Token[T]) SignedString(key interface{}) (string, error) { // SigningString generates the signing string. This is the most expensive part // of the whole deal. Unless you need this for something special, just go // straight for the SignedString. -func (t *Token[T]) SigningString() (string, error) { +func (t *TokenFor[T]) SigningString() (string, error) { h, err := json.Marshal(t.Header) if err != nil { return "", err @@ -100,7 +106,7 @@ func (t *Token[T]) SigningString() (string, error) { // expected algorithm. For more details about the importance of validating the // 'alg' claim, see // https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ -func Parse(tokenString string, keyFunc Keyfunc[MapClaims], options ...ParserOption) (*Token[MapClaims], error) { +func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { return NewParser(options...).Parse(tokenString, keyFunc) } @@ -111,7 +117,7 @@ func Parse(tokenString string, keyFunc Keyfunc[MapClaims], options ...ParserOpti // embed a non-pointer version of the claims or b) if you are using a pointer, // allocate the proper memory for it before passing in the overall claims, // otherwise you might run into a panic. -func ParseWithClaims[T Claims](tokenString string, keyFunc Keyfunc[T], options ...ParserOption) (*Token[T], error) { +func ParseWithClaims[T Claims](tokenString string, keyFunc KeyfuncFor[T], options ...ParserOption) (*TokenFor[T], error) { return NewParserFor[T](options...).Parse(tokenString, keyFunc) } diff --git a/token_test.go b/token_test.go index f228d895..c7fddc4a 100644 --- a/token_test.go +++ b/token_test.go @@ -40,7 +40,7 @@ func TestToken_SigningString(t1 *testing.T) { } for _, tt := range tests { t1.Run(tt.name, func(t1 *testing.T) { - t := &jwt.Token[jwt.Claims]{ + t := &jwt.TokenFor[jwt.Claims]{ Raw: tt.fields.Raw, Method: tt.fields.Method, Header: tt.fields.Header, @@ -61,7 +61,7 @@ func TestToken_SigningString(t1 *testing.T) { } func BenchmarkToken_SigningString(b *testing.B) { - t := &jwt.Token[jwt.Claims]{ + t := &jwt.TokenFor[jwt.Claims]{ Method: jwt.SigningMethodHS256, Header: map[string]interface{}{ "typ": "JWT",