From 5accbc3863fccfc5b7178ec8f70e300b95a4eb7e Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Wed, 12 Jul 2023 10:14:19 -0600 Subject: [PATCH 1/3] Cache the start key for message pruning --- arbnode/message_pruner.go | 46 +++++++++++++++++++++------------- arbnode/message_pruner_test.go | 38 +++++++++++++++------------- 2 files changed, 49 insertions(+), 35 deletions(-) diff --git a/arbnode/message_pruner.go b/arbnode/message_pruner.go index aeee07ca73..a0aa86050f 100644 --- a/arbnode/message_pruner.go +++ b/arbnode/message_pruner.go @@ -23,11 +23,13 @@ import ( type MessagePruner struct { stopwaiter.StopWaiter - transactionStreamer *TransactionStreamer - inboxTracker *InboxTracker - config MessagePrunerConfigFetcher - pruningLock sync.Mutex - lastPruneDone time.Time + transactionStreamer *TransactionStreamer + inboxTracker *InboxTracker + config MessagePrunerConfigFetcher + pruningLock sync.Mutex + lastPruneDone time.Time + cachedPrunedMessages uint64 + cachedPrunedDelayedMessages uint64 } type MessagePrunerConfig struct { @@ -108,11 +110,11 @@ func (m *MessagePruner) prune(ctx context.Context, count arbutil.MessageIndex, g msgCount := endBatchMetadata.MessageCount delayedCount := endBatchMetadata.DelayedMessageCount - return deleteOldMessageFromDB(ctx, msgCount, delayedCount, m.inboxTracker.db, m.transactionStreamer.db) + return m.deleteOldMessagesFromDB(ctx, msgCount, delayedCount) } -func deleteOldMessageFromDB(ctx context.Context, messageCount arbutil.MessageIndex, delayedMessageCount uint64, inboxTrackerDb ethdb.Database, transactionStreamerDb ethdb.Database) error { - prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(ctx, transactionStreamerDb, messagePrefix, uint64(messageCount)) +func (m *MessagePruner) deleteOldMessagesFromDB(ctx context.Context, messageCount arbutil.MessageIndex, delayedMessageCount uint64) error { + prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(ctx, m.transactionStreamer.db, messagePrefix, &m.cachedPrunedMessages, uint64(messageCount)) if err != nil { return fmt.Errorf("error deleting last batch messages: %w", err) } @@ -120,7 +122,7 @@ func deleteOldMessageFromDB(ctx context.Context, messageCount arbutil.MessageInd log.Info("Pruned last batch messages:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) } - prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(ctx, inboxTrackerDb, rlpDelayedMessagePrefix, delayedMessageCount) + prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(ctx, m.inboxTracker.db, rlpDelayedMessagePrefix, &m.cachedPrunedDelayedMessages, delayedMessageCount) if err != nil { return fmt.Errorf("error deleting last batch delayed messages: %w", err) } @@ -130,15 +132,25 @@ func deleteOldMessageFromDB(ctx context.Context, messageCount arbutil.MessageInd return nil } -func deleteFromLastPrunedUptoEndKey(ctx context.Context, db ethdb.Database, prefix []byte, endMinKey uint64) ([]uint64, error) { - startIter := db.NewIterator(prefix, uint64ToKey(1)) - if !startIter.Next() { +// deleteFromLastPrunedUptoEndKey is similar to deleteFromRange but automatically populates the start key +// cachedStartMinKey must not be nil. It's set to the new start key at the end of this function if successful. +func deleteFromLastPrunedUptoEndKey(ctx context.Context, db ethdb.Database, prefix []byte, cachedStartMinKey *uint64, endMinKey uint64) ([]uint64, error) { + startMinKey := *cachedStartMinKey + if startMinKey == 0 { + startIter := db.NewIterator(prefix, uint64ToKey(1)) + if !startIter.Next() { + return nil, nil + } + startMinKey = binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) + startIter.Release() + } + if endMinKey <= startMinKey { + *cachedStartMinKey = startMinKey return nil, nil } - startMinKey := binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) - startIter.Release() - if endMinKey > startMinKey { - return deleteFromRange(ctx, db, prefix, startMinKey, endMinKey-1) + keys, err := deleteFromRange(ctx, db, prefix, startMinKey, endMinKey-1) + if err == nil { + *cachedStartMinKey = endMinKey - 1 } - return nil, nil + return keys, err } diff --git a/arbnode/message_pruner_test.go b/arbnode/message_pruner_test.go index c0cb2cb4fe..0212ed2364 100644 --- a/arbnode/message_pruner_test.go +++ b/arbnode/message_pruner_test.go @@ -17,8 +17,8 @@ func TestMessagePrunerWithPruningEligibleMessagePresent(t *testing.T) { defer cancel() messagesCount := uint64(2 * 100 * 1024) - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*100*1024, 2*100*1024) - err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + inboxTrackerDb, transactionStreamerDb, pruner := setupDatabase(t, 2*100*1024, 2*100*1024) + err := pruner.deleteOldMessagesFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount) Require(t, err) checkDbKeys(t, messagesCount, transactionStreamerDb, messagePrefix) @@ -26,22 +26,21 @@ func TestMessagePrunerWithPruningEligibleMessagePresent(t *testing.T) { } -func TestMessagePrunerTraverseEachMessageOnlyOnce(t *testing.T) { +func TestMessagePrunerTwoHalves(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() messagesCount := uint64(10) - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, messagesCount, messagesCount) - // In first iteration message till messagesCount are tried to be deleted. - err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + _, transactionStreamerDb, pruner := setupDatabase(t, messagesCount, messagesCount) + // In first iteration message till messagesCount/2 are tried to be deleted. + err := pruner.deleteOldMessagesFromDB(ctx, arbutil.MessageIndex(messagesCount/2), messagesCount/2) Require(t, err) - // After first iteration messagesCount/2 is reinserted in inbox db - err = inboxTrackerDb.Put(dbKey(messagePrefix, messagesCount/2), []byte{}) + // In first iteration all the message till messagesCount/2 are deleted. + checkDbKeys(t, messagesCount/2, transactionStreamerDb, messagePrefix) + // In second iteration message till messagesCount are tried to be deleted. + err = pruner.deleteOldMessagesFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount) Require(t, err) - // In second iteration message till messagesCount are again tried to be deleted. - err = deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) - Require(t, err) - // In second iteration all the message till messagesCount are deleted again. + // In second iteration all the message till messagesCount are deleted. checkDbKeys(t, messagesCount, transactionStreamerDb, messagePrefix) } @@ -50,10 +49,10 @@ func TestMessagePrunerPruneTillLessThenEqualTo(t *testing.T) { defer cancel() messagesCount := uint64(10) - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*messagesCount, 20) + inboxTrackerDb, transactionStreamerDb, pruner := setupDatabase(t, 2*messagesCount, 20) err := inboxTrackerDb.Delete(dbKey(messagePrefix, 9)) Require(t, err) - err = deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + err = pruner.deleteOldMessagesFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount) Require(t, err) hasKey, err := transactionStreamerDb.Has(dbKey(messagePrefix, messagesCount)) Require(t, err) @@ -67,8 +66,8 @@ func TestMessagePrunerWithNoPruningEligibleMessagePresent(t *testing.T) { defer cancel() messagesCount := uint64(10) - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, messagesCount, messagesCount) - err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + inboxTrackerDb, transactionStreamerDb, pruner := setupDatabase(t, messagesCount, messagesCount) + err := pruner.deleteOldMessagesFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount) Require(t, err) checkDbKeys(t, uint64(messagesCount), transactionStreamerDb, messagePrefix) @@ -76,7 +75,7 @@ func TestMessagePrunerWithNoPruningEligibleMessagePresent(t *testing.T) { } -func setupDatabase(t *testing.T, messageCount, delayedMessageCount uint64) (ethdb.Database, ethdb.Database) { +func setupDatabase(t *testing.T, messageCount, delayedMessageCount uint64) (ethdb.Database, ethdb.Database, *MessagePruner) { transactionStreamerDb := rawdb.NewMemoryDatabase() for i := uint64(0); i < uint64(messageCount); i++ { @@ -90,7 +89,10 @@ func setupDatabase(t *testing.T, messageCount, delayedMessageCount uint64) (ethd Require(t, err) } - return inboxTrackerDb, transactionStreamerDb + return inboxTrackerDb, transactionStreamerDb, &MessagePruner{ + transactionStreamer: &TransactionStreamer{db: transactionStreamerDb}, + inboxTracker: &InboxTracker{db: inboxTrackerDb}, + } } func checkDbKeys(t *testing.T, endCount uint64, db ethdb.Database, prefix []byte) { From 65a852972ff9a0b50446815f7f6ae15e58005ceb Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Wed, 12 Jul 2023 10:49:03 -0600 Subject: [PATCH 2/3] Use latest confirmed instead of latest staked for message pruner --- arbnode/message_pruner.go | 2 +- arbnode/node.go | 6 +- staker/staker.go | 102 +++++++++++++++++++++------- staker/stateless_block_validator.go | 1 + system_tests/staker_test.go | 3 + util/headerreader/header_reader.go | 4 ++ 6 files changed, 91 insertions(+), 27 deletions(-) diff --git a/arbnode/message_pruner.go b/arbnode/message_pruner.go index a0aa86050f..b469ecdbef 100644 --- a/arbnode/message_pruner.go +++ b/arbnode/message_pruner.go @@ -64,7 +64,7 @@ func (m *MessagePruner) Start(ctxIn context.Context) { m.StopWaiter.Start(ctxIn, m) } -func (m *MessagePruner) UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) { +func (m *MessagePruner) UpdateLatestConfirmed(count arbutil.MessageIndex, globalState validator.GoGlobalState) { locked := m.pruningLock.TryLock() if !locked { return diff --git a/arbnode/node.go b/arbnode/node.go index 8a4f38f28c..bd5605346b 100644 --- a/arbnode/node.go +++ b/arbnode/node.go @@ -825,13 +825,13 @@ func createNodeImpl( } } - notifiers := make([]staker.LatestStakedNotifier, 0) + var confirmedNotifiers []staker.LatestConfirmedNotifier if config.MessagePruner.Enable && !config.Caching.Archive { messagePruner = NewMessagePruner(txStreamer, inboxTracker, func() *MessagePrunerConfig { return &configFetcher.Get().MessagePruner }) - notifiers = append(notifiers, messagePruner) + confirmedNotifiers = append(confirmedNotifiers, messagePruner) } - stakerObj, err = staker.NewStaker(l1Reader, wallet, bind.CallOpts{}, config.Staker, blockValidator, statelessBlockValidator, notifiers, deployInfo.ValidatorUtils, fatalErrChan) + stakerObj, err = staker.NewStaker(l1Reader, wallet, bind.CallOpts{}, config.Staker, blockValidator, statelessBlockValidator, nil, confirmedNotifiers, deployInfo.ValidatorUtils, fatalErrChan) if err != nil { return nil, err } diff --git a/staker/staker.go b/staker/staker.go index 09a05daad2..f360a60a7d 100644 --- a/staker/staker.go +++ b/staker/staker.go @@ -17,6 +17,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/rpc" flag "github.com/spf13/pflag" "github.com/offchainlabs/nitro/arbutil" @@ -30,6 +31,7 @@ var ( stakerBalanceGauge = metrics.NewRegisteredGaugeFloat64("arb/staker/balance", nil) stakerAmountStakedGauge = metrics.NewRegisteredGauge("arb/staker/amount_staked", nil) stakerLatestStakedNodeGauge = metrics.NewRegisteredGauge("arb/staker/staked_node", nil) + stakerLatestConfirmedNodeGauge = metrics.NewRegisteredGauge("arb/staker/confirmed_node", nil) stakerLastSuccessfulActionGauge = metrics.NewRegisteredGauge("arb/staker/action/last_success", nil) stakerActionSuccessCounter = metrics.NewRegisteredCounter("arb/staker/action/success", nil) stakerActionFailureCounter = metrics.NewRegisteredCounter("arb/staker/action/failure", nil) @@ -195,11 +197,16 @@ type LatestStakedNotifier interface { UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) } +type LatestConfirmedNotifier interface { + UpdateLatestConfirmed(count arbutil.MessageIndex, globalState validator.GoGlobalState) +} + type Staker struct { *L1Validator stopwaiter.StopWaiter l1Reader L1ReaderInterface - notifiers []LatestStakedNotifier + stakedNotifiers []LatestStakedNotifier + confirmedNotifiers []LatestConfirmedNotifier activeChallenge *ChallengeManager baseCallOpts bind.CallOpts config L1ValidatorConfig @@ -219,7 +226,8 @@ func NewStaker( config L1ValidatorConfig, blockValidator *BlockValidator, statelessBlockValidator *StatelessBlockValidator, - notifiers []LatestStakedNotifier, + stakedNotifiers []LatestStakedNotifier, + confirmedNotifiers []LatestConfirmedNotifier, validatorUtilsAddress common.Address, fatalErr chan<- error, ) (*Staker, error) { @@ -235,12 +243,13 @@ func NewStaker( } stakerLastSuccessfulActionGauge.Update(time.Now().Unix()) if config.StartFromStaked { - notifiers = append(notifiers, blockValidator) + stakedNotifiers = append(stakedNotifiers, blockValidator) } return &Staker{ L1Validator: val, l1Reader: l1Reader, - notifiers: notifiers, + stakedNotifiers: stakedNotifiers, + confirmedNotifiers: confirmedNotifiers, baseCallOpts: callOpts, config: config, highGasBlocksBuffer: big.NewInt(config.L1PostingStrategy.HighGasDelayBlocks), @@ -280,8 +289,42 @@ func (s *Staker) Initialize(ctx context.Context) error { return nil } +func (s *Staker) latestNodeDetailsForUpdate(ctx context.Context, description string, node uint64) (arbutil.MessageIndex, *validator.GoGlobalState, error) { + stakedInfo, err := s.rollup.LookupNode(ctx, node) + if err != nil { + return 0, nil, fmt.Errorf("couldn't look up latest %v assertion %v: %w", description, node, err) + } + + globalState := stakedInfo.AfterState().GlobalState + caughtUp, count, err := GlobalStateToMsgCount(s.inboxTracker, s.txStreamer, globalState) + if err != nil { + if errors.Is(err, ErrGlobalStateNotInChain) && s.fatalErr != nil { + fatal := fmt.Errorf("latest %v assertion %v not in chain: %w", description, node, err) + s.fatalErr <- fatal + } + return 0, nil, fmt.Errorf("latest %v assertion %v: %w", description, node, err) + } + + if !caughtUp { + log.Info(fmt.Sprintf("latest %v assertion not yet in our node", description), "assertion", node, "state", globalState) + return 0, nil, nil + } + + processedCount, err := s.txStreamer.GetProcessedMessageCount() + if err != nil { + return 0, nil, err + } + + if processedCount < count { + log.Info("execution catching up to rollup", "lookingFor", description, "rollupCount", count, "processedCount", processedCount) + return 0, nil, nil + } + + return count, &globalState, nil +} + func (s *Staker) checkLatestStaked(ctx context.Context) error { - latestStaked, _, err := s.validatorUtils.LatestStaked(&s.baseCallOpts, s.rollupAddress, s.wallet.AddressOrZero()) + latestStaked, _, err := s.validatorUtils.LatestStaked(s.getCallOpts(ctx), s.rollupAddress, s.wallet.AddressOrZero()) if err != nil { return fmt.Errorf("couldn't get LatestStaked: %w", err) } @@ -290,38 +333,44 @@ func (s *Staker) checkLatestStaked(ctx context.Context) error { return nil } - stakedInfo, err := s.rollup.LookupNode(ctx, latestStaked) + count, globalState, err := s.latestNodeDetailsForUpdate(ctx, "staked", latestStaked) if err != nil { - return fmt.Errorf("couldn't look up latest node: %w", err) + return err + } + if globalState == nil { + return nil } - stakedGlobalState := stakedInfo.AfterState().GlobalState - caughtUp, count, err := GlobalStateToMsgCount(s.inboxTracker, s.txStreamer, stakedGlobalState) - if err != nil { - if errors.Is(err, ErrGlobalStateNotInChain) && s.fatalErr != nil { - fatal := fmt.Errorf("latest staked not in chain: %w", err) - s.fatalErr <- fatal - } - return fmt.Errorf("staker: latest staked %w", err) + for _, notifier := range s.stakedNotifiers { + notifier.UpdateLatestStaked(count, *globalState) } + return nil +} - if !caughtUp { - log.Info("latest valid not yet in our node", "staked", stakedGlobalState) +func (s *Staker) checkLatestConfirmed(ctx context.Context) error { + callOpts := s.getCallOpts(ctx) + if s.l1Reader.UseFinalityData() { + callOpts.BlockNumber = big.NewInt(int64(rpc.FinalizedBlockNumber)) + } + latestConfirmed, err := s.rollup.LatestConfirmed(callOpts) + if err != nil { + return fmt.Errorf("couldn't get LatestConfirmed: %w", err) + } + stakerLatestConfirmedNodeGauge.Update(int64(latestConfirmed)) + if latestConfirmed == 0 { return nil } - processedCount, err := s.txStreamer.GetProcessedMessageCount() + count, globalState, err := s.latestNodeDetailsForUpdate(ctx, "confirmed", latestConfirmed) if err != nil { return err } - - if processedCount < count { - log.Info("execution catching up to last validated", "validatedCount", count, "processedCount", processedCount) + if globalState == nil { return nil } - for _, notifier := range s.notifiers { - notifier.UpdateLatestStaked(count, stakedGlobalState) + for _, notifier := range s.confirmedNotifiers { + notifier.UpdateLatestConfirmed(count, *globalState) } return nil } @@ -387,6 +436,13 @@ func (s *Staker) Start(ctxIn context.Context) { } return s.config.StakerInterval }) + s.CallIteratively(func(ctx context.Context) time.Duration { + err := s.checkLatestConfirmed(ctx) + if err != nil && ctx.Err() == nil { + log.Error("staker: error checking latest confirmed", "err", err) + } + return s.config.StakerInterval + }) } func (s *Staker) IsWhitelisted(ctx context.Context) (bool, error) { diff --git a/staker/stateless_block_validator.go b/staker/stateless_block_validator.go index 0242daa3c7..7add3e258d 100644 --- a/staker/stateless_block_validator.go +++ b/staker/stateless_block_validator.go @@ -84,6 +84,7 @@ type L1ReaderInterface interface { Client() arbutil.L1Interface Subscribe(bool) (<-chan *types.Header, func()) WaitForTxApproval(ctx context.Context, tx *types.Transaction) (*types.Receipt, error) + UseFinalityData() bool } type GlobalStatePosition struct { diff --git a/system_tests/staker_test.go b/system_tests/staker_test.go index 7a3ae41814..8a8ef29bf3 100644 --- a/system_tests/staker_test.go +++ b/system_tests/staker_test.go @@ -166,6 +166,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) nil, statelessA, nil, + nil, l2nodeA.DeployInfo.ValidatorUtils, nil, ) @@ -201,6 +202,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) nil, statelessB, nil, + nil, l2nodeB.DeployInfo.ValidatorUtils, nil, ) @@ -223,6 +225,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) nil, statelessA, nil, + nil, l2nodeA.DeployInfo.ValidatorUtils, nil, ) diff --git a/util/headerreader/header_reader.go b/util/headerreader/header_reader.go index f25fed02d9..1d52e8f78d 100644 --- a/util/headerreader/header_reader.go +++ b/util/headerreader/header_reader.go @@ -434,6 +434,10 @@ func (s *HeaderReader) Client() arbutil.L1Interface { return s.client } +func (s *HeaderReader) UseFinalityData() bool { + return s.config().UseFinalityData +} + func (s *HeaderReader) Start(ctxIn context.Context) { s.StopWaiter.Start(ctxIn, s) s.LaunchThread(s.broadcastLoop) From 65c79942b1303848e0e85d0481fb47f9cc0a9b51 Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Thu, 13 Jul 2023 21:14:12 -0600 Subject: [PATCH 3/3] Unify latest staked and latest confirmed checking --- staker/staker.go | 117 +++++++++++++++---------------------- staker/validator_wallet.go | 28 +++++---- 2 files changed, 63 insertions(+), 82 deletions(-) diff --git a/staker/staker.go b/staker/staker.go index f360a60a7d..230c381d05 100644 --- a/staker/staker.go +++ b/staker/staker.go @@ -242,7 +242,7 @@ func NewStaker( return nil, err } stakerLastSuccessfulActionGauge.Update(time.Now().Unix()) - if config.StartFromStaked { + if config.StartFromStaked && blockValidator != nil { stakedNotifiers = append(stakedNotifiers, blockValidator) } return &Staker{ @@ -289,90 +289,50 @@ func (s *Staker) Initialize(ctx context.Context) error { return nil } -func (s *Staker) latestNodeDetailsForUpdate(ctx context.Context, description string, node uint64) (arbutil.MessageIndex, *validator.GoGlobalState, error) { - stakedInfo, err := s.rollup.LookupNode(ctx, node) +func (s *Staker) getLatestStakedState(ctx context.Context, staker common.Address) (uint64, arbutil.MessageIndex, *validator.GoGlobalState, error) { + callOpts := s.getCallOpts(ctx) + if s.l1Reader.UseFinalityData() { + callOpts.BlockNumber = big.NewInt(int64(rpc.FinalizedBlockNumber)) + } + latestStaked, _, err := s.validatorUtils.LatestStaked(s.getCallOpts(ctx), s.rollupAddress, staker) + if err != nil { + return 0, 0, nil, fmt.Errorf("couldn't get LatestStaked(%v): %w", staker, err) + } + if latestStaked == 0 { + return latestStaked, 0, nil, nil + } + + stakedInfo, err := s.rollup.LookupNode(ctx, latestStaked) if err != nil { - return 0, nil, fmt.Errorf("couldn't look up latest %v assertion %v: %w", description, node, err) + return 0, 0, nil, fmt.Errorf("couldn't look up latest assertion of %v (%v): %w", staker, latestStaked, err) } globalState := stakedInfo.AfterState().GlobalState caughtUp, count, err := GlobalStateToMsgCount(s.inboxTracker, s.txStreamer, globalState) if err != nil { if errors.Is(err, ErrGlobalStateNotInChain) && s.fatalErr != nil { - fatal := fmt.Errorf("latest %v assertion %v not in chain: %w", description, node, err) + fatal := fmt.Errorf("latest assertion of %v (%v) not in chain: %w", staker, latestStaked, err) s.fatalErr <- fatal } - return 0, nil, fmt.Errorf("latest %v assertion %v: %w", description, node, err) + return 0, 0, nil, fmt.Errorf("latest assertion of %v (%v): %w", staker, latestStaked, err) } if !caughtUp { - log.Info(fmt.Sprintf("latest %v assertion not yet in our node", description), "assertion", node, "state", globalState) - return 0, nil, nil + log.Info("latest assertion not yet in our node", "staker", staker, "assertion", latestStaked, "state", globalState) + return latestStaked, 0, nil, nil } processedCount, err := s.txStreamer.GetProcessedMessageCount() if err != nil { - return 0, nil, err + return 0, 0, nil, err } if processedCount < count { - log.Info("execution catching up to rollup", "lookingFor", description, "rollupCount", count, "processedCount", processedCount) - return 0, nil, nil + log.Info("execution catching up to rollup", "staker", staker, "rollupCount", count, "processedCount", processedCount) + return latestStaked, 0, nil, nil } - return count, &globalState, nil -} - -func (s *Staker) checkLatestStaked(ctx context.Context) error { - latestStaked, _, err := s.validatorUtils.LatestStaked(s.getCallOpts(ctx), s.rollupAddress, s.wallet.AddressOrZero()) - if err != nil { - return fmt.Errorf("couldn't get LatestStaked: %w", err) - } - stakerLatestStakedNodeGauge.Update(int64(latestStaked)) - if latestStaked == 0 { - return nil - } - - count, globalState, err := s.latestNodeDetailsForUpdate(ctx, "staked", latestStaked) - if err != nil { - return err - } - if globalState == nil { - return nil - } - - for _, notifier := range s.stakedNotifiers { - notifier.UpdateLatestStaked(count, *globalState) - } - return nil -} - -func (s *Staker) checkLatestConfirmed(ctx context.Context) error { - callOpts := s.getCallOpts(ctx) - if s.l1Reader.UseFinalityData() { - callOpts.BlockNumber = big.NewInt(int64(rpc.FinalizedBlockNumber)) - } - latestConfirmed, err := s.rollup.LatestConfirmed(callOpts) - if err != nil { - return fmt.Errorf("couldn't get LatestConfirmed: %w", err) - } - stakerLatestConfirmedNodeGauge.Update(int64(latestConfirmed)) - if latestConfirmed == 0 { - return nil - } - - count, globalState, err := s.latestNodeDetailsForUpdate(ctx, "confirmed", latestConfirmed) - if err != nil { - return err - } - if globalState == nil { - return nil - } - - for _, notifier := range s.confirmedNotifiers { - notifier.UpdateLatestConfirmed(count, *globalState) - } - return nil + return latestStaked, count, &globalState, nil } func (s *Staker) Start(ctxIn context.Context) { @@ -430,16 +390,31 @@ func (s *Staker) Start(ctxIn context.Context) { return backoff }) s.CallIteratively(func(ctx context.Context) time.Duration { - err := s.checkLatestStaked(ctx) + wallet := s.wallet.AddressOrZero() + staked, stakedMsgCount, stakedGlobalState, err := s.getLatestStakedState(ctx, wallet) if err != nil && ctx.Err() == nil { log.Error("staker: error checking latest staked", "err", err) } - return s.config.StakerInterval - }) - s.CallIteratively(func(ctx context.Context) time.Duration { - err := s.checkLatestConfirmed(ctx) - if err != nil && ctx.Err() == nil { - log.Error("staker: error checking latest confirmed", "err", err) + stakerLatestStakedNodeGauge.Update(int64(staked)) + if stakedGlobalState != nil { + for _, notifier := range s.stakedNotifiers { + notifier.UpdateLatestStaked(stakedMsgCount, *stakedGlobalState) + } + } + confirmed := staked + confirmedMsgCount := stakedMsgCount + confirmedGlobalState := stakedGlobalState + if wallet != (common.Address{}) { + confirmed, confirmedMsgCount, confirmedGlobalState, err = s.getLatestStakedState(ctx, common.Address{}) + if err != nil && ctx.Err() == nil { + log.Error("staker: error checking latest confirmed", "err", err) + } + } + stakerLatestConfirmedNodeGauge.Update(int64(confirmed)) + if confirmedGlobalState != nil { + for _, notifier := range s.confirmedNotifiers { + notifier.UpdateLatestConfirmed(confirmedMsgCount, *confirmedGlobalState) + } } return s.config.StakerInterval }) diff --git a/staker/validator_wallet.go b/staker/validator_wallet.go index c36efa7b61..d878749f35 100644 --- a/staker/validator_wallet.go +++ b/staker/validator_wallet.go @@ -8,6 +8,7 @@ import ( "errors" "math/big" "strings" + "sync/atomic" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts/abi" @@ -38,7 +39,9 @@ func init() { type ValidatorWalletInterface interface { Initialize(context.Context) error + // Address must be able to be called concurrently with other functions Address() *common.Address + // Address must be able to be called concurrently with other functions AddressOrZero() common.Address TxSenderAddress() *common.Address RollupAddress() common.Address @@ -53,7 +56,7 @@ type ValidatorWalletInterface interface { type ContractValidatorWallet struct { con *rollupgen.ValidatorWallet - address *common.Address + address atomic.Pointer[common.Address] onWalletCreated func(common.Address) l1Reader L1ReaderInterface auth *bind.TransactOpts @@ -79,9 +82,8 @@ func NewContractValidatorWallet(address *common.Address, walletFactoryAddr, roll if err != nil { return nil, err } - return &ContractValidatorWallet{ + wallet := &ContractValidatorWallet{ con: con, - address: address, onWalletCreated: onWalletCreated, l1Reader: l1Reader, auth: auth, @@ -89,7 +91,10 @@ func NewContractValidatorWallet(address *common.Address, walletFactoryAddr, roll rollupAddress: rollupAddress, rollup: rollup, rollupFromBlock: rollupFromBlock, - }, nil + } + // Go complains if we make an address variable before wallet and copy it in + wallet.address.Store(address) + return wallet, nil } func (v *ContractValidatorWallet) validateWallet(ctx context.Context) error { @@ -127,15 +132,16 @@ func (v *ContractValidatorWallet) Initialize(ctx context.Context) error { // May be the nil if the wallet hasn't been deployed yet func (v *ContractValidatorWallet) Address() *common.Address { - return v.address + return v.address.Load() } // May be zero if the wallet hasn't been deployed yet func (v *ContractValidatorWallet) AddressOrZero() common.Address { - if v.address == nil { + addr := v.address.Load() + if addr == nil { return common.Address{} } - return *v.address + return *addr } func (v *ContractValidatorWallet) TxSenderAddress() *common.Address { @@ -183,7 +189,7 @@ func (v *ContractValidatorWallet) populateWallet(ctx context.Context, createIfMi } return nil } - if v.address == nil { + if v.address.Load() == nil { auth, err := v.getAuth(ctx, nil) if err != nil { return err @@ -195,12 +201,12 @@ func (v *ContractValidatorWallet) populateWallet(ctx context.Context, createIfMi if addr == nil { return nil } - v.address = addr + v.address.Store(addr) if v.onWalletCreated != nil { v.onWalletCreated(*addr) } } - con, err := rollupgen.NewValidatorWallet(*v.address, v.l1Reader.Client()) + con, err := rollupgen.NewValidatorWallet(*v.Address(), v.l1Reader.Client()) if err != nil { return err } @@ -260,7 +266,7 @@ func (v *ContractValidatorWallet) ExecuteTransactions(ctx context.Context, build totalAmount = totalAmount.Add(totalAmount, tx.Value()) } - balanceInContract, err := v.l1Reader.Client().BalanceAt(ctx, *v.address, nil) + balanceInContract, err := v.l1Reader.Client().BalanceAt(ctx, *v.Address(), nil) if err != nil { return nil, err }