Skip to content

Commit

Permalink
session: prepare now prepares on all nodes
Browse files Browse the repository at this point in the history
Fixes #249
  • Loading branch information
Kulezi committed Aug 1, 2022
1 parent 4d3a132 commit 5bc350c
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 22 deletions.
47 changes: 31 additions & 16 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scylla
import (
"fmt"
"log"
"sync"

"github.com/mmatczuk/scylla-go-driver/frame"
"github.com/mmatczuk/scylla-go-driver/transport"
Expand Down Expand Up @@ -147,24 +148,38 @@ func (s *Session) Query(content string) Query {
}

func (s *Session) Prepare(content string) (Query, error) {
n := s.policy.Node(s.cluster.NewQueryInfo(), 0)
conn := n.LeastBusyConn()
if conn == nil {
return Query{}, errNoConnection
}

stmt := transport.Statement{Content: content, Consistency: frame.ALL}
res, err := conn.Prepare(stmt)

return Query{session: s,
stmt: res,
exec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes) (transport.QueryResult, error) {
return conn.Execute(stmt, pagingState)
},
asyncExec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes, handler transport.ResponseHandler) {
conn.AsyncExecute(stmt, pagingState, handler)
},
}, err
// Prepare on all nodes concurrently.
nodes := s.cluster.Topology().Nodes
resStmt := make([]transport.Statement, len(nodes))
resErr := make([]error, len(nodes))
var wg sync.WaitGroup
for i := range nodes {
wg.Add(1)
go func(idx int) {
defer wg.Done()
resStmt[idx], resErr[idx] = nodes[idx].Prepare(stmt)
}(i)
}
wg.Wait()

// Find first result that succeeded.
for i := range nodes {
if resErr[i] == nil {
return Query{session: s,
stmt: resStmt[i],
exec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes) (transport.QueryResult, error) {
return conn.Execute(stmt, pagingState)
},
asyncExec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes, handler transport.ResponseHandler) {
conn.AsyncExecute(stmt, pagingState, handler)
},
}, nil
}
}

return Query{}, fmt.Errorf("prepare failed on all nodes, details: %v", resErr)
}

func (s *Session) NewTokenAwarePolicy() transport.HostSelectionPolicy {
Expand Down
39 changes: 39 additions & 0 deletions session_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"io/ioutil"
"testing"
"time"

"go.uber.org/goleak"
)
Expand Down Expand Up @@ -338,3 +339,41 @@ func TestTLSIntegration(t *testing.T) {
})
}
}

func TestPrepare(t *testing.T) {
defer goleak.VerifyNone(t)

cfg := DefaultSessionConfig("ks", "192.168.100.100:9042")
session, err := NewSession(cfg)

if err != nil {
panic(err)
}

defer session.Close()
initStmts := []string{
"DROP KEYSPACE IF EXISTS testks",
"CREATE KEYSPACE IF NOT EXISTS testks WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 1}",
"CREATE TABLE IF NOT EXISTS testks.doubles (pk bigint PRIMARY KEY, v bigint)",
}

for _, stmt := range initStmts {
q := session.Query(stmt)
if _, err := q.Exec(); err != nil {
t.Fatal(err)
}
time.Sleep(time.Second)
}

q, err := session.Prepare("INSERT INTO testks.doubles (pk, v) VALUES (?, ?)")
if err != nil {
t.Fatal(err)
}

for i := int64(0); i < 1000; i++ {
_, err := q.BindInt64(0, i).BindInt64(1, 2*i).Exec()
if err != nil {
t.Fatal(err)
}
}
}
6 changes: 3 additions & 3 deletions transport/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type topology struct {
localDC string
peers peerMap
dcRacks dcRacksMap
nodes []*Node
Nodes []*Node
policyInfo policyInfo
keyspaces ksMap
}
Expand Down Expand Up @@ -219,7 +219,7 @@ func (c *Cluster) refreshTopology() error {
// Every encountered node becomes known host for future use.
c.knownHosts[n.addr] = struct{}{}
t.peers[n.addr] = n
t.nodes = append(t.nodes, n)
t.Nodes = append(t.Nodes, n)
u[uniqueRack{dc: n.datacenter, rack: n.rack}] = struct{}{}
if err := parseTokensFromRow(n, r, &t.policyInfo.ring); err != nil {
return err
Expand Down Expand Up @@ -251,7 +251,7 @@ func newTopology() *topology {
return &topology{
peers: make(peerMap),
dcRacks: make(dcRacksMap),
nodes: make([]*Node, 0),
Nodes: make([]*Node, 0),
policyInfo: policyInfo{
ring: make(Ring, 0),
},
Expand Down
4 changes: 4 additions & 0 deletions transport/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ func (n *Node) Conn(token Token) *Conn {
return n.pool.Conn(token)
}

func (n *Node) Prepare(stmt Statement) (Statement, error) {
return n.LeastBusyConn().Prepare(stmt)
}

type RingEntry struct {
node *Node
token Token
Expand Down
6 changes: 3 additions & 3 deletions transport/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (pi *policyInfo) Preprocess(t *topology, ks keyspace) {
}

func (pi *policyInfo) preprocessSimpleStrategy(t *topology, stg strategy) {
pi.localNodes = t.nodes
pi.localNodes = t.Nodes
sort.Sort(pi.ring)
trie := trieRoot()
for i := range pi.ring {
Expand Down Expand Up @@ -122,14 +122,14 @@ func (pi *policyInfo) preprocessSimpleStrategy(t *topology, stg strategy) {
}

func (pi *policyInfo) preprocessRoundRobinStrategy(t *topology) {
pi.localNodes = t.nodes
pi.localNodes = t.Nodes
pi.remoteNodes = nil
}

func (pi *policyInfo) preprocessDCAwareRoundRobinStrategy(t *topology) {
pi.localNodes = make([]*Node, 0)
pi.remoteNodes = make([]*Node, 0)
for _, v := range t.nodes {
for _, v := range t.Nodes {
if v.datacenter == t.localDC {
pi.localNodes = append(pi.localNodes, v)
} else {
Expand Down

0 comments on commit 5bc350c

Please sign in to comment.