From 4aada7ad890a91cc6685509f9aea8ee31c6f7dff Mon Sep 17 00:00:00 2001 From: ekexium Date: Tue, 24 Sep 2024 16:26:57 +0800 Subject: [PATCH] apply suggestions from review Signed-off-by: ekexium --- internal/locate/region_cache.go | 9 ++-- txnkv/transaction/2pc.go | 79 +++++++++++++++++---------------- 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/internal/locate/region_cache.go b/internal/locate/region_cache.go index 6ada53055..01dfa385c 100644 --- a/internal/locate/region_cache.go +++ b/internal/locate/region_cache.go @@ -728,7 +728,7 @@ func NewRegionCache(pdClient pd.Client, opt ...RegionCacheOpt) *RegionCache { } c.bg.schedule( func(ctx context.Context, _ time.Time) bool { - refreshFullStoreList(ctx, c.PDClient(), c.stores) + refreshFullStoreList(ctx, c.stores) return false }, refreshStoreListInterval, ) @@ -736,8 +736,8 @@ func NewRegionCache(pdClient pd.Client, opt ...RegionCacheOpt) *RegionCache { } // Try to refresh full store list. Errors are ignored. -func refreshFullStoreList(ctx context.Context, pdClient pd.Client, stores storeCache) { - storeList, err := pdClient.GetAllStores(ctx) +func refreshFullStoreList(ctx context.Context, stores storeCache) { + storeList, err := stores.fetchAllStores(ctx) if err != nil { logutil.Logger(ctx).Info("refresh full store list failed", zap.Error(err)) return @@ -747,17 +747,16 @@ func refreshFullStoreList(ctx context.Context, pdClient pd.Client, stores storeC if exist { continue } - s := stores.getOrInsertDefault(store.GetId()) // GetAllStores is supposed to return only Up and Offline stores. // This check is being defensive and to make it consistent with store resolve code. if store == nil || store.GetState() == metapb.StoreState_Tombstone { - s.setResolveState(tombstone) continue } addr := store.GetAddress() if addr == "" { continue } + s := stores.getOrInsertDefault(store.GetId()) // TODO: maybe refactor this, together with other places initializing Store s.addr = addr s.peerAddr = store.GetPeerAddress() diff --git a/txnkv/transaction/2pc.go b/txnkv/transaction/2pc.go index cf99f660a..2a01befe6 100644 --- a/txnkv/transaction/2pc.go +++ b/txnkv/transaction/2pc.go @@ -1376,8 +1376,8 @@ func keepAlive( const broadcastRpcTimeout = time.Second * 5 const broadcastMaxConcurrency = 10 -// broadcastToAllStores asynchronously broadcasts the transaction status to all stores -// errors are ignored. +// broadcastToAllStores asynchronously broadcasts the transaction status to all stores. +// Errors are ignored. func broadcastToAllStores( txn *KVTxn, store kvstore, @@ -1388,53 +1388,56 @@ func broadcastToAllStores( ) { broadcastFunc := func() { stores := store.GetRegionCache().GetStoresByType(tikvrpc.TiKV) - req := tikvrpc.NewRequest( - tikvrpc.CmdBroadcastTxnStatus, &kvrpcpb.BroadcastTxnStatusRequest{ - TxnStatus: []*kvrpcpb.TxnStatus{status}, - }, - ) - req.Context.ClusterId = store.GetClusterID() - req.Context.ResourceControlContext = &kvrpcpb.ResourceControlContext{ - ResourceGroupName: resourceGroupName, - } - req.Context.ResourceGroupTag = resourceGroupTag + concurrency := min(broadcastMaxConcurrency, len(stores)) + rateLimit := make(chan struct{}, concurrency) var wg sync.WaitGroup - concurrency := min(broadcastMaxConcurrency, len(stores)) - taskChan := make(chan *locate.Store, concurrency) - for i := 0; i < concurrency; i++ { + for _, s := range stores { + rateLimit <- struct{}{} wg.Add(1) - if err := txn.spawnWithStorePool(func() { + target := s + + err := txn.spawnWithStorePool(func() { defer wg.Done() - for s := range taskChan { - _, err := store.GetTiKVClient().SendRequest( - bo.GetCtx(), - s.GetAddr(), - req, - broadcastRpcTimeout, + defer func() { <-rateLimit }() + + req := tikvrpc.NewRequest( + tikvrpc.CmdBroadcastTxnStatus, &kvrpcpb.BroadcastTxnStatusRequest{ + TxnStatus: []*kvrpcpb.TxnStatus{status}, + }, + ) + req.Context.ClusterId = store.GetClusterID() + req.Context.ResourceControlContext = &kvrpcpb.ResourceControlContext{ + ResourceGroupName: resourceGroupName, + } + req.Context.ResourceGroupTag = resourceGroupTag + + _, err := store.GetTiKVClient().SendRequest( + bo.GetCtx(), + target.GetAddr(), + req, + broadcastRpcTimeout, + ) + if err != nil { + logutil.Logger(store.Ctx()).Info( + "broadcast txn status failed", + zap.Uint64("storeID", target.StoreID()), + zap.String("storeAddr", target.GetAddr()), + zap.Stringer("status", status), + zap.Error(err), ) - if err != nil { - logutil.Logger(store.Ctx()).Info( - "broadcast txn status failed", - zap.Uint64("storeID", s.StoreID()), - zap.String("storeAddr", s.GetAddr()), - zap.Stringer("status", status), - zap.Error(err), - ) - } } - }); err != nil { - wg.Done() // Ensure wg is decremented if spawning fails + }) + + if err != nil { + // If spawning the goroutine fails, release the slot and mark done + <-rateLimit + wg.Done() logutil.Logger(store.Ctx()).Error("failed to spawn worker goroutine", zap.Error(err)) } } - for _, s := range stores { - taskChan <- s - } - - close(taskChan) wg.Wait() }