Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: replace net.IP with netip.Addr #285

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion frame/buffer_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package frame
import (
"fmt"
"log"
"net/netip"
)

// All the read functions call readByte or readInto as they would want to read a single byte or copy a slice of bytes.
Expand Down Expand Up @@ -160,7 +161,9 @@ func (b *Buffer) ReadInet() Inet {
log.Printf("unknown ip length")
}
}
return Inet{IP: b.readCopy(int(n)), Port: b.ReadInt()}

ip, _ := netip.AddrFromSlice(b.readCopy(int(n)))
return Inet{IP: ip, Port: b.ReadInt()}
}

func (b *Buffer) ReadString() string {
Expand Down
9 changes: 5 additions & 4 deletions frame/buffer_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ func (b *Buffer) WriteValue(v Value) {
}

func (b *Buffer) WriteInet(v Inet) {
addr := v.IP.AsSlice()
if Debug {
if l := len(v.IP); l != 4 && l != 16 {
log.Printf("unknown IP length")
if len(addr) != 4 && len(addr) != 16 {
log.Printf("unknown ip length")
}
}
b.WriteByte(Byte(len(v.IP)))
b.Write(v.IP)
b.WriteByte(Byte(len(addr)))
b.Write(addr)
b.WriteInt(v.Port)
}

Expand Down
27 changes: 13 additions & 14 deletions frame/cqlvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"encoding/binary"
"fmt"
"math"
"net"
"net/netip"
"unicode"
"unicode/utf8"
)
Expand Down Expand Up @@ -153,16 +153,17 @@ func (c CqlValue) AsText() (string, error) {
return string(c.Value), nil
}

func (c CqlValue) AsIP() (net.IP, error) {
func (c CqlValue) AsIP() (netip.Addr, error) {
if c.Type.ID != InetID {
return nil, fmt.Errorf("%v is not of Inet type", c)
return netip.Addr{}, fmt.Errorf("%v is not of Inet type", c)
}

if len(c.Value) != 4 && len(c.Value) != 16 {
return nil, fmt.Errorf("invalid ip length")
ret, ok := netip.AddrFromSlice(c.Value)
if !ok {
return netip.Addr{}, fmt.Errorf("invalid ip length")
}

return c.Value, nil
return ret, nil
}

func (c CqlValue) AsFloat32() (float32, error) {
Expand Down Expand Up @@ -414,17 +415,15 @@ func CqlFromTimeUUID(b [16]byte) (CqlValue, error) {
return c, nil
}

func CqlFromIP(ip net.IP) (CqlValue, error) {
if len(ip) != 4 || len(ip) != 16 {
return CqlValue{}, fmt.Errorf("invalid ip address")
func CqlFromIP(ip netip.Addr) (CqlValue, error) {
if ip.BitLen() == 0 {
return CqlValue{}, fmt.Errorf("zero addr is not supported")
}

c := CqlValue{
return CqlValue{
Type: &Option{ID: InetID},
Value: make(Bytes, len(ip)),
}
copy(c.Value, ip)
return c, nil
Value: ip.AsSlice(),
}, nil
}

func CqlFromFloat32(v float32) CqlValue {
Expand Down
10 changes: 8 additions & 2 deletions frame/cqlvalue_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package frame
import (
"math"
"net"
"net/netip"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -142,12 +143,17 @@ func FuzzCqlValueText(f *testing.F) {
}

func FuzzCqlValueIP(f *testing.F) {
testCases := [][]byte{{1, 2, 3}, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 1}.To16()}
testCases := [][]byte{net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 1}.To16()}
for _, tc := range testCases {
f.Add(tc)
}
f.Fuzz(func(t *testing.T, data []byte) {
in, err := CqlFromIP(data)
ip, ok := netip.AddrFromSlice(data)
if !ok {
t.Skip()
}

in, err := CqlFromIP(ip)
if err != nil {
// We skip tests with incorrect CqlValue.
t.Skip()
Expand Down
16 changes: 8 additions & 8 deletions frame/cqlvalue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package frame

import (
"math"
"net"
"net/netip"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -509,7 +509,7 @@ func TestCqlValueAsIP(t *testing.T) {
name string
content CqlValue
valid bool
expected net.IP
expected netip.Addr
}{
{
name: "wrong length",
Expand All @@ -530,19 +530,19 @@ func TestCqlValueAsIP(t *testing.T) {
name: "valid v4",
content: CqlValue{
Type: &Option{ID: InetID},
Value: Bytes(net.IP{127, 0, 0, 1}),
Value: Bytes{127, 0, 0, 1},
},
valid: true,
expected: net.IP{127, 0, 0, 1},
expected: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
},
{
name: "valid v6",
content: CqlValue{
Type: &Option{ID: InetID},
Value: Bytes(net.IP{127, 0, 0, 1}.To16()),
Value: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1},
},
valid: true,
expected: net.IP{127, 0, 0, 1}.To16(),
expected: netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}),
},
}

Expand All @@ -557,8 +557,8 @@ func TestCqlValueAsIP(t *testing.T) {
}
return
}
if diff := cmp.Diff(v, tc.expected); diff != "" {
t.Fatalf(diff)
if v != tc.expected {
t.Fatalf("expected %v, got %v", tc.expected, v)
}
})
}
Expand Down
17 changes: 9 additions & 8 deletions frame/response/event_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package response

import (
"net/netip"
"testing"

"github.com/scylladb/scylla-go-driver/frame"
Expand All @@ -21,15 +22,15 @@ func TestStatusChangeEvent(t *testing.T) { // nolint:dupl // Tests are different
var b frame.Buffer
b.WriteString("UP")
b.WriteInet(frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
})
return b.Bytes()
}(),
expected: StatusChange{
Status: "UP",
Address: frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
},
},
Expand All @@ -42,8 +43,8 @@ func TestStatusChangeEvent(t *testing.T) { // nolint:dupl // Tests are different
var buf frame.Buffer
buf.Write(tc.content)
a := ParseStatusChange(&buf)
if diff := cmp.Diff(*a, tc.expected); diff != "" {
t.Fatal(diff)
if *a != tc.expected {
t.Fatalf("expected %v, got %v", tc.expected, *a)
}
})
}
Expand All @@ -62,15 +63,15 @@ func TestTopologyChangeEvent(t *testing.T) { //nolint:dupl // Tests are differen
var b frame.Buffer
b.WriteString("NEW_NODE")
b.WriteInet(frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
})
return b.Bytes()
}(),
expected: TopologyChange{
Change: "NEW_NODE",
Address: frame.Inet{
IP: []byte{127, 0, 0, 1},
IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
Port: 9042,
},
},
Expand All @@ -83,8 +84,8 @@ func TestTopologyChangeEvent(t *testing.T) { //nolint:dupl // Tests are differen
var buf frame.Buffer
buf.Write(tc.content)
a := ParseTopologyChange(&buf)
if diff := cmp.Diff(*a, tc.expected); diff != "" {
t.Fatal(diff)
if *a != tc.expected {
t.Fatalf("expected %v, got %v", tc.expected, *a)
}
})
}
Expand Down
6 changes: 3 additions & 3 deletions frame/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package frame

import (
"errors"
"net"
"net/netip"
)

// Generic types from CQL binary protocol.
Expand Down Expand Up @@ -40,13 +40,13 @@ func (v Value) Clone() Value {

// https://github.com/apache/cassandra/blob/adcff3f630c0d07d1ba33bf23fcb11a6db1b9af1/doc/native_protocol_v4.spec#L241-L245
type Inet struct {
IP Bytes
IP netip.Addr
Port Int
}

// String only takes care of IP part of the address.
func (i Inet) String() string {
return net.IP(i.IP).String()
return i.IP.String()
}

// https://github.com/apache/cassandra/blob/adcff3f630c0d07d1ba33bf23fcb11a6db1b9af1/doc/native_protocol_v4.spec#L183-L201
Expand Down
12 changes: 8 additions & 4 deletions transport/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log"
"net"
"net/netip"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -320,21 +321,24 @@ func (c *Cluster) parseNodeFromRow(r frame.Row) (*Node, error) {
}
// Possible IP addresses starts from addrIndex in both system.local and system.peers queries.
// They are grouped with decreasing priority.
var addr net.IP
var addr netip.Addr
for i := addrIndex; i < len(r); i++ {
addr, err = r[i].AsIP()
if err == nil && !addr.IsUnspecified() {
break
} else if err == nil && addr.IsUnspecified() {
host, _, err := net.SplitHostPort(c.control.conn.RemoteAddr().String())
if err == nil {
addr = net.ParseIP(host)
addr, err = netip.ParseAddr(host)
if err != nil {
addr = netip.AddrFrom4([4]byte{0, 0, 0, 0})
}
break
}
}
}
if addr == nil || addr.IsUnspecified() {
return nil, fmt.Errorf("all addr columns conatin invalid IP")
if addr.IsUnspecified() {
return nil, fmt.Errorf("all addr columns contain invalid IP")
}
return &Node{
hostID: hostID,
Expand Down
3 changes: 2 additions & 1 deletion transport/cluster_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package transport
import (
"context"
"fmt"
"net/netip"
"os/signal"
"syscall"
"testing"
Expand Down Expand Up @@ -40,7 +41,7 @@ func TestClusterIntegration(t *testing.T) {
defer cancel()

addr := frame.Inet{
IP: []byte{192, 168, 100, 100},
IP: netip.MustParseAddr(TestHost),
Port: 9042,
}

Expand Down