Skip to content

Commit

Permalink
feat: add backward-compat alias types
Browse files Browse the repository at this point in the history
  • Loading branch information
costela committed Feb 25, 2023
1 parent 191c1a9 commit d3f78a0
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 36 deletions.
2 changes: 1 addition & 1 deletion cmd/jwt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func verifyToken() error {
}

// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token[jwt.MapClaims]) (interface{}, error) {
token, err := jwt.Parse(string(tokData), func(t *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
if isNone() {
return jwt.UnsafeAllowNoneSignatureType, nil
}
Expand Down
8 changes: 4 additions & 4 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand All @@ -103,7 +103,7 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

Expand Down Expand Up @@ -138,7 +138,7 @@ func (m MyCustomClaims) Validate() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

Expand All @@ -156,7 +156,7 @@ func ExampleParse_errorChecking() {
// Token from another example. This token is expired
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand Down
2 changes: 1 addition & 1 deletion hmac_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func ExampleParse_hmac() {
// useful if you use multiple keys for your application. The standard is to use 'kid' in the
// head of the token to identify which key to use, but the parsed token (head and claims) is provided
// to the callback, providing flexibility.
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
Expand Down
4 changes: 2 additions & 2 deletions http_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func Example_getTokenViaHTTP() {
tokenString := strings.TrimSpace(buf.String())

// Parse the token
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
Expand Down Expand Up @@ -191,7 +191,7 @@ func authHandler(w http.ResponseWriter, r *http.Request) {
// only accessible with a valid token
func restrictedHandler(w http.ResponseWriter, r *http.Request) {
// Get token from request
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
token, err := request.ParseFromRequestWithClaims(r, request.OAuth2Extractor, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
Expand Down
6 changes: 3 additions & 3 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewParserFor[T Claims](options ...ParserOption) *Parser[T] {
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
func (p *Parser[T]) Parse(tokenString string, keyFunc KeyfuncFor[T]) (*TokenFor[T], error) {
token, parts, err := p.ParseUnverified(tokenString)
if err != nil {
return token, err
Expand Down Expand Up @@ -119,13 +119,13 @@ func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], er
//
// It's only ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from it.
func (p *Parser[T]) ParseUnverified(tokenString string) (token *Token[T], parts []string, err error) {
func (p *Parser[T]) ParseUnverified(tokenString string) (token *TokenFor[T], parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed)
}

token = &Token[T]{Raw: tokenString}
token = &TokenFor[T]{Raw: tokenString}

// parse Header
var headerBytes []byte
Expand Down
20 changes: 10 additions & 10 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ const (
keyFuncNil
)

func getKeyFunc[T jwt.Claims](kind keyFuncKind) jwt.Keyfunc[T] {
func getKeyFunc[T jwt.Claims](kind keyFuncKind) jwt.KeyfuncFor[T] {
switch kind {
case keyFuncDefault:
return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestDefaultKey, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return jwtTestDefaultKey, nil }
case keyFuncECDSA:
return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestEC256PublicKey, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return jwtTestEC256PublicKey, nil }
case keyFuncPadded:
return func(t *jwt.Token[T]) (interface{}, error) { return paddedKey, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return paddedKey, nil }
case keyFuncEmpty:
return func(t *jwt.Token[T]) (interface{}, error) { return nil, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return nil, nil }
case keyFuncError:
return func(t *jwt.Token[T]) (interface{}, error) { return nil, errKeyFuncError }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return nil, errKeyFuncError }
case keyFuncNil:
return nil
default:
Expand Down Expand Up @@ -376,8 +376,8 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string {

// cloneToken is necesssary to "forget" the type information back to a generic jwt.Claims.
// Assignment of parameterized types is currently (1.20) not supported.
func cloneToken[T jwt.Claims](tin *jwt.Token[T]) *jwt.Token[jwt.Claims] {
tout := &jwt.Token[jwt.Claims]{}
func cloneToken[T jwt.Claims](tin *jwt.TokenFor[T]) *jwt.TokenFor[jwt.Claims] {
tout := &jwt.TokenFor[jwt.Claims]{}
tout.Claims = tin.Claims
tout.Header = tin.Header
tout.Method = tin.Method
Expand All @@ -397,7 +397,7 @@ func TestParser_Parse(t *testing.T) {
}

// Parse the token
var token *jwt.Token[jwt.Claims]
var token *jwt.TokenFor[jwt.Claims]
var err error
switch data.claims.(type) {
case jwt.MapClaims:
Expand Down Expand Up @@ -475,7 +475,7 @@ func TestParser_ParseUnverified(t *testing.T) {
}

// Parse the token
var token *jwt.Token[jwt.Claims]
var token *jwt.TokenFor[jwt.Claims]
var err error
switch data.claims.(type) {
case jwt.MapClaims:
Expand Down
12 changes: 9 additions & 3 deletions request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ import (
// the logic for extracting a token. Several useful implementations are provided.
//
// You can provide options to modify parsing behavior
func ParseFromRequest[T jwt.Claims](req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc[T], options ...ParseFromRequestOption[T]) (token *jwt.Token[T], err error) {
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...Option) (token *jwt.Token, err error) {
return ParseFromRequestWithClaims(req, extractor, keyFunc, options...)
}

func ParseFromRequestWithClaims[T jwt.Claims](req *http.Request, extractor Extractor, keyFunc jwt.KeyfuncFor[T], options ...OptionFor[T]) (token *jwt.TokenFor[T], err error) {
// Create basic parser struct
p := &fromRequestParser[T]{
req: req,
Expand Down Expand Up @@ -45,10 +49,12 @@ type fromRequestParser[T jwt.Claims] struct {
parser *jwt.Parser[T]
}

type ParseFromRequestOption[T jwt.Claims] func(*fromRequestParser[T])
type OptionFor[T jwt.Claims] func(*fromRequestParser[T])

type Option = OptionFor[jwt.MapClaims]

// WithParser parses using a custom parser
func WithParser[T jwt.Claims](parser *jwt.Parser[T]) ParseFromRequestOption[T] {
func WithParser[T jwt.Claims](parser *jwt.Parser[T]) OptionFor[T] {
return func(p *fromRequestParser[T]) {
p.parser = parser
}
Expand Down
2 changes: 1 addition & 1 deletion request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestParseRequest(t *testing.T) {
// load keys from disk
privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key")
publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub")
keyfunc := func(*jwt.Token[jwt.MapClaims]) (interface{}, error) {
keyfunc := func(*jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
return publicKey, nil
}

Expand Down
24 changes: 15 additions & 9 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ var DecodeStrict bool
// the key for verification. The function receives the parsed, but unverified
// Token. This allows you to use properties in the Header of the token (such as
// `kid`) to identify which key to use.
type Keyfunc[T Claims] func(*Token[T]) (interface{}, error)
type KeyfuncFor[T Claims] func(*TokenFor[T]) (interface{}, error)

// Keyfunc is an alias for KeyfuncFor[Claims], for backward compatibility.
type Keyfunc = KeyfuncFor[MapClaims]

// Token represents a JWT Token. Different fields will be used depending on
// whether you're creating or parsing/verifying a token.
type Token[T Claims] struct {
type TokenFor[T Claims] struct {
Raw string // Raw contains the raw token. Populated when you [Parse] a token
Method SigningMethod // Method is the signing method used or to be used
Header map[string]interface{} // Header is the first segment of the token
Expand All @@ -41,16 +44,19 @@ type Token[T Claims] struct {
Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token
}

// Token is an alias for TokenFor[Claims], for backward compatibility.
type Token = TokenFor[MapClaims]

// New creates a new [Token] with the specified signing method and an empty map of
// claims.
func New(method SigningMethod) *Token[MapClaims] {
func New(method SigningMethod) *Token {
return NewWithClaims(method, MapClaims{})
}

// NewWithClaims creates a new [Token] with the specified signing method and
// claims.
func NewWithClaims[T Claims](method SigningMethod, claims T) *Token[T] {
return &Token[T]{
func NewWithClaims[T Claims](method SigningMethod, claims T) *TokenFor[T] {
return &TokenFor[T]{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
Expand All @@ -62,7 +68,7 @@ func NewWithClaims[T Claims](method SigningMethod, claims T) *Token[T] {

// SignedString creates and returns a complete, signed JWT. The token is signed
// using the SigningMethod specified in the token.
func (t *Token[T]) SignedString(key interface{}) (string, error) {
func (t *TokenFor[T]) SignedString(key interface{}) (string, error) {
sstr, err := t.SigningString()
if err != nil {
return "", err
Expand All @@ -79,7 +85,7 @@ func (t *Token[T]) SignedString(key interface{}) (string, error) {
// SigningString generates the signing string. This is the most expensive part
// of the whole deal. Unless you need this for something special, just go
// straight for the SignedString.
func (t *Token[T]) SigningString() (string, error) {
func (t *TokenFor[T]) SigningString() (string, error) {
h, err := json.Marshal(t.Header)
if err != nil {
return "", err
Expand All @@ -100,7 +106,7 @@ func (t *Token[T]) SigningString() (string, error) {
// expected algorithm. For more details about the importance of validating the
// 'alg' claim, see
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
func Parse(tokenString string, keyFunc Keyfunc[MapClaims], options ...ParserOption) (*Token[MapClaims], error) {
func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
return NewParser(options...).Parse(tokenString, keyFunc)
}

Expand All @@ -111,7 +117,7 @@ func Parse(tokenString string, keyFunc Keyfunc[MapClaims], options ...ParserOpti
// embed a non-pointer version of the claims or b) if you are using a pointer,
// allocate the proper memory for it before passing in the overall claims,
// otherwise you might run into a panic.
func ParseWithClaims[T Claims](tokenString string, keyFunc Keyfunc[T], options ...ParserOption) (*Token[T], error) {
func ParseWithClaims[T Claims](tokenString string, keyFunc KeyfuncFor[T], options ...ParserOption) (*TokenFor[T], error) {
return NewParserFor[T](options...).Parse(tokenString, keyFunc)
}

Expand Down
4 changes: 2 additions & 2 deletions token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestToken_SigningString(t1 *testing.T) {
}
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
t := &jwt.Token[jwt.Claims]{
t := &jwt.TokenFor[jwt.Claims]{
Raw: tt.fields.Raw,
Method: tt.fields.Method,
Header: tt.fields.Header,
Expand All @@ -61,7 +61,7 @@ func TestToken_SigningString(t1 *testing.T) {
}

func BenchmarkToken_SigningString(b *testing.B) {
t := &jwt.Token[jwt.Claims]{
t := &jwt.TokenFor[jwt.Claims]{
Method: jwt.SigningMethodHS256,
Header: map[string]interface{}{
"typ": "JWT",
Expand Down

0 comments on commit d3f78a0

Please sign in to comment.