Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: MyonKeminta <[email protected]>
  • Loading branch information
MyonKeminta committed Aug 2, 2024
1 parent 9af28fc commit 881f542
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 90 deletions.
24 changes: 17 additions & 7 deletions client/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ func initAndRegisterMetrics(constLabels prometheus.Labels) {
}

var (
cmdDuration *prometheus.HistogramVec
cmdFailedDuration *prometheus.HistogramVec
requestDuration *prometheus.HistogramVec
tsoBestBatchSize prometheus.Histogram
tsoBatchSize prometheus.Histogram
tsoBatchSendLatency prometheus.Histogram
requestForwarded *prometheus.GaugeVec
cmdDuration *prometheus.HistogramVec
cmdFailedDuration *prometheus.HistogramVec
requestDuration *prometheus.HistogramVec
tsoBestBatchSize prometheus.Histogram
tsoBatchSize prometheus.Histogram
tsoBatchSendLatency prometheus.Histogram
requestForwarded *prometheus.GaugeVec
ongoingRequestCountGauge *prometheus.GaugeVec
)

func initMetrics(constLabels prometheus.Labels) {
Expand Down Expand Up @@ -117,6 +118,15 @@ func initMetrics(constLabels prometheus.Labels) {
Help: "The status to indicate if the request is forwarded",
ConstLabels: constLabels,
}, []string{"host", "delegate"})

ongoingRequestCountGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "pd_client",
Subsystem: "request",
Name: "ongoing_requests_count",
Help: "Current count of ongoing batch tso requests",
ConstLabels: constLabels,
}, []string{"stream"})
}

var (
Expand Down
24 changes: 7 additions & 17 deletions client/tso_batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,29 @@ type tsoBatchController struct {
// bestBatchSize is a dynamic size that changed based on the current batch effect.
bestBatchSize int

tsoRequestCh chan *tsoRequest
collectedRequests []*tsoRequest
collectedRequestCount int

batchStartTime time.Time
}

func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tsoBatchController {
func newTSOBatchController(maxBatchSize int) *tsoBatchController {
return &tsoBatchController{
maxBatchSize: maxBatchSize,
bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */
tsoRequestCh: tsoRequestCh,
collectedRequests: make([]*tsoRequest, maxBatchSize+1),
collectedRequestCount: 0,
}
}

// fetchPendingRequests will start a new round of the batch collecting from the channel.
// It returns true if everything goes well, otherwise false which means we should stop the service.
func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error {
func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh <-chan struct{}, maxBatchWaitInterval time.Duration) error {
var firstRequest *tsoRequest
select {
case <-ctx.Done():
return ctx.Err()
case firstRequest = <-tbc.tsoRequestCh:
case firstRequest = <-tsoRequestCh:
}
// Start to batch when the first TSO request arrives.
tbc.batchStartTime = time.Now()
Expand All @@ -65,7 +63,7 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatc
fetchPendingRequestsLoop:
for tbc.collectedRequestCount < tbc.maxBatchSize {
select {
case tsoReq := <-tbc.tsoRequestCh:
case tsoReq := <-tsoRequestCh:
tbc.pushRequest(tsoReq)
case <-ctx.Done():
return ctx.Err()
Expand All @@ -88,7 +86,7 @@ fetchPendingRequestsLoop:
defer after.Stop()
for tbc.collectedRequestCount < tbc.bestBatchSize {
select {
case tsoReq := <-tbc.tsoRequestCh:
case tsoReq := <-tsoRequestCh:
tbc.pushRequest(tsoReq)
case <-ctx.Done():
return ctx.Err()
Expand All @@ -103,7 +101,7 @@ fetchPendingRequestsLoop:
// we can adjust the `tbc.bestBatchSize` dynamically later.
for tbc.collectedRequestCount < tbc.maxBatchSize {
select {
case tsoReq := <-tbc.tsoRequestCh:
case tsoReq := <-tsoRequestCh:
tbc.pushRequest(tsoReq)
case <-ctx.Done():
return ctx.Err()
Expand Down Expand Up @@ -149,18 +147,10 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in
tbc.collectedRequestCount = 0
}

func (tbc *tsoBatchController) revokePendingRequests(err error) {
for i := 0; i < len(tbc.tsoRequestCh); i++ {
req := <-tbc.tsoRequestCh
req.tryDone(err)
}
}

func (tbc *tsoBatchController) clear() {
log.Info("[pd] clear the tso batch controller",
zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize),
zap.Int("collected-request-count", tbc.collectedRequestCount), zap.Int("pending-request-count", len(tbc.tsoRequestCh)))
zap.Int("collected-request-count", tbc.collectedRequestCount))
tsoErr := errors.WithStack(errClosing)
tbc.finishCollectedRequests(0, 0, 0, tsoErr)
tbc.revokePendingRequests(tsoErr)
}
110 changes: 71 additions & 39 deletions client/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,18 @@ type tsoDispatcher struct {

provider tsoServiceProvider
// URL -> *connectionContext
connectionCtxs *sync.Map
batchController *tsoBatchController
tsDeadlineCh chan *deadline
lastTSOInfo *tsoInfo
connectionCtxs *sync.Map
tsoRequestCh chan *tsoRequest
tsDeadlineCh chan *deadline
lastTSOInfo *tsoInfo
// For reusing tsoBatchController objects
batchBufferPool *sync.Pool

// For controlling amount of concurrently processing RPC requests.
// A token must be acquired here before sending an RPC request, and the token must be put back after finishing the
// RPC. This is used like a semaphore, but we don't use semaphore directly here as it cannot be selected with
// other channels.
tokenCh chan struct{}

updateConnectionCtxsCh chan struct{}
}
Expand All @@ -91,24 +99,27 @@ func newTSODispatcher(
provider tsoServiceProvider,
) *tsoDispatcher {
dispatcherCtx, dispatcherCancel := context.WithCancel(ctx)
tsoBatchController := newTSOBatchController(
make(chan *tsoRequest, maxBatchSize*2),
maxBatchSize,
)
tsoRequestCh := make(chan *tsoRequest, maxBatchSize*2)
failpoint.Inject("shortDispatcherChannel", func() {
tsoBatchController = newTSOBatchController(
make(chan *tsoRequest, 1),
maxBatchSize,
)
tsoRequestCh = make(chan *tsoRequest, 1)
})

tokenCh := make(chan struct{}, 64)

td := &tsoDispatcher{
ctx: dispatcherCtx,
cancel: dispatcherCancel,
dc: dc,
provider: provider,
connectionCtxs: &sync.Map{},
batchController: tsoBatchController,
tsDeadlineCh: make(chan *deadline, 1),
ctx: dispatcherCtx,
cancel: dispatcherCancel,
dc: dc,
provider: provider,
connectionCtxs: &sync.Map{},
tsoRequestCh: tsoRequestCh,
tsDeadlineCh: make(chan *deadline, 1),
batchBufferPool: &sync.Pool{
New: func() any {
return newTSOBatchController(maxBatchSize * 2)
},
},
tokenCh: tokenCh,
updateConnectionCtxsCh: make(chan struct{}, 1),
}
go td.watchTSDeadline()
Expand Down Expand Up @@ -146,13 +157,21 @@ func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() {
}
}

func (td *tsoDispatcher) revokePendingRequests(err error) {
for i := 0; i < len(td.tsoRequestCh); i++ {
req := <-td.tsoRequestCh
req.tryDone(err)
}
}

func (td *tsoDispatcher) close() {
td.cancel()
td.batchController.clear()
tsoErr := errors.WithStack(errClosing)
td.revokePendingRequests(tsoErr)
}

func (td *tsoDispatcher) push(request *tsoRequest) {
td.batchController.tsoRequestCh <- request
td.tsoRequestCh <- request
}

func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) {
Expand All @@ -163,8 +182,12 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) {
svcDiscovery = provider.getServiceDiscovery()
option = provider.getOption()
connectionCtxs = td.connectionCtxs
batchController = td.batchController
batchController *tsoBatchController
)

// Currently only 1 concurrency is supported. Put one token in.
td.tokenCh <- struct{}{}

log.Info("[tso] tso dispatcher created", zap.String("dc-location", dc))
// Clean up the connectionCtxs when the dispatcher exits.
defer func() {
Expand All @@ -175,7 +198,11 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) {
return true
})
// Clear the tso batch controller.
batchController.clear()
if batchController != nil {
batchController.clear()
}
tsoErr := errors.WithStack(errClosing)
td.revokePendingRequests(tsoErr)
wg.Done()
}()
// Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event.
Expand Down Expand Up @@ -203,7 +230,7 @@ tsoBatchLoop:
maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval()
// Once the TSO requests are collected, must make sure they could be finished or revoked eventually,
// otherwise the upper caller may get blocked on waiting for the results.
if err = batchController.fetchPendingRequests(ctx, maxBatchWaitInterval); err != nil {
if err = batchController.fetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil {
// Finish the collected requests if the fetch failed.
batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err))
if err == context.Canceled {
Expand Down Expand Up @@ -284,7 +311,7 @@ tsoBatchLoop:
case td.tsDeadlineCh <- dl:
}
// processRequests guarantees that the collected requests could be finished properly.
err = td.processRequests(stream, dc, td.batchController)
err = td.processRequests(stream, dc, batchController)
close(done)
// If error happens during tso stream handling, reset stream and run the next trial.
if err != nil {
Expand Down Expand Up @@ -422,25 +449,30 @@ func (td *tsoDispatcher) processRequests(
keyspaceID = svcDiscovery.GetKeyspaceID()
reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID()
)
respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests(

cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) {
curTSOInfo := &tsoInfo{
tsoServer: stream.getServerURL(),
reqKeyspaceGroupID: reqKeyspaceGroupID,
respKeyspaceGroupID: result.respKeyspaceGroupID,
respReceivedAt: time.Now(),
physical: result.physical,
logical: result.logical,
}
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits)
td.compareAndSwapTS(curTSOInfo, firstLogical)
tbc.finishCollectedRequests(result.physical, firstLogical, result.suffixBits, err)
td.batchBufferPool.Put(tbc)
}

err := stream.processRequests(
clusterID, keyspaceID, reqKeyspaceGroupID,
dcLocation, count, tbc.batchStartTime)
dcLocation, count, tbc.batchStartTime, cb)
if err != nil {
tbc.finishCollectedRequests(0, 0, 0, err)
return err
}
curTSOInfo := &tsoInfo{
tsoServer: stream.getServerURL(),
reqKeyspaceGroupID: reqKeyspaceGroupID,
respKeyspaceGroupID: respKeyspaceGroupID,
respReceivedAt: time.Now(),
physical: physical,
logical: logical,
}
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits)
td.compareAndSwapTS(curTSOInfo, firstLogical)
tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil)
return nil
}

Expand Down
Loading

0 comments on commit 881f542

Please sign in to comment.