Skip to content

Commit

Permalink
apply suggestions from review
Browse files Browse the repository at this point in the history
Signed-off-by: ekexium <[email protected]>
  • Loading branch information
ekexium committed Sep 24, 2024
1 parent 4a9b779 commit 4aada7a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 43 deletions.
9 changes: 4 additions & 5 deletions internal/locate/region_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,16 +728,16 @@ func NewRegionCache(pdClient pd.Client, opt ...RegionCacheOpt) *RegionCache {
}
c.bg.schedule(
func(ctx context.Context, _ time.Time) bool {
refreshFullStoreList(ctx, c.PDClient(), c.stores)
refreshFullStoreList(ctx, c.stores)
return false
}, refreshStoreListInterval,
)
return c
}

// Try to refresh full store list. Errors are ignored.
func refreshFullStoreList(ctx context.Context, pdClient pd.Client, stores storeCache) {
storeList, err := pdClient.GetAllStores(ctx)
func refreshFullStoreList(ctx context.Context, stores storeCache) {
storeList, err := stores.fetchAllStores(ctx)
if err != nil {
logutil.Logger(ctx).Info("refresh full store list failed", zap.Error(err))
return
Expand All @@ -747,17 +747,16 @@ func refreshFullStoreList(ctx context.Context, pdClient pd.Client, stores storeC
if exist {
continue
}
s := stores.getOrInsertDefault(store.GetId())
// GetAllStores is supposed to return only Up and Offline stores.
// This check is being defensive and to make it consistent with store resolve code.
if store == nil || store.GetState() == metapb.StoreState_Tombstone {
s.setResolveState(tombstone)
continue
}
addr := store.GetAddress()
if addr == "" {
continue
}
s := stores.getOrInsertDefault(store.GetId())
// TODO: maybe refactor this, together with other places initializing Store
s.addr = addr
s.peerAddr = store.GetPeerAddress()
Expand Down
79 changes: 41 additions & 38 deletions txnkv/transaction/2pc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1376,8 +1376,8 @@ func keepAlive(
const broadcastRpcTimeout = time.Second * 5
const broadcastMaxConcurrency = 10

// broadcastToAllStores asynchronously broadcasts the transaction status to all stores
// errors are ignored.
// broadcastToAllStores asynchronously broadcasts the transaction status to all stores.
// Errors are ignored.
func broadcastToAllStores(
txn *KVTxn,
store kvstore,
Expand All @@ -1388,53 +1388,56 @@ func broadcastToAllStores(
) {
broadcastFunc := func() {
stores := store.GetRegionCache().GetStoresByType(tikvrpc.TiKV)
req := tikvrpc.NewRequest(
tikvrpc.CmdBroadcastTxnStatus, &kvrpcpb.BroadcastTxnStatusRequest{
TxnStatus: []*kvrpcpb.TxnStatus{status},
},
)
req.Context.ClusterId = store.GetClusterID()
req.Context.ResourceControlContext = &kvrpcpb.ResourceControlContext{
ResourceGroupName: resourceGroupName,
}
req.Context.ResourceGroupTag = resourceGroupTag
concurrency := min(broadcastMaxConcurrency, len(stores))
rateLimit := make(chan struct{}, concurrency)

var wg sync.WaitGroup
concurrency := min(broadcastMaxConcurrency, len(stores))
taskChan := make(chan *locate.Store, concurrency)

for i := 0; i < concurrency; i++ {
for _, s := range stores {
rateLimit <- struct{}{}
wg.Add(1)
if err := txn.spawnWithStorePool(func() {
target := s

err := txn.spawnWithStorePool(func() {
defer wg.Done()
for s := range taskChan {
_, err := store.GetTiKVClient().SendRequest(
bo.GetCtx(),
s.GetAddr(),
req,
broadcastRpcTimeout,
defer func() { <-rateLimit }()

req := tikvrpc.NewRequest(
tikvrpc.CmdBroadcastTxnStatus, &kvrpcpb.BroadcastTxnStatusRequest{
TxnStatus: []*kvrpcpb.TxnStatus{status},
},
)
req.Context.ClusterId = store.GetClusterID()
req.Context.ResourceControlContext = &kvrpcpb.ResourceControlContext{
ResourceGroupName: resourceGroupName,
}
req.Context.ResourceGroupTag = resourceGroupTag

_, err := store.GetTiKVClient().SendRequest(
bo.GetCtx(),
target.GetAddr(),
req,
broadcastRpcTimeout,
)
if err != nil {
logutil.Logger(store.Ctx()).Info(
"broadcast txn status failed",
zap.Uint64("storeID", target.StoreID()),
zap.String("storeAddr", target.GetAddr()),
zap.Stringer("status", status),
zap.Error(err),
)
if err != nil {
logutil.Logger(store.Ctx()).Info(
"broadcast txn status failed",
zap.Uint64("storeID", s.StoreID()),
zap.String("storeAddr", s.GetAddr()),
zap.Stringer("status", status),
zap.Error(err),
)
}
}
}); err != nil {
wg.Done() // Ensure wg is decremented if spawning fails
})

if err != nil {
// If spawning the goroutine fails, release the slot and mark done
<-rateLimit
wg.Done()
logutil.Logger(store.Ctx()).Error("failed to spawn worker goroutine", zap.Error(err))
}
}

for _, s := range stores {
taskChan <- s
}

close(taskChan)
wg.Wait()
}

Expand Down

0 comments on commit 4aada7a

Please sign in to comment.