From c133955e3b26ffc0c8b748f11f8b48619a36871e Mon Sep 17 00:00:00 2001 From: pigletfly Date: Fri, 29 Apr 2022 11:22:36 +0800 Subject: [PATCH] Add ctx in state Signed-off-by: pigletfly --- state/aerospike/aerospike.go | 9 ++- state/alicloud/tablestore/tablestore.go | 15 ++-- state/alicloud/tablestore/tablestore_test.go | 17 +++-- state/aws/dynamodb/dynamodb.go | 15 ++-- state/aws/dynamodb/dynamodb_test.go | 41 +++++----- state/azure/blobstorage/blobstorage.go | 28 +++---- state/azure/cosmosdb/cosmosdb.go | 9 ++- state/azure/tablestorage/tablestorage.go | 9 ++- state/cassandra/cassandra.go | 9 ++- state/cockroachdb/cockroachdb.go | 32 ++++---- state/cockroachdb/cockroachdb_access.go | 31 ++++---- state/cockroachdb/cockroachdb_access_test.go | 29 +++---- .../cockroachdb_integration_test.go | 31 ++++---- state/cockroachdb/cockroachdb_test.go | 15 ++-- state/cockroachdb/dbaccess.go | 20 +++-- state/couchbase/couchbase.go | 9 ++- state/gcp/firestore/firestore.go | 18 ++--- state/hashicorp/consul/consul.go | 9 ++- state/hazelcast/hazelcast.go | 9 ++- state/jetstream/jetstream.go | 9 ++- state/jetstream/jetstream_test.go | 11 +-- state/memcached/memcached.go | 13 ++-- state/mongodb/mongodb.go | 16 ++-- state/mysql/mysql.go | 33 ++++---- state/mysql/mysql_integration_test.go | 31 ++++---- state/mysql/mysql_test.go | 63 ++++++++-------- state/oci/objectstorage/objectstorage.go | 23 +++--- .../objectstorage_integration_test.go | 75 ++++++++++--------- state/oci/objectstorage/objectstorage_test.go | 43 ++++++----- state/oracledatabase/dbaccess.go | 12 +-- state/oracledatabase/oracledatabase.go | 31 ++++---- .../oracledatabase_integration_test.go | 55 +++++++------- state/oracledatabase/oracledatabase_test.go | 27 +++---- state/oracledatabase/oracledatabaseaccess.go | 23 +++--- state/postgresql/dbaccess.go | 14 ++-- state/postgresql/postgresdbaccess.go | 29 +++---- state/postgresql/postgresdbaccess_test.go | 29 +++---- state/postgresql/postgresql.go | 30 ++++---- .../postgresql/postgresql_integration_test.go | 31 ++++---- state/postgresql/postgresql_test.go | 13 ++-- state/redis/redis.go | 28 +++---- state/redis/redis_test.go | 22 +++--- state/request_options.go | 9 ++- state/request_options_test.go | 9 ++- state/rethinkdb/rethinkdb.go | 25 ++++--- state/rethinkdb/rethinkdb_test.go | 37 ++++----- state/sqlserver/sqlserver.go | 15 ++-- state/sqlserver/sqlserver_integration_test.go | 49 ++++++------ state/store.go | 26 ++++--- state/store_test.go | 65 ++++++++-------- state/zookeeper/zk.go | 23 +++--- state/zookeeper/zk_test.go | 37 ++++----- tests/conformance/state/state.go | 67 +++++++++-------- 53 files changed, 717 insertions(+), 661 deletions(-) diff --git a/state/aerospike/aerospike.go b/state/aerospike/aerospike.go index ceeb143dea..367c87b8a5 100644 --- a/state/aerospike/aerospike.go +++ b/state/aerospike/aerospike.go @@ -14,6 +14,7 @@ limitations under the License. package aerospike import ( + "context" "encoding/json" "errors" "fmt" @@ -110,7 +111,7 @@ func (aspike *Aerospike) Features() []state.Feature { } // Set stores value for a key to Aerospike. It honors ETag (for concurrency) and consistency settings. -func (aspike *Aerospike) Set(req *state.SetRequest) error { +func (aspike *Aerospike) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -162,7 +163,7 @@ func (aspike *Aerospike) Set(req *state.SetRequest) error { } // Get retrieves state from Aerospike with a key. -func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (aspike *Aerospike) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { asKey, err := as.NewKey(aspike.namespace, aspike.set, req.Key) if err != nil { return nil, err @@ -196,7 +197,7 @@ func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error) } // Delete performs a delete operation. -func (aspike *Aerospike) Delete(req *state.DeleteRequest) error { +func (aspike *Aerospike) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -238,7 +239,7 @@ func (aspike *Aerospike) Delete(req *state.DeleteRequest) error { return nil } -func (aspike *Aerospike) Ping() error { +func (aspike *Aerospike) Ping(ctx context.Context) error { return nil } diff --git a/state/alicloud/tablestore/tablestore.go b/state/alicloud/tablestore/tablestore.go index ef97327434..0acc46b713 100644 --- a/state/alicloud/tablestore/tablestore.go +++ b/state/alicloud/tablestore/tablestore.go @@ -14,6 +14,7 @@ limitations under the License. package tablestore import ( + "context" "encoding/json" "github.com/agrea/ptr" @@ -68,7 +69,7 @@ func (s *AliCloudTableStore) Features() []state.Feature { return s.features } -func (s *AliCloudTableStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *AliCloudTableStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { criteria := &tablestore.SingleRowQueryCriteria{ PrimaryKey: s.primaryKey(req.Key), TableName: s.metadata.TableName, @@ -103,7 +104,7 @@ func (s *AliCloudTableStore) getResp(columns []*tablestore.AttributeColumn) *sta return getResp } -func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *AliCloudTableStore) BulkGet(ctx context.Context, reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) { // "len == 0": empty request, directly return empty response if len(reqs) == 0 { return true, []state.BulkGetResponse{}, nil @@ -139,7 +140,7 @@ func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.Bul return true, responseList, nil } -func (s *AliCloudTableStore) Set(req *state.SetRequest) error { +func (s *AliCloudTableStore) Set(ctx context.Context, req *state.SetRequest) error { change := s.updateRowChange(req) request := &tablestore.UpdateRowRequest{ @@ -183,7 +184,7 @@ func unmarshal(val interface{}) []byte { return []byte(output) } -func (s *AliCloudTableStore) Delete(req *state.DeleteRequest) error { +func (s *AliCloudTableStore) Delete(ctx context.Context, req *state.DeleteRequest) error { change := s.deleteRowChange(req) deleteRowReq := &tablestore.DeleteRowRequest{ @@ -205,11 +206,11 @@ func (s *AliCloudTableStore) deleteRowChange(req *state.DeleteRequest) *tablesto return change } -func (s *AliCloudTableStore) BulkSet(reqs []state.SetRequest) error { +func (s *AliCloudTableStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error { return s.batchWrite(reqs, nil) } -func (s *AliCloudTableStore) BulkDelete(reqs []state.DeleteRequest) error { +func (s *AliCloudTableStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error { return s.batchWrite(nil, reqs) } @@ -234,7 +235,7 @@ func (s *AliCloudTableStore) batchWrite(setReqs []state.SetRequest, deleteReqs [ return nil } -func (s *AliCloudTableStore) Ping() error { +func (s *AliCloudTableStore) Ping(ctx context.Context) error { return nil } diff --git a/state/alicloud/tablestore/tablestore_test.go b/state/alicloud/tablestore/tablestore_test.go index 7ded912b66..c7e6835272 100644 --- a/state/alicloud/tablestore/tablestore_test.go +++ b/state/alicloud/tablestore/tablestore_test.go @@ -14,6 +14,7 @@ limitations under the License. package tablestore import ( + "context" "testing" "github.com/agrea/ptr" @@ -63,7 +64,7 @@ func TestReadAndWrite(t *testing.T) { Value: "value of key", ETag: ptr.String("the etag"), } - err := store.Set(setReq) + err := store.Set(context.TODO(), setReq) assert.Nil(t, err) }) @@ -71,7 +72,7 @@ func TestReadAndWrite(t *testing.T) { getReq := &state.GetRequest{ Key: "theFirstKey", } - resp, err := store.Get(getReq) + resp, err := store.Get(context.TODO(), getReq) assert.Nil(t, err) assert.NotNil(t, resp) assert.Equal(t, "value of key", string(resp.Data)) @@ -83,7 +84,7 @@ func TestReadAndWrite(t *testing.T) { Value: "1234", ETag: ptr.String("the etag"), } - err := store.Set(setReq) + err := store.Set(context.TODO(), setReq) assert.Nil(t, err) }) @@ -91,14 +92,14 @@ func TestReadAndWrite(t *testing.T) { getReq := &state.GetRequest{ Key: "theSecondKey", } - resp, err := store.Get(getReq) + resp, err := store.Get(context.TODO(), getReq) assert.Nil(t, err) assert.NotNil(t, resp) assert.Equal(t, "1234", string(resp.Data)) }) t.Run("test BulkSet", func(t *testing.T) { - err := store.BulkSet([]state.SetRequest{{ + err := store.BulkSet(context.TODO(), []state.SetRequest{{ Key: "theFirstKey", Value: "666", }, { @@ -110,7 +111,7 @@ func TestReadAndWrite(t *testing.T) { }) t.Run("test BulkGet", func(t *testing.T) { - _, resp, err := store.BulkGet([]state.GetRequest{{ + _, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{ Key: "theFirstKey", }, { Key: "theSecondKey", @@ -126,12 +127,12 @@ func TestReadAndWrite(t *testing.T) { req := &state.DeleteRequest{ Key: "theFirstKey", } - err := store.Delete(req) + err := store.Delete(context.TODO(), req) assert.Nil(t, err) }) t.Run("test BulkGet2", func(t *testing.T) { - _, resp, err := store.BulkGet([]state.GetRequest{{ + _, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{ Key: "theFirstKey", }, { Key: "theSecondKey", diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index f680954223..bb7d8c7db3 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -14,6 +14,7 @@ limitations under the License. package dynamodb import ( + "context" "encoding/json" "fmt" "strconv" @@ -70,7 +71,7 @@ func (d *StateStore) Init(metadata state.Metadata) error { return nil } -func (d *StateStore) Ping() error { +func (d *StateStore) Ping(ctx context.Context) error { return nil } @@ -80,7 +81,7 @@ func (d *StateStore) Features() []state.Feature { } // Get retrieves a dynamoDB item. -func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { input := &dynamodb.GetItemInput{ ConsistentRead: aws.Bool(req.Options.Consistency == state.Strong), TableName: aws.String(d.table), @@ -124,13 +125,13 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulk get operations. -func (d *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (d *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with dynamodb.BatchGetItem for performance return false, nil, nil } // Set saves a dynamoDB item. -func (d *StateStore) Set(req *state.SetRequest) error { +func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { value, err := d.marshalToString(req.Value) if err != nil { return fmt.Errorf("dynamodb error: failed to set key %s: %s", req.Key, err) @@ -176,7 +177,7 @@ func (d *StateStore) Set(req *state.SetRequest) error { } // BulkSet performs a bulk set operation. -func (d *StateStore) BulkSet(req []state.SetRequest) error { +func (d *StateStore) BulkSet(ctx context.Context, req []state.SetRequest) error { writeRequests := []*dynamodb.WriteRequest{} for _, r := range req { @@ -234,7 +235,7 @@ func (d *StateStore) BulkSet(req []state.SetRequest) error { } // Delete performs a delete operation. -func (d *StateStore) Delete(req *state.DeleteRequest) error { +func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { input := &dynamodb.DeleteItemInput{ Key: map[string]*dynamodb.AttributeValue{ "key": { @@ -249,7 +250,7 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error { } // BulkDelete performs a bulk delete operation. -func (d *StateStore) BulkDelete(req []state.DeleteRequest) error { +func (d *StateStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { writeRequests := []*dynamodb.WriteRequest{} for _, r := range req { diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index abcf01a555..85e8e0258b 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -13,6 +13,7 @@ limitations under the License. package dynamodb import ( + "context" "fmt" "strconv" "testing" @@ -117,7 +118,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Equal(t, []byte("some value"), out.Data) }) @@ -149,7 +150,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Equal(t, []byte("some value"), out.Data) }) @@ -181,7 +182,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Nil(t, out.Data) }) @@ -200,7 +201,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.NotNil(t, err) assert.Nil(t, out) }) @@ -221,7 +222,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Nil(t, out.Data) }) @@ -246,7 +247,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Empty(t, out.Data) }) @@ -287,7 +288,7 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) @@ -323,7 +324,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "-1", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { @@ -358,7 +359,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) @@ -376,7 +377,7 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.NotNil(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { @@ -412,7 +413,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { @@ -451,7 +452,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "invalidvalue", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.NotNil(t, err) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) }) @@ -517,7 +518,7 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.Nil(t, err) }) t.Run("Successfully set items with ttl = -1", func(t *testing.T) { @@ -582,7 +583,7 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.Nil(t, err) }) t.Run("Successfully set items with ttl", func(t *testing.T) { @@ -649,7 +650,7 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully set items", func(t *testing.T) { @@ -668,7 +669,7 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.NotNil(t, err) }) } @@ -692,7 +693,7 @@ func TestDelete(t *testing.T) { }, }, } - err := ss.Delete(req) + err := ss.Delete(context.TODO(), req) assert.Nil(t, err) }) @@ -707,7 +708,7 @@ func TestDelete(t *testing.T) { req := &state.DeleteRequest{ Key: "key", } - err := ss.Delete(req) + err := ss.Delete(context.TODO(), req) assert.NotNil(t, err) }) } @@ -756,7 +757,7 @@ func TestBulkDelete(t *testing.T) { Key: "key2", }, } - err := ss.BulkDelete(req) + err := ss.BulkDelete(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully delete items", func(t *testing.T) { @@ -772,7 +773,7 @@ func TestBulkDelete(t *testing.T) { Key: "key", }, } - err := ss.BulkDelete(req) + err := ss.BulkDelete(context.TODO(), req) assert.NotNil(t, err) }) } diff --git a/state/azure/blobstorage/blobstorage.go b/state/azure/blobstorage/blobstorage.go index 14b3ef69d7..1f1ce7271a 100644 --- a/state/azure/blobstorage/blobstorage.go +++ b/state/azure/blobstorage/blobstorage.go @@ -132,16 +132,16 @@ func (r *StateStore) Features() []state.Feature { } // Delete the state. -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r.logger.Debugf("delete %s", req.Key) - return r.deleteFile(req) + return r.deleteFile(ctx, req) } // Get the state. -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { r.logger.Debugf("fetching %s", req.Key) - data, etag, contentType, err := r.readFile(req) + data, etag, contentType, err := r.readFile(ctx, req) if err != nil { r.logger.Debugf("error %s", err) @@ -160,16 +160,16 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set the state. -func (r *StateStore) Set(req *state.SetRequest) error { +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r.logger.Debugf("saving %s", req.Key) - return r.writeFile(req) + return r.writeFile(ctx, req) } -func (r *StateStore) Ping() error { +func (r *StateStore) Ping(ctx context.Context) error { accessConditions := azblob.BlobAccessConditions{} - if _, err := r.containerURL.GetProperties(context.Background(), accessConditions.LeaseAccessConditions); err != nil { + if _, err := r.containerURL.GetProperties(ctx, accessConditions.LeaseAccessConditions); err != nil { return fmt.Errorf("blob storage: error connecting to Blob storage at %s: %s", r.containerURL.URL().Host, err) } @@ -206,10 +206,10 @@ func getBlobStorageMetadata(metadata map[string]string) (*blobStorageMetadata, e return &meta, nil } -func (r *StateStore) readFile(req *state.GetRequest) ([]byte, string, *string, error) { +func (r *StateStore) readFile(ctx context.Context, req *state.GetRequest) ([]byte, string, *string, error) { blobURL := r.containerURL.NewBlockBlobURL(getFileName(req.Key)) - resp, err := blobURL.Download(context.Background(), 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false) + resp, err := blobURL.Download(ctx, 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false) if err != nil { r.logger.Debugf("download file %s, err %s", req.Key, err) @@ -230,7 +230,7 @@ func (r *StateStore) readFile(req *state.GetRequest) ([]byte, string, *string, e return data.Bytes(), string(resp.ETag()), &contentType, nil } -func (r *StateStore) writeFile(req *state.SetRequest) error { +func (r *StateStore) writeFile(ctx context.Context, req *state.SetRequest) error { accessConditions := azblob.BlobAccessConditions{} if req.Options.Concurrency == state.FirstWrite && req.ETag != nil { @@ -247,7 +247,7 @@ func (r *StateStore) writeFile(req *state.SetRequest) error { if err != nil { return err } - _, err = azblob.UploadBufferToBlockBlob(context.Background(), r.marshal(req), blobURL, azblob.UploadToBlockBlobOptions{ + _, err = azblob.UploadBufferToBlockBlob(ctx, r.marshal(req), blobURL, azblob.UploadToBlockBlobOptions{ Parallelism: 16, Metadata: req.Metadata, AccessConditions: accessConditions, @@ -307,7 +307,7 @@ func (r *StateStore) createBlobHTTPHeadersFromRequest(req *state.SetRequest) (az return blobHTTPHeaders, nil } -func (r *StateStore) deleteFile(req *state.DeleteRequest) error { +func (r *StateStore) deleteFile(ctx context.Context, req *state.DeleteRequest) error { blobURL := r.containerURL.NewBlockBlobURL(getFileName(req.Key)) accessConditions := azblob.BlobAccessConditions{} @@ -319,7 +319,7 @@ func (r *StateStore) deleteFile(req *state.DeleteRequest) error { accessConditions.IfMatch = azblob.ETag(etag) } - _, err := blobURL.Delete(context.Background(), azblob.DeleteSnapshotsOptionNone, accessConditions) + _, err := blobURL.Delete(ctx, azblob.DeleteSnapshotsOptionNone, accessConditions) if err != nil { r.logger.Debugf("delete file %s, err %s", req.Key, err) diff --git a/state/azure/cosmosdb/cosmosdb.go b/state/azure/cosmosdb/cosmosdb.go index d74c3ecaef..1a6c47de3c 100644 --- a/state/azure/cosmosdb/cosmosdb.go +++ b/state/azure/cosmosdb/cosmosdb.go @@ -15,6 +15,7 @@ package cosmosdb import ( // For go:embed. + "context" _ "embed" "encoding/base64" "encoding/json" @@ -212,7 +213,7 @@ func (c *StateStore) Features() []state.Feature { } // Get retrieves a CosmosDB item. -func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (c *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { key := req.Key partitionKey := populatePartitionMetadata(req.Key, req.Metadata) @@ -269,7 +270,7 @@ func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set saves a CosmosDB item. -func (c *StateStore) Set(req *state.SetRequest) error { +func (c *StateStore) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -321,7 +322,7 @@ func (c *StateStore) Set(req *state.SetRequest) error { } // Delete performs a delete operation. -func (c *StateStore) Delete(req *state.DeleteRequest) error { +func (c *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -463,7 +464,7 @@ func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error }, nil } -func (c *StateStore) Ping() error { +func (c *StateStore) Ping(ctx context.Context) error { return retryOperation(func() error { _, innerErr := c.findCollection() if innerErr != nil { diff --git a/state/azure/tablestorage/tablestorage.go b/state/azure/tablestorage/tablestorage.go index c42fdbff4e..e65a848bbc 100644 --- a/state/azure/tablestorage/tablestorage.go +++ b/state/azure/tablestorage/tablestorage.go @@ -38,6 +38,7 @@ Concurrency is supported with ETags according to https://docs.microsoft.com/en-u package tablestorage import ( + "context" "fmt" "strings" @@ -110,7 +111,7 @@ func (r *StateStore) Features() []state.Feature { return r.features } -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r.logger.Debugf("delete %s", req.Key) err := r.deleteRow(req) @@ -126,7 +127,7 @@ func (r *StateStore) Delete(req *state.DeleteRequest) error { return err } -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { r.logger.Debugf("fetching %s", req.Key) pk, rk := getPartitionAndRowKey(req.Key) entity := r.table.GetEntityReference(pk, rk) @@ -147,7 +148,7 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, err } -func (r *StateStore) Set(req *state.SetRequest) error { +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r.logger.Debugf("saving %s", req.Key) err := r.writeRow(req) @@ -275,7 +276,7 @@ func (r *StateStore) deleteRow(req *state.DeleteRequest) error { return entity.Delete(true, nil) } -func (r *StateStore) Ping() error { +func (r *StateStore) Ping(ctx context.Context) error { return nil } diff --git a/state/cassandra/cassandra.go b/state/cassandra/cassandra.go index 0b9c3d81b9..e1b51b18d1 100644 --- a/state/cassandra/cassandra.go +++ b/state/cassandra/cassandra.go @@ -14,6 +14,7 @@ limitations under the License. package cassandra import ( + "context" "errors" "fmt" "strconv" @@ -230,12 +231,12 @@ func getCassandraMetadata(metadata state.Metadata) (*cassandraMetadata, error) { } // Delete performs a delete operation. -func (c *Cassandra) Delete(req *state.DeleteRequest) error { +func (c *Cassandra) Delete(ctx context.Context, req *state.DeleteRequest) error { return c.session.Query(fmt.Sprintf("DELETE FROM %s WHERE key = ?", c.table), req.Key).Exec() } // Get retrieves state from cassandra with a key. -func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (c *Cassandra) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { session := c.session if req.Options.Consistency == state.Strong { @@ -269,7 +270,7 @@ func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set saves state into cassandra. -func (c *Cassandra) Set(req *state.SetRequest) error { +func (c *Cassandra) Set(ctx context.Context, req *state.SetRequest) error { var bt []byte b, ok := req.Value.([]byte) if ok { @@ -308,7 +309,7 @@ func (c *Cassandra) Set(req *state.SetRequest) error { return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?)", c.table), req.Key, bt).Exec() } -func (c *Cassandra) Ping() error { +func (c *Cassandra) Ping(ctx context.Context) error { return nil } diff --git a/state/cockroachdb/cockroachdb.go b/state/cockroachdb/cockroachdb.go index 189a8e17da..f0acdc0dc0 100644 --- a/state/cockroachdb/cockroachdb.go +++ b/state/cockroachdb/cockroachdb.go @@ -14,6 +14,8 @@ limitations under the License. package cockroachdb import ( + "context" + "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" ) @@ -53,44 +55,44 @@ func (c *CockroachDB) Features() []state.Feature { } // Delete removes an entity from the store. -func (c *CockroachDB) Delete(req *state.DeleteRequest) error { - return c.dbaccess.Delete(req) +func (c *CockroachDB) Delete(ctx context.Context, req *state.DeleteRequest) error { + return c.dbaccess.Delete(ctx, req) } // Get returns an entity from store. -func (c *CockroachDB) Get(req *state.GetRequest) (*state.GetResponse, error) { - return c.dbaccess.Get(req) +func (c *CockroachDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + return c.dbaccess.Get(ctx, req) } // Set adds/updates an entity on store. -func (c *CockroachDB) Set(req *state.SetRequest) error { - return c.dbaccess.Set(req) +func (c *CockroachDB) Set(ctx context.Context, req *state.SetRequest) error { + return c.dbaccess.Set(ctx, req) } // Ping checks if database is available. -func (c *CockroachDB) Ping() error { - return c.dbaccess.Ping() +func (c *CockroachDB) Ping(ctx context.Context) error { + return c.dbaccess.Ping(ctx) } // BulkDelete removes multiple entries from the store. -func (c *CockroachDB) BulkDelete(req []state.DeleteRequest) error { - return c.dbaccess.BulkDelete(req) +func (c *CockroachDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { + return c.dbaccess.BulkDelete(ctx, req) } // BulkGet performs a bulks get operations. -func (c *CockroachDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (c *CockroachDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with ExecuteMulti for performance. return false, nil, nil } // BulkSet adds/updates multiple entities on store. -func (c *CockroachDB) BulkSet(req []state.SetRequest) error { - return c.dbaccess.BulkSet(req) +func (c *CockroachDB) BulkSet(ctx context.Context, req []state.SetRequest) error { + return c.dbaccess.BulkSet(ctx, req) } // Multi handles multiple transactions. Implements TransactionalStore. -func (c *CockroachDB) Multi(request *state.TransactionalStateRequest) error { - return c.dbaccess.ExecuteMulti(request) +func (c *CockroachDB) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { + return c.dbaccess.ExecuteMulti(ctx, request) } // Query executes a query against store. diff --git a/state/cockroachdb/cockroachdb_access.go b/state/cockroachdb/cockroachdb_access.go index 36d1e0f9fc..9943ad2b78 100644 --- a/state/cockroachdb/cockroachdb_access.go +++ b/state/cockroachdb/cockroachdb_access.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/agrea/ptr" + "golang.org/x/net/context" "github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state/query" @@ -95,12 +96,12 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error { } // Set makes an insert or update to the database. -func (p *cockroachDBAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(p.setValue, req) +func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(p.setValue, ctx, req) } // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { +func (p *cockroachDBAccess) setValue(ctx context.Context, req *state.SetRequest) error { p.logger.Debug("Setting state value in CockroachDB") value, isBinary, err := validateAndReturnValue(req) @@ -148,7 +149,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { return nil } -func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error { +func (p *cockroachDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error { p.logger.Debug("Executing BulkSet request") tx, err := p.db.Begin() if err != nil { @@ -158,7 +159,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error { if len(req) > 0 { for _, s := range req { sa := s // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Set(&sa) + err = p.Set(ctx, &sa) if err != nil { tx.Rollback() @@ -173,7 +174,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error { } // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. -func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (p *cockroachDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { p.logger.Debug("Getting state value from CockroachDB") if req.Key == "" { return nil, fmt.Errorf("missing key in get operation") @@ -221,12 +222,12 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro } // Delete removes an item from the state store. -func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(p.deleteValue, req) +func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(p.deleteValue, ctx, req) } // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { +func (p *cockroachDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error { p.logger.Debug("Deleting state value from CockroachDB") if req.Key == "" { return fmt.Errorf("missing key in delete operation") @@ -264,7 +265,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { return nil } -func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error { +func (p *cockroachDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { p.logger.Debug("Executing BulkDelete request") tx, err := p.db.Begin() if err != nil { @@ -274,7 +275,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error { if len(req) > 0 { for _, d := range req { da := d // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Delete(&da) + err = p.Delete(ctx, &da) if err != nil { tx.Rollback() @@ -288,7 +289,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error { return err } -func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error { +func (p *cockroachDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error { p.logger.Debug("Executing PostgreSQL transaction") tx, err := p.db.Begin() @@ -307,7 +308,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques return err } - err = p.Set(&setReq) + err = p.Set(ctx, &setReq) if err != nil { tx.Rollback() return err @@ -322,7 +323,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques return err } - err = p.Delete(&delReq) + err = p.Delete(ctx, &delReq) if err != nil { tx.Rollback() return err @@ -377,7 +378,7 @@ func (p *cockroachDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse } // Ping implements database ping. -func (p *cockroachDBAccess) Ping() error { +func (p *cockroachDBAccess) Ping(ctx context.Context) error { return p.db.Ping() } diff --git a/state/cockroachdb/cockroachdb_access_test.go b/state/cockroachdb/cockroachdb_access_test.go index e59fb8c861..c6eb831073 100644 --- a/state/cockroachdb/cockroachdb_access_test.go +++ b/state/cockroachdb/cockroachdb_access_test.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "database/sql" "testing" @@ -109,7 +110,7 @@ func TestMultiWithNoRequests(t *testing.T) { var operations []state.TransactionalStateOperation // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -133,7 +134,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -158,7 +159,7 @@ func TestValidSetRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -182,7 +183,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -206,7 +207,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -231,7 +232,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -255,7 +256,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -279,7 +280,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -311,7 +312,7 @@ func TestMultiOperationOrder(t *testing.T) { ) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -334,7 +335,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) { }) // Act - err := m.roachDba.BulkSet(sets) + err := m.roachDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -356,7 +357,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) { }) // Act - err := m.roachDba.BulkSet(sets) + err := m.roachDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -379,7 +380,7 @@ func TestValidBulkSet(t *testing.T) { }) // Act - err := m.roachDba.BulkSet(sets) + err := m.roachDba.BulkSet(context.TODO(), sets) // Assert assert.Nil(t, err) @@ -400,7 +401,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) { }) // Act - err := m.roachDba.BulkDelete(deletes) + err := m.roachDba.BulkDelete(context.TODO(), deletes) // Assert assert.NotNil(t, err) @@ -422,7 +423,7 @@ func TestValidBulkDelete(t *testing.T) { }) // Act - err := m.roachDba.BulkDelete(deletes) + err := m.roachDba.BulkDelete(context.TODO(), deletes) // Assert assert.Nil(t, err) diff --git a/state/cockroachdb/cockroachdb_integration_test.go b/state/cockroachdb/cockroachdb_integration_test.go index b0cd2fa69b..49513d149a 100644 --- a/state/cockroachdb/cockroachdb_integration_test.go +++ b/state/cockroachdb/cockroachdb_integration_test.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "database/sql" "encoding/json" "fmt" @@ -210,7 +211,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *CockroachDB) { Consistency: "", }, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, err) } @@ -238,7 +239,7 @@ func multiWithSetOnly(t *testing.T, pgs *CockroachDB) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, Metadata: nil, }) @@ -279,7 +280,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *CockroachDB) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, Metadata: nil, }) @@ -340,7 +341,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *CockroachDB) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, Metadata: nil, }) @@ -375,7 +376,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *CockroachDB) { Consistency: "", }, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -391,7 +392,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *CockroachDB) { Consistency: "", }, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -414,7 +415,7 @@ func newItemWithEtagFails(t *testing.T, pgs *CockroachDB) { ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -448,7 +449,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *CockroachDB) { }, ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -500,7 +501,7 @@ func getItemWithNoKey(t *testing.T, pgs *CockroachDB) { }, } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) } @@ -543,7 +544,7 @@ func setItemWithNoKey(t *testing.T, pgs *CockroachDB) { ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -562,7 +563,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) { }, } - err := pgs.BulkSet(setReq) + err := pgs.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -576,7 +577,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) { }, } - err = pgs.BulkDelete(deleteReq) + err = pgs.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -643,7 +644,7 @@ func setItem(t *testing.T, pgs *CockroachDB, key string, value interface{}, etag ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.Nil(t, err) itemExists := storeItemExists(t, key) assert.True(t, itemExists) @@ -660,7 +661,7 @@ func getItem(t *testing.T, pgs *CockroachDB, key string) (*state.GetResponse, *f Metadata: map[string]string{}, } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{ @@ -684,7 +685,7 @@ func deleteItem(t *testing.T, pgs *CockroachDB, key string, etag *string) { Metadata: map[string]string{}, } - deleteErr := pgs.Delete(deleteReq) + deleteErr := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr) assert.False(t, storeItemExists(t, key)) } diff --git a/state/cockroachdb/cockroachdb_test.go b/state/cockroachdb/cockroachdb_test.go index 173cae166f..64006f5c5d 100644 --- a/state/cockroachdb/cockroachdb_test.go +++ b/state/cockroachdb/cockroachdb_test.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -41,33 +42,33 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error { return nil } -func (m *fakeDBaccess) Set(req *state.SetRequest) error { +func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error { m.setExecuted = true return nil } -func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.getExecuted = true return nil, nil } -func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { +func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { m.deleteExecuted = true return nil } -func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error { +func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error { return nil } -func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error { +func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { return nil } -func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error { +func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error { return nil } @@ -79,7 +80,7 @@ func (m *fakeDBaccess) Close() error { return nil } -func (m *fakeDBaccess) Ping() error { +func (m *fakeDBaccess) Ping(ctx context.Context) error { return nil } diff --git a/state/cockroachdb/dbaccess.go b/state/cockroachdb/dbaccess.go index d89af8a545..129ea32946 100644 --- a/state/cockroachdb/dbaccess.go +++ b/state/cockroachdb/dbaccess.go @@ -13,18 +13,22 @@ limitations under the License. package cockroachdb -import "github.com/dapr/components-contrib/state" +import ( + "context" + + "github.com/dapr/components-contrib/state" +) // dbAccess is a private interface which enables unit testing of CockroachDB. type dbAccess interface { Init(metadata state.Metadata) error - Set(req *state.SetRequest) error - BulkSet(req []state.SetRequest) error - Get(req *state.GetRequest) (*state.GetResponse, error) - Delete(req *state.DeleteRequest) error - BulkDelete(req []state.DeleteRequest) error - ExecuteMulti(req *state.TransactionalStateRequest) error + Set(ctx context.Context, req *state.SetRequest) error + BulkSet(ctx context.Context, req []state.SetRequest) error + Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) + Delete(ctx context.Context, req *state.DeleteRequest) error + BulkDelete(ctx context.Context, req []state.DeleteRequest) error + ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error Query(req *state.QueryRequest) (*state.QueryResponse, error) - Ping() error + Ping(ctx context.Context) error Close() error } diff --git a/state/couchbase/couchbase.go b/state/couchbase/couchbase.go index 86d72d2b8e..89bf64b00a 100644 --- a/state/couchbase/couchbase.go +++ b/state/couchbase/couchbase.go @@ -14,6 +14,7 @@ limitations under the License. package couchbase import ( + "context" "errors" "fmt" "strconv" @@ -144,7 +145,7 @@ func (cbs *Couchbase) Features() []state.Feature { } // Set stores value for a key to couchbase. It honors ETag (for concurrency) and consistency settings. -func (cbs *Couchbase) Set(req *state.SetRequest) error { +func (cbs *Couchbase) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -188,7 +189,7 @@ func (cbs *Couchbase) Set(req *state.SetRequest) error { } // Get retrieves state from couchbase with a key. -func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (cbs *Couchbase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { var data interface{} cas, err := cbs.bucket.Get(req.Key, &data) if err != nil { @@ -206,7 +207,7 @@ func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Delete performs a delete operation. -func (cbs *Couchbase) Delete(req *state.DeleteRequest) error { +func (cbs *Couchbase) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -236,7 +237,7 @@ func (cbs *Couchbase) Delete(req *state.DeleteRequest) error { return nil } -func (cbs *Couchbase) Ping() error { +func (cbs *Couchbase) Ping(ctx context.Context) error { return nil } diff --git a/state/gcp/firestore/firestore.go b/state/gcp/firestore/firestore.go index ae5041ae23..79c5bb583c 100644 --- a/state/gcp/firestore/firestore.go +++ b/state/gcp/firestore/firestore.go @@ -93,7 +93,7 @@ func (f *Firestore) Features() []state.Feature { } // Get retrieves state from Firestore with a key (Always strong consistency). -func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (f *Firestore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { key := req.Key entityKey := datastore.NameKey(f.entityKind, key, nil) @@ -111,7 +111,7 @@ func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, nil } -func (f *Firestore) setValue(req *state.SetRequest) error { +func (f *Firestore) setValue(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -128,7 +128,6 @@ func (f *Firestore) setValue(req *state.SetRequest) error { entity := &StateEntity{ Value: v, } - ctx := context.Background() key := datastore.NameKey(f.entityKind, req.Key, nil) _, err = f.client.Put(ctx, key, entity) @@ -141,16 +140,15 @@ func (f *Firestore) setValue(req *state.SetRequest) error { } // Set saves state into Firestore with retry. -func (f *Firestore) Set(req *state.SetRequest) error { - return state.SetWithOptions(f.setValue, req) +func (f *Firestore) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(f.setValue, ctx, req) } -func (f *Firestore) Ping() error { +func (f *Firestore) Ping(ctx context.Context) error { return nil } -func (f *Firestore) deleteValue(req *state.DeleteRequest) error { - ctx := context.Background() +func (f *Firestore) deleteValue(ctx context.Context, req *state.DeleteRequest) error { key := datastore.NameKey(f.entityKind, req.Key, nil) err := f.client.Delete(ctx, key) @@ -162,8 +160,8 @@ func (f *Firestore) deleteValue(req *state.DeleteRequest) error { } // Delete performs a delete operation. -func (f *Firestore) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(f.deleteValue, req) +func (f *Firestore) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(f.deleteValue, ctx, req) } func getFirestoreMetadata(metadata state.Metadata) (*firestoreMetadata, error) { diff --git a/state/hashicorp/consul/consul.go b/state/hashicorp/consul/consul.go index 464d4206f9..6458e195a1 100644 --- a/state/hashicorp/consul/consul.go +++ b/state/hashicorp/consul/consul.go @@ -14,6 +14,7 @@ limitations under the License. package consul import ( + "context" "encoding/json" "fmt" @@ -102,7 +103,7 @@ func metadataToConfig(connInfo map[string]string) (*consulConfig, error) { } // Get retrieves a Consul KV item. -func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (c *Consul) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { queryOpts := &api.QueryOptions{} if req.Options.Consistency == state.Strong { queryOpts.RequireConsistent = true @@ -124,7 +125,7 @@ func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set saves a Consul KV item. -func (c *Consul) Set(req *state.SetRequest) error { +func (c *Consul) Set(ctx context.Context, req *state.SetRequest) error { var reqValByte []byte b, ok := req.Value.([]byte) if ok { @@ -146,12 +147,12 @@ func (c *Consul) Set(req *state.SetRequest) error { return nil } -func (c *Consul) Ping() error { +func (c *Consul) Ping(ctx context.Context) error { return nil } // Delete performes a Consul KV delete operation. -func (c *Consul) Delete(req *state.DeleteRequest) error { +func (c *Consul) Delete(ctx context.Context, req *state.DeleteRequest) error { keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key) _, err := c.client.KV().Delete(keyWithPath, nil) if err != nil { diff --git a/state/hazelcast/hazelcast.go b/state/hazelcast/hazelcast.go index 9ccd16c3e1..21776f6013 100644 --- a/state/hazelcast/hazelcast.go +++ b/state/hazelcast/hazelcast.go @@ -14,6 +14,7 @@ limitations under the License. package hazelcast import ( + "context" "errors" "fmt" "strings" @@ -91,7 +92,7 @@ func (store *Hazelcast) Features() []state.Feature { } // Set stores value for a key to Hazelcast. -func (store *Hazelcast) Set(req *state.SetRequest) error { +func (store *Hazelcast) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req) if err != nil { return err @@ -117,7 +118,7 @@ func (store *Hazelcast) Set(req *state.SetRequest) error { } // Get retrieves state from Hazelcast with a key. -func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (store *Hazelcast) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { resp, err := store.hzMap.Get(req.Key) if err != nil { return nil, fmt.Errorf("hazelcast error: failed to get value for %s: %s", req.Key, err) @@ -137,12 +138,12 @@ func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) { }, nil } -func (store *Hazelcast) Ping() error { +func (store *Hazelcast) Ping(ctx context.Context) error { return nil } // Delete performs a delete operation. -func (store *Hazelcast) Delete(req *state.DeleteRequest) error { +func (store *Hazelcast) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err diff --git a/state/jetstream/jetstream.go b/state/jetstream/jetstream.go index 2ea83a1048..e995418d7d 100644 --- a/state/jetstream/jetstream.go +++ b/state/jetstream/jetstream.go @@ -14,6 +14,7 @@ limitations under the License. package jetstream import ( + "context" "fmt" "strings" @@ -92,7 +93,7 @@ func (js *StateStore) Init(metadata state.Metadata) error { return nil } -func (js *StateStore) Ping() error { +func (js *StateStore) Ping(ctx context.Context) error { return nil } @@ -101,7 +102,7 @@ func (js *StateStore) Features() []state.Feature { } // Get retrieves state with a key. -func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (js *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { entry, err := js.bucket.Get(escape(req.Key)) if err != nil { return nil, err @@ -113,14 +114,14 @@ func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set stores value for a key. -func (js *StateStore) Set(req *state.SetRequest) error { +func (js *StateStore) Set(ctx context.Context, req *state.SetRequest) error { bt, _ := utils.Marshal(req.Value, js.json.Marshal) _, err := js.bucket.Put(escape(req.Key), bt) return err } // Delete performs a delete operation. -func (js *StateStore) Delete(req *state.DeleteRequest) error { +func (js *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { return js.bucket.Delete(escape(req.Key)) } diff --git a/state/jetstream/jetstream_test.go b/state/jetstream/jetstream_test.go index 23a0bb7cc9..7f9b699940 100644 --- a/state/jetstream/jetstream_test.go +++ b/state/jetstream/jetstream_test.go @@ -14,6 +14,7 @@ limitations under the License. package jetstream import ( + "context" "encoding/json" "fmt" "reflect" @@ -110,8 +111,8 @@ func TestSetGetAndDelete(t *testing.T) { tData := map[string]string{ "dkey": "dvalue", } - - err = store.Set(&state.SetRequest{ + ctx := context.TODO() + err = store.Set(ctx, &state.SetRequest{ Key: tkey, Value: tData, }) @@ -120,7 +121,7 @@ func TestSetGetAndDelete(t *testing.T) { return } - resp, err := store.Get(&state.GetRequest{ + resp, err := store.Get(ctx, &state.GetRequest{ Key: tkey, }) if err != nil { @@ -133,7 +134,7 @@ func TestSetGetAndDelete(t *testing.T) { t.Fatal("Response data does not match written data\n") } - err = store.Delete(&state.DeleteRequest{ + err = store.Delete(ctx, &state.DeleteRequest{ Key: tkey, }) if err != nil { @@ -141,7 +142,7 @@ func TestSetGetAndDelete(t *testing.T) { return } - _, err = store.Get(&state.GetRequest{ + _, err = store.Get(ctx, &state.GetRequest{ Key: tkey, }) if err == nil { diff --git a/state/memcached/memcached.go b/state/memcached/memcached.go index fc33c70002..ef521eb914 100644 --- a/state/memcached/memcached.go +++ b/state/memcached/memcached.go @@ -14,6 +14,7 @@ limitations under the License. package memcached import ( + "context" "errors" "fmt" "strconv" @@ -131,7 +132,7 @@ func (m *Memcached) parseTTL(req *state.SetRequest) (*int32, error) { return nil, nil } -func (m *Memcached) setValue(req *state.SetRequest) error { +func (m *Memcached) setValue(ctx context.Context, req *state.SetRequest) error { var bt []byte ttl, err := m.parseTTL(req) if err != nil { @@ -151,7 +152,7 @@ func (m *Memcached) setValue(req *state.SetRequest) error { return nil } -func (m *Memcached) Delete(req *state.DeleteRequest) error { +func (m *Memcached) Delete(ctx context.Context, req *state.DeleteRequest) error { err := m.client.Delete(req.Key) if err != nil { return err @@ -160,7 +161,7 @@ func (m *Memcached) Delete(req *state.DeleteRequest) error { return nil } -func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *Memcached) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { item, err := m.client.Get(req.Key) if err != nil { // Return nil for status 204 @@ -176,10 +177,10 @@ func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) { }, nil } -func (m *Memcached) Set(req *state.SetRequest) error { - return state.SetWithOptions(m.setValue, req) +func (m *Memcached) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(m.setValue, ctx, req) } -func (m *Memcached) Ping() error { +func (m *Memcached) Ping(ctx context.Context) error { return nil } diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 9774614bf9..fbd8ea5df5 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -156,8 +156,8 @@ func (m *MongoDB) Features() []state.Feature { } // Set saves state into MongoDB. -func (m *MongoDB) Set(req *state.SetRequest) error { - ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout) +func (m *MongoDB) Set(ctx context.Context, req *state.SetRequest) error { + ctx, cancel := context.WithTimeout(ctx, m.operationTimeout) defer cancel() err := m.setInternal(ctx, req) @@ -168,8 +168,8 @@ func (m *MongoDB) Set(req *state.SetRequest) error { return nil } -func (m *MongoDB) Ping() error { - if err := m.client.Ping(context.Background(), nil); err != nil { +func (m *MongoDB) Ping(ctx context.Context) error { + if err := m.client.Ping(ctx, nil); err != nil { return fmt.Errorf("mongoDB store: error connecting to mongoDB at %s: %s", m.metadata.host, err) } @@ -202,10 +202,10 @@ func (m *MongoDB) setInternal(ctx context.Context, req *state.SetRequest) error } // Get retrieves state from MongoDB with a key. -func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *MongoDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { var result Item - ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout) + ctx, cancel := context.WithTimeout(ctx, m.operationTimeout) defer cancel() filter := bson.M{id: req.Key} @@ -261,8 +261,8 @@ func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Delete performs a delete operation. -func (m *MongoDB) Delete(req *state.DeleteRequest) error { - ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout) +func (m *MongoDB) Delete(ctx context.Context, req *state.DeleteRequest) error { + ctx, cancel := context.WithTimeout(ctx, m.operationTimeout) defer cancel() err := m.deleteInternal(ctx, req) diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index 408306f913..aef51d9bf9 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -14,6 +14,7 @@ limitations under the License. package mysql import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -158,7 +159,7 @@ func (m *MySQL) Init(metadata state.Metadata) error { return m.finishInit(db, err) } -func (m *MySQL) Ping() error { +func (m *MySQL) Ping(ctx context.Context) error { return nil } @@ -296,13 +297,13 @@ func tableExists(db *sql.DB, tableName string) (bool, error) { // Delete removes an entity from the store // Store Interface. -func (m *MySQL) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(m.deleteValue, req) +func (m *MySQL) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(m.deleteValue, ctx, req) } // deleteValue is an internal implementation of delete to enable passing the // logic to state.DeleteWithRetries as a func. -func (m *MySQL) deleteValue(req *state.DeleteRequest) error { +func (m *MySQL) deleteValue(ctx context.Context, req *state.DeleteRequest) error { m.logger.Debug("Deleting state value from MySql") if req.Key == "" { @@ -340,7 +341,7 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error { // BulkDelete removes multiple entries from the store // Store Interface. -func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { +func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { m.logger.Debug("Executing BulkDelete request") tx, err := m.db.Begin() @@ -351,7 +352,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { if len(req) > 0 { for _, d := range req { da := d // Fix for goSec G601: Implicit memory aliasing in for loop. - err = m.Delete(&da) + err = m.Delete(ctx, &da) if err != nil { tx.Rollback() @@ -367,7 +368,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { // Get returns an entity from store // Store Interface. -func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *MySQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.logger.Debug("Getting state value from MySql") if req.Key == "" { @@ -417,13 +418,13 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { // Set adds/updates an entity on store // Store Interface. -func (m *MySQL) Set(req *state.SetRequest) error { - return state.SetWithOptions(m.setValue, req) +func (m *MySQL) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(m.setValue, ctx, req) } // setValue is an internal implementation of set to enable passing the logic // to state.SetWithRetries as a func. -func (m *MySQL) setValue(req *state.SetRequest) error { +func (m *MySQL) setValue(ctx context.Context, req *state.SetRequest) error { m.logger.Debug("Setting state value in MySql") err := state.CheckRequestOptions(req.Options) @@ -502,7 +503,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error { // BulkSet adds/updates multiple entities on store // Store Interface. -func (m *MySQL) BulkSet(req []state.SetRequest) error { +func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error { m.logger.Debug("Executing BulkSet request") tx, err := m.db.Begin() @@ -513,7 +514,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error { if len(req) > 0 { for _, s := range req { sa := s // Fix for goSec G601: Implicit memory aliasing in for loop. - err = m.Set(&sa) + err = m.Set(ctx, &sa) if err != nil { tx.Rollback() @@ -529,7 +530,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error { // Multi handles multiple transactions. // TransactionalStore Interface. -func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { +func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { m.logger.Debug("Executing Multi request") tx, err := m.db.Begin() @@ -546,7 +547,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { return err } - err = m.Set(&setReq) + err = m.Set(ctx, &setReq) if err != nil { tx.Rollback() return err @@ -559,7 +560,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { return err } - err = m.Delete(&delReq) + err = m.Delete(ctx, &delReq) if err != nil { tx.Rollback() return err @@ -602,7 +603,7 @@ func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteR } // BulkGet performs a bulks get operations. -func (m *MySQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (m *MySQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // by default, the store doesn't support bulk get // return false so daprd will fallback to call get() method one by one return false, nil, nil diff --git a/state/mysql/mysql_integration_test.go b/state/mysql/mysql_integration_test.go index d5ee4cbd57..50ed4d9d3a 100644 --- a/state/mysql/mysql_integration_test.go +++ b/state/mysql/mysql_integration_test.go @@ -13,6 +13,7 @@ limitations under the License. package mysql import ( + "context" "crypto/tls" "crypto/x509" "database/sql" @@ -174,7 +175,7 @@ func multiWithSetOnly(t *testing.T, mys *MySQL) { }) } - err := mys.Multi(&state.TransactionalStateRequest{ + err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -204,7 +205,7 @@ func multiWithDeleteOnly(t *testing.T, mys *MySQL) { }) } - err := mys.Multi(&state.TransactionalStateRequest{ + err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -247,7 +248,7 @@ func multiWithDeleteAndSet(t *testing.T, mys *MySQL) { }) } - err := mys.Multi(&state.TransactionalStateRequest{ + err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -268,7 +269,7 @@ func deleteItemThatDoesNotExist(t *testing.T, mys *MySQL) { Key: randomKey(), } - err := mys.Delete(deleteReq) + err := mys.Delete(context.TODO(), deleteReq) assert.Nil(t, err) } @@ -277,7 +278,7 @@ func deleteWithNoKeyFails(t *testing.T, mys *MySQL) { Key: "", } - err := mys.Delete(deleteReq) + err := mys.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -295,7 +296,7 @@ func deleteWithInvalidEtagFails(t *testing.T, mys *MySQL) { ETag: &eTag, } - err := mys.Delete(deleteReq) + err := mys.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -311,7 +312,7 @@ func newItemWithEtagFails(t *testing.T, mys *MySQL) { Value: value, } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -340,7 +341,7 @@ func updateWithOldETagFails(t *testing.T, mys *MySQL) { Value: newValue, } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.NotNil(t, err, "Error was not thrown using old eTag") } @@ -379,7 +380,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) { }, } - err := mys.BulkSet(setReq) + err := mys.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -393,7 +394,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) { }, } - err = mys.BulkDelete(deleteReq) + err = mys.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -404,7 +405,7 @@ func setItemWithNoKey(t *testing.T, mys *MySQL) { Key: "", } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.NotNil(t, err, "Error was not nil when setting item with no key.") } @@ -438,7 +439,7 @@ func getItemWithNoKey(t *testing.T, mys *MySQL) { Key: "", } - response, getErr := mys.Get(getReq) + response, getErr := mys.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) } @@ -566,7 +567,7 @@ func setItem(t *testing.T, mys *MySQL, key string, value interface{}, eTag *stri Value: value, } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.Nil(t, err, "Error setting an item") itemExists := storeItemExists(t, key) assert.True(t, itemExists, "Item does not exist after being set") @@ -578,7 +579,7 @@ func getItem(t *testing.T, mys *MySQL, key string) (*state.GetResponse, *fakeIte Options: state.GetStateOption{}, } - response, getErr := mys.Get(getReq) + response, getErr := mys.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{} @@ -594,7 +595,7 @@ func deleteItem(t *testing.T, mys *MySQL, key string, eTag *string) { Options: state.DeleteStateOption{}, } - deleteErr := mys.Delete(deleteReq) + deleteErr := mys.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr, "There was an error deleting a record") assert.False(t, storeItemExists(t, key), "Item still exists after delete") } diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index 5ba5ca0e1f..38a2bf6d7c 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -13,6 +13,7 @@ limitations under the License. package mysql import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -171,7 +172,7 @@ func TestExecuteMultiCannotBeginTransaction(t *testing.T) { m.mock1.ExpectBegin().WillReturnError(fmt.Errorf("beginError")) // Act - err := m.mySQL.Multi(nil) + err := m.mySQL.Multi(context.TODO(), nil) // Assert assert.NotNil(t, err, "no error returned") @@ -190,7 +191,7 @@ func TestMySQLBulkDeleteRollbackDeletes(t *testing.T) { deletes := []state.DeleteRequest{createDeleteRequest()} // Act - err := m.mySQL.BulkDelete(deletes) + err := m.mySQL.BulkDelete(context.TODO(), deletes) // Assert assert.NotNil(t, err, "no error returned") @@ -209,7 +210,7 @@ func TestMySQLBulkSetRollbackSets(t *testing.T) { sets := []state.SetRequest{createSetRequest()} // Act - err := m.mySQL.BulkSet(sets) + err := m.mySQL.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err, "no error returned") @@ -242,7 +243,7 @@ func TestExecuteMultiCommitSetsAndDeletes(t *testing.T) { } // Act - err := m.mySQL.Multi(&request) + err := m.mySQL.Multi(context.TODO(), &request) // Assert assert.Nil(t, err, "error returned") @@ -258,7 +259,7 @@ func TestSetHandlesOptionsError(t *testing.T) { request.Options.Consistency = "Invalid" // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -273,7 +274,7 @@ func TestSetHandlesNoKey(t *testing.T) { request.Key = "" // Act - err := m.mySQL.Set(&request) + err := m.mySQL.Set(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -293,7 +294,7 @@ func TestSetHandlesUpdate(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.Nil(t, err) @@ -312,7 +313,7 @@ func TestSetHandlesErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -325,7 +326,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -337,7 +338,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.Nil(t, err) @@ -348,7 +349,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -362,7 +363,7 @@ func TestSetHandlesErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -379,7 +380,7 @@ func TestMySQLDeleteHandlesNoKey(t *testing.T) { request.Key = "" // Act - err := m.mySQL.Delete(&request) + err := m.mySQL.Delete(context.TODO(), &request) // Asset assert.NotNil(t, err) @@ -398,7 +399,7 @@ func TestDeleteWithETag(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(context.TODO(), &request) // Assert assert.Nil(t, err) @@ -415,7 +416,7 @@ func TestDeleteWithErr(t *testing.T) { request := createDeleteRequest() // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -430,7 +431,7 @@ func TestDeleteWithErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -451,7 +452,7 @@ func TestGetHandlesNoRows(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.Nil(t, err, "returned error") @@ -468,7 +469,7 @@ func TestGetHandlesNoKey(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.NotNil(t, err, "returned error") @@ -488,7 +489,7 @@ func TestGetHandlesGenericError(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.NotNil(t, err) @@ -509,7 +510,7 @@ func TestGetSucceeds(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.Nil(t, err) @@ -527,7 +528,7 @@ func TestGetSucceeds(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.Nil(t, err) @@ -689,7 +690,7 @@ func TestBulkGetReturnsNil(t *testing.T) { m, _ := mockDatabase(t) // Act - supported, response, err := m.mySQL.BulkGet(nil) + supported, response, err := m.mySQL.BulkGet(context.TODO(), nil) // Assert assert.Nil(t, err, `returned err`) @@ -708,7 +709,7 @@ func TestMultiWithNoRequestsDoesNothing(t *testing.T) { m.mock1.ExpectCommit() // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -728,7 +729,7 @@ func TestInvalidMultiAction(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -767,7 +768,7 @@ func TestValidSetRequest(t *testing.T) { m.mock1.ExpectCommit() // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -788,7 +789,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -812,7 +813,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -836,7 +837,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -857,7 +858,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -880,7 +881,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -919,7 +920,7 @@ func TestMultiOperationOrder(t *testing.T) { m.mock1.ExpectCommit() // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) diff --git a/state/oci/objectstorage/objectstorage.go b/state/oci/objectstorage/objectstorage.go index 2da43caab6..c41e4d800e 100644 --- a/state/oci/objectstorage/objectstorage.go +++ b/state/oci/objectstorage/objectstorage.go @@ -126,15 +126,15 @@ func (r *StateStore) Features() []state.Feature { return r.features } -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r.logger.Debugf("Delete entry from OCI Object Storage State Store with key ", req.Key) - err := r.deleteDocument(req) + err := r.deleteDocument(ctx, req) return err } -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { r.logger.Debugf("Get from OCI Object Storage State Store with key ", req.Key) - content, etag, err := r.readDocument((req)) + content, etag, err := r.readDocument(ctx, req) if err != nil { r.logger.Debugf("error %s", err) if err.Error() == "ObjectNotFound" { @@ -150,12 +150,12 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, err } -func (r *StateStore) Set(req *state.SetRequest) error { +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r.logger.Debugf("saving %s to OCI Object Storage State Store", req.Key) - return r.writeDocument(req) + return r.writeDocument(ctx, req) } -func (r *StateStore) Ping() error { +func (r *StateStore) Ping(ctx context.Context) error { return r.pingBucket() } @@ -262,7 +262,7 @@ func getIdentityAuthenticationDetails(metadata map[string]string, meta *Metadata } // functions that bridge from the Dapr State API to the OCI ObjectStorage Client. -func (r *StateStore) writeDocument(req *state.SetRequest) error { +func (r *StateStore) writeDocument(ctx context.Context, req *state.SetRequest) error { if len(req.Key) == 0 || req.Key == "" { return fmt.Errorf("key for value to set was missing from request") } @@ -281,7 +281,6 @@ func (r *StateStore) writeDocument(req *state.SetRequest) error { objectName := getFileName(req.Key) content := r.marshal(req) objectLength := int64(len(content)) - ctx := context.Background() etag := req.ETag if req.Options.Concurrency != state.FirstWrite { etag = nil @@ -310,12 +309,11 @@ func convertTTLtoExpiryTime(req *state.SetRequest, logger logger.Logger, metadat return nil } -func (r *StateStore) readDocument(req *state.GetRequest) ([]byte, *string, error) { +func (r *StateStore) readDocument(ctx context.Context, req *state.GetRequest) ([]byte, *string, error) { if len(req.Key) == 0 || req.Key == "" { return nil, nil, fmt.Errorf("key for value to get was missing from request") } objectName := getFileName(req.Key) - ctx := context.Background() content, etag, meta, err := r.client.getObject(ctx, objectName, r.logger) if err != nil { r.logger.Debugf("download file %s, err %s", req.Key, err) @@ -343,13 +341,12 @@ func (r *StateStore) pingBucket() error { return nil } -func (r *StateStore) deleteDocument(req *state.DeleteRequest) error { +func (r *StateStore) deleteDocument(ctx context.Context, req *state.DeleteRequest) error { if len(req.Key) == 0 || req.Key == "" { return fmt.Errorf("key for value to delete was missing from request") } objectName := getFileName(req.Key) - ctx := context.Background() etag := req.ETag if req.Options.Concurrency != state.FirstWrite { etag = nil diff --git a/state/oci/objectstorage/objectstorage_integration_test.go b/state/oci/objectstorage/objectstorage_integration_test.go index b2cac9f349..34177119e2 100644 --- a/state/oci/objectstorage/objectstorage_integration_test.go +++ b/state/oci/objectstorage/objectstorage_integration_test.go @@ -4,6 +4,7 @@ package objectstorage // go test -v github.com/dapr/components-contrib/state/oci/objectstorage. import ( + "context" "fmt" "os" "testing" @@ -84,20 +85,20 @@ func testGet(t *testing.T, ociProperties map[string]string) { statestore := NewOCIObjectStorageStore(logger.NewLogger("logger")) meta := state.Metadata{} meta.Properties = ociProperties - + ctx := context.TODO() t.Run("Get an non-existing key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: "xyzq"}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: "xyzq"}) assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.NoError(t, err, "Non-existing key must not be treated as error") }) t.Run("Get an existing key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: "test-key", Value: []byte("test-value")}) + err = statestore.Set(ctx, &state.SetRequest{Key: "test-key", Value: []byte("test-value")}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: "test-key"}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: "test-key"}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") @@ -105,9 +106,9 @@ func testGet(t *testing.T, ociProperties map[string]string) { t.Run("Get an existing composed key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")}) + err = statestore.Set(ctx, &state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: "test-app||test-key"}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: "test-app||test-key"}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") }) @@ -115,11 +116,11 @@ func testGet(t *testing.T, ociProperties map[string]string) { testKey := "unexpired-ttl-test-key" err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "100", })}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set despite TTL setting") }) @@ -127,23 +128,23 @@ func testGet(t *testing.T, ociProperties map[string]string) { testKey := "never-expiring-ttl-test-key" err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "-1", })}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal (TTL setting of -1 means never expire)") }) t.Run("Get an expired (TTL in the past) state element", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(ctx, &state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "1", })}) assert.Nil(t, err) time.Sleep(time.Second * 2) - getResponse, err := statestore.Get(&state.GetRequest{Key: "ttl-test-key"}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: "ttl-test-key"}) assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.NoError(t, err, "Expired element must not be treated as error") }) @@ -153,10 +154,11 @@ func testSet(t *testing.T, ociProperties map[string]string) { meta := state.Metadata{} meta.Properties = ociProperties statestore := NewOCIObjectStorageStore(logger.NewLogger("logger")) + ctx := context.TODO() t.Run("Set without a key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Value: []byte("test-value")}) + err = statestore.Set(ctx, &state.SetRequest{Value: []byte("test-value")}) assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error") }) t.Run("Regular Set Operation", func(t *testing.T) { @@ -164,9 +166,9 @@ func testSet(t *testing.T, ociProperties map[string]string) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") @@ -176,20 +178,20 @@ func testSet(t *testing.T, ociProperties map[string]string) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree") - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(ctx, &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") }) t.Run("Regular Set Operation with TTL", func(t *testing.T) { testKey := "test-key-with-ttl" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err := statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "500", })}) assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "XXX", })}) assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") @@ -200,25 +202,25 @@ func testSet(t *testing.T, ociProperties map[string]string) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - getResponse, _ := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, _ := statestore.Get(ctx, &state.GetRequest{Key: testKey}) etag := getResponse.ETag - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Updating value with proper etag should go fine") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Updating value with the old etag should be refused") // retrieve the latest etag - assigned by the previous set operation. - getResponse, _ = statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, _ = statestore.Get(ctx, &state.GetRequest{Key: testKey}) assert.NotNil(t, *getResponse.ETag, "ETag should be set") etag = getResponse.ETag - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Updating value with the latest etag should be accepted") @@ -229,10 +231,11 @@ func testDelete(t *testing.T, ociProperties map[string]string) { m := state.Metadata{} m.Properties = ociProperties s := NewOCIObjectStorageStore(logger.NewLogger("logger")) + ctx := context.TODO() t.Run("Delete without a key", func(t *testing.T) { err := s.Init(m) assert.Nil(t, err) - err = s.Delete(&state.DeleteRequest{}) + err = s.Delete(ctx, &state.DeleteRequest{}) assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error") }) t.Run("Regular Delete Operation", func(t *testing.T) { @@ -240,9 +243,9 @@ func testDelete(t *testing.T, ociProperties map[string]string) { err := s.Init(m) assert.Nil(t, err) - err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = s.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - err = s.Delete(&state.DeleteRequest{Key: testKey}) + err = s.Delete(ctx, &state.DeleteRequest{Key: testKey}) assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree") }) t.Run("Regular Delete Operation for composite key", func(t *testing.T) { @@ -250,13 +253,13 @@ func testDelete(t *testing.T, ociProperties map[string]string) { err := s.Init(m) assert.Nil(t, err) - err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = s.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree") - err = s.Delete(&state.DeleteRequest{Key: testKey}) + err = s.Delete(ctx, &state.DeleteRequest{Key: testKey}) assert.Nil(t, err, "Deleting an existing value with a proper composite key should be errorfree") }) t.Run("Delete with an unknown key", func(t *testing.T) { - err := s.Delete(&state.DeleteRequest{Key: "unknownKey"}) + err := s.Delete(ctx, &state.DeleteRequest{Key: "unknownKey"}) assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found") }) @@ -265,18 +268,18 @@ func testDelete(t *testing.T, ociProperties map[string]string) { err := s.Init(m) assert.Nil(t, err) // create document. - err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = s.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - getResponse, _ := s.Get(&state.GetRequest{Key: testKey}) + getResponse, _ := s.Get(ctx, &state.GetRequest{Key: testKey}) etag := getResponse.ETag incorrectETag := "someRandomETag" - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ + err = s.Delete(ctx, &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented") - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{ + err = s.Delete(ctx, &state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Deleting value with proper etag should go fine") @@ -290,7 +293,7 @@ func testPing(t *testing.T, ociProperties map[string]string) { t.Run("Ping", func(t *testing.T) { err := s.Init(m) assert.Nil(t, err) - err = s.Ping() + err = s.Ping(context.TODO()) assert.Nil(t, err, "Ping should be successful") }) } diff --git a/state/oci/objectstorage/objectstorage_test.go b/state/oci/objectstorage/objectstorage_test.go index c9f530e3df..b06fe3ae99 100644 --- a/state/oci/objectstorage/objectstorage_test.go +++ b/state/oci/objectstorage/objectstorage_test.go @@ -235,25 +235,26 @@ func TestGetWithMockClient(t *testing.T) { mockClient := &mockedObjectStoreClient{} s.client = mockClient t.Parallel() + ctx := context.TODO() t.Run("Test regular Get", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "test-key"}) + getResponse, err := s.Get(ctx, &state.GetRequest{Key: "test-key"}) assert.True(t, mockClient.getIsCalled, "function Get should be invoked on the mockClient") assert.Equal(t, "Hello World", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") assert.Nil(t, err) }) t.Run("Test Get with composite key", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "test-app||test-key"}) + getResponse, err := s.Get(ctx, &state.GetRequest{Key: "test-app||test-key"}) assert.Equal(t, "Hello Continent", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Nil(t, err) }) t.Run("Test Get with an unknown key", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "unknownKey"}) + getResponse, err := s.Get(ctx, &state.GetRequest{Key: "unknownKey"}) assert.Nil(t, getResponse.Data, "No value should be retrieved for an unknown key") assert.Nil(t, err, "404", "Not finding an object because of unknown key should not result in an error") }) t.Run("Test expired element (because of TTL) ", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "test-expired-ttl-key"}) + getResponse, err := s.Get(ctx, &state.GetRequest{Key: "test-expired-ttl-key"}) assert.Nil(t, getResponse.Data, "No value should be retrieved for an expired state element") assert.Nil(t, err, "Not returning an object because of expiration should not result in an error") }) @@ -277,7 +278,7 @@ func TestPingWithMockClient(t *testing.T) { s.client = mockClient t.Run("Test Ping", func(t *testing.T) { - err := s.Ping() + err := s.Ping(context.TODO()) assert.Nil(t, err) assert.True(t, mockClient.pingBucketIsCalled, "function pingBucket should be invoked on the mockClient") }) @@ -288,29 +289,30 @@ func TestSetWithMockClient(t *testing.T) { statestore := NewOCIObjectStorageStore(logger.NewLogger("logger")) mockClient := &mockedObjectStoreClient{} statestore.client = mockClient + ctx := context.TODO() t.Run("Set without a key", func(t *testing.T) { - err := statestore.Set(&state.SetRequest{Value: []byte("test-value")}) + err := statestore.Set(ctx, &state.SetRequest{Value: []byte("test-value")}) assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error") }) t.Run("Regular Set Operation", func(t *testing.T) { testKey := "test-key" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err := statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") assert.True(t, mockClient.putIsCalled, "function put should be invoked on the mockClient") }) t.Run("Regular Set Operation with TTL", func(t *testing.T) { testKey := "test-key" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err := statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "5", })}) assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "XXX", })}) assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "1", })}) assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") @@ -320,22 +322,22 @@ func TestSetWithMockClient(t *testing.T) { incorrectETag := "notTheCorrectETag" etag := "correctETag" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{ + err := statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Updating value with wrong etag should fail") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Updating value with proper etag should go fine") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ + err = statestore.Set(ctx, &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Updating value with concurrency policy at FirstWrite should fail when ETag is missing") @@ -347,35 +349,36 @@ func TestDeleteWithMockClient(t *testing.T) { s := NewOCIObjectStorageStore(logger.NewLogger("logger")) mockClient := &mockedObjectStoreClient{} s.client = mockClient + ctx := context.TODO() t.Run("Delete without a key", func(t *testing.T) { - err := s.Delete(&state.DeleteRequest{}) + err := s.Delete(ctx, &state.DeleteRequest{}) assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error") }) t.Run("Delete with an unknown key", func(t *testing.T) { - err := s.Delete(&state.DeleteRequest{Key: "unknownKey"}) + err := s.Delete(ctx, &state.DeleteRequest{Key: "unknownKey"}) assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found") }) t.Run("Regular Delete Operation", func(t *testing.T) { testKey := "test-key" - err := s.Delete(&state.DeleteRequest{Key: testKey}) + err := s.Delete(ctx, &state.DeleteRequest{Key: testKey}) assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree") assert.True(t, mockClient.deleteIsCalled, "function delete should be invoked on the mockClient") }) t.Run("Testing Delete & Concurrency (ETags)", func(t *testing.T) { testKey := "etag-test-delete-key" incorrectETag := "notTheCorrectETag" - err := s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ + err := s.Delete(ctx, &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented") etag := "correctETag" - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{ + err = s.Delete(ctx, &state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Deleting value with proper etag should go fine") - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{ + err = s.Delete(ctx, &state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail") diff --git a/state/oracledatabase/dbaccess.go b/state/oracledatabase/dbaccess.go index 60e7ee3366..174a9c9ab2 100644 --- a/state/oracledatabase/dbaccess.go +++ b/state/oracledatabase/dbaccess.go @@ -14,16 +14,18 @@ limitations under the License. package oracledatabase import ( + "context" + "github.com/dapr/components-contrib/state" ) // dbAccess is a private interface which enables unit testing of Oracle Database. type dbAccess interface { Init(metadata state.Metadata) error - Ping() error - Set(req *state.SetRequest) error - Get(req *state.GetRequest) (*state.GetResponse, error) - Delete(req *state.DeleteRequest) error - ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error + Ping(ctx context.Context) error + Set(ctx context.Context, req *state.SetRequest) error + Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) + Delete(ctx context.Context, req *state.DeleteRequest) error + ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error Close() error // io.Closer. } diff --git a/state/oracledatabase/oracledatabase.go b/state/oracledatabase/oracledatabase.go index e10ad684a5..7a1eed101b 100644 --- a/state/oracledatabase/oracledatabase.go +++ b/state/oracledatabase/oracledatabase.go @@ -14,6 +14,7 @@ limitations under the License. package oracledatabase import ( + "context" "fmt" "github.com/dapr/components-contrib/state" @@ -49,8 +50,8 @@ func (o *OracleDatabase) Init(metadata state.Metadata) error { return o.dbaccess.Init(metadata) } -func (o *OracleDatabase) Ping() error { - return o.dbaccess.Ping() +func (o *OracleDatabase) Ping(ctx context.Context) error { + return o.dbaccess.Ping(ctx) } // Features returns the features available in this state store. @@ -59,38 +60,38 @@ func (o *OracleDatabase) Features() []state.Feature { } // Delete removes an entity from the store. -func (o *OracleDatabase) Delete(req *state.DeleteRequest) error { - return o.dbaccess.Delete(req) +func (o *OracleDatabase) Delete(ctx context.Context, req *state.DeleteRequest) error { + return o.dbaccess.Delete(ctx, req) } // BulkDelete removes multiple entries from the store. -func (o *OracleDatabase) BulkDelete(req []state.DeleteRequest) error { - return o.dbaccess.ExecuteMulti(nil, req) +func (o *OracleDatabase) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { + return o.dbaccess.ExecuteMulti(ctx, nil, req) } // Get returns an entity from store. -func (o *OracleDatabase) Get(req *state.GetRequest) (*state.GetResponse, error) { - return o.dbaccess.Get(req) +func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + return o.dbaccess.Get(ctx, req) } // BulkGet performs a bulks get operations. -func (o *OracleDatabase) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with ExecuteMulti for performance. return false, nil, nil } // Set adds/updates an entity on store. -func (o *OracleDatabase) Set(req *state.SetRequest) error { - return o.dbaccess.Set(req) +func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error { + return o.dbaccess.Set(ctx, req) } // BulkSet adds/updates multiple entities on store. -func (o *OracleDatabase) BulkSet(req []state.SetRequest) error { - return o.dbaccess.ExecuteMulti(req, nil) +func (o *OracleDatabase) BulkSet(ctx context.Context, req []state.SetRequest) error { + return o.dbaccess.ExecuteMulti(ctx, req, nil) } // Multi handles multiple transactions. Implements TransactionalStore. -func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error { +func (o *OracleDatabase) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { var deletes []state.DeleteRequest var sets []state.SetRequest for _, req := range request.Operations { @@ -115,7 +116,7 @@ func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error { } if len(sets) > 0 || len(deletes) > 0 { - return o.dbaccess.ExecuteMulti(sets, deletes) + return o.dbaccess.ExecuteMulti(ctx, sets, deletes) } return nil diff --git a/state/oracledatabase/oracledatabase_integration_test.go b/state/oracledatabase/oracledatabase_integration_test.go index eb6c6c9ec7..8b8b5082b8 100644 --- a/state/oracledatabase/oracledatabase_integration_test.go +++ b/state/oracledatabase/oracledatabase_integration_test.go @@ -13,6 +13,7 @@ limitations under the License. package oracledatabase import ( + "context" "database/sql" "encoding/json" "fmt" @@ -220,7 +221,7 @@ func deleteItemThatDoesNotExist(t *testing.T, ods *OracleDatabase) { deleteReq := &state.DeleteRequest{ Key: randomKey(), } - err := ods.Delete(deleteReq) + err := ods.Delete(context.TODO(), deleteReq) assert.Nil(t, err) } @@ -239,7 +240,7 @@ func multiWithSetOnly(t *testing.T, ods *OracleDatabase) { }) } - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -269,7 +270,7 @@ func multiWithDeleteOnly(t *testing.T, ods *OracleDatabase) { }) } - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -312,7 +313,7 @@ func multiWithDeleteAndSet(t *testing.T, ods *OracleDatabase) { }) } - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -342,7 +343,7 @@ func deleteWithInvalidEtagFails(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err := ods.Delete(deleteReq) + err := ods.Delete(context.TODO(), deleteReq) assert.NotNil(t, err, "Deleting an item with the wrong etag while enforcing FirstWrite policy should fail") } @@ -350,7 +351,7 @@ func deleteWithNoKeyFails(t *testing.T, ods *OracleDatabase) { deleteReq := &state.DeleteRequest{ Key: "", } - err := ods.Delete(deleteReq) + err := ods.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -368,7 +369,7 @@ func newItemWithEtagFails(t *testing.T, ods *OracleDatabase) { }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -398,7 +399,7 @@ func updateWithOldEtagFails(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -420,7 +421,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err, "Setting the item should be successful") updateResponse, updatedItem := getItem(t, ods, key) assert.Equal(t, value, updatedItem) @@ -436,7 +437,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err = ods.Delete(deleteReq) + err = ods.Delete(context.TODO(), deleteReq) assert.Nil(t, err, "Deleting an item with the right etag while enforcing FirstWrite policy should succeed") // Item is not in the data store. @@ -462,7 +463,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac Concurrency: state.LastWrite, }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err, "Setting the item should be successful") _, updatedItem := getItem(t, ods, key) assert.Equal(t, value, updatedItem) @@ -475,7 +476,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac Concurrency: state.LastWrite, }, } - err = ods.Delete(deleteReq) + err = ods.Delete(context.TODO(), deleteReq) assert.Nil(t, err, "Deleting an item with the wrong etag but not enforcing FirstWrite policy should succeed") // Item is not in the data store. @@ -497,7 +498,7 @@ func getItemWithNoKey(t *testing.T, ods *OracleDatabase) { Key: "", } - response, getErr := ods.Get(getReq) + response, getErr := ods.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) } @@ -545,7 +546,7 @@ func setTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) { }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) connectionString := getConnectionString() if getWalletLocation() != "" { @@ -577,10 +578,10 @@ func setNoTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "1000", }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) delete(setReq.Metadata, "ttlInSeconds") - err = ods.Set(setReq) + err = ods.Set(context.TODO(), setReq) assert.Nil(t, err) connectionString := getConnectionString() if getWalletLocation() != "" { @@ -611,11 +612,11 @@ func expiredStateCannotBeRead(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "1", }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) time.Sleep(time.Second * time.Duration(2)) - getResponse, err := ods.Get(&state.GetRequest{Key: key}) + getResponse, err := ods.Get(context.TODO(), &state.GetRequest{Key: key}) assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.NoError(t, err, "Expired element must not be treated as error") @@ -636,7 +637,7 @@ func unexpiredStateCanBeRead(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "10000", }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) _, getValue := getItem(t, ods, key) assert.Equal(t, value.Color, getValue.Color, "Response must be as set") @@ -650,7 +651,7 @@ func setItemWithNoKey(t *testing.T, ods *OracleDatabase) { Key: "", } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -699,7 +700,7 @@ func testSetItemWithInvalidTTL(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "XX", }), } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") } @@ -711,7 +712,7 @@ func testSetItemWithNegativeTTL(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "-10", }), } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err, "Setting a value with a proper key and a negative (other than -1) TTL value should be produce an error") } @@ -728,7 +729,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) { }, } - err := ods.BulkSet(setReq) + err := ods.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -742,7 +743,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) { }, } - err = ods.BulkDelete(deleteReq) + err = ods.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -809,7 +810,7 @@ func setItem(t *testing.T, ods *OracleDatabase, key string, value interface{}, e Options: setOptions, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) itemExists := storeItemExists(t, key) assert.True(t, itemExists, "Item should exist after set has been executed ") @@ -821,7 +822,7 @@ func getItem(t *testing.T, ods *OracleDatabase, key string) (*state.GetResponse, Options: state.GetStateOption{}, } - response, getErr := ods.Get(getReq) + response, getErr := ods.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{} @@ -837,7 +838,7 @@ func deleteItem(t *testing.T, ods *OracleDatabase, key string, etag *string) { Options: state.DeleteStateOption{}, } - deleteErr := ods.Delete(deleteReq) + deleteErr := ods.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr) assert.False(t, storeItemExists(t, key), "item should no longer exist after delete has been performed") } diff --git a/state/oracledatabase/oracledatabase_test.go b/state/oracledatabase/oracledatabase_test.go index e81653f3b0..5ffb765ddb 100644 --- a/state/oracledatabase/oracledatabase_test.go +++ b/state/oracledatabase/oracledatabase_test.go @@ -13,6 +13,7 @@ limitations under the License. package oracledatabase import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -34,7 +35,7 @@ type fakeDBaccess struct { getExecuted bool } -func (m *fakeDBaccess) Ping() error { +func (m *fakeDBaccess) Ping(ctx context.Context) error { m.pingExecuted = true return nil } @@ -45,23 +46,23 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error { return nil } -func (m *fakeDBaccess) Set(req *state.SetRequest) error { +func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error { m.setExecuted = true return nil } -func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.getExecuted = true return nil, nil } -func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { +func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { return nil } -func (m *fakeDBaccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error { +func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error { return nil } @@ -73,7 +74,7 @@ func (m *fakeDBaccess) Close() error { func TestInitRunsDBAccessInit(t *testing.T) { t.Parallel() ods, fake := createOracleDatabaseWithFake(t) - ods.Ping() + ods.Ping(context.TODO()) assert.True(t, fake.initExecuted) } @@ -81,7 +82,7 @@ func TestMultiWithNoRequestsReturnsNil(t *testing.T) { t.Parallel() var operations []state.TransactionalStateOperation ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -97,7 +98,7 @@ func TestInvalidMultiAction(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.NotNil(t, err) @@ -113,7 +114,7 @@ func TestValidSetRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -129,7 +130,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.NotNil(t, err) @@ -145,7 +146,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -161,7 +162,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.NotNil(t, err) @@ -190,7 +191,7 @@ func createOracleDatabaseWithFake(t *testing.T) (*OracleDatabase, *fakeDBaccess) func TestPingRunsDBAccessPing(t *testing.T) { t.Parallel() odb, fake := createOracleDatabaseWithFake(t) - odb.Ping() + odb.Ping(context.TODO()) assert.True(t, fake.pingExecuted) } diff --git a/state/oracledatabase/oracledatabaseaccess.go b/state/oracledatabase/oracledatabaseaccess.go index 5fba997c35..a066a4cf7e 100644 --- a/state/oracledatabase/oracledatabaseaccess.go +++ b/state/oracledatabase/oracledatabaseaccess.go @@ -14,6 +14,7 @@ limitations under the License. package oracledatabase import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -57,7 +58,7 @@ func newOracleDatabaseAccess(logger logger.Logger) *oracleDatabaseAccess { } } -func (o *oracleDatabaseAccess) Ping() error { +func (o *oracleDatabaseAccess) Ping(ctx context.Context) error { return o.db.Ping() } @@ -96,8 +97,8 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error { } // Set makes an insert or update to the database. -func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(o.setValue, req) +func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(o.setValue, ctx, req) } func parseTTL(requestMetadata map[string]string) (*int, error) { @@ -115,7 +116,7 @@ func parseTTL(requestMetadata map[string]string) (*int, error) { } // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { +func (o *oracleDatabaseAccess) setValue(ctx context.Context, req *state.SetRequest) error { o.logger.Debug("Setting state value in OracleDatabase") err := state.CheckRequestOptions(req.Options) if err != nil { @@ -214,7 +215,7 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { } // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. -func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { o.logger.Debug("Getting state value from OracleDatabase") if req.Key == "" { return nil, fmt.Errorf("missing key in get operation") @@ -253,12 +254,12 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e } // Delete removes an item from the state store. -func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(o.deleteValue, req) +func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(o.deleteValue, ctx, req) } // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error { +func (o *oracleDatabaseAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error { o.logger.Debug("Deleting state value from OracleDatabase") if req.Key == "" { return fmt.Errorf("missing key in delete operation") @@ -303,7 +304,7 @@ func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error { return nil } -func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error { +func (o *oracleDatabaseAccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error { o.logger.Debug("Executing multiple OracleDatabase operations, within a single transaction") tx, err := o.db.Begin() if err != nil { @@ -313,7 +314,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s if len(deletes) > 0 { for _, d := range deletes { da := d // Fix for gosec G601: Implicit memory aliasing in for looo. - err = o.Delete(&da) + err = o.Delete(ctx, &da) if err != nil { tx.Rollback() return err @@ -323,7 +324,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s if len(sets) > 0 { for _, s := range sets { sa := s // Fix for gosec G601: Implicit memory aliasing in for looo. - err = o.Set(&sa) + err = o.Set(ctx, &sa) if err != nil { tx.Rollback() return err diff --git a/state/postgresql/dbaccess.go b/state/postgresql/dbaccess.go index d4575be3f0..3bb9d20a8a 100644 --- a/state/postgresql/dbaccess.go +++ b/state/postgresql/dbaccess.go @@ -14,18 +14,20 @@ limitations under the License. package postgresql import ( + "context" + "github.com/dapr/components-contrib/state" ) // dbAccess is a private interface which enables unit testing of PostgreSQL. type dbAccess interface { Init(metadata state.Metadata) error - Set(req *state.SetRequest) error - BulkSet(req []state.SetRequest) error - Get(req *state.GetRequest) (*state.GetResponse, error) - Delete(req *state.DeleteRequest) error - BulkDelete(req []state.DeleteRequest) error - ExecuteMulti(req *state.TransactionalStateRequest) error + Set(ctx context.Context, req *state.SetRequest) error + BulkSet(ctx context.Context, req []state.SetRequest) error + Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) + Delete(ctx context.Context, req *state.DeleteRequest) error + BulkDelete(ctx context.Context, req []state.DeleteRequest) error + ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error Query(req *state.QueryRequest) (*state.QueryResponse, error) Close() error // io.Closer } diff --git a/state/postgresql/postgresdbaccess.go b/state/postgresql/postgresdbaccess.go index 82cd70c840..6faacc78d8 100644 --- a/state/postgresql/postgresdbaccess.go +++ b/state/postgresql/postgresdbaccess.go @@ -14,6 +14,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -90,12 +91,12 @@ func (p *postgresDBAccess) Init(metadata state.Metadata) error { } // Set makes an insert or update to the database. -func (p *postgresDBAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(p.setValue, req) +func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(p.setValue, ctx, req) } // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (p *postgresDBAccess) setValue(req *state.SetRequest) error { +func (p *postgresDBAccess) setValue(ctx context.Context, req *state.SetRequest) error { p.logger.Debug("Setting state value in PostgreSQL") err := state.CheckRequestOptions(req.Options) @@ -166,7 +167,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error { return nil } -func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { +func (p *postgresDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error { p.logger.Debug("Executing BulkSet request") tx, err := p.db.Begin() if err != nil { @@ -176,7 +177,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { if len(req) > 0 { for _, s := range req { sa := s // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Set(&sa) + err = p.Set(ctx, &sa) if err != nil { tx.Rollback() @@ -191,7 +192,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { } // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. -func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { p.logger.Debug("Getting state value from PostgreSQL") if req.Key == "" { return nil, fmt.Errorf("missing key in get operation") @@ -237,12 +238,12 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error } // Delete removes an item from the state store. -func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(p.deleteValue, req) +func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(p.deleteValue, ctx, req) } // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { +func (p *postgresDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error { p.logger.Debug("Deleting state value from PostgreSQL") if req.Key == "" { return fmt.Errorf("missing key in delete operation") @@ -281,7 +282,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { return nil } -func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { +func (p *postgresDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { p.logger.Debug("Executing BulkDelete request") tx, err := p.db.Begin() if err != nil { @@ -291,7 +292,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { if len(req) > 0 { for _, d := range req { da := d // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Delete(&da) + err = p.Delete(ctx, &da) if err != nil { tx.Rollback() @@ -305,7 +306,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { return err } -func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error { +func (p *postgresDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error { p.logger.Debug("Executing PostgreSQL transaction") tx, err := p.db.Begin() @@ -324,7 +325,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest return err } - err = p.Set(&setReq) + err = p.Set(ctx, &setReq) if err != nil { tx.Rollback() return err @@ -339,7 +340,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest return err } - err = p.Delete(&delReq) + err = p.Delete(ctx, &delReq) if err != nil { tx.Rollback() return err diff --git a/state/postgresql/postgresdbaccess_test.go b/state/postgresql/postgresdbaccess_test.go index 35681cf261..b3e5a57325 100644 --- a/state/postgresql/postgresdbaccess_test.go +++ b/state/postgresql/postgresdbaccess_test.go @@ -13,6 +13,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "testing" @@ -108,7 +109,7 @@ func TestMultiWithNoRequests(t *testing.T) { var operations []state.TransactionalStateOperation // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -132,7 +133,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -157,7 +158,7 @@ func TestValidSetRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -181,7 +182,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -205,7 +206,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -230,7 +231,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -254,7 +255,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -278,7 +279,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -310,7 +311,7 @@ func TestMultiOperationOrder(t *testing.T) { ) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -333,7 +334,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) { }) // Act - err := m.pgDba.BulkSet(sets) + err := m.pgDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -355,7 +356,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) { }) // Act - err := m.pgDba.BulkSet(sets) + err := m.pgDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -378,7 +379,7 @@ func TestValidBulkSet(t *testing.T) { }) // Act - err := m.pgDba.BulkSet(sets) + err := m.pgDba.BulkSet(context.TODO(), sets) // Assert assert.Nil(t, err) @@ -399,7 +400,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) { }) // Act - err := m.pgDba.BulkDelete(deletes) + err := m.pgDba.BulkDelete(context.TODO(), deletes) // Assert assert.NotNil(t, err) @@ -421,7 +422,7 @@ func TestValidBulkDelete(t *testing.T) { }) // Act - err := m.pgDba.BulkDelete(deletes) + err := m.pgDba.BulkDelete(context.TODO(), deletes) // Assert assert.Nil(t, err) diff --git a/state/postgresql/postgresql.go b/state/postgresql/postgresql.go index 5bad1f110a..5ac5600664 100644 --- a/state/postgresql/postgresql.go +++ b/state/postgresql/postgresql.go @@ -14,6 +14,8 @@ limitations under the License. package postgresql import ( + "context" + "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" ) @@ -47,7 +49,7 @@ func (p *PostgreSQL) Init(metadata state.Metadata) error { return p.dbaccess.Init(metadata) } -func (p *PostgreSQL) Ping() error { +func (p *PostgreSQL) Ping(ctx context.Context) error { return nil } @@ -57,39 +59,39 @@ func (p *PostgreSQL) Features() []state.Feature { } // Delete removes an entity from the store. -func (p *PostgreSQL) Delete(req *state.DeleteRequest) error { - return p.dbaccess.Delete(req) +func (p *PostgreSQL) Delete(ctx context.Context, req *state.DeleteRequest) error { + return p.dbaccess.Delete(ctx, req) } // BulkDelete removes multiple entries from the store. -func (p *PostgreSQL) BulkDelete(req []state.DeleteRequest) error { - return p.dbaccess.BulkDelete(req) +func (p *PostgreSQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { + return p.dbaccess.BulkDelete(ctx, req) } // Get returns an entity from store. -func (p *PostgreSQL) Get(req *state.GetRequest) (*state.GetResponse, error) { - return p.dbaccess.Get(req) +func (p *PostgreSQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + return p.dbaccess.Get(ctx, req) } // BulkGet performs a bulks get operations. -func (p *PostgreSQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (p *PostgreSQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with ExecuteMulti for performance return false, nil, nil } // Set adds/updates an entity on store. -func (p *PostgreSQL) Set(req *state.SetRequest) error { - return p.dbaccess.Set(req) +func (p *PostgreSQL) Set(ctx context.Context, req *state.SetRequest) error { + return p.dbaccess.Set(ctx, req) } // BulkSet adds/updates multiple entities on store. -func (p *PostgreSQL) BulkSet(req []state.SetRequest) error { - return p.dbaccess.BulkSet(req) +func (p *PostgreSQL) BulkSet(ctx context.Context, req []state.SetRequest) error { + return p.dbaccess.BulkSet(ctx, req) } // Multi handles multiple transactions. Implements TransactionalStore. -func (p *PostgreSQL) Multi(request *state.TransactionalStateRequest) error { - return p.dbaccess.ExecuteMulti(request) +func (p *PostgreSQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { + return p.dbaccess.ExecuteMulti(ctx, request) } // Query executes a query against store. diff --git a/state/postgresql/postgresql_integration_test.go b/state/postgresql/postgresql_integration_test.go index c2784658cd..4276825ce6 100644 --- a/state/postgresql/postgresql_integration_test.go +++ b/state/postgresql/postgresql_integration_test.go @@ -13,6 +13,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "encoding/json" "fmt" @@ -189,7 +190,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) { deleteReq := &state.DeleteRequest{ Key: randomKey(), } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, err) } @@ -208,7 +209,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -238,7 +239,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -281,7 +282,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -308,7 +309,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *PostgreSQL) { Key: key, ETag: &etag, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -316,7 +317,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *PostgreSQL) { deleteReq := &state.DeleteRequest{ Key: "", } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -331,7 +332,7 @@ func newItemWithEtagFails(t *testing.T, pgs *PostgreSQL) { Value: value, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -357,7 +358,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *PostgreSQL) { ETag: originalEtag, Value: newValue, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -399,7 +400,7 @@ func getItemWithNoKey(t *testing.T, pgs *PostgreSQL) { Key: "", } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) } @@ -430,7 +431,7 @@ func setItemWithNoKey(t *testing.T, pgs *PostgreSQL) { Key: "", } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -447,7 +448,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) { }, } - err := pgs.BulkSet(setReq) + err := pgs.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -461,7 +462,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) { }, } - err = pgs.BulkDelete(deleteReq) + err = pgs.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -518,7 +519,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag Value: value, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.Nil(t, err) itemExists := storeItemExists(t, key) assert.True(t, itemExists) @@ -530,7 +531,7 @@ func getItem(t *testing.T, pgs *PostgreSQL, key string) (*state.GetResponse, *fa Options: state.GetStateOption{}, } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{} @@ -546,7 +547,7 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) { Options: state.DeleteStateOption{}, } - deleteErr := pgs.Delete(deleteReq) + deleteErr := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr) assert.False(t, storeItemExists(t, key)) } diff --git a/state/postgresql/postgresql_test.go b/state/postgresql/postgresql_test.go index 7deca686d2..f82e3ff76c 100644 --- a/state/postgresql/postgresql_test.go +++ b/state/postgresql/postgresql_test.go @@ -13,6 +13,7 @@ limitations under the License. package postgresql import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -40,33 +41,33 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error { return nil } -func (m *fakeDBaccess) Set(req *state.SetRequest) error { +func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error { m.setExecuted = true return nil } -func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.getExecuted = true return nil, nil } -func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { +func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { m.deleteExecuted = true return nil } -func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error { +func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error { return nil } -func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error { +func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { return nil } -func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error { +func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error { return nil } diff --git a/state/redis/redis.go b/state/redis/redis.go index 405dcdffd3..815ec41316 100644 --- a/state/redis/redis.go +++ b/state/redis/redis.go @@ -117,8 +117,8 @@ func NewRedisStateStore(logger logger.Logger) *StateStore { return s } -func (r *StateStore) Ping() error { - if _, err := r.client.Ping(context.Background()).Result(); err != nil { +func (r *StateStore) Ping(ctx context.Context) error { + if _, err := r.client.Ping(ctx).Result(); err != nil { return fmt.Errorf("redis store: error connecting to redis at %s: %s", r.clientSettings.Host, err) } @@ -195,7 +195,7 @@ func (r *StateStore) parseConnectedSlaves(res string) int { return 0 } -func (r *StateStore) deleteValue(req *state.DeleteRequest) error { +func (r *StateStore) deleteValue(ctx context.Context, req *state.DeleteRequest) error { if req.ETag == nil { etag := "0" req.ETag = &etag @@ -207,7 +207,7 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error { } else { delQuery = delDefaultQuery } - _, err := r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() + _, err := r.client.Do(ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() if err != nil { return state.NewETagError(state.ETagMismatch, err) } @@ -216,13 +216,13 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error { } // Delete performs a delete operation. -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err } - return state.DeleteWithOptions(r.deleteValue, req) + return state.DeleteWithOptions(r.deleteValue, ctx, req) } func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) { @@ -242,8 +242,8 @@ func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error }, nil } -func (r *StateStore) getDefault(req *state.GetRequest) (*state.GetResponse, error) { - res, err := r.client.Do(r.ctx, "HGETALL", req.Key).Result() // Prefer values with ETags +func (r *StateStore) getDefault(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + res, err := r.client.Do(ctx, "HGETALL", req.Key).Result() // Prefer values with ETags if err != nil { return r.directGet(req) // Falls back to original get for backward compats. } @@ -304,12 +304,12 @@ func (r *StateStore) getJSON(req *state.GetRequest) (*state.GetResponse, error) } // Get retrieves state from redis with a key. -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { if contentType, ok := req.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType { return r.getJSON(req) } - return r.getDefault(req) + return r.getDefault(ctx, req) } type jsonEntry struct { @@ -317,7 +317,7 @@ type jsonEntry struct { Version *int `json:"version,omitempty"` } -func (r *StateStore) setValue(req *state.SetRequest) error { +func (r *StateStore) setValue(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -350,7 +350,7 @@ func (r *StateStore) setValue(req *state.SetRequest) error { bt, _ = utils.Marshal(req.Value, r.json.Marshal) } - err = r.client.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err() + err = r.client.Do(ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err() if err != nil { if req.ETag != nil { return state.NewETagError(state.ETagMismatch, err) @@ -384,8 +384,8 @@ func (r *StateStore) setValue(req *state.SetRequest) error { } // Set saves state into redis. -func (r *StateStore) Set(req *state.SetRequest) error { - return state.SetWithOptions(r.setValue, req) +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(r.setValue, ctx, req) } // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. diff --git a/state/redis/redis_test.go b/state/redis/redis_test.go index 56949cfdf4..a955a04d3d 100644 --- a/state/redis/redis_test.go +++ b/state/redis/redis_test.go @@ -273,7 +273,7 @@ func TestTransactionalDelete(t *testing.T) { ss.ctx, ss.cancel = context.WithCancel(context.Background()) // Insert a record first. - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon", Value: "deathstar", }) @@ -307,12 +307,12 @@ func TestPing(t *testing.T) { clientSettings: &rediscomponent.Settings{}, } - err := ss.Ping() + err := ss.Ping(context.TODO()) assert.NoError(t, err) s.Close() - err = ss.Ping() + err = ss.Ping(context.TODO()) assert.Error(t, err) } @@ -331,7 +331,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) { ss.ctx, ss.cancel = context.WithCancel(context.Background()) t.Run("TTL: Only global specified", func(t *testing.T) { - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", }) @@ -342,7 +342,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) { t.Run("TTL: Global and Request specified", func(t *testing.T) { requestTTL := 200 - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", Metadata: map[string]string{ @@ -424,7 +424,7 @@ func TestSetRequestWithTTL(t *testing.T) { t.Run("TTL specified", func(t *testing.T) { ttlInSeconds := 100 - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", Metadata: map[string]string{ @@ -438,7 +438,7 @@ func TestSetRequestWithTTL(t *testing.T) { }) t.Run("TTL not specified", func(t *testing.T) { - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon200", Value: "deathstar200", }) @@ -449,7 +449,7 @@ func TestSetRequestWithTTL(t *testing.T) { }) t.Run("TTL Changed for Existing Key", func(t *testing.T) { - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon300", Value: "deathstar300", }) @@ -458,7 +458,7 @@ func TestSetRequestWithTTL(t *testing.T) { // make the key no longer persistent ttlInSeconds := 123 - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon300", Value: "deathstar300", Metadata: map[string]string{ @@ -469,7 +469,7 @@ func TestSetRequestWithTTL(t *testing.T) { assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl) // make the key persistent again - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon300", Value: "deathstar301", Metadata: map[string]string{ @@ -493,7 +493,7 @@ func TestTransactionalDeleteNoEtag(t *testing.T) { ss.ctx, ss.cancel = context.WithCancel(context.Background()) // Insert a record first. - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", }) diff --git a/state/request_options.go b/state/request_options.go index 23ed85cb30..e2bc3e39f1 100644 --- a/state/request_options.go +++ b/state/request_options.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "fmt" ) @@ -68,11 +69,11 @@ func validateConsistencyOption(c string) error { } // SetWithOptions handles SetRequest with request options. -func SetWithOptions(method func(req *SetRequest) error, req *SetRequest) error { - return method(req) +func SetWithOptions(method func(ctx context.Context, req *SetRequest) error, ctx context.Context, req *SetRequest) error { + return method(ctx, req) } // DeleteWithOptions handles DeleteRequest with options. -func DeleteWithOptions(method func(req *DeleteRequest) error, req *DeleteRequest) error { - return method(req) +func DeleteWithOptions(method func(ctx context.Context, req *DeleteRequest) error, ctx context.Context, req *DeleteRequest) error { + return method(ctx, req) } diff --git a/state/request_options_test.go b/state/request_options_test.go index 2bf8e72f36..f5795c9b57 100644 --- a/state/request_options_test.go +++ b/state/request_options_test.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -23,21 +24,21 @@ import ( func TestSetRequestWithOptions(t *testing.T) { t.Run("set with default options", func(t *testing.T) { counter := 0 - SetWithOptions(func(req *SetRequest) error { + SetWithOptions(func(ctx context.Context, req *SetRequest) error { counter++ return nil - }, &SetRequest{}) + }, context.TODO(), &SetRequest{}) assert.Equal(t, 1, counter, "should execute only once") }) t.Run("set with no explicit options", func(t *testing.T) { counter := 0 - SetWithOptions(func(req *SetRequest) error { + SetWithOptions(func(ctx context.Context, req *SetRequest) error { counter++ return nil - }, &SetRequest{ + }, context.TODO(), &SetRequest{ Options: SetStateOption{}, }) assert.Equal(t, 1, counter, "should execute only once") diff --git a/state/rethinkdb/rethinkdb.go b/state/rethinkdb/rethinkdb.go index 13f7630efe..68a0140d25 100644 --- a/state/rethinkdb/rethinkdb.go +++ b/state/rethinkdb/rethinkdb.go @@ -14,6 +14,7 @@ limitations under the License. package rethinkdb import ( + "context" "encoding/json" "io/ioutil" "strconv" @@ -131,7 +132,7 @@ func (s *RethinkDB) Init(metadata state.Metadata) error { return nil } -func (s *RethinkDB) Ping() error { +func (s *RethinkDB) Ping(ctx context.Context) error { return nil } @@ -151,7 +152,7 @@ func tableExists(arr []string, table string) bool { } // Get retrieves a RethinkDB KV item. -func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *RethinkDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { if req == nil || req.Key == "" { return nil, errors.New("invalid state request, missing key") } @@ -191,22 +192,22 @@ func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulks get operations. -func (s *RethinkDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *RethinkDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with bulk get for performance return false, nil, nil } // Set saves a state KV item. -func (s *RethinkDB) Set(req *state.SetRequest) error { +func (s *RethinkDB) Set(ctx context.Context, req *state.SetRequest) error { if req == nil || req.Key == "" || req.Value == nil { return errors.New("invalid state request, key and value required") } - return s.BulkSet([]state.SetRequest{*req}) + return s.BulkSet(ctx, []state.SetRequest{*req}) } // BulkSet performs a bulk save operation. -func (s *RethinkDB) BulkSet(req []state.SetRequest) error { +func (s *RethinkDB) BulkSet(ctx context.Context, req []state.SetRequest) error { docs := make([]*stateRecord, len(req)) for i, v := range req { var etag string @@ -261,16 +262,16 @@ func (s *RethinkDB) archive(changes []r.ChangeResponse) error { } // Delete performes a RethinkDB KV delete operation. -func (s *RethinkDB) Delete(req *state.DeleteRequest) error { +func (s *RethinkDB) Delete(ctx context.Context, req *state.DeleteRequest) error { if req == nil || req.Key == "" { return errors.New("invalid request, missing key") } - return s.BulkDelete([]state.DeleteRequest{*req}) + return s.BulkDelete(ctx, []state.DeleteRequest{*req}) } // BulkDelete performs a bulk delete operation. -func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error { +func (s *RethinkDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { list := make([]string, 0) for _, d := range req { list = append(list, d.Key) @@ -286,7 +287,7 @@ func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error { } // Multi performs multiple operations. -func (s *RethinkDB) Multi(req state.TransactionalStateRequest) error { +func (s *RethinkDB) Multi(ctx context.Context, req state.TransactionalStateRequest) error { upserts := make([]state.SetRequest, 0) deletes := make([]state.DeleteRequest, 0) @@ -310,11 +311,11 @@ func (s *RethinkDB) Multi(req state.TransactionalStateRequest) error { } // best effort, no transacts supported - if err := s.BulkSet(upserts); err != nil { + if err := s.BulkSet(ctx, upserts); err != nil { return errors.Wrap(err, "error saving records to the database") } - if err := s.BulkDelete(deletes); err != nil { + if err := s.BulkDelete(ctx, deletes); err != nil { return errors.Wrap(err, "error deleting records to the database") } diff --git a/state/rethinkdb/rethinkdb_test.go b/state/rethinkdb/rethinkdb_test.go index 497d415431..df92109f6b 100644 --- a/state/rethinkdb/rethinkdb_test.go +++ b/state/rethinkdb/rethinkdb_test.go @@ -14,6 +14,7 @@ limitations under the License. package rethinkdb import ( + "context" "encoding/json" "fmt" "os" @@ -79,18 +80,19 @@ func TestRethinkDBStateStore(t *testing.T) { } assert.Equal(t, "test", db.config.Table) }) + ctx := context.TODO() t.Run("With struct data", func(t *testing.T) { // create and set data d := &testObj{F1: "test", F2: 1, F3: time.Now().UTC()} k := fmt.Sprintf("ids-%d", time.Now().UnixNano()) - if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil { + if err := db.Set(ctx, &state.SetRequest{Key: k, Value: d}); err != nil { t.Fatalf("error setting data to db: %v", err) } // get set data and compare - resp, err := db.Get(&state.GetRequest{Key: k}) + resp, err := db.Get(ctx, &state.GetRequest{Key: k}) assert.Nil(t, err) d2 := testGetTestObj(t, resp) assert.NotNil(t, d2) @@ -102,12 +104,12 @@ func TestRethinkDBStateStore(t *testing.T) { d2.F2 = 2 d2.F3 = time.Now().UTC() tag := fmt.Sprintf("hash-%d", time.Now().UnixNano()) - if err = db.Set(&state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil { + if err = db.Set(ctx, &state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil { t.Fatalf("error setting data to db: %v", err) } // get updated data and compare - resp2, err := db.Get(&state.GetRequest{Key: k}) + resp2, err := db.Get(ctx, &state.GetRequest{Key: k}) assert.Nil(t, err) d3 := testGetTestObj(t, resp2) assert.NotNil(t, d3) @@ -116,7 +118,7 @@ func TestRethinkDBStateStore(t *testing.T) { assert.Equal(t, d2.F3.Format(time.RFC3339), d3.F3.Format(time.RFC3339)) // delete data - if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil { + if err := db.Delete(ctx, &state.DeleteRequest{Key: k}); err != nil { t.Fatalf("error on data deletion: %v", err) } }) @@ -126,19 +128,19 @@ func TestRethinkDBStateStore(t *testing.T) { d := []byte("test") k := fmt.Sprintf("idb-%d", time.Now().UnixNano()) - if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil { + if err := db.Set(ctx, &state.SetRequest{Key: k, Value: d}); err != nil { t.Fatalf("error setting data to db: %v", err) } // get set data and compare - resp, err := db.Get(&state.GetRequest{Key: k}) + resp, err := db.Get(ctx, &state.GetRequest{Key: k}) assert.Nil(t, err) assert.NotNil(t, resp) assert.NotNil(t, resp.Data) assert.Equal(t, string(d), string(resp.Data)) // delete data - if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil { + if err := db.Delete(ctx, &state.DeleteRequest{Key: k}); err != nil { t.Fatalf("error on data deletion: %v", err) } }) @@ -174,28 +176,28 @@ func testBulk(t *testing.T, db *RethinkDB, i int) { deleteList = append(deleteList, state.DeleteRequest{Key: k}) setList[i] = state.SetRequest{Key: k, Value: d} } - + ctx := context.TODO() // bulk set it - if err := db.BulkSet(setList); err != nil { + if err := db.BulkSet(ctx, setList); err != nil { t.Fatalf("error setting data to db: %v -- run %d", err, i) } // check for the data for _, v := range deleteList { - resp, err := db.Get(&state.GetRequest{Key: v.Key}) + resp, err := db.Get(ctx, &state.GetRequest{Key: v.Key}) assert.Nilf(t, err, " -- run %d", i) assert.NotNil(t, resp) assert.NotNil(t, resp.Data) } // delete data - if err := db.BulkDelete(deleteList); err != nil { + if err := db.BulkDelete(ctx, deleteList); err != nil { t.Fatalf("error on data deletion: %v -- run %d", err, i) } // check for the data NOT being there for _, v := range deleteList { - resp, err := db.Get(&state.GetRequest{Key: v.Key}) + resp, err := db.Get(ctx, &state.GetRequest{Key: v.Key}) assert.Nilf(t, err, " -- run %d", i) assert.NotNil(t, resp) assert.Nil(t, resp.Data) @@ -216,6 +218,7 @@ func TestRethinkDBStateStoreMulti(t *testing.T) { numOfRecords := 4 recordIDFormat := "multi-%d" + ctx := context.TODO() t.Run("With multi", func(t *testing.T) { // create data list d := []byte("test") @@ -223,7 +226,7 @@ func TestRethinkDBStateStoreMulti(t *testing.T) { for i := 0; i < numOfRecords; i++ { list[i] = state.SetRequest{Key: fmt.Sprintf(recordIDFormat, i), Value: d} } - if err := db.BulkSet(list); err != nil { + if err := db.BulkSet(ctx, list); err != nil { t.Fatalf("error setting multi to db: %v", err) } @@ -257,19 +260,19 @@ func TestRethinkDBStateStoreMulti(t *testing.T) { } // execute multi - if err := db.Multi(req); err != nil { + if err := db.Multi(ctx, req); err != nil { t.Fatalf("error setting multi to db: %v", err) } // the one not deleted should be still there - m1, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)}) + m1, err := db.Get(ctx, &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)}) assert.Nil(t, err) assert.NotNil(t, m1) assert.NotNil(t, m1.Data) assert.Equal(t, string(d2), string(m1.Data)) // the one deleted should not - m2, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)}) + m2, err := db.Get(ctx, &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)}) assert.Nil(t, err) assert.NotNil(t, m2) assert.Nil(t, m2.Data) diff --git a/state/sqlserver/sqlserver.go b/state/sqlserver/sqlserver.go index 4f35dfca94..da4f72c5c9 100644 --- a/state/sqlserver/sqlserver.go +++ b/state/sqlserver/sqlserver.go @@ -24,6 +24,7 @@ import ( "github.com/agrea/ptr" mssql "github.com/denisenkom/go-mssqldb" + "golang.org/x/net/context" "github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state/utils" @@ -332,7 +333,7 @@ func (s *SQLServer) getTable(metadata state.Metadata) error { return nil } -func (s *SQLServer) Ping() error { +func (s *SQLServer) Ping(ctx context.Context) error { return nil } @@ -414,7 +415,7 @@ func (s *SQLServer) getDeletes(req state.TransactionalStateOperation) (state.Del } // Delete removes an entity from the store. -func (s *SQLServer) Delete(req *state.DeleteRequest) error { +func (s *SQLServer) Delete(ctx context.Context, req *state.DeleteRequest) error { return s.executeDelete(s.db, req) } @@ -460,7 +461,7 @@ type TvpDeleteTableStringKey struct { } // BulkDelete removes multiple entries from the store. -func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error { +func (s *SQLServer) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { tx, err := s.db.Begin() if err != nil { return err @@ -517,7 +518,7 @@ func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) } // Get returns an entity from store. -func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *SQLServer) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key)) if err != nil { return nil, err @@ -549,12 +550,12 @@ func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulks get operations. -func (s *SQLServer) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *SQLServer) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { return false, nil, nil } // Set adds/updates an entity on store. -func (s *SQLServer) Set(req *state.SetRequest) error { +func (s *SQLServer) Set(ctx context.Context, req *state.SetRequest) error { return s.executeSet(s.db, req) } @@ -608,7 +609,7 @@ func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error { } // BulkSet adds/updates multiple entities on store. -func (s *SQLServer) BulkSet(req []state.SetRequest) error { +func (s *SQLServer) BulkSet(ctx context.Context, req []state.SetRequest) error { tx, err := s.db.Begin() if err != nil { return err diff --git a/state/sqlserver/sqlserver_integration_test.go b/state/sqlserver/sqlserver_integration_test.go index 0a34c61f0a..a9ac3c30c0 100644 --- a/state/sqlserver/sqlserver_integration_test.go +++ b/state/sqlserver/sqlserver_integration_test.go @@ -13,6 +13,7 @@ limitations under the License. package sqlserver import ( + "context" "database/sql" "encoding/json" "fmt" @@ -127,7 +128,7 @@ func getTestStoreWithKeyType(t *testing.T, kt KeyType, indexedProperties string) } func assertUserExists(t *testing.T, store *SQLServer, key string) (user, string) { - getRes, err := store.Get(&state.GetRequest{Key: key}) + getRes, err := store.Get(context.TODO(), &state.GetRequest{Key: key}) assert.Nil(t, err) assert.NotNil(t, getRes) assert.NotNil(t, getRes.Data, "No data was returned") @@ -150,7 +151,7 @@ func assertLoadedUserIsEqual(t *testing.T, store *SQLServer, key string, expecte } func assertUserDoesNotExist(t *testing.T, store *SQLServer, key string) { - _, err := store.Get(&state.GetRequest{Key: key}) + _, err := store.Get(context.TODO(), &state.GetRequest{Key: key}) assert.Nil(t, err) } @@ -221,14 +222,14 @@ func testSingleOperations(t *testing.T) { assertUserDoesNotExist(t, store, john.ID) // Save and read - err := store.Set(&state.SetRequest{Key: john.ID, Value: john}) + err := store.Set(context.TODO(), &state.SetRequest{Key: john.ID, Value: john}) assert.Nil(t, err) johnV1, etagFromInsert := assertLoadedUserIsEqual(t, store, john.ID, john) // Update with ETAG waterJohn := johnV1 waterJohn.FavoriteBeverage = "Water" - err = store.Set(&state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert}) + err = store.Set(context.TODO(), &state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert}) assert.Nil(t, err) // Get updated @@ -237,7 +238,7 @@ func testSingleOperations(t *testing.T) { // Update without ETAG noEtagJohn := johnV2 noEtagJohn.FavoriteBeverage = "No Etag John" - err = store.Set(&state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn}) + err = store.Set(context.TODO(), &state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn}) assert.Nil(t, err) // 7. Get updated @@ -246,17 +247,17 @@ func testSingleOperations(t *testing.T) { // 8. Update with invalid ETAG should fail failedJohn := johnV3 failedJohn.FavoriteBeverage = "Will not work" - err = store.Set(&state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert}) + err = store.Set(context.TODO(), &state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert}) assert.NotNil(t, err) _, etag := assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3) // 9. Delete with invalid ETAG should fail - err = store.Delete(&state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag}) + err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag}) assert.NotNil(t, err) assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3) // 10. Delete with valid ETAG - err = store.Delete(&state.DeleteRequest{Key: johnV2.ID, ETag: &etag}) + err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV2.ID, ETag: &etag}) assert.Nil(t, err) assertUserDoesNotExist(t, store, johnV2.ID) @@ -270,7 +271,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) { u := user{uuid.New().String(), "John", "Coffee"} invEtag := invalidEtag - err := store.Set(&state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag}) + err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag}) assert.NotNil(t, err) } @@ -278,7 +279,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) { func testIndexedProperties(t *testing.T) { store := getTestStore(t, `[{ "column":"FavoriteBeverage", "property":"FavoriteBeverage", "type":"nvarchar(100)"}, { "column":"PetsCount", "property":"PetsCount", "type": "INTEGER"}]`) - err := store.BulkSet([]state.SetRequest{ + err := store.BulkSet(context.TODO(), []state.SetRequest{ {Key: "1", Value: userWithPets{user{"1", "John", "Coffee"}, 3}}, {Key: "2", Value: userWithPets{user{"2", "Laura", "Water"}, 1}}, {Key: "3", Value: userWithPets{user{"3", "Carl", "Beer"}, 0}}, @@ -340,7 +341,7 @@ func testMultiOperations(t *testing.T) { bulkSet[i] = state.SetRequest{Key: u.ID, Value: u} } - err := store.BulkSet(bulkSet) + err := store.BulkSet(context.TODO(), bulkSet) assert.Nil(t, err) assertUserCountIsEqualTo(t, store, len(initialUsers)) @@ -517,7 +518,7 @@ func testBulkSet(t *testing.T) { sets[i] = state.SetRequest{Key: u.ID, Value: u} } - err := store.BulkSet(sets) + err := store.BulkSet(context.TODO(), sets) assert.Nil(t, err) totalUsers = len(sets) assertUserCountIsEqualTo(t, store, totalUsers) @@ -529,7 +530,7 @@ func testBulkSet(t *testing.T) { modified.FavoriteBeverage = beverageTea toInsert := user{keyGen.NextKey(), "Maria", "Wine"} - err := store.BulkSet([]state.SetRequest{ + err := store.BulkSet(context.TODO(), []state.SetRequest{ {Key: modified.ID, Value: modified, ETag: &toModifyETag}, {Key: toInsert.ID, Value: toInsert}, }) @@ -548,7 +549,7 @@ func testBulkSet(t *testing.T) { modified.FavoriteBeverage = beverageTea toInsert := user{keyGen.NextKey(), "Tony", "Milk"} - err := store.BulkSet([]state.SetRequest{ + err := store.BulkSet(context.TODO(), []state.SetRequest{ {Key: modified.ID, Value: modified}, {Key: toInsert.ID, Value: toInsert}, }) @@ -575,7 +576,7 @@ func testBulkSet(t *testing.T) { {Key: modified.ID, Value: modified, ETag: &invEtag}, } - err := store.BulkSet(sets) + err := store.BulkSet(context.TODO(), sets) assert.NotNil(t, err) assertUserCountIsEqualTo(t, store, totalUsers) assertUserDoesNotExist(t, store, toInsert1.ID) @@ -618,7 +619,7 @@ func testBulkDelete(t *testing.T) { for i, u := range initialUsers { sets[i] = state.SetRequest{Key: u.ID, Value: u} } - err := store.BulkSet(sets) + err := store.BulkSet(context.TODO(), sets) assert.Nil(t, err) totalUsers := len(initialUsers) assertUserCountIsEqualTo(t, store, totalUsers) @@ -628,7 +629,7 @@ func testBulkDelete(t *testing.T) { t.Run("Delete 2 items without etag should work", func(t *testing.T) { deleted1 := initialUsers[userIndex].ID deleted2 := initialUsers[userIndex+1].ID - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1}, {Key: deleted2}, }) @@ -645,7 +646,7 @@ func testBulkDelete(t *testing.T) { deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID) deleted2, deleted2Etag := assertUserExists(t, store, initialUsers[userIndex+1].ID) - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted2.ID, ETag: &deleted2Etag}, }) @@ -662,7 +663,7 @@ func testBulkDelete(t *testing.T) { deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID) deleted2 := initialUsers[userIndex+1] - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted2.ID}, }) @@ -680,7 +681,7 @@ func testBulkDelete(t *testing.T) { deleted2 := initialUsers[userIndex+1] invEtag := invalidEtag - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted2.ID, ETag: &invEtag}, }) @@ -700,7 +701,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) { store := getTestStore(t, "") u := user{"1", "John", "Coffee"} - err := store.Set(&state.SetRequest{Key: u.ID, Value: u}) + err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u}) assert.Nil(t, err) var originalInsertTime time.Time @@ -722,7 +723,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) { modified := u modified.FavoriteBeverage = beverageTea - err = store.Set(&state.SetRequest{Key: modified.ID, Value: modified}) + err = store.Set(context.TODO(), &state.SetRequest{Key: modified.ID, Value: modified}) assert.Nil(t, err) assertDBQuery(t, store, getUserTsql, func(t *testing.T, rows *sql.Rows) { assert.True(t, rows.Next()) @@ -746,7 +747,7 @@ func testConcurrentSets(t *testing.T) { store := getTestStore(t, "") u := user{"1", "John", "Coffee"} - err := store.Set(&state.SetRequest{Key: u.ID, Value: u}) + err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u}) assert.Nil(t, err) _, etag := assertLoadedUserIsEqual(t, store, u.ID, u) @@ -763,7 +764,7 @@ func testConcurrentSets(t *testing.T) { defer wc.Done() modified := user{"1", "John", beverageTea} - err := store.Set(&state.SetRequest{Key: id, Value: modified, ETag: &etag}) + err := store.Set(context.TODO(), &state.SetRequest{Key: id, Value: modified, ETag: &etag}) if err != nil { atomic.AddInt32(&totalErrors, 1) } else { diff --git a/state/store.go b/state/store.go index 8eb378378a..079efaa8d2 100644 --- a/state/store.go +++ b/state/store.go @@ -13,22 +13,24 @@ limitations under the License. package state +import "context" + // Store is an interface to perform operations on store. type Store interface { BulkStore Init(metadata Metadata) error Features() []Feature - Delete(req *DeleteRequest) error - Get(req *GetRequest) (*GetResponse, error) - Set(req *SetRequest) error - Ping() error + Delete(ctx context.Context, req *DeleteRequest) error + Get(ctx context.Context, req *GetRequest) (*GetResponse, error) + Set(ctx context.Context, req *SetRequest) error + Ping(ctx context.Context) error } // BulkStore is an interface to perform bulk operations on store. type BulkStore interface { - BulkDelete(req []DeleteRequest) error - BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) - BulkSet(req []SetRequest) error + BulkDelete(ctx context.Context, req []DeleteRequest) error + BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) + BulkSet(ctx context.Context, req []SetRequest) error } // DefaultBulkStore is a default implementation of BulkStore. @@ -50,16 +52,16 @@ func (b *DefaultBulkStore) Features() []Feature { } // BulkGet performs a bulks get operations. -func (b *DefaultBulkStore) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { +func (b *DefaultBulkStore) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) { // by default, the store doesn't support bulk get // return false so daprd will fallback to call get() method one by one return false, nil, nil } // BulkSet performs a bulks save operation. -func (b *DefaultBulkStore) BulkSet(req []SetRequest) error { +func (b *DefaultBulkStore) BulkSet(ctx context.Context, req []SetRequest) error { for i := range req { - err := b.s.Set(&req[i]) + err := b.s.Set(ctx, &req[i]) if err != nil { return err } @@ -69,9 +71,9 @@ func (b *DefaultBulkStore) BulkSet(req []SetRequest) error { } // BulkDelete performs a bulk delete operation. -func (b *DefaultBulkStore) BulkDelete(req []DeleteRequest) error { +func (b *DefaultBulkStore) BulkDelete(ctx context.Context, req []DeleteRequest) error { for i := range req { - err := b.s.Delete(&req[i]) + err := b.s.Delete(ctx, &req[i]) if err != nil { return err } diff --git a/state/store_test.go b/state/store_test.go index cf6a0fd608..a33d318d71 100644 --- a/state/store_test.go +++ b/state/store_test.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -25,23 +26,23 @@ func TestStore_withDefaultBulkImpl(t *testing.T) { var store Store = s require.Equal(t, s.count, 0) require.Equal(t, s.bulkCount, 0) - - store.Get(&GetRequest{}) - store.Set(&SetRequest{}) - store.Delete(&DeleteRequest{}) + ctx := context.TODO() + store.Get(ctx, &GetRequest{}) + store.Set(ctx, &SetRequest{}) + store.Delete(ctx, &DeleteRequest{}) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - bulkGet, responses, err := store.BulkGet([]GetRequest{{}, {}, {}}) + bulkGet, responses, err := store.BulkGet(ctx, []GetRequest{{}, {}, {}}) require.Equal(t, false, bulkGet) require.Equal(t, 0, len(responses)) require.NoError(t, err) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - store.BulkSet([]SetRequest{{}, {}, {}, {}}) + store.BulkSet(ctx, []SetRequest{{}, {}, {}, {}}) require.Equal(t, 3+4, s.count) require.Equal(t, 0, s.bulkCount) - store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) + store.BulkDelete(ctx, []DeleteRequest{{}, {}, {}, {}, {}}) require.Equal(t, 3+4+5, s.count) require.Equal(t, 0, s.bulkCount) } @@ -51,21 +52,21 @@ func TestStore_withCustomisedBulkImpl_notSupportBulkGet(t *testing.T) { var store Store = s require.Equal(t, s.count, 0) require.Equal(t, s.bulkCount, 0) - - store.Get(&GetRequest{}) - store.Set(&SetRequest{}) - store.Delete(&DeleteRequest{}) + ctx := context.TODO() + store.Get(ctx, &GetRequest{}) + store.Set(ctx, &SetRequest{}) + store.Delete(ctx, &DeleteRequest{}) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}}) + bulkGet, _, _ := store.BulkGet(ctx, []GetRequest{{}, {}, {}}) require.Equal(t, false, bulkGet) require.Equal(t, 6, s.count) require.Equal(t, 0, s.bulkCount) - store.BulkSet([]SetRequest{{}, {}, {}, {}}) + store.BulkSet(ctx, []SetRequest{{}, {}, {}, {}}) require.Equal(t, 6, s.count) require.Equal(t, 1, s.bulkCount) - store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) + store.BulkDelete(ctx, []DeleteRequest{{}, {}, {}, {}, {}}) require.Equal(t, 6, s.count) require.Equal(t, 2, s.bulkCount) } @@ -75,21 +76,21 @@ func TestStore_withCustomisedBulkImpl_supportBulkGet(t *testing.T) { var store Store = s require.Equal(t, s.count, 0) require.Equal(t, s.bulkCount, 0) - - store.Get(&GetRequest{}) - store.Set(&SetRequest{}) - store.Delete(&DeleteRequest{}) + ctx := context.TODO() + store.Get(ctx, &GetRequest{}) + store.Set(ctx, &SetRequest{}) + store.Delete(ctx, &DeleteRequest{}) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}}) + bulkGet, _, _ := store.BulkGet(ctx, []GetRequest{{}, {}, {}}) require.Equal(t, true, bulkGet) require.Equal(t, 3, s.count) require.Equal(t, 1, s.bulkCount) - store.BulkSet([]SetRequest{{}, {}, {}, {}}) + store.BulkSet(ctx, []SetRequest{{}, {}, {}, {}}) require.Equal(t, 3, s.count) require.Equal(t, 2, s.bulkCount) - store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) + store.BulkDelete(ctx, []DeleteRequest{{}, {}, {}, {}, {}}) require.Equal(t, 3, s.count) require.Equal(t, 3, s.bulkCount) } @@ -110,25 +111,25 @@ func (s *Store1) Init(metadata Metadata) error { return nil } -func (s *Store1) Delete(req *DeleteRequest) error { +func (s *Store1) Delete(ctx context.Context, req *DeleteRequest) error { s.count++ return nil } -func (s *Store1) Get(req *GetRequest) (*GetResponse, error) { +func (s *Store1) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) { s.count++ return &GetResponse{}, nil } -func (s *Store1) Set(req *SetRequest) error { +func (s *Store1) Set(ctx context.Context, req *SetRequest) error { s.count++ return nil } -func (s *Store1) Ping() error { +func (s *Store1) Ping(ctx context.Context) error { return nil } @@ -149,29 +150,29 @@ func (s *Store2) Features() []Feature { return nil } -func (s *Store2) Delete(req *DeleteRequest) error { +func (s *Store2) Delete(ctx context.Context, req *DeleteRequest) error { s.count++ return nil } -func (s *Store2) Get(req *GetRequest) (*GetResponse, error) { +func (s *Store2) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) { s.count++ return &GetResponse{}, nil } -func (s *Store2) Set(req *SetRequest) error { +func (s *Store2) Set(ctx context.Context, req *SetRequest) error { s.count++ return nil } -func (s *Store2) Ping() error { +func (s *Store2) Ping(ctx context.Context) error { return nil } -func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { +func (s *Store2) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) { if s.supportBulkGet { s.bulkCount++ @@ -183,13 +184,13 @@ func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { return false, nil, nil } -func (s *Store2) BulkSet(req []SetRequest) error { +func (s *Store2) BulkSet(ctx context.Context, req []SetRequest) error { s.bulkCount++ return nil } -func (s *Store2) BulkDelete(req []DeleteRequest) error { +func (s *Store2) BulkDelete(ctx context.Context, req []DeleteRequest) error { s.bulkCount++ return nil diff --git a/state/zookeeper/zk.go b/state/zookeeper/zk.go index 2a33e922b1..399366a62a 100644 --- a/state/zookeeper/zk.go +++ b/state/zookeeper/zk.go @@ -14,6 +14,7 @@ limitations under the License. package zookeeper import ( + "context" "errors" "path" "strconv" @@ -161,7 +162,7 @@ func (s *StateStore) Features() []state.Feature { } // Get retrieves state from Zookeeper with a key. -func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { value, stat, err := s.conn.Get(s.prefixedKey(req.Key)) if err != nil { if errors.Is(err, zk.ErrNoNode) { @@ -178,19 +179,19 @@ func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulks get operations. -func (s *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with Multi for performance return false, nil, nil } // Delete performs a delete operation. -func (s *StateStore) Delete(req *state.DeleteRequest) error { +func (s *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r, err := s.newDeleteRequest(req) if err != nil { return err } - return state.DeleteWithOptions(func(req *state.DeleteRequest) error { + return state.DeleteWithOptions(func(ctx context.Context, req *state.DeleteRequest) error { err := s.conn.Delete(r.Path, r.Version) if errors.Is(err, zk.ErrNoNode) { return nil @@ -205,11 +206,11 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error { } return nil - }, req) + }, ctx, req) } // BulkDelete performs a bulk delete operation. -func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error { +func (s *StateStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error { ops := make([]interface{}, 0, len(reqs)) for i := range reqs { @@ -236,13 +237,13 @@ func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error { } // Set saves state into Zookeeper. -func (s *StateStore) Set(req *state.SetRequest) error { +func (s *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r, err := s.newSetDataRequest(req) if err != nil { return err } - return state.SetWithOptions(func(req *state.SetRequest) error { + return state.SetWithOptions(func(ctx context.Context, req *state.SetRequest) error { _, err = s.conn.Set(r.Path, r.Data, r.Version) if errors.Is(err, zk.ErrNoNode) { @@ -258,11 +259,11 @@ func (s *StateStore) Set(req *state.SetRequest) error { } return nil - }, req) + }, ctx, req) } // BulkSet performs a bulks save operation. -func (s *StateStore) BulkSet(reqs []state.SetRequest) error { +func (s *StateStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error { ops := make([]interface{}, 0, len(reqs)) for i := range reqs { @@ -303,7 +304,7 @@ func (s *StateStore) BulkSet(reqs []state.SetRequest) error { } } -func (s *StateStore) Ping() error { +func (s *StateStore) Ping(ctx context.Context) error { return nil } diff --git a/state/zookeeper/zk_test.go b/state/zookeeper/zk_test.go index d25226db77..ac1e228680 100644 --- a/state/zookeeper/zk_test.go +++ b/state/zookeeper/zk_test.go @@ -14,6 +14,7 @@ limitations under the License. package zookeeper import ( + "context" "fmt" "testing" "time" @@ -80,7 +81,7 @@ func TestGet(t *testing.T) { t.Run("With key exists", func(t *testing.T) { conn.EXPECT().Get("foo").Return([]byte("bar"), &zk.Stat{Version: 123}, nil).Times(1) - res, err := s.Get(&state.GetRequest{Key: "foo"}) + res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"}) assert.NotNil(t, res, "Key must be exists") assert.Equal(t, "bar", string(res.Data), "Value must be equals") assert.Equal(t, ptr.String("123"), res.ETag, "ETag must be equals") @@ -90,7 +91,7 @@ func TestGet(t *testing.T) { t.Run("With key non-exists", func(t *testing.T) { conn.EXPECT().Get("foo").Return(nil, nil, zk.ErrNoNode).Times(1) - res, err := s.Get(&state.GetRequest{Key: "foo"}) + res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"}) assert.Equal(t, &state.GetResponse{}, res, "Response must be empty") assert.NoError(t, err, "Non-existent key must not be treated as error") }) @@ -108,21 +109,21 @@ func TestDelete(t *testing.T) { t.Run("With key", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"}) assert.NoError(t, err, "Key must be exists") }) t.Run("With key and version", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(123)).Return(nil).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo", ETag: &etag}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo", ETag: &etag}) assert.NoError(t, err, "Key must be exists") }) t.Run("With key and concurrency", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1) - err := s.Delete(&state.DeleteRequest{ + err := s.Delete(context.TODO(), &state.DeleteRequest{ Key: "foo", ETag: &etag, Options: state.DeleteStateOption{Concurrency: state.LastWrite}, @@ -133,14 +134,14 @@ func TestDelete(t *testing.T) { t.Run("With delete error", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrUnknown).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"}) assert.EqualError(t, err, "zk: unknown error") }) t.Run("With delete and ignore NoNode error", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrNoNode).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"}) assert.NoError(t, err, "Delete must be successful") }) } @@ -159,7 +160,7 @@ func TestBulkDelete(t *testing.T) { &zk.DeleteRequest{Path: "bar", Version: int32(anyVersion)}, }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) - err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) + err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) assert.NoError(t, err, "Key must be exists") }) @@ -171,7 +172,7 @@ func TestBulkDelete(t *testing.T) { {Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth}, }, nil).Times(1) - err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) + err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) assert.Equal(t, err.(*multierror.Error).Errors, []error{zk.ErrUnknown, zk.ErrNoAuth}) }) t.Run("With keys and ignore NoNode error", func(t *testing.T) { @@ -182,7 +183,7 @@ func TestBulkDelete(t *testing.T) { {Error: zk.ErrNoNode}, {}, }, nil).Times(1) - err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) + err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) assert.NoError(t, err, "Key must be exists") }) } @@ -201,19 +202,19 @@ func TestSet(t *testing.T) { t.Run("With key", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"}) assert.NoError(t, err, "Key must be set") }) t.Run("With key and version", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(123)).Return(stat, nil).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar", ETag: &etag}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar", ETag: &etag}) assert.NoError(t, err, "Key must be set") }) t.Run("With key and concurrency", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1) - err := s.Set(&state.SetRequest{ + err := s.Set(context.TODO(), &state.SetRequest{ Key: "foo", Value: "bar", ETag: &etag, @@ -225,14 +226,14 @@ func TestSet(t *testing.T) { t.Run("With error", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrUnknown).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"}) assert.EqualError(t, err, "zk: unknown error") }) t.Run("With NoNode error and retry", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrNoNode).Times(1) conn.EXPECT().Create("foo", []byte("\"bar\""), int32(0), nil).Return("/foo", nil).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"}) assert.NoError(t, err, "Key must be create") }) } @@ -251,7 +252,7 @@ func TestBulkSet(t *testing.T) { &zk.SetDataRequest{Path: "bar", Data: []byte("\"foo\""), Version: int32(anyVersion)}, }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) - err := s.BulkSet([]state.SetRequest{ + err := s.BulkSet(context.TODO(), []state.SetRequest{ {Key: "foo", Value: "bar"}, {Key: "bar", Value: "foo"}, }) @@ -266,7 +267,7 @@ func TestBulkSet(t *testing.T) { {Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth}, }, nil).Times(1) - err := s.BulkSet([]state.SetRequest{ + err := s.BulkSet(context.TODO(), []state.SetRequest{ {Key: "foo", Value: "bar"}, {Key: "bar", Value: "foo"}, }) @@ -283,7 +284,7 @@ func TestBulkSet(t *testing.T) { &zk.CreateRequest{Path: "foo", Data: []byte("\"bar\"")}, }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) - err := s.BulkSet([]state.SetRequest{ + err := s.BulkSet(context.TODO(), []state.SetRequest{ {Key: "foo", Value: "bar"}, {Key: "bar", Value: "foo"}, }) diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 8e004482d7..b0482b48ae 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "encoding/json" "fmt" "sort" @@ -227,7 +228,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) t.Run("ping", func(t *testing.T) { - err := statestore.Ping() + err := statestore.Ping(context.TODO()) assert.Nil(t, err) }) @@ -243,7 +244,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if len(scenario.contentType) != 0 { req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} } - err := statestore.Set(req) + err := statestore.Set(context.TODO(), req) assert.Nil(t, err) } } @@ -261,7 +262,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if len(scenario.contentType) != 0 { req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} } - res, err := statestore.Get(req) + res, err := statestore.Get(context.TODO(), req) assert.Nil(t, err) assertEquals(t, scenario.value, res) } @@ -310,11 +311,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if len(scenario.contentType) != 0 { req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} } - err := statestore.Delete(req) + err := statestore.Delete(context.TODO(), req) assert.Nil(t, err) t.Logf("Checking value absence for %s", scenario.key) - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, }) assert.Nil(t, err) @@ -336,14 +337,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) } } - err := statestore.BulkSet(bulk) + err := statestore.BulkSet(context.TODO(), bulk) assert.Nil(t, err) for _, scenario := range scenarios { if scenario.bulkOnly { t.Logf("Checking value presence for %s", scenario.key) // Data should have been inserted at this point - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, }) assert.Nil(t, err) @@ -364,12 +365,12 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) } } - err := statestore.BulkDelete(bulk) + err := statestore.BulkDelete(context.TODO(), bulk) assert.Nil(t, err) for _, req := range bulk { t.Logf("Checking value absence for %s", req.Key) - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: req.Key, }) assert.Nil(t, err) @@ -434,7 +435,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if scenario.transactionGroup == transactionGroup { t.Logf("Checking value presence for %s", scenario.key) // Data should have been inserted at this point - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, // For CosmosDB Metadata: map[string]string{ @@ -448,7 +449,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if scenario.toBeDeleted && (scenario.transactionGroup == transactionGroup-1) { t.Logf("Checking value absence for %s", scenario.key) // Data should have been deleted at this point - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, // For CosmosDB Metadata: map[string]string{ @@ -478,7 +479,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St } // prerequisite: key1 should be present - err := statestore.Set(&state.SetRequest{ + err := statestore.Set(context.TODO(), &state.SetRequest{ Key: firstKey, Value: firstValue, Metadata: partitionMetadata, @@ -486,14 +487,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.NoError(t, err, "set request should be successful") // prerequisite: key2 should not be present - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: secondKey, Metadata: partitionMetadata, }) assert.NoError(t, err, "delete request should be successful") // prerequisite: key3 should not be present - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: thirdKey, Metadata: partitionMetadata, }) @@ -548,7 +549,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St // Assert for k, v := range expected { - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: k, Metadata: partitionMetadata, }) @@ -575,20 +576,20 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.True(t, state.FeatureETag.IsPresent(features)) // Delete any potential object, it's important to start from a clean slate. - err := statestore.Delete(&state.DeleteRequest{ + err := statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, }) assert.Nil(t, err) // Set an object. - err = statestore.Set(&state.SetRequest{ + err = statestore.Set(context.TODO(), &state.SetRequest{ Key: testKey, Value: firstValue, }) assert.Nil(t, err) // Validate the set. - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) @@ -597,7 +598,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St etag := res.ETag // Try and update with wrong ETag, expect failure. - err = statestore.Set(&state.SetRequest{ + err = statestore.Set(context.TODO(), &state.SetRequest{ Key: testKey, Value: secondValue, ETag: &fakeEtag, @@ -605,7 +606,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.NotNil(t, err) // Try and update with corect ETag, expect success. - err = statestore.Set(&state.SetRequest{ + err = statestore.Set(context.TODO(), &state.SetRequest{ Key: testKey, Value: secondValue, ETag: etag, @@ -613,7 +614,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.Nil(t, err) // Validate the set. - res, err = statestore.Get(&state.GetRequest{ + res, err = statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) assert.Nil(t, err) @@ -622,14 +623,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St etag = res.ETag // Try and delete with wrong ETag, expect failure. - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, ETag: &fakeEtag, }) assert.NotNil(t, err) // Try and delete with correct ETag, expect success. - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, ETag: etag, }) @@ -687,23 +688,23 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St for _, requestSet := range requestSets { // Delete any potential object, it's important to start from a clean slate. - err := statestore.Delete(&state.DeleteRequest{ + err := statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, }) assert.Nil(t, err) - err = statestore.Set(requestSet[0]) + err = statestore.Set(context.TODO(), requestSet[0]) assert.Nil(t, err) // Validate the set. - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) assert.Nil(t, err) assertEquals(t, firstValue, res) // Second write expect fail - err = statestore.Set(requestSet[1]) + err = statestore.Set(context.TODO(), requestSet[1]) assert.NotNil(t, err) } }) @@ -719,16 +720,16 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St } // Delete any potential object, it's important to start from a clean slate. - err := statestore.Delete(&state.DeleteRequest{ + err := statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, }) assert.Nil(t, err) - err = statestore.Set(request) + err = statestore.Set(context.TODO(), request) assert.Nil(t, err) // Validate the set. - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) assert.Nil(t, err) @@ -745,11 +746,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St Consistency: state.Strong, }, } - err = statestore.Set(request) + err = statestore.Set(context.TODO(), request) assert.Nil(t, err) // Validate the set. - res, err = statestore.Get(&state.GetRequest{ + res, err = statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) assert.Nil(t, err) @@ -759,7 +760,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St request.ETag = etag // Second write expect fail - err = statestore.Set(request) + err = statestore.Set(context.TODO(), request) assert.NotNil(t, err) }) }