Skip to content

Commit

Permalink
transport: replace net.IP with netip.Addr
Browse files Browse the repository at this point in the history
Fixes #274
  • Loading branch information
Kulezi committed Sep 12, 2022
1 parent 88fc8ef commit a93c76b
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 44 deletions.
5 changes: 4 additions & 1 deletion frame/buffer_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package frame
import (
"fmt"
"log"
"net/netip"
)

// All the read functions call readByte or readInto as they would want to read a single byte or copy a slice of bytes.
Expand Down Expand Up @@ -160,7 +161,9 @@ func (b *Buffer) ReadInet() Inet {
log.Printf("unknown ip length")
}
}
return Inet{IP: b.readCopy(int(n)), Port: b.ReadInt()}

ip, _ := netip.AddrFromSlice(b.readCopy(int(n)))
return Inet{IP: ip, Port: b.ReadInt()}
}

func (b *Buffer) ReadString() string {
Expand Down
9 changes: 5 additions & 4 deletions frame/buffer_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ func (b *Buffer) WriteValue(v Value) {
}

func (b *Buffer) WriteInet(v Inet) {
addr := v.IP.AsSlice()
if Debug {
if l := len(v.IP); l != 4 && l != 16 {
log.Printf("unknown IP length")
if len(addr) != 4 && len(addr) != 16 {
log.Printf("unknown ip length")
}
}
b.WriteByte(Byte(len(v.IP)))
b.Write(v.IP)
b.WriteByte(Byte(len(addr)))
b.Write(addr)
b.WriteInt(v.Port)
}

Expand Down
27 changes: 13 additions & 14 deletions frame/cqlvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"encoding/binary"
"fmt"
"math"
"net"
"net/netip"
"unicode"
"unicode/utf8"
)
Expand Down Expand Up @@ -153,16 +153,17 @@ func (c CqlValue) AsText() (string, error) {
return string(c.Value), nil
}

func (c CqlValue) AsIP() (net.IP, error) {
func (c CqlValue) AsIP() (netip.Addr, error) {
if c.Type.ID != InetID {
return nil, fmt.Errorf("%v is not of Inet type", c)
return netip.Addr{}, fmt.Errorf("%v is not of Inet type", c)
}

if len(c.Value) != 4 && len(c.Value) != 16 {
return nil, fmt.Errorf("invalid ip length")
ret, ok := netip.AddrFromSlice(c.Value)
if !ok {
return netip.Addr{}, fmt.Errorf("invalid ip length")
}

return c.Value, nil
return ret, nil
}

func (c CqlValue) AsFloat32() (float32, error) {
Expand Down Expand Up @@ -414,17 +415,15 @@ func CqlFromTimeUUID(b [16]byte) (CqlValue, error) {
return c, nil
}

func CqlFromIP(ip net.IP) (CqlValue, error) {
if len(ip) != 4 || len(ip) != 16 {
return CqlValue{}, fmt.Errorf("invalid ip address")
func CqlFromIP(ip netip.Addr) (CqlValue, error) {
if ip.BitLen() == 0 {
return CqlValue{}, fmt.Errorf("zero addr is not supported")
}

c := CqlValue{
return CqlValue{
Type: &Option{ID: InetID},
Value: make(Bytes, len(ip)),
}
copy(c.Value, ip)
return c, nil
Value: ip.AsSlice(),
}, nil
}

func CqlFromFloat32(v float32) CqlValue {
Expand Down
10 changes: 8 additions & 2 deletions frame/cqlvalue_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package frame
import (
"math"
"net"
"net/netip"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -142,12 +143,17 @@ func FuzzCqlValueText(f *testing.F) {
}

func FuzzCqlValueIP(f *testing.F) {
testCases := [][]byte{{1, 2, 3}, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 1}.To16()}
testCases := [][]byte{net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 1}.To16()}
for _, tc := range testCases {
f.Add(tc)
}
f.Fuzz(func(t *testing.T, data []byte) {
in, err := CqlFromIP(data)
ip, ok := netip.AddrFromSlice(data)
if !ok {
t.Skip()
}

in, err := CqlFromIP(ip)
if err != nil {
// We skip tests with incorrect CqlValue.
t.Skip()
Expand Down
16 changes: 8 additions & 8 deletions frame/cqlvalue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package frame

import (
"math"
"net"
"net/netip"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -509,7 +509,7 @@ func TestCqlValueAsIP(t *testing.T) {
name string
content CqlValue
valid bool
expected net.IP
expected netip.Addr
}{
{
name: "wrong length",
Expand All @@ -530,19 +530,19 @@ func TestCqlValueAsIP(t *testing.T) {
name: "valid v4",
content: CqlValue{
Type: &Option{ID: InetID},
Value: Bytes(net.IP{127, 0, 0, 1}),
Value: Bytes{127, 0, 0, 1},
},
valid: true,
expected: net.IP{127, 0, 0, 1},
expected: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
},
{
name: "valid v6",
content: CqlValue{
Type: &Option{ID: InetID},
Value: Bytes(net.IP{127, 0, 0, 1}.To16()),
Value: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1},
},
valid: true,
expected: net.IP{127, 0, 0, 1}.To16(),
expected: netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}),
},
}

Expand All @@ -557,8 +557,8 @@ func TestCqlValueAsIP(t *testing.T) {
}
return
}
if diff := cmp.Diff(v, tc.expected); diff != "" {
t.Fatalf(diff)
if v != tc.expected {
t.Fatalf("expected %v, got %v", tc.expected, v)
}
})
}
Expand Down
17 changes: 9 additions & 8 deletions frame/response/event_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package response

import (
"net/netip"
"testing"

"github.com/scylladb/scylla-go-driver/frame"
Expand All @@ -21,15 +22,15 @@ func TestStatusChangeEvent(t *testing.T) { // nolint:dupl // Tests are different
var b frame.Buffer
b.WriteString("UP")
b.WriteInet(frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
})
return b.Bytes()
}(),
expected: StatusChange{
Status: "UP",
Address: frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
},
},
Expand All @@ -42,8 +43,8 @@ func TestStatusChangeEvent(t *testing.T) { // nolint:dupl // Tests are different
var buf frame.Buffer
buf.Write(tc.content)
a := ParseStatusChange(&buf)
if diff := cmp.Diff(*a, tc.expected); diff != "" {
t.Fatal(diff)
if *a != tc.expected {
t.Fatalf("expected %v, got %v", tc.expected, *a)
}
})
}
Expand All @@ -62,15 +63,15 @@ func TestTopologyChangeEvent(t *testing.T) { //nolint:dupl // Tests are differen
var b frame.Buffer
b.WriteString("NEW_NODE")
b.WriteInet(frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
})
return b.Bytes()
}(),
expected: TopologyChange{
Change: "NEW_NODE",
Address: frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
},
},
Expand All @@ -83,8 +84,8 @@ func TestTopologyChangeEvent(t *testing.T) { //nolint:dupl // Tests are differen
var buf frame.Buffer
buf.Write(tc.content)
a := ParseTopologyChange(&buf)
if diff := cmp.Diff(*a, tc.expected); diff != "" {
t.Fatal(diff)
if *a != tc.expected {
t.Fatalf("expected %v, got %v", tc.expected, *a)
}
})
}
Expand Down
6 changes: 3 additions & 3 deletions frame/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package frame

import (
"errors"
"net"
"net/netip"
)

// Generic types from CQL binary protocol.
Expand Down Expand Up @@ -40,13 +40,13 @@ func (v Value) Clone() Value {

// https://github.com/apache/cassandra/blob/adcff3f630c0d07d1ba33bf23fcb11a6db1b9af1/doc/native_protocol_v4.spec#L241-L245
type Inet struct {
IP Bytes
IP netip.Addr
Port Int
}

// String only takes care of IP part of the address.
func (i Inet) String() string {
return net.IP(i.IP).String()
return i.IP.String()
}

// https://github.com/apache/cassandra/blob/adcff3f630c0d07d1ba33bf23fcb11a6db1b9af1/doc/native_protocol_v4.spec#L183-L201
Expand Down
10 changes: 7 additions & 3 deletions transport/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log"
"net"
"net/netip"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -320,20 +321,23 @@ func (c *Cluster) parseNodeFromRow(r frame.Row) (*Node, error) {
}
// Possible IP addresses starts from addrIndex in both system.local and system.peers queries.
// They are grouped with decreasing priority.
var addr net.IP
var addr netip.Addr
for i := addrIndex; i < len(r); i++ {
addr, err = r[i].AsIP()
if err == nil && !addr.IsUnspecified() {
break
} else if err == nil && addr.IsUnspecified() {
host, _, err := net.SplitHostPort(c.control.conn.RemoteAddr().String())
if err == nil {
addr = net.ParseIP(host)
addr, err = netip.ParseAddr(host)
if err != nil {
addr = netip.AddrFrom4([4]byte{0, 0, 0, 0})
}
break
}
}
}
if addr == nil || addr.IsUnspecified() {
if addr.IsUnspecified() {
return nil, fmt.Errorf("all addr columns conatin invalid IP")
}
return &Node{
Expand Down
3 changes: 2 additions & 1 deletion transport/cluster_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package transport
import (
"context"
"fmt"
"net/netip"
"os/signal"
"syscall"
"testing"
Expand Down Expand Up @@ -40,7 +41,7 @@ func TestClusterIntegration(t *testing.T) {
defer cancel()

addr := frame.Inet{
IP: []byte{192, 168, 100, 100},
IP: netip.MustParseAddr(TestHost),
Port: 9042,
}

Expand Down

0 comments on commit a93c76b

Please sign in to comment.