Skip to content

Commit

Permalink
Move SQL escape functions to the Postgres package
Browse files Browse the repository at this point in the history
These are specific to Postgres. Add tests and remove unused functions.
  • Loading branch information
cbandy committed Nov 4, 2024
1 parent c7f5e99 commit 561c650
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 68 deletions.
4 changes: 2 additions & 2 deletions internal/pgbouncer/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ func sqlAuthenticationQuery(sqlFunctionName string) string {
// No replicators.
`NOT pg_authid.rolreplication`,
// Not the PgBouncer role itself.
`pg_authid.rolname <> ` + util.SQLQuoteLiteral(postgresqlUser),
`pg_authid.rolname <> ` + postgres.QuoteLiteral(postgresqlUser),
// Those without a password expiration or an expiration in the future.
`(pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)`,
}, "\n AND ")

return strings.TrimSpace(`
CREATE OR REPLACE FUNCTION ` + sqlFunctionName + `(username TEXT)
RETURNS TABLE(username TEXT, password TEXT) AS ` + util.SQLQuoteLiteral(`
RETURNS TABLE(username TEXT, password TEXT) AS ` + postgres.QuoteLiteral(`
SELECT rolname::TEXT, rolpassword::TEXT
FROM pg_catalog.pg_authid
WHERE pg_authid.rolname = $1
Expand Down
8 changes: 4 additions & 4 deletions internal/pgbouncer/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ import (
func TestSQLAuthenticationQuery(t *testing.T) {
assert.Equal(t, sqlAuthenticationQuery("some.fn_name"),
`CREATE OR REPLACE FUNCTION some.fn_name(username TEXT)
RETURNS TABLE(username TEXT, password TEXT) AS '
RETURNS TABLE(username TEXT, password TEXT) AS E'
SELECT rolname::TEXT, rolpassword::TEXT
FROM pg_catalog.pg_authid
WHERE pg_authid.rolname = $1
AND pg_authid.rolcanlogin
AND NOT pg_authid.rolsuper
AND NOT pg_authid.rolreplication
AND pg_authid.rolname <> ''_crunchypgbouncer''
AND pg_authid.rolname <> E''_crunchypgbouncer''
AND (pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)'
LANGUAGE SQL STABLE SECURITY DEFINER;`)
}
Expand Down Expand Up @@ -150,14 +150,14 @@ REVOKE ALL PRIVILEGES
GRANT USAGE
ON SCHEMA :"namespace" TO :"username";
CREATE OR REPLACE FUNCTION :"namespace".get_auth(username TEXT)
RETURNS TABLE(username TEXT, password TEXT) AS '
RETURNS TABLE(username TEXT, password TEXT) AS E'
SELECT rolname::TEXT, rolpassword::TEXT
FROM pg_catalog.pg_authid
WHERE pg_authid.rolname = $1
AND pg_authid.rolcanlogin
AND NOT pg_authid.rolsuper
AND NOT pg_authid.rolreplication
AND pg_authid.rolname <> ''_crunchypgbouncer''
AND pg_authid.rolname <> E''_crunchypgbouncer''
AND (pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)'
LANGUAGE SQL STABLE SECURITY DEFINER;
REVOKE ALL PRIVILEGES
Expand Down
22 changes: 22 additions & 0 deletions internal/postgres/sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2021 - 2024 Crunchy Data Solutions, Inc.
//
// SPDX-License-Identifier: Apache-2.0

package postgres

import "strings"

// escapeLiteral is called by QuoteLiteral to add backslashes before special
// characters of the "escape" string syntax. Double quote marks to escape them
// regardless of the "backslash_quote" parameter.
var escapeLiteral = strings.NewReplacer(`'`, `''`, `\`, `\\`).Replace

// QuoteLiteral escapes v so it can be safely used as a literal (or constant)
// in an SQL statement.
func QuoteLiteral(v string) string {
// Use the "escape" syntax to ensure that backslashes behave consistently regardless
// of the "standard_conforming_strings" parameter. Include a space before so
// the "E" cannot change the meaning of an adjacent SQL keyword or identifier.
// - https://www.postgresql.org/docs/current/sql-syntax-lexical.html
return ` E'` + escapeLiteral(v) + `'`
}
16 changes: 16 additions & 0 deletions internal/postgres/sql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2021 - 2024 Crunchy Data Solutions, Inc.
//
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"testing"

"gotest.tools/v3/assert"
)

func TestQuoteLiteral(t *testing.T) {
assert.Equal(t, QuoteLiteral(``), ` E''`)
assert.Equal(t, QuoteLiteral(`ab"cd\ef'gh`), ` E'ab"cd\\ef''gh'`)
}
62 changes: 0 additions & 62 deletions internal/util/util.go

This file was deleted.

0 comments on commit 561c650

Please sign in to comment.