From fe93f24afef3cda53358db70eae729b850a2c918 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 23 Jul 2024 10:42:29 -0400 Subject: [PATCH] Separate test command line to a package --- cassandra_test.go | 11 ++-- cloud_cluster_test.go | 3 +- common_test.go | 59 +++++++------------- export_test.go | 1 - integration.sh | 2 +- integration_test.go | 14 +++-- internal/{testutils => testcmdline}/flags.go | 42 ++++++-------- internal/testutils/cluster.go | 21 +++---- 8 files changed, 64 insertions(+), 89 deletions(-) rename internal/{testutils => testcmdline}/flags.go (53%) diff --git a/cassandra_test.go b/cassandra_test.go index f7539ded4..441d84807 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "github.com/gocql/gocql/internal/testcmdline" "math" "math/big" "net" @@ -2126,8 +2127,8 @@ func TestGetKeyspaceMetadata(t *testing.T) { if err != nil { t.Fatalf("Error converting string to int with err: %v", err) } - if rfInt != *flagRF { - t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt) + if rfInt != *testcmdline.RF { + t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt) } } @@ -2494,8 +2495,8 @@ func TestUnmarshallNestedTypes(t *testing.T) { } func TestSchemaReset(t *testing.T) { - if flagCassVersion.Major == 0 || flagCassVersion.Before(2, 1, 3) { - t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", flagCassVersion) + if testcmdline.CassVersion.Major == 0 || testcmdline.CassVersion.Before(2, 1, 3) { + t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", testcmdline.CassVersion) } cluster := createCluster() @@ -2560,7 +2561,7 @@ func TestCreateSession_DontSwallowError(t *testing.T) { t.Fatal("expected to get an error for unsupported protocol") } - if flagCassVersion.Major < 3 { + if testcmdline.CassVersion.Major < 3 { // TODO: we should get a distinct error type here which include the underlying // cassandra error about the protocol version, for now check this here. if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") { diff --git a/cloud_cluster_test.go b/cloud_cluster_test.go index 4133ac56e..20f9a99b7 100644 --- a/cloud_cluster_test.go +++ b/cloud_cluster_test.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/gocql/gocql/internal/testcmdline" "io" "net" "os" @@ -22,7 +23,7 @@ import ( ) func TestCloudConnection(t *testing.T) { - if !*gocql.FlagRunSslTest { + if !*testcmdline.RunSslTest { t.Skip("Skipping because SSL is not enabled on cluster") } diff --git a/common_test.go b/common_test.go index abbe91cce..55b794618 100644 --- a/common_test.go +++ b/common_test.go @@ -1,8 +1,8 @@ package gocql import ( - "flag" "fmt" + "github.com/gocql/gocql/internal/testcmdline" "log" "net" "reflect" @@ -12,39 +12,20 @@ import ( "time" ) -var ( - flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") - flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") - flagProto = flag.Int("proto", 0, "protcol version") - flagCQL = flag.String("cql", "3.0.0", "CQL version") - flagRF = flag.Int("rf", 1, "replication factor for test keyspace") - clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") - flagRetry = flag.Int("retries", 5, "number of times to retry queries") - flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") - flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") - flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") - flagCompressTest = flag.String("compressor", "", "compressor to use") - flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") - - flagCassVersion cassVersion -) - func init() { - flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") - log.SetFlags(log.Lshortfile | log.LstdFlags) } func getClusterHosts() []string { - return strings.Split(*flagCluster, ",") + return strings.Split(*testcmdline.Cluster, ",") } func getMultiNodeClusterHosts() []string { - return strings.Split(*flagMultiNodeCluster, ",") + return strings.Split(*testcmdline.MultiNodeCluster, ",") } func addSslOptions(cluster *ClusterConfig) *ClusterConfig { - if *flagRunSslTest { + if *testcmdline.RunSslTest { cluster.Port = 9142 cluster.SslOpts = &SslOptions{ CertPath: "testdata/pki/gocql.crt", @@ -81,21 +62,21 @@ func createTable(s *Session, table string) error { func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { clusterHosts := getClusterHosts() cluster := NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -110,21 +91,21 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { func createMultiNodeCluster(opts ...func(*ClusterConfig)) *ClusterConfig { clusterHosts := getMultiNodeClusterHosts() cluster := NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -156,7 +137,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) { WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d - }`, keyspace, *flagRF)) + }`, keyspace, *testcmdline.RF)) if err != nil { panic(fmt.Sprintf("unable to create keyspace: %v", err)) @@ -232,7 +213,7 @@ func createViews(t *testing.T, session *Session) { } func createMaterializedViews(t *testing.T, session *Session) { - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { return } if err := session.Query(`CREATE TABLE IF NOT EXISTS gocql_test.view_table ( diff --git a/export_test.go b/export_test.go index 830436303..3295697db 100644 --- a/export_test.go +++ b/export_test.go @@ -3,7 +3,6 @@ package gocql -var FlagRunSslTest = flagRunSslTest var CreateCluster = createCluster var TestLogger = &testLogger{} var WaitUntilPoolsStopFilling = waitUntilPoolsStopFilling diff --git a/integration.sh b/integration.sh index 07d67f64b..5fbb08a5f 100755 --- a/integration.sh +++ b/integration.sh @@ -21,7 +21,7 @@ function scylla_down() { } function scylla_restart() { - scylla_down +# scylla_down scylla_up } diff --git a/integration_test.go b/integration_test.go index f548a829f..11f22c445 100644 --- a/integration_test.go +++ b/integration_test.go @@ -9,16 +9,18 @@ import ( "reflect" "testing" "time" + + "github.com/gocql/gocql/internal/testcmdline" ) // TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections func TestAuthentication(t *testing.T) { - if *flagProto < 2 { + if *testcmdline.Proto < 2 { t.Skip("Authentication is not supported with protocol < 2") } - if !*flagRunAuthTest { + if !*testcmdline.RunAuthTest { t.Skip("Authentication is not configured in the target cluster") } @@ -60,21 +62,21 @@ func TestRingDiscovery(t *testing.T) { session := createSessionFromCluster(cluster, t) defer session.Close() - if *clusterSize > 1 { + if *testcmdline.ClusterSize > 1 { // wait for autodiscovery to update the pool with the list of known hosts - time.Sleep(*flagAutoWait) + time.Sleep(*testcmdline.AutoWait) } session.pool.mu.RLock() defer session.pool.mu.RUnlock() size := len(session.pool.hostConnPools) - if *clusterSize != size { + if *testcmdline.ClusterSize != size { for p, pool := range session.pool.hostConnPools { t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String()) } - t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size) + t.Errorf("Expected a cluster size of %d, but actual size was %d", *testcmdline.ClusterSize, size) } } diff --git a/internal/testutils/flags.go b/internal/testcmdline/flags.go similarity index 53% rename from internal/testutils/flags.go rename to internal/testcmdline/flags.go index ce19c3080..938c346f5 100644 --- a/internal/testutils/flags.go +++ b/internal/testcmdline/flags.go @@ -1,31 +1,27 @@ -package testutils +package testcmdline import ( "flag" "fmt" - "log" "strconv" "strings" "time" - - "github.com/gocql/gocql" ) var ( - flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") - flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") - flagProto = flag.Int("proto", 0, "protcol version") - flagCQL = flag.String("cql", "3.0.0", "CQL version") - flagRF = flag.Int("rf", 1, "replication factor for test keyspace") - clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") - flagRetry = flag.Int("retries", 5, "number of times to retry queries") - flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") - flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") - flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") - flagCompressTest = flag.String("compressor", "", "compressor to use") - flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") - - flagCassVersion cassVersion + Cluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") + MultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") + Proto = flag.Int("proto", 0, "protcol version") + CQL = flag.String("cql", "3.0.0", "CQL version") + RF = flag.Int("rf", 1, "replication factor for test keyspace") + ClusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") + Retry = flag.Int("retries", 5, "number of times to retry queries") + AutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") + RunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") + RunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") + CompressTest = flag.String("compressor", "", "compressor to use") + Timeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") + CassVersion cassVersion ) type cassVersion struct { @@ -37,11 +33,7 @@ func (c *cassVersion) Set(v string) error { return nil } - return c.UnmarshalCQL(nil, []byte(v)) -} - -func (c *cassVersion) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { - return c.unmarshal(data) + return c.unmarshal([]byte(v)) } func (c *cassVersion) unmarshal(data []byte) error { @@ -108,7 +100,5 @@ func (c cassVersion) nodeUpDelay() time.Duration { } func init() { - flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") - - log.SetFlags(log.Lshortfile | log.LstdFlags) + flag.Var(&CassVersion, "gocql.cversion", "the cassandra version being tested against") } diff --git a/internal/testutils/cluster.go b/internal/testutils/cluster.go index 431614be4..f25829af9 100644 --- a/internal/testutils/cluster.go +++ b/internal/testutils/cluster.go @@ -3,6 +3,7 @@ package testutils import ( "context" "fmt" + "github.com/gocql/gocql/internal/testcmdline" "log" "strings" "sync" @@ -22,21 +23,21 @@ func CreateSession(tb testing.TB, opts ...func(config *gocql.ClusterConfig)) *go func CreateCluster(opts ...func(*gocql.ClusterConfig)) *gocql.ClusterConfig { clusterHosts := getClusterHosts() cluster := gocql.NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = gocql.Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &gocql.SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -69,7 +70,7 @@ func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocq } func getClusterHosts() []string { - return strings.Split(*flagCluster, ",") + return strings.Split(*testcmdline.Cluster, ",") } func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string) { @@ -92,7 +93,7 @@ func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d - }`, keyspace, *flagRF)) + }`, keyspace, *testcmdline.RF)) if err != nil { panic(fmt.Sprintf("unable to create keyspace: %v", err)) @@ -120,7 +121,7 @@ func CreateTable(s *gocql.Session, table string) error { } func addSslOptions(cluster *gocql.ClusterConfig) *gocql.ClusterConfig { - if *flagRunSslTest { + if *testcmdline.RunSslTest { cluster.Port = 9142 cluster.SslOpts = &gocql.SslOptions{ CertPath: "testdata/pki/gocql.crt",