Skip to content

Commit

Permalink
transport: added TLS support.
Browse files Browse the repository at this point in the history
Fixes #262
  • Loading branch information
Kulezi committed Jul 22, 2022
1 parent 7677981 commit 4d3a132
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 4 deletions.
2 changes: 2 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func (cfg SessionConfig) Clone() SessionConfig {
v.Events = make([]EventType, len(cfg.Events))
copy(v.Events, cfg.Events)

v.TLSConfig = v.TLSConfig.Clone()

return v
}

Expand Down
108 changes: 108 additions & 0 deletions session_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
package scylla

import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"testing"

"go.uber.org/goleak"
Expand Down Expand Up @@ -230,3 +233,108 @@ func TestSessionIterIntegration(t *testing.T) { // nolint:paralleltest // Integr
}
}
}

var (
caPath = "testdata/tls/cadb.pem"
certPath = "testdata/tls/db.crt"
keyPath = "testdata/tls/db.key"
)

func newCertPoolFromFile(t *testing.T, path string) *x509.CertPool {
certPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(path)
if err != nil {
t.Fatal(err)
}

if !certPool.AppendCertsFromPEM(pem) {
t.Fatalf("failed parsing of CA certs")
}

return certPool
}

func makeCertificatesFromFiles(t *testing.T, certPath, keyPath string) []tls.Certificate {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
t.Fatal(err)
}

return []tls.Certificate{cert}
}

func TestTLSIntegration(t *testing.T) {
testCases := []struct {
name string
tlsConfig *tls.Config
}{
{
name: "no tls",
tlsConfig: nil,
},
{
name: "tls - no client verification",
tlsConfig: &tls.Config{
RootCAs: x509.NewCertPool(),
InsecureSkipVerify: true,
Certificates: makeCertificatesFromFiles(t, certPath, keyPath),
},
},
{
name: "tls - with client verification",
tlsConfig: &tls.Config{
RootCAs: newCertPoolFromFile(t, caPath),
InsecureSkipVerify: false,
ServerName: "192.168.100.100",
Certificates: makeCertificatesFromFiles(t, certPath, keyPath),
},
},
}

for i := 0; i < len(testCases); i++ {
tc := testCases[i]
t.Run(tc.name, func(t *testing.T) {
cfg := testingSessionConfig.Clone()
cfg.TLSConfig = tc.tlsConfig
cfg.Keyspace = ""
cfg.Hosts = []string{"192.168.100.100"}
if cfg.TLSConfig != nil {
cfg.DefaultPort = "9142"
}

session, err := NewSession(cfg)
if err != nil {
t.Fatal(err)
}

stmts := []string{
"CREATE KEYSPACE IF NOT EXISTS mykeyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 1}",
"CREATE TABLE IF NOT EXISTS mykeyspace.users (user_id int, fname text, lname text, PRIMARY KEY((user_id)))",
"INSERT INTO mykeyspace.users(user_id, fname, lname) VALUES (1, 'rick', 'sanchez')",
"INSERT INTO mykeyspace.users(user_id, fname, lname) VALUES (4, 'rust', 'cohle')",
}

for _, stmt := range stmts {
q := session.Query(stmt)
if _, err := q.Exec(); err != nil {
t.Fatal(err)
}
}

q := session.Query("SELECT COUNT(*) FROM mykeyspace.users")
if r, err := q.Exec(); err != nil {
t.Fatal(err)
} else {
n, err := r.Rows[0][0].AsInt64()
t.Log(n)
if err != nil {
t.Fatal(err)
}

if n != 2 {
t.Fatalf("expected 2, got %d", n)
}
}
})
}
}
30 changes: 30 additions & 0 deletions transport/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package transport

import (
"bufio"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -297,6 +298,10 @@ type ConnConfig struct {
TCPNoDelay bool
Timeout time.Duration

// If not nil, all connections will use TLS according to TLSConfig,
// please note that the default port (9042) may not support TLS.
TLSConfig *tls.Config

DefaultConsistency frame.Consistency
DefaultPort string

Expand Down Expand Up @@ -378,9 +383,34 @@ func OpenConn(addr string, localAddr *net.TCPAddr, cfg ConnConfig) (*Conn, error
return nil, fmt.Errorf("set TCP no delay option: %w", err)
}

if cfg.TLSConfig != nil {
tConn, err := WrapTLS(tcpConn, cfg.TLSConfig)
if err != nil {
return nil, err
}

return WrapConn(tConn, cfg)
}

return WrapConn(tcpConn, cfg)
}

func WrapTLS(conn *net.TCPConn, cfg *tls.Config) (net.Conn, error) {
cfg = cfg.Clone()
tconn := tls.Client(conn, cfg)
if err := tconn.Handshake(); err != nil {
if err := tconn.Close(); err != nil {
log.Printf("%s failed to close: %s", tconn.RemoteAddr(), err)
} else {
log.Printf("%s closed", tconn.RemoteAddr())
}

return nil, err
}

return tconn, nil
}

// WrapConn transforms tcp connection to a working Scylla connection.
// If error and connection are returned the connection is not valid and must be closed by the caller.
func WrapConn(conn net.Conn, cfg ConnConfig) (*Conn, error) {
Expand Down
16 changes: 12 additions & 4 deletions transport/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,20 @@ func (r *PoolRefiller) init(host string) error {
conn.Close()
return fmt.Errorf("supported: %w", err)
}
ss := s.ScyllaSupported()

if v, ok := s.Options[ScyllaShardAwarePort]; ok {
r.addr = net.JoinHostPort(host, v[0])
ss := s.ScyllaSupported()
if r.cfg.TLSConfig != nil {
if v, ok := s.Options[ScyllaShardAwarePortSSL]; ok {
r.addr = net.JoinHostPort(host, v[0])
} else {
return fmt.Errorf("missing encrypted shard aware port information %v", s.Options)
}
} else {
return fmt.Errorf("missing shard aware port information %v", s.Options)
if v, ok := s.Options[ScyllaShardAwarePort]; ok {
r.addr = net.JoinHostPort(host, v[0])
} else {
return fmt.Errorf("missing shard aware port information %v", s.Options)
}
}

r.pool = ConnPool{
Expand Down

0 comments on commit 4d3a132

Please sign in to comment.