Skip to content

Commit

Permalink
feat: error interface (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
james-d-elliott committed Dec 25, 2023
1 parent 7c5946e commit e754bd7
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 59 deletions.
2 changes: 1 addition & 1 deletion clientcredentials/clientcredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*oauth2.RetrieveError)(rErr)
return nil, &oauth2.RetrieveError{BaseError: (*oauth2.BaseError)(rErr)}
}
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions deviceauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAu

if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
&BaseError{
Response: r,
Body: body,
},
}
}

Expand Down
69 changes: 69 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package oauth2

import (
"fmt"
"net/http"
)

// Error interface for most error types, particularly new ones.
type Error interface {
Error() string

GetErrorCode() string
GetErrorDescription() string
GetErrorURI() string
GetResponse() *http.Response
GetBody() []byte
}

type BaseError struct {
Response *http.Response

Body []byte

// ErrorCode is RFC 6749's 'error' parameter.
ErrorCode string
// ErrorDescription is RFC 6749's 'error_description' parameter.
ErrorDescription string
// ErrorURI is RFC 6749's 'error_uri' parameter.
ErrorURI string
}

func (r *BaseError) GetErrorCode() string {
return r.ErrorCode
}

func (r *BaseError) GetErrorDescription() string {
return r.ErrorDescription
}

func (r *BaseError) GetErrorURI() string {
return r.ErrorURI
}

func (r *BaseError) GetResponse() *http.Response {
return r.Response
}

func (r *BaseError) GetBody() []byte {
return r.Body
}

func (r *BaseError) Error() string {
if r.ErrorCode != "" {
s := fmt.Sprintf("oauth2: %q", r.ErrorCode)
if r.ErrorDescription != "" {
s += fmt.Sprintf(" %q", r.ErrorDescription)
}
if r.ErrorURI != "" {
s += fmt.Sprintf(" %q", r.ErrorURI)
}
return s
}

if r.Response == nil {
return fmt.Sprintf("oauth2: request failed")
}

return fmt.Sprintf("oauth2: request failed: %v\nResponse: %s", r.Response.Status, r.Body)
}
12 changes: 8 additions & 4 deletions google/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ func TestAuthenticationError_Temporary(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ae := &AuthenticationError{
err: &oauth2.RetrieveError{
Response: &http.Response{
StatusCode: tt.code,
&oauth2.BaseError{
Response: &http.Response{
StatusCode: tt.code,
},
},
},
}
Expand Down Expand Up @@ -83,8 +85,10 @@ func (s *errTokenSource) Token() (*oauth2.Token, error) {

func TestErrWrappingTokenSource_TokenError(t *testing.T) {
re := &oauth2.RetrieveError{
Response: &http.Response{
StatusCode: 500,
BaseError: &oauth2.BaseError{
Response: &http.Response{
StatusCode: 500,
},
},
}
ts := errWrappingTokenSource{
Expand Down
6 changes: 4 additions & 2 deletions internal/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ func (js jwtSource) Token() (token *oauth2.Token, err error) {
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
BaseError: &oauth2.BaseError{
Response: resp,
Body: body,
},
}
}
// tokenRes is the JSON response body.
Expand Down
2 changes: 1 addition & 1 deletion internal/jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func TestTokenRetrieveError(t *testing.T) {
t.Fatalf("got %T error, expected *RetrieveError", err)
}
// Test error string for backwards compatibility
expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
expected := fmt.Sprintf("oauth2: request failed: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
if errStr := err.Error(); errStr != expected {
t.Fatalf("got %#v, expected %#v", errStr, expected)
}
Expand Down
6 changes: 4 additions & 2 deletions par.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ func (c *Config) PushedAuth(ctx context.Context, state string, opts ...AuthCodeO
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, nil, &RetrieveError{
Response: r,
Body: body,
BaseError: &BaseError{
Response: r,
Body: body,
},
}
}

Expand Down
25 changes: 4 additions & 21 deletions revocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package oauth2
import (
"context"
"fmt"
"net/http"
"net/url"

"authelia.com/client/oauth2/internal"
Expand Down Expand Up @@ -63,7 +62,9 @@ func (c *Config) RevokeToken(ctx context.Context, token *Token, opts ...Revocati
for _, v := range vals {
if err = internal.RevokeToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.RevocationURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()); err != nil {
if rErr, ok := err.(*internal.RevokeError); ok {
return (*RevokeError)(rErr)
xErr := (*BaseError)(rErr)

return &RevokeError{xErr}
}
return err
}
Expand All @@ -73,25 +74,7 @@ func (c *Config) RevokeToken(ctx context.Context, token *Token, opts ...Revocati
}

type RevokeError struct {
Response *http.Response
Body []byte
ErrorCode string
ErrorDescription string
ErrorURI string
}

func (r *RevokeError) Error() string {
if r.ErrorCode != "" {
s := fmt.Sprintf("oauth2: %q", r.ErrorCode)
if r.ErrorDescription != "" {
s += fmt.Sprintf(" %q", r.ErrorDescription)
}
if r.ErrorURI != "" {
s += fmt.Sprintf(" %q", r.ErrorURI)
}
return s
}
return fmt.Sprintf("oauth2: cannot revoke token: %v\nResponse: %s", r.Response.Status, r.Body)
*BaseError
}

type RevocationOption interface {
Expand Down
28 changes: 2 additions & 26 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package oauth2

import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -167,7 +166,7 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error)
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)
return nil, &RetrieveError{BaseError: (*BaseError)(rErr)}
}
return nil, err
}
Expand All @@ -178,28 +177,5 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error)
// non-2XX HTTP status code or populates RFC 6749's 'error' parameter.
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
type RetrieveError struct {
Response *http.Response
// Body is the body that was consumed by reading Response.Body.
// It may be truncated.
Body []byte
// ErrorCode is RFC 6749's 'error' parameter.
ErrorCode string
// ErrorDescription is RFC 6749's 'error_description' parameter.
ErrorDescription string
// ErrorURI is RFC 6749's 'error_uri' parameter.
ErrorURI string
}

func (r *RetrieveError) Error() string {
if r.ErrorCode != "" {
s := fmt.Sprintf("oauth2: %q", r.ErrorCode)
if r.ErrorDescription != "" {
s += fmt.Sprintf(" %q", r.ErrorDescription)
}
if r.ErrorURI != "" {
s += fmt.Sprintf(" %q", r.ErrorURI)
}
return s
}
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
*BaseError
}

0 comments on commit e754bd7

Please sign in to comment.