From e754bd7696312dd0134e9addac9ce474f2a9e904 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Mon, 25 Dec 2023 16:08:15 +1100 Subject: [PATCH] feat: error interface (#10) --- clientcredentials/clientcredentials.go | 2 +- deviceauth.go | 6 ++- error.go | 69 ++++++++++++++++++++++++++ google/error_test.go | 12 +++-- internal/jwt/jwt.go | 6 ++- internal/jwt/jwt_test.go | 2 +- par.go | 6 ++- revocation.go | 25 ++-------- token.go | 28 +---------- 9 files changed, 97 insertions(+), 59 deletions(-) create mode 100644 error.go diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 2671163..5a45623 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -110,7 +110,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get()) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { - return nil, (*oauth2.RetrieveError)(rErr) + return nil, &oauth2.RetrieveError{BaseError: (*oauth2.BaseError)(rErr)} } return nil, err } diff --git a/deviceauth.go b/deviceauth.go index 5c9de49..f71d6fd 100644 --- a/deviceauth.go +++ b/deviceauth.go @@ -127,8 +127,10 @@ func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAu if code := r.StatusCode; code < 200 || code > 299 { return nil, &RetrieveError{ - Response: r, - Body: body, + &BaseError{ + Response: r, + Body: body, + }, } } diff --git a/error.go b/error.go new file mode 100644 index 0000000..ec964b6 --- /dev/null +++ b/error.go @@ -0,0 +1,69 @@ +package oauth2 + +import ( + "fmt" + "net/http" +) + +// Error interface for most error types, particularly new ones. +type Error interface { + Error() string + + GetErrorCode() string + GetErrorDescription() string + GetErrorURI() string + GetResponse() *http.Response + GetBody() []byte +} + +type BaseError struct { + Response *http.Response + + Body []byte + + // ErrorCode is RFC 6749's 'error' parameter. + ErrorCode string + // ErrorDescription is RFC 6749's 'error_description' parameter. + ErrorDescription string + // ErrorURI is RFC 6749's 'error_uri' parameter. + ErrorURI string +} + +func (r *BaseError) GetErrorCode() string { + return r.ErrorCode +} + +func (r *BaseError) GetErrorDescription() string { + return r.ErrorDescription +} + +func (r *BaseError) GetErrorURI() string { + return r.ErrorURI +} + +func (r *BaseError) GetResponse() *http.Response { + return r.Response +} + +func (r *BaseError) GetBody() []byte { + return r.Body +} + +func (r *BaseError) Error() string { + if r.ErrorCode != "" { + s := fmt.Sprintf("oauth2: %q", r.ErrorCode) + if r.ErrorDescription != "" { + s += fmt.Sprintf(" %q", r.ErrorDescription) + } + if r.ErrorURI != "" { + s += fmt.Sprintf(" %q", r.ErrorURI) + } + return s + } + + if r.Response == nil { + return fmt.Sprintf("oauth2: request failed") + } + + return fmt.Sprintf("oauth2: request failed: %v\nResponse: %s", r.Response.Status, r.Body) +} diff --git a/google/error_test.go b/google/error_test.go index 405577c..de75a1d 100644 --- a/google/error_test.go +++ b/google/error_test.go @@ -47,8 +47,10 @@ func TestAuthenticationError_Temporary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ae := &AuthenticationError{ err: &oauth2.RetrieveError{ - Response: &http.Response{ - StatusCode: tt.code, + &oauth2.BaseError{ + Response: &http.Response{ + StatusCode: tt.code, + }, }, }, } @@ -83,8 +85,10 @@ func (s *errTokenSource) Token() (*oauth2.Token, error) { func TestErrWrappingTokenSource_TokenError(t *testing.T) { re := &oauth2.RetrieveError{ - Response: &http.Response{ - StatusCode: 500, + BaseError: &oauth2.BaseError{ + Response: &http.Response{ + StatusCode: 500, + }, }, } ts := errWrappingTokenSource{ diff --git a/internal/jwt/jwt.go b/internal/jwt/jwt.go index 6f294b0..aa60f5b 100644 --- a/internal/jwt/jwt.go +++ b/internal/jwt/jwt.go @@ -141,8 +141,10 @@ func (js jwtSource) Token() (token *oauth2.Token, err error) { } if c := resp.StatusCode; c < 200 || c > 299 { return nil, &oauth2.RetrieveError{ - Response: resp, - Body: body, + BaseError: &oauth2.BaseError{ + Response: resp, + Body: body, + }, } } // tokenRes is the JSON response body. diff --git a/internal/jwt/jwt_test.go b/internal/jwt/jwt_test.go index 4745adc..1938cb9 100644 --- a/internal/jwt/jwt_test.go +++ b/internal/jwt/jwt_test.go @@ -311,7 +311,7 @@ func TestTokenRetrieveError(t *testing.T) { t.Fatalf("got %T error, expected *RetrieveError", err) } // Test error string for backwards compatibility - expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`) + expected := fmt.Sprintf("oauth2: request failed: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`) if errStr := err.Error(); errStr != expected { t.Fatalf("got %#v, expected %#v", errStr, expected) } diff --git a/par.go b/par.go index 322171a..6d9f19c 100644 --- a/par.go +++ b/par.go @@ -62,8 +62,10 @@ func (c *Config) PushedAuth(ctx context.Context, state string, opts ...AuthCodeO } if code := r.StatusCode; code < 200 || code > 299 { return nil, nil, &RetrieveError{ - Response: r, - Body: body, + BaseError: &BaseError{ + Response: r, + Body: body, + }, } } diff --git a/revocation.go b/revocation.go index 78a457c..af4ffec 100644 --- a/revocation.go +++ b/revocation.go @@ -3,7 +3,6 @@ package oauth2 import ( "context" "fmt" - "net/http" "net/url" "authelia.com/client/oauth2/internal" @@ -63,7 +62,9 @@ func (c *Config) RevokeToken(ctx context.Context, token *Token, opts ...Revocati for _, v := range vals { if err = internal.RevokeToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.RevocationURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()); err != nil { if rErr, ok := err.(*internal.RevokeError); ok { - return (*RevokeError)(rErr) + xErr := (*BaseError)(rErr) + + return &RevokeError{xErr} } return err } @@ -73,25 +74,7 @@ func (c *Config) RevokeToken(ctx context.Context, token *Token, opts ...Revocati } type RevokeError struct { - Response *http.Response - Body []byte - ErrorCode string - ErrorDescription string - ErrorURI string -} - -func (r *RevokeError) Error() string { - if r.ErrorCode != "" { - s := fmt.Sprintf("oauth2: %q", r.ErrorCode) - if r.ErrorDescription != "" { - s += fmt.Sprintf(" %q", r.ErrorDescription) - } - if r.ErrorURI != "" { - s += fmt.Sprintf(" %q", r.ErrorURI) - } - return s - } - return fmt.Sprintf("oauth2: cannot revoke token: %v\nResponse: %s", r.Response.Status, r.Body) + *BaseError } type RevocationOption interface { diff --git a/token.go b/token.go index fb63257..d6beabc 100644 --- a/token.go +++ b/token.go @@ -6,7 +6,6 @@ package oauth2 import ( "context" - "fmt" "net/http" "net/url" "strconv" @@ -167,7 +166,7 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { - return nil, (*RetrieveError)(rErr) + return nil, &RetrieveError{BaseError: (*BaseError)(rErr)} } return nil, err } @@ -178,28 +177,5 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) // non-2XX HTTP status code or populates RFC 6749's 'error' parameter. // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 type RetrieveError struct { - Response *http.Response - // Body is the body that was consumed by reading Response.Body. - // It may be truncated. - Body []byte - // ErrorCode is RFC 6749's 'error' parameter. - ErrorCode string - // ErrorDescription is RFC 6749's 'error_description' parameter. - ErrorDescription string - // ErrorURI is RFC 6749's 'error_uri' parameter. - ErrorURI string -} - -func (r *RetrieveError) Error() string { - if r.ErrorCode != "" { - s := fmt.Sprintf("oauth2: %q", r.ErrorCode) - if r.ErrorDescription != "" { - s += fmt.Sprintf(" %q", r.ErrorDescription) - } - if r.ErrorURI != "" { - s += fmt.Sprintf(" %q", r.ErrorURI) - } - return s - } - return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) + *BaseError }