Skip to content

Commit

Permalink
Implement SendBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronbee committed Jun 23, 2023
1 parent 88eba23 commit 19eebe4
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 0 deletions.
186 changes: 186 additions & 0 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"math"
"strconv"
"sync"
"time"

log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -158,6 +159,191 @@ func (c *client) getRegionAndClientForRPC(ctx context.Context, rpc hrpc.Call) (
}
}

var (
// BatchError is the error returned when some or all RPCs in a
// SendBatch fail and the caller should inspect the individual
// results for more information.
BatchError = errors.New(
"errors executing batch. Inspect the individual results for more information")

// NotExecutedError is returned when an RPC in a batch is not
// executed due to encountering a different error in the batch.
NotExecutedError = errors.New(
"RPC in batch not executed due to another error")
)

// SendBatch will execute all the Calls in batch. Every Call must have
// the same table and must be Batchable.
//
// SendBatch will discover the correct region and region server for
// each Call and dispatch the Calls accordingly. SendBatch is not an
// atomic operation. Some calls may fail and others succeed. Calls
// sharing a region will execute in the order passed into SendBatch.
//
// SendBatch returns a slice of results. The i'th result will be for
// the i'th call in the passed in batch. A non-nil error will be
// returned for any error that affects the whole batch or if any of
// the individual requests cause an error. Check the Error field of
// the RPCResults returned to find out the per-Call error.
//
// TODO: If retryable errors are encountered on a subset of the Calls
// in the batch the caller is required to filter out the successful
// Calls and retry with the failed Calls. Should we handle that here
// instead?
func (c *client) SendBatch(ctx context.Context, batch []hrpc.Call) (
res []hrpc.RPCResult, err error) {
if len(batch) == 0 {
return nil, nil
}

start := time.Now()
description := "SendBatch"
ctx, sp := observability.StartSpan(ctx, description)
defer func() {
result := "ok"
if err != nil {
result = "error"
sp.SetStatus(codes.Error, err.Error())
}

o := operationDurationSeconds.WithLabelValues(description, result)

observability.ObserveWithTrace(ctx, o, time.Since(start).Seconds())
sp.End()
}()

table := batch[0].Table()
res = make([]hrpc.RPCResult, len(batch))
rpcToRes := make(map[hrpc.Call]int, len(batch))
for i, rpc := range batch {
// map Call to index in res so that we can set the correct
// result as Calls complete
if j, ok := rpcToRes[rpc]; ok {
res[i].Error = fmt.Errorf("duplicate call in batch at index %d", j)
err = BatchError
continue
}
rpcToRes[rpc] = i

// Initialize res with NotExecutedError. As RPCs are executed this
// will be replaced by a more specific error or nil if no error
// occurs.
res[i].Error = NotExecutedError

if !bytes.Equal(rpc.Table(), table) {
res[i].Error = fmt.Errorf("multiple tables in batch request: %q and %q",
string(table), string(rpc.Table()))
err = BatchError
} else if b, ok := rpc.(hrpc.Batchable); !ok || b.SkipBatch() {
res[i].Error = errors.New("non-batchable call passed to SendBatch")
err = BatchError
}
}
if err != nil {
return res, err
}

rpcByClient, err := c.findClients(ctx, batch, res)
if err != nil {
return res, err
}

// Send each group of RPCs to region client to be executed.
var wg sync.WaitGroup
var errOnce sync.Once
wg.Add(len(rpcByClient))
for client, rpcs := range rpcByClient {
// TODO: Move this to the RegionClient interface so we don't
// need to type assert here
qb := client.(interface {
QueueBatch(ctx context.Context, rpcs []hrpc.Call)
})
go func(client hrpc.RegionClient, rpcs []hrpc.Call) {
defer wg.Done()
qb.QueueBatch(ctx, rpcs)
ctx, sp := observability.StartSpan(ctx, "waitForResult")
defer sp.End()
err1 := c.waitForCompletion(ctx, client, rpcs, res, rpcToRes)
if err1 != nil {
// First error encountered (should be BatchError if
// non-nil) will be returned by SendBatch
errOnce.Do(func() { err = err1 })
}
}(client, rpcs)
}

wg.Wait()
return res, err
}

// findClients takes a batch of rpcs and discovers the region and
// region client associated with each. A map is returned with rpcs
// grouped by their region client. If any error is encountered, the
// corresponding slot in res will be updated with that error and a
// BatchError is returned.
//
// findClients will not return on the first errror encountered. It
// will iterate through all the RPCs to ensure that all unknown
// regions encountered in the batch will start being initialized.
func (c *client) findClients(ctx context.Context, batch []hrpc.Call, res []hrpc.RPCResult) (
map[hrpc.RegionClient][]hrpc.Call, error) {

rpcByClient := make(map[hrpc.RegionClient][]hrpc.Call)
var err error

for i, rpc := range batch {
regClient, err := c.getRegionAndClientForRPC(ctx, rpc)
if err != nil {
res[i].Error = err
err = BatchError
continue // see if any more RPCs are missing regions
}
rpcByClient[regClient] = append(rpcByClient[regClient], rpc)
}
return rpcByClient, err
}

func (c *client) waitForCompletion(ctx context.Context, rc hrpc.RegionClient,
rpcs []hrpc.Call, results []hrpc.RPCResult, rpcToRes map[hrpc.Call]int) error {

var err error
cancelledIndex := -1
for i, rpc := range rpcs {
select {
case res := <-rpc.ResultChan():
results[rpcToRes[rpc]] = res
if res.Error != nil {
c.handleResultError(res.Error, rpc.Region(), rc)
err = BatchError
}
case <-ctx.Done():
cancelledIndex = i
err = BatchError
break
}
}
if cancelledIndex == -1 {
return err
}

// The context has been cancelled. Do a non-blocking check if
// a result is ready on each RPC, otherwise mark it with the
// context error.
for _, rpc := range rpcs[cancelledIndex:] {
select {
case res := <-rpc.ResultChan():
results[rpcToRes[rpc]] = res
if res.Error != nil {
c.handleResultError(res.Error, rpc.Region(), rc)
}
default:
results[rpcToRes[rpc]].Error = ctx.Err()
}
}

return err
}

func (c *client) handleResultError(err error, reg hrpc.RegionInfo, rc hrpc.RegionClient) {
// Check for errors
switch err.(type) {
Expand Down
101 changes: 101 additions & 0 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -819,3 +820,103 @@ func BenchmarkPrometheusWith(b *testing.B) {
})
}
}

func TestSendBatch(t *testing.T) {
// TODO: Test cases:
// - getRegionForRpc returns an error
// - Client for region is nil
// - waitForCompletion returns errors from one client but not another
// - waitForCompletion returns errors for all clients
// - some results are returned while context is canceled
// - success case - verify no errors set
// - ensure no data races on writing results
}

func TestSendBatchBadInput(t *testing.T) {
ctrl := test.NewController(t)
defer ctrl.Finish()

zkc := mockZk.NewMockClient(ctrl)
zkc.EXPECT().LocateResource(zk.Meta).Return("regionserver:1", nil).AnyTimes()
c := newMockClient(zkc)

newRPC := func(table string, batchable bool) hrpc.Call {
if batchable {
rpc, err := hrpc.NewPutStr(context.Background(), table, "key",
map[string]map[string][]byte{"cf": {"foo": []byte("bar")}})
if err != nil {
t.Fatal(err)
}
return rpc
}
rpc, err := hrpc.NewScanStr(context.Background(), table)
if err != nil {
t.Fatal(err)
}
return rpc
}

rpc1 := newRPC("table", true)
rpc2 := newRPC("table", true)

for _, tc := range []struct {
name string
batch []hrpc.Call
expErr []string
}{{
name: "duplicate",
batch: []hrpc.Call{rpc1, rpc1},
expErr: []string{NotExecutedError.Error(), "duplicate call"},
}, {
name: "duplicate2",
batch: []hrpc.Call{rpc1, rpc2, rpc2, newRPC("table", true), rpc1},
expErr: []string{NotExecutedError.Error(), NotExecutedError.Error(),
"duplicate call", NotExecutedError.Error(), "duplicate call"},
}, {
name: "tables",
batch: []hrpc.Call{newRPC("table", true), newRPC("different_table", true)},
expErr: []string{NotExecutedError.Error(), "multiple tables"},
}, {
name: "batchable",
batch: []hrpc.Call{newRPC("table", false)},
expErr: []string{"non-batchable"},
}, {
name: "various errors",
batch: []hrpc.Call{rpc1,
newRPC("table", false),
rpc1,
newRPC("table2", true),
newRPC("table", true)},
expErr: []string{NotExecutedError.Error(),
"non-batchable",
"duplicate call",
"multiple tables",
NotExecutedError.Error()},
}} {
t.Run(tc.name, func(t *testing.T) {
if len(tc.batch) != len(tc.expErr) {
t.Fatalf("test case provides mismatched batch (%d) and expErr (%d) sizes",
len(tc.batch), len(tc.expErr))
}

results, err := c.SendBatch(context.Background(), tc.batch)
if err != BatchError {
t.Errorf("expected BatchError from SendBatch, got %v", err)
}
if len(results) != len(tc.batch) {
t.Fatalf("result size (%d) does not match batch size (%d)",
len(results), len(tc.batch))
}
for i, res := range results {
if res.Error == nil {
t.Errorf("expected error in res[%d], but got nil for request %v",
i, tc.batch[i])
continue
}
if !strings.Contains(res.Error.Error(), tc.expErr[i]) {
t.Errorf("expected error to contain %q, but got %q", tc.expErr[i], res.Error)
}
}
})
}
}

0 comments on commit 19eebe4

Please sign in to comment.