Skip to content

Commit

Permalink
feat(abciclient): support timeouts in abci calls (#749)
Browse files Browse the repository at this point in the history
* feat(abciclient): expire checktx abci calls on messages from p2p

* fix(abciclient): race when accessing socket client waker

* chore(p2p): checktx timeout 1s

* chore(rpc): broadcasttx  timeout 1s

* chore(mempool,rpc):  checktx timeout 1s using const

* fix(rpc): broadcastasync is broken due to invalid ctx cancellation

* revert: waker not needed in abci client

* chore: fix build issue

* chore: fix race condition

* chore(abciclient): fix goroutine leak

* chore: self-review

* test(abciclient): test socket client timeouts
  • Loading branch information
lklimek authored Mar 7, 2024
1 parent e6ee68d commit 608bc11
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 41 deletions.
7 changes: 5 additions & 2 deletions abci/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ type requestAndResponse struct {
*types.Request
*types.Response

mtx sync.Mutex
mtx sync.Mutex
// context for the request; we check if it's not expired before sending
ctx context.Context
signal chan struct{}
}

func makeReqRes(req *types.Request) *requestAndResponse {
func makeReqRes(ctx context.Context, req *types.Request) *requestAndResponse {
return &requestAndResponse{
Request: req,
Response: nil,
ctx: ctx,
signal: make(chan struct{}),
}
}
Expand Down
82 changes: 55 additions & 27 deletions abci/client/socket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type socketClient struct {
mustConnect bool
conn net.Conn

// Requests queue
reqQueue chan *requestAndResponse

mtx sync.Mutex
Expand Down Expand Up @@ -116,37 +117,62 @@ func (cli *socketClient) Error() error {

//----------------------------------------

// Add the request to the pending messages queue.
//
// If the context `ctx` is canceled, return ctx.Err().
func (cli *socketClient) enqueue(ctx context.Context, reqres *requestAndResponse) error {
select {
case <-ctx.Done():
return ctx.Err()
case cli.reqQueue <- reqres:
return nil
}
}

// Block until first request arrives, then return it.
//
// If the context `ctx` is canceled, return nil.
func (cli *socketClient) dequeue(ctx context.Context) *requestAndResponse {
select {
case item := <-cli.reqQueue:
return item
case <-ctx.Done():
return nil
}
}

func (cli *socketClient) sendRequestsRoutine(ctx context.Context, conn io.Writer) {
bw := bufio.NewWriter(conn)
for {
select {
case <-ctx.Done():
// dequeue will block until a message arrives
for reqres := cli.dequeue(ctx); reqres != nil && ctx.Err() == nil; reqres = cli.dequeue(ctx) {
if err := reqres.ctx.Err(); err != nil {
// request expired, skip it
cli.logger.Debug("abci.socketClient request expired, skipping", "req", reqres.Request.Value, "error", err)
continue
}

// N.B. We must track request before sending it out, otherwise the
// server may reply before we do it, and the receiver will fail for an
// unsolicited reply.
cli.trackRequest(reqres)

if err := types.WriteMessage(reqres.Request, bw); err != nil {
cli.stopForError(fmt.Errorf("write to buffer: %w", err))
return
case reqres := <-cli.reqQueue:
// N.B. We must enqueue before sending out the request, otherwise the
// server may reply before we do it, and the receiver will fail for an
// unsolicited reply.
cli.trackRequest(reqres)

if err := types.WriteMessage(reqres.Request, bw); err != nil {
cli.stopForError(fmt.Errorf("write to buffer: %w", err))
return
}
}

if err := bw.Flush(); err != nil {
cli.stopForError(fmt.Errorf("flush buffer: %w", err))
return
}
if err := bw.Flush(); err != nil {
cli.stopForError(fmt.Errorf("flush buffer: %w", err))
return
}
}

cli.logger.Debug("context canceled, stopping sendRequestsRoutine")
}

func (cli *socketClient) recvResponseRoutine(ctx context.Context, conn io.Reader) {
r := bufio.NewReader(conn)
for {
if ctx.Err() != nil {
return
}
for ctx.Err() == nil {
res := &types.Response{}

if err := types.ReadMessage(r, res); err != nil {
Expand All @@ -166,6 +192,8 @@ func (cli *socketClient) recvResponseRoutine(ctx context.Context, conn io.Reader
}
}
}

cli.logger.Debug("context canceled, stopping recvResponseRoutine")
}

func (cli *socketClient) trackRequest(reqres *requestAndResponse) {
Expand Down Expand Up @@ -209,15 +237,15 @@ func (cli *socketClient) doRequest(ctx context.Context, req *types.Request) (*ty
return nil, errors.New("client has stopped")
}

reqres := makeReqRes(req)

select {
case cli.reqQueue <- reqres:
case <-ctx.Done():
return nil, fmt.Errorf("can't queue req: %w", ctx.Err())
reqres := makeReqRes(ctx, req)
if err := cli.enqueue(ctx, reqres); err != nil {
return nil, err
}

// wait for response for our request
select {
case <-reqres.ctx.Done():
return nil, reqres.ctx.Err()
case <-reqres.signal:
if err := cli.Error(); err != nil {
return nil, err
Expand Down
120 changes: 120 additions & 0 deletions abci/client/socket_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package abciclient

import (
"context"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/dashpay/tenderdash/abci/server"
"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/abci/types/mocks"
"github.com/dashpay/tenderdash/libs/log"
)

// TestSocketClientTimeout tests that the socket client times out correctly.
func TestSocketClientTimeout(t *testing.T) {
const (
Success = 0
FailDuringEnqueue = 1
FailDuringProcessing = 2

baseTime = 10 * time.Millisecond
)
type testCase struct {
name string
timeout time.Duration
enqueueSleep time.Duration
processingSleep time.Duration
expect int
}
testCases := []testCase{
{name: "immediate", timeout: baseTime, enqueueSleep: 0, processingSleep: 0, expect: Success},
{name: "small enqueue delay", timeout: 4 * baseTime, enqueueSleep: 1 * baseTime, processingSleep: 0, expect: Success},
{name: "small processing delay", timeout: 4 * baseTime, enqueueSleep: 0, processingSleep: 1 * baseTime, expect: Success},
{name: "within timeout", timeout: 4 * baseTime, enqueueSleep: 1 * baseTime, processingSleep: 1 * baseTime, expect: Success},
{name: "timeout during enqueue", timeout: 3 * baseTime, enqueueSleep: 4 * baseTime, processingSleep: 1 * baseTime, expect: FailDuringEnqueue},
{name: "timeout during processing", timeout: 4 * baseTime, enqueueSleep: 1 * baseTime, processingSleep: 4 * baseTime, expect: FailDuringProcessing},
}

logger := log.NewTestingLogger(t)

for i, tc := range testCases {
i := i
tc := tc
t.Run(tc.name, func(t *testing.T) {

// wait until all threads end, otherwise we'll get data race in t.Log()
wg := sync.WaitGroup{}
defer wg.Wait()

mainCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

socket := "unix://" + t.TempDir() + "/socket." + strconv.Itoa(i)

checkTxExecuted := atomic.Bool{}

app := mocks.NewApplication(t)
app.On("Echo", mock.Anything, mock.Anything).Return(&types.ResponseEcho{}, nil).Maybe()
app.On("Info", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) {
wg.Add(1)
logger.Debug("Info before sleep")
time.Sleep(tc.enqueueSleep)
logger.Debug("Info after sleep")
wg.Done()
}).Return(&types.ResponseInfo{}, nil).Maybe()
app.On("CheckTx", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) {
wg.Add(1)
logger.Debug("CheckTx before sleep")
checkTxExecuted.Store(true)
time.Sleep(tc.processingSleep)
logger.Debug("CheckTx after sleep")
wg.Done()
}).Return(&types.ResponseCheckTx{}, nil).Maybe()

service, err := server.NewServer(logger, socket, "socket", app)
require.NoError(t, err)
svr := service.(*server.SocketServer)
err = svr.Start(mainCtx)
require.NoError(t, err)
defer svr.Stop()

cli := NewSocketClient(logger, socket, true).(*socketClient)

err = cli.Start(mainCtx)
require.NoError(t, err)
defer cli.Stop()

reqCtx, reqCancel := context.WithTimeout(context.Background(), tc.timeout)
defer reqCancel()
// Info is here just to block for some time, so we don't want to enforce timeout on it

wg.Add(1)
go func() {
_, _ = cli.Info(mainCtx, &types.RequestInfo{})
wg.Done()
}()

time.Sleep(1 * time.Millisecond) // ensure the goroutine has started

_, err = cli.CheckTx(reqCtx, &types.RequestCheckTx{})
switch tc.expect {
case Success:
require.NoError(t, err)
require.True(t, checkTxExecuted.Load())
case FailDuringEnqueue:
require.Error(t, err)
require.False(t, checkTxExecuted.Load())
case FailDuringProcessing:
require.Error(t, err)
require.True(t, checkTxExecuted.Load())
}
})
}
}
24 changes: 16 additions & 8 deletions internal/mempool/p2p_msg_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"time"

"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/internal/p2p/client"
Expand All @@ -12,6 +13,12 @@ import (
"github.com/dashpay/tenderdash/types"
)

const (
// CheckTxTimeout is the maximum time we wait for CheckTx to return.
// TODO: Change to config option
CheckTxTimeout = 1 * time.Second
)

type (
mempoolP2PMessageHandler struct {
logger log.Logger
Expand Down Expand Up @@ -53,7 +60,10 @@ func (h *mempoolP2PMessageHandler) Handle(ctx context.Context, _ *client.Client,
SenderNodeID: envelope.From,
}
for _, tx := range protoTxs {
if err := h.checker.CheckTx(ctx, tx, nil, txInfo); err != nil {
subCtx, subCtxCancel := context.WithTimeout(ctx, CheckTxTimeout)
defer subCtxCancel()

if err := h.checker.CheckTx(subCtx, tx, nil, txInfo); err != nil {
if errors.Is(err, types.ErrTxInCache) {
// if the tx is in the cache,
// then we've been gossiped a
Expand All @@ -63,13 +73,11 @@ func (h *mempoolP2PMessageHandler) Handle(ctx context.Context, _ *client.Client,
// problem.
continue
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// Do not propagate context
// cancellation errors, but do
// not continue to check
// transactions from this
// message if we are shutting down.
return err

// In case of ctx cancelation, we return error as we are most likely shutting down.
// Otherwise we just reject the tx.
if errCtx := ctx.Err(); errCtx != nil {
return errCtx
}
logger.Error("checktx failed for tx",
"tx", fmt.Sprintf("%X", types.Tx(tx).Hash()),
Expand Down
9 changes: 7 additions & 2 deletions internal/p2p/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,15 @@ func (c *Client) GetSyncStatus(ctx context.Context) error {
}

// SendTxs sends a transaction to the peer
func (c *Client) SendTxs(ctx context.Context, peerID types.NodeID, tx types.Tx) error {
func (c *Client) SendTxs(ctx context.Context, peerID types.NodeID, tx ...types.Tx) error {
txs := make([][]byte, len(tx))
for i := 0; i < len(tx); i++ {
txs[i] = tx[i]
}

return c.Send(ctx, p2p.Envelope{
To: peerID,
Message: &protomem.Txs{Txs: [][]byte{tx}},
Message: &protomem.Txs{Txs: txs},
})
}

Expand Down
16 changes: 14 additions & 2 deletions internal/rpc/core/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,17 @@ import (
// More:
// https://docs.tendermint.com/master/rpc/#/Tx/broadcast_tx_async
// Deprecated and should be removed in 0.37
func (env *Environment) BroadcastTxAsync(ctx context.Context, req *coretypes.RequestBroadcastTx) (*coretypes.ResultBroadcastTx, error) {
go func() { _ = env.Mempool.CheckTx(ctx, req.Tx, nil, mempool.TxInfo{}) }()
func (env *Environment) BroadcastTxAsync(_ctx context.Context, req *coretypes.RequestBroadcastTx) (*coretypes.ResultBroadcastTx, error) {
go func() {
// We need to create a new context here, because the original context
// may be canceled after parent function returns.
ctx, cancel := context.WithTimeout(context.Background(), mempool.CheckTxTimeout)
defer cancel()

if res, err := env.BroadcastTx(ctx, req); err != nil || res.Code != abci.CodeTypeOK {
env.Logger.Error("error on broadcastTxAsync", "err", err, "result", res, "tx", req.Tx.Hash())
}
}()

return &coretypes.ResultBroadcastTx{Hash: req.Tx.Hash()}, nil
}
Expand All @@ -37,6 +46,9 @@ func (env *Environment) BroadcastTxSync(ctx context.Context, req *coretypes.Requ
// DeliverTx result.
// More: https://docs.tendermint.com/master/rpc/#/Tx/broadcast_tx_sync
func (env *Environment) BroadcastTx(ctx context.Context, req *coretypes.RequestBroadcastTx) (*coretypes.ResultBroadcastTx, error) {
ctx, cancel := context.WithTimeout(ctx, mempool.CheckTxTimeout)
defer cancel()

resCh := make(chan *abci.ResponseCheckTx, 1)
err := env.Mempool.CheckTx(
ctx,
Expand Down

0 comments on commit 608bc11

Please sign in to comment.