Skip to content

Commit

Permalink
DSU-1860 Add context (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtopc authored Sep 13, 2023
1 parent 09d3d5d commit b517452
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 17 deletions.
83 changes: 66 additions & 17 deletions cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ import (
)

type Cassandra interface {
QueryCtx(ctx context.Context, consistency gocql.Consistency, queryString string, queryParams ...interface{}) *gocql.Query
Query(gocql.Consistency, string, ...interface{}) *gocql.Query
ExecuteQueryCtx(ctx context.Context, queryString string, queryParams ...interface{}) error
ExecuteQuery(string, ...interface{}) error
ExecuteBatchCtx(ctx context.Context, batchType gocql.BatchType, queries []string, params [][]interface{}) error
ExecuteBatch(gocql.BatchType, []string, [][]interface{}) error
ExecuteUnloggedBatchCtx(ctx context.Context, queries []string, params [][]interface{}) error
ExecuteUnloggedBatch([]string, [][]interface{}) error
ScanQueryCtx(ctx context.Context, queryString string, queryParams []interface{}, outParams ...interface{}) error
ScanQuery(string, []interface{}, ...interface{}) error
ScanCASQueryCtx(ctx context.Context, queryString string, queryParams []interface{}, outParams ...interface{}) (applied bool, err error)
ScanCASQuery(string, []interface{}, ...interface{}) (bool, error)
IterQueryCtx(ctx context.Context, queryString string, queryParams []interface{}, outParams ...interface{}) func() (idx int, hasNext bool, err error)
IterQuery(string, []interface{}, ...interface{}) func() (int, bool, error)
Close() error
Config() CassandraConfig
Expand Down Expand Up @@ -133,34 +138,40 @@ func (c *cassandra) Session() *gocql.Session {
return c.session
}

// Query provides an access to the gocql.Query if a user of this library needs to tune some parameters for
// QueryCtx provides access to the gocql.Query if a user of this library needs to tune some parameters for
// a specific query without modifying the parameters the library was configured with, for example to use
// a consistency level that differs from the configured read/write consistency levels.
func (c *cassandra) QueryCtx(ctx context.Context, consistency gocql.Consistency, queryString string, queryParams ...interface{}) *gocql.Query {
return c.session.Query(queryString, queryParams...).WithContext(ctx).Consistency(consistency)
}

// Query same as QueryCtx, but without context.Context.
// Deprecated: Query is deprecated. Use QueryCtx instead.
func (c *cassandra) Query(consistency gocql.Consistency, queryString string, queryParams ...interface{}) *gocql.Query {
return c.session.Query(queryString, queryParams...).Consistency(consistency)
return c.QueryCtx(context.Background(), consistency, queryString, queryParams...)
}

// ExecuteQueryCtx executes a single DML/DDL statement at the configured write consistency level.
func (c *cassandra) ExecuteQueryCtx(ctx context.Context, queryString string, queryParams ...interface{}) error {
return c.Query(c.wcl, queryString, queryParams...).WithContext(ctx).Exec()
return c.QueryCtx(ctx, c.wcl, queryString, queryParams...).Exec()
}

// ExecuteQuery executes a single DML/DDL statement at the configured write consistency level.
// Deprecated: ExecuteQuery is deprecated. Switch to ExecuteQueryCtx.
// Deprecated: ExecuteQuery is deprecated. Use ExecuteQueryCtx instead.
func (c *cassandra) ExecuteQuery(queryString string, queryParams ...interface{}) error {
return c.ExecuteQueryCtx(context.Background(), queryString, queryParams...)
}

// ExecuteBatch executes a batch of DML/DDL statements at the configured write consistency level.
func (c *cassandra) ExecuteBatch(batchType gocql.BatchType, queries []string, params [][]interface{}) error {
// ExecuteBatchCtx executes a batch of DML/DDL statements at the configured write consistency level.
func (c *cassandra) ExecuteBatchCtx(ctx context.Context, batchType gocql.BatchType, queries []string, params [][]interface{}) error {
count := len(queries)

// quick sanity check
if count != len(params) {
return errors.New("Amount of queries and params does not match")
}

batch := c.session.NewBatch(batchType)
batch := c.session.NewBatch(batchType).WithContext(ctx)
batch.Cons = c.wcl
for idx := 0; idx < count; idx++ {
batch.Query(queries[idx], params[idx]...)
Expand All @@ -169,15 +180,27 @@ func (c *cassandra) ExecuteBatch(batchType gocql.BatchType, queries []string, pa
return c.session.ExecuteBatch(batch)
}

// ExecuteBatch executes a batch of DML/DDL statements at the configured write consistency level.
// Deprecated: ExecuteBatch is deprecated. Use ExecuteBatchCtx instead.
func (c *cassandra) ExecuteBatch(batchType gocql.BatchType, queries []string, params [][]interface{}) error {
return c.ExecuteBatchCtx(context.Background(), batchType, queries, params)
}

// ExecuteUnloggedBatchCtx executes a batch of DML/DDL statements in a non-atomic way at the configured
// write consistency level.
func (c *cassandra) ExecuteUnloggedBatchCtx(ctx context.Context, queries []string, params [][]interface{}) error {
return c.ExecuteBatchCtx(ctx, gocql.UnloggedBatch, queries, params)
}

// ExecuteUnloggedBatch executes a batch of DML/DDL statements in a non-atomic way at the configured
// write consistency level.
func (c *cassandra) ExecuteUnloggedBatch(queries []string, params [][]interface{}) error {
return c.ExecuteBatch(gocql.UnloggedBatch, queries, params)
return c.ExecuteUnloggedBatchCtx(context.Background(), queries, params)
}

// ScanQueryCtx executes a provided SELECT query at the configured read consistency level.
func (c *cassandra) ScanQueryCtx(ctx context.Context, queryString string, queryParams []interface{}, outParams ...interface{}) error {
if err := c.Query(c.rcl, queryString, queryParams...).WithContext(ctx).Scan(outParams...); err != nil {
if err := c.QueryCtx(ctx, c.rcl, queryString, queryParams...).Scan(outParams...); err != nil {
if err == gocql.ErrNotFound {
return NotFound
}
Expand All @@ -187,19 +210,29 @@ func (c *cassandra) ScanQueryCtx(ctx context.Context, queryString string, queryP
}

// ScanQuery executes a provided SELECT query at the configured read consistency level.
// Deprecated: ScanQuery is deprecated. Use ScanQueryCtx instead.
func (c *cassandra) ScanQuery(queryString string, queryParams []interface{}, outParams ...interface{}) error {
return c.ScanQueryCtx(context.Background(), queryString, queryParams, outParams...)
}

// ScanCASQueryCtx executes a lightweight transaction (an UPDATE or INSERT statement containing an IF clause)
// at the configured write consistency level.
func (c *cassandra) ScanCASQueryCtx(ctx context.Context, queryString string, queryParams []interface{}, outParams ...interface{},
) (applied bool, err error) {
return c.QueryCtx(ctx, c.wcl, queryString, queryParams...).ScanCAS(outParams...)
}

// ScanCASQuery executes a lightweight transaction (an UPDATE or INSERT statement containing an IF clause)
// at the configured write consistency level.
// Deprecated: ScanCASQuery is deprecated. Use ScanCASQueryCtx instead.
func (c *cassandra) ScanCASQuery(queryString string, queryParams []interface{}, outParams ...interface{}) (bool, error) {
return c.Query(c.wcl, queryString, queryParams...).ScanCAS(outParams...)
return c.ScanCASQueryCtx(context.Background(), queryString, queryParams, outParams...)
}

// IterQuery consumes row by row of the provided SELECT query executed at the configured read consistency level.
func (c *cassandra) IterQuery(queryString string, queryParams []interface{}, outParams ...interface{}) func() (int, bool, error) {
iter := c.Query(c.rcl, queryString, queryParams...).Iter()
// IterQueryCtx consumes row by row of the provided SELECT query executed at the configured read consistency level.
func (c *cassandra) IterQueryCtx(ctx context.Context, queryString string, queryParams []interface{}, outParams ...interface{},
) func() (idx int, hasNext bool, err error) {
iter := c.QueryCtx(ctx, c.rcl, queryString, queryParams...).Iter()
idx := -1
return func() (int, bool, error) {
idx++
Expand All @@ -213,10 +246,16 @@ func (c *cassandra) IterQuery(queryString string, queryParams []interface{}, out
}
}

func TableExists(db Cassandra, table string) (bool, error) {
// IterQuery consumes row by row of the provided SELECT query executed at the configured read consistency level.
// Deprecated: IterQuery is deprecated. Use IterQueryCtx instead.
func (c *cassandra) IterQuery(queryString string, queryParams []interface{}, outParams ...interface{}) func() (int, bool, error) {
return c.IterQueryCtx(context.Background(), queryString, queryParams, outParams...)
}

func TableExistsCtx(ctx context.Context, db Cassandra, table string) (bool, error) {
var tableName string
// Only tested with Cassandra 3.11.x
iter := db.IterQuery("SELECT table_name FROM system_schema.tables"+
iter := db.IterQueryCtx(ctx, "SELECT table_name FROM system_schema.tables"+
" WHERE keyspace_name = ? AND table_name = ?",
[]interface{}{db.Config().Keyspace, table}, &tableName)
_, _, err := iter()
Expand All @@ -232,7 +271,12 @@ func TableExists(db Cassandra, table string) (bool, error) {
return false, nil
}

func WaitForTables(db Cassandra, timeout time.Duration, tables ...string) error {
// Deprecated: TableExists is deprecated. Use TableExistsCtx instead.
func TableExists(db Cassandra, table string) (bool, error) {
return TableExistsCtx(context.Background(), db, table)
}

func WaitForTablesCtx(ctx context.Context, db Cassandra, timeout time.Duration, tables ...string) error {
quit := false
mutex := sync.Mutex{}
time.AfterFunc(timeout, func() {
Expand All @@ -244,7 +288,7 @@ func WaitForTables(db Cassandra, timeout time.Duration, tables ...string) error
for _, table := range tables {
tryAgain:
mutex.Lock()
exists, err := TableExists(db, table)
exists, err := TableExistsCtx(ctx, db, table)
if err != nil {
mutex.Unlock()
return err
Expand All @@ -266,6 +310,11 @@ func WaitForTables(db Cassandra, timeout time.Duration, tables ...string) error
return nil
}

// Deprecated: WaitForTables is deprecated. Use WaitForTablesCtx instead.
func WaitForTables(db Cassandra, timeout time.Duration, tables ...string) error {
return WaitForTablesCtx(context.Background(), db, timeout, tables...)
}

func translateDuration(k string, df time.Duration) (time.Duration, error) {
if k == "" {
return df, nil
Expand Down
6 changes: 6 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: "3.9"
services:
cassandra:
image: cassandra
ports:
- "9042:9042"
22 changes: 22 additions & 0 deletions testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (

type TestErrorCassandra struct{}

func (c *TestErrorCassandra) QueryCtx(_ context.Context, consistency gocql.Consistency, queryString string, queryParams ...interface{}) *gocql.Query {
return nil
}

func (c *TestErrorCassandra) Query(consistency gocql.Consistency, queryString string, queryParams ...interface{}) *gocql.Query {
return nil
}
Expand All @@ -29,10 +33,18 @@ func (c *TestErrorCassandra) ExecuteQuery(queryString string, queryParams ...int
return fmt.Errorf("Error during ExecuteQuery")
}

func (c *TestErrorCassandra) ExecuteBatchCtx(_ context.Context, batchType gocql.BatchType, queries []string, params [][]interface{}) error {
return fmt.Errorf("Error during ExecuteBatchCtx")
}

func (c *TestErrorCassandra) ExecuteBatch(batchType gocql.BatchType, queries []string, params [][]interface{}) error {
return fmt.Errorf("Error during ExecuteBatch")
}

func (c *TestErrorCassandra) ExecuteUnloggedBatchCtx(_ context.Context, queries []string, params [][]interface{}) error {
return fmt.Errorf("Error during ExecuteUnloggedBatchCtx")
}

func (c *TestErrorCassandra) ExecuteUnloggedBatch(queries []string, params [][]interface{}) error {
return fmt.Errorf("Error during ExecuteUnloggedBatch")
}
Expand All @@ -45,10 +57,20 @@ func (c *TestErrorCassandra) ScanQuery(queryString string, queryParams []interfa
return fmt.Errorf("Error during ScanQuery")
}

func (c *TestErrorCassandra) ScanCASQueryCtx(_ context.Context, queryString string, queryParams []interface{}, outParams ...interface{}) (bool, error) {
return false, fmt.Errorf("Error during ScanCASQueryCtx")
}

func (c *TestErrorCassandra) ScanCASQuery(queryString string, queryParams []interface{}, outParams ...interface{}) (bool, error) {
return false, fmt.Errorf("Error during ScanCASQuery")
}

func (c *TestErrorCassandra) IterQueryCtx(_ context.Context, queryString string, queryParams []interface{}, outParams ...interface{}) func() (int, bool, error) {
return func() (int, bool, error) {
return 0, true, fmt.Errorf("Error during IterQueryCtx")
}
}

func (c *TestErrorCassandra) IterQuery(queryString string, queryParams []interface{}, outParams ...interface{}) func() (int, bool, error) {
return func() (int, bool, error) {
return 0, true, fmt.Errorf("Error during IterQuery")
Expand Down

0 comments on commit b517452

Please sign in to comment.