diff --git a/pkg/detectors/postgres/postgres.go b/pkg/detectors/postgres/postgres.go index cb22536cd396..26e60c15d093 100644 --- a/pkg/detectors/postgres/postgres.go +++ b/pkg/detectors/postgres/postgres.go @@ -4,58 +4,81 @@ import ( "context" "database/sql" "fmt" - "net/url" "regexp" + "strconv" "strings" "time" - _ "github.com/lib/pq" // PostgreSQL driver + "github.com/lib/pq" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" ) const ( defaultPort = "5432" - defaultHost = "localhost" ) +// This detector currently only finds Postgres connection string URIs +// (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-URIS) When it finds one, it uses +// pq.ParseURI to normalize this into space-separated key-value pair Postgres connection string, and then uses a regular +// expression to transform this connection string into a parameters map. This parameters map is manipulated prior to +// verification, which operates by transforming the map back into a space-separated kvp connection string. This is kind +// of clunky overall, but it has the benefit of preserving the connection string as a map when it needs to be modified, +// which is much nicer than having to patch a space-separated string of kvps. + +// Multi-host connection string URIs are currently not supported because pq.ParseURI doesn't parse them correctly. If we +// happen to run into a case where this matters we can address it then. var ( - _ detectors.Detector = (*Scanner)(nil) - uriPattern = regexp.MustCompile(`\b(?i)postgresql://[\S]+\b`) - hostnamePattern = regexp.MustCompile(`(?i)(?:host|server|address).{0,40}?(\b[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*\b)`) - portPattern = regexp.MustCompile(`(?i)(?:port|p).{0,40}?(\b[0-9]{1,5}\b)`) - usernamePattern = regexp.MustCompile(`(?im)(?:user|usr)\S{0,40}?[:=\s]{1,3}[ '"=]{0,1}([^:'"\s]{4,40})`) - passwordPattern = regexp.MustCompile(`(?im)(?:pass)\S{0,40}?[:=\s]{1,3}[ '"=]{0,1}([^:'"\s]{4,40})`) + _ detectors.Detector = (*Scanner)(nil) + uriPattern = regexp.MustCompile(`\b(?i)postgres(?:ql)?://\S+\b`) + connStrPartPattern = regexp.MustCompile(`([[:alpha:]]+)='(.+?)' ?`) ) type Scanner struct{} func (s Scanner) Keywords() []string { - return []string{"postgres", "psql", "pghost"} + return []string{"postgres"} } func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]detectors.Result, error) { var results []detectors.Result - var pgURLs []url.URL - pgURLs = append(pgURLs, findUriMatches(string(data))) - pgURLs = append(pgURLs, findComponentMatches(verify, string(data))...) + candidateParamSets := findUriMatches(data) - for _, pgURL := range pgURLs { - if pgURL.User == nil { + for _, params := range candidateParamSets { + user, ok := params["user"] + if !ok { continue } - username := pgURL.User.Username() - password, _ := pgURL.User.Password() - hostport := pgURL.Host + + password, ok := params["password"] + if !ok { + continue + } + + hostport, ok := params["host"] + if !ok { + continue + } + + if port, ok := params["port"]; ok { + hostport = hostport + ":" + port + } + result := detectors.Result{ DetectorType: detectorspb.DetectorType_Postgres, - Raw: []byte(hostport + username + password), - RawV2: []byte(hostport + username + password), + Raw: []byte(hostport + user + password), + RawV2: []byte(hostport + user + password), } if verify { - timeoutInSeconds := getDeadlineInSeconds(ctx) - isVerified, verificationErr := verifyPostgres(&pgURL, timeoutInSeconds) + if timeout := getDeadlineInSeconds(ctx); timeout != 0 { + params["connect_timeout"] = strconv.Itoa(timeout) + } + + // We'd like the 'allow' mode but pq doesn't support it (https://github.com/lib/pq/issues/776) + // To kludge it we first try with 'require' and then fall back to 'disable' if there's an SSL error + params["sslmode"] = "require" + isVerified, verificationErr := verifyPostgres(params) result.Verified = isVerified result.SetVerificationError(verificationErr, password) } @@ -69,186 +92,61 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete return results, nil } -func getDeadlineInSeconds(ctx context.Context) int { - deadline, ok := ctx.Deadline() - if !ok { - // Context does not have a deadline - return 0 - } - - duration := time.Until(deadline) - return int(duration.Seconds()) -} - -func findUriMatches(dataStr string) url.URL { - var pgURL url.URL - for _, uri := range uriPattern.FindAllString(dataStr, -1) { - pgURL, err := url.Parse(uri) +func findUriMatches(data []byte) []map[string]string { + var matches []map[string]string + for _, uri := range uriPattern.FindAll(data, -1) { + connStr, err := pq.ParseURL(string(uri)) if err != nil { continue } - if pgURL.User != nil { - return *pgURL - } - } - return pgURL -} - -// check if postgres is running -func postgresRunning(hostname, port string) bool { - connStr := fmt.Sprintf("host=%s port=%s sslmode=disable", hostname, port) - db, err := sql.Open("postgres", connStr) - if err != nil { - return false - } - defer db.Close() - return true -} - -func findComponentMatches(verify bool, dataStr string) []url.URL { - usernameMatches := usernamePattern.FindAllStringSubmatch(dataStr, -1) - passwordMatches := passwordPattern.FindAllStringSubmatch(dataStr, -1) - hostnameMatches := hostnamePattern.FindAllStringSubmatch(dataStr, -1) - portMatches := portPattern.FindAllStringSubmatch(dataStr, -1) - - var pgURLs []url.URL - - hosts := findHosts(verify, hostnameMatches, portMatches) - - for _, username := range dedupMatches(usernameMatches) { - for _, password := range dedupMatches(passwordMatches) { - for _, host := range hosts { - hostname, port := strings.Split(host, ":")[0], strings.Split(host, ":")[1] - if combinedLength := len(username) + len(password) + len(hostname); combinedLength > 255 { - continue - } - postgresURL := url.URL{ - Scheme: "postgresql", - User: url.UserPassword(username, password), - Host: fmt.Sprintf("%s:%s", hostname, port), - } - pgURLs = append(pgURLs, postgresURL) - } - } - } - return pgURLs -} - -// if verification is turned on, and we can confirm that postgres is running on at least one host, -// return only hosts where it's running. otherwise return all hosts. -func findHosts(verify bool, hostnameMatches, portMatches [][]string) []string { - hostnames := dedupMatches(hostnameMatches) - ports := dedupMatches(portMatches) - var hosts []string - - if len(hostnames) < 1 { - hostnames = append(hostnames, defaultHost) - } - - if len(ports) < 1 { - ports = append(ports, defaultPort) - } - - for _, hostname := range hostnames { - for _, port := range ports { - hosts = append(hosts, fmt.Sprintf("%s:%s", hostname, port)) - } - } - if verify { - var verifiedHosts []string - for _, host := range hosts { - parts := strings.Split(host, ":") - hostname, port := parts[0], parts[1] - if postgresRunning(hostname, port) { - verifiedHosts = append(verifiedHosts, host) - } - } - if len(verifiedHosts) > 0 { - return verifiedHosts + params := make(map[string]string) + parts := connStrPartPattern.FindAllStringSubmatch(connStr, -1) + for _, part := range parts { + params[part[1]] = part[2] } - } - return hosts -} - -// deduplicate matches in order to reduce the number of verification requests -func dedupMatches(matches [][]string) []string { - setOfMatches := make(map[string]struct{}) - for _, match := range matches { - if len(match) > 1 { - setOfMatches[match[1]] = struct{}{} - } + matches = append(matches, params) } - var results []string - for match := range setOfMatches { - results = append(results, match) - } - return results + return matches } -func verifyPostgres(pgURL *url.URL, timeoutInSeconds int) (bool, error) { - if pgURL.User == nil { - return false, nil - } - username := pgURL.User.Username() - password, _ := pgURL.User.Password() - - hostname, port := pgURL.Hostname(), pgURL.Port() - if hostname == "" { - hostname = defaultHost - } - if port == "" { - port = defaultPort +func getDeadlineInSeconds(ctx context.Context) int { + deadline, ok := ctx.Deadline() + if !ok { + // Context does not have a deadline + return 0 } - sslmode := determineSSLMode(pgURL) + duration := time.Until(deadline) + return int(duration.Seconds()) +} - connStr := fmt.Sprintf("user=%s password=%s host=%s port=%s sslmode=%s", username, password, hostname, port, sslmode) - if timeoutInSeconds > 0 { - connStr = fmt.Sprintf("%s connect_timeout=%d", connStr, timeoutInSeconds) +func verifyPostgres(params map[string]string) (bool, error) { + var connStr string + for key, value := range params { + connStr += fmt.Sprintf("%s='%s'", key, value) } db, err := sql.Open("postgres", connStr) if err != nil { - if strings.Contains(err.Error(), "connection refused") { - // inactive host - return false, nil - } return false, err } defer db.Close() - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - err = db.PingContext(ctx) + err = db.Ping() if err == nil { return true, nil - } else if strings.Contains(err.Error(), "password authentication failed") || // incorrect username or password - strings.Contains(err.Error(), "connection refused") { // inactive host + } else if strings.Contains(err.Error(), "password authentication failed") { return false, nil + } else if strings.Contains(err.Error(), "SSL is not enabled on the server") { + params["sslmode"] = "disable" + return verifyPostgres(params) } - // if ssl is not enabled, manually fall-back to sslmode=disable - if strings.Contains(err.Error(), "SSL is not enabled on the server") { - pgURL.RawQuery = fmt.Sprintf("sslmode=%s", "disable") - return verifyPostgres(pgURL, timeoutInSeconds) - } return false, err } -func determineSSLMode(pgURL *url.URL) string { - // default ssl mode is "prefer" per https://www.postgresql.org/docs/current/libpq-ssl.html - // but is currently not implemented in the driver per https://github.com/lib/pq/issues/1006 - // default for the driver is "require". ideally we would use "allow" but that is also not supported by the driver. - sslmode := "require" - if sslQuery, ok := pgURL.Query()["sslmode"]; ok && len(sslQuery) > 0 { - sslmode = sslQuery[0] - } - return sslmode -} - func (s Scanner) Type() detectorspb.DetectorType { return detectorspb.DetectorType_Postgres } diff --git a/pkg/detectors/postgres/postgres_test.go b/pkg/detectors/postgres/postgres_test.go index 08da6dd2fabc..33b3fa659a22 100644 --- a/pkg/detectors/postgres/postgres_test.go +++ b/pkg/detectors/postgres/postgres_test.go @@ -26,7 +26,7 @@ const ( postgresUser = "postgres" postgresPass = "23201dabb56ca236f3dc6736c0f9afad" postgresHost = "localhost" - postgresPort = "5433" + postgresPort = "5434" // Do not use 5433, as local dev environments can use it for other things inactiveUser = "inactive" inactivePass = "inactive" @@ -62,28 +62,7 @@ func TestPostgres_FromChunk(t *testing.T) { wantErr: false, }, { - name: "found with seperated credentials, verified", - s: Scanner{}, - args: args{ - ctx: context.Background(), - data: []byte(fmt.Sprintf(` - POSTGRES_USER=%s - POSTGRES_PASSWORD=%s - POSTGRES_ADDRESS=%s - POSTGRES_PORT=%s - `, postgresUser, postgresPass, postgresHost, postgresPort)), - verify: true, - }, - want: []detectors.Result{ - { - DetectorType: detectorspb.DetectorType_Postgres, - Verified: true, - }, - }, - wantErr: false, - }, - { - name: "found with single line credentials, verified", + name: "found connection URI, verified", s: Scanner{}, args: args{ ctx: context.Background(), @@ -99,65 +78,7 @@ func TestPostgres_FromChunk(t *testing.T) { wantErr: false, }, { - name: "found with json credentials, verified", - s: Scanner{}, - args: args{ - ctx: context.Background(), - data: []byte(fmt.Sprintf( - `DB_CONFIG={"user": "%s", "password": "%s", "host": "%s", "port": "%s", "database": "postgres"}`, postgresUser, postgresPass, postgresHost, postgresPort)), - verify: true, - }, - want: []detectors.Result{ - { - DetectorType: detectorspb.DetectorType_Postgres, - Verified: true, - }, - }, - wantErr: false, - }, - { - name: "found with seperated credentials, unverified", - s: Scanner{}, - args: args{ - ctx: context.Background(), - data: []byte(fmt.Sprintf(` - POSTGRES_USER=%s - POSTGRES_PASSWORD=%s - POSTGRES_ADDRESS=%s - POSTGRES_PORT=%s - `, postgresUser, inactivePass, postgresHost, postgresPort)), - verify: true, - }, - want: []detectors.Result{ - { - DetectorType: detectorspb.DetectorType_Postgres, - Verified: false, - }, - }, - wantErr: false, - }, - { - name: "found with seperated credentials - no port, unverified", - s: Scanner{}, - args: args{ - ctx: context.Background(), - data: []byte(fmt.Sprintf(` - POSTGRES_USER=%s - POSTGRES_PASSWORD=%s - POSTGRES_ADDRESS=%s - `, postgresUser, inactivePass, postgresHost)), - verify: true, - }, - want: []detectors.Result{ - { - DetectorType: detectorspb.DetectorType_Postgres, - Verified: false, - }, - }, - wantErr: false, - }, - { - name: "found with single line credentials, unverified", + name: "found connection URI, unverified", s: Scanner{}, args: args{ ctx: context.Background(), @@ -173,59 +94,7 @@ func TestPostgres_FromChunk(t *testing.T) { wantErr: false, }, { - name: "found with json credentials, unverified - inactive password", - s: Scanner{}, - args: args{ - ctx: context.Background(), - data: []byte(fmt.Sprintf( - `DB_CONFIG={"user": "%s", "password": "%s", "host": "%s", "port": "%s", "database": "postgres"}`, postgresUser, inactivePass, postgresHost, postgresPort)), - verify: true, - }, - want: []detectors.Result{ - { - DetectorType: detectorspb.DetectorType_Postgres, - Verified: false, - }, - }, - wantErr: false, - }, - { - name: "found with json credentials, unverified - inactive user", - s: Scanner{}, - args: args{ - ctx: context.Background(), - data: []byte(fmt.Sprintf( - `DB_CONFIG={"user": "%s", "password": "%s", "host": "%s", "port": "%s", "database": "postgres"}`, inactiveUser, postgresPass, postgresHost, postgresPort)), - verify: true, - }, - want: []detectors.Result{ - { - DetectorType: detectorspb.DetectorType_Postgres, - Verified: false, - }, - }, - wantErr: false, - }, - { - name: "found, unverified due to error - inactive port", - s: Scanner{}, - args: args{ - ctx: context.Background(), - data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres`, postgresUser, postgresPass, postgresHost, inactivePort)), - verify: true, - }, - want: func() []detectors.Result { - r := detectors.Result{ - DetectorType: detectorspb.DetectorType_Postgres, - Verified: false, - } - return []detectors.Result{r} - }(), - wantErr: false, - }, - // This test seems take a long time to run (70s+) even with the timeout set to 1s. It's not clear why. - { - name: "found, unverified due to error - inactive host", + name: "found connection URI, unverified due to error - inactive host", s: Scanner{}, args: func() args { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)