diff --git a/pkg/detectors/postgres/postgres.go b/pkg/detectors/postgres/postgres.go index 8465bc073a93..f22542e341b2 100644 --- a/pkg/detectors/postgres/postgres.go +++ b/pkg/detectors/postgres/postgres.go @@ -37,7 +37,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete var results []detectors.Result var pgURLs []url.URL pgURLs = append(pgURLs, findUriMatches(string(data))) - pgURLs = append(pgURLs, findComponentMatches(string(data))...) + pgURLs = append(pgURLs, findComponentMatches(verify, string(data))...) for _, pgURL := range pgURLs { if pgURL.User == nil { @@ -81,7 +81,18 @@ func findUriMatches(dataStr string) url.URL { return pgURL } -func findComponentMatches(dataStr string) []url.URL { +// check if postgres is running +func postgresRunning(hostname, port string) bool { + connStr := fmt.Sprintf("host=%s port=%s sslmode=disable", hostname, port) + db, err := sql.Open("postgres", connStr) + if err != nil { + return false + } + defer db.Close() + return true +} + +func findComponentMatches(verify bool, dataStr string) []url.URL { usernameMatches := common.UsernameRegexCheck("").Matches([]byte(dataStr)) passwordMatches := common.PasswordRegexCheck("").Matches([]byte(dataStr)) hostnameMatches := hostnamePattern.FindAllStringSubmatch(dataStr, -1) @@ -89,33 +100,19 @@ func findComponentMatches(dataStr string) []url.URL { var pgURLs []url.URL - for _, username := range usernameMatches { - if len(username) < 2 { - continue - } - for _, password := range passwordMatches { - if len(password) < 2 { - continue - } - for _, hostname := range hostnameMatches { - if len(hostname) < 2 { - continue - } - port := "" - for _, ports := range portMatches { - // this will only grab the last one if there are multiple - // TODO @0x1: enumerate found ports first - if len(ports) > 1 { - port = ports[1] - } - } - if combinedLength := len(username) + len(password) + len(hostname[1]); combinedLength > 255 { + hosts := findHosts(verify, hostnameMatches, portMatches) + + for _, username := range dedupMatches(usernameMatches) { + for _, password := range dedupMatches(passwordMatches) { + for _, host := range hosts { + hostname, port := strings.Split(host, ":")[0], strings.Split(host, ":")[1] + if combinedLength := len(username) + len(password) + len(hostname); combinedLength > 255 { continue } postgresURL := url.URL{ Scheme: "postgresql", User: url.UserPassword(username, password), - Host: fmt.Sprintf("%s:%s", hostname[1], port), + Host: fmt.Sprintf("%s:%s", hostname, port), } pgURLs = append(pgURLs, postgresURL) } @@ -124,6 +121,69 @@ func findComponentMatches(dataStr string) []url.URL { return pgURLs } +// if verification is turned on, and we are able to verify at least one host, return only verified hosts +// otherwise return all hosts +func findHosts(verify bool, hostnameMatches, portMatches [][]string) []string { + hostnames := dedupMatches2(hostnameMatches) + ports := dedupMatches2(portMatches) + var hosts []string + + if len(hostnames) < 1 { + hostnames = append(hostnames, defaultHost) + } + + if len(ports) < 1 { + ports = append(ports, defaultPort) + } + + for _, hostname := range hostnames { + for _, port := range ports { + hosts = append(hosts, fmt.Sprintf("%s:%s", hostname, port)) + } + } + + if verify { + var verifiedHosts []string + for _, host := range hosts { + hostname, port := strings.Split(host, ":")[0], strings.Split(host, ":")[1] + if postgresRunning(hostname, port) { + verifiedHosts = append(verifiedHosts, host) + } + } + if len(verifiedHosts) > 0 { + return verifiedHosts + } + } + + return hosts +} + +func dedupMatches(matches []string) []string { + setOfMatches := make(map[string]struct{}) + for _, match := range matches { + setOfMatches[match] = struct{}{} + } + var results []string + for match := range setOfMatches { + results = append(results, match) + } + return results +} + +func dedupMatches2(matches [][]string) []string { + setOfMatches := make(map[string]struct{}) + for _, match := range matches { + if len(match) > 1 { + setOfMatches[match[1]] = struct{}{} + } + } + var results []string + for match := range setOfMatches { + results = append(results, match) + } + return results +} + func verifyPostgres(pgURL *url.URL) (bool, error) { if pgURL.User == nil { return false, nil @@ -144,6 +204,10 @@ func verifyPostgres(pgURL *url.URL) (bool, error) { connStr := fmt.Sprintf("user=%s password=%s host=%s port=%s sslmode=%s", username, password, hostname, port, sslmode) db, err := sql.Open("postgres", connStr) if err != nil { + if strings.Contains(err.Error(), "connection refused") { + // inactive host + return false, nil + } return false, err } defer db.Close() @@ -154,8 +218,8 @@ func verifyPostgres(pgURL *url.URL) (bool, error) { err = db.PingContext(ctx) if err == nil { return true, nil - } else if strings.Contains(err.Error(), "password authentication failed") { - // incorrect username or password + } else if strings.Contains(err.Error(), "password authentication failed") || // incorrect username or password + strings.Contains(err.Error(), "connection refused") { // inactive host return false, nil } diff --git a/pkg/detectors/postgres/postgres_test.go b/pkg/detectors/postgres/postgres_test.go index ec170527ce17..c5be452473ca 100644 --- a/pkg/detectors/postgres/postgres_test.go +++ b/pkg/detectors/postgres/postgres_test.go @@ -136,6 +136,26 @@ func TestPostgres_FromChunk(t *testing.T) { }, wantErr: false, }, + { + name: "found with seperated credentials - no port, unverified", + s: Scanner{}, + args: args{ + ctx: context.Background(), + data: []byte(fmt.Sprintf(` + POSTGRES_USER=%s + POSTGRES_PASSWORD=%s + POSTGRES_ADDRESS=%s + `, postgresUser, inactivePass, postgresHost)), + verify: true, + }, + want: []detectors.Result{ + { + DetectorType: detectorspb.DetectorType_Postgres, + Verified: false, + }, + }, + wantErr: false, + }, { name: "found with single line credentials, unverified", s: Scanner{}, @@ -199,12 +219,11 @@ func TestPostgres_FromChunk(t *testing.T) { DetectorType: detectorspb.DetectorType_Postgres, Verified: false, } - r.SetVerificationError(errors.New("connection refused")) return []detectors.Result{r} }(), wantErr: false, }, - // TODO: This test seems take a long time to run (70s+) even with the timeout set to 1s. It's not clear why. + // This test seems take a long time to run (70s+) even with the timeout set to 1s. It's not clear why. { name: "found, unverified due to error - inactive host", s: Scanner{},