From 19eebe4b962110fcf89efb58008064fb0d65152b Mon Sep 17 00:00:00 2001 From: Aaron Beitch Date: Thu, 6 Oct 2022 09:13:39 -0700 Subject: [PATCH] Implement SendBatch --- rpc.go | 186 ++++++++++++++++++++++++++++++++++++++++++++++++++++ rpc_test.go | 101 ++++++++++++++++++++++++++++ 2 files changed, 287 insertions(+) diff --git a/rpc.go b/rpc.go index a6a58f49..cb3bc026 100644 --- a/rpc.go +++ b/rpc.go @@ -13,6 +13,7 @@ import ( "io" "math" "strconv" + "sync" "time" log "github.com/sirupsen/logrus" @@ -158,6 +159,191 @@ func (c *client) getRegionAndClientForRPC(ctx context.Context, rpc hrpc.Call) ( } } +var ( + // BatchError is the error returned when some or all RPCs in a + // SendBatch fail and the caller should inspect the individual + // results for more information. + BatchError = errors.New( + "errors executing batch. Inspect the individual results for more information") + + // NotExecutedError is returned when an RPC in a batch is not + // executed due to encountering a different error in the batch. + NotExecutedError = errors.New( + "RPC in batch not executed due to another error") +) + +// SendBatch will execute all the Calls in batch. Every Call must have +// the same table and must be Batchable. +// +// SendBatch will discover the correct region and region server for +// each Call and dispatch the Calls accordingly. SendBatch is not an +// atomic operation. Some calls may fail and others succeed. Calls +// sharing a region will execute in the order passed into SendBatch. +// +// SendBatch returns a slice of results. The i'th result will be for +// the i'th call in the passed in batch. A non-nil error will be +// returned for any error that affects the whole batch or if any of +// the individual requests cause an error. Check the Error field of +// the RPCResults returned to find out the per-Call error. +// +// TODO: If retryable errors are encountered on a subset of the Calls +// in the batch the caller is required to filter out the successful +// Calls and retry with the failed Calls. Should we handle that here +// instead? +func (c *client) SendBatch(ctx context.Context, batch []hrpc.Call) ( + res []hrpc.RPCResult, err error) { + if len(batch) == 0 { + return nil, nil + } + + start := time.Now() + description := "SendBatch" + ctx, sp := observability.StartSpan(ctx, description) + defer func() { + result := "ok" + if err != nil { + result = "error" + sp.SetStatus(codes.Error, err.Error()) + } + + o := operationDurationSeconds.WithLabelValues(description, result) + + observability.ObserveWithTrace(ctx, o, time.Since(start).Seconds()) + sp.End() + }() + + table := batch[0].Table() + res = make([]hrpc.RPCResult, len(batch)) + rpcToRes := make(map[hrpc.Call]int, len(batch)) + for i, rpc := range batch { + // map Call to index in res so that we can set the correct + // result as Calls complete + if j, ok := rpcToRes[rpc]; ok { + res[i].Error = fmt.Errorf("duplicate call in batch at index %d", j) + err = BatchError + continue + } + rpcToRes[rpc] = i + + // Initialize res with NotExecutedError. As RPCs are executed this + // will be replaced by a more specific error or nil if no error + // occurs. + res[i].Error = NotExecutedError + + if !bytes.Equal(rpc.Table(), table) { + res[i].Error = fmt.Errorf("multiple tables in batch request: %q and %q", + string(table), string(rpc.Table())) + err = BatchError + } else if b, ok := rpc.(hrpc.Batchable); !ok || b.SkipBatch() { + res[i].Error = errors.New("non-batchable call passed to SendBatch") + err = BatchError + } + } + if err != nil { + return res, err + } + + rpcByClient, err := c.findClients(ctx, batch, res) + if err != nil { + return res, err + } + + // Send each group of RPCs to region client to be executed. + var wg sync.WaitGroup + var errOnce sync.Once + wg.Add(len(rpcByClient)) + for client, rpcs := range rpcByClient { + // TODO: Move this to the RegionClient interface so we don't + // need to type assert here + qb := client.(interface { + QueueBatch(ctx context.Context, rpcs []hrpc.Call) + }) + go func(client hrpc.RegionClient, rpcs []hrpc.Call) { + defer wg.Done() + qb.QueueBatch(ctx, rpcs) + ctx, sp := observability.StartSpan(ctx, "waitForResult") + defer sp.End() + err1 := c.waitForCompletion(ctx, client, rpcs, res, rpcToRes) + if err1 != nil { + // First error encountered (should be BatchError if + // non-nil) will be returned by SendBatch + errOnce.Do(func() { err = err1 }) + } + }(client, rpcs) + } + + wg.Wait() + return res, err +} + +// findClients takes a batch of rpcs and discovers the region and +// region client associated with each. A map is returned with rpcs +// grouped by their region client. If any error is encountered, the +// corresponding slot in res will be updated with that error and a +// BatchError is returned. +// +// findClients will not return on the first errror encountered. It +// will iterate through all the RPCs to ensure that all unknown +// regions encountered in the batch will start being initialized. +func (c *client) findClients(ctx context.Context, batch []hrpc.Call, res []hrpc.RPCResult) ( + map[hrpc.RegionClient][]hrpc.Call, error) { + + rpcByClient := make(map[hrpc.RegionClient][]hrpc.Call) + var err error + + for i, rpc := range batch { + regClient, err := c.getRegionAndClientForRPC(ctx, rpc) + if err != nil { + res[i].Error = err + err = BatchError + continue // see if any more RPCs are missing regions + } + rpcByClient[regClient] = append(rpcByClient[regClient], rpc) + } + return rpcByClient, err +} + +func (c *client) waitForCompletion(ctx context.Context, rc hrpc.RegionClient, + rpcs []hrpc.Call, results []hrpc.RPCResult, rpcToRes map[hrpc.Call]int) error { + + var err error + cancelledIndex := -1 + for i, rpc := range rpcs { + select { + case res := <-rpc.ResultChan(): + results[rpcToRes[rpc]] = res + if res.Error != nil { + c.handleResultError(res.Error, rpc.Region(), rc) + err = BatchError + } + case <-ctx.Done(): + cancelledIndex = i + err = BatchError + break + } + } + if cancelledIndex == -1 { + return err + } + + // The context has been cancelled. Do a non-blocking check if + // a result is ready on each RPC, otherwise mark it with the + // context error. + for _, rpc := range rpcs[cancelledIndex:] { + select { + case res := <-rpc.ResultChan(): + results[rpcToRes[rpc]] = res + if res.Error != nil { + c.handleResultError(res.Error, rpc.Region(), rc) + } + default: + results[rpcToRes[rpc]].Error = ctx.Err() + } + } + + return err +} + func (c *client) handleResultError(err error, reg hrpc.RegionInfo, rc hrpc.RegionClient) { // Check for errors switch err.(type) { diff --git a/rpc_test.go b/rpc_test.go index 1dbd68fe..6e19e9f2 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -12,6 +12,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "sync" "testing" "time" @@ -819,3 +820,103 @@ func BenchmarkPrometheusWith(b *testing.B) { }) } } + +func TestSendBatch(t *testing.T) { + // TODO: Test cases: + // - getRegionForRpc returns an error + // - Client for region is nil + // - waitForCompletion returns errors from one client but not another + // - waitForCompletion returns errors for all clients + // - some results are returned while context is canceled + // - success case - verify no errors set + // - ensure no data races on writing results +} + +func TestSendBatchBadInput(t *testing.T) { + ctrl := test.NewController(t) + defer ctrl.Finish() + + zkc := mockZk.NewMockClient(ctrl) + zkc.EXPECT().LocateResource(zk.Meta).Return("regionserver:1", nil).AnyTimes() + c := newMockClient(zkc) + + newRPC := func(table string, batchable bool) hrpc.Call { + if batchable { + rpc, err := hrpc.NewPutStr(context.Background(), table, "key", + map[string]map[string][]byte{"cf": {"foo": []byte("bar")}}) + if err != nil { + t.Fatal(err) + } + return rpc + } + rpc, err := hrpc.NewScanStr(context.Background(), table) + if err != nil { + t.Fatal(err) + } + return rpc + } + + rpc1 := newRPC("table", true) + rpc2 := newRPC("table", true) + + for _, tc := range []struct { + name string + batch []hrpc.Call + expErr []string + }{{ + name: "duplicate", + batch: []hrpc.Call{rpc1, rpc1}, + expErr: []string{NotExecutedError.Error(), "duplicate call"}, + }, { + name: "duplicate2", + batch: []hrpc.Call{rpc1, rpc2, rpc2, newRPC("table", true), rpc1}, + expErr: []string{NotExecutedError.Error(), NotExecutedError.Error(), + "duplicate call", NotExecutedError.Error(), "duplicate call"}, + }, { + name: "tables", + batch: []hrpc.Call{newRPC("table", true), newRPC("different_table", true)}, + expErr: []string{NotExecutedError.Error(), "multiple tables"}, + }, { + name: "batchable", + batch: []hrpc.Call{newRPC("table", false)}, + expErr: []string{"non-batchable"}, + }, { + name: "various errors", + batch: []hrpc.Call{rpc1, + newRPC("table", false), + rpc1, + newRPC("table2", true), + newRPC("table", true)}, + expErr: []string{NotExecutedError.Error(), + "non-batchable", + "duplicate call", + "multiple tables", + NotExecutedError.Error()}, + }} { + t.Run(tc.name, func(t *testing.T) { + if len(tc.batch) != len(tc.expErr) { + t.Fatalf("test case provides mismatched batch (%d) and expErr (%d) sizes", + len(tc.batch), len(tc.expErr)) + } + + results, err := c.SendBatch(context.Background(), tc.batch) + if err != BatchError { + t.Errorf("expected BatchError from SendBatch, got %v", err) + } + if len(results) != len(tc.batch) { + t.Fatalf("result size (%d) does not match batch size (%d)", + len(results), len(tc.batch)) + } + for i, res := range results { + if res.Error == nil { + t.Errorf("expected error in res[%d], but got nil for request %v", + i, tc.batch[i]) + continue + } + if !strings.Contains(res.Error.Error(), tc.expErr[i]) { + t.Errorf("expected error to contain %q, but got %q", tc.expErr[i], res.Error) + } + } + }) + } +}