Skip to content

Commit

Permalink
feat!: add support for type parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
costela committed Feb 25, 2023
1 parent 0d2f0d4 commit 191c1a9
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 235 deletions.
2 changes: 1 addition & 1 deletion cmd/jwt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func verifyToken() error {
}

// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(string(tokData), func(t *jwt.Token[jwt.MapClaims]) (interface{}, error) {
if isNone() {
return jwt.UnsafeAllowNoneSignatureType, nil
}
Expand Down
26 changes: 13 additions & 13 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func ExampleNewWithClaims_registeredClaims() {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, err := token.SignedString(mySigningKey)
fmt.Printf("%v %v", ss, err)
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 <nil>
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 <nil>
}

// Example creating a token using a custom claims type. The RegisteredClaims is embedded
Expand Down Expand Up @@ -67,7 +67,7 @@ func ExampleNewWithClaims_customClaimsType() {
ss, err := token.SignedString(mySigningKey)
fmt.Printf("%v %v", ss, err)

//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM <nil>
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM <nil>
}

// Example creating a token using a custom claims type. The RegisteredClaims is embedded
Expand All @@ -80,12 +80,12 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
if token.Valid {
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
} else {
fmt.Println(err)
}
Expand All @@ -103,12 +103,12 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
if token.Valid {
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
} else {
fmt.Println(err)
}
Expand Down Expand Up @@ -138,12 +138,12 @@ func (m MyCustomClaims) Validate() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
if token.Valid {
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
} else {
fmt.Println(err)
}
Expand All @@ -154,9 +154,9 @@ func ExampleParseWithClaims_customValidation() {
// An example of parsing the error types using errors.Is.
func ExampleParse_errorChecking() {
// Token from another example. This token is expired
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand Down
11 changes: 5 additions & 6 deletions hmac_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func ExampleParse_hmac() {
// useful if you use multiple keys for your application. The standard is to use 'kid' in the
// head of the token to identify which key to use, but the parsed token (head and claims) is provided
// to the callback, providing flexibility.
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
Expand All @@ -56,12 +56,11 @@ func ExampleParse_hmac() {
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
return hmacSampleSecret, nil
})

if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
fmt.Println(claims["foo"], claims["nbf"])
} else {
fmt.Println(err)
if err != nil {
panic(err)
}

fmt.Println(token.Claims["foo"], token.Claims["nbf"])

// Output: bar 1.4444784e+09
}
36 changes: 17 additions & 19 deletions http_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,20 @@ func Example_getTokenViaHTTP() {
tokenString := strings.TrimSpace(buf.String())

// Parse the token
token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
})
fatal(err)

claims := token.Claims.(*CustomClaimsExample)
claims := token.Claims
fmt.Println(claims.CustomerInfo.Name)

//Output: test
// Output: test
}

func Example_useTokenViaHTTP() {

// Make a sample token
// In a real world situation, this token will have been acquired from
// some other API call (see Example_getTokenViaHTTP)
Expand All @@ -138,18 +137,18 @@ func Example_useTokenViaHTTP() {

func createToken(user string) (string, error) {
// create a signer for rsa 256
t := jwt.New(jwt.GetSigningMethod("RS256"))

// set our claims
t.Claims = &CustomClaimsExample{
jwt.RegisteredClaims{
// set the expire time
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
t := jwt.NewWithClaims(
jwt.GetSigningMethod("RS256"),
&CustomClaimsExample{
jwt.RegisteredClaims{
// set the expire time
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
},
"level1",
CustomerInfo{user, "human"},
},
"level1",
CustomerInfo{user, "human"},
}
)

// Creat token string
return t.SignedString(signKey)
Expand Down Expand Up @@ -192,12 +191,11 @@ func authHandler(w http.ResponseWriter, r *http.Request) {
// only accessible with a valid token
func restrictedHandler(w http.ResponseWriter, r *http.Request) {
// Get token from request
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) {
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
}, request.WithClaims(&CustomClaimsExample{}))

})
// If the token is missing or invalid, return error
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
Expand All @@ -206,5 +204,5 @@ func restrictedHandler(w http.ResponseWriter, r *http.Request) {
}

// Token is valid
fmt.Fprintln(w, "Welcome,", token.Claims.(*CustomClaimsExample).Name)
fmt.Fprintln(w, "Welcome,", token.Claims.Name)
}
72 changes: 38 additions & 34 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

type Parser struct {
type parserOpts struct {
// If populated, only these methods will be considered valid.
validMethods []string

Expand All @@ -20,44 +20,54 @@ type Parser struct {
validator *validator
}

type Parser[T Claims] struct {
opts parserOpts
}

// NewParser creates a new Parser with the specified options
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
validator: &validator{},
func NewParser(options ...ParserOption) *Parser[MapClaims] {
p := &Parser[MapClaims]{
opts: parserOpts{validator: &validator{}},
}

// Loop through our parsing options and apply them
for _, option := range options {
option(p)
option(&p.opts)
}

return p
}

// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
func NewParserFor[T Claims](options ...ParserOption) *Parser[T] {
p := &Parser[T]{
opts: parserOpts{validator: &validator{}},
}

// Loop through our parsing options and apply them
for _, option := range options {
option(&p.opts)
}

return p
}

// ParseWithClaims parses, validates, and verifies like Parse, but supplies a default object implementing the Claims
// interface. This provides default values which can be overridden and allows a caller to use their own type, rather
// than the default MapClaims implementation of Claims.
// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
//
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
token, parts, err := p.ParseUnverified(tokenString)
if err != nil {
return token, err
}

// Verify signing method is in the required set
if p.validMethods != nil {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range p.validMethods {
if p.opts.validMethods != nil {
signingMethodValid := false
alg := token.Method.Alg()
for _, m := range p.opts.validMethods {
if m == alg {
signingMethodValid = true
break
Expand Down Expand Up @@ -86,13 +96,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
}

// Validate Claims
if !p.skipClaimsValidation {
if !p.opts.skipClaimsValidation {
// Make sure we have at least a default validator
if p.validator == nil {
p.validator = newValidator()
if p.opts.validator == nil {
p.opts.validator = newValidator()
}

if err := p.validator.Validate(claims); err != nil {
if err := p.opts.validator.Validate(token.Claims); err != nil {
return token, newError("", ErrTokenInvalidClaims, err)
}
}
Expand All @@ -109,13 +119,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
//
// It's only ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
func (p *Parser[T]) ParseUnverified(tokenString string) (token *Token[T], parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed)
}

token = &Token{Raw: tokenString}
token = &Token[T]{Raw: tokenString}

// parse Header
var headerBytes []byte
Expand All @@ -131,23 +141,17 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke

// parse Claims
var claimBytes []byte
token.Claims = claims

if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
}

dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.useJSONNumber {
if p.opts.useJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}

// Handle decode error
if err != nil {
if err = dec.Decode(&token.Claims); err != nil {
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
}

Expand Down
Loading

0 comments on commit 191c1a9

Please sign in to comment.