-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
398 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
// Copyright (C) 2024 The GoHBase Authors. All rights reserved. | ||
// This file is part of GoHBase. | ||
// Use of this source code is governed by the Apache License 2.0 | ||
// that can be found in the COPYING file. | ||
|
||
package region | ||
|
||
import ( | ||
"bufio" | ||
"io" | ||
"sync" | ||
|
||
"github.com/tsuna/gohbase/hrpc" | ||
"google.golang.org/protobuf/proto" | ||
) | ||
|
||
type congestionControl struct { | ||
c *client | ||
|
||
// retry allows the readLoop to send failed RPCs back to the write | ||
// loop for retrying. They are prioritized over new requests. | ||
retry chan hrpc.Call | ||
|
||
// sendWindow is the dynamic limit on the number of outstanding | ||
// requests. | ||
sendWindow int | ||
minWindow int | ||
maxWindow int | ||
// sema limits the total number of outstanding requests. Before | ||
// sending a request a token must be pushed into the sema. If | ||
// there isn't space then it blocks. Requests that are received | ||
// pull a token out of the sema. Tokens are pushed into the sema | ||
// in writeLoop and pulled in readLoop. | ||
sema *semaphore | ||
|
||
// IO functions overridable for testing | ||
trySend func(hrpc.Call) error | ||
receive func(io.Reader) (hrpc.Call, proto.Message, error) | ||
} | ||
|
||
func newCongestion(c *client, minWindowSize, maxWindowSize int) *congestionControl { | ||
cc := &congestionControl{ | ||
c: c, | ||
// retry is buffered to maxWindowSize to make it unlikely it | ||
// fills | ||
retry: make(chan hrpc.Call, maxWindowSize), | ||
|
||
sendWindow: maxWindowSize / 2, | ||
minWindow: minWindowSize, | ||
maxWindow: maxWindowSize, | ||
sema: newSemaphore(maxWindowSize/2, maxWindowSize), | ||
|
||
trySend: c.trySend, | ||
receive: c.receive, | ||
} | ||
return cc | ||
} | ||
|
||
func (cc *congestionControl) run(req <-chan hrpc.Call, done <-chan struct{}) { | ||
var wg sync.WaitGroup | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
cc.writeLoop(req, done) | ||
}() | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
cc.readLoop() | ||
}() | ||
wg.Wait() | ||
} | ||
|
||
func (cc *congestionControl) send(rpc hrpc.Call) error { | ||
cc.sema.take1() | ||
if err := cc.trySend(rpc); err != nil { | ||
returnResult(rpc, nil, err) | ||
if _, ok := err.(ServerError); ok { | ||
// trySend will fail the client | ||
return err | ||
} | ||
cc.sema.release1() | ||
} | ||
return nil | ||
} | ||
|
||
func (cc *congestionControl) writeLoop(req <-chan hrpc.Call, done <-chan struct{}) { | ||
for { | ||
// Prioritize retry requests | ||
select { | ||
case rpc := <-cc.retry: | ||
if err := cc.send(rpc); err != nil { | ||
return | ||
} | ||
case <-done: | ||
return | ||
default: | ||
} | ||
|
||
select { | ||
case rpc := <-req: | ||
if err := cc.send(rpc); err != nil { | ||
return | ||
} | ||
case rpc := <-cc.retry: | ||
if err := cc.send(rpc); err != nil { | ||
return | ||
} | ||
case <-done: | ||
return | ||
} | ||
} | ||
} | ||
|
||
func (cc *congestionControl) read(r io.Reader) error { | ||
// TODO: Requests with priority shouldn't affect the sema or send window | ||
rpc, resp, err := cc.receive(r) | ||
if err != nil { | ||
if _, ok := err.(ServerError); ok { | ||
returnResult(rpc, resp, err) | ||
return err | ||
} | ||
if _, ok := err.(RetryableError); ok { | ||
newSendWindow := cc.sendWindow / 2 | ||
if newSendWindow < cc.minWindow { | ||
newSendWindow = cc.minWindow | ||
} | ||
if newSendWindow != cc.sendWindow { | ||
// Get the semaphore in line with our window size by | ||
// adding the difference between the newSendWindow and | ||
// the existing sendWindow. | ||
cc.sema.add(newSendWindow - cc.sendWindow) | ||
cc.sendWindow = newSendWindow | ||
} | ||
cc.sema.release1() | ||
// Prioritize this request by putting it on cc.retry. | ||
select { | ||
case cc.retry <- rpc: | ||
default: | ||
// read cannot block or else it may result in a | ||
// deadlock between the writer and reader. If cc.retry | ||
// doesn't have room for this request return the | ||
// result to the caller to retry. | ||
returnResult(rpc, resp, err) | ||
} | ||
return nil | ||
} | ||
// other errors will be returned to the client to be handled | ||
} | ||
returnResult(rpc, resp, err) | ||
|
||
// Request succeeded or hit an error unrelated to congestion, so | ||
// expand sendWindow | ||
cc.sendWindow = min(cc.sendWindow+1, cc.maxWindow) | ||
// increase the sema by 2, (1 for the received response and 1 | ||
// because the window size has increased) | ||
cc.sema.add(2) | ||
return nil | ||
} | ||
|
||
func (cc *congestionControl) readLoop() { | ||
defer cc.sema.add(cc.maxWindow * 2) // unblock a possible waiter in writeLoop | ||
r := bufio.NewReader(cc.c.conn) | ||
for { | ||
if err := cc.read(r); err != nil { | ||
// fail the client and let the callers establish a new one | ||
cc.c.fail(err) | ||
return | ||
} | ||
} | ||
} | ||
|
||
// semaphore implements a semaphore using a condition variable and | ||
// mutex. This is more commonly implemented in Go with a buffered | ||
// channel, but a buffered channel doesn't allow consuming multiple | ||
// tokens at once or going negative in the number of consumed tokens | ||
// which this semaphore does. These abilities make it simpler to | ||
// reduce the window size after an error. | ||
type semaphore struct { | ||
c sync.Cond | ||
m sync.Mutex | ||
v int | ||
|
||
max int | ||
} | ||
|
||
func newSemaphore(v int, max int) *semaphore { | ||
s := &semaphore{v: v, max: max} | ||
s.c.L = &s.m | ||
return s | ||
} | ||
|
||
// waits for v to be positive and then subtracts 1 | ||
func (s *semaphore) take1() { | ||
s.m.Lock() | ||
defer s.m.Unlock() | ||
for s.v <= 0 { | ||
s.c.Wait() | ||
} | ||
s.v-- | ||
} | ||
|
||
func (s *semaphore) release1() { | ||
s.add(1) | ||
} | ||
|
||
func (s *semaphore) add(v int) { | ||
s.m.Lock() | ||
defer s.m.Unlock() | ||
prev := s.v | ||
s.v = min(s.v+v, s.max) | ||
if prev <= 0 && s.v > 0 { | ||
// because we have a max of one waiter (the writeLoop | ||
// goroutine) we just need to signal once when s.v becomes | ||
// positive. | ||
s.c.Signal() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
// Copyright (C) 2024 The GoHBase Authors. All rights reserved. | ||
// This file is part of GoHBase. | ||
// Use of this source code is governed by the Apache License 2.0 | ||
// that can be found in the COPYING file. | ||
|
||
package region | ||
|
||
import ( | ||
"errors" | ||
"io" | ||
"testing" | ||
|
||
"github.com/tsuna/gohbase/hrpc" | ||
"google.golang.org/protobuf/proto" | ||
) | ||
|
||
type resp struct { | ||
rpc hrpc.Call | ||
err error | ||
} | ||
|
||
type fakeRegionServer struct { | ||
reqs chan hrpc.Call | ||
resps chan resp | ||
} | ||
|
||
func (rs *fakeRegionServer) trySend(rpc hrpc.Call) error { | ||
rs.reqs <- rpc | ||
return nil | ||
} | ||
|
||
func (rs *fakeRegionServer) receive(io.Reader) (hrpc.Call, proto.Message, error) { | ||
resp := <-rs.resps | ||
return resp.rpc, nil, resp.err | ||
} | ||
|
||
func newRPC() hrpc.Call { | ||
get, err := hrpc.NewGet(nil, nil, nil) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return get | ||
} | ||
|
||
func TestCongestion(t *testing.T) { | ||
rs := fakeRegionServer{ | ||
reqs: make(chan hrpc.Call, 1), | ||
resps: make(chan resp, 1), | ||
} | ||
cc := newCongestion(nil, 1, 10) | ||
cc.trySend = rs.trySend | ||
cc.receive = rs.receive | ||
|
||
if cc.sendWindow != 5 || cc.sema.v != 5 { | ||
t.Fatalf("unexpected starting values: sendWindow: %d sema.v: %d", cc.sendWindow, cc.sema.v) | ||
} | ||
|
||
var rpcs []hrpc.Call | ||
for i := range 5 { | ||
rpc := newRPC() | ||
if err := cc.send(rpc); err != nil { | ||
t.Fatalf("unexpected error from send: %s", err) | ||
} | ||
if cc.sema.v != 5-(i+1) { | ||
t.Errorf("expected sema to be at %d, but got: %d", 5-(i+1), cc.sema.v) | ||
} | ||
rpcs = append(rpcs, <-rs.reqs) | ||
} | ||
|
||
// The next send should block, so do it in a goroutine | ||
go func() { | ||
rpc := newRPC() | ||
if err := cc.send(rpc); err != nil { | ||
t.Errorf("unexpected error from send: %s", err) | ||
} | ||
}() | ||
|
||
select { | ||
case rpc := <-rs.reqs: | ||
t.Errorf("send should have been blocked, but an rpc was sent: %v", rpc) | ||
default: | ||
} | ||
|
||
// receive a response, which should unblock our goroutine above | ||
rs.resps <- resp{rpcs[0], nil} | ||
if err := cc.read(nil); err != nil { | ||
t.Errorf("unexpected error from receive: %s", err) | ||
} | ||
if res := <-rpcs[0].ResultChan(); res.Error != nil { | ||
t.Errorf("unexpected error on response: %s", res.Error) | ||
} | ||
if cc.sendWindow != 6 { | ||
t.Errorf("expected sendWindow to expand to 6, got %d", cc.sendWindow) | ||
} | ||
rpcs = rpcs[1:] | ||
// accept send from above goroutine | ||
rpcs = append(rpcs, <-rs.reqs) | ||
// sema was at 0, then we received a successful response, which | ||
// should have added 2, and then sent another request so it should | ||
// be 1 now. | ||
if cc.sema.v != 1 { | ||
t.Errorf("expected sema to be at 1, but it's %d", cc.sema.v) | ||
} | ||
for i, rpc := range rpcs { | ||
rs.resps <- resp{rpc, nil} | ||
if err := cc.read(nil); err != nil { | ||
t.Errorf("unexpected error from receive: %s", err) | ||
} | ||
if res := <-rpc.ResultChan(); res.Error != nil { | ||
t.Errorf("unexpected error on response: %s", res.Error) | ||
} | ||
// sendWindow should be increasing by one on each successful receive | ||
if cc.sendWindow != min(6+(i+1), 10) { | ||
t.Errorf("expected sendWindow to be %d, got %d", min(6+i, 10), cc.sendWindow) | ||
} | ||
// sema.v should be increasing by two on each successful receive | ||
if cc.sema.v != min(1+(i+1)*2, 10) { | ||
t.Errorf("expected sema.v to be %d, got %d", min(1+(i+1)*2, 10), cc.sema.v) | ||
} | ||
} | ||
|
||
rpcs = rpcs[:0] | ||
|
||
// Send 10 requests to fill the send window and then fail them all | ||
for i := range 10 { | ||
rpc := newRPC() | ||
if err := cc.send(rpc); err != nil { | ||
t.Fatalf("unexpected error from send: %s", err) | ||
} | ||
if cc.sema.v != 10-(i+1) { | ||
t.Errorf("expected sema to be at %d, but got: %d", 5-(i+1), cc.sema.v) | ||
} | ||
rpcs = append(rpcs, <-rs.reqs) | ||
} | ||
// sendWindow should halve on each failure | ||
expectedSendWindow := []int{5, 2, 1, 1, 1, 1, 1, 1, 1, 1} | ||
// sema will be decreasing by the change in sendWindow, but | ||
// also increasing by 1 for each rpc read. | ||
expectedSema := []int{-4, -6, -6, -5, -4, -3, -2, -1, 0, 1} | ||
for i, rpc := range rpcs { | ||
rs.resps <- resp{rpc, RetryableError{}} | ||
if err := cc.read(nil); err != nil { | ||
t.Errorf("unexpected error from recieve: %s", err) | ||
} | ||
if cc.sendWindow != expectedSendWindow[i] { | ||
t.Errorf("expected sendWindow to be %d, got %d", expectedSendWindow[i], cc.sendWindow) | ||
} | ||
if cc.sema.v != expectedSema[i] { | ||
t.Errorf("expected sema.v to be %d, got %d", expectedSema[i], cc.sema.v) | ||
} | ||
} | ||
|
||
// Send one more RPC and fail it | ||
if err := cc.send(newRPC()); err != nil { | ||
t.Fatalf("unexpected error from send: %s", err) | ||
} | ||
rpc := <-rs.reqs | ||
rs.resps <- resp{rpc, RetryableError{}} | ||
if err := cc.read(nil); err != nil { | ||
t.Errorf("unexpected error from recieve: %s", err) | ||
} | ||
// read should not have blocked sending to the retry channel and | ||
// instead returned the error to the caller: | ||
if res := <-rpc.ResultChan(); !errors.Is(res.Error, RetryableError{}) { | ||
t.Errorf("expected to see RetryAble error result on rpc, got: %v", res) | ||
} | ||
|
||
// All the failed rpc's should be on the retry channel | ||
for i, rpc := range rpcs { | ||
if got := <-cc.retry; got != rpc { | ||
t.Errorf("unexpected rpc %d on retry chan, exp: %p got: %p", i, rpc, got) | ||
} | ||
} | ||
} | ||
|
||
// func TestRetries(t *testing.T) { | ||
// // create connection that fails every other request. Send 200 | ||
// // requests with maxWindowSize 100 and verify they all make it | ||
// // through. | ||
// } |