diff --git a/coord/coordinator.go b/coord/coordinator.go index a68a2da..a7a3955 100644 --- a/coord/coordinator.go +++ b/coord/coordinator.go @@ -25,41 +25,30 @@ type StateMachine[S any, E any] interface { Advance(context.Context, E) S } -// eventQueue is a bounded, typed queue for events -// NOTE: this type is incompatible with the semantics of event.Queue which blocks on Dequeue -type eventQueue[E any] struct { - events chan E +type StateMachineAction[E any] struct { + Event E + AdvanceFn func(context.Context, E) } -func newEventQueue[E any](capacity int) *eventQueue[E] { - return &eventQueue[E]{ - events: make(chan E, capacity), - } -} - -// Enqueue adds an event to the queue. It blocks if the queue is at capacity. -func (q *eventQueue[E]) Enqueue(ctx context.Context, e E) { - q.events <- e +func (a *StateMachineAction[E]) Run(ctx context.Context) { + a.AdvanceFn(ctx, a.Event) } -// Dequeue reads an event from the queue. It returns the event and a true value -// if an event was read or the zero value of the event type and false if no event -// was read. This method is non-blocking. -func (q *eventQueue[E]) Dequeue(ctx context.Context) (E, bool) { - select { - case e := <-q.events: - return e, true - default: - var v E - return v, false - } +func (a *StateMachineAction[E]) String() string { + return fmt.Sprintf("StateMachineAction[%T]", a.Event) } // FindNodeRequestFunc is a function that creates a request to find the supplied node id // TODO: consider this being a first class method of the Endpoint type FindNodeRequestFunc[K kad.Key[K], A kad.Address[A]] func(kad.NodeID[K]) (address.ProtocolID, kad.Request[K, A]) +// ActionQueue accepts actions and queues them for later execution +type ActionQueue interface { + EnqueueAction(context.Context, event.Action) +} + // A Coordinator coordinates the state machines that comprise a Kademlia DHT +// It is only one possible configuration of the DHT components, others are possible. // Currently this is only queries and bootstrapping but will expand to include other state machines such as // routing table refresh, and reproviding. type Coordinator[K kad.Key[K], A kad.Address[A]] struct { @@ -72,21 +61,12 @@ type Coordinator[K kad.Key[K], A kad.Address[A]] struct { // pool is the query pool state machine, responsible for running user-submitted queries pool StateMachine[query.PoolState, query.PoolEvent] - // poolEvents is a fifo queue of events that are to be processed by the pool state machine - poolEvents *eventQueue[query.PoolEvent] - // bootstrap is the bootstrap state machine, responsible for bootstrapping the routing table bootstrap StateMachine[routing.BootstrapState, routing.BootstrapEvent] - // bootstrapEvents is a fifo queue of events that are to be processed by the bootstrap state machine - bootstrapEvents *eventQueue[routing.BootstrapEvent] - // include is the include state machine, responsible for including candidate nodes into the routing table include StateMachine[routing.IncludeState, routing.IncludeEvent] - // includeEvents is a fifo queue of events that are to be processed by the include state machine - includeEvents *eventQueue[routing.IncludeEvent] - // rt is the routing table used to look up nodes by distance rt kad.RoutingTable[K, kad.NodeID[K]] @@ -97,11 +77,7 @@ type Coordinator[K kad.Key[K], A kad.Address[A]] struct { // TODO: thiis should be a function of the endpoint findNodeFn FindNodeRequestFunc[K, A] - // queue not used - queue event.EventQueue - - // planner not used - planner event.AwareActionPlanner + sched ActionQueue outboundEvents chan KademliaEvent } @@ -169,7 +145,7 @@ func DefaultConfig() *Config { } } -func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpoint.Endpoint[K, A], fn FindNodeRequestFunc[K, A], rt kad.RoutingTable[K, kad.NodeID[K]], cfg *Config) (*Coordinator[K, A], error) { +func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpoint.Endpoint[K, A], fn FindNodeRequestFunc[K, A], rt kad.RoutingTable[K, kad.NodeID[K]], sched ActionQueue, cfg *Config) (*Coordinator[K, A], error) { if cfg == nil { cfg = DefaultConfig() } else if err := cfg.Validate(); err != nil { @@ -213,20 +189,16 @@ func NewCoordinator[K kad.Key[K], A kad.Address[A]](self kad.NodeID[K], ep endpo return nil, fmt.Errorf("include: %w", err) } return &Coordinator[K, A]{ - self: self, - cfg: *cfg, - ep: ep, - findNodeFn: fn, - rt: rt, - pool: qp, - poolEvents: newEventQueue[query.PoolEvent](20), // 20 is abitrary, move to config - bootstrap: bootstrap, - bootstrapEvents: newEventQueue[routing.BootstrapEvent](20), // 20 is abitrary, move to config - include: include, - includeEvents: newEventQueue[routing.IncludeEvent](20), // 20 is abitrary, move to config - outboundEvents: make(chan KademliaEvent, 20), - queue: event.NewChanQueue(DefaultChanqueueCapacity), - planner: event.NewSimplePlanner(cfg.Clock), + self: self, + cfg: *cfg, + ep: ep, + findNodeFn: fn, + rt: rt, + pool: qp, + bootstrap: bootstrap, + include: include, + outboundEvents: make(chan KademliaEvent, 20), + sched: sched, }, nil } @@ -234,126 +206,91 @@ func (c *Coordinator[K, A]) Events() <-chan KademliaEvent { return c.outboundEvents } -func (c *Coordinator[K, A]) RunOne(ctx context.Context) bool { - ctx, span := util.StartSpan(ctx, "Coordinator.RunOne") - defer span.End() - - // Process state machines in priority order - - // Give the bootstrap state machine priority - if c.advanceBootstrap(ctx) { - return true - } - - // Attempt to advance the include state machine so candidate nodes - // are added to the routing table - if c.advanceInclude(ctx) { - return true - } - - // Attempt to advance an outbound query - if c.advancePool(ctx) { - return true - } - - return false +func (c *Coordinator[K, A]) scheduleBootstrapEvent(ctx context.Context, ev routing.BootstrapEvent) { + // TODO: enqueue with higher priority when we have priority queues + c.sched.EnqueueAction(ctx, &StateMachineAction[routing.BootstrapEvent]{ + Event: ev, + AdvanceFn: c.advanceBootstrap, + }) } -func (c *Coordinator[K, A]) advanceBootstrap(ctx context.Context) bool { - bev, ok := c.bootstrapEvents.Dequeue(ctx) - if !ok { - bev = &routing.EventBootstrapPoll{} - } - - bstate := c.bootstrap.Advance(ctx, bev) +func (c *Coordinator[K, A]) advanceBootstrap(ctx context.Context, ev routing.BootstrapEvent) { + bstate := c.bootstrap.Advance(ctx, ev) switch st := bstate.(type) { case *routing.StateBootstrapMessage[K, A]: c.sendBootstrapFindNode(ctx, st.NodeID, st.QueryID, st.Stats) - return true case *routing.StateBootstrapWaiting: - // bootstrap waiting for a message response, proceed with other state machines - return false - + // bootstrap waiting for a message response, nothing to do case *routing.StateBootstrapFinished: c.outboundEvents <- &KademliaBootstrapFinishedEvent{ Stats: st.Stats, } - return true - case *routing.StateBootstrapIdle: - // bootstrap not running, can proceed to other state machines - return false + // bootstrap not running, nothing to do default: panic(fmt.Sprintf("unexpected bootstrap state: %T", st)) } } -func (c *Coordinator[K, A]) advanceInclude(ctx context.Context) bool { +func (c *Coordinator[K, A]) scheduleIncludeEvent(ctx context.Context, ev routing.IncludeEvent) { + c.sched.EnqueueAction(ctx, &StateMachineAction[routing.IncludeEvent]{ + Event: ev, + AdvanceFn: c.advanceInclude, + }) +} + +func (c *Coordinator[K, A]) advanceInclude(ctx context.Context, ev routing.IncludeEvent) { // Attempt to advance the include state machine so candidate nodes // are added to the routing table - iev, ok := c.includeEvents.Dequeue(ctx) - if !ok { - iev = &routing.EventIncludePoll{} - } - istate := c.include.Advance(ctx, iev) + istate := c.include.Advance(ctx, ev) switch st := istate.(type) { case *routing.StateIncludeFindNodeMessage[K, A]: // include wants to send a find node message to a node c.sendIncludeFindNode(ctx, st.NodeInfo) - return true case *routing.StateIncludeRoutingUpdated[K, A]: // a node has been included in the routing table c.outboundEvents <- &KademliaRoutingUpdatedEvent[K, A]{ NodeInfo: st.NodeInfo, } - return true case *routing.StateIncludeWaitingAtCapacity: // nothing to do except wait for message response or timeout - return false case *routing.StateIncludeWaitingWithCapacity: // nothing to do except wait for message response or timeout - return false case *routing.StateIncludeWaitingFull: // nothing to do except wait for message response or timeout - return false case *routing.StateIncludeIdle: // nothing to do except wait for message response or timeout - return false default: panic(fmt.Sprintf("unexpected include state: %T", st)) } } -func (c *Coordinator[K, A]) advancePool(ctx context.Context) bool { - pev, ok := c.poolEvents.Dequeue(ctx) - if !ok { - pev = &query.EventPoolPoll{} - } +func (c *Coordinator[K, A]) schedulePoolEvent(ctx context.Context, ev query.PoolEvent) { + c.sched.EnqueueAction(ctx, &StateMachineAction[query.PoolEvent]{ + Event: ev, + AdvanceFn: c.advancePool, + }) +} - state := c.pool.Advance(ctx, pev) +func (c *Coordinator[K, A]) advancePool(ctx context.Context, ev query.PoolEvent) { + state := c.pool.Advance(ctx, ev) switch st := state.(type) { case *query.StatePoolQueryMessage[K, A]: c.sendQueryMessage(ctx, st.ProtocolID, st.NodeID, st.Message, st.QueryID, st.Stats) - return true case *query.StatePoolWaitingAtCapacity: // nothing to do except wait for message response or timeout - return false case *query.StatePoolWaitingWithCapacity: // nothing to do except wait for message response or timeout - return false case *query.StatePoolQueryFinished: c.outboundEvents <- &KademliaOutboundQueryFinishedEvent{ QueryID: st.QueryID, Stats: st.Stats, } - return true case *query.StatePoolQueryTimeout: // TODO - return false case *query.StatePoolIdle: // nothing to do - return false default: panic(fmt.Sprintf("unexpected pool state: %T", st)) } @@ -370,12 +307,11 @@ func (c *Coordinator[K, A]) sendQueryMessage(ctx context.Context, protoID addres return } - qev := &query.EventPoolMessageFailure[K]{ + c.advancePool(ctx, &query.EventPoolMessageFailure[K]{ NodeID: to, QueryID: queryID, Error: err, - } - c.poolEvents.Enqueue(ctx, qev) + }) } onMessageResponse := func(ctx context.Context, resp kad.Response[K, A], err error) { @@ -400,12 +336,11 @@ func (c *Coordinator[K, A]) sendQueryMessage(ctx context.Context, protoID addres Stats: stats, } - qev := &query.EventPoolMessageResponse[K, A]{ + c.advancePool(ctx, &query.EventPoolMessageResponse[K, A]{ NodeID: to, QueryID: queryID, Response: resp, - } - c.poolEvents.Enqueue(ctx, qev) + }) } err := c.ep.SendRequestHandleResponse(ctx, protoID, to, msg, msg.EmptyResponse(), 0, onMessageResponse) @@ -425,11 +360,10 @@ func (c *Coordinator[K, A]) sendBootstrapFindNode(ctx context.Context, to kad.No return } - bev := &routing.EventBootstrapMessageFailure[K]{ + c.advanceBootstrap(ctx, &routing.EventBootstrapMessageFailure[K]{ NodeID: to, Error: err, - } - c.bootstrapEvents.Enqueue(ctx, bev) + }) } onMessageResponse := func(ctx context.Context, resp kad.Response[K, A], err error) { @@ -454,11 +388,10 @@ func (c *Coordinator[K, A]) sendBootstrapFindNode(ctx context.Context, to kad.No Stats: stats, } - bev := &routing.EventBootstrapMessageResponse[K, A]{ + c.advanceBootstrap(ctx, &routing.EventBootstrapMessageResponse[K, A]{ NodeID: to, Response: resp, - } - c.bootstrapEvents.Enqueue(ctx, bev) + }) } protoID, msg := c.findNodeFn(c.self) @@ -479,11 +412,10 @@ func (c *Coordinator[K, A]) sendIncludeFindNode(ctx context.Context, to kad.Node return } - iev := &routing.EventIncludeMessageFailure[K, A]{ + c.advanceInclude(ctx, &routing.EventIncludeMessageFailure[K, A]{ NodeInfo: to, Error: err, - } - c.includeEvents.Enqueue(ctx, iev) + }) } onMessageResponse := func(ctx context.Context, resp kad.Response[K, A], err error) { @@ -492,19 +424,10 @@ func (c *Coordinator[K, A]) sendIncludeFindNode(ctx context.Context, to kad.Node return } - iev := &routing.EventIncludeMessageResponse[K, A]{ + c.advanceInclude(ctx, &routing.EventIncludeMessageResponse[K, A]{ NodeInfo: to, Response: resp, - } - c.includeEvents.Enqueue(ctx, iev) - - if resp != nil { - candidates := resp.CloserNodes() - if len(candidates) > 0 { - // ignore error here - c.AddNodes(ctx, candidates) - } - } + }) } // this might be new node addressing info @@ -522,24 +445,25 @@ func (c *Coordinator[K, A]) StartQuery(ctx context.Context, queryID query.QueryI defer span.End() knownClosestPeers := c.rt.NearestNodes(msg.Target(), 20) - qev := &query.EventPoolAddQuery[K, A]{ + c.schedulePoolEvent(ctx, &query.EventPoolAddQuery[K, A]{ QueryID: queryID, Target: msg.Target(), ProtocolID: protocolID, Message: msg, KnownClosestNodes: knownClosestPeers, - } - c.poolEvents.Enqueue(ctx, qev) + }) + return nil } func (c *Coordinator[K, A]) StopQuery(ctx context.Context, queryID query.QueryID) error { ctx, span := util.StartSpan(ctx, "Coordinator.StopQuery") defer span.End() - qev := &query.EventPoolStopQuery{ + + c.schedulePoolEvent(ctx, &query.EventPoolStopQuery{ QueryID: queryID, - } - c.poolEvents.Enqueue(ctx, qev) + }) + return nil } @@ -553,11 +477,10 @@ func (c *Coordinator[K, A]) AddNodes(ctx context.Context, infos []kad.NodeInfo[K // skip self continue } - // inject a new node into the coordinator's includeEvents queue - iev := &routing.EventIncludeAddCandidate[K, A]{ + + c.scheduleIncludeEvent(ctx, &routing.EventIncludeAddCandidate[K, A]{ NodeInfo: info, - } - c.includeEvents.Enqueue(ctx, iev) + }) } return nil @@ -568,13 +491,11 @@ func (c *Coordinator[K, A]) AddNodes(ctx context.Context, infos []kad.NodeInfo[K func (c *Coordinator[K, A]) Bootstrap(ctx context.Context, seeds []kad.NodeID[K]) error { protoID, msg := c.findNodeFn(c.self) - bev := &routing.EventBootstrapStart[K, A]{ + c.scheduleBootstrapEvent(ctx, &routing.EventBootstrapStart[K, A]{ ProtocolID: protoID, Message: msg, KnownClosestNodes: seeds, - } - - c.bootstrapEvents.Enqueue(ctx, bev) + }) return nil } @@ -623,31 +544,3 @@ func (*KademliaUnroutablePeerEvent[K]) kademliaEvent() {} func (*KademliaRoutablePeerEvent[K]) kademliaEvent() {} func (*KademliaOutboundQueryFinishedEvent) kademliaEvent() {} func (*KademliaBootstrapFinishedEvent) kademliaEvent() {} - -// var _ scheduler.Scheduler = (*Coordinator[key.Key8])(nil) -func (c *Coordinator[K, A]) Clock() clock.Clock { - return c.cfg.Clock -} - -func (c *Coordinator[K, A]) EnqueueAction(ctx context.Context, a event.Action) { - c.queue.Enqueue(ctx, a) -} - -func (c *Coordinator[K, A]) ScheduleAction(ctx context.Context, t time.Time, a event.Action) event.PlannedAction { - if c.cfg.Clock.Now().After(t) { - c.EnqueueAction(ctx, a) - return nil - } - return c.planner.ScheduleAction(ctx, t, a) -} - -func (c *Coordinator[K, A]) RemovePlannedAction(ctx context.Context, a event.PlannedAction) bool { - return c.planner.RemoveAction(ctx, a) -} - -// NextActionTime returns the time of the next action to run, or the current -// time if there are actions to be run in the queue, or util.MaxTime if there -// are no scheduled to run. -func (c *Coordinator[K, A]) NextActionTime(ctx context.Context) time.Time { - return c.cfg.Clock.Now() -} diff --git a/coord/coordinator_test.go b/coord/coordinator_test.go index 79dbc39..cf9f68a 100644 --- a/coord/coordinator_test.go +++ b/coord/coordinator_test.go @@ -22,7 +22,9 @@ import ( "github.com/plprobelab/go-kademlia/sim" ) -func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8, kadtest.StrAddr], []*sim.Endpoint[key.Key8, kadtest.StrAddr], []*simplert.SimpleRT[key.Key8, kad.NodeID[key.Key8]], *sim.LiteSimulator) { +var _ event.Action = (*StateMachineAction[query.PoolEvent])(nil) + +func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8, kadtest.StrAddr], []*sim.Endpoint[key.Key8, kadtest.StrAddr], []*simplert.SimpleRT[key.Key8, kad.NodeID[key.Key8]], []event.AwareScheduler, *sim.LiteSimulator) { // create node identifiers nodeCount := 4 ids := make([]*kadtest.ID[key.Key8], nodeCount) @@ -83,7 +85,7 @@ func setupSimulation(t *testing.T, ctx context.Context) ([]kad.NodeInfo[key.Key8 siml := sim.NewLiteSimulator(clk) sim.AddSchedulers(siml, schedulers...) - return addrs, eps, rts, siml + return addrs, eps, rts, schedulers, siml } // connectNodes adds nodes to each other's peerstores and routing tables @@ -176,7 +178,7 @@ func TestExhaustiveQuery(t *testing.T) { ctx, cancel := kadtest.Ctx(t) defer cancel() - nodes, eps, rts, siml := setupSimulation(t, ctx) + nodes, eps, rts, scheds, siml := setupSimulation(t, ctx) clk := siml.Clock() @@ -184,26 +186,14 @@ func TestExhaustiveQuery(t *testing.T) { ccfg.Clock = clk ccfg.PeerstoreTTL = peerstoreTTL - go func(ctx context.Context) { - for { - select { - case <-time.After(10 * time.Millisecond): - siml.Run(ctx) - case <-ctx.Done(): - return - } - } - }(ctx) - // A (ids[0]) is looking for D (ids[3]) // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], scheds[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } - siml.Add(c) events := c.Events() queryID := query.QueryID("query1") @@ -213,6 +203,9 @@ func TestExhaustiveQuery(t *testing.T) { t.Fatalf("failed to start query: %v", err) } + // progress the schedulers + siml.Run(ctx) + // the query run by the coordinator should have received a response from nodes[1] ev, err := expectEventType(t, ctx, events, &KademliaOutboundQueryProgressedEvent[key.Key8, kadtest.StrAddr]{}) require.NoError(t, err) @@ -253,7 +246,7 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { ctx, cancel := kadtest.Ctx(t) defer cancel() - nodes, eps, rts, siml := setupSimulation(t, ctx) + nodes, eps, rts, scheds, siml := setupSimulation(t, ctx) clk := siml.Clock() @@ -261,26 +254,14 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { ccfg.Clock = clk ccfg.PeerstoreTTL = peerstoreTTL - go func(ctx context.Context) { - for { - select { - case <-time.After(10 * time.Millisecond): - siml.Run(ctx) - case <-ctx.Done(): - return - } - } - }(ctx) - // A (ids[0]) is looking for D (ids[3]) // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], scheds[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } - siml.Add(c) events := c.Events() queryID := query.QueryID("query1") @@ -290,6 +271,9 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { t.Fatalf("failed to start query: %v", err) } + // progress the schedulers + siml.Run(ctx) + // the query run by the coordinator should have received a response from nodes[1] with closer nodes // nodes[0] and nodes[2] which should trigger a routing table update ev, err := expectEventType(t, ctx, events, &KademliaRoutingUpdatedEvent[key.Key8, kadtest.StrAddr]{}) @@ -307,17 +291,13 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { tev = ev.(*KademliaRoutingUpdatedEvent[key.Key8, kadtest.StrAddr]) require.Equal(t, nodes[3].ID(), tev.NodeInfo.ID()) - - // the query run by the coordinator should have completed - _, err = expectEventType(t, ctx, events, &KademliaOutboundQueryFinishedEvent{}) - require.NoError(t, err) } func TestBootstrap(t *testing.T) { ctx, cancel := kadtest.Ctx(t) defer cancel() - nodes, eps, rts, siml := setupSimulation(t, ctx) + nodes, eps, rts, scheds, siml := setupSimulation(t, ctx) clk := siml.Clock() @@ -325,23 +305,11 @@ func TestBootstrap(t *testing.T) { ccfg.Clock = clk ccfg.PeerstoreTTL = peerstoreTTL - go func(ctx context.Context) { - for { - select { - case <-time.After(10 * time.Millisecond): - siml.Run(ctx) - case <-ctx.Done(): - return - } - } - }(ctx) - self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], scheds[0], ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } - siml.Add(c) events := c.Events() queryID := query.QueryID("bootstrap") @@ -352,6 +320,9 @@ func TestBootstrap(t *testing.T) { err = c.Bootstrap(ctx, seeds) require.NoError(t, err) + // progress the schedulers + siml.Run(ctx) + // the query run by the coordinator should have received a response from nodes[1] ev, err := expectEventType(t, ctx, events, &KademliaOutboundQueryProgressedEvent[key.Key8, kadtest.StrAddr]{}) require.NoError(t, err) @@ -391,7 +362,7 @@ func TestIncludeNode(t *testing.T) { ctx, cancel := kadtest.Ctx(t) defer cancel() - nodes, eps, rts, siml := setupSimulation(t, ctx) + nodes, eps, rts, scheds, siml := setupSimulation(t, ctx) clk := siml.Clock() @@ -399,25 +370,6 @@ func TestIncludeNode(t *testing.T) { ccfg.Clock = clk ccfg.PeerstoreTTL = peerstoreTTL - go func(ctx context.Context) { - for { - select { - case <-time.After(10 * time.Millisecond): - siml.Run(ctx) - case <-ctx.Done(): - return - } - } - }(ctx) - - self := nodes[0].ID() - c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], ccfg) - if err != nil { - log.Fatalf("unexpected error creating coordinator: %v", err) - } - siml.Add(c) - events := c.Events() - candidate := nodes[3] // not in nodes[0] routing table // the routing table should not contain the node yet @@ -425,10 +377,20 @@ func TestIncludeNode(t *testing.T) { require.NoError(t, err) require.Nil(t, foundNode) + self := nodes[0].ID() + c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], scheds[0], ccfg) + if err != nil { + log.Fatalf("unexpected error creating coordinator: %v", err) + } + events := c.Events() + // inject a new node into the coordinator's includeEvents queue err = c.AddNodes(ctx, []kad.NodeInfo[key.Key8, kadtest.StrAddr]{candidate}) require.NoError(t, err) + // progress the schedulers + siml.Run(ctx) + // the include state machine runs in the background and eventually should add the node to routing table ev, err := expectEventType(t, ctx, events, &KademliaRoutingUpdatedEvent[key.Key8, kadtest.StrAddr]{}) require.NoError(t, err) diff --git a/examples/statemachine/main.go b/examples/statemachine/main.go index bba44d8..17085e8 100644 --- a/examples/statemachine/main.go +++ b/examples/statemachine/main.go @@ -33,7 +33,7 @@ func main() { ctx, cancel := context.WithCancel(ctx) defer cancel() - nodes, eps, rts, siml := setupSimulation(ctx) + nodes, eps, rts, scheds, siml := setupSimulation(ctx) tp, err := tracerProvider("http://localhost:14268/api/traces") if err != nil { @@ -61,11 +61,10 @@ func main() { ccfg.Clock = siml.Clock() ccfg.PeerstoreTTL = peerstoreTTL - kad, err := coord.NewCoordinator[key.Key256, net.IP](nodes[0].ID(), eps[0], findNodeFn, rts[0], ccfg) + kad, err := coord.NewCoordinator[key.Key256, net.IP](nodes[0].ID(), eps[0], findNodeFn, rts[0], scheds[0], ccfg) if err != nil { log.Fatal(err) } - siml.Add(kad) ih := NewIpfsDht(kad) ih.Start(ctx) @@ -97,7 +96,7 @@ func main() { const peerstoreTTL = 10 * time.Minute -func setupSimulation(ctx context.Context) ([]kad.NodeInfo[key.Key256, net.IP], []*sim.Endpoint[key.Key256, net.IP], []kad.RoutingTable[key.Key256, kad.NodeID[key.Key256]], *sim.LiteSimulator) { +func setupSimulation(ctx context.Context) ([]kad.NodeInfo[key.Key256, net.IP], []*sim.Endpoint[key.Key256, net.IP], []kad.RoutingTable[key.Key256, kad.NodeID[key.Key256]], []event.AwareScheduler, *sim.LiteSimulator) { // create node identifiers nodeCount := 4 ids := make([]*kadtest.ID[key.Key256], nodeCount) @@ -183,7 +182,7 @@ func setupSimulation(ctx context.Context) ([]kad.NodeInfo[key.Key256, net.IP], [ siml := sim.NewLiteSimulator(clk) sim.AddSchedulers(siml, schedulers...) - return addrs, eps, rts, siml + return addrs, eps, rts, schedulers, siml } // connectNodes adds nodes to each other's peerstores and routing tables diff --git a/key/key.go b/key/key.go index 9a8808f..d75a503 100644 --- a/key/key.go +++ b/key/key.go @@ -3,12 +3,16 @@ package key import ( "bytes" "encoding/hex" + "errors" "fmt" "math" "github.com/plprobelab/go-kademlia/kad" ) +// ErrInvalidDataLength is the error returned when attempting to construct a key from binary data of the wrong length. +var ErrInvalidDataLength = errors.New("invalid data length") + const bitPanicMsg = "bit index out of range" // Key256 is a 256-bit Kademlia key. @@ -21,7 +25,7 @@ var _ kad.Key[Key256] = Key256{} // NewKey256 returns a 256-bit Kademlia key whose bits are set from the supplied bytes. func NewKey256(data []byte) Key256 { if len(data) != 32 { - panic("invalid data length for key") + panic(ErrInvalidDataLength) } var b [32]byte copy(b[:], data) @@ -86,7 +90,15 @@ func (k Key256) CommonPrefixLength(o Key256) int { // Compare compares the numeric value of the key with another key of the same type. func (k Key256) Compare(o Key256) int { - return bytes.Compare(k.b[:], o.b[:]) + if k.b != nil && o.b != nil { + return bytes.Compare(k.b[:], o.b[:]) + } + + var zero [32]byte + if k.b == nil { + return bytes.Compare(zero[:], o.b[:]) + } + return bytes.Compare(zero[:], k.b[:]) } // HexString returns a string containing the hexadecimal representation of the key. @@ -97,6 +109,16 @@ func (k Key256) HexString() string { return hex.EncodeToString(k.b[:]) } +// MarshalBinary marshals the key into a byte slice. +// The bytes may be passed to NewKey256 to construct a new key with the same value. +func (k Key256) MarshalBinary() ([]byte, error) { + buf := make([]byte, 32) + if k.b != nil { + copy(buf, (*k.b)[:]) + } + return buf, nil +} + // Key32 is a 32-bit Kademlia key, suitable for testing and simulation of small networks. type Key32 uint32 diff --git a/key/key_test.go b/key/key_test.go index ad2d6ed..86b6baa 100644 --- a/key/key_test.go +++ b/key/key_test.go @@ -34,6 +34,8 @@ func TestKey256(t *testing.T) { } tester.RunTests(t) + + testBinaryMarshaler(t, tester.KeyX, NewKey256) } func TestKey32(t *testing.T) { @@ -79,6 +81,7 @@ func TestBitStrKey7(t *testing.T) { tester.RunTests(t) } +// KeyTester tests a kad.Key's implementation type KeyTester[K kad.Key[K]] struct { // Key 0 is zero Key0 K @@ -127,6 +130,12 @@ func (kt *KeyTester[K]) TestXor(t *testing.T) { xored = kt.Key1.Xor(kt.Key2) require.Equal(t, kt.Key1xor2, xored) + + var empty K // zero value of key + xored = kt.Key0.Xor(empty) + require.Equal(t, kt.Key0, xored) + xored = empty.Xor(kt.Key0) + require.Equal(t, kt.Key0, xored) } func (kt *KeyTester[K]) TestCommonPrefixLength(t *testing.T) { @@ -141,6 +150,12 @@ func (kt *KeyTester[K]) TestCommonPrefixLength(t *testing.T) { cpl = kt.Key0.CommonPrefixLength(kt.Key010) require.Equal(t, 1, cpl) + + var empty K // zero value of key + cpl = kt.Key0.CommonPrefixLength(empty) + require.Equal(t, kt.Key0.BitLen(), cpl) + cpl = empty.CommonPrefixLength(kt.Key0) + require.Equal(t, kt.Key0.BitLen(), cpl) } func (kt *KeyTester[K]) TestCompare(t *testing.T) { @@ -167,6 +182,12 @@ func (kt *KeyTester[K]) TestCompare(t *testing.T) { res = kt.Key1.Compare(kt.Key2) require.Equal(t, -1, res) + + var empty K // zero value of key + res = kt.Key0.Compare(empty) + require.Equal(t, 0, res) + res = empty.Compare(kt.Key0) + require.Equal(t, 0, res) } func (kt *KeyTester[K]) TestBit(t *testing.T) { @@ -184,6 +205,11 @@ func (kt *KeyTester[K]) TestBit(t *testing.T) { } require.Equal(t, uint(1), kt.Key2.Bit(kt.Key2.BitLen()-2), fmt.Sprintf("Key1.Bit(%d)=%d", kt.Key2.BitLen()-2, kt.Key2.BitLen()-2)) require.Equal(t, uint(0), kt.Key2.Bit(kt.Key2.BitLen()-1), fmt.Sprintf("Key1.Bit(%d)=%d", kt.Key2.BitLen()-2, kt.Key2.BitLen()-1)) + + var empty K // zero value of key + for i := 0; i < empty.BitLen(); i++ { + require.Equal(t, uint(0), empty.Bit(i), fmt.Sprintf("empty.Bit(%d)=%d", i, kt.Key0.Bit(i))) + } } func (kt *KeyTester[K]) TestBitString(t *testing.T) { @@ -230,6 +256,21 @@ func (kt *KeyTester[K]) TestHexString(t *testing.T) { } } +// testBinaryMarshaler tests the behaviour of a kad.Key implementation that also implements the BinaryMarshaler interface +func testBinaryMarshaler[K interface { + kad.Key[K] + MarshalBinary() ([]byte, error) +}](t *testing.T, k K, newFunc func([]byte) K, +) { + b, err := k.MarshalBinary() + require.NoError(t, err) + + other := newFunc(b) + + res := k.Compare(other) + require.Equal(t, 0, res) +} + // BitStrKey is a key represented by a string of 1's and 0's type BitStrKey string @@ -253,6 +294,12 @@ func (k BitStrKey) Bit(i int) uint { func (k BitStrKey) Xor(o BitStrKey) BitStrKey { if len(k) != len(o) { + if len(k) == 0 && o.isZero() { + return BitStrKey(o) + } + if len(o) == 0 && k.isZero() { + return BitStrKey(k) + } panic("BitStrKey: other key has different length") } buf := make([]byte, len(k)) @@ -268,6 +315,12 @@ func (k BitStrKey) Xor(o BitStrKey) BitStrKey { func (k BitStrKey) CommonPrefixLength(o BitStrKey) int { if len(k) != len(o) { + if len(k) == 0 && o.isZero() { + return len(o) + } + if len(o) == 0 && k.isZero() { + return len(k) + } panic("BitStrKey: other key has different length") } for i := 0; i < len(k); i++ { @@ -280,6 +333,12 @@ func (k BitStrKey) CommonPrefixLength(o BitStrKey) int { func (k BitStrKey) Compare(o BitStrKey) int { if len(k) != len(o) { + if len(k) == 0 && o.isZero() { + return 0 + } + if len(o) == 0 && k.isZero() { + return 0 + } panic("BitStrKey: other key has different length") } for i := 0; i < len(k); i++ { @@ -292,3 +351,12 @@ func (k BitStrKey) Compare(o BitStrKey) int { } return 0 } + +func (k BitStrKey) isZero() bool { + for i := 0; i < len(k); i++ { + if k[i] != '0' { + return false + } + } + return true +} diff --git a/libp2p/libp2pendpoint_test.go b/libp2p/libp2pendpoint_test.go index 0062f73..4dce986 100644 --- a/libp2p/libp2pendpoint_test.go +++ b/libp2p/libp2pendpoint_test.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/p2p/net/swarm" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" "github.com/plprobelab/go-kademlia/event" @@ -135,12 +136,16 @@ func TestConnections(t *testing.T) { connectedness, err = endpoints[0].Connectedness(ids[1]) require.NoError(t, err) require.Equal(t, endpoint.Connected, connectedness) + // test peerinfo peerinfo, err = endpoints[0].PeerInfo(ids[1]) require.NoError(t, err) - require.Len(t, peerinfo.Addrs, len(addrs[1].Addrs)) - for _, addr := range peerinfo.Addrs { - require.Contains(t, addrs[1].Addrs, addr) + // filter out loopback addresses + expectedAddrs := ma.FilterAddrs(addrs[1].Addrs, func(a ma.Multiaddr) bool { + return !manet.IsIPLoopback(a) + }) + for _, addr := range expectedAddrs { + require.Contains(t, peerinfo.Addrs, addr, addr.String(), expectedAddrs) } peerinfo, err = endpoints[0].PeerInfo(ids[2]) require.NoError(t, err) @@ -181,7 +186,7 @@ func TestAsyncDial(t *testing.T) { // AsyncDialAndReport adds the dial action to the event queue, so we // need to run the scheduler for !scheds[0].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[0].Clock().Sleep(time.Millisecond) } wg.Done() }() @@ -216,7 +221,7 @@ func TestAsyncDial(t *testing.T) { // AsyncDialAndReport adds the dial action to the event queue, so we // need to run the scheduler for !scheds[0].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[0].Clock().Sleep(time.Millisecond) } wg.Done() }() @@ -327,7 +332,7 @@ func TestSuccessfulRequest(t *testing.T) { go func() { // run server 1 for !scheds[1].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[1].Clock().Sleep(time.Millisecond) } require.False(t, scheds[1].RunOne(ctx)) // only 1 action should run on server wg.Done() @@ -335,7 +340,7 @@ func TestSuccessfulRequest(t *testing.T) { go func() { // timeout is queued in the scheduler 0 for !scheds[0].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[0].Clock().Sleep(time.Millisecond) } require.False(t, scheds[0].RunOne(ctx)) wg.Done() @@ -384,7 +389,7 @@ func TestReqUnknownPeer(t *testing.T) { go func() { // timeout is queued in the scheduler 0 for !scheds[0].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[0].Clock().Sleep(time.Millisecond) } require.False(t, scheds[0].RunOne(ctx)) wg.Done() @@ -425,7 +430,7 @@ func TestReqTimeout(t *testing.T) { go func() { // timeout is queued in the scheduler 0 for !scheds[0].RunOne(ctx) { - time.Sleep(1 * time.Millisecond) + scheds[0].Clock().Sleep(time.Millisecond) } require.False(t, scheds[0].RunOne(ctx)) wg.Done() @@ -473,7 +478,7 @@ func TestReqHandlerError(t *testing.T) { wg.Add(2) go func() { for !scheds[1].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[1].Clock().Sleep(time.Millisecond) } require.False(t, scheds[1].RunOne(ctx)) cancel() @@ -482,7 +487,7 @@ func TestReqHandlerError(t *testing.T) { go func() { // timeout is queued in the scheduler 0 for !scheds[0].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[0].Clock().Sleep(time.Millisecond) } require.False(t, scheds[0].RunOne(ctx)) wg.Done() @@ -527,7 +532,7 @@ func TestReqHandlerReturnsWrongType(t *testing.T) { wg.Add(2) go func() { for !scheds[1].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[1].Clock().Sleep(time.Millisecond) } require.False(t, scheds[1].RunOne(ctx)) cancel() @@ -536,7 +541,7 @@ func TestReqHandlerReturnsWrongType(t *testing.T) { go func() { // timeout is queued in the scheduler 0 for !scheds[0].RunOne(ctx) { - time.Sleep(time.Millisecond) + scheds[0].Clock().Sleep(time.Millisecond) } require.False(t, scheds[0].RunOne(ctx)) wg.Done() diff --git a/routing/include.go b/routing/include.go index d789912..d6fec3f 100644 --- a/routing/include.go +++ b/routing/include.go @@ -105,6 +105,24 @@ func (b *Include[K, A]) Advance(ctx context.Context, ev IncludeEvent) IncludeSta switch tev := ev.(type) { case *EventIncludeAddCandidate[K, A]: + // Ignore if already running a check + _, checking := b.checks[key.HexString(tev.NodeInfo.ID().Key())] + if checking { + break + } + + // Ignore if node already in routing table + // TODO: promote this interface (or something similar) to kad.RoutingTable + if rtf, ok := b.rt.(interface { + Find(context.Context, kad.NodeID[K]) (kad.NodeInfo[K, A], error) + }); ok { + n, _ := rtf.Find(ctx, tev.NodeInfo.ID()) + if n != nil { + // node already in routing table + break + } + } + // TODO: potentially time out a check and make room in the queue if !b.candidates.HasCapacity() { return &StateIncludeWaitingFull{}