From 4d3a132e0b4e09c80a7bdc1b6563de493291a526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Putra?= Date: Wed, 20 Jul 2022 17:33:54 +0200 Subject: [PATCH] transport: added TLS support. Fixes #262 --- session.go | 2 + session_integration_test.go | 108 ++++++++++++++++++++++++++++++++++++ transport/conn.go | 30 ++++++++++ transport/pool.go | 16 ++++-- 4 files changed, 152 insertions(+), 4 deletions(-) diff --git a/session.go b/session.go index a449bbff..8ff5be4e 100644 --- a/session.go +++ b/session.go @@ -87,6 +87,8 @@ func (cfg SessionConfig) Clone() SessionConfig { v.Events = make([]EventType, len(cfg.Events)) copy(v.Events, cfg.Events) + v.TLSConfig = v.TLSConfig.Clone() + return v } diff --git a/session_integration_test.go b/session_integration_test.go index 8d8dc4bd..410e0cf0 100644 --- a/session_integration_test.go +++ b/session_integration_test.go @@ -3,7 +3,10 @@ package scylla import ( + "crypto/tls" + "crypto/x509" "errors" + "io/ioutil" "testing" "go.uber.org/goleak" @@ -230,3 +233,108 @@ func TestSessionIterIntegration(t *testing.T) { // nolint:paralleltest // Integr } } } + +var ( + caPath = "testdata/tls/cadb.pem" + certPath = "testdata/tls/db.crt" + keyPath = "testdata/tls/db.key" +) + +func newCertPoolFromFile(t *testing.T, path string) *x509.CertPool { + certPool := x509.NewCertPool() + pem, err := ioutil.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + if !certPool.AppendCertsFromPEM(pem) { + t.Fatalf("failed parsing of CA certs") + } + + return certPool +} + +func makeCertificatesFromFiles(t *testing.T, certPath, keyPath string) []tls.Certificate { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + t.Fatal(err) + } + + return []tls.Certificate{cert} +} + +func TestTLSIntegration(t *testing.T) { + testCases := []struct { + name string + tlsConfig *tls.Config + }{ + { + name: "no tls", + tlsConfig: nil, + }, + { + name: "tls - no client verification", + tlsConfig: &tls.Config{ + RootCAs: x509.NewCertPool(), + InsecureSkipVerify: true, + Certificates: makeCertificatesFromFiles(t, certPath, keyPath), + }, + }, + { + name: "tls - with client verification", + tlsConfig: &tls.Config{ + RootCAs: newCertPoolFromFile(t, caPath), + InsecureSkipVerify: false, + ServerName: "192.168.100.100", + Certificates: makeCertificatesFromFiles(t, certPath, keyPath), + }, + }, + } + + for i := 0; i < len(testCases); i++ { + tc := testCases[i] + t.Run(tc.name, func(t *testing.T) { + cfg := testingSessionConfig.Clone() + cfg.TLSConfig = tc.tlsConfig + cfg.Keyspace = "" + cfg.Hosts = []string{"192.168.100.100"} + if cfg.TLSConfig != nil { + cfg.DefaultPort = "9142" + } + + session, err := NewSession(cfg) + if err != nil { + t.Fatal(err) + } + + stmts := []string{ + "CREATE KEYSPACE IF NOT EXISTS mykeyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 1}", + "CREATE TABLE IF NOT EXISTS mykeyspace.users (user_id int, fname text, lname text, PRIMARY KEY((user_id)))", + "INSERT INTO mykeyspace.users(user_id, fname, lname) VALUES (1, 'rick', 'sanchez')", + "INSERT INTO mykeyspace.users(user_id, fname, lname) VALUES (4, 'rust', 'cohle')", + } + + for _, stmt := range stmts { + q := session.Query(stmt) + if _, err := q.Exec(); err != nil { + t.Fatal(err) + } + } + + q := session.Query("SELECT COUNT(*) FROM mykeyspace.users") + if r, err := q.Exec(); err != nil { + t.Fatal(err) + } else { + n, err := r.Rows[0][0].AsInt64() + t.Log(n) + if err != nil { + t.Fatal(err) + } + + if n != 2 { + t.Fatalf("expected 2, got %d", n) + } + } + }) + } +} diff --git a/transport/conn.go b/transport/conn.go index fdf38f5f..22d8c6bf 100644 --- a/transport/conn.go +++ b/transport/conn.go @@ -2,6 +2,7 @@ package transport import ( "bufio" + "crypto/tls" "encoding/binary" "errors" "fmt" @@ -297,6 +298,10 @@ type ConnConfig struct { TCPNoDelay bool Timeout time.Duration + // If not nil, all connections will use TLS according to TLSConfig, + // please note that the default port (9042) may not support TLS. + TLSConfig *tls.Config + DefaultConsistency frame.Consistency DefaultPort string @@ -378,9 +383,34 @@ func OpenConn(addr string, localAddr *net.TCPAddr, cfg ConnConfig) (*Conn, error return nil, fmt.Errorf("set TCP no delay option: %w", err) } + if cfg.TLSConfig != nil { + tConn, err := WrapTLS(tcpConn, cfg.TLSConfig) + if err != nil { + return nil, err + } + + return WrapConn(tConn, cfg) + } + return WrapConn(tcpConn, cfg) } +func WrapTLS(conn *net.TCPConn, cfg *tls.Config) (net.Conn, error) { + cfg = cfg.Clone() + tconn := tls.Client(conn, cfg) + if err := tconn.Handshake(); err != nil { + if err := tconn.Close(); err != nil { + log.Printf("%s failed to close: %s", tconn.RemoteAddr(), err) + } else { + log.Printf("%s closed", tconn.RemoteAddr()) + } + + return nil, err + } + + return tconn, nil +} + // WrapConn transforms tcp connection to a working Scylla connection. // If error and connection are returned the connection is not valid and must be closed by the caller. func WrapConn(conn net.Conn, cfg ConnConfig) (*Conn, error) { diff --git a/transport/pool.go b/transport/pool.go index ac9910fc..fc576449 100644 --- a/transport/pool.go +++ b/transport/pool.go @@ -147,12 +147,20 @@ func (r *PoolRefiller) init(host string) error { conn.Close() return fmt.Errorf("supported: %w", err) } - ss := s.ScyllaSupported() - if v, ok := s.Options[ScyllaShardAwarePort]; ok { - r.addr = net.JoinHostPort(host, v[0]) + ss := s.ScyllaSupported() + if r.cfg.TLSConfig != nil { + if v, ok := s.Options[ScyllaShardAwarePortSSL]; ok { + r.addr = net.JoinHostPort(host, v[0]) + } else { + return fmt.Errorf("missing encrypted shard aware port information %v", s.Options) + } } else { - return fmt.Errorf("missing shard aware port information %v", s.Options) + if v, ok := s.Options[ScyllaShardAwarePort]; ok { + r.addr = net.JoinHostPort(host, v[0]) + } else { + return fmt.Errorf("missing shard aware port information %v", s.Options) + } } r.pool = ConnPool{