Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tri-state verification to mongodb detector #1575

Merged
merged 10 commits into from
Jul 31, 2023
Merged
60 changes: 42 additions & 18 deletions pkg/detectors/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mongodb

import (
"context"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
"regexp"
"strings"
"time"
Expand All @@ -14,12 +16,15 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb"
)

type Scanner struct{}
type Scanner struct {
timeout time.Duration // Zero value means "default timeout"
}

// Ensure the Scanner satisfies the interface at compile time.
var _ detectors.Detector = (*Scanner)(nil)

var (
defaultTimeout = 2 * time.Second
// Make sure that your group is surrounded in boundary characters such as below to reduce false positives.
keyPat = regexp.MustCompile(`\b(mongodb(\+srv)?://[\S]{3,50}:([\S]{3,88})@[-.%\w\/:]+)\b`)
// TODO: Add support for sharded cluster, replica set and Atlas Deployment.
Expand All @@ -46,30 +51,49 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
}

if verify {
func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
client, err := mongo.Connect(ctx, options.Client().ApplyURI(resMatch))
if err != nil {
return
}
defer func() {
if err := client.Disconnect(ctx); err != nil {
return
}
}()
Comment on lines -56 to -60
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahrav i think this was just discarding the result of client.Disconnect so i simplified it a bit when i moved it, but give a shout if the new version looks wrong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

if err := client.Ping(ctx, readpref.Primary()); err != nil {
return
}
s1.Verified = true
}()
timeout := s.timeout
if timeout == 0 {
timeout = defaultTimeout
}
err := verifyUri(resMatch, timeout)
s1.Verified = err == nil
if !isErrDeterminate(err) {
s1.VerificationError = err
}
}
results = append(results, s1)
}

return results, nil
}

func isErrDeterminate(err error) bool {
Copy link
Collaborator

@ahrav ahrav Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: I think you can simplify this with:

if e, ok := err.(topology.ConnectionError); ok {
    _, ok = e.Unwrap().(*auth.Error)
    return ok
}
return false

or i guess if we want to use the errors pkg:

func isErrDeterminate(err error) bool {
    var connErr topology.ConnectionError
    if errors.As(err, &connErr) {
        var authErr *auth.Error
	return errors.As(connErr.Unwrap(), &authErr)
    }
    return false
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i actually specifically wrote this not to be as terse as possible, but to make it easy to extend with other error cases in the future. do you find it hard to read as it is?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh okay, i think that makes sense. I sorta had the feeling, but wasn't sure. If we leave the switches might be worth adding the default case to each, no mandatory though. One of the the go linters happens to yell at me when i have a switch with no default. 🤷

switch e := err.(type) {
case topology.ConnectionError:
switch e.Unwrap().(type) {
case *auth.Error:
return true
default:
return false
}
default:
return false
}
}

func verifyUri(uri string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
if err != nil {
return err
}
defer func() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since we are ignoring the err here anyways we can inline it:
defer client.Disconnect(ctx)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, i could have sworn the linter didn't let me do that...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might've complained that you are ignoring the err i suppose.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah golangci-lint is not about this

_ = client.Disconnect(ctx)
}()
return client.Ping(ctx, readpref.Primary())
}

func (s Scanner) Type() detectorspb.DetectorType {
return detectorspb.DetectorType_MongoDB
}
54 changes: 46 additions & 8 deletions pkg/detectors/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package mongodb
import (
"context"
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -34,11 +35,12 @@ func TestMongoDB_FromChunk(t *testing.T) {
verify bool
}
tests := []struct {
name string
s Scanner
args args
want []detectors.Result
wantErr bool
name string
s Scanner
args args
want []detectors.Result
wantErr bool
wantVerificationErr bool
}{
{
name: "found, verified",
Expand Down Expand Up @@ -72,6 +74,40 @@ func TestMongoDB_FromChunk(t *testing.T) {
},
wantErr: false,
},
{
name: "found, would be verified but for connection timeout",
s: Scanner{timeout: 1 * time.Microsecond},
args: args{
ctx: context.Background(),
data: []byte(fmt.Sprintf("You can find a mongodb secret %s within", secret)),
verify: true,
},
want: []detectors.Result{
{
DetectorType: detectorspb.DetectorType_MongoDB,
Verified: false,
},
},
wantErr: false,
wantVerificationErr: true,
},
{
name: "found, bad host",
s: Scanner{},
args: args{
ctx: context.Background(),
data: []byte(fmt.Sprintf("You can find a mongodb secret %s within", strings.ReplaceAll(secret, ".mongodb.net", ".mongodb.net.bad"))),
verify: true,
},
want: []detectors.Result{
{
DetectorType: detectorspb.DetectorType_MongoDB,
Verified: false,
},
},
wantErr: false,
wantVerificationErr: true,
},
{
name: "not found",
s: Scanner{},
Expand All @@ -86,8 +122,7 @@ func TestMongoDB_FromChunk(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := Scanner{}
got, err := s.FromData(tt.args.ctx, tt.args.verify, tt.args.data)
got, err := tt.s.FromData(tt.args.ctx, tt.args.verify, tt.args.data)
if (err != nil) != tt.wantErr {
t.Errorf("MongoDB.FromData() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -97,8 +132,11 @@ func TestMongoDB_FromChunk(t *testing.T) {
t.Fatalf("no raw secret present: \n %+v", got[i])
}
got[i].Raw = nil
if (got[i].VerificationError != nil) != tt.wantVerificationErr {
t.Fatalf("wantVerificationErr = %v, verification error = %v", tt.wantVerificationErr, got[i].VerificationError)
}
}
ignoreOpts := cmpopts.IgnoreFields(detectors.Result{}, "RawV2")
ignoreOpts := cmpopts.IgnoreFields(detectors.Result{}, "RawV2", "VerificationError")
if diff := cmp.Diff(tt.want, got, ignoreOpts); diff != "" {
t.Errorf("MongoDB.FromData() %s diff: (-got +want)\n%s", tt.name, diff)
}
Expand Down
Loading