diff --git a/session.go b/session.go index 8ff5be4e..4c4c2314 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package scylla import ( "fmt" "log" + "sync" "github.com/mmatczuk/scylla-go-driver/frame" "github.com/mmatczuk/scylla-go-driver/transport" @@ -147,24 +148,38 @@ func (s *Session) Query(content string) Query { } func (s *Session) Prepare(content string) (Query, error) { - n := s.policy.Node(s.cluster.NewQueryInfo(), 0) - conn := n.LeastBusyConn() - if conn == nil { - return Query{}, errNoConnection - } - stmt := transport.Statement{Content: content, Consistency: frame.ALL} - res, err := conn.Prepare(stmt) - return Query{session: s, - stmt: res, - exec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes) (transport.QueryResult, error) { - return conn.Execute(stmt, pagingState) - }, - asyncExec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes, handler transport.ResponseHandler) { - conn.AsyncExecute(stmt, pagingState, handler) - }, - }, err + // Prepare on all nodes concurrently. + nodes := s.cluster.Topology().Nodes + resStmt := make([]transport.Statement, len(nodes)) + resErr := make([]error, len(nodes)) + var wg sync.WaitGroup + for i := range nodes { + wg.Add(1) + go func(idx int) { + defer wg.Done() + resStmt[idx], resErr[idx] = nodes[idx].Prepare(stmt) + }(i) + } + wg.Wait() + + // Find first result that succeeded. + for i := range nodes { + if resErr[i] == nil { + return Query{session: s, + stmt: resStmt[i], + exec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes) (transport.QueryResult, error) { + return conn.Execute(stmt, pagingState) + }, + asyncExec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes, handler transport.ResponseHandler) { + conn.AsyncExecute(stmt, pagingState, handler) + }, + }, nil + } + } + + return Query{}, fmt.Errorf("prepare failed on all nodes, details: %v", resErr) } func (s *Session) NewTokenAwarePolicy() transport.HostSelectionPolicy { diff --git a/session_integration_test.go b/session_integration_test.go index 410e0cf0..9ff07fe0 100644 --- a/session_integration_test.go +++ b/session_integration_test.go @@ -8,6 +8,7 @@ import ( "errors" "io/ioutil" "testing" + "time" "go.uber.org/goleak" ) @@ -338,3 +339,41 @@ func TestTLSIntegration(t *testing.T) { }) } } + +func TestPrepare(t *testing.T) { + defer goleak.VerifyNone(t) + + cfg := DefaultSessionConfig("ks", "192.168.100.100:9042") + session, err := NewSession(cfg) + + if err != nil { + panic(err) + } + + defer session.Close() + initStmts := []string{ + "DROP KEYSPACE IF EXISTS testks", + "CREATE KEYSPACE IF NOT EXISTS testks WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 1}", + "CREATE TABLE IF NOT EXISTS testks.doubles (pk bigint PRIMARY KEY, v bigint)", + } + + for _, stmt := range initStmts { + q := session.Query(stmt) + if _, err := q.Exec(); err != nil { + t.Fatal(err) + } + time.Sleep(time.Second) + } + + q, err := session.Prepare("INSERT INTO testks.doubles (pk, v) VALUES (?, ?)") + if err != nil { + t.Fatal(err) + } + + for i := int64(0); i < 1000; i++ { + _, err := q.BindInt64(0, i).BindInt64(1, 2*i).Exec() + if err != nil { + t.Fatal(err) + } + } +} diff --git a/transport/cluster.go b/transport/cluster.go index da3404e6..3b1fca69 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -41,7 +41,7 @@ type topology struct { localDC string peers peerMap dcRacks dcRacksMap - nodes []*Node + Nodes []*Node policyInfo policyInfo keyspaces ksMap } @@ -219,7 +219,7 @@ func (c *Cluster) refreshTopology() error { // Every encountered node becomes known host for future use. c.knownHosts[n.addr] = struct{}{} t.peers[n.addr] = n - t.nodes = append(t.nodes, n) + t.Nodes = append(t.Nodes, n) u[uniqueRack{dc: n.datacenter, rack: n.rack}] = struct{}{} if err := parseTokensFromRow(n, r, &t.policyInfo.ring); err != nil { return err @@ -251,7 +251,7 @@ func newTopology() *topology { return &topology{ peers: make(peerMap), dcRacks: make(dcRacksMap), - nodes: make([]*Node, 0), + Nodes: make([]*Node, 0), policyInfo: policyInfo{ ring: make(Ring, 0), }, diff --git a/transport/node.go b/transport/node.go index 81c97d5b..9b4acf76 100644 --- a/transport/node.go +++ b/transport/node.go @@ -37,6 +37,10 @@ func (n *Node) Conn(token Token) *Conn { return n.pool.Conn(token) } +func (n *Node) Prepare(stmt Statement) (Statement, error) { + return n.LeastBusyConn().Prepare(stmt) +} + type RingEntry struct { node *Node token Token diff --git a/transport/policy.go b/transport/policy.go index 19027ef0..937112a8 100644 --- a/transport/policy.go +++ b/transport/policy.go @@ -86,7 +86,7 @@ func (pi *policyInfo) Preprocess(t *topology, ks keyspace) { } func (pi *policyInfo) preprocessSimpleStrategy(t *topology, stg strategy) { - pi.localNodes = t.nodes + pi.localNodes = t.Nodes sort.Sort(pi.ring) trie := trieRoot() for i := range pi.ring { @@ -122,14 +122,14 @@ func (pi *policyInfo) preprocessSimpleStrategy(t *topology, stg strategy) { } func (pi *policyInfo) preprocessRoundRobinStrategy(t *topology) { - pi.localNodes = t.nodes + pi.localNodes = t.Nodes pi.remoteNodes = nil } func (pi *policyInfo) preprocessDCAwareRoundRobinStrategy(t *topology) { pi.localNodes = make([]*Node, 0) pi.remoteNodes = make([]*Node, 0) - for _, v := range t.nodes { + for _, v := range t.Nodes { if v.datacenter == t.localDC { pi.localNodes = append(pi.localNodes, v) } else {