From cd88b49665cbd89f850b11fcba1b0407c749d59f Mon Sep 17 00:00:00 2001 From: "zhouyiheng.go" Date: Sun, 2 Apr 2023 23:48:41 +0800 Subject: [PATCH] feat: custom json and base64 encoders for Token and Parser --- encoder.go | 13 ++++++++++++ encoder_test.go | 26 ++++++++++++++++++++++++ go.sum | 34 +++++++++++++++++++++++++++++++ parser.go | 46 ++++++++++++++++++++++++++++++++---------- parser_option.go | 13 ++++++++++++ parser_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++- token.go | 34 +++++++++++++++++++++---------- token_option.go | 12 +++++++++++ token_test.go | 21 +++++++++++++++++++ 9 files changed, 230 insertions(+), 21 deletions(-) create mode 100644 encoder.go create mode 100644 encoder_test.go diff --git a/encoder.go b/encoder.go new file mode 100644 index 00000000..d36d701a --- /dev/null +++ b/encoder.go @@ -0,0 +1,13 @@ +package jwt + +// Base64Encoder is an interface that allows to implement custom Base64 encoding/decoding algorithms. +type Base64Encoder interface { + EncodeToString(src []byte) string + DecodeString(s string) ([]byte, error) +} + +// JSONEncoder is an interface that allows to implement custom JSON encoding/decoding algorithms. +type JSONEncoder interface { + Marshal(v any) ([]byte, error) + Unmarshal(data []byte, v any) error +} diff --git a/encoder_test.go b/encoder_test.go new file mode 100644 index 00000000..add53354 --- /dev/null +++ b/encoder_test.go @@ -0,0 +1,26 @@ +package jwt_test + +import ( + "encoding/base64" + "encoding/json" +) + +type customJSONEncoder struct{} + +func (s *customJSONEncoder) Marshal(v any) ([]byte, error) { + return json.Marshal(v) +} + +func (s *customJSONEncoder) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +type customBase64Encoder struct{} + +func (s *customBase64Encoder) EncodeToString(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +func (s *customBase64Encoder) DecodeString(data string) ([]byte, error) { + return base64.RawURLEncoding.DecodeString(data) +} diff --git a/go.sum b/go.sum index e69de29b..31c17b11 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,34 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.8.6 h1:aUgO9S8gvdN6SyW2EhIpAw5E4ChworywIEndZCkCVXk= +github.com/bytedance/sonic v1.8.6/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/parser.go b/parser.go index f4386fba..d80f0f99 100644 --- a/parser.go +++ b/parser.go @@ -12,7 +12,7 @@ type Parser struct { // If populated, only these methods will be considered valid. validMethods []string - // Use JSON Number format in JSON decoder. + // Use JSON Number format in JSON decoder. This field is disabled when using a custom json encoder. useJSONNumber bool // Skip claims validation during token parsing. @@ -20,9 +20,17 @@ type Parser struct { validator *validator + // This field is disabled when using a custom base64 encoder. decodeStrict bool + // This field is disabled when using a custom base64 encoder. decodePaddingAllowed bool + + // Custom base64 encoder. + base64Encoder Base64Encoder + + // Custom json encoder. + jsonEncoder JSONEncoder } // NewParser creates a new Parser with the specified options @@ -135,7 +143,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err) } - if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + if p.jsonEncoder != nil { + err = p.jsonEncoder.Unmarshal(headerBytes, &token.Header) + } else { + err = json.Unmarshal(headerBytes, &token.Header) + } + if err != nil { return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err) } @@ -146,21 +159,30 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke if claimBytes, err = p.DecodeSegment(parts[1]); err != nil { return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } - dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) - if p.useJSONNumber { - dec.UseNumber() - } + // JSON Decode. Special case for map type to avoid weird pointer behavior - if c, ok := token.Claims.(MapClaims); ok { - err = dec.Decode(&c) + mapClaims, isMapClaims := token.Claims.(MapClaims) + if p.jsonEncoder != nil { + if isMapClaims { + err = p.jsonEncoder.Unmarshal(claimBytes, &mapClaims) + } else { + err = p.jsonEncoder.Unmarshal(claimBytes, &claims) + } } else { - err = dec.Decode(&claims) + decoder := json.NewDecoder(bytes.NewBuffer(claimBytes)) + if p.useJSONNumber { + decoder.UseNumber() + } + if isMapClaims { + err = decoder.Decode(&mapClaims) + } else { + err = decoder.Decode(&claims) + } } // Handle decode error if err != nil { return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err) } - // Lookup signature method if method, ok := token.Header["alg"].(string); ok { if token.Method = GetSigningMethod(method); token.Method == nil { @@ -177,6 +199,10 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. func (p *Parser) DecodeSegment(seg string) ([]byte, error) { + if p.base64Encoder != nil { + return p.base64Encoder.DecodeString(seg) + } + encoding := base64.RawURLEncoding if p.decodePaddingAllowed { diff --git a/parser_option.go b/parser_option.go index 3ad17bc6..c7d748aa 100644 --- a/parser_option.go +++ b/parser_option.go @@ -125,3 +125,16 @@ func WithStrictDecoding() ParserOption { p.decodeStrict = true } } + +// WithJSONEncoder supports +func WithJSONEncoder(enc JSONEncoder) ParserOption { + return func(p *Parser) { + p.jsonEncoder = enc + } +} + +func WithBase64Encoder(enc Base64Encoder) ParserOption { + return func(p *Parser) { + p.base64Encoder = enc + } +} diff --git a/parser_test.go b/parser_test.go index 5b912b15..8bd69941 100644 --- a/parser_test.go +++ b/parser_test.go @@ -54,6 +54,7 @@ var jwtTestData = []struct { err []error parser *jwt.Parser signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose + options []jwt.ParserOption }{ { "invalid JWT", @@ -64,6 +65,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "invalid JSON claim", @@ -74,6 +76,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "bearer in JWT", @@ -84,6 +87,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "basic", @@ -94,6 +98,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "basic expired", @@ -104,6 +109,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nbf", @@ -114,6 +120,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, nil, jwt.SigningMethodRS256, + nil, }, { "expired and nbf", @@ -124,6 +131,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, + nil, }, { "basic invalid", @@ -134,6 +142,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nokeyfunc", @@ -144,6 +153,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nokey", @@ -154,6 +164,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "basic errorkey", @@ -164,6 +175,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable, errKeyFuncError}, nil, jwt.SigningMethodRS256, + nil, }, { "invalid signing method", @@ -174,6 +186,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), jwt.SigningMethodRS256, + nil, }, { "valid RSA signing method", @@ -184,6 +197,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodRS256, + nil, }, { "ECDSA signing method not accepted", @@ -194,6 +208,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodES256, + nil, }, { "valid ECDSA signing method", @@ -204,6 +219,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), jwt.SigningMethodES256, + nil, }, { "JSON Number", @@ -214,6 +230,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - basic expired", @@ -224,6 +241,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - basic nbf", @@ -234,6 +252,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - expired and nbf", @@ -244,6 +263,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "SkipClaimsValidation during token parsing", @@ -254,6 +274,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims", @@ -266,6 +287,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - single aud", @@ -278,6 +300,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - multiple aud", @@ -290,6 +313,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - single aud with wrong type", @@ -302,6 +326,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - multiple aud with wrong types", @@ -314,6 +339,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - nbf with 60s skew", @@ -324,6 +350,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithLeeway(time.Minute)), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - nbf with 120s skew", @@ -334,6 +361,29 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, + nil, + }, + { + "custom json encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + []jwt.ParserOption{jwt.WithJSONEncoder(&customJSONEncoder{})}, + }, + { + "custom base64 encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + []jwt.ParserOption{jwt.WithBase64Encoder(&customBase64Encoder{})}, }, } @@ -367,7 +417,7 @@ func TestParser_Parse(t *testing.T) { var err error var parser = data.parser if parser == nil { - parser = jwt.NewParser() + parser = jwt.NewParser(data.options...) } // Figure out correct claims type switch data.claims.(type) { diff --git a/token.go b/token.go index c8ad7c78..31119d24 100644 --- a/token.go +++ b/token.go @@ -14,12 +14,14 @@ type Keyfunc func(*Token) (interface{}, error) // Token represents a JWT Token. Different fields will be used depending on // whether you're creating or parsing/verifying a token. type Token struct { - Raw string // Raw contains the raw token. Populated when you [Parse] a token - Method SigningMethod // Method is the signing method used or to be used - Header map[string]interface{} // Header is the first segment of the token in decoded form - Claims Claims // Claims is the second segment of the token in decoded form - Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token - Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + Raw string // Raw contains the raw token. Populated when you [Parse] a token + Method SigningMethod // Method is the signing method used or to be used + Header map[string]interface{} // Header is the first segment of the token in decoded form + Claims Claims // Claims is the second segment of the token in decoded form + Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token + Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + jsonEncoder JSONEncoder // jsonEncoder is the custom json encoder/decoder + base64Encoder Base64Encoder // base64Encoder is the custom base64 encoder/decoder } // New creates a new [Token] with the specified signing method and an empty map @@ -31,7 +33,7 @@ func New(method SigningMethod, opts ...TokenOption) *Token { // NewWithClaims creates a new [Token] with the specified signing method and // claims. Additional options can be specified, but are currently unused. func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *Token { - return &Token{ + t := &Token{ Header: map[string]interface{}{ "typ": "JWT", "alg": method.Alg(), @@ -39,6 +41,10 @@ func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *To Claims: claims, Method: method, } + for _, opt := range opts { + opt(t) + } + return t } // SignedString creates and returns a complete, signed JWT. The token is signed @@ -64,12 +70,17 @@ func (t *Token) SignedString(key interface{}) (string, error) { // of the whole deal. Unless you need this for something special, just go // straight for the SignedString. func (t *Token) SigningString() (string, error) { - h, err := json.Marshal(t.Header) + marshal := json.Marshal + if t.jsonEncoder != nil { + marshal = t.jsonEncoder.Marshal + } + + h, err := marshal(t.Header) if err != nil { return "", err } - c, err := json.Marshal(t.Claims) + c, err := marshal(t.Claims) if err != nil { return "", err } @@ -81,6 +92,9 @@ func (t *Token) SigningString() (string, error) { // stripped. In the future, this function might take into account a // [TokenOption]. Therefore, this function exists as a method of [Token], rather // than a global function. -func (*Token) EncodeSegment(seg []byte) string { +func (t *Token) EncodeSegment(seg []byte) string { + if t.base64Encoder != nil { + return t.base64Encoder.EncodeToString(seg) + } return base64.RawURLEncoding.EncodeToString(seg) } diff --git a/token_option.go b/token_option.go index b4ae3bad..7a6accb2 100644 --- a/token_option.go +++ b/token_option.go @@ -3,3 +3,15 @@ package jwt // TokenOption is a reserved type, which provides some forward compatibility, // if we ever want to introduce token creation-related options. type TokenOption func(*Token) + +func WithTokenJSONEncoder(enc JSONEncoder) TokenOption { + return func(token *Token) { + token.jsonEncoder = enc + } +} + +func WithTokenBase64Encoder(enc Base64Encoder) TokenOption { + return func(token *Token) { + token.base64Encoder = enc + } +} diff --git a/token_test.go b/token_test.go index f18329e0..27c50010 100644 --- a/token_test.go +++ b/token_test.go @@ -14,6 +14,7 @@ func TestToken_SigningString(t1 *testing.T) { Claims jwt.Claims Signature []byte Valid bool + Options []jwt.TokenOption } tests := []struct { name string @@ -21,6 +22,22 @@ func TestToken_SigningString(t1 *testing.T) { want string wantErr bool }{ + { + name: "", + fields: fields{ + Raw: "", + Method: jwt.SigningMethodHS256, + Header: map[string]interface{}{ + "typ": "JWT", + "alg": jwt.SigningMethodHS256.Alg(), + }, + Claims: jwt.RegisteredClaims{}, + Valid: false, + Options: nil, + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", + wantErr: false, + }, { name: "", fields: fields{ @@ -32,6 +49,10 @@ func TestToken_SigningString(t1 *testing.T) { }, Claims: jwt.RegisteredClaims{}, Valid: false, + Options: []jwt.TokenOption{ + jwt.WithTokenJSONEncoder(&customJSONEncoder{}), + jwt.WithTokenBase64Encoder(&customBase64Encoder{}), + }, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", wantErr: false,