diff --git a/rpc.go b/rpc.go index fbe5e17a..a6a58f49 100644 --- a/rpc.go +++ b/rpc.go @@ -86,14 +86,13 @@ func (c *client) SendRPC(rpc hrpc.Call) (msg proto.Message, err error) { sp.End() }() - reg, err := c.getRegionForRpc(ctx, rpc) - if err != nil { - return nil, err - } - backoff := backoffStart for { - msg, err = c.sendRPCToRegion(ctx, rpc, reg) + regClient, err := c.getRegionAndClientForRPC(ctx, rpc) + if err != nil { + return nil, err + } + msg, err = c.sendRPCToRegion(ctx, rpc, regClient) switch err.(type) { case region.RetryableError: sp.AddEvent("retrySleep") @@ -101,13 +100,44 @@ func (c *client) SendRPC(rpc hrpc.Call) (msg proto.Message, err error) { if err != nil { return msg, err } + continue // retry case region.ServerError, region.NotServingRegionError: + continue // retry + } + return msg, err + } +} + +func (c *client) getRegionAndClientForRPC(ctx context.Context, rpc hrpc.Call) ( + hrpc.RegionClient, error) { + for { + reg, err := c.getRegionForRpc(ctx, rpc) + if err != nil { + return nil, err + } + if ch := reg.AvailabilityChan(); ch != nil { // region is currently unavailable + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.done: + return nil, ErrClientClosed + case <-ch: + } + } + + client := reg.Client() + if client == nil { + // There was an error getting the region client. Mark the + // region as unavailable. + if reg.MarkUnavailable() { + // If this was the first goroutine to mark the region as + // unavailable, start a goroutine to reestablish a connection + go c.reestablishRegion(reg) + } if ch := reg.AvailabilityChan(); ch != nil { - // The region is unavailable. Wait for it to become available, - // a new region or for the deadline to be exceeded. select { case <-ctx.Done(): - return nil, rpc.Context().Err() + return nil, ctx.Err() case <-c.done: return nil, ErrClientClosed case <-ch: @@ -115,15 +145,16 @@ func (c *client) SendRPC(rpc hrpc.Call) (msg proto.Message, err error) { } if reg.Context().Err() != nil { // region is dead because it was split or merged, - // lookup a new one and retry - reg, err = c.getRegionForRpc(ctx, rpc) - if err != nil { - return nil, err - } + // retry lookup + continue + } + client = reg.Client() + if client == nil { + continue } - default: - return msg, err } + rpc.SetRegion(reg) + return client, nil } } @@ -171,31 +202,14 @@ func sendBlocking(ctx context.Context, rc hrpc.RegionClient, rpc hrpc.Call) ( } } -func (c *client) sendRPCToRegion(ctx context.Context, rpc hrpc.Call, reg hrpc.RegionInfo) ( +func (c *client) sendRPCToRegion(ctx context.Context, rpc hrpc.Call, regClient hrpc.RegionClient) ( proto.Message, error) { - if reg.IsUnavailable() { - return nil, region.NotServingRegionError{} - } - rpc.SetRegion(reg) - - // Queue the RPC to be sent to the region - client := reg.Client() - if client == nil { - // There was an error queueing the RPC. - // Mark the region as unavailable. - if reg.MarkUnavailable() { - // If this was the first goroutine to mark the region as - // unavailable, start a goroutine to reestablish a connection - go c.reestablishRegion(reg) - } - return nil, region.NotServingRegionError{} - } - res, err := sendBlocking(ctx, client, rpc) + res, err := sendBlocking(ctx, regClient, rpc) if err != nil { return nil, err } if res.Error != nil { - c.handleResultError(res.Error, reg, client) + c.handleResultError(res.Error, rpc.Region(), regClient) } return res.Msg, res.Error } diff --git a/rpc_test.go b/rpc_test.go index 192aee3e..1dbd68fe 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -443,7 +443,7 @@ func TestSendRPCToRegionClientDownDelayed(t *testing.T) { origlReg.SetClient(rc) mockCall := mock.NewMockCall(ctrl) - mockCall.EXPECT().SetRegion(origlReg).Times(1) + mockCall.EXPECT().Region().Return(origlReg).Times(1) result := make(chan hrpc.RPCResult, 1) mockCall.EXPECT().ResultChan().Return(result).Times(1) @@ -464,19 +464,12 @@ func TestSendRPCToRegionClientDownDelayed(t *testing.T) { result <- hrpc.RPCResult{Error: region.ServerError{}} }) - var wg sync.WaitGroup - wg.Add(1) - go func() { - _, err := c.sendRPCToRegion(context.Background(), mockCall, origlReg) - switch err.(type) { - case region.ServerError, region.NotServingRegionError: - default: - t.Errorf("Got unexpected error: %v", err) - } - wg.Done() - }() - - wg.Wait() + _, err := c.sendRPCToRegion(context.Background(), mockCall, rc) + switch err.(type) { + case region.ServerError, region.NotServingRegionError: + default: + t.Errorf("Got unexpected error: %v", err) + } // check that we did not down new client if len(c.clients.regions) != 1 { @@ -745,6 +738,7 @@ func TestConcurrentRetryableError(t *testing.T) { for i := range calls { mockCall := mock.NewMockCall(ctrl) mockCall.EXPECT().SetRegion(origlReg).AnyTimes() + mockCall.EXPECT().Region().Return(origlReg).AnyTimes() result := make(chan hrpc.RPCResult, 1) result <- hrpc.RPCResult{Error: region.NotServingRegionError{}} mockCall.EXPECT().ResultChan().Return(result).AnyTimes() @@ -755,7 +749,7 @@ func TestConcurrentRetryableError(t *testing.T) { for _, mockCall := range calls { wg.Add(1) go func(mockCall hrpc.Call) { - _, err := c.sendRPCToRegion(context.Background(), mockCall, origlReg) + _, err := c.sendRPCToRegion(context.Background(), mockCall, rc) if _, ok := err.(region.NotServingRegionError); !ok { t.Errorf("Got unexpected error: %v", err) }