Skip to content

Commit

Permalink
Endpoint customizer refresh (trufflesecurity#3308)
Browse files Browse the repository at this point in the history
* Refresh EndpointCustomizer for more explicit configuration

Also add CloudProvider interface.

* WIP: Update EndpointSetter

* Updated detectors with new endpoint customizer

* Fixed linter

* Added check for appending cloud endpoints

---------

Co-authored-by: Miccah Castorina <[email protected]>
  • Loading branch information
kashifkhan0771 and mcastorina authored Sep 24, 2024
1 parent b2311b4 commit 4b6957d
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 42 deletions.
31 changes: 19 additions & 12 deletions pkg/detectors/artifactory/artifactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ import (
type Scanner struct {
client *http.Client
detectors.DefaultMultiPartCredentialProvider
detectors.EndpointSetter
}

var (
// Ensure the Scanner satisfies the interface at compile time.
_ detectors.Detector = (*Scanner)(nil)
_ detectors.Detector = (*Scanner)(nil)
_ detectors.EndpointCustomizer = (*Scanner)(nil)

defaultClient = detectors.DetectorHttpClientWithNoLocalAddresses

Expand Down Expand Up @@ -52,6 +54,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
if len(URLmatch) != 2 {
continue
}

resURLMatch = strings.TrimSpace(URLmatch[1])
}

Expand All @@ -61,20 +64,24 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
}
resMatch := strings.TrimSpace(match[1])

s1 := detectors.Result{
DetectorType: detectorspb.DetectorType_ArtifactoryAccessToken,
Raw: []byte(resMatch),
RawV2: []byte(resMatch + resURLMatch),
}
client := s.getClient()

for _, URL := range s.Endpoints(resURLMatch) {
s1 := detectors.Result{
DetectorType: detectorspb.DetectorType_ArtifactoryAccessToken,
Raw: []byte(resMatch),
RawV2: []byte(resMatch + URL),
}

if verify {
isVerified, verificationErr := verifyArtifactory(ctx, client, URL, resMatch)
s1.Verified = isVerified
s1.SetVerificationError(verificationErr, resMatch)
}

if verify {
client := s.getClient()
isVerified, verificationErr := verifyArtifactory(ctx, client, resURLMatch, resMatch)
s1.Verified = isVerified
s1.SetVerificationError(verificationErr, resMatch)
results = append(results, s1)
}

results = append(results, s1)
}

return results, nil
Expand Down
8 changes: 4 additions & 4 deletions pkg/detectors/datadogtoken/datadogtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ type Scanner struct {
// Ensure the Scanner satisfies the interface at compile time.
var _ detectors.Detector = (*Scanner)(nil)
var _ detectors.EndpointCustomizer = (*Scanner)(nil)
var _ detectors.CloudProvider = (*Scanner)(nil)

func (Scanner) DefaultEndpoint() string { return "https://api.datadoghq.com" }
func (Scanner) CloudEndpoint() string { return "https://api.datadoghq.com" }

var (
client = common.SaneHttpClient()
Expand Down Expand Up @@ -126,7 +127,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
}

if verify {
for _, baseURL := range s.Endpoints(s.DefaultEndpoint()) {
for _, baseURL := range s.Endpoints() {
req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/api/v2/users", nil)
if err != nil {
continue
Expand Down Expand Up @@ -169,8 +170,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
}

if verify {

for _, baseURL := range s.Endpoints(s.DefaultEndpoint()) {
for _, baseURL := range s.Endpoints() {
req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/api/v1/validate", nil)
if err != nil {
continue
Expand Down
10 changes: 8 additions & 2 deletions pkg/detectors/detectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@ type MultiPartCredentialProvider interface {
// EndpointCustomizer is an optional interface that a detector can implement to
// support verifying against user-supplied endpoints.
type EndpointCustomizer interface {
SetEndpoints(...string) error
DefaultEndpoint() string
SetConfiguredEndpoints(...string) error
SetCloudEndpoint(string)
UseCloudEndpoint(bool)
UseFoundEndpoints(bool)
}

type CloudProvider interface {
CloudEndpoint() string
}

type Result struct {
Expand Down
42 changes: 29 additions & 13 deletions pkg/detectors/endpoint_customizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,43 @@ import (
// of the EndpointCustomizer interface. A detector can embed this struct to
// gain the functionality.
type EndpointSetter struct {
endpoints []string
configuredEndpoints []string
cloudEndpoint string
useCloudEndpoint bool
useFoundEndpoints bool
}

func (e *EndpointSetter) SetEndpoints(endpoints ...string) error {
if len(endpoints) == 0 {
func (e *EndpointSetter) SetConfiguredEndpoints(userConfiguredEndpoints ...string) error {
if len(userConfiguredEndpoints) == 0 {
return fmt.Errorf("at least one endpoint required")
}
deduped := make([]string, 0, len(endpoints))
for _, endpoint := range endpoints {
deduped := make([]string, 0, len(userConfiguredEndpoints))
for _, endpoint := range userConfiguredEndpoints {
common.AddStringSliceItem(endpoint, &deduped)
}
e.endpoints = deduped
e.configuredEndpoints = deduped
return nil
}

func (e *EndpointSetter) Endpoints(defaultEndpoint string) []string {
// The only valid time len(e.endpoints) == 0 is when EndpointSetter is
// initializetd to its default state. That means SetEndpoints was never
// called and we should use the default.
if len(e.endpoints) == 0 {
return []string{defaultEndpoint}
func (e *EndpointSetter) SetCloudEndpoint(url string) {
e.cloudEndpoint = url
}

func (e *EndpointSetter) UseCloudEndpoint(enabled bool) {
e.useCloudEndpoint = enabled
}

func (e *EndpointSetter) UseFoundEndpoints(enabled bool) {
e.useFoundEndpoints = enabled
}

func (e *EndpointSetter) Endpoints(foundEndpoints ...string) []string {
endpoints := e.configuredEndpoints
if e.useCloudEndpoint && e.cloudEndpoint != "" {
endpoints = append(endpoints, e.cloudEndpoint)
}
if e.useFoundEndpoints {
endpoints = append(endpoints, foundEndpoints...)
}
return e.endpoints
return endpoints
}
4 changes: 2 additions & 2 deletions pkg/detectors/endpoint_customizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestEmbeddedEndpointSetter(t *testing.T) {
type Scanner struct{ EndpointSetter }
var s Scanner
assert.Equal(t, []string{"baz"}, s.Endpoints("baz"))
assert.NoError(t, s.SetEndpoints("foo", "bar"))
assert.Error(t, s.SetEndpoints())
assert.NoError(t, s.SetConfiguredEndpoints("foo", "bar"))
assert.Error(t, s.SetConfiguredEndpoints())
assert.Equal(t, []string{"foo", "bar"}, s.Endpoints("baz"))
}
7 changes: 4 additions & 3 deletions pkg/detectors/github/v1/github_old.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ type Scanner struct{ detectors.EndpointSetter }
var _ detectors.Detector = (*Scanner)(nil)
var _ detectors.Versioner = (*Scanner)(nil)
var _ detectors.EndpointCustomizer = (*Scanner)(nil)
var _ detectors.CloudProvider = (*Scanner)(nil)

func (Scanner) Version() int { return 1 }
func (Scanner) DefaultEndpoint() string { return "https://api.github.com" }
func (Scanner) Version() int { return 1 }
func (Scanner) CloudEndpoint() string { return "https://api.github.com" }

var (
// Oauth token
Expand Down Expand Up @@ -112,7 +113,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
func (s Scanner) VerifyGithub(ctx context.Context, client *http.Client, token string) (bool, *UserRes, *HeaderInfo, error) {
// https://developer.github.com/v3/users/#get-the-authenticated-user
var requestErr error
for _, url := range s.Endpoints(s.DefaultEndpoint()) {
for _, url := range s.Endpoints() {
requestErr = nil

req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/user", url), nil)
Expand Down
3 changes: 2 additions & 1 deletion pkg/detectors/github/v2/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ type Scanner struct {
var _ detectors.Detector = (*Scanner)(nil)
var _ detectors.Versioner = (*Scanner)(nil)
var _ detectors.EndpointCustomizer = (*Scanner)(nil)
var _ detectors.CloudProvider = (*Scanner)(nil)

func (s Scanner) Version() int {
return 2
}
func (Scanner) DefaultEndpoint() string { return "https://api.github.com" }
func (Scanner) CloudEndpoint() string { return "https://api.github.com" }

var (
// Oauth token
Expand Down
7 changes: 4 additions & 3 deletions pkg/detectors/gitlab/v1/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ var (
_ detectors.Detector = (*Scanner)(nil)
_ detectors.EndpointCustomizer = (*Scanner)(nil)
_ detectors.Versioner = (*Scanner)(nil)
_ detectors.CloudProvider = (*Scanner)(nil)
)

func (Scanner) Version() int { return 1 }
func (Scanner) DefaultEndpoint() string { return "https://gitlab.com" }
func (Scanner) Version() int { return 1 }
func (Scanner) CloudEndpoint() string { return "https://gitlab.com" }

var (
defaultClient = common.SaneHttpClient()
Expand Down Expand Up @@ -87,7 +88,7 @@ func (s Scanner) verifyGitlab(ctx context.Context, resMatch string) (bool, error
if client == nil {
client = defaultClient
}
for _, baseURL := range s.Endpoints(s.DefaultEndpoint()) {
for _, baseURL := range s.Endpoints() {
// test `read_user` scope
req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/api/v4/user", nil)
if err != nil {
Expand Down
12 changes: 10 additions & 2 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,19 @@ func NewEngine(ctx context.Context, cfg *Config) (*Engine, error) {
}

if !cfg.CustomVerifiersOnly || len(urls) == 0 {
urls = append(urls, customizer.DefaultEndpoint())
customizer.UseFoundEndpoints(true)
customizer.UseCloudEndpoint(true)
}
if err := customizer.SetEndpoints(urls...); err != nil {

if err := customizer.SetConfiguredEndpoints(urls...); err != nil {
return false
}

cloudProvider, ok := d.(detectors.CloudProvider)
if ok {
customizer.SetCloudEndpoint(cloudProvider.CloudEndpoint())
}

return true
})
}
Expand Down

0 comments on commit 4b6957d

Please sign in to comment.