Skip to content

Commit

Permalink
improved simplequery and corrected bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumemichel committed Jun 30, 2023
1 parent 84ef09d commit eda993d
Show file tree
Hide file tree
Showing 10 changed files with 350 additions and 63 deletions.
15 changes: 12 additions & 3 deletions examples/dispatchquery/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ func queryTest(ctx context.Context) {
schedA := ss.NewSimpleScheduler(clk)
endpointA := fakeendpoint.NewFakeEndpoint(selfA, schedA, router)
servA := basicserver.NewBasicServer(rtA, endpointA)
endpointA.AddRequestHandler(protoID, servA.HandleRequest, nil)
err = endpointA.AddRequestHandler(protoID, nil, servA.HandleRequest)
if err != nil {
panic(err)
}

// create peer B
pidB, err := peer.Decode("12BoooooBETA")
Expand All @@ -78,7 +81,10 @@ func queryTest(ctx context.Context) {
schedB := ss.NewSimpleScheduler(clk)
endpointB := fakeendpoint.NewFakeEndpoint(selfB, schedB, router)
servB := basicserver.NewBasicServer(rtB, endpointB)
endpointB.AddRequestHandler(protoID, servB.HandleRequest, nil)
err = endpointB.AddRequestHandler(protoID, nil, servB.HandleRequest)
if err != nil {
panic(err)
}

// create peer C
pidC, err := peer.Decode("12BooooGAMMA")
Expand All @@ -93,7 +99,10 @@ func queryTest(ctx context.Context) {
schedC := ss.NewSimpleScheduler(clk)
endpointC := fakeendpoint.NewFakeEndpoint(selfC, schedC, router)
servC := basicserver.NewBasicServer(rtC, endpointC)
endpointC.AddRequestHandler(protoID, servC.HandleRequest, nil)
err = endpointC.AddRequestHandler(protoID, nil, servC.HandleRequest)
if err != nil {
panic(err)
}

// connect peer A and B
endpointA.MaybeAddToPeerstore(ctx, naddrB, peerstoreTTL)
Expand Down
5 changes: 4 additions & 1 deletion examples/fullsim/findnode.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ func findNode(ctx context.Context) {
// create a server instance for the node
servers[i] = basicserver.NewBasicServer(rts[i], eps[i])
// add the server request handler for protoID to the endpoint
eps[i].AddRequestHandler(protoID, servers[i].HandleRequest, nil)
err := eps[i].AddRequestHandler(protoID, nil, servers[i].HandleRequest)
if err != nil {
panic(err)
}
}

// A connects to B
Expand Down
1 change: 1 addition & 0 deletions network/endpoint/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ var (
ErrUnknownPeer = errors.New("unknown peer")
ErrInvalidPeer = errors.New("invalid peer")
ErrTimeout = errors.New("request timeout")
ErrNilRequestHandler = errors.New("nil request handler")
ErrNilResponseHandler = errors.New("nil response handler")
ErrResponseReceivedAfterTimeout = errors.New("response received after timeout")
)
7 changes: 5 additions & 2 deletions network/endpoint/fakeendpoint/fakeendpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ func (e *FakeEndpoint) HandleMessage(ctx context.Context, id address.NodeID,
return
}

if _, ok := e.serverProtos[protoID]; ok {
if handler, ok := e.serverProtos[protoID]; ok && handler != nil {
// it isn't a response, so treat it as a request
resp, err := e.serverProtos[protoID](ctx, id, msg)
resp, err := handler(ctx, id, msg)
if err != nil {
span.RecordError(err)
return
Expand All @@ -207,6 +207,9 @@ func (e *FakeEndpoint) HandleMessage(ctx context.Context, id address.NodeID,

func (e *FakeEndpoint) AddRequestHandler(protoID address.ProtocolID,
req message.MinKadMessage, reqHandler endpoint.RequestHandlerFn) error {
if reqHandler == nil {
return endpoint.ErrNilRequestHandler
}
e.serverProtos[protoID] = reqHandler
return nil
}
Expand Down
5 changes: 4 additions & 1 deletion network/endpoint/fakeendpoint/fakeendpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ func TestFakeEndpoint(t *testing.T) {
fakeEndpoint0 := NewFakeEndpoint(node0, sched0, router)
rt0 := simplert.NewSimpleRT(node0.Key(), 2)
serv0 := basicserver.NewBasicServer(rt0, fakeEndpoint0)
fakeEndpoint0.AddRequestHandler(protoID, nil, serv0.HandleRequest)
err = fakeEndpoint0.AddRequestHandler(protoID, nil, serv0.HandleRequest)
require.NoError(t, err)
err = fakeEndpoint0.AddRequestHandler(protoID, nil, nil)
require.Equal(t, endpoint.ErrNilRequestHandler, err)
// remove a request handler that doesn't exist
fakeEndpoint0.RemoveRequestHandler("/test/0.0.1")

Expand Down
7 changes: 6 additions & 1 deletion network/endpoint/libp2pendpoint/libp2pendpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ func (e *Libp2pEndpoint) AsyncDialAndReport(ctx context.Context,

if reportFn != nil {
// report dial result where it is needed
reportFn(ctx, success)
e.sched.EnqueueAction(ctx, ba.BasicAction(func(ctx context.Context) {
reportFn(ctx, success)
}))
}
}()
return nil
Expand Down Expand Up @@ -274,6 +276,9 @@ func (e *Libp2pEndpoint) AddRequestHandler(protoID address.ProtocolID,
if !ok {
return ErrRequireProtoKadMessage
}
if reqHandler == nil {
return endpoint.ErrNilRequestHandler
}
// when a new request comes in, we need to queue it
streamHandler := func(s network.Stream) {
e.sched.EnqueueAction(e.ctx, ba.BasicAction(func(ctx context.Context) {
Expand Down
16 changes: 16 additions & 0 deletions network/endpoint/libp2pendpoint/libp2pendpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ func TestLibp2pEndpoint(t *testing.T) {
wg.Done()
})
require.NoError(t, err)
go func() {
// AsyncDialAndReport adds the dial action to the event queue, so we
// need to run the scheduler
for !scheds[0].RunOne(ctx) {
time.Sleep(time.Millisecond)
}
}()
wg.Wait()
// test async dial and report from 0 to 1 (already connected)
err = endpoints[0].AsyncDialAndReport(ctx, ids[1], func(ctx context.Context, success bool) {
Expand All @@ -155,6 +162,13 @@ func TestLibp2pEndpoint(t *testing.T) {
require.False(t, success)
wg.Done()
})
go func() {
// AsyncDialAndReport adds the dial action to the event queue, so we
// need to run the scheduler
for !scheds[0].RunOne(ctx) {
time.Sleep(time.Millisecond)
}
}()
wg.Wait()
// test asyc dial with invalid peerid
err = endpoints[0].AsyncDialAndReport(ctx, invalidID, nil)
Expand All @@ -168,6 +182,8 @@ func TestLibp2pEndpoint(t *testing.T) {
}
err = endpoints[1].AddRequestHandler(protoID, &ipfsv1.Message{}, requestHandler)
require.NoError(t, err)
err = endpoints[1].AddRequestHandler(protoID, &ipfsv1.Message{}, nil)
require.Equal(t, endpoint.ErrNilRequestHandler, err)

// send request from 0 to 1
wg.Add(1)
Expand Down
14 changes: 14 additions & 0 deletions query/simplequery/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ type Config struct {
// for a request. It is used to determine whether the query should be
// stopped and whether the peerlist should be updated.
HandleResultsFunc HandleResultFn
// NotifyFailureFn is a function that is called when the query fails. It is
// used to notify the user that the query failed.
NotifyFailureFunc NotifyFailureFn

// RoutingTable is the routing table used to find closer peers. It is
// updated with newly discovered peers.
Expand Down Expand Up @@ -79,6 +82,7 @@ var DefaultConfig = func(cfg *Config) error {
resp message.MinKadResponseMessage) (bool, []address.NodeID) {
return false, resp.CloserNodes()
}
cfg.NotifyFailureFunc = func(context.Context) {}

return nil
}
Expand Down Expand Up @@ -134,6 +138,16 @@ func WithHandleResultsFunc(fn HandleResultFn) Option {
}
}

func WithNotifyFailureFunc(fn NotifyFailureFn) Option {
return func(cfg *Config) error {
if fn == nil {
return fmt.Errorf("NotifyFailureFunc cannot be nil")
}
cfg.NotifyFailureFunc = fn
return nil
}
}

func WithRoutingTable(rt routingtable.RoutingTable) Option {
return func(cfg *Config) error {
if rt == nil {
Expand Down
114 changes: 65 additions & 49 deletions query/simplequery/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@ import (
"go.opentelemetry.io/otel/trace"
)

// note that the returned []address.NodeID are expected to be of the same type
// as the type returned by the routing table's NearestPeers method. the
// address.NodeID returned by resp.CloserNodes() is not necessarily of the same
// type as the one returned by the routing table's NearestPeers method. so
// address.NodeID s may need to be converted in this function.
type HandleResultFn func(context.Context, address.NodeID,
message.MinKadResponseMessage) (bool, []address.NodeID)

type NotifyFailureFn func(context.Context)

type SimpleQuery struct {
ctx context.Context
done bool
Expand All @@ -37,8 +44,10 @@ type SimpleQuery struct {
inflightRequests int // requests that are either in flight or scheduled
peerlist *peerList

// success condition
// response handling
handleResultFn HandleResultFn
// failure callback
notifyFailureFn NotifyFailureFn
}

// NewSimpleQuery creates a new SimpleQuery. It initializes the query by adding
Expand All @@ -49,7 +58,7 @@ type SimpleQuery struct {
// parameters. The query keeps track of the closest known peers to the target
// key, and the peers that have been queried so far.
func NewSimpleQuery(ctx context.Context, req message.MinKadRequestMessage,
opts ...Option) *SimpleQuery {
opts ...Option) (*SimpleQuery, error) {

ctx, span := util.StartSpan(ctx, "SimpleQuery.NewSimpleQuery",
trace.WithAttributes(attribute.String("Target", req.Target().Hex())))
Expand All @@ -59,46 +68,68 @@ func NewSimpleQuery(ctx context.Context, req message.MinKadRequestMessage,
var cfg Config
if err := cfg.Apply(append([]Option{DefaultConfig}, opts...)...); err != nil {
span.RecordError(err)
return nil
return nil, err
}

closestPeers, err := cfg.RoutingTable.NearestPeers(ctx, req.Target(),
cfg.NumberUsefulCloserPeers)
if err != nil {
span.RecordError(err)
return nil
return nil, err
}

pl := newPeerList(req.Target())
pl.addToPeerlist(closestPeers)

q := &SimpleQuery{
ctx: ctx,
req: req,
protoID: cfg.ProtocolID,
concurrency: cfg.Concurrency,
timeout: cfg.RequestTimeout,
peerstoreTTL: cfg.PeerstoreTTL,
rt: cfg.RoutingTable,
msgEndpoint: cfg.Endpoint,
sched: cfg.Scheduler,
handleResultFn: cfg.HandleResultsFunc,
peerlist: pl,
}

// we don't want more pending requests than the number of peers we can query
requestsEvents := q.concurrency
if len(closestPeers) < q.concurrency {
requestsEvents = len(closestPeers)
}
for i := 0; i < requestsEvents; i++ {
// add concurrency requests to the event queue
q.sched.EnqueueAction(ctx, ba.BasicAction(q.newRequest))
ctx: ctx,
req: req,
protoID: cfg.ProtocolID,
concurrency: cfg.Concurrency,
timeout: cfg.RequestTimeout,
peerstoreTTL: cfg.PeerstoreTTL,
rt: cfg.RoutingTable,
msgEndpoint: cfg.Endpoint,
sched: cfg.Scheduler,
handleResultFn: cfg.HandleResultsFunc,
notifyFailureFn: cfg.NotifyFailureFunc,
peerlist: pl,
}
span.AddEvent("Enqueued " + strconv.Itoa(requestsEvents) + " SimpleQuery.newRequest")
q.inflightRequests = requestsEvents

return q
q.enqueueNewRequests(ctx)

return q, nil
}

func (q *SimpleQuery) enqueueNewRequests(ctx context.Context) {
ctx, span := util.StartSpan(ctx, "SimpleQuery.enqueueNewRequests")
defer span.End()

// we always want to have the maximal number of requests in flight
newRequestsToSend := q.concurrency - q.inflightRequests
if q.peerlist.queuedCount < newRequestsToSend {
newRequestsToSend = q.peerlist.queuedCount
}

if newRequestsToSend == 0 && q.inflightRequests == 0 {
// no more requests to send and no requests in flight, query has failed
// and is done
q.done = true
span.AddEvent("all peers queried")
q.notifyFailureFn(ctx)
}

span.AddEvent("newRequestsToSend: " + strconv.Itoa(newRequestsToSend) +
" q.inflightRequests: " + strconv.Itoa(q.inflightRequests))

for i := 0; i < newRequestsToSend; i++ {
// add new pending request(s) for this query to eventqueue
q.sched.EnqueueAction(ctx, ba.BasicAction(q.newRequest))

}
q.inflightRequests += newRequestsToSend
span.AddEvent("Enqueued " + strconv.Itoa(newRequestsToSend) +
" SimpleQuery.newRequest")
}

func (q *SimpleQuery) checkIfDone() error {
Expand Down Expand Up @@ -131,10 +162,12 @@ func (q *SimpleQuery) newRequest(ctx context.Context) {
}

id := q.peerlist.popClosestQueued()
if id == nil || id.String() == "" {
// TODO: handle this case
if id == nil {
// TODO: should never happen
q.done = true
span.AddEvent("all peers queried")
q.inflightRequests--
q.notifyFailureFn(ctx)
return
}
span.AddEvent("peer selected: " + id.String())
Expand Down Expand Up @@ -211,23 +244,7 @@ func (q *SimpleQuery) handleResponse(ctx context.Context, id address.NodeID, res

q.peerlist.addToPeerlist(usefulNodeID)

// we always want to have the maximal number of requests in flight
newRequestsToSend := q.concurrency - q.inflightRequests
if q.peerlist.queuedCount < newRequestsToSend {
newRequestsToSend = q.peerlist.queuedCount
}

span.AddEvent("newRequestsToSend: " + strconv.Itoa(newRequestsToSend) + " q.inflightRequests: " + strconv.Itoa(q.inflightRequests))

for i := 0; i < newRequestsToSend; i++ {
// add new pending request(s) for this query to eventqueue
q.sched.EnqueueAction(ctx, ba.BasicAction(q.newRequest))

}
q.inflightRequests += newRequestsToSend
span.AddEvent("Enqueued " + strconv.Itoa(newRequestsToSend) +
" SimpleQuery.newRequest")

q.enqueueNewRequests(ctx)
}

func (q *SimpleQuery) requestError(ctx context.Context, id address.NodeID, err error) {
Expand All @@ -250,6 +267,5 @@ func (q *SimpleQuery) requestError(ctx context.Context, id address.NodeID, err e

q.peerlist.updatePeerStatusInPeerlist(id, unreachable)

// add pending request for this query to eventqueue
q.sched.EnqueueAction(ctx, ba.BasicAction(q.newRequest))
q.enqueueNewRequests(ctx)
}
Loading

0 comments on commit eda993d

Please sign in to comment.