Skip to content

Commit

Permalink
Add getRegionAndClientForRPC
Browse files Browse the repository at this point in the history
This function can be used by both SendRPC and SendBatch to associate a
region and region client with an RPC.
  • Loading branch information
aaronbee committed Jun 23, 2023
1 parent c9faad0 commit 88eba23
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 51 deletions.
86 changes: 50 additions & 36 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,44 +86,75 @@ func (c *client) SendRPC(rpc hrpc.Call) (msg proto.Message, err error) {
sp.End()
}()

reg, err := c.getRegionForRpc(ctx, rpc)
if err != nil {
return nil, err
}

backoff := backoffStart
for {
msg, err = c.sendRPCToRegion(ctx, rpc, reg)
regClient, err := c.getRegionAndClientForRPC(ctx, rpc)
if err != nil {
return nil, err
}
msg, err = c.sendRPCToRegion(ctx, rpc, regClient)
switch err.(type) {
case region.RetryableError:
sp.AddEvent("retrySleep")
backoff, err = sleepAndIncreaseBackoff(ctx, backoff)
if err != nil {
return msg, err
}
continue // retry
case region.ServerError, region.NotServingRegionError:
continue // retry
}
return msg, err
}
}

func (c *client) getRegionAndClientForRPC(ctx context.Context, rpc hrpc.Call) (
hrpc.RegionClient, error) {
for {
reg, err := c.getRegionForRpc(ctx, rpc)
if err != nil {
return nil, err
}
if ch := reg.AvailabilityChan(); ch != nil { // region is currently unavailable
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-c.done:
return nil, ErrClientClosed
case <-ch:
}
}

client := reg.Client()
if client == nil {
// There was an error getting the region client. Mark the
// region as unavailable.
if reg.MarkUnavailable() {
// If this was the first goroutine to mark the region as
// unavailable, start a goroutine to reestablish a connection
go c.reestablishRegion(reg)
}
if ch := reg.AvailabilityChan(); ch != nil {
// The region is unavailable. Wait for it to become available,
// a new region or for the deadline to be exceeded.
select {
case <-ctx.Done():
return nil, rpc.Context().Err()
return nil, ctx.Err()
case <-c.done:
return nil, ErrClientClosed
case <-ch:
}
}
if reg.Context().Err() != nil {
// region is dead because it was split or merged,
// lookup a new one and retry
reg, err = c.getRegionForRpc(ctx, rpc)
if err != nil {
return nil, err
}
// retry lookup
continue
}
client = reg.Client()
if client == nil {
continue
}
default:
return msg, err
}
rpc.SetRegion(reg)
return client, nil
}
}

Expand Down Expand Up @@ -171,31 +202,14 @@ func sendBlocking(ctx context.Context, rc hrpc.RegionClient, rpc hrpc.Call) (
}
}

func (c *client) sendRPCToRegion(ctx context.Context, rpc hrpc.Call, reg hrpc.RegionInfo) (
func (c *client) sendRPCToRegion(ctx context.Context, rpc hrpc.Call, regClient hrpc.RegionClient) (
proto.Message, error) {
if reg.IsUnavailable() {
return nil, region.NotServingRegionError{}
}
rpc.SetRegion(reg)

// Queue the RPC to be sent to the region
client := reg.Client()
if client == nil {
// There was an error queueing the RPC.
// Mark the region as unavailable.
if reg.MarkUnavailable() {
// If this was the first goroutine to mark the region as
// unavailable, start a goroutine to reestablish a connection
go c.reestablishRegion(reg)
}
return nil, region.NotServingRegionError{}
}
res, err := sendBlocking(ctx, client, rpc)
res, err := sendBlocking(ctx, regClient, rpc)
if err != nil {
return nil, err
}
if res.Error != nil {
c.handleResultError(res.Error, reg, client)
c.handleResultError(res.Error, rpc.Region(), regClient)
}
return res.Msg, res.Error
}
Expand Down
24 changes: 9 additions & 15 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ func TestSendRPCToRegionClientDownDelayed(t *testing.T) {
origlReg.SetClient(rc)

mockCall := mock.NewMockCall(ctrl)
mockCall.EXPECT().SetRegion(origlReg).Times(1)
mockCall.EXPECT().Region().Return(origlReg).Times(1)
result := make(chan hrpc.RPCResult, 1)
mockCall.EXPECT().ResultChan().Return(result).Times(1)

Expand All @@ -464,19 +464,12 @@ func TestSendRPCToRegionClientDownDelayed(t *testing.T) {
result <- hrpc.RPCResult{Error: region.ServerError{}}
})

var wg sync.WaitGroup
wg.Add(1)
go func() {
_, err := c.sendRPCToRegion(context.Background(), mockCall, origlReg)
switch err.(type) {
case region.ServerError, region.NotServingRegionError:
default:
t.Errorf("Got unexpected error: %v", err)
}
wg.Done()
}()

wg.Wait()
_, err := c.sendRPCToRegion(context.Background(), mockCall, rc)
switch err.(type) {
case region.ServerError, region.NotServingRegionError:
default:
t.Errorf("Got unexpected error: %v", err)
}

// check that we did not down new client
if len(c.clients.regions) != 1 {
Expand Down Expand Up @@ -745,6 +738,7 @@ func TestConcurrentRetryableError(t *testing.T) {
for i := range calls {
mockCall := mock.NewMockCall(ctrl)
mockCall.EXPECT().SetRegion(origlReg).AnyTimes()
mockCall.EXPECT().Region().Return(origlReg).AnyTimes()
result := make(chan hrpc.RPCResult, 1)
result <- hrpc.RPCResult{Error: region.NotServingRegionError{}}
mockCall.EXPECT().ResultChan().Return(result).AnyTimes()
Expand All @@ -755,7 +749,7 @@ func TestConcurrentRetryableError(t *testing.T) {
for _, mockCall := range calls {
wg.Add(1)
go func(mockCall hrpc.Call) {
_, err := c.sendRPCToRegion(context.Background(), mockCall, origlReg)
_, err := c.sendRPCToRegion(context.Background(), mockCall, rc)
if _, ok := err.(region.NotServingRegionError); !ok {
t.Errorf("Got unexpected error: %v", err)
}
Expand Down

0 comments on commit 88eba23

Please sign in to comment.