Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(abciclient)!: limit concurrent gRPC connections #775

Merged
merged 12 commits into from
Apr 22, 2024
8 changes: 6 additions & 2 deletions abci/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
sync "github.com/sasha-s/go-deadlock"

"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/libs/log"
"github.com/dashpay/tenderdash/libs/service"
)
Expand Down Expand Up @@ -36,12 +37,15 @@ type Client interface {

// NewClient returns a new ABCI client of the specified transport type.
// It returns an error if the transport is not "socket" or "grpc"
func NewClient(logger log.Logger, addr, transport string, mustConnect bool) (Client, error) {
func NewClient(logger log.Logger, cfg config.AbciConfig, mustConnect bool) (Client, error) {
transport := cfg.Transport
addr := cfg.Address

switch transport {
case "socket":
return NewSocketClient(logger, addr, mustConnect), nil
case "grpc":
return NewGRPCClient(logger, addr, mustConnect), nil
return NewGRPCClient(logger, addr, cfg.GrpcConcurrency, mustConnect), nil
case "routed":
return NewRoutedClientWithAddr(logger, addr, mustConnect)
default:
Expand Down
99 changes: 97 additions & 2 deletions abci/client/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"strings"
"time"

sync "github.com/sasha-s/go-deadlock"
Expand All @@ -30,36 +31,130 @@ type grpcClient struct {
mtx sync.Mutex
addr string
err error

// map between method name (in grpc format, for example `/tendermint.abci.ABCIApplication/Echo`)
// and a channel that will be used to limit the number of concurrent requests for that method.
//
// If the value is nil, no limit is enforced.
//
// Not thread-safe, only modify this before starting the client.
concurrency map[string]chan struct{}
}

var _ Client = (*grpcClient)(nil)

// NewGRPCClient creates a gRPC client, which will connect to addr upon the
// start. Note Client#Start returns an error if connection is unsuccessful and
// mustConnect is true.
func NewGRPCClient(logger log.Logger, addr string, mustConnect bool) Client {
func NewGRPCClient(logger log.Logger, addr string, concurrency map[string]uint16, mustConnect bool) Client {
cli := &grpcClient{
logger: logger,
addr: addr,
mustConnect: mustConnect,
concurrency: make(map[string]chan struct{}, 20),
}
cli.BaseService = *service.NewBaseService(logger, "grpcClient", cli)
cli.SetMaxConcurrentStreams(concurrency)

return cli
}

func methodID(method string) string {
if pos := strings.LastIndex(method, "/"); pos > 0 {
method = method[pos+1:]
}

return strings.ToLower(method)
}

// SetMaxConcurrentStreams sets the maximum number of concurrent streams to be
// allowed on this client.
//
// Not thread-safe, only use this before starting the client.
//
// If limit is 0, no limit is enforced.
func (cli *grpcClient) SetMaxConcurrentStreamsForMethod(method string, n uint16) {
if cli.IsRunning() {
panic("cannot set max concurrent streams after starting the client")
}
var ch chan struct{}
if n != 0 {
ch = make(chan struct{}, n)
}

cli.mtx.Lock()
cli.concurrency[methodID(method)] = ch
cli.mtx.Unlock()
}

// SetMaxConcurrentStreams sets the maximum number of concurrent streams to be
// allowed on this client.
// # Arguments
//
// * `methods` - A map between method name (in grpc format, for example `/tendermint.abci.ABCIApplication/Echo`)
// and the maximum number of concurrent streams to be allowed for that method.
//
// Special method name "*" can be used to set the default limit for methods not explicitly listed.
//
// If the value is 0, no limit is enforced.
//
// Not thread-safe, only use this before starting the client.
func (cli *grpcClient) SetMaxConcurrentStreams(methods map[string]uint16) {
for method, n := range methods {
cli.SetMaxConcurrentStreamsForMethod(method, n)
}
}

func dialerFunc(_ctx context.Context, addr string) (net.Conn, error) {
return tmnet.Connect(addr)
}

// rateLimit blocks until the client is allowed to send a request.
// It returns a function that should be called after the request is done.
//
// method should be the method name in grpc format, for example `/tendermint.abci.ABCIApplication/Echo`.
// Special method name "*" can be used to define the default limit.
// If no limit is set for the method, the default limit is used.
func (cli *grpcClient) rateLimit(method string) context.CancelFunc {
ch := cli.concurrency[methodID(method)]
// handle default
if ch == nil {
ch = cli.concurrency["*"]
}
if ch == nil {
return func() {}
}

cli.logger.Trace("grpcClient rateLimit", "addr", cli.addr)
ch <- struct{}{}
return func() { <-ch }
}

func (cli *grpcClient) unaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
done := cli.rateLimit(method)
defer done()

return invoker(ctx, method, req, reply, cc, opts...)
}

func (cli *grpcClient) streamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
done := cli.rateLimit(method)
defer done()

return streamer(ctx, desc, cc, method, opts...)
}

func (cli *grpcClient) OnStart(ctx context.Context) error {
timer := time.NewTimer(0)
defer timer.Stop()

RETRY_LOOP:
for {
conn, err := grpc.Dial(cli.addr,
conn, err := grpc.NewClient(cli.addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(dialerFunc),
grpc.WithChainUnaryInterceptor(cli.unaryClientInterceptor),
grpc.WithChainStreamInterceptor(cli.streamClientInterceptor),
)
if err != nil {
if cli.mustConnect {
Expand Down
140 changes: 95 additions & 45 deletions abci/client/grpc_client_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package abciclient_test
package abciclient

import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/assert"

abciclient "github.com/dashpay/tenderdash/abci/client"
abciserver "github.com/dashpay/tenderdash/abci/server"
"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/libs/log"
Expand All @@ -18,47 +19,87 @@ import (

// TestGRPCClientServerParallel tests that gRPC client and server can handle multiple parallel requests
func TestGRPCClientServerParallel(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

logger := log.NewNopLogger()
app := &mockApplication{t: t}

socket := t.TempDir() + "/grpc_test"
client, _, err := makeGRPCClientServer(ctx, t, logger, app, socket)
if err != nil {
t.Fatal(err)
const (
timeout = 1 * time.Second
tick = 10 * time.Millisecond
)

type testCase struct {
threads int
infoConcurrency uint16
defautConcurrency uint16
}

// we'll use that mutex to ensure threads don't finish before we check status
app.mtx.Lock()

const threads = 5
// started will be marked as done as soon as app.Info() handler executes on the server
app.started.Add(threads)
// done will be used to wait for all threads to finish
var done sync.WaitGroup
done.Add(threads)

for i := 0; i < threads; i++ {
thread := uint64(i)
go func() {
_, _ = client.Info(ctx, &types.RequestInfo{BlockVersion: thread})
done.Done()
}()
testCases := []testCase{
{threads: 1, infoConcurrency: 1},
{threads: 2, infoConcurrency: 1},
{threads: 2, infoConcurrency: 2},
{threads: 5, infoConcurrency: 0},
{threads: 5, infoConcurrency: 0, defautConcurrency: 2},
{threads: 5, infoConcurrency: 1},
{threads: 5, infoConcurrency: 2},
{threads: 5, infoConcurrency: 5},
}

// wait for threads to execute
// note it doesn't mean threads are really done, as they are waiting on the mtx
// so if all `started` are marked as done, it means all threads have started
// in parallel
app.started.Wait()

// unlock the mutex so that threads can finish their execution
app.mtx.Unlock()
logger := log.NewNopLogger()

// wait for all threads to really finish
done.Wait()
for _, tc := range testCases {
t.Run(fmt.Sprintf("t_%d-i_%d,d_%d", tc.threads, tc.infoConcurrency, tc.defautConcurrency), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

app := &mockApplication{t: t, concurrencyLimit: int32(tc.infoConcurrency)}

socket := t.TempDir() + "/grpc_test"
limits := map[string]uint16{
"/tendermint.abci.ABCIApplication/Info": tc.infoConcurrency,
"*": tc.defautConcurrency,
}

client, _, err := makeGRPCClientServer(ctx, t, logger, app, socket, limits)
if err != nil {
t.Fatal(err)
}

// we'll use that mutex to ensure threads don't finish before we check status
app.mtx.Lock()

// done will be used to wait for all threads to finish
var done sync.WaitGroup

for i := 0; i < tc.threads; i++ {
done.Add(1)
thread := uint64(i)
go func() {
// we use BlockVersion for logging purposes, so we put thread id there
_, _ = client.Info(ctx, &types.RequestInfo{BlockVersion: thread})
done.Done()
}()
}

expectThreads := int32(tc.infoConcurrency)
if expectThreads == 0 {
expectThreads = int32(tc.defautConcurrency)
}
if expectThreads == 0 {
expectThreads = int32(tc.threads)
}

// wait for all threads to start execution
assert.Eventually(t, func() bool {
return app.running.Load() == expectThreads
}, timeout, tick, "not all threads started in time")

// ensure no other threads will start
time.Sleep(2 * tick)

// unlock the mutex so that threads can finish their execution
app.mtx.Unlock()

// wait for all threads to really finish
done.Wait()
})
}
}

func makeGRPCClientServer(
Expand All @@ -67,7 +108,8 @@ func makeGRPCClientServer(
logger log.Logger,
app types.Application,
name string,
) (abciclient.Client, service.Service, error) {
concurrency map[string]uint16,
) (Client, service.Service, error) {
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
t.Cleanup(leaktest.Check(t))
Expand All @@ -82,7 +124,7 @@ func makeGRPCClientServer(
return nil, nil, err
}

client := abciclient.NewGRPCClient(logger.With("module", "abci-client"), socket, true)
client := NewGRPCClient(logger.With("module", "abci-client"), socket, concurrency, true)

if err := client.Start(ctx); err != nil {
cancel()
Expand All @@ -96,19 +138,27 @@ func makeGRPCClientServer(
type mockApplication struct {
types.BaseApplication
mtx sync.Mutex
// we'll use that to ensure all threads have started
started sync.WaitGroup

running atomic.Int32
// concurrencyLimit of concurrent requests
concurrencyLimit int32

t *testing.T
}

func (m *mockApplication) Info(_ctx context.Context, req *types.RequestInfo) (res *types.ResponseInfo, err error) {
m.t.Logf("Info %d called", req.BlockVersion)
// mark wg as done to signal that we have executed
m.started.Done()
// we will wait here until all threads mark wg as done
running := m.running.Add(1)
defer m.running.Add(-1)

if m.concurrencyLimit > 0 {
assert.LessOrEqual(m.t, running, m.concurrencyLimit, "too many requests running in parallel")
}

// we will wait here until all expected threads are running
m.mtx.Lock()
defer m.mtx.Unlock()
m.t.Logf("Info %d finished", req.BlockVersion)

return &types.ResponseInfo{}, nil
}
4 changes: 3 additions & 1 deletion abci/client/routed_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hashicorp/go-multierror"

"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/libs/log"
"github.com/dashpay/tenderdash/libs/service"
)
Expand Down Expand Up @@ -71,7 +72,8 @@ func NewRoutedClientWithAddr(logger log.Logger, addr string, mustConnect bool) (
// Create a new client if it doesn't exist
clientName := fmt.Sprintf("%s:%s", transport, address)
if _, ok := clients[clientName]; !ok {
c, err := NewClient(logger, address, transport, mustConnect)
cfg := config.AbciConfig{Address: address, Transport: transport}
c, err := NewClient(logger, cfg, mustConnect)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion abci/cmd/abci-cli/abci-cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/dashpay/tenderdash/abci/server"
servertest "github.com/dashpay/tenderdash/abci/tests/server"
"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/libs/log"
"github.com/dashpay/tenderdash/proto/tendermint/crypto"
tmproto "github.com/dashpay/tenderdash/proto/tendermint/types"
Expand Down Expand Up @@ -64,7 +65,8 @@ func RootCmmand(logger log.Logger) *cobra.Command {

if client == nil {
var err error
client, err = abciclient.NewClient(logger.With("module", "abci-client"), flagAddress, flagAbci, false)
cfg := config.AbciConfig{Address: flagAddress, Transport: flagAbci}
client, err = abciclient.NewClient(logger.With("module", "abci-client"), cfg, false)
if err != nil {
return err
}
Expand Down
Loading
Loading