diff --git a/key/errors.go b/key/errors.go deleted file mode 100644 index 6fb9a39..0000000 --- a/key/errors.go +++ /dev/null @@ -1,9 +0,0 @@ -package key - -import ( - "fmt" -) - -func ErrInvalidKey(l int) error { - return fmt.Errorf("invalid key: should be %d bytes long", l) -} diff --git a/key/key.go b/key/key.go index 902559d..b4c3fd5 100644 --- a/key/key.go +++ b/key/key.go @@ -19,51 +19,59 @@ func (k KadKey) String() string { return k.Hex() } -func (a KadKey) Xor(b KadKey) (KadKey, error) { - if a.Size() != b.Size() { - return nil, ErrInvalidKey(a.Size()) +func shortLong(a, b KadKey) (min, max KadKey) { + if len(a) < len(b) { + return a, b } + return b, a +} - xored := make([]byte, a.Size()) - for i := 0; i < a.Size(); i++ { +func (a KadKey) Xor(b KadKey) KadKey { + short, long := shortLong(a, b) + xored := make([]byte, len(long)) + for i := 0; i < len(short); i++ { xored[i] = a[i] ^ b[i] } - return xored, nil + copy(xored[len(short):], long[len(short):]) + return xored } -func (a KadKey) CommonPrefixLength(b KadKey) (int, error) { - if a.Size() != b.Size() { - return 0, ErrInvalidKey(a.Size()) - } +func (a KadKey) CommonPrefixLength(b KadKey) int { + short, _ := shortLong(a, b) var xored byte - for i := 0; i < a.Size(); i++ { + for i := 0; i < len(short); i++ { xored = a[i] ^ b[i] if xored != 0 { - return i*8 + 7 - int(math.Log2(float64(xored))), nil + return i*8 + 7 - int(math.Log2(float64(xored))) } } - return 8 * a.Size(), nil + return 8 * len(short) } // Compare returns -1 if a < b, 0 if a == b, and 1 if a > b -func (a KadKey) Compare(b KadKey) (int8, error) { - if a.Size() != b.Size() { - return 2, ErrInvalidKey(a.Size()) - } +func (a KadKey) Compare(b KadKey) int { + short, _ := shortLong(a, b) - for i := 0; i < a.Size(); i++ { + for i := 0; i < len(short); i++ { if a[i] < b[i] { - return -1, nil + return -1 } if a[i] > b[i] { - return 1, nil + return 1 } } - return 0, nil + if len(a) == len(b) { + return 0 + } else if len(a) < len(b) { + // if both keys don't have the same size, and the shorter is a prefix + // of the longer, then the shorter is considered smaller + return -1 + } else { + return 1 + } } -func (a KadKey) Equal(b KadKey) (bool, error) { - cmp, err := a.Compare(b) - return cmp == 0, err +func (a KadKey) Equal(b KadKey) bool { + return a.Compare(b) == 0 } diff --git a/key/key_test.go b/key/key_test.go index afb5fa9..ab927f0 100644 --- a/key/key_test.go +++ b/key/key_test.go @@ -9,16 +9,8 @@ import ( var keysize = 4 -func zeroBytes(n int) []byte { - bytes := make([]byte, n) - for i := 0; i < n; i++ { - bytes[i] = 0 - } - return bytes -} - func TestKadKeyString(t *testing.T) { - zeroKadid := KadKey(zeroBytes(keysize)) + zeroKadid := KadKey(make([]byte, keysize)) zeroHex := strings.Repeat("00", keysize) require.Equal(t, zeroHex, zeroKadid.String()) @@ -38,72 +30,84 @@ func TestKadKeyString(t *testing.T) { } func TestXor(t *testing.T) { - key0 := KadKey(zeroBytes(keysize)) // 00000...000 + key0 := KadKey(make([]byte, keysize)) // 00000...000 randKey := KadKey([]byte{0x23, 0xe4, 0xdd, 0x03}) // arbitrary key - xored, err := key0.Xor(key0) - require.NoError(t, err) + xored := key0.Xor(key0) require.Equal(t, key0, xored) - xored, _ = randKey.Xor(key0) + xored = randKey.Xor(key0) require.Equal(t, randKey, xored) - xored, _ = key0.Xor(randKey) + xored = key0.Xor(randKey) require.Equal(t, randKey, xored) - xored, _ = randKey.Xor(randKey) + xored = randKey.Xor(randKey) require.Equal(t, key0, xored) - invalidKey := KadKey([]byte{0x23, 0xe4, 0xdd}) // invalid key - _, err = key0.Xor(invalidKey) - require.Equal(t, ErrInvalidKey(4), err) + shorterKey := KadKey([]byte{0x23, 0xe4, 0xdd}) // shorter key + xored = key0.Xor(shorterKey) + expected := append(shorterKey, make([]byte, key0.Size()-shorterKey.Size())...) + require.Equal(t, expected, xored) + xored = shorterKey.Xor(key0) + require.Equal(t, expected, xored) + xored = key0.Xor(nil) + require.Equal(t, key0, xored) } func TestCommonPrefixLength(t *testing.T) { - key0 := KadKey(zeroBytes(keysize)) // 00000...000 - key1 := KadKey(append(zeroBytes(keysize-1), 0x01)) // 00000...001 - key2 := KadKey(append([]byte{0x80}, zeroBytes(keysize-1)...)) // 10000...000 - key3 := KadKey(append([]byte{0x40}, zeroBytes(keysize-1)...)) // 01000...000 + key0 := KadKey(make([]byte, keysize)) // 00000...000 + key1 := KadKey(append(make([]byte, keysize-1), 0x01)) // 00000...001 + key2 := KadKey(append([]byte{0x80}, make([]byte, keysize-1)...)) // 10000...000 + key3 := KadKey(append([]byte{0x40}, make([]byte, keysize-1)...)) // 01000...000 - cpl, err := key0.CommonPrefixLength(key0) - require.NoError(t, err) + cpl := key0.CommonPrefixLength(key0) require.Equal(t, keysize*8, cpl) - cpl, _ = key0.CommonPrefixLength(key1) + cpl = key0.CommonPrefixLength(key1) require.Equal(t, keysize*8-1, cpl) - cpl, _ = key0.CommonPrefixLength(key2) + cpl = key0.CommonPrefixLength(key2) require.Equal(t, 0, cpl) - cpl, _ = key0.CommonPrefixLength(key3) + cpl = key0.CommonPrefixLength(key3) require.Equal(t, 1, cpl) - invalidKey := KadKey([]byte{0x23, 0xe4, 0xdd}) // invalid key - _, err = key0.CommonPrefixLength(invalidKey) - require.Equal(t, ErrInvalidKey(4), err) + cpl = key0.CommonPrefixLength(nil) + require.Equal(t, 0, cpl) + cpl = key0.CommonPrefixLength([]byte{0x00}) + require.Equal(t, 8, cpl) + cpl = key0.CommonPrefixLength([]byte{0x00, 0x40}) + require.Equal(t, 9, cpl) + cpl = key0.CommonPrefixLength([]byte{0x80}) + require.Equal(t, 0, cpl) } func TestCompare(t *testing.T) { nKeys := 5 keys := make([]KadKey, nKeys) // ascending order - keys[0] = KadKey(zeroBytes(keysize)) // 00000...000 - keys[1] = KadKey(append(zeroBytes(keysize-1), 0x01)) // 00000...001 - keys[2] = KadKey(append(zeroBytes(keysize-1), 0x02)) // 00000...010 - keys[3] = KadKey(append([]byte{0x40}, zeroBytes(keysize-1)...)) // 01000...000 - keys[4] = KadKey(append([]byte{0x80}, zeroBytes(keysize-1)...)) // 10000...000 + keys[0] = KadKey(make([]byte, keysize)) // 00000...000 + keys[1] = KadKey(append(make([]byte, keysize-1), 0x01)) // 00000...001 + keys[2] = KadKey(append(make([]byte, keysize-1), 0x02)) // 00000...010 + keys[3] = KadKey(append([]byte{0x40}, make([]byte, keysize-1)...)) // 01000...000 + keys[4] = KadKey(append([]byte{0x80}, make([]byte, keysize-1)...)) // 10000...000 for i := 0; i < nKeys; i++ { for j := 0; j < nKeys; j++ { - res, _ := keys[i].Compare(keys[j]) + res := keys[i].Compare(keys[j]) if i < j { - require.Equal(t, int8(-1), res) + require.Equal(t, -1, res) } else if i > j { - require.Equal(t, int8(1), res) + require.Equal(t, 1, res) } else { - require.Equal(t, int8(0), res) - equal, err := keys[i].Equal(keys[j]) - require.NoError(t, err) + require.Equal(t, 0, res) + equal := keys[i].Equal(keys[j]) require.True(t, equal) } } } - invalidKey := KadKey([]byte{0x23, 0xe4, 0xdd}) // invalid key - _, err := keys[0].Compare(invalidKey) - require.Equal(t, ErrInvalidKey(4), err) + // compare keys of different sizes + key := keys[4] // 10000...000 (32 bits) + require.Equal(t, 1, key.Compare([]byte{})) // b is prefix of a -> 1 + require.Equal(t, 1, key.Compare([]byte{0x00})) // a[0] > b [0] -> 1 + require.Equal(t, 1, key.Compare([]byte{0x80})) // b is prefix of a -> 1 + require.Equal(t, -1, key.Compare([]byte{0x81})) // a[4] < b[4] -> -1 + require.Equal(t, 0, key.Compare([]byte{0x80, 0x00, 0x00, 0x00})) // a == b -> 0 + require.Equal(t, -1, key.Compare([]byte{0x80, 0x00, 0x00, 0x00, 0x00})) // a is prefix of b -> -1 } diff --git a/network/endpoint/fakeendpoint/fakeendpoint_test.go b/network/endpoint/fakeendpoint/fakeendpoint_test.go index 3a81494..6176e46 100644 --- a/network/endpoint/fakeendpoint/fakeendpoint_test.go +++ b/network/endpoint/fakeendpoint/fakeendpoint_test.go @@ -39,12 +39,11 @@ func TestFakeEndpoint(t *testing.T) { fakeEndpoint := NewFakeEndpoint(selfID, sched, router) - b, err := selfID.Key().Equal(fakeEndpoint.KadKey()) - require.NoError(t, err) + b := selfID.Key().Equal(fakeEndpoint.KadKey()) require.True(t, b) node0 := si.StringID("node0") - err = fakeEndpoint.DialPeer(ctx, node0) + err := fakeEndpoint.DialPeer(ctx, node0) require.Equal(t, endpoint.ErrUnknownPeer, err) connectedness, err := fakeEndpoint.Connectedness(node0) diff --git a/network/message/ipfsv1/helpers_test.go b/network/message/ipfsv1/helpers_test.go index d114136..6de95a8 100644 --- a/network/message/ipfsv1/helpers_test.go +++ b/network/message/ipfsv1/helpers_test.go @@ -26,8 +26,7 @@ func TestFindPeerRequest(t *testing.T) { require.Equal(t, msg.GetKey(), []byte(p)) - b, err := msg.Target().Equal(pid.Key()) - require.NoError(t, err) + b := msg.Target().Equal(pid.Key()) require.True(t, b) require.Equal(t, 0, len(msg.CloserNodes())) diff --git a/network/message/simmessage/simmessage_test.go b/network/message/simmessage/simmessage_test.go index 951702c..72ecdba 100644 --- a/network/message/simmessage/simmessage_test.go +++ b/network/message/simmessage/simmessage_test.go @@ -17,7 +17,7 @@ func TestSimRequest(t *testing.T) { require.Equal(t, &SimMessage{}, msg.EmptyResponse()) - b, _ := msg.Target().Equal(target.Key()) + b := msg.Target().Equal(target.Key()) require.True(t, b) require.Nil(t, msg.CloserNodes()) } diff --git a/query/simplequery/peerlist.go b/query/simplequery/peerlist.go index 8f7d3f3..31fca83 100644 --- a/query/simplequery/peerlist.go +++ b/query/simplequery/peerlist.go @@ -76,7 +76,7 @@ func (pl *peerList) addToPeerlist(ids []address.NodeID) { currOld := true // current element is from old list closestQueuedReached := false - r, _ := oldHead.distance.Compare(newHead.distance) + r := oldHead.distance.Compare(newHead.distance) if r > 0 { pl.closest = newHead pl.closestQueued = newHead @@ -153,7 +153,7 @@ func (pl *peerList) addToPeerlist(ids []address.NodeID) { if oldHead == nil || newHead == nil { break } - r, _ = oldHead.distance.Compare(newHead.distance) + r = oldHead.distance.Compare(newHead.distance) } // append the remaining list to the end @@ -189,15 +189,13 @@ func sliceToPeerInfos(target key.KadKey, ids []address.NodeID) *peerInfo { // sort the new list sort.Slice(newPeers, func(i, j int) bool { - r, _ := newPeers[i].distance.Compare(newPeers[j].distance) - return r < 0 + return newPeers[i].distance.Compare(newPeers[j].distance) < 0 }) // convert slice to linked list and remove duplicates curr := newPeers[0] for i := 1; i < len(newPeers); i++ { - r, _ := curr.distance.Compare(newPeers[i].distance) - if r != 0 { + if !curr.distance.Equal(newPeers[i].distance) { curr.next = newPeers[i] curr = curr.next } @@ -210,9 +208,8 @@ func addrInfoToPeerInfo(target key.KadKey, id address.NodeID) *peerInfo { if id == nil || id.String() == "" || target.Size() != id.Key().Size() { return nil } - dist, _ := target.Xor(id.Key()) return &peerInfo{ - distance: dist, + distance: target.Xor(id.Key()), status: queued, id: id, } @@ -227,7 +224,7 @@ func (pl *peerList) enqueueUnreachablePeer(pi *peerInfo) { pl.queuedCount++ // if curr is closer to target than closestQueued, update closestQueued - if r, _ := pi.distance.Compare(pl.closestQueued.distance); r < 0 { + if pi.distance.Compare(pl.closestQueued.distance) < 0 { pl.closestQueued = pi } } diff --git a/query/simplequery/query.go b/query/simplequery/query.go index 94af242..25c6524 100644 --- a/query/simplequery/query.go +++ b/query/simplequery/query.go @@ -250,12 +250,12 @@ func (q *SimpleQuery) handleResponse(ctx context.Context, id address.NodeID, // remove all occurneces of q.self from usefulNodeIDs writeIndex := 0 for _, id := range usefulNodeIDs { - if c, err := q.rt.Self().Compare(id.Key()); err == nil && c != 0 { + if q.rt.Self().Size() != id.Key().Size() { + span.AddEvent("wrong KadKey length") + } else if !q.rt.Self().Equal(id.Key()) { // id is valid and isn't self usefulNodeIDs[writeIndex] = id writeIndex++ - } else if err != nil { - span.AddEvent("wrong KadKey length") } else { span.AddEvent("never add self to query peerlist") } diff --git a/query/simplequery/query_test.go b/query/simplequery/query_test.go index 70002af..bcaaa00 100644 --- a/query/simplequery/query_test.go +++ b/query/simplequery/query_test.go @@ -286,7 +286,7 @@ func getHandleResults(t *testing.T, req message.MinKadRequestMessage, var found bool for i, n := range resp.CloserNodes() { ids[i] = n.NodeID() - if match, err := ids[i].Key().Equal(req.Target()); err == nil && match { + if ids[i].Key().Equal(req.Target()) { // the target was found, stop the query found = true } @@ -330,7 +330,7 @@ func TestElementaryQuery(t *testing.T) { currID := 0 // while currID != target.Key() - for c, _ := ids[currID].NodeID().Key().Equal(req.Target()); !c; { + for !ids[currID].NodeID().Key().Equal(req.Target()) { // get closest peer to target from the sollicited peer closest, err := rts[currID].NearestPeers(ctx, req.Target(), 1) require.NoError(t, err) @@ -349,9 +349,6 @@ func TestElementaryQuery(t *testing.T) { closestKey[i] = n.Key() } expectedResponses = append(expectedResponses, closestKey) - - // test if the current ID is the target - c, _ = ids[currID].NodeID().Key().Equal(req.Target()) } // handleResults is called when a peer receives a response from a peer. If diff --git a/routing/errors.go b/routing/errors.go new file mode 100644 index 0000000..6c31ba0 --- /dev/null +++ b/routing/errors.go @@ -0,0 +1,7 @@ +package routing + +import "errors" + +var ( + ErrWrongKeySize = errors.New("wrong key size") +) diff --git a/routing/simplert/table.go b/routing/simplert/table.go index d6a1500..971c189 100644 --- a/routing/simplert/table.go +++ b/routing/simplert/table.go @@ -38,13 +38,6 @@ func New(self key.KadKey, bucketSize int) *SimpleRT { return &rt } -func (rt *SimpleRT) keyError(kadId key.KadKey) error { - if rt.self.Size() != kadId.Size() { - return key.ErrInvalidKey(rt.self.Size()) - } - return nil -} - func (rt *SimpleRT) Self() key.KadKey { return rt.self } @@ -58,10 +51,11 @@ func (rt *SimpleRT) BucketSize() int { } func (rt *SimpleRT) BucketIdForKey(kadId key.KadKey) (int, error) { - bid, err := rt.self.CommonPrefixLength(kadId) - if err != nil { - return 0, err + if rt.self.Size() != kadId.Size() { + return 0, routing.ErrWrongKeySize } + + bid := rt.self.CommonPrefixLength(kadId) nBuckets := len(rt.buckets) if bid >= nBuckets { bid = nBuckets - 1 @@ -84,9 +78,8 @@ func (rt *SimpleRT) addPeer(ctx context.Context, kadId key.KadKey, id address.No )) defer span.End() - if err := rt.keyError(kadId); err != nil { - span.RecordError(err) - return false, err + if rt.self.Size() != kadId.Size() { + return false, routing.ErrWrongKeySize } // no need to check the error here, it's already been checked in keyError @@ -129,7 +122,7 @@ func (rt *SimpleRT) addPeer(ctx context.Context, kadId key.KadKey, id address.No span.AddEvent("splitting last bucket (" + strconv.Itoa(lastBucketId) + ")") for _, p := range rt.buckets[lastBucketId] { - if cpl, _ := p.kadId.CommonPrefixLength(rt.self); cpl == lastBucketId { + if p.kadId.CommonPrefixLength(rt.self) == lastBucketId { farBucket = append(farBucket, p) } else { closeBucket = append(closeBucket, p) @@ -137,8 +130,8 @@ func (rt *SimpleRT) addPeer(ctx context.Context, kadId key.KadKey, id address.No strconv.Itoa(lastBucketId+1) + ")") } } - cpl, _ := rt.self.CommonPrefixLength(kadId) - if len(farBucket) == rt.bucketSize && cpl == lastBucketId { + if len(farBucket) == rt.bucketSize && + rt.self.CommonPrefixLength(kadId) == lastBucketId { // if all peers in the last bucket have the CPL matching this bucket, // don't split it and discard the new peer return false, nil @@ -161,8 +154,7 @@ func (rt *SimpleRT) addPeer(ctx context.Context, kadId key.KadKey, id address.No func (rt *SimpleRT) alreadyInBucket(kadId key.KadKey, bucketId int) bool { for _, p := range rt.buckets[bucketId] { // error already checked in keyError by the caller - c, _ := kadId.Compare(p.kadId) - if c == 0 { + if kadId.Equal(p.kadId) { return true } } @@ -175,15 +167,13 @@ func (rt *SimpleRT) RemoveKey(ctx context.Context, kadId key.KadKey) (bool, erro )) defer span.End() - if err := rt.keyError(kadId); err != nil { - span.RecordError(err) - return false, err + if rt.self.Size() != kadId.Size() { + return false, routing.ErrWrongKeySize } bid, _ := rt.BucketIdForKey(kadId) for i, p := range rt.buckets[bid] { - c, _ := kadId.Compare(p.kadId) - if c == 0 { + if kadId.Equal(p.kadId) { // remove peer from bucket rt.buckets[bid][i] = rt.buckets[bid][len(rt.buckets[bid])-1] rt.buckets[bid] = rt.buckets[bid][:len(rt.buckets[bid])-1] @@ -203,15 +193,13 @@ func (rt *SimpleRT) Find(ctx context.Context, kadId key.KadKey) (address.NodeID, )) defer span.End() - if err := rt.keyError(kadId); err != nil { - span.RecordError(err) - return nil, err + if rt.self.Size() != kadId.Size() { + return nil, routing.ErrWrongKeySize } bid, _ := rt.BucketIdForKey(kadId) for _, p := range rt.buckets[bid] { - c, _ := kadId.Compare(p.kadId) - if c == 0 { + if kadId.Equal(p.kadId) { return p.id, nil } } @@ -227,9 +215,8 @@ func (rt *SimpleRT) NearestPeers(ctx context.Context, kadId key.KadKey, n int) ( )) defer span.End() - if err := rt.keyError(kadId); err != nil { - span.RecordError(err) - return nil, err + if rt.self.Size() != kadId.Size() { + return nil, routing.ErrWrongKeySize } bid, _ := rt.BucketIdForKey(kadId) @@ -243,8 +230,7 @@ func (rt *SimpleRT) NearestPeers(ctx context.Context, kadId key.KadKey, n int) ( peers = make([]peerInfo, 0) for i := 0; i < len(rt.buckets); i++ { for _, p := range rt.buckets[i] { - c, _ := rt.self.Compare(p.kadId) - if c != 0 { + if !rt.self.Equal(p.kadId) { peers = append(peers, p) } } diff --git a/routing/simplert/table_test.go b/routing/simplert/table_test.go index e19eca5..39591df 100644 --- a/routing/simplert/table_test.go +++ b/routing/simplert/table_test.go @@ -10,6 +10,7 @@ import ( "github.com/plprobelab/go-kademlia/network/address" "github.com/plprobelab/go-kademlia/network/address/peerid" si "github.com/plprobelab/go-kademlia/network/address/stringid" + "github.com/plprobelab/go-kademlia/routing" "github.com/stretchr/testify/require" ) @@ -227,22 +228,22 @@ func TestInvalidKeys(t *testing.T) { rt := New(key0, 2) success, err := rt.addPeer(ctx, invalidKey, dummyNodeId) - require.Equal(t, err, key.ErrInvalidKey(32)) + require.Equal(t, routing.ErrWrongKeySize, err) require.False(t, success) bid, err := rt.BucketIdForKey(invalidKey) - require.Equal(t, err, key.ErrInvalidKey(32)) + require.Equal(t, routing.ErrWrongKeySize, err) require.Equal(t, bid, 0) success, err = rt.RemoveKey(ctx, invalidKey) - require.Equal(t, err, key.ErrInvalidKey(32)) + require.Equal(t, routing.ErrWrongKeySize, err) require.False(t, success) nodeID, err := rt.Find(ctx, invalidKey) - require.Equal(t, err, key.ErrInvalidKey(32)) + require.Equal(t, routing.ErrWrongKeySize, err) require.Nil(t, nodeID) nodeIDs, err := rt.NearestPeers(ctx, invalidKey, 2) - require.Equal(t, err, key.ErrInvalidKey(32)) + require.Equal(t, routing.ErrWrongKeySize, err) require.Nil(t, nodeIDs) }