From 20b77938285b82bc80531ba176989b7f8bae8c4b Mon Sep 17 00:00:00 2001 From: Cody Rose Date: Wed, 19 Jul 2023 16:57:57 -0400 Subject: [PATCH] JDBC indeterminacy (#1507) This PR adds an indeterminacy check to the JDBC verifiers. --- pkg/detectors/jdbc/jdbc.go | 27 +++++++-- pkg/detectors/jdbc/mysql.go | 23 ++++++-- pkg/detectors/jdbc/mysql_integration_test.go | 47 +++++++++------- pkg/detectors/jdbc/postgres.go | 37 ++++++++++--- .../jdbc/postgres_integration_test.go | 55 ++++++++++++------- pkg/detectors/jdbc/sqlite.go | 12 +++- pkg/detectors/jdbc/sqlite_test.go | 2 +- pkg/detectors/jdbc/sqlserver.go | 23 ++++++-- .../jdbc/sqlserver_integration_test.go | 43 ++++++++++----- 9 files changed, 190 insertions(+), 79 deletions(-) diff --git a/pkg/detectors/jdbc/jdbc.go b/pkg/detectors/jdbc/jdbc.go index c45df67eef1f..32fb96055f25 100644 --- a/pkg/detectors/jdbc/jdbc.go +++ b/pkg/detectors/jdbc/jdbc.go @@ -86,7 +86,14 @@ matchLoop: } ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - s.Verified = j.ping(ctx) + pingRes := j.ping(ctx) + s.Verified = pingRes.err == nil + // If there's a ping error that is marked as "determinate" we throw it away. We do this because this was the + // behavior before tri-state verification was introduced and preserving it allows us to gradually migrate + // detectors to use tri-state verification. + if pingRes.err != nil && !pingRes.determinate { + s.VerificationError = pingRes.err + } // TODO: specialized redaction } @@ -198,8 +205,13 @@ var supportedSubprotocols = map[string]func(string) (jdbc, error){ "sqlserver": parseSqlServer, } +type pingResult struct { + err error + determinate bool +} + type jdbc interface { - ping(context.Context) bool + ping(context.Context) pingResult } func newJDBC(conn string) (jdbc, error) { @@ -220,13 +232,16 @@ func newJDBC(conn string) (jdbc, error) { return parser(subname) } -func ping(ctx context.Context, driverName string, candidateConns ...string) bool { +func ping(ctx context.Context, driverName string, isDeterminate func(error) bool, candidateConns ...string) pingResult { + var indeterminateErrors []error for _, c := range candidateConns { - if err := pingErr(ctx, driverName, c); err == nil { - return true + err := pingErr(ctx, driverName, c) + if err == nil || isDeterminate(err) { + return pingResult{err, true} } + indeterminateErrors = append(indeterminateErrors, err) } - return false + return pingResult{errors.Join(indeterminateErrors...), false} } func pingErr(ctx context.Context, driverName, conn string) error { diff --git a/pkg/detectors/jdbc/mysql.go b/pkg/detectors/jdbc/mysql.go index 2e98a045b0ca..f21662ec99d3 100644 --- a/pkg/detectors/jdbc/mysql.go +++ b/pkg/detectors/jdbc/mysql.go @@ -3,9 +3,8 @@ package jdbc import ( "context" "errors" + "github.com/go-sql-driver/mysql" "strings" - - _ "github.com/go-sql-driver/mysql" ) type mysqlJDBC struct { @@ -16,8 +15,8 @@ type mysqlJDBC struct { params string } -func (s *mysqlJDBC) ping(ctx context.Context) bool { - return ping(ctx, "mysql", +func (s *mysqlJDBC) ping(ctx context.Context) pingResult { + return ping(ctx, "mysql", isMySQLErrorDeterminate, s.conn, buildMySQLConnectionString(s.host, s.database, s.userPass, s.params), buildMySQLConnectionString(s.host, "", s.userPass, s.params)) @@ -34,6 +33,22 @@ func buildMySQLConnectionString(host, database, userPass, params string) string return conn } +func isMySQLErrorDeterminate(err error) bool { + // MySQL error numbers from https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html + if mySQLErr, isMySQLErr := err.(*mysql.MySQLError); isMySQLErr { + switch mySQLErr.Number { + case 1044: + // User access denied to a particular database + return false // "Indeterminate" so that other connection variations will be tried + case 1045: + // User access denied + return true + } + } + + return false +} + func parseMySQL(subname string) (jdbc, error) { // expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]] hostAndDB, params, _ := strings.Cut(subname, "?") diff --git a/pkg/detectors/jdbc/mysql_integration_test.go b/pkg/detectors/jdbc/mysql_integration_test.go index b5d503c7a203..cbc64284127f 100644 --- a/pkg/detectors/jdbc/mysql_integration_test.go +++ b/pkg/detectors/jdbc/mysql_integration_test.go @@ -21,45 +21,54 @@ const ( ) func TestMySQL(t *testing.T) { + type result struct { + parseErr bool + pingOk bool + pingDeterminate bool + } tests := []struct { - input string - wantErr bool - wantPing bool + input string + want result }{ { - input: "", - wantErr: true, + input: "", + want: result{parseErr: true}, }, { - input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/" + mysqlDatabase, - wantPing: true, + input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/" + mysqlDatabase, + want: result{pingOk: true, pingDeterminate: true}, }, { - input: "//wrongUser:wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase, - wantPing: false, + input: "//wrongUser:wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase, + want: result{pingOk: false, pingDeterminate: true}, }, { - input: "//" + mysqlUser + ":wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase, - wantPing: false, + input: "//" + mysqlUser + ":wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase, + want: result{pingOk: false, pingDeterminate: true}, }, { - input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/", - wantPing: true, + input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/", + want: result{pingOk: true, pingDeterminate: true}, }, { - input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/wrongDB", - wantPing: true, + input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/wrongDB", + want: result{pingOk: true, pingDeterminate: true}, }, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { j, err := parseMySQL(tt.input) - if tt.wantErr { - assert.Error(t, err) + + if err != nil { + got := result{parseErr: true} + assert.Equal(t, tt.want, got) return } - assert.NoError(t, err) - assert.Equal(t, tt.wantPing, j.ping(context.Background())) + + pr := j.ping(context.Background()) + + got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate} + assert.Equal(t, tt.want, got) }) } } diff --git a/pkg/detectors/jdbc/postgres.go b/pkg/detectors/jdbc/postgres.go index 5d806855e45c..a79efb081cb5 100644 --- a/pkg/detectors/jdbc/postgres.go +++ b/pkg/detectors/jdbc/postgres.go @@ -4,9 +4,8 @@ import ( "context" "errors" "fmt" + "github.com/lib/pq" "strings" - - _ "github.com/lib/pq" ) type postgresJDBC struct { @@ -14,12 +13,36 @@ type postgresJDBC struct { params map[string]string } -func (s *postgresJDBC) ping(ctx context.Context) bool { - return ping(ctx, "postgres", - s.conn, - "postgres://"+s.conn, +func (s *postgresJDBC) ping(ctx context.Context) pingResult { + // It is crucial that we try to build a connection string ourselves before using the one we found. This is because + // if the found connection string doesn't include a username, the driver will attempt to connect using the current + // user's name, which will fail in a way that looks like a determinate failure, thus terminating the waterfall. In + // contrast, when we build a connection string ourselves, if there's no username, we try 'postgres' instead, which + // actually has a chance of working. + return ping(ctx, "postgres", isPostgresErrorDeterminate, buildPostgresConnectionString(s.params, true), - buildPostgresConnectionString(s.params, false)) + buildPostgresConnectionString(s.params, false), + s.conn, + "postgres://"+s.conn) +} + +func isPostgresErrorDeterminate(err error) bool { + // Postgres codes from https://www.postgresql.org/docs/current/errcodes-appendix.html + if pqErr, isPostgresError := err.(*pq.Error); isPostgresError { + switch pqErr.Code { + case "28P01": + // Invalid username/password + return true + case "3D000": + // Unknown database + return false // "Indeterminate" so that other connection variations will be tried + case "3F000": + // Unknown schema + return false // "Indeterminate" so that other connection variations will be tried + } + } + + return false } func joinKeyValues(m map[string]string, sep string) string { diff --git a/pkg/detectors/jdbc/postgres_integration_test.go b/pkg/detectors/jdbc/postgres_integration_test.go index fb4fe93b68af..e667a0aa8a96 100644 --- a/pkg/detectors/jdbc/postgres_integration_test.go +++ b/pkg/detectors/jdbc/postgres_integration_test.go @@ -20,49 +20,62 @@ const ( ) func TestPostgres(t *testing.T) { + type result struct { + parseErr bool + pingOk bool + pingDeterminate bool + } tests := []struct { - input string - wantErr bool - wantPing bool + input string + want result }{ { - input: "//localhost:5432/foo?sslmode=disable&password=" + postgresPass, - wantPing: true, + input: "//localhost:5432/foo?sslmode=disable&password=" + postgresPass, + want: result{pingOk: true, pingDeterminate: true}, + }, + { + input: "//localhost:5432/foo?sslmode=disable&user=" + postgresUser + "&password=" + postgresPass, + want: result{pingOk: true, pingDeterminate: true}, }, { - input: "//localhost:5432/foo?sslmode=disable&user=" + postgresUser + "&password=" + postgresPass, - wantPing: true, + input: "//localhost/foo?sslmode=disable&port=5432&password=" + postgresPass, + want: result{pingOk: true, pingDeterminate: true}, }, { - input: "//localhost/foo?sslmode=disable&port=5432&password=" + postgresPass, - wantPing: true, + input: "//localhost:5432/foo?password=" + postgresPass, + want: result{pingOk: false, pingDeterminate: false}, }, { - input: "//localhost:5432/foo?password=" + postgresPass, - wantPing: false, + input: "//localhost:5432/foo?sslmode=disable&password=foo", + want: result{pingOk: false, pingDeterminate: true}, }, { - input: "//localhost:5432/foo?sslmode=disable&password=foo", - wantPing: false, + input: "//localhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass, + want: result{pingOk: false, pingDeterminate: true}, }, { - input: "//localhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass, - wantPing: false, + input: "//badhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass, + want: result{pingOk: false, pingDeterminate: false}, }, { - input: "invalid", - wantErr: true, + input: "invalid", + want: result{parseErr: true}, }, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { j, err := parsePostgres(tt.input) - if tt.wantErr { - assert.Error(t, err) + + if err != nil { + got := result{parseErr: true} + assert.Equal(t, tt.want, got) return } - assert.NoError(t, err) - assert.Equal(t, tt.wantPing, j.ping(context.Background())) + + pr := j.ping(context.Background()) + + got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate} + assert.Equal(t, tt.want, got) }) } } diff --git a/pkg/detectors/jdbc/sqlite.go b/pkg/detectors/jdbc/sqlite.go index edf8aa8672a6..7268888ad96f 100644 --- a/pkg/detectors/jdbc/sqlite.go +++ b/pkg/detectors/jdbc/sqlite.go @@ -14,12 +14,18 @@ type sqliteJDBC struct { testing bool } -func (s *sqliteJDBC) ping(ctx context.Context) bool { +var cannotVerifySqliteError error = errors.New("sqlite credentials cannot be verified") + +func (s *sqliteJDBC) ping(ctx context.Context) pingResult { if !s.testing { // sqlite is not a networked database, so we cannot verify - return false + return pingResult{cannotVerifySqliteError, true} } - return ping(ctx, "sqlite3", s.filename) + return ping(ctx, "sqlite3", isSqliteErrorDeterminate, s.filename) +} + +func isSqliteErrorDeterminate(err error) bool { + return true } func parseSqlite(subname string) (jdbc, error) { diff --git a/pkg/detectors/jdbc/sqlite_test.go b/pkg/detectors/jdbc/sqlite_test.go index 4e2578bb323a..377808386722 100644 --- a/pkg/detectors/jdbc/sqlite_test.go +++ b/pkg/detectors/jdbc/sqlite_test.go @@ -39,7 +39,7 @@ func TestParseSqlite(t *testing.T) { assert.Error(t, err) } else { assert.NoError(t, err) - assert.True(t, j.ping(context.Background())) + assert.True(t, j.ping(context.Background()).err == nil) } }) } diff --git a/pkg/detectors/jdbc/sqlserver.go b/pkg/detectors/jdbc/sqlserver.go index ac781002fc44..3a092e5a13c3 100644 --- a/pkg/detectors/jdbc/sqlserver.go +++ b/pkg/detectors/jdbc/sqlserver.go @@ -3,9 +3,8 @@ package jdbc import ( "context" "errors" + mssql "github.com/denisenkom/go-mssqldb" "strings" - - _ "github.com/denisenkom/go-mssqldb" ) type sqlServerJDBC struct { @@ -13,13 +12,27 @@ type sqlServerJDBC struct { params map[string]string } -func (s *sqlServerJDBC) ping(ctx context.Context) bool { - return ping(ctx, "mssql", - s.conn, +func (s *sqlServerJDBC) ping(ctx context.Context) pingResult { + return ping(ctx, "mssql", isSqlServerErrorDeterminate, joinKeyValues(s.params, ";"), + s.conn, "sqlserver://"+s.conn) } +func isSqlServerErrorDeterminate(err error) bool { + // Error numbers from https://learn.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors?view=sql-server-ver16 + if mssqlError, isMssqlError := err.(mssql.Error); isMssqlError { + switch mssqlError.Number { + case 18456: + // Login failed + // This is a determinate failure iff we tried to use a real user + return mssqlError.Message != "login error: Login failed for user ''." + } + } + + return false +} + func parseSqlServer(subname string) (jdbc, error) { if !strings.HasPrefix(subname, "//") { return nil, errors.New("expected connection to start with //") diff --git a/pkg/detectors/jdbc/sqlserver_integration_test.go b/pkg/detectors/jdbc/sqlserver_integration_test.go index e76e6abc81fe..0258b5adef25 100644 --- a/pkg/detectors/jdbc/sqlserver_integration_test.go +++ b/pkg/detectors/jdbc/sqlserver_integration_test.go @@ -21,33 +21,50 @@ const ( ) func TestSqlServer(t *testing.T) { + type result struct { + parseErr bool + pingOk bool + pingDeterminate bool + } tests := []struct { - input string - wantErr bool - wantPing bool + input string + want result }{ { - input: "", - wantErr: true, + input: "", + want: result{parseErr: true}, + }, + { + input: "//server=localhost;user id=sa;database=master;password=" + sqlServerPass, + want: result{pingOk: true, pingDeterminate: true}, + }, + { + input: "//server=badhost;user id=sa;database=master;password=" + sqlServerPass, + want: result{pingOk: false, pingDeterminate: false}, }, { - input: "//odbc:server=localhost;user id=sa;database=master;password=" + sqlServerPass, - wantPing: true, + input: "//localhost;database=master;spring.datasource.password=" + sqlServerPass, + want: result{pingOk: true, pingDeterminate: true}, }, { - input: "//localhost;database=master;spring.datasource.password=" + sqlServerPass, - wantPing: true, + input: "//localhost;database=master;spring.datasource.password=badpassword", + want: result{pingOk: false, pingDeterminate: true}, }, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { j, err := parseSqlServer(tt.input) - if tt.wantErr { - assert.Error(t, err) + + if err != nil { + got := result{parseErr: true} + assert.Equal(t, tt.want, got) return } - assert.NoError(t, err) - assert.Equal(t, tt.wantPing, j.ping(context.Background())) + + pr := j.ping(context.Background()) + + got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate} + assert.Equal(t, tt.want, got) }) } }