From 760948a533106a8e8e4769781700dc67e9687016 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 12 Jul 2023 13:50:03 -0700 Subject: [PATCH 1/6] Local storage binding: disable access to other system folders for security reasons (#2947) Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- bindings/localstorage/localstorage.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bindings/localstorage/localstorage.go b/bindings/localstorage/localstorage.go index a038fd4ce4..654bc54346 100644 --- a/bindings/localstorage/localstorage.go +++ b/bindings/localstorage/localstorage.go @@ -40,6 +40,9 @@ const ( // List of root paths that are disallowed var disallowedRootPaths = []string{ + filepath.Clean("/proc"), + filepath.Clean("/sys"), + filepath.Clean("/boot"), // See: https://github.com/dapr/components-contrib/issues/2444 filepath.Clean("/var/run/secrets"), } From fd8e3a208674713572dd94fa0282ed1fe27be064 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 12 Jul 2023 13:51:58 -0700 Subject: [PATCH 2/6] Fixes in HTTP binding (#2981) Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Bernd Verst --- bindings/http/http.go | 47 +++++++++++++------------------------- bindings/http/http_test.go | 16 ------------- 2 files changed, 16 insertions(+), 47 deletions(-) diff --git a/bindings/http/http.go b/bindings/http/http.go index 3acc9dff2d..63fe4ee4f6 100644 --- a/bindings/http/http.go +++ b/bindings/http/http.go @@ -29,7 +29,6 @@ import ( "strconv" "strings" "time" - "unicode" "github.com/dapr/components-contrib/bindings" "github.com/dapr/components-contrib/internal/utils" @@ -102,11 +101,11 @@ func (h *HTTPSource) Init(_ context.Context, meta bindings.Metadata) error { // See guidance on proper HTTP client settings here: // https://medium.com/@nate510/don-t-use-go-s-default-http-client-4804cb19f779 dialer := &net.Dialer{ - Timeout: 5 * time.Second, + Timeout: 15 * time.Second, } netTransport := &http.Transport{ Dial: dialer.Dial, - TLSHandshakeTimeout: 5 * time.Second, + TLSHandshakeTimeout: 15 * time.Second, TLSClientConfig: tlsConfig, } @@ -150,17 +149,11 @@ func (h *HTTPSource) readMTLSClientCertificates(tlsConfig *tls.Config) error { func (h *HTTPSource) setTLSRenegotiation(tlsConfig *tls.Config) error { switch h.metadata.MTLSRenegotiation { case "RenegotiateNever": - { - tlsConfig.Renegotiation = tls.RenegotiateNever - } + tlsConfig.Renegotiation = tls.RenegotiateNever case "RenegotiateOnceAsClient": - { - tlsConfig.Renegotiation = tls.RenegotiateOnceAsClient - } + tlsConfig.Renegotiation = tls.RenegotiateOnceAsClient case "RenegotiateFreelyAsClient": - { - tlsConfig.Renegotiation = tls.RenegotiateFreelyAsClient - } + tlsConfig.Renegotiation = tls.RenegotiateFreelyAsClient default: return fmt.Errorf("invalid renegotiation value: %s", h.metadata.MTLSRenegotiation) } @@ -231,23 +224,18 @@ func (h *HTTPSource) Invoke(parentCtx context.Context, req *bindings.InvokeReque errorIfNot2XX := h.errorIfNot2XX // Default to the component config (default is true) - if req.Metadata != nil { - if path, ok := req.Metadata["path"]; ok { - // Simplicity and no "../../.." type exploits. - u = fmt.Sprintf("%s/%s", strings.TrimRight(u, "/"), strings.TrimLeft(path, "/")) - if strings.Contains(u, "..") { - return nil, fmt.Errorf("invalid path: %s", path) - } - } - - if _, ok := req.Metadata["errorIfNot2XX"]; ok { - errorIfNot2XX = utils.IsTruthy(req.Metadata["errorIfNot2XX"]) - } - } else { + if req.Metadata == nil { // Prevent things below from failing if req.Metadata is nil. req.Metadata = make(map[string]string) } + if req.Metadata["path"] != "" { + u = strings.TrimRight(u, "/") + "/" + strings.TrimLeft(req.Metadata["path"], "/") + } + if req.Metadata["errorIfNot2XX"] != "" { + errorIfNot2XX = utils.IsTruthy(req.Metadata["errorIfNot2XX"]) + } + var body io.Reader method := strings.ToUpper(string(req.Operation)) // For backward compatibility @@ -262,10 +250,8 @@ func (h *HTTPSource) Invoke(parentCtx context.Context, req *bindings.InvokeReque return nil, fmt.Errorf("invalid operation: %s", req.Operation) } - var ctx context.Context - if h.metadata.ResponseTimeout == nil { - ctx = parentCtx - } else { + ctx := parentCtx + if h.metadata.ResponseTimeout != nil { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(parentCtx, *h.metadata.ResponseTimeout) defer cancel() @@ -294,8 +280,7 @@ func (h *HTTPSource) Invoke(parentCtx context.Context, req *bindings.InvokeReque // Any metadata keys that start with a capital letter // are treated as request headers for mdKey, mdValue := range req.Metadata { - keyAsRunes := []rune(mdKey) - if len(keyAsRunes) > 0 && unicode.IsUpper(keyAsRunes[0]) { + if len(mdKey) > 0 && (mdKey[0] >= 'A' && mdKey[0] <= 'Z') { request.Header.Set(mdKey, mdValue) } } diff --git a/bindings/http/http_test.go b/bindings/http/http_test.go index 04e53cdfd0..164b1ef328 100644 --- a/bindings/http/http_test.go +++ b/bindings/http/http_test.go @@ -553,14 +553,6 @@ func verifyDefaultBehaviors(t *testing.T, hs bindings.OutputBinding, handler *HT err: "", statusCode: 200, }, - "invalid path": { - input: "expected", - operation: "POST", - metadata: map[string]string{"path": "/../test"}, - path: "", - err: "invalid path: /../test", - statusCode: 400, - }, "invalid operation": { input: "notvalid", operation: "notvalid", @@ -665,14 +657,6 @@ func verifyNon2XXErrorsSuppressed(t *testing.T, hs bindings.OutputBinding, handl err: "", statusCode: 200, }, - "invalid path": { - input: "expected", - operation: "POST", - metadata: map[string]string{"path": "/../test"}, - path: "", - err: "invalid path: /../test", - statusCode: 400, - }, } for name, tc := range tests { From 1349fca858369cc067a93576be0a19d0c05df58f Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 12 Jul 2023 14:16:35 -0700 Subject: [PATCH 3/6] MySQL binding: allow passing parameters for queries (#2975) Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- bindings/mysql/metadata.yaml | 13 +- bindings/mysql/mysql.go | 129 +++++++++++------ bindings/mysql/mysql_integration_test.go | 176 +++++++++++++---------- bindings/mysql/mysql_test.go | 8 +- 4 files changed, 192 insertions(+), 134 deletions(-) diff --git a/bindings/mysql/metadata.yaml b/bindings/mysql/metadata.yaml index e305696d11..005b2e7637 100644 --- a/bindings/mysql/metadata.yaml +++ b/bindings/mysql/metadata.yaml @@ -17,17 +17,17 @@ binding: - name: query description: "The query operation is used for SELECT statements, which returns the metadata along with data in a form of an array of row values." - name: close - description: "The close operation can be used to explicitly close the DB connection and return it to the pool. This operation doesn’t have any response." + description: "The close operation can be used to explicitly close the DB connection and return it to the pool. This operation doesn't have any response." metadata: - name: url required: true - description: "Represent a DB connection in Data Source Name (DNS) format." - example: "user:password@tcp(localhost:3306)/dbname" + description: "Represent a DB connection in Data Source Name (DNS) format" + example: '"user:password@tcp(localhost:3306)/dbname"' type: string - name: pemPath required: false description: "Path to the PEM file. Used with SSL connection" - example: "path/to/pem/file" + example: '"path/to/pem/file"' type: string - name: maxIdleConns required: false @@ -49,8 +49,3 @@ metadata: description: "The max connection idel time." example: "12s" type: duration - - name: maxRetries - required: false - description: "MaxRetries is the maximum number of retries for a query." - example: "5" - type: number diff --git a/bindings/mysql/mysql.go b/bindings/mysql/mysql.go index 1c53dc8df4..2127caf0f3 100644 --- a/bindings/mysql/mysql.go +++ b/bindings/mysql/mysql.go @@ -25,6 +25,7 @@ import ( "os" "reflect" "strconv" + "sync/atomic" "time" "github.com/go-sql-driver/mysql" @@ -52,7 +53,8 @@ const ( // "%s:%s@tcp(%s:3306)/%s?allowNativePasswords=true&tls=custom",'myadmin@mydemoserver', 'yourpassword', 'mydemoserver.mysql.database.azure.com', 'targetdb'. // keys from request's metadata. - commandSQLKey = "sql" + commandSQLKey = "sql" + commandParamsKey = "params" // keys from response's metadata. respOpKey = "operation" @@ -67,6 +69,7 @@ const ( type Mysql struct { db *sql.DB logger logger.Logger + closed atomic.Bool } type mysqlMetadata struct { @@ -87,21 +90,22 @@ type mysqlMetadata struct { // ConnMaxIdleTime is the maximum amount of time a connection may be idle. ConnMaxIdleTime time.Duration `mapstructure:"connMaxIdleTime"` - - // MaxRetries is the maximum number of retries for a query. - MaxRetries int `mapstructure:"maxRetries"` } // NewMysql returns a new MySQL output binding. func NewMysql(logger logger.Logger) bindings.OutputBinding { - return &Mysql{logger: logger} + return &Mysql{ + logger: logger, + } } // Init initializes the MySQL binding. func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error { - m.logger.Debug("Initializing MySql binding") + if m.closed.Load() { + return errors.New("cannot initialize a previously-closed component") + } - // parse metadata + // Parse metadata meta := mysqlMetadata{} err := metadata.DecodeMetadata(md.Properties, &meta) if err != nil { @@ -112,23 +116,29 @@ func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error { return fmt.Errorf("missing MySql connection string") } - db, err := initDB(meta.URL, meta.PemPath) + m.db, err = initDB(meta.URL, meta.PemPath) if err != nil { return err } - db.SetMaxIdleConns(meta.MaxIdleConns) - db.SetMaxOpenConns(meta.MaxOpenConns) - db.SetConnMaxIdleTime(meta.ConnMaxIdleTime) - db.SetConnMaxLifetime(meta.ConnMaxLifetime) + if meta.MaxIdleConns > 0 { + m.db.SetMaxIdleConns(meta.MaxIdleConns) + } + if meta.MaxOpenConns > 0 { + m.db.SetMaxOpenConns(meta.MaxOpenConns) + } + if meta.ConnMaxIdleTime > 0 { + m.db.SetConnMaxIdleTime(meta.ConnMaxIdleTime) + } + if meta.ConnMaxLifetime > 0 { + m.db.SetConnMaxLifetime(meta.ConnMaxLifetime) + } - err = db.PingContext(ctx) + err = m.db.PingContext(ctx) if err != nil { return fmt.Errorf("unable to ping the DB: %w", err) } - m.db = db - return nil } @@ -138,22 +148,38 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi return nil, errors.New("invoke request required") } + // We let the "close" operation here succeed even if the component has been closed already if req.Operation == closeOperation { - return nil, m.db.Close() + return nil, m.Close() + } + + if m.closed.Load() { + return nil, errors.New("component is closed") } if req.Metadata == nil { return nil, errors.New("metadata required") } - m.logger.Debugf("operation: %v", req.Operation) - s, ok := req.Metadata[commandSQLKey] - if !ok || s == "" { + s := req.Metadata[commandSQLKey] + if s == "" { return nil, fmt.Errorf("required metadata not set: %s", commandSQLKey) } - startTime := time.Now() + // Metadata property "params" contains JSON-encoded parameters, and it's optional + // If present, it must be unserializable into a []any object + var ( + params []any + err error + ) + if paramsStr := req.Metadata[commandParamsKey]; paramsStr != "" { + err = json.Unmarshal([]byte(paramsStr), ¶ms) + if err != nil { + return nil, fmt.Errorf("invalid metadata property %s: failed to unserialize into an array: %w", commandParamsKey, err) + } + } + startTime := time.Now().UTC() resp := &bindings.InvokeResponse{ Metadata: map[string]string{ respOpKey: string(req.Operation), @@ -162,16 +188,16 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi }, } - switch req.Operation { //nolint:exhaustive + switch req.Operation { case execOperation: - r, err := m.exec(ctx, s) + r, err := m.exec(ctx, s, params...) if err != nil { return nil, err } resp.Metadata[respRowsAffectedKey] = strconv.FormatInt(r, 10) case queryOperation: - d, err := m.query(ctx, s) + d, err := m.query(ctx, s, params...) if err != nil { return nil, err } @@ -182,7 +208,7 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi req.Operation, execOperation, queryOperation, closeOperation) } - endTime := time.Now() + endTime := time.Now().UTC() resp.Metadata[respEndTimeKey] = endTime.Format(time.RFC3339Nano) resp.Metadata[respDurationKey] = endTime.Sub(startTime).String() @@ -200,23 +226,26 @@ func (m *Mysql) Operations() []bindings.OperationKind { // Close will close the DB. func (m *Mysql) Close() error { + if !m.closed.CompareAndSwap(false, true) { + // If this failed, the component has already been closed + // We allow multiple calls to close + return nil + } + if m.db != nil { - return m.db.Close() + m.db.Close() + m.db = nil } return nil } -func (m *Mysql) query(ctx context.Context, sql string) ([]byte, error) { - rows, err := m.db.QueryContext(ctx, sql) +func (m *Mysql) query(ctx context.Context, sql string, params ...any) ([]byte, error) { + rows, err := m.db.QueryContext(ctx, sql, params...) if err != nil { return nil, fmt.Errorf("error executing query: %w", err) } - - defer func() { - _ = rows.Close() - _ = rows.Err() - }() + defer rows.Close() result, err := m.jsonify(rows) if err != nil { @@ -226,10 +255,8 @@ func (m *Mysql) query(ctx context.Context, sql string) ([]byte, error) { return result, nil } -func (m *Mysql) exec(ctx context.Context, sql string) (int64, error) { - m.logger.Debugf("exec: %s", sql) - - res, err := m.db.ExecContext(ctx, sql) +func (m *Mysql) exec(ctx context.Context, sql string, params ...any) (int64, error) { + res, err := m.db.ExecContext(ctx, sql, params...) if err != nil { return 0, fmt.Errorf("error executing query: %w", err) } @@ -238,13 +265,15 @@ func (m *Mysql) exec(ctx context.Context, sql string) (int64, error) { } func initDB(url, pemPath string) (*sql.DB, error) { - if _, err := mysql.ParseDSN(url); err != nil { + conf, err := mysql.ParseDSN(url) + if err != nil { return nil, fmt.Errorf("illegal Data Source Name (DSN) specified by %s", connectionURLKey) } if pemPath != "" { + var pem []byte rootCertPool := x509.NewCertPool() - pem, err := os.ReadFile(pemPath) + pem, err = os.ReadFile(pemPath) if err != nil { return nil, fmt.Errorf("error reading PEM file from %s: %w", pemPath, err) } @@ -254,17 +283,25 @@ func initDB(url, pemPath string) (*sql.DB, error) { return nil, fmt.Errorf("failed to append PEM") } - err = mysql.RegisterTLSConfig("custom", &tls.Config{RootCAs: rootCertPool, MinVersion: tls.VersionTLS12}) + err = mysql.RegisterTLSConfig("custom", &tls.Config{ + RootCAs: rootCertPool, + MinVersion: tls.VersionTLS12, + }) if err != nil { return nil, fmt.Errorf("error register TLS config: %w", err) } } - db, err := sql.Open("mysql", url) + // Required to correctly parse time columns + // See: https://stackoverflow.com/a/45040724 + conf.ParseTime = true + + connector, err := mysql.NewConnector(conf) if err != nil { return nil, fmt.Errorf("error opening DB connection: %w", err) } + db := sql.OpenDB(connector) return db, nil } @@ -274,7 +311,7 @@ func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) { return nil, err } - var ret []interface{} + var ret []any for rows.Next() { values := prepareValues(columnTypes) err := rows.Scan(values...) @@ -289,13 +326,13 @@ func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) { return json.Marshal(ret) } -func prepareValues(columnTypes []*sql.ColumnType) []interface{} { +func prepareValues(columnTypes []*sql.ColumnType) []any { types := make([]reflect.Type, len(columnTypes)) for i, tp := range columnTypes { types[i] = tp.ScanType() } - values := make([]interface{}, len(columnTypes)) + values := make([]any, len(columnTypes)) for i := range values { values[i] = reflect.New(types[i]).Interface() } @@ -303,8 +340,8 @@ func prepareValues(columnTypes []*sql.ColumnType) []interface{} { return values } -func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map[string]interface{} { - r := map[string]interface{}{} +func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []any) map[string]any { + r := map[string]any{} for i, ct := range columnTypes { value := values[i] @@ -312,7 +349,7 @@ func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map switch v := values[i].(type) { case driver.Valuer: if vv, err := v.Value(); err == nil { - value = interface{}(vv) + value = any(vv) } else { m.logger.Warnf("error to convert value: %v", err) } diff --git a/bindings/mysql/mysql_integration_test.go b/bindings/mysql/mysql_integration_test.go index fcc173526f..d5df16dda3 100644 --- a/bindings/mysql/mysql_integration_test.go +++ b/bindings/mysql/mysql_integration_test.go @@ -22,36 +22,20 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/bindings" "github.com/dapr/components-contrib/metadata" "github.com/dapr/kit/logger" ) -const ( - // MySQL doesn't accept RFC3339 formatted time, rejects trailing 'Z' for UTC indicator. - mySQLDateTimeFormat = "2006-01-02 15:04:05" - - testCreateTable = `CREATE TABLE IF NOT EXISTS foo ( - id bigint NOT NULL, - v1 character varying(50) NOT NULL, - b BOOLEAN, - ts TIMESTAMP, - data LONGTEXT)` - testDropTable = `DROP TABLE foo` - testInsert = "INSERT INTO foo (id, v1, b, ts, data) VALUES (%d, 'test-%d', %t, '%v', '%s')" - testDelete = "DELETE FROM foo" - testUpdate = "UPDATE foo SET ts = '%v' WHERE id = %d" - testSelect = "SELECT * FROM foo WHERE id < 3" - testSelectJSONExtract = "SELECT JSON_EXTRACT(data, '$.key') AS `key` FROM foo WHERE id < 3" -) +// MySQL doesn't accept RFC3339 formatted time, rejects trailing 'Z' for UTC indicator. +const mySQLDateTimeFormat = "2006-01-02 15:04:05" func TestOperations(t *testing.T) { - t.Parallel() t.Run("Get operation list", func(t *testing.T) { - t.Parallel() - b := NewMysql(nil) - assert.NotNil(t, b) + b := NewMysql(logger.NewLogger("test")) + require.NotNil(t, b) l := b.Operations() assert.Equal(t, 3, len(l)) assert.Contains(t, l, execOperation) @@ -70,123 +54,165 @@ func TestOperations(t *testing.T) { func TestMysqlIntegration(t *testing.T) { url := os.Getenv("MYSQL_TEST_CONN_URL") if url == "" { - t.SkipNow() + t.Skip("Skipping because env var MYSQL_TEST_CONN_URL is empty") } b := NewMysql(logger.NewLogger("test")).(*Mysql) m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}} - if err := b.Init(context.Background(), m); err != nil { - t.Fatal(err) - } - defer b.Close() + err := b.Init(context.Background(), m) + require.NoError(t, err) - req := &bindings.InvokeRequest{Metadata: map[string]string{}} + defer b.Close() t.Run("Invoke create table", func(t *testing.T) { - req.Operation = execOperation - req.Metadata[commandSQLKey] = testCreateTable - res, err := b.Invoke(context.TODO(), req) + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: execOperation, + Metadata: map[string]string{ + commandSQLKey: `CREATE TABLE IF NOT EXISTS foo ( + id bigint NOT NULL, + v1 character varying(50) NOT NULL, + b BOOLEAN, + ts TIMESTAMP, + data LONGTEXT)`, + }, + }) assertResponse(t, res, err) }) t.Run("Invoke delete", func(t *testing.T) { - req.Operation = execOperation - req.Metadata[commandSQLKey] = testDelete - res, err := b.Invoke(context.TODO(), req) + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: execOperation, + Metadata: map[string]string{ + commandSQLKey: "DELETE FROM foo", + }, + }) assertResponse(t, res, err) }) t.Run("Invoke insert", func(t *testing.T) { - req.Operation = execOperation for i := 0; i < 10; i++ { - req.Metadata[commandSQLKey] = fmt.Sprintf(testInsert, i, i, true, time.Now().Format(mySQLDateTimeFormat), "{\"key\":\"val\"}") - res, err := b.Invoke(context.TODO(), req) + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: execOperation, + Metadata: map[string]string{ + commandSQLKey: fmt.Sprintf( + "INSERT INTO foo (id, v1, b, ts, data) VALUES (%d, 'test-%d', %t, '%v', '%s')", + i, i, true, time.Now().Format(mySQLDateTimeFormat), `{"key":"val"}`), + }, + }) assertResponse(t, res, err) } }) t.Run("Invoke update", func(t *testing.T) { - req.Operation = execOperation + date := time.Now().Add(time.Hour) + for i := 0; i < 10; i++ { + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: execOperation, + Metadata: map[string]string{ + commandSQLKey: fmt.Sprintf( + "UPDATE foo SET ts = '%v' WHERE id = %d", + date.Add(10*time.Duration(i)*time.Second).Format(mySQLDateTimeFormat), i), + }, + }) + assertResponse(t, res, err) + assert.Equal(t, "1", res.Metadata[respRowsAffectedKey]) + } + }) + + t.Run("Invoke update with parameters", func(t *testing.T) { + date := time.Now().Add(2 * time.Hour) for i := 0; i < 10; i++ { - req.Metadata[commandSQLKey] = fmt.Sprintf(testUpdate, time.Now().Format(mySQLDateTimeFormat), i) - res, err := b.Invoke(context.TODO(), req) + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: execOperation, + Metadata: map[string]string{ + commandSQLKey: "UPDATE foo SET ts = ? WHERE id = ?", + commandParamsKey: fmt.Sprintf(`[%q,%d]`, date.Add(10*time.Duration(i)*time.Second).Format(mySQLDateTimeFormat), i), + }, + }) assertResponse(t, res, err) + assert.Equal(t, "1", res.Metadata[respRowsAffectedKey]) } }) t.Run("Invoke select", func(t *testing.T) { - req.Operation = queryOperation - req.Metadata[commandSQLKey] = testSelect - res, err := b.Invoke(context.TODO(), req) + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: queryOperation, + Metadata: map[string]string{ + commandSQLKey: "SELECT * FROM foo WHERE id < 3", + }, + }) assertResponse(t, res, err) t.Logf("received result: %s", res.Data) // verify number, boolean and string - assert.Contains(t, string(res.Data), "\"id\":1") - assert.Contains(t, string(res.Data), "\"b\":1") - assert.Contains(t, string(res.Data), "\"v1\":\"test-1\"") - assert.Contains(t, string(res.Data), "\"data\":\"{\\\"key\\\":\\\"val\\\"}\"") + assert.Contains(t, string(res.Data), `"id":1`) + assert.Contains(t, string(res.Data), `"b":1`) + assert.Contains(t, string(res.Data), `"v1":"test-1"`) + assert.Contains(t, string(res.Data), `"data":"{\"key\":\"val\"}"`) - result := make([]interface{}, 0) + result := make([]any, 0) err = json.Unmarshal(res.Data, &result) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 3, len(result)) // verify timestamp - ts, ok := result[0].(map[string]interface{})["ts"].(string) + ts, ok := result[0].(map[string]any)["ts"].(string) assert.True(t, ok) // have to use custom layout to parse timestamp, see this: https://github.com/dapr/components-contrib/pull/615 var tt time.Time tt, err = time.Parse("2006-01-02T15:04:05Z", ts) - assert.Nil(t, err) + require.NoError(t, err) t.Logf("time stamp is: %v", tt) }) - t.Run("Invoke select JSON_EXTRACT", func(t *testing.T) { - req.Operation = queryOperation - req.Metadata[commandSQLKey] = testSelectJSONExtract - res, err := b.Invoke(context.TODO(), req) + t.Run("Invoke select with parameters", func(t *testing.T) { + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: queryOperation, + Metadata: map[string]string{ + commandSQLKey: "SELECT * FROM foo WHERE id = ?", + commandParamsKey: `[1]`, + }, + }) assertResponse(t, res, err) t.Logf("received result: %s", res.Data) - // verify json extract number - assert.Contains(t, string(res.Data), "{\"key\":\"\\\"val\\\"\"}") + // verify number, boolean and string + assert.Contains(t, string(res.Data), `"id":1`) + assert.Contains(t, string(res.Data), `"b":1`) + assert.Contains(t, string(res.Data), `"v1":"test-1"`) + assert.Contains(t, string(res.Data), `"data":"{\"key\":\"val\"}"`) - result := make([]interface{}, 0) + result := make([]any, 0) err = json.Unmarshal(res.Data, &result) - assert.Nil(t, err) - assert.Equal(t, 3, len(result)) - }) - - t.Run("Invoke delete", func(t *testing.T) { - req.Operation = execOperation - req.Metadata[commandSQLKey] = testDelete - req.Data = nil - res, err := b.Invoke(context.TODO(), req) - assertResponse(t, res, err) + require.NoError(t, err) + assert.Equal(t, 1, len(result)) }) t.Run("Invoke drop", func(t *testing.T) { - req.Operation = execOperation - req.Metadata[commandSQLKey] = testDropTable - res, err := b.Invoke(context.TODO(), req) + res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: execOperation, + Metadata: map[string]string{ + commandSQLKey: "DROP TABLE foo", + }, + }) assertResponse(t, res, err) }) t.Run("Invoke close", func(t *testing.T) { - req.Operation = closeOperation - req.Metadata = nil - req.Data = nil - _, err := b.Invoke(context.TODO(), req) + _, err := b.Invoke(context.Background(), &bindings.InvokeRequest{ + Operation: closeOperation, + }) assert.NoError(t, err) }) } func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) { + t.Helper() + assert.NoError(t, err) assert.NotNil(t, res) if res != nil { - assert.NotNil(t, res.Metadata) + assert.NotEmpty(t, res.Metadata) } } diff --git a/bindings/mysql/mysql_test.go b/bindings/mysql/mysql_test.go index a17c151b12..37c8f8ce5e 100644 --- a/bindings/mysql/mysql_test.go +++ b/bindings/mysql/mysql_test.go @@ -42,7 +42,7 @@ func TestQuery(t *testing.T) { assert.Nil(t, err) t.Logf("query result: %s", ret) assert.Contains(t, string(ret), "\"id\":1") - var result []interface{} + var result []any err = json.Unmarshal(ret, &result) assert.Nil(t, err) assert.Equal(t, 3, len(result)) @@ -65,13 +65,13 @@ func TestQuery(t *testing.T) { assert.Contains(t, string(ret), "\"id\":1") assert.Contains(t, string(ret), "\"value\":2.2") - var result []interface{} + var result []any err = json.Unmarshal(ret, &result) assert.Nil(t, err) assert.Equal(t, 3, len(result)) // verify timestamp - ts, ok := result[0].(map[string]interface{})["timestamp"].(string) + ts, ok := result[0].(map[string]any)["timestamp"].(string) assert.True(t, ok) var tt time.Time tt, err = time.Parse(time.RFC3339, ts) @@ -134,7 +134,7 @@ func TestInvoke(t *testing.T) { } resp, err := m.Invoke(context.Background(), req) assert.Nil(t, err) - var data []interface{} + var data []any err = json.Unmarshal(resp.Data, &data) assert.Nil(t, err) assert.Equal(t, 1, len(data)) From a4012953ea92db109c8721203c095856f86c98af Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 12 Jul 2023 15:03:18 -0700 Subject: [PATCH 4/6] Add Azure AD support to Postgres configuration store and bindings (#2971) Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> --- .../azure/setup-azure-conf-test.sh | 12 ++ .../docker-compose-postgresql.yml | 2 +- .github/scripts/test-info.mjs | 56 ++++++++- bindings/postgres/metadata.go | 52 ++++++++ bindings/postgres/metadata.yaml | 54 ++++++++- bindings/postgres/postgres.go | 69 ++++++----- bindings/postgres/postgres_test.go | 4 +- configuration/postgres/metadata.go | 51 +++++++- configuration/postgres/metadata.yaml | 59 +++++++-- configuration/postgres/postgres.go | 57 ++------- configuration/postgres/postgres_test.go | 18 +-- .../authentication/postgresql/metadata.go | 113 ++++++++++++++++++ internal/component/postgresql/metadata.go | 113 +++--------------- .../component/postgresql/metadata_test.go | 28 ++--- .../component/postgresql/postgresdbaccess.go | 23 ++-- state/postgresql/metadata.yaml | 2 +- .../components/standard/postgres.yaml | 6 +- .../bindings/postgres/docker-compose.yml | 4 +- .../bindings/postgres/postgres_test.go | 38 +++--- .../configuration/postgres/docker-compose.yml | 2 +- tests/certification/go.mod | 1 - tests/certification/go.sum | 2 - .../state/postgresql/docker-compose.yml | 2 +- tests/config/bindings/postgres/bindings.yml | 10 -- .../bindings/postgresql/azure/bindings.yml | 18 +++ .../bindings/postgresql/docker/bindings.yml | 11 ++ tests/config/bindings/tests.yml | 4 +- .../postgresql/azure/configstore.yml | 20 ++++ .../docker}/configstore.yml | 3 +- tests/config/configuration/tests.yml | 4 +- .../state/postgresql/azure/statestore.yml | 6 + tests/conformance/common.go | 11 +- .../configuration/configuration.go | 17 ++- .../utils/configupdater/postgres/postgres.go | 32 +++-- 34 files changed, 599 insertions(+), 305 deletions(-) create mode 100644 bindings/postgres/metadata.go create mode 100644 internal/authentication/postgresql/metadata.go delete mode 100644 tests/config/bindings/postgres/bindings.yml create mode 100644 tests/config/bindings/postgresql/azure/bindings.yml create mode 100644 tests/config/bindings/postgresql/docker/bindings.yml create mode 100644 tests/config/configuration/postgresql/azure/configstore.yml rename tests/config/configuration/{postgres => postgresql/docker}/configstore.yml (85%) diff --git a/.github/infrastructure/conformance/azure/setup-azure-conf-test.sh b/.github/infrastructure/conformance/azure/setup-azure-conf-test.sh index 8220f57fcd..ca553026bd 100755 --- a/.github/infrastructure/conformance/azure/setup-azure-conf-test.sh +++ b/.github/infrastructure/conformance/azure/setup-azure-conf-test.sh @@ -230,6 +230,9 @@ SQL_SERVER_DB_NAME_VAR_NAME="AzureSqlServerDbName" SQL_SERVER_CONNECTION_STRING_VAR_NAME="AzureSqlServerConnectionString" AZURE_DB_POSTGRES_CONNSTRING_VAR_NAME="AzureDBPostgresConnectionString" +AZURE_DB_POSTGRES_CLIENT_ID_VAR_NAME="AzureDBPostgresClientId" +AZURE_DB_POSTGRES_CLIENT_SECRET_VAR_NAME="AzureDBPostgresClientSecret" +AZURE_DB_POSTGRES_TENANT_ID_VAR_NAME="AzureDBPostgresTenantId" STORAGE_ACCESS_KEY_VAR_NAME="AzureBlobStorageAccessKey" STORAGE_ACCOUNT_VAR_NAME="AzureBlobStorageAccount" @@ -693,6 +696,15 @@ AZURE_DB_POSTGRES_CONNSTRING="host=${PREFIX}-conf-test-pg.postgres.database.azur echo export ${AZURE_DB_POSTGRES_CONNSTRING_VAR_NAME}=\"${AZURE_DB_POSTGRES_CONNSTRING}\" >> "${ENV_CONFIG_FILENAME}" az keyvault secret set --name "${AZURE_DB_POSTGRES_CONNSTRING_VAR_NAME}" --vault-name "${KEYVAULT_NAME}" --value "${AZURE_DB_POSTGRES_CONNSTRING}" +echo export ${AZURE_DB_POSTGRES_CLIENT_ID_VAR_NAME}=\"${SDK_AUTH_SP_APPID}\" >> "${ENV_CONFIG_FILENAME}" +az keyvault secret set --name "${AZURE_DB_POSTGRES_CLIENT_ID_VAR_NAME}" --vault-name "${KEYVAULT_NAME}" --value "${SDK_AUTH_SP_APPID}" + +echo export ${AZURE_DB_POSTGRES_CLIENT_SECRET_VAR_NAME}=\"${SDK_AUTH_SP_CLIENT_SECRET}\" >> "${ENV_CONFIG_FILENAME}" +az keyvault secret set --name "${AZURE_DB_POSTGRES_CLIENT_SECRET_VAR_NAME}" --vault-name "${KEYVAULT_NAME}" --value "${SDK_AUTH_SP_CLIENT_SECRET}" + +echo export ${AZURE_DB_POSTGRES_TENANT_ID_VAR_NAME}=\"${TENANT_ID}\" >> "${ENV_CONFIG_FILENAME}" +az keyvault secret set --name "${AZURE_DB_POSTGRES_TENANT_ID_VAR_NAME}" --vault-name "${KEYVAULT_NAME}" --value "${TENANT_ID}" + # ---------------------------------- # Populate Event Hubs test settings # ---------------------------------- diff --git a/.github/infrastructure/docker-compose-postgresql.yml b/.github/infrastructure/docker-compose-postgresql.yml index 6819464d45..d11b2d537f 100644 --- a/.github/infrastructure/docker-compose-postgresql.yml +++ b/.github/infrastructure/docker-compose-postgresql.yml @@ -1,7 +1,7 @@ version: '2' services: db: - image: postgres:15 + image: postgres:15-alpine restart: always ports: - "5432:5432" diff --git a/.github/scripts/test-info.mjs b/.github/scripts/test-info.mjs index 8bfeb0187f..1c4468af26 100644 --- a/.github/scripts/test-info.mjs +++ b/.github/scripts/test-info.mjs @@ -167,9 +167,28 @@ const components = { sourcePkg: ['bindings/mqtt3'], }, 'bindings.postgres': { - conformance: true, certification: true, + }, + 'bindings.postgresql.docker': { + conformance: true, conformanceSetup: 'docker-compose.sh postgresql', + sourcePkg: [ + 'bindings/postgresql', + 'internal/authentication/postgresql', + ], + }, + 'bindings.postgresql.azure': { + conformance: true, + requiredSecrets: [ + 'AzureDBPostgresConnectionString', + 'AzureDBPostgresClientId', + 'AzureDBPostgresClientSecret', + 'AzureDBPostgresTenantId', + ], + sourcePkg: [ + 'bindings/postgresql', + 'internal/authentication/postgresql', + ], }, 'bindings.rabbitmq': { conformance: true, @@ -191,9 +210,32 @@ const components = { sourcePkg: ['bindings/redis', 'internal/component/redis'], }, 'configuration.postgres': { - conformance: true, certification: true, + sourcePkg: [ + 'configuration/postgresql', + 'internal/authentication/postgresql', + ], + }, + 'configuration.postgresql.docker': { + conformance: true, conformanceSetup: 'docker-compose.sh postgresql', + sourcePkg: [ + 'configuration/postgresql', + 'internal/authentication/postgresql', + ], + }, + 'configuration.postgresql.azure': { + conformance: true, + requiredSecrets: [ + 'AzureDBPostgresConnectionString', + 'AzureDBPostgresClientId', + 'AzureDBPostgresClientSecret', + 'AzureDBPostgresTenantId', + ], + sourcePkg: [ + 'configuration/postgresql', + 'internal/authentication/postgresql', + ], }, 'configuration.redis.v6': { conformance: true, @@ -585,6 +627,7 @@ const components = { certification: true, sourcePkg: [ 'state/postgresql', + 'internal/authentication/postgresql', 'internal/component/postgresql', 'internal/component/sql', ], @@ -594,15 +637,22 @@ const components = { conformanceSetup: 'docker-compose.sh postgresql', sourcePkg: [ 'state/postgresql', + 'internal/authentication/postgresql', 'internal/component/postgresql', 'internal/component/sql', ], }, 'state.postgresql.azure': { conformance: true, - requiredSecrets: ['AzureDBPostgresConnectionString'], + requiredSecrets: [ + 'AzureDBPostgresConnectionString', + 'AzureDBPostgresClientId', + 'AzureDBPostgresClientSecret', + 'AzureDBPostgresTenantId', + ], sourcePkg: [ 'state/postgresql', + 'internal/authentication/postgresql', 'internal/component/postgresql', 'internal/component/sql', ], diff --git a/bindings/postgres/metadata.go b/bindings/postgres/metadata.go new file mode 100644 index 0000000000..281b2c593d --- /dev/null +++ b/bindings/postgres/metadata.go @@ -0,0 +1,52 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package postgres + +import ( + pgauth "github.com/dapr/components-contrib/internal/authentication/postgresql" + contribMetadata "github.com/dapr/components-contrib/metadata" +) + +type psqlMetadata struct { + pgauth.PostgresAuthMetadata `mapstructure:",squash"` + + // URL is the connection string to connect to the database. + // Deprecated alias: use connectionString instead. + URL string `mapstructure:"url"` +} + +func (m *psqlMetadata) InitWithMetadata(meta map[string]string) error { + // Reset the object + m.PostgresAuthMetadata.Reset() + m.URL = "" + + err := contribMetadata.DecodeMetadata(meta, &m) + if err != nil { + return err + } + + // Legacy options + if m.ConnectionString == "" && m.URL != "" { + m.ConnectionString = m.URL + } + + // Validate and sanitize input + // Azure AD auth is supported for this component + err = m.PostgresAuthMetadata.InitWithMetadata(meta, true) + if err != nil { + return err + } + + return nil +} diff --git a/bindings/postgres/metadata.yaml b/bindings/postgres/metadata.yaml index 1d771523a3..01341e690b 100644 --- a/bindings/postgres/metadata.yaml +++ b/bindings/postgres/metadata.yaml @@ -19,17 +19,33 @@ binding: description: "The query operation is used for SELECT statements, which return both the metadata and the retrieved data in a form of an array of row values." - name: close description: "The close operation can be used to explicitly close the DB connection and return it to the pool. This operation doesn't have any response." +builtinAuthenticationProfiles: + - name: "azuread" + metadata: + - name: useAzureAD + required: true + type: bool + example: '"true"' + description: | + Must be set to `true` to enable the component to retrieve access tokens from Azure AD. + This authentication method only works with Azure Database for PostgreSQL databases. + - name: connectionString + required: true + sensitive: true + description: | + The connection string for the PostgreSQL database + This must contain the user, which corresponds to the name of the user created inside PostgreSQL that maps to the Azure AD identity; this is often the name of the corresponding principal (e.g. the name of the Azure AD application). This connection string should not contain any password. + example: | + "host=mydb.postgres.database.azure.com user=myapplication port=5432 database=dapr_test sslmode=require" + type: string authenticationProfiles: - title: "Connection string" - description: "Authenticate using a Connection String." + description: "Authenticate using a Connection String" metadata: - - name: url + - name: connectionString required: true sensitive: true - binding: - input: false - output: true - description: "Connection string for PostgreSQL." + description: "The connection string for the PostgreSQL database" url: title: More details url: https://docs.dapr.io/reference/components-reference/supported-bindings/postgres/#url-format @@ -37,3 +53,29 @@ authenticationProfiles: "user=dapr password=secret host=dapr.example.com port=5432 dbname=dapr sslmode=verify-ca" or "postgres://dapr:secret@dapr.example.com:5432/dapr?sslmode=verify-ca" type: string +metadata: + - name: maxConns + required: false + description: | + Maximum number of connections pooled by this component. + Set to 0 or lower to use the default value, which is the greater of 4 or the number of CPUs. + example: "4" + default: "0" + type: number + - name: connectionMaxIdleTime + required: false + description: | + Max idle time before unused connections are automatically closed in the + connection pool. By default, there's no value and this is left to the + database driver to choose. + example: "5m" + type: duration + - name: url + deprecated: true + required: false + description: | + Deprecated alias for "connectionString" + type: string + sensitive: true + example: | + "user=dapr password=secret host=dapr.example.com port=5432 dbname=dapr sslmode=verify-ca" diff --git a/bindings/postgres/postgres.go b/bindings/postgres/postgres.go index 7c03e3bb7a..9919f9b1d0 100644 --- a/bindings/postgres/postgres.go +++ b/bindings/postgres/postgres.go @@ -1,5 +1,5 @@ /* -Copyright 2021 The Dapr Authors +Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -20,6 +20,7 @@ import ( "fmt" "reflect" "strconv" + "sync/atomic" "time" "github.com/jackc/pgx/v5/pgxpool" @@ -35,38 +36,36 @@ const ( queryOperation bindings.OperationKind = "query" closeOperation bindings.OperationKind = "close" - connectionURLKey = "url" - commandSQLKey = "sql" + commandSQLKey = "sql" ) // Postgres represents PostgreSQL output binding. type Postgres struct { logger logger.Logger db *pgxpool.Pool -} - -type psqlMetadata struct { - // ConnectionURL is the connection string to connect to the database. - ConnectionURL string `mapstructure:"url"` + closed atomic.Bool } // NewPostgres returns a new PostgreSQL output binding. func NewPostgres(logger logger.Logger) bindings.OutputBinding { - return &Postgres{logger: logger} + return &Postgres{ + logger: logger, + } } // Init initializes the PostgreSql binding. func (p *Postgres) Init(ctx context.Context, meta bindings.Metadata) error { + if p.closed.Load() { + return errors.New("cannot initialize a previously-closed component") + } + m := psqlMetadata{} - err := metadata.DecodeMetadata(meta.Properties, &m) + err := m.InitWithMetadata(meta.Properties) if err != nil { return err } - if m.ConnectionURL == "" { - return fmt.Errorf("required metadata not set: %s", connectionURLKey) - } - poolConfig, err := pgxpool.ParseConfig(m.ConnectionURL) + poolConfig, err := m.GetPgxPoolConfig() if err != nil { return fmt.Errorf("error opening DB connection: %w", err) } @@ -75,7 +74,7 @@ func (p *Postgres) Init(ctx context.Context, meta bindings.Metadata) error { // only scoped to postgres creating resources at init. p.db, err = pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { - return fmt.Errorf("unable to ping the DB: %w", err) + return fmt.Errorf("unable to connect to the DB: %w", err) } return nil @@ -96,16 +95,19 @@ func (p *Postgres) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res return nil, errors.New("invoke request required") } + // We let the "close" operation here succeed even if the component has been closed already if req.Operation == closeOperation { - p.db.Close() + err = p.Close() + return nil, err + } - return nil, nil + if p.closed.Load() { + return nil, errors.New("component is closed") } if req.Metadata == nil { return nil, errors.New("metadata required") } - p.logger.Debugf("operation: %v", req.Operation) sql, ok := req.Metadata[commandSQLKey] if !ok || sql == "" { @@ -125,14 +127,14 @@ func (p *Postgres) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res case execOperation: r, err := p.exec(ctx, sql) if err != nil { - return nil, fmt.Errorf("error executing %s: %w", sql, err) + return nil, err } resp.Metadata["rows-affected"] = strconv.FormatInt(r, 10) // 0 if error case queryOperation: d, err := p.query(ctx, sql) if err != nil { - return nil, fmt.Errorf("error executing %s: %w", sql, err) + return nil, err } resp.Data = d @@ -152,17 +154,21 @@ func (p *Postgres) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res // Close close PostgreSql instance. func (p *Postgres) Close() error { - if p.db == nil { + if !p.closed.CompareAndSwap(false, true) { + // If this failed, the component has already been closed + // We allow multiple calls to close return nil } - p.db.Close() + + if p.db != nil { + p.db.Close() + } + p.db = nil return nil } func (p *Postgres) query(ctx context.Context, sql string) (result []byte, err error) { - p.logger.Debugf("query: %s", sql) - rows, err := p.db.Query(ctx, sql) if err != nil { return nil, fmt.Errorf("error executing query: %w", err) @@ -172,29 +178,26 @@ func (p *Postgres) query(ctx context.Context, sql string) (result []byte, err er for rows.Next() { val, rowErr := rows.Values() if rowErr != nil { - return nil, fmt.Errorf("error parsing result '%v': %w", rows.Err(), rowErr) + return nil, fmt.Errorf("error reading result '%v': %w", rows.Err(), rowErr) } rs = append(rs, val) //nolint:asasalint } - if result, err = json.Marshal(rs); err != nil { - err = fmt.Errorf("error serializing results: %w", err) + result, err = json.Marshal(rs) + if err != nil { + return nil, fmt.Errorf("error serializing results: %w", err) } - return + return result, nil } func (p *Postgres) exec(ctx context.Context, sql string) (result int64, err error) { - p.logger.Debugf("exec: %s", sql) - res, err := p.db.Exec(ctx, sql) if err != nil { return 0, fmt.Errorf("error executing query: %w", err) } - result = res.RowsAffected() - - return + return res.RowsAffected(), nil } // GetComponentMetadata returns the metadata of the component. diff --git a/bindings/postgres/postgres_test.go b/bindings/postgres/postgres_test.go index 78adb43ab0..392fe2a6ad 100644 --- a/bindings/postgres/postgres_test.go +++ b/bindings/postgres/postgres_test.go @@ -1,5 +1,5 @@ /* -Copyright 2021 The Dapr Authors +Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -63,7 +63,7 @@ func TestPostgresIntegration(t *testing.T) { // live DB test b := NewPostgres(logger.NewLogger("test")).(*Postgres) - m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}} + m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"connectionString": url}}} if err := b.Init(context.Background(), m); err != nil { t.Fatal(err) } diff --git a/configuration/postgres/metadata.go b/configuration/postgres/metadata.go index 0bf10cb742..b53a5b2d3c 100644 --- a/configuration/postgres/metadata.go +++ b/configuration/postgres/metadata.go @@ -13,10 +13,53 @@ limitations under the License. package postgres -import "time" +import ( + "fmt" + "time" + + pgauth "github.com/dapr/components-contrib/internal/authentication/postgresql" + contribMetadata "github.com/dapr/components-contrib/metadata" +) type metadata struct { - MaxIdleTimeout time.Duration `mapstructure:"connMaxIdleTime"` - ConnectionString string `mapstructure:"connectionString"` - ConfigTable string `mapstructure:"table"` + pgauth.PostgresAuthMetadata `mapstructure:",squash"` + + ConfigTable string `mapstructure:"table"` + MaxIdleTimeoutOld time.Duration `mapstructure:"connMaxIdleTime"` // Deprecated alias for "connectionMaxIdleTime" +} + +func (m *metadata) InitWithMetadata(meta map[string]string) error { + // Reset the object + m.PostgresAuthMetadata.Reset() + m.ConfigTable = "" + m.MaxIdleTimeoutOld = 0 + + err := contribMetadata.DecodeMetadata(meta, &m) + if err != nil { + return err + } + + // Legacy options + if m.ConnectionMaxIdleTime == 0 && m.MaxIdleTimeoutOld > 0 { + m.ConnectionMaxIdleTime = m.MaxIdleTimeoutOld + } + + // Validate and sanitize input + if m.ConfigTable == "" { + return fmt.Errorf("missing postgreSQL configuration table name") + } + if len(m.ConfigTable) > maxIdentifierLength { + return fmt.Errorf("table name is too long - tableName : '%s'. max allowed field length is %d", m.ConfigTable, maxIdentifierLength) + } + if !allowedTableNameChars.MatchString(m.ConfigTable) { + return fmt.Errorf("invalid table name '%s'. non-alphanumerics or upper cased table names are not supported", m.ConfigTable) + } + + // Azure AD auth is supported for this component + err = m.PostgresAuthMetadata.InitWithMetadata(meta, true) + if err != nil { + return err + } + + return nil } diff --git a/configuration/postgres/metadata.yaml b/configuration/postgres/metadata.yaml index 16e411c8af..47fba4a0f4 100644 --- a/configuration/postgres/metadata.yaml +++ b/configuration/postgres/metadata.yaml @@ -1,14 +1,33 @@ # yaml-language-server: $schema=../../component-metadata-schema.json schemaVersion: v1 type: configuration -name: postgres +name: postgresql version: v1 -status: alpha -title: "Postgres" +status: stable +title: "PostgreSQL" urls: - title: Reference - url: https://docs.dapr.io/reference/components-reference/supported-configuration-stores/postgres-configuration-store/ + url: https://docs.dapr.io/reference/components-reference/supported-configuration-stores/postgresql-configuration-store/ capabilities: [] +builtinAuthenticationProfiles: + - name: "azuread" + metadata: + - name: useAzureAD + required: true + type: bool + example: '"true"' + description: | + Must be set to `true` to enable the component to retrieve access tokens from Azure AD. + This authentication method only works with Azure Database for PostgreSQL databases. + - name: connectionString + required: true + sensitive: true + description: | + The connection string for the PostgreSQL database + This must contain the user, which corresponds to the name of the user created inside PostgreSQL that maps to the Azure AD identity; this is often the name of the corresponding principal (e.g. the name of the Azure AD application). This connection string should not contain any password. + example: | + "host=mydb.postgres.database.azure.com user=myapplication port=5432 database=dapr_test sslmode=require" + type: string authenticationProfiles: - title: "Connection string" description: "Authenticate using a Connection String." @@ -16,10 +35,9 @@ authenticationProfiles: - name: connectionString required: true sensitive: true - description: | - The connection string for PostgreSQL, as a URL or DSN. - Note: the default value for `pool_max_conns` is 5. - example: "host=localhost user=postgres password=example port=5432 connect_timeout=10 database=dapr_test pool_max_conns=10" + description: The connection string for the PostgreSQL database + example: | + "host=localhost user=postgres password=example port=5432 connect_timeout=10 database=dapr_test" type: string metadata: - name: table @@ -27,9 +45,26 @@ metadata: description: The table name for configuration information. example: "configTable" type: string - - name: connMaxIdleTime + - name: connectionMaxIdleTime required: false - description: The maximum amount of time a connection may be idle. - example: "15s" - default: "30s" + description: | + Max idle time before unused connections are automatically closed in the + connection pool. By default, there's no value and this is left to the + database driver to choose. + example: "5m" type: duration + - name: maxConns + required: false + description: | + Maximum number of connections pooled by this component. + Set to 0 or lower to use the default value, which is the greater of 4 or the number of CPUs. + example: "4" + default: "0" + type: number + - name: connMaxIdleTime + deprecated: true + required: false + description: | + Deprecated alias for 'connectionMaxIdleTime'. + example: "5m" + type: duration \ No newline at end of file diff --git a/configuration/postgres/postgres.go b/configuration/postgres/postgres.go index f948ac0ed7..3a072da851 100644 --- a/configuration/postgres/postgres.go +++ b/configuration/postgres/postgres.go @@ -23,13 +23,12 @@ import ( "strconv" "strings" "sync" - "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" - "k8s.io/utils/strings/slices" + "golang.org/x/exp/slices" "github.com/dapr/components-contrib/configuration" contribMetadata "github.com/dapr/components-contrib/metadata" @@ -62,34 +61,29 @@ const ( ) var ( - allowedChars = regexp.MustCompile(`^[a-zA-Z0-9./_]*$`) - allowedTableNameChars = regexp.MustCompile(`^[a-z0-9./_]*$`) - defaultMaxConnIdleTime = time.Second * 30 + allowedChars = regexp.MustCompile(`^[a-zA-Z0-9./_]*$`) + allowedTableNameChars = regexp.MustCompile(`^[a-z0-9./_]*$`) ) func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store { - logger.Debug("Instantiating PostgreSQL configuration store") return &ConfigurationStore{ logger: logger, subscribeStopChanMap: make(map[string]chan struct{}), } } -func (p *ConfigurationStore) Init(parentCtx context.Context, metadata configuration.Metadata) error { - if m, err := parseMetadata(metadata); err != nil { +func (p *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error { + err := p.metadata.InitWithMetadata(metadata.Properties) + if err != nil { p.logger.Error(err) return err - } else { - p.metadata = m } + p.ActiveSubscriptions = make(map[string]*subscription) - ctx, cancel := context.WithTimeout(parentCtx, p.metadata.MaxIdleTimeout) - defer cancel() - client, err := Connect(ctx, p.metadata.ConnectionString, p.metadata.MaxIdleTimeout) + p.client, err = p.connectDB(ctx) if err != nil { return fmt.Errorf("error connecting to configuration store: '%w'", err) } - p.client = client err = p.client.Ping(ctx) if err != nil { return fmt.Errorf("unable to connect to configuration store: '%w'", err) @@ -180,7 +174,7 @@ func (p *ConfigurationStore) Subscribe(ctx context.Context, req *configuration.S } } if pgNotifyChannel == "" { - return "", fmt.Errorf("unable to subscribe to '%s'.pgNotifyChannel attribute cannot be empty", p.metadata.ConfigTable) + return "", fmt.Errorf("unable to subscribe to '%s'. pgNotifyChannel attribute cannot be empty", p.metadata.ConfigTable) } return p.subscribeToChannel(ctx, pgNotifyChannel, req, handler) } @@ -290,37 +284,8 @@ func (p *ConfigurationStore) handleSubscribedChange(ctx context.Context, handler } } -func parseMetadata(cmetadata configuration.Metadata) (metadata, error) { - m := metadata{ - MaxIdleTimeout: defaultMaxConnIdleTime, - } - decodeErr := contribMetadata.DecodeMetadata(cmetadata.Properties, &m) - if decodeErr != nil { - return m, decodeErr - } - - if m.ConnectionString == "" { - return m, fmt.Errorf("missing postgreSQL connection string") - } - - if m.ConfigTable != "" { - if !allowedTableNameChars.MatchString(m.ConfigTable) { - return m, fmt.Errorf("invalid table name '%s'. non-alphanumerics or upper cased table names are not supported", m.ConfigTable) - } - if len(m.ConfigTable) > maxIdentifierLength { - return m, fmt.Errorf("table name is too long - tableName : '%s'. max allowed field length is %d", m.ConfigTable, maxIdentifierLength) - } - } else { - return m, fmt.Errorf("missing postgreSQL configuration table name") - } - if m.MaxIdleTimeout <= 0 { - m.MaxIdleTimeout = defaultMaxConnIdleTime - } - return m, nil -} - -func Connect(ctx context.Context, conn string, maxTimeout time.Duration) (*pgxpool.Pool, error) { - config, err := pgxpool.ParseConfig(conn) +func (p *ConfigurationStore) connectDB(ctx context.Context) (*pgxpool.Pool, error) { + config, err := p.metadata.GetPgxPoolConfig() if err != nil { return nil, fmt.Errorf("postgres configuration store connection error : %w", err) } diff --git a/configuration/postgres/postgres_test.go b/configuration/postgres/postgres_test.go index 4a9666582e..74fb36a787 100644 --- a/configuration/postgres/postgres_test.go +++ b/configuration/postgres/postgres_test.go @@ -20,8 +20,10 @@ import ( "github.com/pashagolub/pgxmock/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/configuration" + pgauth "github.com/dapr/components-contrib/internal/authentication/postgresql" ) func TestSelectAllQuery(t *testing.T) { @@ -43,7 +45,7 @@ func TestSelectAllQuery(t *testing.T) { if err != nil { t.Errorf("Error building query: %v ", err) } - assert.Nil(t, err, "Error building query: %v ", err) + assert.NoError(t, err, "Error building query: %v ", err) assert.Equal(t, expected, query, "did not get expected result. Got: '%v' , Expected: '%v'", query, expected) } @@ -57,7 +59,7 @@ func TestPostgresbuildQuery(t *testing.T) { query, params, err := buildQuery(g, "cfgtbl") _ = params - assert.Nil(t, err, "Error building query: %v ", err) + assert.NoError(t, err, "Error building query: %v ", err) expected := "SELECT * FROM cfgtbl WHERE KEY IN ($1) AND $2 = $3" assert.Equal(t, expected, query, "did not get expected result. Got: '%v' , Expected: '%v'", query, expected) i := 0 @@ -80,12 +82,14 @@ func TestPostgresbuildQuery(t *testing.T) { func TestConnectAndQuery(t *testing.T) { m := metadata{ - ConnectionString: "mockConnectionString", - ConfigTable: "mockConfigTable", + PostgresAuthMetadata: pgauth.PostgresAuthMetadata{ + ConnectionString: "mockConnectionString", + }, + ConfigTable: "mockConfigTable", } mock, err := pgxmock.NewPool() - assert.Nil(t, err) + require.NoError(t, err) defer mock.Close() query := "SELECT EXISTS (SELECT FROM pg_tables where tablename = '" + m.ConfigTable + "'" @@ -97,9 +101,9 @@ func TestConnectAndQuery(t *testing.T) { rows := mock.QueryRow(context.Background(), query) var id string err = rows.Scan(&id) - assert.Nil(t, err, "error in scan") + assert.NoError(t, err, "error in scan") err = mock.ExpectationsWereMet() - assert.Nil(t, err, "pgxmock error in expectations were met") + assert.NoError(t, err, "pgxmock error in expectations were met") } func TestValidateInput(t *testing.T) { diff --git a/internal/authentication/postgresql/metadata.go b/internal/authentication/postgresql/metadata.go new file mode 100644 index 0000000000..66f67fef1f --- /dev/null +++ b/internal/authentication/postgresql/metadata.go @@ -0,0 +1,113 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package postgresql + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/dapr/components-contrib/internal/authentication/azure" +) + +// PostgresAuthMetadata contains authentication metadata for PostgreSQL components. +type PostgresAuthMetadata struct { + ConnectionString string `mapstructure:"connectionString"` + ConnectionMaxIdleTime time.Duration `mapstructure:"connectionMaxIdleTime"` + MaxConns int `mapstructure:"maxConns"` + UseAzureAD bool `mapstructure:"useAzureAD"` + + azureEnv azure.EnvironmentSettings +} + +// Reset the object. +func (m *PostgresAuthMetadata) Reset() { + m.ConnectionString = "" + m.ConnectionMaxIdleTime = 0 + m.MaxConns = 0 + m.UseAzureAD = false +} + +// InitWithMetadata inits the object with metadata from the user. +// Set azureADEnabled to true if the component can support authentication with Azure AD. +// This is different from the "useAzureAD" property from the user, which is provided by the user and instructs the component to authenticate using Azure AD. +func (m *PostgresAuthMetadata) InitWithMetadata(meta map[string]string, azureADEnabled bool) (err error) { + // Validate input + if m.ConnectionString == "" { + return errors.New("missing connection string") + } + + // Populate the Azure environment if using Azure AD + if azureADEnabled && m.UseAzureAD { + m.azureEnv, err = azure.NewEnvironmentSettings(meta) + if err != nil { + return err + } + } else { + // Make sure this is false + m.UseAzureAD = false + } + + return nil +} + +// GetPgxPoolConfig returns the pgxpool.Config object that contains the credentials for connecting to PostgreSQL. +func (m *PostgresAuthMetadata) GetPgxPoolConfig() (*pgxpool.Config, error) { + // Get the config from the connection string + config, err := pgxpool.ParseConfig(m.ConnectionString) + if err != nil { + return nil, fmt.Errorf("failed to parse connection string: %w", err) + } + if m.ConnectionMaxIdleTime > 0 { + config.MaxConnIdleTime = m.ConnectionMaxIdleTime + } + if m.MaxConns > 1 { + config.MaxConns = int32(m.MaxConns) + } + + // Check if we should use Azure AD + if m.UseAzureAD { + tokenCred, errToken := m.azureEnv.GetTokenCredential() + if errToken != nil { + return nil, errToken + } + + // Reset the password + config.ConnConfig.Password = "" + + // We need to retrieve the token every time we attempt a new connection + // This is because tokens expire, and connections can drop and need to be re-established at any time + // Fortunately, we can do this with the "BeforeConnect" hook + config.BeforeConnect = func(ctx context.Context, cc *pgx.ConnConfig) error { + at, err := tokenCred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{ + m.azureEnv.Cloud.Services[azure.ServiceOSSRDBMS].Audience + "/.default", + }, + }) + if err != nil { + return err + } + + cc.Password = at.Token + return nil + } + } + + return config, nil +} diff --git a/internal/component/postgresql/metadata.go b/internal/component/postgresql/metadata.go index fb10bed6a0..a4dd06f076 100644 --- a/internal/component/postgresql/metadata.go +++ b/internal/component/postgresql/metadata.go @@ -1,5 +1,5 @@ /* -Copyright 2021 The Dapr Authors +Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -14,15 +14,10 @@ limitations under the License. package postgresql import ( - "context" "fmt" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" - - "github.com/dapr/components-contrib/internal/authentication/azure" + pgauth "github.com/dapr/components-contrib/internal/authentication/postgresql" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" "github.com/dapr/kit/ptr" @@ -39,24 +34,17 @@ const ( ) type postgresMetadataStruct struct { - ConnectionString string `mapstructure:"connectionString"` - ConnectionMaxIdleTime time.Duration `mapstructure:"connectionMaxIdleTime"` - TableName string `mapstructure:"tableName"` // Could be in the format "schema.table" or just "table" - MetadataTableName string `mapstructure:"metadataTableName"` // Could be in the format "schema.table" or just "table" - Timeout time.Duration `mapstructure:"timeoutInSeconds"` - CleanupInterval *time.Duration `mapstructure:"cleanupIntervalInSeconds"` - MaxConns int `mapstructure:"maxConns"` - UseAzureAD bool `mapstructure:"useAzureAD"` + pgauth.PostgresAuthMetadata `mapstructure:",squash"` - // Set to true if the component can support authentication with Azure AD. - // This is different from the "useAzureAD" property above, which is provided by the user and instructs the component to authenticate using Azure AD. - azureADEnabled bool - azureEnv azure.EnvironmentSettings + TableName string `mapstructure:"tableName"` // Could be in the format "schema.table" or just "table" + MetadataTableName string `mapstructure:"metadataTableName"` // Could be in the format "schema.table" or just "table" + Timeout time.Duration `mapstructure:"timeoutInSeconds"` + CleanupInterval *time.Duration `mapstructure:"cleanupIntervalInSeconds"` } -func (m *postgresMetadataStruct) InitWithMetadata(meta state.Metadata) error { +func (m *postgresMetadataStruct) InitWithMetadata(meta state.Metadata, azureADEnabled bool) error { // Reset the object - m.ConnectionString = "" + m.PostgresAuthMetadata.Reset() m.TableName = defaultTableName m.MetadataTableName = defaultMetadataTableName m.CleanupInterval = ptr.Of(defaultCleanupInternal * time.Second) @@ -69,8 +57,9 @@ func (m *postgresMetadataStruct) InitWithMetadata(meta state.Metadata) error { } // Validate and sanitize input - if m.ConnectionString == "" { - return errMissingConnectionString + err = m.PostgresAuthMetadata.InitWithMetadata(meta.Properties, azureADEnabled) + if err != nil { + return err } // Timeout @@ -79,79 +68,15 @@ func (m *postgresMetadataStruct) InitWithMetadata(meta state.Metadata) error { } // Cleanup interval - if m.CleanupInterval != nil { - // Non-positive value from meta means disable auto cleanup. - if *m.CleanupInterval <= 0 { - if meta.Properties[cleanupIntervalKey] == "" { - // unfortunately the mapstructure decoder decodes an empty string to 0, a missing key would be nil however - m.CleanupInterval = ptr.Of(defaultCleanupInternal * time.Second) - } else { - m.CleanupInterval = nil - } - } - } - - // Populate the Azure environment if using Azure AD - if m.azureADEnabled && m.UseAzureAD { - m.azureEnv, err = azure.NewEnvironmentSettings(meta.Properties) - if err != nil { - return err + // Non-positive value from meta means disable auto cleanup. + if m.CleanupInterval != nil && *m.CleanupInterval <= 0 { + if meta.Properties[cleanupIntervalKey] == "" { + // Unfortunately the mapstructure decoder decodes an empty string to 0, a missing key would be nil however + m.CleanupInterval = ptr.Of(defaultCleanupInternal * time.Second) + } else { + m.CleanupInterval = nil } } return nil } - -// GetPgxPoolConfig returns the pgxpool.Config object that contains the credentials for connecting to Postgres. -func (m *postgresMetadataStruct) GetPgxPoolConfig() (*pgxpool.Config, error) { - // Get the config from the connection string - config, err := pgxpool.ParseConfig(m.ConnectionString) - if err != nil { - return nil, fmt.Errorf("failed to parse connection string: %w", err) - } - if m.ConnectionMaxIdleTime > 0 { - config.MaxConnIdleTime = m.ConnectionMaxIdleTime - } - if m.MaxConns > 1 { - config.MaxConns = int32(m.MaxConns) - } - - // Check if we should use Azure AD - if m.azureADEnabled && m.UseAzureAD { - tokenCred, errToken := m.azureEnv.GetTokenCredential() - if errToken != nil { - return nil, errToken - } - - // Reset the password - config.ConnConfig.Password = "" - - /*// For Azure AD, using SSL is required - // If not already enabled, configure TLS without certificate validation - if config.ConnConfig.TLSConfig == nil { - config.ConnConfig.TLSConfig = &tls.Config{ - //nolint:gosec - InsecureSkipVerify: true, - } - }*/ - - // We need to retrieve the token every time we attempt a new connection - // This is because tokens expire, and connections can drop and need to be re-established at any time - // Fortunately, we can do this with the "BeforeConnect" hook - config.BeforeConnect = func(ctx context.Context, cc *pgx.ConnConfig) error { - at, err := tokenCred.GetToken(ctx, policy.TokenRequestOptions{ - Scopes: []string{ - m.azureEnv.Cloud.Services[azure.ServiceOSSRDBMS].Audience + "/.default", - }, - }) - if err != nil { - return err - } - - cc.Password = at.Token - return nil - } - } - - return config, nil -} diff --git a/internal/component/postgresql/metadata_test.go b/internal/component/postgresql/metadata_test.go index 9d56f0ab34..bb4b612f18 100644 --- a/internal/component/postgresql/metadata_test.go +++ b/internal/component/postgresql/metadata_test.go @@ -1,5 +1,5 @@ /* -Copyright 2021 The Dapr Authors +Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -28,9 +28,9 @@ func TestMetadata(t *testing.T) { m := postgresMetadataStruct{} props := map[string]string{} - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.Error(t, err) - assert.ErrorIs(t, err, errMissingConnectionString) + assert.ErrorContains(t, err, "connection string") }) t.Run("has connection string", func(t *testing.T) { @@ -39,7 +39,7 @@ func TestMetadata(t *testing.T) { "connectionString": "foo", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) }) @@ -49,7 +49,7 @@ func TestMetadata(t *testing.T) { "connectionString": "foo", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) assert.Equal(t, m.TableName, defaultTableName) }) @@ -61,7 +61,7 @@ func TestMetadata(t *testing.T) { "tableName": "mytable", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) assert.Equal(t, m.TableName, "mytable") }) @@ -72,7 +72,7 @@ func TestMetadata(t *testing.T) { "connectionString": "foo", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) assert.Equal(t, defaultTimeout*time.Second, m.Timeout) }) @@ -84,7 +84,7 @@ func TestMetadata(t *testing.T) { "timeoutInSeconds": "NaN", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.Error(t, err) }) @@ -95,7 +95,7 @@ func TestMetadata(t *testing.T) { "timeoutInSeconds": "42", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) assert.Equal(t, 42*time.Second, m.Timeout) }) @@ -107,7 +107,7 @@ func TestMetadata(t *testing.T) { "timeoutInSeconds": "0", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.Error(t, err) }) @@ -117,7 +117,7 @@ func TestMetadata(t *testing.T) { "connectionString": "foo", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) _ = assert.NotNil(t, m.CleanupInterval) && assert.Equal(t, defaultCleanupInternal*time.Second, *m.CleanupInterval) @@ -130,7 +130,7 @@ func TestMetadata(t *testing.T) { "cleanupIntervalInSeconds": "NaN", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.Error(t, err) }) @@ -141,7 +141,7 @@ func TestMetadata(t *testing.T) { "cleanupIntervalInSeconds": "42", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) _ = assert.NotNil(t, m.CleanupInterval) && assert.Equal(t, 42*time.Second, *m.CleanupInterval) @@ -154,7 +154,7 @@ func TestMetadata(t *testing.T) { "cleanupIntervalInSeconds": "0", } - err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) assert.NoError(t, err) assert.Nil(t, m.CleanupInterval) }) diff --git a/internal/component/postgresql/postgresdbaccess.go b/internal/component/postgresql/postgresdbaccess.go index 5a5ec16e1b..431279569c 100644 --- a/internal/component/postgresql/postgresdbaccess.go +++ b/internal/component/postgresql/postgresdbaccess.go @@ -35,8 +35,6 @@ import ( "github.com/dapr/kit/ptr" ) -var errMissingConnectionString = errors.New("missing connection string") - // Interface that applies to *pgxpool.Pool. // We need this to be able to mock the connection in tests. type PGXPoolConn interface { @@ -57,9 +55,10 @@ type PostgresDBAccess struct { gc internalsql.GarbageCollector - migrateFn func(context.Context, PGXPoolConn, MigrateOptions) error - setQueryFn func(*state.SetRequest, SetQueryOptions) string - etagColumn string + migrateFn func(context.Context, PGXPoolConn, MigrateOptions) error + setQueryFn func(*state.SetRequest, SetQueryOptions) string + etagColumn string + enableAzureAD bool } // newPostgresDBAccess creates a new instance of postgresAccess. @@ -67,13 +66,11 @@ func newPostgresDBAccess(logger logger.Logger, opts Options) *PostgresDBAccess { logger.Debug("Instantiating new Postgres state store") return &PostgresDBAccess{ - logger: logger, - metadata: postgresMetadataStruct{ - azureADEnabled: opts.EnableAzureAD, - }, - migrateFn: opts.MigrateFn, - setQueryFn: opts.SetQueryFn, - etagColumn: opts.ETagColumn, + logger: logger, + migrateFn: opts.MigrateFn, + setQueryFn: opts.SetQueryFn, + etagColumn: opts.ETagColumn, + enableAzureAD: opts.EnableAzureAD, } } @@ -81,7 +78,7 @@ func newPostgresDBAccess(logger logger.Logger, opts Options) *PostgresDBAccess { func (p *PostgresDBAccess) Init(ctx context.Context, meta state.Metadata) error { p.logger.Debug("Initializing Postgres state store") - err := p.metadata.InitWithMetadata(meta) + err := p.metadata.InitWithMetadata(meta, p.enableAzureAD) if err != nil { p.logger.Errorf("Failed to parse metadata: %v", err) return err diff --git a/state/postgresql/metadata.yaml b/state/postgresql/metadata.yaml index 2d88cf6b6b..e8eee6fbb1 100644 --- a/state/postgresql/metadata.yaml +++ b/state/postgresql/metadata.yaml @@ -37,7 +37,7 @@ builtinAuthenticationProfiles: type: string authenticationProfiles: - title: "Connection string" - description: "Authenticate using a Connection String." + description: "Authenticate using a Connection String" metadata: - name: connectionString required: true diff --git a/tests/certification/bindings/postgres/components/standard/postgres.yaml b/tests/certification/bindings/postgres/components/standard/postgres.yaml index 0c9f368f74..ebfc129b32 100644 --- a/tests/certification/bindings/postgres/components/standard/postgres.yaml +++ b/tests/certification/bindings/postgres/components/standard/postgres.yaml @@ -3,8 +3,8 @@ kind: Component metadata: name: standard-binding spec: - type: bindings.postgres + type: bindings.postgresql version: v1 metadata: - - name: url - value: "host=localhost user=postgres password=example port=5432 connect_timeout=10 database=dapr_test" \ No newline at end of file + - name: connectionString + value: "host=localhost user=postgres password=example port=5432 connect_timeout=10 database=dapr_test" diff --git a/tests/certification/bindings/postgres/docker-compose.yml b/tests/certification/bindings/postgres/docker-compose.yml index 48a388d3f9..2d62bd1b27 100644 --- a/tests/certification/bindings/postgres/docker-compose.yml +++ b/tests/certification/bindings/postgres/docker-compose.yml @@ -1,11 +1,11 @@ version: '2' services: db: - image: postgres + image: postgres:15-alpine restart: always ports: - "5432:5432" environment: POSTGRES_USER: postgres POSTGRES_PASSWORD: example - POSTGRES_DB: dapr_test \ No newline at end of file + POSTGRES_DB: dapr_test diff --git a/tests/certification/bindings/postgres/postgres_test.go b/tests/certification/bindings/postgres/postgres_test.go index ac825be467..6cc0052f98 100644 --- a/tests/certification/bindings/postgres/postgres_test.go +++ b/tests/certification/bindings/postgres/postgres_test.go @@ -19,7 +19,8 @@ import ( "testing" "time" - _ "github.com/lib/pq" + // PGX driver for database/sql + _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -46,6 +47,7 @@ const ( ) func TestPostgres(t *testing.T) { + const tableName = "dapr_test_table" ports, _ := dapr_testing.GetFreePorts(3) grpcPort := ports[0] @@ -59,7 +61,7 @@ func TestPostgres(t *testing.T) { ctx.Log("Invoking output binding for exec operation!") req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "exec", Metadata: metadata} - req.Metadata["sql"] = "INSERT INTO dapr_test_table (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05Z07:00');" + req.Metadata["sql"] = "INSERT INTO " + tableName + " (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05Z07:00');" errBinding := client.InvokeOutputBinding(ctx, req) require.NoError(ctx, errBinding, "error in output binding - exec") @@ -74,7 +76,7 @@ func TestPostgres(t *testing.T) { ctx.Log("Invoking output binding for query operation!") req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "query", Metadata: metadata} - req.Metadata["sql"] = "SELECT * FROM dapr_test_table WHERE id = 1;" + req.Metadata["sql"] = "SELECT * FROM " + tableName + " WHERE id = 1;" resp, errBinding := client.InvokeBinding(ctx, req) assert.Contains(t, string(resp.Data), "1,\"demo\",\"2020-09-24T11:45:05Z07:00\"") require.NoError(ctx, errBinding, "error in output binding - query") @@ -84,17 +86,18 @@ func TestPostgres(t *testing.T) { testClose := func(ctx flow.Context) error { client, err := daprClient.NewClientWithPort(fmt.Sprintf("%d", grpcPort)) - require.NoError(t, err, "Could not initialize dapr client.") + require.NoError(ctx, err, "Could not initialize dapr client.") metadata := make(map[string]string) - ctx.Log("Invoking output binding for query operation!") + ctx.Log("Invoking output binding for close operation!") req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "close", Metadata: metadata} errBinding := client.InvokeOutputBinding(ctx, req) require.NoError(ctx, errBinding, "error in output binding - close") + ctx.Log("Invoking output binding for query operation!") req = &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "query", Metadata: metadata} - req.Metadata["sql"] = "SELECT * FROM dapr_test_table WHERE id = 1;" + req.Metadata["sql"] = "SELECT * FROM " + tableName + " WHERE id = 1;" errBinding = client.InvokeOutputBinding(ctx, req) require.Error(ctx, errBinding, "error in output binding - query") @@ -102,9 +105,9 @@ func TestPostgres(t *testing.T) { } createTable := func(ctx flow.Context) error { - db, err := sql.Open("postgres", dockerConnectionString) + db, err := sql.Open("pgx", dockerConnectionString) assert.NoError(t, err) - _, err = db.Exec("CREATE TABLE dapr_test_table(id INT, c1 TEXT, ts TEXT);") + _, err = db.Exec("CREATE TABLE " + tableName + " (id INT, c1 TEXT, ts TEXT);") assert.NoError(t, err) db.Close() return nil @@ -114,7 +117,6 @@ func TestPostgres(t *testing.T) { Step(dockercompose.Run("db", dockerComposeYAML)). Step("wait for component to start", flow.Sleep(10*time.Second)). Step("Creating table", createTable). - Step("wait for component to start", flow.Sleep(10*time.Second)). Step(sidecar.Run("standardSidecar", embedded.WithoutApp(), embedded.WithDaprGRPCPort(grpcPort), @@ -124,14 +126,13 @@ func TestPostgres(t *testing.T) { )). Step("Run exec test", testExec). Step("Run query test", testQuery). - Step("wait for DB operations to complete", flow.Sleep(10*time.Second)). Step("Run close test", testClose). Step("stop postgresql", dockercompose.Stop("db", dockerComposeYAML, "db")). - Step("wait for component to start", flow.Sleep(10*time.Second)). Run() } func TestPostgresNetworkError(t *testing.T) { + const tableName = "dapr_test_table_network" ports, _ := dapr_testing.GetFreePorts(3) grpcPort := ports[0] @@ -145,7 +146,7 @@ func TestPostgresNetworkError(t *testing.T) { ctx.Log("Invoking output binding for exec operation!") req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "exec", Metadata: metadata} - req.Metadata["sql"] = "INSERT INTO dapr_test_table (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05Z07:00');" + req.Metadata["sql"] = "INSERT INTO " + tableName + " (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05Z07:00');" errBinding := client.InvokeOutputBinding(ctx, req) require.NoError(ctx, errBinding, "error in output binding - exec") @@ -160,7 +161,7 @@ func TestPostgresNetworkError(t *testing.T) { ctx.Log("Invoking output binding for query operation!") req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "query", Metadata: metadata} - req.Metadata["sql"] = "SELECT * FROM dapr_test_table WHERE id = 1;" + req.Metadata["sql"] = "SELECT * FROM " + tableName + " WHERE id = 1;" errBinding := client.InvokeOutputBinding(ctx, req) require.NoError(ctx, errBinding, "error in output binding - query") @@ -168,9 +169,9 @@ func TestPostgresNetworkError(t *testing.T) { } createTable := func(ctx flow.Context) error { - db, err := sql.Open("postgres", dockerConnectionString) + db, err := sql.Open("pgx", dockerConnectionString) assert.NoError(t, err) - _, err = db.Exec("CREATE TABLE dapr_test_table(id INT, c1 TEXT, ts TEXT);") + _, err = db.Exec("CREATE TABLE " + tableName + "(id INT, c1 TEXT, ts TEXT);") assert.NoError(t, err) db.Close() return nil @@ -180,7 +181,6 @@ func TestPostgresNetworkError(t *testing.T) { Step(dockercompose.Run("db", dockerComposeYAML)). Step("wait for component to start", flow.Sleep(10*time.Second)). Step("Creating table", createTable). - Step("wait for component to start", flow.Sleep(10*time.Second)). Step(sidecar.Run("standardSidecar", embedded.WithoutApp(), embedded.WithDaprGRPCPort(grpcPort), @@ -190,8 +190,8 @@ func TestPostgresNetworkError(t *testing.T) { )). Step("Run exec test", testExec). Step("Run query test", testQuery). - Step("wait for DB operations to complete", flow.Sleep(10*time.Second)). - Step("interrupt network", network.InterruptNetwork(30*time.Second, nil, nil, "5432")). + Step("wait for DB operations to complete", flow.Sleep(5*time.Second)). + Step("interrupt network", network.InterruptNetwork(20*time.Second, nil, nil, "5432")). Step("wait for component to recover", flow.Sleep(10*time.Second)). Step("Run query test", testQuery). Run() @@ -204,7 +204,7 @@ func componentRuntimeOptions() []runtime.Option { bindingsRegistry.Logger = log bindingsRegistry.RegisterOutputBinding(func(l logger.Logger) bindings.OutputBinding { return binding_postgres.NewPostgres(l) - }, "postgres") + }, "postgresql") return []runtime.Option{ runtime.WithBindings(bindingsRegistry), diff --git a/tests/certification/configuration/postgres/docker-compose.yml b/tests/certification/configuration/postgres/docker-compose.yml index 560a5d111c..dd46ec9693 100644 --- a/tests/certification/configuration/postgres/docker-compose.yml +++ b/tests/certification/configuration/postgres/docker-compose.yml @@ -1,7 +1,7 @@ version: '2' services: db: - image: postgres:15 + image: postgres:15-alpine restart: always ports: - "5432:5432" diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 8896bf8d2f..6c16649bd7 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -28,7 +28,6 @@ require ( github.com/google/uuid v1.3.0 github.com/jackc/pgx/v5 v5.3.1 github.com/lestrrat-go/jwx/v2 v2.0.11 - github.com/lib/pq v1.10.7 github.com/nacos-group/nacos-sdk-go/v2 v2.1.3 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/rabbitmq/amqp091-go v1.8.1 diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 3bdfe28575..9f6d9eb0f8 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -893,8 +893,6 @@ github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmt github.com/lestrrat/go-envload v0.0.0-20180220120943-6ed08b54a570/go.mod h1:BLt8L9ld7wVsvEWQbuLrUZnCMnUmLZ+CGDzKtclrTlE= github.com/lestrrat/go-file-rotatelogs v0.0.0-20180223000712-d3151e2a480f/go.mod h1:UGmTpUd3rjbtfIpwAPrcfmGf/Z1HS95TATB+m57TPB8= github.com/lestrrat/go-strftime v0.0.0-20180220042222-ba3bf9c1d042/go.mod h1:TPpsiPUEh0zFL1Snz4crhMlBe60PYxRHr5oFF3rRYg0= -github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= -github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/linkedin/goavro/v2 v2.9.8 h1:jN50elxBsGBDGVDEKqUlDuU1cFwJ11K/yrJCBMe/7Wg= diff --git a/tests/certification/state/postgresql/docker-compose.yml b/tests/certification/state/postgresql/docker-compose.yml index 48a388d3f9..dd46ec9693 100644 --- a/tests/certification/state/postgresql/docker-compose.yml +++ b/tests/certification/state/postgresql/docker-compose.yml @@ -1,7 +1,7 @@ version: '2' services: db: - image: postgres + image: postgres:15-alpine restart: always ports: - "5432:5432" diff --git a/tests/config/bindings/postgres/bindings.yml b/tests/config/bindings/postgres/bindings.yml deleted file mode 100644 index be68503baa..0000000000 --- a/tests/config/bindings/postgres/bindings.yml +++ /dev/null @@ -1,10 +0,0 @@ -apiVersion: dapr.io/v1alpha1 -kind: Component -metadata: - name: postgres-binding -spec: - type: bindings.postgres - version: v1 - metadata: - - name: url # Required - value: "host=localhost user=postgres password=example port=5432 connect_timeout=10 database=dapr_test" diff --git a/tests/config/bindings/postgresql/azure/bindings.yml b/tests/config/bindings/postgresql/azure/bindings.yml new file mode 100644 index 0000000000..7d66ed5021 --- /dev/null +++ b/tests/config/bindings/postgresql/azure/bindings.yml @@ -0,0 +1,18 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: postgres-binding +spec: + type: bindings.postgresql + version: v1 + metadata: + - name: connectionString + value: "${{AzureDBPostgresConnectionString}}" + - name: azureClientId + value: "${{AzureDBPostgresClientId}}" + - name: azureClientSecret + value: "${{AzureDBPostgresClientSecret}}" + - name: azureTenantId + value: "${{AzureDBPostgresTenantId}}" + - name: useAzureAD + value: "true" \ No newline at end of file diff --git a/tests/config/bindings/postgresql/docker/bindings.yml b/tests/config/bindings/postgresql/docker/bindings.yml new file mode 100644 index 0000000000..f6b9679791 --- /dev/null +++ b/tests/config/bindings/postgresql/docker/bindings.yml @@ -0,0 +1,11 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: postgres-binding +spec: + type: bindings.postgresql + version: v1 + metadata: + # "url" is the old name for "connectionString" and is kept here to test for backwards-compatibility + - name: url + value: "host=localhost user=postgres password=example port=5432 connect_timeout=10 database=dapr_test" diff --git a/tests/config/bindings/tests.yml b/tests/config/bindings/tests.yml index b8efdcf7a8..62c83c0c81 100644 --- a/tests/config/bindings/tests.yml +++ b/tests/config/bindings/tests.yml @@ -73,7 +73,9 @@ components: checkInOrderProcessing: false - component: kubemq operations: [ "create", "operations", "read" ] - - component: postgres + - component: postgresql.docker + operations: [ "exec", "query", "close", "operations" ] + - component: postgresql.azure operations: [ "exec", "query", "close", "operations" ] - component: aws.s3.docker operations: ["create", "operations", "get", "list"] diff --git a/tests/config/configuration/postgresql/azure/configstore.yml b/tests/config/configuration/postgresql/azure/configstore.yml new file mode 100644 index 0000000000..375f4f5391 --- /dev/null +++ b/tests/config/configuration/postgresql/azure/configstore.yml @@ -0,0 +1,20 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: configstore +spec: + type: configuration.postgresql + version: v1 + metadata: + - name: connectionString + value: "${{AzureDBPostgresConnectionString}}" + - name: azureClientId + value: "${{AzureDBPostgresClientId}}" + - name: azureClientSecret + value: "${{AzureDBPostgresClientSecret}}" + - name: azureTenantId + value: "${{AzureDBPostgresTenantId}}" + - name: useAzureAD + value: "true" + - name: table + value: configtable \ No newline at end of file diff --git a/tests/config/configuration/postgres/configstore.yml b/tests/config/configuration/postgresql/docker/configstore.yml similarity index 85% rename from tests/config/configuration/postgres/configstore.yml rename to tests/config/configuration/postgresql/docker/configstore.yml index b2d313707e..5f408c6f4d 100644 --- a/tests/config/configuration/postgres/configstore.yml +++ b/tests/config/configuration/postgresql/docker/configstore.yml @@ -3,7 +3,8 @@ kind: Component metadata: name: configstore spec: - type: configuration.postgres + type: configuration.postgresql + version: v1 metadata: - name: connectionString value: "host=localhost user=postgres password=example port=5432 connect_timeout=10 database=dapr_test" diff --git a/tests/config/configuration/tests.yml b/tests/config/configuration/tests.yml index 17dd6354a0..6713453afa 100644 --- a/tests/config/configuration/tests.yml +++ b/tests/config/configuration/tests.yml @@ -5,5 +5,7 @@ components: operations: [] - component: redis.v7 operations: [] - - component: postgres + - component: postgresql.azure + operations: [] + - component: postgresql.docker operations: [] diff --git a/tests/config/state/postgresql/azure/statestore.yml b/tests/config/state/postgresql/azure/statestore.yml index 1788537e8e..b58a92c49c 100644 --- a/tests/config/state/postgresql/azure/statestore.yml +++ b/tests/config/state/postgresql/azure/statestore.yml @@ -8,5 +8,11 @@ spec: metadata: - name: connectionString value: "${{AzureDBPostgresConnectionString}}" + - name: azureClientId + value: "${{AzureDBPostgresClientId}}" + - name: azureClientSecret + value: "${{AzureDBPostgresClientSecret}}" + - name: azureTenantId + value: "${{AzureDBPostgresTenantId}}" - name: useAzureAD value: "true" \ No newline at end of file diff --git a/tests/conformance/common.go b/tests/conformance/common.go index 7ba46cd7e5..bfd4850992 100644 --- a/tests/conformance/common.go +++ b/tests/conformance/common.go @@ -431,13 +431,10 @@ func loadConfigurationStore(tc TestComponent) (configuration.Store, configupdate var store configuration.Store var updater configupdater.Updater switch tc.Component { - case redisv6: - store = c_redis.NewRedisConfigurationStore(testLogger) - updater = cu_redis.NewRedisConfigUpdater(testLogger) - case redisv7: + case redisv6, redisv7: store = c_redis.NewRedisConfigurationStore(testLogger) updater = cu_redis.NewRedisConfigUpdater(testLogger) - case "postgres": + case "postgresql.docker", "postgresql.azure": store = c_postgres.NewPostgresConfigurationStore(testLogger) updater = cu_postgres.NewPostgresConfigUpdater(testLogger) default: @@ -624,7 +621,9 @@ func loadOutputBindings(tc TestComponent) bindings.OutputBinding { binding = b_rabbitmq.NewRabbitMQ(testLogger) case "kubemq": binding = b_kubemq.NewKubeMQ(testLogger) - case "postgres": + case "postgresql.docker": + binding = b_postgres.NewPostgres(testLogger) + case "postgresql.azure": binding = b_postgres.NewPostgres(testLogger) case "aws.s3.docker": binding = b_aws_s3.NewAWSS3(testLogger) diff --git a/tests/conformance/configuration/configuration.go b/tests/conformance/configuration/configuration.go index 4dd76bac70..988e6192c5 100644 --- a/tests/conformance/configuration/configuration.go +++ b/tests/conformance/configuration/configuration.go @@ -37,7 +37,7 @@ const ( v1 = "1.0.0" defaultMaxReadDuration = 30 * time.Second defaultWaitDuration = 5 * time.Second - postgresComponent = "postgres" + postgresComponent = "postgresql" pgNotifyChannelKey = "pgNotifyChannel" pgNotifyChannel = "config" ) @@ -152,7 +152,7 @@ func ConformanceTests(t *testing.T, props map[string]string, store configuration require.NoError(t, err) // Creating trigger for postgres config updater - if component == postgresComponent { + if strings.HasPrefix(component, postgresComponent) { err = updater.(*postgres_updater.ConfigUpdater).CreateTrigger(pgNotifyChannel) require.NoError(t, err) } @@ -223,7 +223,7 @@ func ConformanceTests(t *testing.T, props map[string]string, store configuration t.Run("subscribe", func(t *testing.T) { subscribeMetadata := make(map[string]string) - if component == postgresComponent { + if strings.HasPrefix(component, postgresComponent) { subscribeMetadata[pgNotifyChannelKey] = pgNotifyChannel } t.Run("subscriber 1 with non-empty key list", func(t *testing.T) { @@ -307,7 +307,7 @@ func ConformanceTests(t *testing.T, props map[string]string, store configuration // Delete initValues2 errDelete := updater.DeleteKey(getKeys(initValues2)) assert.NoError(t, errDelete, "expected no error on updating keys") - if component != postgresComponent { + if !strings.HasPrefix(component, postgresComponent) { for k := range initValues2 { initValues2[k] = &configuration.Item{} } @@ -323,10 +323,9 @@ func ConformanceTests(t *testing.T, props map[string]string, store configuration t.Run("unsubscribe", func(t *testing.T) { t.Run("unsubscribe subscriber 1", func(t *testing.T) { - ID1 := subscribeIDs[0] err := store.Unsubscribe(context.Background(), &configuration.UnsubscribeRequest{ - ID: ID1, + ID: subscribeIDs[0], }, ) assert.NoError(t, err, "expected no error in unsubscribe") @@ -346,10 +345,9 @@ func ConformanceTests(t *testing.T, props map[string]string, store configuration }) t.Run("unsubscribe subscriber 2", func(t *testing.T) { - ID2 := subscribeIDs[1] err := store.Unsubscribe(context.Background(), &configuration.UnsubscribeRequest{ - ID: ID2, + ID: subscribeIDs[1], }, ) assert.NoError(t, err, "expected no error in unsubscribe") @@ -367,10 +365,9 @@ func ConformanceTests(t *testing.T, props map[string]string, store configuration }) t.Run("unsubscribe subscriber 3", func(t *testing.T) { - ID3 := subscribeIDs[2] err := store.Unsubscribe(context.Background(), &configuration.UnsubscribeRequest{ - ID: ID3, + ID: subscribeIDs[2], }, ) assert.NoError(t, err, "expected no error in unsubscribe") diff --git a/tests/utils/configupdater/postgres/postgres.go b/tests/utils/configupdater/postgres/postgres.go index 609bdd469b..ce5b02bb2b 100644 --- a/tests/utils/configupdater/postgres/postgres.go +++ b/tests/utils/configupdater/postgres/postgres.go @@ -5,10 +5,13 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/dapr/components-contrib/configuration" + pgauth "github.com/dapr/components-contrib/internal/authentication/postgresql" + "github.com/dapr/components-contrib/internal/utils" "github.com/dapr/components-contrib/tests/utils/configupdater" "github.com/dapr/kit/logger" ) @@ -72,7 +75,7 @@ func (r *ConfigUpdater) CreateTrigger(channel string) error { return fmt.Errorf("error creating config event function : %w", err) } triggerName := "configTrigger_" + channel - createTriggerSQL := "CREATE TRIGGER " + triggerName + " AFTER INSERT OR UPDATE OR DELETE ON " + r.configTable + " FOR EACH ROW EXECUTE PROCEDURE " + procedureName + "();" + createTriggerSQL := "CREATE OR REPLACE TRIGGER " + triggerName + " AFTER INSERT OR UPDATE OR DELETE ON " + r.configTable + " FOR EACH ROW EXECUTE PROCEDURE " + procedureName + "();" _, err = r.client.Exec(ctx, createTriggerSQL) if err != nil { return fmt.Errorf("error creating config trigger : %w", err) @@ -81,31 +84,38 @@ func (r *ConfigUpdater) CreateTrigger(channel string) error { } func (r *ConfigUpdater) Init(props map[string]string) error { - var conn string - ctx := context.Background() - if val, ok := props["connectionString"]; ok && val != "" { - conn = val - } else { - return fmt.Errorf("missing postgreSQL connection string") + md := pgauth.PostgresAuthMetadata{ + ConnectionString: props["connectionString"], + UseAzureAD: utils.IsTruthy(props["useAzureAD"]), + } + err := md.InitWithMetadata(props, true) + if err != nil { + return err } + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + if tbl, ok := props["table"]; ok && tbl != "" { r.configTable = tbl } else { return fmt.Errorf("missing postgreSQL configuration table name") } - config, err := pgxpool.ParseConfig(conn) + + config, err := md.GetPgxPoolConfig() if err != nil { return fmt.Errorf("postgres configuration store connection error : %w", err) } - pool, err := pgxpool.NewWithConfig(ctx, config) + + r.client, err = pgxpool.NewWithConfig(ctx, config) if err != nil { return fmt.Errorf("postgres configuration store connection error : %w", err) } - err = pool.Ping(ctx) + err = r.client.Ping(ctx) if err != nil { return fmt.Errorf("postgres configuration store ping error : %w", err) } - r.client = pool + err = createAndSetTable(ctx, r.client, r.configTable) if err != nil { return fmt.Errorf("postgres configuration store table creation error : %w", err) From 1ab15ef04ba571ff7c7c51b299d77afaa4565071 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 12 Jul 2023 15:34:14 -0700 Subject: [PATCH 5/6] Postgres binding: support parametrized queries (#2972) Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- bindings/postgres/postgres.go | 30 ++++-- .../certification/bindings/postgres/README.md | 1 + .../bindings/postgres/postgres_test.go | 99 ++++++++++++------- 3 files changed, 87 insertions(+), 43 deletions(-) diff --git a/bindings/postgres/postgres.go b/bindings/postgres/postgres.go index 9919f9b1d0..637b88d344 100644 --- a/bindings/postgres/postgres.go +++ b/bindings/postgres/postgres.go @@ -36,7 +36,8 @@ const ( queryOperation bindings.OperationKind = "query" closeOperation bindings.OperationKind = "close" - commandSQLKey = "sql" + commandSQLKey = "sql" + commandArgsKey = "params" ) // Postgres represents PostgreSQL output binding. @@ -109,11 +110,22 @@ func (p *Postgres) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res return nil, errors.New("metadata required") } - sql, ok := req.Metadata[commandSQLKey] - if !ok || sql == "" { + // Metadata property "sql" contains the query to execute + sql := req.Metadata[commandSQLKey] + if sql == "" { return nil, fmt.Errorf("required metadata not set: %s", commandSQLKey) } + // Metadata property "params" contains JSON-encoded parameters, and it's optional + // If present, it must be unserializable into a []any object + var args []any + if argsStr := req.Metadata[commandArgsKey]; argsStr != "" { + err = json.Unmarshal([]byte(argsStr), &args) + if err != nil { + return nil, fmt.Errorf("invalid metadata property %s: failed to unserialize into an array: %w", commandArgsKey, err) + } + } + startTime := time.Now().UTC() resp = &bindings.InvokeResponse{ Metadata: map[string]string{ @@ -125,14 +137,14 @@ func (p *Postgres) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res switch req.Operation { //nolint:exhaustive case execOperation: - r, err := p.exec(ctx, sql) + r, err := p.exec(ctx, sql, args...) if err != nil { return nil, err } resp.Metadata["rows-affected"] = strconv.FormatInt(r, 10) // 0 if error case queryOperation: - d, err := p.query(ctx, sql) + d, err := p.query(ctx, sql, args...) if err != nil { return nil, err } @@ -168,8 +180,8 @@ func (p *Postgres) Close() error { return nil } -func (p *Postgres) query(ctx context.Context, sql string) (result []byte, err error) { - rows, err := p.db.Query(ctx, sql) +func (p *Postgres) query(ctx context.Context, sql string, args ...any) (result []byte, err error) { + rows, err := p.db.Query(ctx, sql, args...) if err != nil { return nil, fmt.Errorf("error executing query: %w", err) } @@ -191,8 +203,8 @@ func (p *Postgres) query(ctx context.Context, sql string) (result []byte, err er return result, nil } -func (p *Postgres) exec(ctx context.Context, sql string) (result int64, err error) { - res, err := p.db.Exec(ctx, sql) +func (p *Postgres) exec(ctx context.Context, sql string, args ...any) (result int64, err error) { + res, err := p.db.Exec(ctx, sql, args...) if err != nil { return 0, fmt.Errorf("error executing query: %w", err) } diff --git a/tests/certification/bindings/postgres/README.md b/tests/certification/bindings/postgres/README.md index 1759d74259..f30909c863 100644 --- a/tests/certification/bindings/postgres/README.md +++ b/tests/certification/bindings/postgres/README.md @@ -17,6 +17,7 @@ The purpose of this module is to provide tests that certify the PostgreSQL Outpu * Run dapr application with component to store data in postgres as output binding. * Read stored data from postgres. * Ensure that read data is same as the data that was stored. + * Verify the ability to use named paramters in queries. * Verify reconnection to postgres for output binding. * Simulate a network error before sending any messages. * Run dapr application with the component. diff --git a/tests/certification/bindings/postgres/postgres_test.go b/tests/certification/bindings/postgres/postgres_test.go index 6cc0052f98..deddaaccb2 100644 --- a/tests/certification/bindings/postgres/postgres_test.go +++ b/tests/certification/bindings/postgres/postgres_test.go @@ -55,31 +55,58 @@ func TestPostgres(t *testing.T) { testExec := func(ctx flow.Context) error { client, err := daprClient.NewClientWithPort(fmt.Sprintf("%d", grpcPort)) - require.NoError(t, err, "Could not initialize dapr client.") - - metadata := make(map[string]string) - - ctx.Log("Invoking output binding for exec operation!") - req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "exec", Metadata: metadata} - req.Metadata["sql"] = "INSERT INTO " + tableName + " (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05Z07:00');" - errBinding := client.InvokeOutputBinding(ctx, req) - require.NoError(ctx, errBinding, "error in output binding - exec") + require.NoError(t, err, "Could not initialize dapr client") + + ctx.Log("Invoking output binding for exec operation") + err = client.InvokeOutputBinding(ctx, &daprClient.InvokeBindingRequest{ + Name: "standard-binding", + Operation: "exec", + Metadata: map[string]string{ + "sql": "INSERT INTO " + tableName + " (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05+07:00');", + }, + }) + require.NoError(ctx, err, "error in output binding - exec") + + ctx.Log("Invoking output binding for exec operation with parameters") + err = client.InvokeOutputBinding(ctx, &daprClient.InvokeBindingRequest{ + Name: "standard-binding", + Operation: "exec", + Metadata: map[string]string{ + "sql": "INSERT INTO " + tableName + " (id, c1, ts) VALUES ($1, $2, $3);", + "params": `[2, "demo2", "2021-03-19T11:45:05+07:00"]`, + }, + }) + require.NoError(ctx, err, "error in output binding - exec") return nil } testQuery := func(ctx flow.Context) error { client, err := daprClient.NewClientWithPort(fmt.Sprintf("%d", grpcPort)) - require.NoError(t, err, "Could not initialize dapr client.") - - metadata := make(map[string]string) - - ctx.Log("Invoking output binding for query operation!") - req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "query", Metadata: metadata} - req.Metadata["sql"] = "SELECT * FROM " + tableName + " WHERE id = 1;" - resp, errBinding := client.InvokeBinding(ctx, req) - assert.Contains(t, string(resp.Data), "1,\"demo\",\"2020-09-24T11:45:05Z07:00\"") - require.NoError(ctx, errBinding, "error in output binding - query") + require.NoError(t, err, "Could not initialize dapr client") + + ctx.Log("Invoking output binding for query operation") + resp, err := client.InvokeBinding(ctx, &daprClient.InvokeBindingRequest{ + Name: "standard-binding", + Operation: "query", + Metadata: map[string]string{ + "sql": "SELECT * FROM " + tableName + " WHERE id = 1;", + }, + }) + assert.Equal(t, `[[1,"demo","2020-09-24T11:45:05Z"]]`, string(resp.Data)) + require.NoError(ctx, err, "error in output binding - query") + + ctx.Log("Invoking output binding for query operation with parameters") + resp, err = client.InvokeBinding(ctx, &daprClient.InvokeBindingRequest{ + Name: "standard-binding", + Operation: "query", + Metadata: map[string]string{ + "sql": "SELECT * FROM " + tableName + " WHERE id = ANY($1);", + "params": `[[1, 2]]`, + }, + }) + assert.Equal(t, `[[1,"demo","2020-09-24T11:45:05Z"],[2,"demo2","2021-03-19T11:45:05Z"]]`, string(resp.Data)) + require.NoError(ctx, err, "error in output binding - query") return nil } @@ -107,8 +134,8 @@ func TestPostgres(t *testing.T) { createTable := func(ctx flow.Context) error { db, err := sql.Open("pgx", dockerConnectionString) assert.NoError(t, err) - _, err = db.Exec("CREATE TABLE " + tableName + " (id INT, c1 TEXT, ts TEXT);") - assert.NoError(t, err) + _, err = db.Exec("CREATE TABLE " + tableName + " (id INT, c1 TEXT, ts TIMESTAMP);") + require.NoError(t, err) db.Close() return nil } @@ -140,14 +167,16 @@ func TestPostgresNetworkError(t *testing.T) { testExec := func(ctx flow.Context) error { client, err := daprClient.NewClientWithPort(fmt.Sprintf("%d", grpcPort)) - require.NoError(t, err, "Could not initialize dapr client.") - - metadata := make(map[string]string) + require.NoError(t, err, "Could not initialize dapr client") ctx.Log("Invoking output binding for exec operation!") - req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "exec", Metadata: metadata} - req.Metadata["sql"] = "INSERT INTO " + tableName + " (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05Z07:00');" - errBinding := client.InvokeOutputBinding(ctx, req) + errBinding := client.InvokeOutputBinding(ctx, &daprClient.InvokeBindingRequest{ + Name: "standard-binding", + Operation: "exec", + Metadata: map[string]string{ + "sql": "INSERT INTO " + tableName + " (id, c1, ts) VALUES (1, 'demo', '2020-09-24T11:45:05+07:00');", + }, + }) require.NoError(ctx, errBinding, "error in output binding - exec") return nil @@ -155,14 +184,16 @@ func TestPostgresNetworkError(t *testing.T) { testQuery := func(ctx flow.Context) error { client, err := daprClient.NewClientWithPort(fmt.Sprintf("%d", grpcPort)) - require.NoError(t, err, "Could not initialize dapr client.") - - metadata := make(map[string]string) + require.NoError(t, err, "Could not initialize dapr client") ctx.Log("Invoking output binding for query operation!") - req := &daprClient.InvokeBindingRequest{Name: "standard-binding", Operation: "query", Metadata: metadata} - req.Metadata["sql"] = "SELECT * FROM " + tableName + " WHERE id = 1;" - errBinding := client.InvokeOutputBinding(ctx, req) + errBinding := client.InvokeOutputBinding(ctx, &daprClient.InvokeBindingRequest{ + Name: "standard-binding", + Operation: "query", + Metadata: map[string]string{ + "sql": "SELECT * FROM " + tableName + " WHERE id = 1;", + }, + }) require.NoError(ctx, errBinding, "error in output binding - query") return nil @@ -171,7 +202,7 @@ func TestPostgresNetworkError(t *testing.T) { createTable := func(ctx flow.Context) error { db, err := sql.Open("pgx", dockerConnectionString) assert.NoError(t, err) - _, err = db.Exec("CREATE TABLE " + tableName + "(id INT, c1 TEXT, ts TEXT);") + _, err = db.Exec("CREATE TABLE " + tableName + " (id INT, c1 TEXT, ts TIMESTAMP);") assert.NoError(t, err) db.Close() return nil From 95045c4dfeeff85a9792955b87004a1e1a401c63 Mon Sep 17 00:00:00 2001 From: Shivam Kumar Singh Date: Fri, 14 Jul 2023 02:57:42 +0530 Subject: [PATCH 6/6] Add output binding for OpenAI (#2965) Signed-off-by: Shivam Kumar Singh Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> --- bindings/azure/openai/metadata.yaml | 38 ++++ bindings/azure/openai/openai.go | 326 ++++++++++++++++++++++++++++ go.mod | 3 +- go.sum | 6 +- tests/certification/go.mod | 2 +- tests/certification/go.sum | 4 +- 6 files changed, 373 insertions(+), 6 deletions(-) create mode 100644 bindings/azure/openai/metadata.yaml create mode 100644 bindings/azure/openai/openai.go diff --git a/bindings/azure/openai/metadata.yaml b/bindings/azure/openai/metadata.yaml new file mode 100644 index 0000000000..f579916940 --- /dev/null +++ b/bindings/azure/openai/metadata.yaml @@ -0,0 +1,38 @@ +# yaml-language-server: $schema=../../../component-metadata-schema.json +schemaVersion: v1 +type: bindings +name: azure.openai +version: v1 +status: alpha +title: "Azure OpenAI" +urls: + - title: Reference + url: https://docs.dapr.io/reference/components-reference/supported-bindings/azure-openai/ +binding: + output: true + input: false + operations: + - name: completion + description: "Text completion" + - name: chat-completion + description: "Chat completion" +builtinAuthenticationProfiles: + - name: "azuread" +authenticationProfiles: + - title: "API Key" + description: "Authenticate using an API key" + metadata: + - name: apiKey + required: true + sensitive: true + description: "API Key" + example: '"1234567890abcdef"' +metadata: + - name: endpoint + required: true + description: "Endpoint of the Azure OpenAI service" + example: '"https://myopenai.openai.azure.com"' + - name: deploymentID + required: true + description: "ID of the model deployment in the Azure OpenAI service" + example: '"my-model"' diff --git a/bindings/azure/openai/openai.go b/bindings/azure/openai/openai.go new file mode 100644 index 0000000000..83d189203a --- /dev/null +++ b/bindings/azure/openai/openai.go @@ -0,0 +1,326 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openai + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + + "github.com/dapr/components-contrib/bindings" + azauth "github.com/dapr/components-contrib/internal/authentication/azure" + "github.com/dapr/components-contrib/metadata" + "github.com/dapr/kit/config" + "github.com/dapr/kit/logger" +) + +// List of operations. +const ( + CompletionOperation bindings.OperationKind = "completion" + ChatCompletionOperation bindings.OperationKind = "chat-completion" + + APIKey = "apiKey" + DeploymentID = "deploymentID" + Endpoint = "endpoint" + MessagesKey = "messages" + Temperature = "temperature" + MaxTokens = "maxTokens" + TopP = "topP" + N = "n" + Stop = "stop" + FrequencyPenalty = "frequencyPenalty" + LogitBias = "logitBias" + User = "user" +) + +// AzOpenAI represents OpenAI output binding. +type AzOpenAI struct { + logger logger.Logger + client *azopenai.Client +} + +type openAIMetadata struct { + // APIKey is the API key for the Azure OpenAI API. + APIKey string `mapstructure:"apiKey"` + // DeploymentID is the deployment ID for the Azure OpenAI API. + DeploymentID string `mapstructure:"deploymentID"` + // Endpoint is the endpoint for the Azure OpenAI API. + Endpoint string `mapstructure:"endpoint"` +} + +type ChatSettings struct { + Temperature float32 `mapstructure:"temperature"` + MaxTokens int32 `mapstructure:"maxTokens"` + TopP float32 `mapstructure:"topP"` + N int32 `mapstructure:"n"` + PresencePenalty float32 `mapstructure:"presencePenalty"` + FrequencyPenalty float32 `mapstructure:"frequencyPenalty"` +} + +// ChatMessages type for chat completion API. +type ChatMessages struct { + Messages []Message `json:"messages"` + Temperature float32 `json:"temperature"` + MaxTokens int32 `json:"maxTokens"` + TopP float32 `json:"topP"` + N int32 `json:"n"` + PresencePenalty float32 `json:"presencePenalty"` + FrequencyPenalty float32 `json:"frequencyPenalty"` +} + +// Message type stores the messages for bot conversation. +type Message struct { + Role string + Message string +} + +// Prompt type for completion API. +type Prompt struct { + Prompt string `json:"prompt"` + Temperature float32 `json:"temperature"` + MaxTokens int32 `json:"maxTokens"` + TopP float32 `json:"topP"` + N int32 `json:"n"` + PresencePenalty float32 `json:"presencePenalty"` + FrequencyPenalty float32 `json:"frequencyPenalty"` +} + +// NewOpenAI returns a new OpenAI output binding. +func NewOpenAI(logger logger.Logger) bindings.OutputBinding { + return &AzOpenAI{ + logger: logger, + } +} + +// Init initializes the OpenAI binding. +func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error { + m := openAIMetadata{} + err := metadata.DecodeMetadata(meta.Properties, &m) + if err != nil { + return fmt.Errorf("error decoding metadata: %w", err) + } + if m.Endpoint == "" { + return fmt.Errorf("required metadata not set: %s", Endpoint) + } + if m.DeploymentID == "" { + return fmt.Errorf("required metadata not set: %s", DeploymentID) + } + + if m.APIKey != "" { + // use API key authentication + var keyCredential azopenai.KeyCredential + keyCredential, err = azopenai.NewKeyCredential(m.APIKey) + if err != nil { + return fmt.Errorf("error getting credentials object: %w", err) + } + + p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, m.DeploymentID, nil) + if err != nil { + return fmt.Errorf("error creating Azure OpenAI client: %w", err) + } + } else { + // fallback to Azure AD authentication + settings, innerErr := azauth.NewEnvironmentSettings(meta.Properties) + if innerErr != nil { + return fmt.Errorf("error creating environment settings: %w", innerErr) + } + + token, innerErr := settings.GetTokenCredential() + if innerErr != nil { + return fmt.Errorf("error getting token credential: %w", innerErr) + } + + p.client, err = azopenai.NewClient(m.Endpoint, token, m.DeploymentID, nil) + if err != nil { + return fmt.Errorf("error creating Azure OpenAI client: %w", err) + } + } + + return nil +} + +// Operations returns list of operations supported by OpenAI binding. +func (p *AzOpenAI) Operations() []bindings.OperationKind { + return []bindings.OperationKind{ + ChatCompletionOperation, + CompletionOperation, + } +} + +// Invoke handles all invoke operations. +func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (resp *bindings.InvokeResponse, err error) { + if req == nil || len(req.Metadata) == 0 { + return nil, fmt.Errorf("invalid request: metadata is required") + } + + startTime := time.Now().UTC() + resp = &bindings.InvokeResponse{ + Metadata: map[string]string{ + "operation": string(req.Operation), + "start-time": startTime.Format(time.RFC3339Nano), + }, + } + + switch req.Operation { //nolint:exhaustive + case CompletionOperation: + response, err := p.completion(ctx, req.Data, req.Metadata) + if err != nil { + return nil, fmt.Errorf("error performing completion: %w", err) + } + responseAsBytes, _ := json.Marshal(response) + resp.Data = responseAsBytes + + case ChatCompletionOperation: + response, err := p.chatCompletion(ctx, req.Data, req.Metadata) + if err != nil { + return nil, fmt.Errorf("error performing chat completion: %w", err) + } + responseAsBytes, _ := json.Marshal(response) + resp.Data = responseAsBytes + + default: + return nil, fmt.Errorf( + "invalid operation type: %s. Expected %s, %s", + req.Operation, CompletionOperation, ChatCompletionOperation, + ) + } + + endTime := time.Now().UTC() + resp.Metadata["end-time"] = endTime.Format(time.RFC3339Nano) + resp.Metadata["duration"] = endTime.Sub(startTime).String() + + return resp, nil +} + +func (s *ChatSettings) Decode(in any) error { + return config.Decode(in, s) +} + +func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) { + prompt := Prompt{ + Temperature: 1.0, + TopP: 1.0, + MaxTokens: 16, + N: 1, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + } + err = json.Unmarshal(message, &prompt) + if err != nil { + return nil, fmt.Errorf("error unmarshalling the input object: %w", err) + } + + if prompt.Prompt == "" { + return nil, fmt.Errorf("prompt is required for completion operation") + } + + resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{ + Prompt: []*string{&prompt.Prompt}, + MaxTokens: &prompt.MaxTokens, + Temperature: &prompt.Temperature, + TopP: &prompt.TopP, + N: &prompt.N, + }, nil) + if err != nil { + return nil, fmt.Errorf("error getting completion api: %w", err) + } + + // No choices returned + if len(resp.Completions.Choices) == 0 { + return []azopenai.Choice{}, nil + } + + choices := resp.Completions.Choices + response = make([]azopenai.Choice, len(choices)) + for i, c := range choices { + response[i] = *c + } + + return response, nil +} + +func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, metadata map[string]string) (response []azopenai.ChatChoice, err error) { + messages := ChatMessages{ + Temperature: 1.0, + TopP: 1.0, + N: 1, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + } + err = json.Unmarshal(messageRequest, &messages) + if err != nil { + return nil, fmt.Errorf("error unmarshalling the input object: %w", err) + } + + if len(messages.Messages) == 0 { + return nil, fmt.Errorf("messages are required for chat-completion operation") + } + + messageReq := make([]*azopenai.ChatMessage, len(messages.Messages)) + for i, m := range messages.Messages { + messageReq[i] = &azopenai.ChatMessage{ + Role: to.Ptr(azopenai.ChatRole(m.Role)), + Content: to.Ptr(m.Message), + } + } + + var maxTokens *int32 + if messages.MaxTokens != 0 { + maxTokens = &messages.MaxTokens + } + + res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{ + MaxTokens: maxTokens, + Temperature: &messages.Temperature, + TopP: &messages.TopP, + N: &messages.N, + Messages: messageReq, + }, nil) + if err != nil { + return nil, fmt.Errorf("error getting chat completion api: %w", err) + } + + // No choices returned. + if len(res.ChatCompletions.Choices) == 0 { + return []azopenai.ChatChoice{}, nil + } + + choices := res.ChatCompletions.Choices + response = make([]azopenai.ChatChoice, len(choices)) + for i, c := range choices { + response[i] = *c + } + + return response, nil +} + +// Close Az OpenAI instance. +func (p *AzOpenAI) Close() error { + p.client = nil + + return nil +} + +// GetComponentMetadata returns the metadata of the component. +func (p *AzOpenAI) GetComponentMetadata() map[string]string { + metadataStruct := openAIMetadata{} + metadataInfo := map[string]string{} + metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType) + return metadataInfo +} diff --git a/go.mod b/go.mod index 937f84ebff..98002f8790 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,9 @@ require ( cloud.google.com/go/secretmanager v1.10.0 cloud.google.com/go/storage v1.30.1 dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3 - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 + github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5 github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.0.1 diff --git a/go.sum b/go.sum index 33ff6e88ca..d952c1ce42 100644 --- a/go.sum +++ b/go.sum @@ -420,11 +420,13 @@ github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0 github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.2/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 h1:8q4SaHjFsClSvuVne0ID/5Ka8u3fcIHyqkLjcFpNRHQ= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5 h1:DQCZXtoCPuwBMlAa2aC+B3CfpE6xz2xe1jqdqt8nIJY= +github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5/go.mod h1:GQSjs1n073tbMa3e76+STZkyFb+NcEA4N7OB5vNvB3E= github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 h1:OrKZybbyagpgJiREiIVzH5mV/z9oS4rXqdX7i31DSF0= github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0/go.mod h1:p74+tP95m8830ypJk53L93+BEsjTKY4SKQ75J2NmS5U= github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 h1:qS0Bp4do0cIvnuQgSGeO6ZCu/q/HlRKl4NPfv1eJ2p0= diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 6c16649bd7..37efbf1384 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -52,7 +52,7 @@ require ( github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a // indirect github.com/AthenZ/athenz v1.10.39 // indirect github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 // indirect github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.0.1 // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 9f6d9eb0f8..39ae5032fe 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -70,8 +70,8 @@ github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0 github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.2/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 h1:8q4SaHjFsClSvuVne0ID/5Ka8u3fcIHyqkLjcFpNRHQ= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=