diff --git a/client/metrics.go b/client/metrics.go index a11362669b35..a83b4a364076 100644 --- a/client/metrics.go +++ b/client/metrics.go @@ -39,13 +39,14 @@ func initAndRegisterMetrics(constLabels prometheus.Labels) { } var ( - cmdDuration *prometheus.HistogramVec - cmdFailedDuration *prometheus.HistogramVec - requestDuration *prometheus.HistogramVec - tsoBestBatchSize prometheus.Histogram - tsoBatchSize prometheus.Histogram - tsoBatchSendLatency prometheus.Histogram - requestForwarded *prometheus.GaugeVec + cmdDuration *prometheus.HistogramVec + cmdFailedDuration *prometheus.HistogramVec + requestDuration *prometheus.HistogramVec + tsoBestBatchSize prometheus.Histogram + tsoBatchSize prometheus.Histogram + tsoBatchSendLatency prometheus.Histogram + requestForwarded *prometheus.GaugeVec + ongoingRequestCountGauge *prometheus.GaugeVec ) func initMetrics(constLabels prometheus.Labels) { @@ -117,6 +118,15 @@ func initMetrics(constLabels prometheus.Labels) { Help: "The status to indicate if the request is forwarded", ConstLabels: constLabels, }, []string{"host", "delegate"}) + + ongoingRequestCountGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "pd_client", + Subsystem: "request", + Name: "ongoing_requests_count", + Help: "Current count of ongoing batch tso requests", + ConstLabels: constLabels, + }, []string{"stream"}) } var ( diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index a713b7a187d8..10c85f5aa1f3 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -30,18 +30,16 @@ type tsoBatchController struct { // bestBatchSize is a dynamic size that changed based on the current batch effect. bestBatchSize int - tsoRequestCh chan *tsoRequest collectedRequests []*tsoRequest collectedRequestCount int batchStartTime time.Time } -func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tsoBatchController { +func newTSOBatchController(maxBatchSize int) *tsoBatchController { return &tsoBatchController{ maxBatchSize: maxBatchSize, bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */ - tsoRequestCh: tsoRequestCh, collectedRequests: make([]*tsoRequest, maxBatchSize+1), collectedRequestCount: 0, } @@ -49,12 +47,12 @@ func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tso // fetchPendingRequests will start a new round of the batch collecting from the channel. // It returns true if everything goes well, otherwise false which means we should stop the service. -func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error { +func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh <-chan struct{}, maxBatchWaitInterval time.Duration) error { var firstRequest *tsoRequest select { case <-ctx.Done(): return ctx.Err() - case firstRequest = <-tbc.tsoRequestCh: + case firstRequest = <-tsoRequestCh: } // Start to batch when the first TSO request arrives. tbc.batchStartTime = time.Now() @@ -65,7 +63,7 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatc fetchPendingRequestsLoop: for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -88,7 +86,7 @@ fetchPendingRequestsLoop: defer after.Stop() for tbc.collectedRequestCount < tbc.bestBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -103,7 +101,7 @@ fetchPendingRequestsLoop: // we can adjust the `tbc.bestBatchSize` dynamically later. for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -149,18 +147,10 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in tbc.collectedRequestCount = 0 } -func (tbc *tsoBatchController) revokePendingRequests(err error) { - for i := 0; i < len(tbc.tsoRequestCh); i++ { - req := <-tbc.tsoRequestCh - req.tryDone(err) - } -} - func (tbc *tsoBatchController) clear() { log.Info("[pd] clear the tso batch controller", zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), - zap.Int("collected-request-count", tbc.collectedRequestCount), zap.Int("pending-request-count", len(tbc.tsoRequestCh))) + zap.Int("collected-request-count", tbc.collectedRequestCount)) tsoErr := errors.WithStack(errClosing) tbc.finishCollectedRequests(0, 0, 0, tsoErr) - tbc.revokePendingRequests(tsoErr) } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a7c99057275c..368177dfe90b 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -76,10 +76,18 @@ type tsoDispatcher struct { provider tsoServiceProvider // URL -> *connectionContext - connectionCtxs *sync.Map - batchController *tsoBatchController - tsDeadlineCh chan *deadline - lastTSOInfo *tsoInfo + connectionCtxs *sync.Map + tsoRequestCh chan *tsoRequest + tsDeadlineCh chan *deadline + lastTSOInfo *tsoInfo + // For reusing tsoBatchController objects + batchBufferPool *sync.Pool + + // For controlling amount of concurrently processing RPC requests. + // A token must be acquired here before sending an RPC request, and the token must be put back after finishing the + // RPC. This is used like a semaphore, but we don't use semaphore directly here as it cannot be selected with + // other channels. + tokenCh chan struct{} updateConnectionCtxsCh chan struct{} } @@ -91,24 +99,27 @@ func newTSODispatcher( provider tsoServiceProvider, ) *tsoDispatcher { dispatcherCtx, dispatcherCancel := context.WithCancel(ctx) - tsoBatchController := newTSOBatchController( - make(chan *tsoRequest, maxBatchSize*2), - maxBatchSize, - ) + tsoRequestCh := make(chan *tsoRequest, maxBatchSize*2) failpoint.Inject("shortDispatcherChannel", func() { - tsoBatchController = newTSOBatchController( - make(chan *tsoRequest, 1), - maxBatchSize, - ) + tsoRequestCh = make(chan *tsoRequest, 1) }) + + tokenCh := make(chan struct{}, 64) + td := &tsoDispatcher{ - ctx: dispatcherCtx, - cancel: dispatcherCancel, - dc: dc, - provider: provider, - connectionCtxs: &sync.Map{}, - batchController: tsoBatchController, - tsDeadlineCh: make(chan *deadline, 1), + ctx: dispatcherCtx, + cancel: dispatcherCancel, + dc: dc, + provider: provider, + connectionCtxs: &sync.Map{}, + tsoRequestCh: tsoRequestCh, + tsDeadlineCh: make(chan *deadline, 1), + batchBufferPool: &sync.Pool{ + New: func() any { + return newTSOBatchController(maxBatchSize * 2) + }, + }, + tokenCh: tokenCh, updateConnectionCtxsCh: make(chan struct{}, 1), } go td.watchTSDeadline() @@ -146,13 +157,21 @@ func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { } } +func (td *tsoDispatcher) revokePendingRequests(err error) { + for i := 0; i < len(td.tsoRequestCh); i++ { + req := <-td.tsoRequestCh + req.tryDone(err) + } +} + func (td *tsoDispatcher) close() { td.cancel() - td.batchController.clear() + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) } func (td *tsoDispatcher) push(request *tsoRequest) { - td.batchController.tsoRequestCh <- request + td.tsoRequestCh <- request } func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { @@ -163,8 +182,12 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { svcDiscovery = provider.getServiceDiscovery() option = provider.getOption() connectionCtxs = td.connectionCtxs - batchController = td.batchController + batchController *tsoBatchController ) + + // Currently only 1 concurrency is supported. Put one token in. + td.tokenCh <- struct{}{} + log.Info("[tso] tso dispatcher created", zap.String("dc-location", dc)) // Clean up the connectionCtxs when the dispatcher exits. defer func() { @@ -175,7 +198,11 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { return true }) // Clear the tso batch controller. - batchController.clear() + if batchController != nil { + batchController.clear() + } + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) wg.Done() }() // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. @@ -203,7 +230,7 @@ tsoBatchLoop: maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, // otherwise the upper caller may get blocked on waiting for the results. - if err = batchController.fetchPendingRequests(ctx, maxBatchWaitInterval); err != nil { + if err = batchController.fetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil { // Finish the collected requests if the fetch failed. batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) if err == context.Canceled { @@ -284,7 +311,7 @@ tsoBatchLoop: case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = td.processRequests(stream, dc, td.batchController) + err = td.processRequests(stream, dc, batchController) close(done) // If error happens during tso stream handling, reset stream and run the next trial. if err != nil { @@ -422,25 +449,30 @@ func (td *tsoDispatcher) processRequests( keyspaceID = svcDiscovery.GetKeyspaceID() reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() ) - respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( + + cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + curTSOInfo := &tsoInfo{ + tsoServer: stream.getServerURL(), + reqKeyspaceGroupID: reqKeyspaceGroupID, + respKeyspaceGroupID: result.respKeyspaceGroupID, + respReceivedAt: time.Now(), + physical: result.physical, + logical: result.logical, + } + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) + td.compareAndSwapTS(curTSOInfo, firstLogical) + tbc.finishCollectedRequests(result.physical, firstLogical, result.suffixBits, err) + td.batchBufferPool.Put(tbc) + } + + err := stream.processRequests( clusterID, keyspaceID, reqKeyspaceGroupID, - dcLocation, count, tbc.batchStartTime) + dcLocation, count, tbc.batchStartTime, cb) if err != nil { tbc.finishCollectedRequests(0, 0, 0, err) return err } - curTSOInfo := &tsoInfo{ - tsoServer: stream.getServerURL(), - reqKeyspaceGroupID: reqKeyspaceGroupID, - respKeyspaceGroupID: respKeyspaceGroupID, - respReceivedAt: time.Now(), - physical: physical, - logical: logical, - } - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits) - td.compareAndSwapTS(curTSOInfo, firstLogical) - tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) return nil } diff --git a/client/tso_stream.go b/client/tso_stream.go index da9cab95ba00..c9584c0c0475 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -16,13 +16,19 @@ package pd import ( "context" + "fmt" "io" + "sync" + "sync/atomic" "time" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/client/errs" + "go.uber.org/zap" "google.golang.org/grpc" ) @@ -62,7 +68,7 @@ func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFun stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: pdTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(b.serverURL, pdTSOStreamAdapter{stream}), nil } return nil, err } @@ -81,7 +87,7 @@ func (b *tsoTSOStreamBuilder) build( stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: tsoTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(b.serverURL, tsoTSOStreamAdapter{stream}), nil } return nil, err } @@ -172,51 +178,206 @@ func (s tsoTSOStreamAdapter) Recv() (tsoRequestResult, error) { }, nil } +type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) + +type batchedRequests struct { + startTime time.Time + count int64 + reqKeyspaceGroupID uint32 + callback onFinishedCallback +} + +// tsoStream represents an abstracted stream for requesting TSO. +// This type designed decoupled with users of this type, so tsoDispatcher won't be directly accessed here. +// Also in order to avoid potential memory allocations that might happen when passing closures as the callback, +// we instead use the `batchedRequestsNotifier` as the abstraction, and accepts generic type instead of dynamic interface +// type. type tsoStream struct { serverURL string // The internal gRPC stream. // - `pdpb.PD_TsoClient` for a leader/follower in the PD cluster. // - `tsopb.TSO_TsoClient` for a primary/secondary in the TSO cluster. stream grpcTSOStreamAdapter + // An identifier of the tsoStream object for metrics reporting and diagnosing. + streamID string + + pendingRequests chan batchedRequests + + estimateLatencyMicros atomic.Uint64 + + cancel context.CancelFunc + wg sync.WaitGroup + + // For syncing between sender and receiver to guarantee all requests are finished when closing. + state atomic.Int32 + + ongoingRequestCountGauge prometheus.Gauge + ongoingRequests atomic.Int32 +} + +const ( + streamStateIdle int32 = iota + streamStateSending + streamStateClosing +) + +var streamIDAlloc atomic.Int32 + +// TODO: Pass a context? +func newTSOStream(serverURL string, stream grpcTSOStreamAdapter) *tsoStream { + streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) + ctx, cancel := context.WithCancel(context.Background()) + s := &tsoStream{ + serverURL: serverURL, + stream: stream, + streamID: streamID, + + pendingRequests: make(chan batchedRequests, 64), + + cancel: cancel, + + ongoingRequestCountGauge: ongoingRequestCountGauge.WithLabelValues(streamID), + } + s.wg.Add(1) + go s.recvLoop(ctx) + return s } func (s *tsoStream) getServerURL() string { return s.serverURL } +// processRequests starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready, +// it will be passed th `notifier.finish`. +// +// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. func (s *tsoStream) processRequests( - clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, -) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { + clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, callback onFinishedCallback, +) error { start := time.Now() - if err = s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { + + // Check if the stream is closing or closed, in which case no more requests should be put in. + // Note that the prevState should be restored very soon, as the receiver may check + prevState := s.state.Swap(streamStateSending) + switch prevState { + case streamStateIdle: + // Expected case + break + case streamStateClosing: + s.state.Store(prevState) + log.Info("tsoStream closed") + return errs.ErrClientTSOStreamClosed + case streamStateSending: + log.Fatal("unexpected concurrent sending on tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", prevState)) + } + + select { + case s.pendingRequests <- batchedRequests{ + startTime: start, + count: count, + reqKeyspaceGroupID: keyspaceGroupID, + callback: callback, + }: + default: + s.state.Store(prevState) + return errors.New("unexpected channel full") + } + s.state.Store(prevState) + + if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + return errs.ErrClientTSOStreamClosed } - return + return errors.WithStack(err) } tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds()) - res, err := s.stream.Recv() - duration := time.Since(start).Seconds() - if err != nil { - requestFailedDurationTSO.Observe(duration) - if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(1))) + return nil +} + +func (s *tsoStream) recvLoop(ctx context.Context) { + var finishWithErr error + + defer func() { + s.cancel() + for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { + switch state := s.state.Load(); state { + case streamStateIdle, streamStateSending: + // streamStateSending should switch to streamStateIdle very quickly. Spin until successfully setting to + // streamStateClosing. + continue + case streamStateClosing: + log.Warn("unexpected double closing of tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", state)) + } } - return - } - requestDurationTSO.Observe(duration) - tsoBatchSize.Observe(float64(count)) - if res.count != uint32(count) { - err = errors.WithStack(errTSOLength) - return + // The loop must end with an error (including context.Canceled). + if finishWithErr == nil { + log.Fatal("tsoStream recvLoop ended without error", zap.String("stream", s.streamID)) + } + log.Info("tsoStream.recvLoop ended", zap.String("stream", s.streamID), zap.Error(finishWithErr)) + + close(s.pendingRequests) + for req := range s.pendingRequests { + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, s.serverURL, finishWithErr) + } + + s.wg.Done() + s.ongoingRequests.Store(0) + s.ongoingRequestCountGauge.Set(0) + }() + +recvLoop: + for { + select { + case <-ctx.Done(): + finishWithErr = context.Canceled + break recvLoop + default: + } + + res, err := s.stream.Recv() + + // Load the corresponding batchedRequests + var req batchedRequests + select { + case req = <-s.pendingRequests: + default: + finishWithErr = errors.New("tsoStream timing order broken") + break + } + + durationSeconds := time.Since(req.startTime).Seconds() + + if err != nil { + requestFailedDurationTSO.Observe(durationSeconds) + if err == io.EOF { + finishWithErr = errs.ErrClientTSOStreamClosed + } else { + finishWithErr = errors.WithStack(err) + } + break + } + + latencySeconds := durationSeconds + requestDurationTSO.Observe(latencySeconds) + tsoBatchSize.Observe(float64(res.count)) + + if res.count != uint32(req.count) { + finishWithErr = errors.WithStack(errTSOLength) + break + } + + req.callback(res, req.reqKeyspaceGroupID, s.serverURL, nil) + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(-1))) } +} - respKeyspaceGroupID = res.respKeyspaceGroupID - physical, logical, suffixBits = res.physical, res.logical, res.suffixBits - return +func (s *tsoStream) Close() { + s.cancel() + s.wg.Wait() }