diff --git a/pkg/auth/authenticator.go b/pkg/auth/authenticator.go index 93e1b81..93d64cf 100644 --- a/pkg/auth/authenticator.go +++ b/pkg/auth/authenticator.go @@ -187,7 +187,7 @@ func (a *Authenticator) Check(ctx context.Context, request *Request) (finalRespo reason, cerberusExtraHeaders = a.TestAccess(request, wsvcCacheEntry) extraHeaders = toExtraHeaders(cerberusExtraHeaders) - if reason == CerberusReasonOK && hasUpstreamAuth(wsvcCacheEntry) { + if reason == "" && hasUpstreamAuth(wsvcCacheEntry) { request.Context[HasUpstreamAuth] = "true" reason = a.checkServiceUpstreamAuth(wsvcCacheEntry, request, &extraHeaders, ctx) } @@ -336,10 +336,16 @@ func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry, attribute.String("upstream-http-request-start", reqStart.Format(tracing.TimeFormat)), attribute.String("upstream-http-request-end", time.Now().Format(tracing.TimeFormat)), attribute.Float64("upstream-http-request-rtt-seconds", time.Since(reqStart).Seconds()), - attribute.Int("upstream-auth-status-code", resp.StatusCode), ) - labels := AddWithDownstreamDeadlineLabel(AddStatusLabel(nil, resp.StatusCode), hasDownstreamDeadline) - upstreamAuthRequestDuration.With(labels).Observe(reqDuration.Seconds()) + + if resp != nil { + span.SetAttributes(attribute.Int("upstream-auth-status-code", resp.StatusCode)) + labels := AddWithDownstreamDeadlineLabel(AddStatusLabel(nil, resp.StatusCode), hasDownstreamDeadline) + upstreamAuthRequestDuration.With(labels).Observe(reqDuration.Seconds()) + } else { + labels := AddWithDownstreamDeadlineLabel(nil, hasDownstreamDeadline) + upstreamAuthFailedRequests.With(labels).Inc() + } if reason := processResponseError(err); reason != "" { span.RecordError(err) diff --git a/pkg/auth/authenticator_test.go b/pkg/auth/authenticator_test.go index 5eb7c09..ef6fb0e 100644 --- a/pkg/auth/authenticator_test.go +++ b/pkg/auth/authenticator_test.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "io" "net/http" "net/url" + "strings" "testing" "time" @@ -978,3 +980,362 @@ func TestSetupUpstreamAuthRequest(t *testing.T) { assert.Nil(t, actualReq, "Request should be nil when there is an error") assert.Error(t, actualErr, "Error should occur when service is empty") } + +func TestCheck_SuccessfulAuthentication(t *testing.T) { + mockHTTPClient := &http.Client{ + Transport: &MockTransport{ + DoFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }, + }, + } + + authenticator := &Authenticator{ + httpClient: mockHTTPClient, + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + tokens := prepareAccessTokens(1) + services := prepareWebservices(1) + + tokenEntry := AccessTokensCacheEntry{ + AccessToken: tokens[0], + allowedWebservicesCache: map[string]struct{}{ + "default/webservice-1": {}, + }, + } + (*authenticator.accessTokensCache)["valid-token"] = tokenEntry + + headers := http.Header{} + headers.Set(string(CerberusHeaderAccessToken), "valid-token") + + request := &Request{ + Context: map[string]string{ + "webservice": services[0].Name, + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + webserviceKey := fmt.Sprintf("%s/%s", "default", services[0].Name) + authenticator.webservicesCache = &WebservicesCache{ + webserviceKey: WebservicesCacheEntry{WebService: services[0]}, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.NoError(t, err, "Expected no error for successful authentication") + assert.NotNil(t, finalResponse, "Expected a non-nil response for successful authentication") + assert.True(t, finalResponse.Allow, "Expected the request to be allowed for valid token and service") +} + +func TestCheck_TokenNotFound(t *testing.T) { + authenticator := &Authenticator{ + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + + services := prepareWebservices(1) + + headers := http.Header{} + headers.Set(string(CerberusHeaderAccessToken), "nonexistent-token") + + request := &Request{ + Context: map[string]string{ + "webservice": services[0].Name, + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + webserviceKey := fmt.Sprintf("%s/%s", "default", services[0].Name) + authenticator.webservicesCache = &WebservicesCache{ + webserviceKey: WebservicesCacheEntry{WebService: services[0]}, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.NoError(t, err, "Expected no error from Check function itself") + assert.NotNil(t, finalResponse, "Expected a non-nil response even for token not found scenario") + assert.False(t, finalResponse.Allow, "Expected the request to not be allowed due to token not found") + assert.Contains(t, finalResponse.Response.Header.Get("X-Cerberus-Reason"), "token-not-found", "Expected X-Cerberus-Reason header to indicate token not found") +} + +func TestCheck_ServiceNotFound(t *testing.T) { + authenticator := &Authenticator{ + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + + tokens := prepareAccessTokens(1) + + tokenEntry := AccessTokensCacheEntry{ + AccessToken: tokens[0], + allowedWebservicesCache: map[string]struct{}{}, + } + (*authenticator.accessTokensCache)["valid-token"] = tokenEntry + + headers := http.Header{} + headers.Set(string(CerberusHeaderAccessToken), "valid-token") + + request := &Request{ + Context: map[string]string{ + "webservice": "nonexistent-service", + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.NoError(t, err, "Expected no error even if service is not found") + assert.NotNil(t, finalResponse, "Expected a non-nil response even if service is not found") + assert.False(t, finalResponse.Allow, "Expected the request to be denied due to service not found") + assert.Contains(t, finalResponse.Response.Header.Get("X-Cerberus-Reason"), "webservice-notfound", "Expected webservice-notfound reason") +} + +func TestCheck_EmptyToken(t *testing.T) { + authenticator := &Authenticator{ + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + services := prepareWebservices(1) + + webserviceKey := fmt.Sprintf("%s/%s", "default", services[0].Name) + (*authenticator.webservicesCache)[webserviceKey] = WebservicesCacheEntry{WebService: services[0]} + + headers := http.Header{} + + request := &Request{ + Context: map[string]string{ + "webservice": services[0].Name, + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.NoError(t, err, "Expected no error for empty token scenario") + assert.NotNil(t, finalResponse, "Expected a non-nil response for empty token scenario") + assert.False(t, finalResponse.Allow, "Expected the request to be denied due to empty token") + assert.Equal(t, http.StatusUnauthorized, finalResponse.Response.StatusCode, "Expected a 401 Unauthorized status code") + assert.Contains(t, finalResponse.Response.Header.Get("X-Cerberus-Reason"), "token-empty", "Expected reason to indicate empty token") +} + +func TestCheck_InvalidServiceName(t *testing.T) { + authenticator := &Authenticator{ + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + tokens := prepareAccessTokens(1) + + tokenEntry := AccessTokensCacheEntry{ + AccessToken: tokens[0], + allowedWebservicesCache: map[string]struct{}{ + "default/valid-service": {}, + }, + } + (*authenticator.accessTokensCache)["valid-token"] = tokenEntry + + headers := http.Header{} + headers.Set(string(CerberusHeaderAccessToken), "valid-token") + + request := &Request{ + Context: map[string]string{ + "webservice": "invalid-service", + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.NoError(t, err, "Expected no error for invalid service name scenario") + assert.NotNil(t, finalResponse, "Expected a non-nil response for invalid service name scenario") + assert.False(t, finalResponse.Allow, "Expected the request to be denied due to invalid service name") + assert.Equal(t, http.StatusUnauthorized, finalResponse.Response.StatusCode, "Expected a 401 Unauthorized status code") + assert.Contains(t, finalResponse.Response.Header.Get("X-Cerberus-Reason"), "webservice-notfound", "Expected reason to indicate service not found") +} + +func TestCheck_UpstreamAuthUnauthorized(t *testing.T) { + mockHTTPClient := &http.Client{ + Transport: &MockTransport{ + DoFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(strings.NewReader("Unauthorized")), + }, nil + }, + }, + } + + authenticator := &Authenticator{ + httpClient: mockHTTPClient, + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + + services := prepareWebservices(1) + tokens := prepareAccessTokens(1) + + tokenEntry := AccessTokensCacheEntry{ + AccessToken: tokens[0], + allowedWebservicesCache: map[string]struct{}{ + "default/" + services[0].Name: {}, + }, + } + (*authenticator.accessTokensCache)["valid-token"] = tokenEntry + + webserviceKey := fmt.Sprintf("%s/%s", "default", services[0].Name) + (*authenticator.webservicesCache)[webserviceKey] = WebservicesCacheEntry{WebService: services[0]} + + headers := http.Header{} + headers.Set(string(CerberusHeaderAccessToken), "valid-token") + + request := &Request{ + Context: map[string]string{ + "webservice": services[0].Name, + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.NoError(t, err, "Did not expect an error from Check function") + assert.NotNil(t, finalResponse, "Expected a non-nil response") + assert.False(t, finalResponse.Allow, "Expected the request to be denied due to unauthorized upstream authentication") + assert.Equal(t, "unauthorized", finalResponse.Response.Header.Get("X-Cerberus-Reason"), "Expected reason to indicate unauthorized") +} + +func TestCheck_UpstreamAuthFailed(t *testing.T) { + mockHTTPClient := &http.Client{ + Transport: &MockTransport{ + DoFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusRequestTimeout, + Body: io.NopCloser(strings.NewReader("Internal Server Error")), + Header: make(http.Header), + }, &url.Error{ + Op: "Get", + URL: "http://fake-upstream-service/authenticate", + Err: errors.New("Internal Server Error"), + } + }, + }, + } + + authenticator := &Authenticator{ + httpClient: mockHTTPClient, + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + + services := prepareWebservices(1) + tokens := prepareAccessTokens(1) + + tokenEntry := AccessTokensCacheEntry{ + AccessToken: tokens[0], + allowedWebservicesCache: map[string]struct{}{ + "default/" + services[0].Name: {}, + }, + } + (*authenticator.accessTokensCache)["valid-token"] = tokenEntry + + webserviceKey := fmt.Sprintf("%s/%s", "default", services[0].Name) + (*authenticator.webservicesCache)[webserviceKey] = WebservicesCacheEntry{WebService: services[0]} + + headers := http.Header{} + headers.Set(string(CerberusHeaderAccessToken), "valid-token") + + request := &Request{ + Context: map[string]string{ + "webservice": services[0].Name, + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.Error(t, err, "Error should occur") + assert.NotNil(t, finalResponse, "Expected a non-nil response") + assert.False(t, finalResponse.Allow, "Expected the request to be denied due to upstream authentication failed") + assert.Equal(t, "upstream-auth-failed", finalResponse.Response.Header.Get("X-Cerberus-Reason"), "Expected reason to indicate upstream authentication failed") +} + +func TestCheck_UpstreamAuthTimeout(t *testing.T) { + mockHTTPClient := &http.Client{ + Transport: &MockTransport{ + DoFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusRequestTimeout, + Body: io.NopCloser(strings.NewReader("Request Timeout")), + Header: make(http.Header), + }, &url.Error{ + Op: "Get", + URL: "http://fake-upstream-service/authenticate", + Err: &innerError{timeout: true}, + } + }, + }, + } + + authenticator := &Authenticator{ + httpClient: mockHTTPClient, + accessTokensCache: &AccessTokensCache{}, + webservicesCache: &WebservicesCache{}, + } + + services := prepareWebservices(1) + tokens := prepareAccessTokens(1) + + tokenEntry := AccessTokensCacheEntry{ + AccessToken: tokens[0], + allowedWebservicesCache: map[string]struct{}{ + "default/" + services[0].Name: {}, + }, + } + (*authenticator.accessTokensCache)["valid-token"] = tokenEntry + + webserviceKey := fmt.Sprintf("%s/%s", "default", services[0].Name) + (*authenticator.webservicesCache)[webserviceKey] = WebservicesCacheEntry{WebService: services[0]} + + headers := http.Header{} + headers.Set(string(CerberusHeaderAccessToken), "valid-token") + + request := &Request{ + Context: map[string]string{ + "webservice": services[0].Name, + "namespace": "default", + }, + Request: http.Request{ + Header: headers, + }, + } + + finalResponse, err := authenticator.Check(context.Background(), request) + + assert.Error(t, err, "Error should occur") + assert.NotNil(t, finalResponse, "Expected a non-nil response") + assert.False(t, finalResponse.Allow, "Expected the request to be denied due to upstream authentication timeout") + assert.Equal(t, "upstream-auth-timeout", finalResponse.Response.Header.Get("X-Cerberus-Reason"), "Expected reason to indicate upstream authentication timeout") +} diff --git a/pkg/auth/metrics.go b/pkg/auth/metrics.go index f26c39e..10c8d74 100644 --- a/pkg/auth/metrics.go +++ b/pkg/auth/metrics.go @@ -122,6 +122,14 @@ var ( }, []string{StatusCode, WithDownstreamDeadlineLabel}, ) + + upstreamAuthFailedRequests = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "upstream_auth_failed_requests_total", + Help: "Total number of failed UpstreamAuth requests", + }, + []string{"with_downstream_deadline"}, + ) ) func init() { @@ -138,6 +146,7 @@ func init() { fetchObjectListLatency, serviceUpstreamAuthCalls, upstreamAuthRequestDuration, + upstreamAuthFailedRequests, ) }