Skip to content

Commit

Permalink
Separate test command line to a package
Browse files Browse the repository at this point in the history
  • Loading branch information
dkropachev committed Jul 23, 2024
1 parent 5c7adac commit fe93f24
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 89 deletions.
11 changes: 6 additions & 5 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
"github.com/gocql/gocql/internal/testcmdline"
"math"
"math/big"
"net"
Expand Down Expand Up @@ -2126,8 +2127,8 @@ func TestGetKeyspaceMetadata(t *testing.T) {
if err != nil {
t.Fatalf("Error converting string to int with err: %v", err)
}
if rfInt != *flagRF {
t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt)
if rfInt != *testcmdline.RF {
t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt)
}
}

Expand Down Expand Up @@ -2494,8 +2495,8 @@ func TestUnmarshallNestedTypes(t *testing.T) {
}

func TestSchemaReset(t *testing.T) {
if flagCassVersion.Major == 0 || flagCassVersion.Before(2, 1, 3) {
t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", flagCassVersion)
if testcmdline.CassVersion.Major == 0 || testcmdline.CassVersion.Before(2, 1, 3) {
t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", testcmdline.CassVersion)
}

cluster := createCluster()
Expand Down Expand Up @@ -2560,7 +2561,7 @@ func TestCreateSession_DontSwallowError(t *testing.T) {
t.Fatal("expected to get an error for unsupported protocol")
}

if flagCassVersion.Major < 3 {
if testcmdline.CassVersion.Major < 3 {
// TODO: we should get a distinct error type here which include the underlying
// cassandra error about the protocol version, for now check this here.
if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
Expand Down
3 changes: 2 additions & 1 deletion cloud_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/gocql/gocql/internal/testcmdline"
"io"
"net"
"os"
Expand All @@ -22,7 +23,7 @@ import (
)

func TestCloudConnection(t *testing.T) {
if !*gocql.FlagRunSslTest {
if !*testcmdline.RunSslTest {
t.Skip("Skipping because SSL is not enabled on cluster")
}

Expand Down
59 changes: 20 additions & 39 deletions common_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package gocql

import (
"flag"
"fmt"
"github.com/gocql/gocql/internal/testcmdline"
"log"
"net"
"reflect"
Expand All @@ -12,39 +12,20 @@ import (
"time"
)

var (
flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples")
flagProto = flag.Int("proto", 0, "protcol version")
flagCQL = flag.String("cql", "3.0.0", "CQL version")
flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
flagRetry = flag.Int("retries", 5, "number of times to retry queries")
flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
flagCompressTest = flag.String("compressor", "", "compressor to use")
flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")

flagCassVersion cassVersion
)

func init() {
flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")

log.SetFlags(log.Lshortfile | log.LstdFlags)
}

func getClusterHosts() []string {
return strings.Split(*flagCluster, ",")
return strings.Split(*testcmdline.Cluster, ",")
}

func getMultiNodeClusterHosts() []string {
return strings.Split(*flagMultiNodeCluster, ",")
return strings.Split(*testcmdline.MultiNodeCluster, ",")
}

func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
if *flagRunSslTest {
if *testcmdline.RunSslTest {
cluster.Port = 9142
cluster.SslOpts = &SslOptions{
CertPath: "testdata/pki/gocql.crt",
Expand Down Expand Up @@ -81,21 +62,21 @@ func createTable(s *Session, table string) error {
func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
clusterHosts := getClusterHosts()
cluster := NewCluster(clusterHosts...)
cluster.ProtoVersion = *flagProto
cluster.CQLVersion = *flagCQL
cluster.Timeout = *flagTimeout
cluster.ProtoVersion = *testcmdline.Proto
cluster.CQLVersion = *testcmdline.CQL
cluster.Timeout = *testcmdline.Timeout
cluster.Consistency = Quorum
cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
if *flagRetry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
if *testcmdline.Retry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry}
}

switch *flagCompressTest {
switch *testcmdline.CompressTest {
case "snappy":
cluster.Compressor = &SnappyCompressor{}
case "":
default:
panic("invalid compressor: " + *flagCompressTest)
panic("invalid compressor: " + *testcmdline.CompressTest)
}

cluster = addSslOptions(cluster)
Expand All @@ -110,21 +91,21 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
func createMultiNodeCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
clusterHosts := getMultiNodeClusterHosts()
cluster := NewCluster(clusterHosts...)
cluster.ProtoVersion = *flagProto
cluster.CQLVersion = *flagCQL
cluster.Timeout = *flagTimeout
cluster.ProtoVersion = *testcmdline.Proto
cluster.CQLVersion = *testcmdline.CQL
cluster.Timeout = *testcmdline.Timeout
cluster.Consistency = Quorum
cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
if *flagRetry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
if *testcmdline.Retry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry}
}

switch *flagCompressTest {
switch *testcmdline.CompressTest {
case "snappy":
cluster.Compressor = &SnappyCompressor{}
case "":
default:
panic("invalid compressor: " + *flagCompressTest)
panic("invalid compressor: " + *testcmdline.CompressTest)
}

cluster = addSslOptions(cluster)
Expand Down Expand Up @@ -156,7 +137,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : %d
}`, keyspace, *flagRF))
}`, keyspace, *testcmdline.RF))

if err != nil {
panic(fmt.Sprintf("unable to create keyspace: %v", err))
Expand Down Expand Up @@ -232,7 +213,7 @@ func createViews(t *testing.T, session *Session) {
}

func createMaterializedViews(t *testing.T, session *Session) {
if flagCassVersion.Before(3, 0, 0) {
if testcmdline.CassVersion.Before(3, 0, 0) {
return
}
if err := session.Query(`CREATE TABLE IF NOT EXISTS gocql_test.view_table (
Expand Down
1 change: 0 additions & 1 deletion export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

package gocql

var FlagRunSslTest = flagRunSslTest
var CreateCluster = createCluster
var TestLogger = &testLogger{}
var WaitUntilPoolsStopFilling = waitUntilPoolsStopFilling
Expand Down
2 changes: 1 addition & 1 deletion integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function scylla_down() {
}

function scylla_restart() {
scylla_down
# scylla_down
scylla_up
}

Expand Down
14 changes: 8 additions & 6 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import (
"reflect"
"testing"
"time"

"github.com/gocql/gocql/internal/testcmdline"
)

// TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections
func TestAuthentication(t *testing.T) {

if *flagProto < 2 {
if *testcmdline.Proto < 2 {
t.Skip("Authentication is not supported with protocol < 2")
}

if !*flagRunAuthTest {
if !*testcmdline.RunAuthTest {
t.Skip("Authentication is not configured in the target cluster")
}

Expand Down Expand Up @@ -60,21 +62,21 @@ func TestRingDiscovery(t *testing.T) {
session := createSessionFromCluster(cluster, t)
defer session.Close()

if *clusterSize > 1 {
if *testcmdline.ClusterSize > 1 {
// wait for autodiscovery to update the pool with the list of known hosts
time.Sleep(*flagAutoWait)
time.Sleep(*testcmdline.AutoWait)
}

session.pool.mu.RLock()
defer session.pool.mu.RUnlock()
size := len(session.pool.hostConnPools)

if *clusterSize != size {
if *testcmdline.ClusterSize != size {
for p, pool := range session.pool.hostConnPools {
t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String())

}
t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
t.Errorf("Expected a cluster size of %d, but actual size was %d", *testcmdline.ClusterSize, size)
}
}

Expand Down
42 changes: 16 additions & 26 deletions internal/testutils/flags.go → internal/testcmdline/flags.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
package testutils
package testcmdline

import (
"flag"
"fmt"
"log"
"strconv"
"strings"
"time"

"github.com/gocql/gocql"
)

var (
flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples")
flagProto = flag.Int("proto", 0, "protcol version")
flagCQL = flag.String("cql", "3.0.0", "CQL version")
flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
flagRetry = flag.Int("retries", 5, "number of times to retry queries")
flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
flagCompressTest = flag.String("compressor", "", "compressor to use")
flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")

flagCassVersion cassVersion
Cluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
MultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples")
Proto = flag.Int("proto", 0, "protcol version")
CQL = flag.String("cql", "3.0.0", "CQL version")
RF = flag.Int("rf", 1, "replication factor for test keyspace")
ClusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
Retry = flag.Int("retries", 5, "number of times to retry queries")
AutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
RunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
RunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
CompressTest = flag.String("compressor", "", "compressor to use")
Timeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
CassVersion cassVersion
)

type cassVersion struct {
Expand All @@ -37,11 +33,7 @@ func (c *cassVersion) Set(v string) error {
return nil
}

return c.UnmarshalCQL(nil, []byte(v))
}

func (c *cassVersion) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
return c.unmarshal(data)
return c.unmarshal([]byte(v))
}

func (c *cassVersion) unmarshal(data []byte) error {
Expand Down Expand Up @@ -108,7 +100,5 @@ func (c cassVersion) nodeUpDelay() time.Duration {
}

func init() {
flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")

log.SetFlags(log.Lshortfile | log.LstdFlags)
flag.Var(&CassVersion, "gocql.cversion", "the cassandra version being tested against")
}
21 changes: 11 additions & 10 deletions internal/testutils/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package testutils
import (
"context"
"fmt"
"github.com/gocql/gocql/internal/testcmdline"
"log"
"strings"
"sync"
Expand All @@ -22,21 +23,21 @@ func CreateSession(tb testing.TB, opts ...func(config *gocql.ClusterConfig)) *go
func CreateCluster(opts ...func(*gocql.ClusterConfig)) *gocql.ClusterConfig {
clusterHosts := getClusterHosts()
cluster := gocql.NewCluster(clusterHosts...)
cluster.ProtoVersion = *flagProto
cluster.CQLVersion = *flagCQL
cluster.Timeout = *flagTimeout
cluster.ProtoVersion = *testcmdline.Proto
cluster.CQLVersion = *testcmdline.CQL
cluster.Timeout = *testcmdline.Timeout
cluster.Consistency = gocql.Quorum
cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
if *flagRetry > 0 {
cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *flagRetry}
if *testcmdline.Retry > 0 {
cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *testcmdline.Retry}
}

switch *flagCompressTest {
switch *testcmdline.CompressTest {
case "snappy":
cluster.Compressor = &gocql.SnappyCompressor{}
case "":
default:
panic("invalid compressor: " + *flagCompressTest)
panic("invalid compressor: " + *testcmdline.CompressTest)
}

cluster = addSslOptions(cluster)
Expand Down Expand Up @@ -69,7 +70,7 @@ func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocq
}

func getClusterHosts() []string {
return strings.Split(*flagCluster, ",")
return strings.Split(*testcmdline.Cluster, ",")
}

func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string) {
Expand All @@ -92,7 +93,7 @@ func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : %d
}`, keyspace, *flagRF))
}`, keyspace, *testcmdline.RF))

if err != nil {
panic(fmt.Sprintf("unable to create keyspace: %v", err))
Expand Down Expand Up @@ -120,7 +121,7 @@ func CreateTable(s *gocql.Session, table string) error {
}

func addSslOptions(cluster *gocql.ClusterConfig) *gocql.ClusterConfig {
if *flagRunSslTest {
if *testcmdline.RunSslTest {
cluster.Port = 9142
cluster.SslOpts = &gocql.SslOptions{
CertPath: "testdata/pki/gocql.crt",
Expand Down

0 comments on commit fe93f24

Please sign in to comment.