Skip to content

Commit

Permalink
pool: Make endpoint tests more robust and fix race.
Browse files Browse the repository at this point in the history
This modifies the endpoint tests to respect the context when sending to
and receiving from the endpoint channels similar to what was recently
done for the chainstate tests.

It also fixes a race in the test due to the add callback being invoked
in the run goroutine.
  • Loading branch information
davecgh committed Sep 20, 2023
1 parent 491a43e commit 70baf85
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions pool/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func makeConn(listener *net.TCPListener, serverCh chan net.Conn) (net.Conn, net.
}

func testEndpoint(t *testing.T) {
const maxConnsPerHost = 3
powLimit := chaincfg.SimNetParams().PowLimit
iterations := math.Pow(2, float64(256-powLimit.BitLen()))
maxGenTime := time.Second * 20
Expand All @@ -36,14 +37,14 @@ func testEndpoint(t *testing.T) {
new(big.Rat).SetInt(powLimit), maxGenTime)
connections := make(map[string]uint32)
var connectionsMtx sync.RWMutex
var connectionsWg sync.WaitGroup
removeConn := make(chan struct{}, maxConnsPerHost)
eCfg := &EndpointConfig{
ActiveNet: chaincfg.SimNetParams(),
db: db,
SoloPool: true,
Blake256Pad: blake256Pad,
NonceIterations: iterations,
MaxConnectionsPerHost: 3,
MaxConnectionsPerHost: maxConnsPerHost,
FetchMinerDifficulty: func(miner string) (*DifficultyInfo, error) {
return poolDiffs.fetchMinerDifficulty(miner)
},
Expand All @@ -57,7 +58,6 @@ func testEndpoint(t *testing.T) {
return true
},
AddConnection: func(host string) {
connectionsWg.Add(1)
connectionsMtx.Lock()
connections[host]++
connectionsMtx.Unlock()
Expand All @@ -66,7 +66,7 @@ func testEndpoint(t *testing.T) {
connectionsMtx.Lock()
connections[host]--
connectionsMtx.Unlock()
connectionsWg.Done()
removeConn <- struct{}{}
},
FetchHostConnections: func(host string) uint32 {
connectionsMtx.RLock()
Expand All @@ -91,7 +91,18 @@ func testEndpoint(t *testing.T) {
endpoint.run(ctx)
wg.Done()
}()
time.Sleep(time.Millisecond * 100)
sendToConnChanOrFatal := func(msg *connection) {
select {
case endpoint.connCh <- msg:
case <-ctx.Done():
t.Fatalf("unexpected enpoing shutdown")
}
select {
case <-msg.Done:
case <-ctx.Done():
t.Fatalf("unexpected enpoing shutdown")
}
}

laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3031")
if err != nil {
Expand Down Expand Up @@ -134,8 +145,7 @@ func testEndpoint(t *testing.T) {
Conn: connA,
Done: make(chan struct{}),
}
endpoint.connCh <- msgA
<-msgA.Done
sendToConnChanOrFatal(msgA)
addr := connA.RemoteAddr()
tcpAddr, err := net.ResolveTCPAddr(addr.Network(), addr.String())
if err != nil {
Expand All @@ -161,8 +171,7 @@ func testEndpoint(t *testing.T) {
Conn: connB,
Done: make(chan struct{}),
}
endpoint.connCh <- msgB
<-msgB.Done
sendToConnChanOrFatal(msgB)
connC, srvC, err := makeConn(ln, serverCh)
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -173,8 +182,7 @@ func testEndpoint(t *testing.T) {
Conn: connC,
Done: make(chan struct{}),
}
endpoint.connCh <- msgC
<-msgC.Done
sendToConnChanOrFatal(msgC)

// Ensure the connected clients to the host got incremented to 3.
hostConnections = endpoint.cfg.FetchHostConnections(host)
Expand All @@ -194,8 +202,7 @@ func testEndpoint(t *testing.T) {
Conn: connD,
Done: make(chan struct{}),
}
endpoint.connCh <- msgD
<-msgD.Done
sendToConnChanOrFatal(msgD)

// Ensure the connected clients count to the host stayed at 3 because
// the recent connection got rejected due to MaxConnectionCountPerHost
Expand All @@ -206,7 +213,7 @@ func testEndpoint(t *testing.T) {
"for host %s, got %d", 3, host, hostConnections)
}

// Remove all clients.
// Remove all clients and wait for their removal.
endpoint.clientsMtx.Lock()
clients := make([]*Client, 0, len(endpoint.clients))
for _, cl := range endpoint.clients {
Expand All @@ -216,7 +223,13 @@ func testEndpoint(t *testing.T) {
for _, cl := range clients {
cl.shutdown()
}
connectionsWg.Wait()
for i := 0; i < maxConnsPerHost; i++ {
select {
case <-removeConn:
case <-time.After(time.Second):
t.Fatalf("timeout waiting for connection removal")
}
}

// Ensure there are no connected clients to the host.
hostConnections = endpoint.cfg.FetchHostConnections(host)
Expand Down

0 comments on commit 70baf85

Please sign in to comment.