diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e1a591b720..c77ef3b770 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -132,8 +132,8 @@ jobs: - name: Set environment variables run: | - mkdir -p target/tmp - echo "TMPDIR=$(pwd)/target/tmp" >> "$GITHUB_ENV" + mkdir -p target/tmp/deadbeefbee + echo "TMPDIR=$(pwd)/target/tmp/deadbeefbee" >> "$GITHUB_ENV" echo "GOMEMLIMIT=6GiB" >> "$GITHUB_ENV" echo "GOGC=80" >> "$GITHUB_ENV" diff --git a/Dockerfile b/Dockerfile index 12010d5dac..367d76d4b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -160,6 +160,7 @@ COPY ./scripts/download-machine.sh . #RUN ./download-machine.sh consensus-v9 0xd1842bfbe047322b3f3b3635b5fe62eb611557784d17ac1d2b1ce9c170af6544 RUN ./download-machine.sh consensus-v10 0x6b94a7fc388fd8ef3def759297828dc311761e88d8179c7ee8d3887dc554f3c3 RUN ./download-machine.sh consensus-v10.1 0xda4e3ad5e7feacb817c21c8d0220da7650fe9051ece68a3f0b1c5d38bbb27b21 +RUN ./download-machine.sh consensus-v10.2 0x0754e09320c381566cc0449904c377a52bd34a6b9404432e80afd573b67f7b17 FROM golang:1.20-bullseye as node-builder WORKDIR /workspace diff --git a/arbnode/api.go b/arbnode/api.go index 057c03bf31..d28d7481d9 100644 --- a/arbnode/api.go +++ b/arbnode/api.go @@ -9,23 +9,17 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/rpc" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/staker" + "github.com/offchainlabs/nitro/validator" ) type BlockValidatorAPI struct { val *staker.BlockValidator } -func (a *BlockValidatorAPI) LatestValidatedBlock(ctx context.Context) (hexutil.Uint64, error) { - block := a.val.LastBlockValidated() - return hexutil.Uint64(block), nil -} - -func (a *BlockValidatorAPI) LatestValidatedBlockHash(ctx context.Context) (common.Hash, error) { - _, hash, _ := a.val.LastBlockValidatedAndHash() - return hash, nil +func (a *BlockValidatorAPI) LatestValidated(ctx context.Context) (*staker.GlobalStateValidatedInfo, error) { + return a.val.ReadLastValidatedInfo() } type BlockValidatorDebugAPI struct { @@ -34,25 +28,16 @@ type BlockValidatorDebugAPI struct { } type ValidateBlockResult struct { - Valid bool `json:"valid"` - Latency string `json:"latency"` + Valid bool `json:"valid"` + Latency string `json:"latency"` + GlobalState validator.GoGlobalState `json:"globalstate"` } -func (a *BlockValidatorDebugAPI) ValidateBlock( - ctx context.Context, blockNum rpc.BlockNumber, full bool, moduleRootOptional *common.Hash, +func (a *BlockValidatorDebugAPI) ValidateMessageNumber( + ctx context.Context, msgNum hexutil.Uint64, full bool, moduleRootOptional *common.Hash, ) (ValidateBlockResult, error) { result := ValidateBlockResult{} - if blockNum < 0 { - return result, errors.New("this method only accepts absolute block numbers") - } - header := a.blockchain.GetHeaderByNumber(uint64(blockNum)) - if header == nil { - return result, errors.New("block not found") - } - if !a.blockchain.Config().IsArbitrumNitro(header.Number) { - return result, types.ErrUseFallback - } var moduleRoot common.Hash if moduleRootOptional != nil { moduleRoot = *moduleRootOptional @@ -64,8 +49,11 @@ func (a *BlockValidatorDebugAPI) ValidateBlock( moduleRoot = moduleRoots[0] } start_time := time.Now() - valid, err := a.val.ValidateBlock(ctx, header, full, moduleRoot) - result.Valid = valid + valid, gs, err := a.val.ValidateResult(ctx, arbutil.MessageIndex(msgNum), full, moduleRoot) result.Latency = fmt.Sprintf("%vms", time.Since(start_time).Milliseconds()) + if gs != nil { + result.GlobalState = *gs + } + result.Valid = valid return result, err } diff --git a/arbnode/batch_poster.go b/arbnode/batch_poster.go index 3e5e6a738f..5aa07f5157 100644 --- a/arbnode/batch_poster.go +++ b/arbnode/batch_poster.go @@ -10,21 +10,24 @@ import ( "errors" "fmt" "math/big" + "sync/atomic" "time" "github.com/andybalholm/brotli" - flag "github.com/spf13/pflag" + "github.com/spf13/pflag" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" - "github.com/offchainlabs/nitro/arbnode/dataposter" + "github.com/offchainlabs/nitro/arbnode/dataposter/storage" "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbstate" "github.com/offchainlabs/nitro/arbutil" @@ -66,6 +69,8 @@ type BatchPoster struct { redisLock *SimpleRedisLock firstAccErr time.Time // first time a continuous missing accumulator occurred backlog uint64 // An estimate of the number of unposted batches + + batchReverted atomic.Bool // indicates whether data poster batch was reverted } type BatchPosterConfig struct { @@ -101,7 +106,7 @@ func (c *BatchPosterConfig) Validate() error { type BatchPosterConfigFetcher func() *BatchPosterConfig -func BatchPosterConfigAddOptions(prefix string, f *flag.FlagSet) { +func BatchPosterConfigAddOptions(prefix string, f *pflag.FlagSet) { f.Bool(prefix+".enable", DefaultBatchPosterConfig.Enable, "enable posting batches to l1") f.Bool(prefix+".disable-das-fallback-store-data-on-chain", DefaultBatchPosterConfig.DisableDasFallbackStoreDataOnChain, "If unable to batch to DAS, disable fallback storing data on chain") f.Int(prefix+".max-size", DefaultBatchPosterConfig.MaxBatchSize, "maximum batch size") @@ -158,7 +163,7 @@ var TestBatchPosterConfig = BatchPosterConfig{ L1Wallet: DefaultBatchPosterL1WalletConfig, } -func NewBatchPoster(l1Reader *headerreader.HeaderReader, inbox *InboxTracker, streamer *TransactionStreamer, syncMonitor *SyncMonitor, config BatchPosterConfigFetcher, deployInfo *chaininfo.RollupAddresses, transactOpts *bind.TransactOpts, daWriter das.DataAvailabilityServiceWriter) (*BatchPoster, error) { +func NewBatchPoster(dataPosterDB ethdb.Database, l1Reader *headerreader.HeaderReader, inbox *InboxTracker, streamer *TransactionStreamer, syncMonitor *SyncMonitor, config BatchPosterConfigFetcher, deployInfo *chaininfo.RollupAddresses, transactOpts *bind.TransactOpts, daWriter das.DataAvailabilityServiceWriter) (*BatchPoster, error) { seqInbox, err := bridgegen.NewSequencerInbox(deployInfo.SequencerInbox, l1Reader.Client()) if err != nil { return nil, err @@ -201,13 +206,91 @@ func NewBatchPoster(l1Reader *headerreader.HeaderReader, inbox *InboxTracker, st dataPosterConfigFetcher := func() *dataposter.DataPosterConfig { return &config().DataPoster } - b.dataPoster, err = dataposter.NewDataPoster(l1Reader, transactOpts, redisClient, redisLock, dataPosterConfigFetcher, b.getBatchPosterPosition) + b.dataPoster, err = dataposter.NewDataPoster(dataPosterDB, l1Reader, transactOpts, redisClient, redisLock, dataPosterConfigFetcher, b.getBatchPosterPosition) if err != nil { return nil, err } return b, nil } +// checkRevert checks blocks with number in range [from, to] whether they +// contain reverted batch_poster transaction. +func (b *BatchPoster) checkReverts(ctx context.Context, from, to int64) (bool, error) { + if from > to { + return false, fmt.Errorf("wrong range, from: %d is more to: %d", from, to) + } + for idx := from; idx <= to; idx++ { + number := big.NewInt(idx) + block, err := b.l1Reader.Client().BlockByNumber(ctx, number) + if err != nil { + return false, fmt.Errorf("getting block: %v by number: %w", number, err) + } + for idx, tx := range block.Transactions() { + from, err := b.l1Reader.Client().TransactionSender(ctx, tx, block.Hash(), uint(idx)) + if err != nil { + return false, fmt.Errorf("getting sender of transaction tx: %v, %w", tx.Hash(), err) + } + if bytes.Equal(from.Bytes(), b.dataPoster.From().Bytes()) { + r, err := b.l1Reader.Client().TransactionReceipt(ctx, tx.Hash()) + if err != nil { + return false, fmt.Errorf("getting a receipt for transaction: %v, %w", tx.Hash(), err) + } + if r.Status == types.ReceiptStatusFailed { + log.Error("Transaction from batch poster reverted", "nonce", tx.Nonce(), "txHash", tx.Hash(), "blockNumber", r.BlockNumber, "blockHash", r.BlockHash) + return true, nil + } + } + } + } + return false, nil +} + +// pollForReverts runs a gouroutine that listens to l1 block headers, checks +// if any transaction made by batch poster was reverted. +func (b *BatchPoster) pollForReverts(ctx context.Context) { + headerCh, unsubscribe := b.l1Reader.Subscribe(false) + defer unsubscribe() + + last := int64(0) // number of last seen block + for { + // Poll until: + // - L1 headers reader channel is closed, or + // - polling is through context, or + // - we see a transaction in the block from dataposter that was reverted. + select { + case h, closed := <-headerCh: + if closed { + log.Info("L1 headers channel has been closed") + return + } + // If this is the first block header, set last seen as number-1. + // We may see same block number again if there is L1 reorg, in that + // case we check the block again. + if last == 0 || last == h.Number.Int64() { + last = h.Number.Int64() - 1 + } + if h.Number.Int64()-last > 100 { + log.Warn("Large gap between last seen and current block number, skipping check for reverts", "last", last, "current", h.Number) + last = h.Number.Int64() + continue + } + + reverted, err := b.checkReverts(ctx, last+1, h.Number.Int64()) + if err != nil { + log.Error("Checking batch reverts", "error", err) + continue + } + if reverted { + b.batchReverted.Store(true) + return + } + last = h.Number.Int64() + case <-ctx.Done(): + return + } + } +} + func (b *BatchPoster) getBatchPosterPosition(ctx context.Context, blockNum *big.Int) (batchPosterPosition, error) { bigInboxBatchCount, err := b.seqInbox.BatchCount(&bind.CallOpts{Context: ctx, BlockNumber: blockNum}) if err != nil { @@ -554,6 +637,9 @@ func (b *BatchPoster) estimateGas(ctx context.Context, sequencerMessage []byte, } func (b *BatchPoster) maybePostSequencerBatch(ctx context.Context) (bool, error) { + if b.batchReverted.Load() { + return false, fmt.Errorf("batch was reverted, not posting any more batches") + } nonce, batchPosition, err := b.dataPoster.GetNextNonceAndMeta(ctx) if err != nil { return false, err @@ -636,7 +722,7 @@ func (b *BatchPoster) maybePostSequencerBatch(ctx context.Context) (bool, error) cert, err := b.daWriter.Store(ctx, sequencerMsg, uint64(time.Now().Add(config.DASRetentionPeriod).Unix()), []byte{}) // b.daWriter will append signature if enabled if errors.Is(err, das.BatchToDasFailed) { if config.DisableDasFallbackStoreDataOnChain { - return false, errors.New("Unable to batch to DAS and fallback storing data on chain is disabled") + return false, errors.New("unable to batch to DAS and fallback storing data on chain is disabled") } log.Warn("Falling back to storing data on chain", "err", err) } else if err != nil { @@ -697,6 +783,7 @@ func (b *BatchPoster) Start(ctxIn context.Context) { b.dataPoster.Start(ctxIn) b.redisLock.Start(ctxIn) b.StopWaiter.Start(ctxIn, b) + b.LaunchThread(b.pollForReverts) b.CallIteratively(func(ctx context.Context) time.Duration { var err error if common.HexToAddress(b.config().GasRefunderAddress) != (common.Address{}) { @@ -723,7 +810,7 @@ func (b *BatchPoster) Start(ctxIn context.Context) { if err != nil { b.building = nil logLevel := log.Error - if errors.Is(err, AccumulatorNotFoundErr) || errors.Is(err, dataposter.ErrStorageRace) { + if errors.Is(err, AccumulatorNotFoundErr) || errors.Is(err, storage.ErrStorageRace) { // Likely the inbox tracker just isn't caught up. // Let's see if this error disappears naturally. if b.firstAccErr == (time.Time{}) { diff --git a/arbnode/dataposter/data_poster.go b/arbnode/dataposter/data_poster.go index ff0dcfebcf..1dec6ad0c4 100644 --- a/arbnode/dataposter/data_poster.go +++ b/arbnode/dataposter/data_poster.go @@ -15,16 +15,22 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" "github.com/go-redis/redis/v8" + "github.com/offchainlabs/nitro/arbnode/dataposter/leveldb" + "github.com/offchainlabs/nitro/arbnode/dataposter/slice" + "github.com/offchainlabs/nitro/arbnode/dataposter/storage" "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/util/arbmath" "github.com/offchainlabs/nitro/util/headerreader" "github.com/offchainlabs/nitro/util/signature" "github.com/offchainlabs/nitro/util/stopwaiter" - flag "github.com/spf13/pflag" + "github.com/spf13/pflag" + + redisstorage "github.com/offchainlabs/nitro/arbnode/dataposter/redis" ) type queuedTransaction[Meta any] struct { @@ -36,6 +42,8 @@ type queuedTransaction[Meta any] struct { NextReplacement time.Time } +// Note: one of the implementation of this interface (Redis storage) does not +// support duplicate values. type QueueStorage[Item any] interface { GetContents(ctx context.Context, startingIndex uint64, maxResults uint64) ([]*Item, error) GetLast(ctx context.Context) (*Item, error) @@ -55,11 +63,12 @@ type DataPosterConfig struct { UrgencyGwei float64 `koanf:"urgency-gwei" reload:"hot"` MinFeeCapGwei float64 `koanf:"min-fee-cap-gwei" reload:"hot"` MinTipCapGwei float64 `koanf:"min-tip-cap-gwei" reload:"hot"` + EnableLevelDB bool `koanf:"enable-leveldb" reload:"hot"` } type DataPosterConfigFetcher func() *DataPosterConfig -func DataPosterConfigAddOptions(prefix string, f *flag.FlagSet) { +func DataPosterConfigAddOptions(prefix string, f *pflag.FlagSet) { f.String(prefix+".replacement-times", DefaultDataPosterConfig.ReplacementTimes, "comma-separated list of durations since first posting to attempt a replace-by-fee") f.Bool(prefix+".wait-for-l1-finality", DefaultDataPosterConfig.WaitForL1Finality, "only treat a transaction as confirmed after L1 finality has been achieved (recommended)") f.Uint64(prefix+".max-mempool-transactions", DefaultDataPosterConfig.MaxMempoolTransactions, "the maximum number of transactions to have queued in the mempool at once (0 = unlimited)") @@ -68,6 +77,7 @@ func DataPosterConfigAddOptions(prefix string, f *flag.FlagSet) { f.Float64(prefix+".urgency-gwei", DefaultDataPosterConfig.UrgencyGwei, "the urgency to use for maximum fee cap calculation") f.Float64(prefix+".min-fee-cap-gwei", DefaultDataPosterConfig.MinFeeCapGwei, "the minimum fee cap to post transactions at") f.Float64(prefix+".min-tip-cap-gwei", DefaultDataPosterConfig.MinTipCapGwei, "the minimum tip cap to post transactions at") + f.Bool(prefix+".enable-leveldb", DefaultDataPosterConfig.EnableLevelDB, "uses leveldb when enabled") signature.SimpleHmacConfigAddOptions(prefix+".redis-signer", f) } @@ -78,6 +88,7 @@ var DefaultDataPosterConfig = DataPosterConfig{ UrgencyGwei: 2., MaxMempoolTransactions: 64, MinTipCapGwei: 0.05, + EnableLevelDB: false, } var TestDataPosterConfig = DataPosterConfig{ @@ -88,6 +99,7 @@ var TestDataPosterConfig = DataPosterConfig{ UrgencyGwei: 2., MaxMempoolTransactions: 64, MinTipCapGwei: 0.05, + EnableLevelDB: false, } // DataPoster must be RLP serializable and deserializable @@ -114,7 +126,7 @@ type AttemptLocker interface { AttemptLock(context.Context) bool } -func NewDataPoster[Meta any](headerReader *headerreader.HeaderReader, auth *bind.TransactOpts, redisClient redis.UniversalClient, redisLock AttemptLocker, config DataPosterConfigFetcher, metadataRetriever func(ctx context.Context, blockNum *big.Int) (Meta, error)) (*DataPoster[Meta], error) { +func NewDataPoster[Meta any](db ethdb.Database, headerReader *headerreader.HeaderReader, auth *bind.TransactOpts, redisClient redis.UniversalClient, redisLock AttemptLocker, config DataPosterConfigFetcher, metadataRetriever func(ctx context.Context, blockNum *big.Int) (Meta, error)) (*DataPoster[Meta], error) { var replacementTimes []time.Duration var lastReplacementTime time.Duration for _, s := range strings.Split(config().ReplacementTimes, ",") { @@ -134,11 +146,14 @@ func NewDataPoster[Meta any](headerReader *headerreader.HeaderReader, auth *bind // To avoid special casing "don't replace again", replace in 10 years replacementTimes = append(replacementTimes, time.Hour*24*365*10) var queue QueueStorage[queuedTransaction[Meta]] - if redisClient == nil { - queue = NewSliceStorage[queuedTransaction[Meta]]() - } else { + switch { + case config().EnableLevelDB: + queue = leveldb.New[queuedTransaction[Meta]](db) + case redisClient == nil: + queue = slice.NewStorage[queuedTransaction[Meta]]() + default: var err error - queue, err = NewRedisStorage[queuedTransaction[Meta]](redisClient, "data-poster.queue", &config().RedisSigner) + queue, err = redisstorage.NewStorage[queuedTransaction[Meta]](redisClient, "data-poster.queue", &config().RedisSigner) if err != nil { return nil, err } @@ -460,7 +475,7 @@ func (p *DataPoster[Meta]) maybeLogError(err error, tx *queuedTransaction[Meta], return } logLevel := log.Error - if errors.Is(err, ErrStorageRace) { + if errors.Is(err, storage.ErrStorageRace) { p.errorCount[nonce]++ if p.errorCount[nonce] <= maxConsecutiveIntermittentErrors { logLevel = log.Debug diff --git a/arbnode/dataposter/leveldb/leveldb.go b/arbnode/dataposter/leveldb/leveldb.go new file mode 100644 index 0000000000..c271b71267 --- /dev/null +++ b/arbnode/dataposter/leveldb/leveldb.go @@ -0,0 +1,181 @@ +// Copyright 2021-2023, Offchain Labs, Inc. +// For license information, see https://github.com/nitro/blob/master/LICENSE + +package leveldb + +import ( + "bytes" + "context" + "errors" + "fmt" + "strconv" + + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" + "github.com/syndtr/goleveldb/leveldb" +) + +// Storage implements leveldb based storage for batch poster. +type Storage[Item any] struct { + db ethdb.Database +} + +var ( + // Value at this index holds the *index* of last item. + // Keys that we never want to be accidentally deleted by "Prune()" should be + // lexicographically less than minimum index (that is "0"), hence the prefix + // ".". + lastItemIdxKey = []byte(".last_item_idx_key") + countKey = []byte(".count_key") +) + +func New[Item any](db ethdb.Database) *Storage[Item] { + return &Storage[Item]{db: db} +} + +func (s *Storage[Item]) decodeItem(data []byte) (*Item, error) { + var item Item + if err := rlp.DecodeBytes(data, &item); err != nil { + return nil, fmt.Errorf("decoding item: %w", err) + } + return &item, nil +} + +func idxToKey(idx uint64) []byte { + return []byte(fmt.Sprintf("%020d", idx)) +} + +func (s *Storage[Item]) GetContents(_ context.Context, startingIndex uint64, maxResults uint64) ([]*Item, error) { + var res []*Item + it := s.db.NewIterator([]byte(""), idxToKey(startingIndex)) + defer it.Release() + for i := 0; i < int(maxResults); i++ { + if !it.Next() { + break + } + item, err := s.decodeItem(it.Value()) + if err != nil { + return nil, err + } + res = append(res, item) + } + return res, it.Error() +} + +func (s *Storage[Item]) lastItemIdx(context.Context) ([]byte, error) { + return s.db.Get(lastItemIdxKey) +} + +func (s *Storage[Item]) GetLast(ctx context.Context) (*Item, error) { + size, err := s.Length(ctx) + if err != nil { + return nil, err + } + if size == 0 { + return nil, nil + } + lastItemIdx, err := s.lastItemIdx(ctx) + if err != nil { + return nil, fmt.Errorf("getting last item index: %w", err) + } + val, err := s.db.Get(lastItemIdx) + if err != nil { + return nil, err + } + return s.decodeItem(val) +} + +func (s *Storage[Item]) Prune(ctx context.Context, keepStartingAt uint64) error { + cnt, err := s.Length(ctx) + if err != nil { + return err + } + end := idxToKey(keepStartingAt) + it := s.db.NewIterator([]byte{}, idxToKey(0)) + defer it.Release() + b := s.db.NewBatch() + for it.Next() { + if bytes.Compare(it.Key(), end) >= 0 { + break + } + if err := b.Delete(it.Key()); err != nil { + return fmt.Errorf("deleting key: %w", err) + } + cnt-- + } + if err := b.Put(countKey, []byte(strconv.Itoa(cnt))); err != nil { + return fmt.Errorf("updating length counter: %w", err) + } + return b.Write() +} + +// valueAt gets returns the value at key. If it doesn't exist then it returns +// encoded bytes of nil. +func (s *Storage[Item]) valueAt(_ context.Context, key []byte) ([]byte, error) { + val, err := s.db.Get(key) + if err != nil { + if errors.Is(err, leveldb.ErrNotFound) { + return rlp.EncodeToBytes((*Item)(nil)) + } + return nil, err + } + return val, nil +} + +func (s *Storage[Item]) Put(ctx context.Context, index uint64, prev *Item, new *Item) error { + key := idxToKey(index) + stored, err := s.valueAt(ctx, key) + if err != nil { + return err + } + prevEnc, err := rlp.EncodeToBytes(prev) + if err != nil { + return fmt.Errorf("encoding previous item: %w", err) + } + if !bytes.Equal(stored, prevEnc) { + return fmt.Errorf("replacing different item than expected at index %v %v %v", index, stored, prevEnc) + } + newEnc, err := rlp.EncodeToBytes(new) + if err != nil { + return fmt.Errorf("encoding new item: %w", err) + } + b := s.db.NewBatch() + cnt, err := s.Length(ctx) + if err != nil { + return err + } + if err := b.Put(key, newEnc); err != nil { + return fmt.Errorf("updating value at: %v: %w", key, err) + } + lastItemIdx, err := s.lastItemIdx(ctx) + if err != nil && !errors.Is(err, leveldb.ErrNotFound) { + return err + } + if errors.Is(err, leveldb.ErrNotFound) { + lastItemIdx = []byte{} + } + if cnt == 0 || bytes.Compare(key, lastItemIdx) > 0 { + if err := b.Put(lastItemIdxKey, key); err != nil { + return fmt.Errorf("updating last item: %w", err) + } + if err := b.Put(countKey, []byte(strconv.Itoa(cnt+1))); err != nil { + return fmt.Errorf("updating length counter: %w", err) + } + } + return b.Write() +} + +func (s *Storage[Item]) Length(context.Context) (int, error) { + val, err := s.db.Get(countKey) + if err != nil { + if errors.Is(err, leveldb.ErrNotFound) { + return 0, nil + } + return 0, err + } + return strconv.Atoi(string(val)) +} + +func (s *Storage[Item]) IsPersistent() bool { + return true +} diff --git a/arbnode/dataposter/redis_storage.go b/arbnode/dataposter/redis/redisstorage.go similarity index 76% rename from arbnode/dataposter/redis_storage.go rename to arbnode/dataposter/redis/redisstorage.go index df3e894539..f7aed00e59 100644 --- a/arbnode/dataposter/redis_storage.go +++ b/arbnode/dataposter/redis/redisstorage.go @@ -1,7 +1,7 @@ // Copyright 2021-2022, Offchain Labs, Inc. // For license information, see https://github.com/nitro/blob/master/LICENSE -package dataposter +package redis import ( "bytes" @@ -11,22 +11,26 @@ import ( "github.com/ethereum/go-ethereum/rlp" "github.com/go-redis/redis/v8" + "github.com/offchainlabs/nitro/arbnode/dataposter/storage" "github.com/offchainlabs/nitro/util/signature" ) -// RedisStorage requires that Item is RLP encodable/decodable -type RedisStorage[Item any] struct { +// Storage implements redis sorted set backed storage. It does not support +// duplicate keys or values. That is, putting the same element on different +// indexes will not yield expected behavior. +// More at: https://redis.io/commands/zadd/. +type Storage[Item any] struct { client redis.UniversalClient signer *signature.SimpleHmac key string } -func NewRedisStorage[Item any](client redis.UniversalClient, key string, signerConf *signature.SimpleHmacConfig) (*RedisStorage[Item], error) { +func NewStorage[Item any](client redis.UniversalClient, key string, signerConf *signature.SimpleHmacConfig) (*Storage[Item], error) { signer, err := signature.NewSimpleHmac(signerConf) if err != nil { return nil, err } - return &RedisStorage[Item]{client, signer, key}, nil + return &Storage[Item]{client, signer, key}, nil } func joinHmacMsg(msg []byte, sig []byte) ([]byte, error) { @@ -36,7 +40,7 @@ func joinHmacMsg(msg []byte, sig []byte) ([]byte, error) { return append(sig, msg...), nil } -func (s *RedisStorage[Item]) peelVerifySignature(data []byte) ([]byte, error) { +func (s *Storage[Item]) peelVerifySignature(data []byte) ([]byte, error) { if len(data) < 32 { return nil, errors.New("data is too short to contain message signature") } @@ -48,7 +52,7 @@ func (s *RedisStorage[Item]) peelVerifySignature(data []byte) ([]byte, error) { return data[32:], nil } -func (s *RedisStorage[Item]) GetContents(ctx context.Context, startingIndex uint64, maxResults uint64) ([]*Item, error) { +func (s *Storage[Item]) GetContents(ctx context.Context, startingIndex uint64, maxResults uint64) ([]*Item, error) { query := redis.ZRangeArgs{ Key: s.key, ByScore: true, @@ -75,7 +79,7 @@ func (s *RedisStorage[Item]) GetContents(ctx context.Context, startingIndex uint return items, nil } -func (s *RedisStorage[Item]) GetLast(ctx context.Context) (*Item, error) { +func (s *Storage[Item]) GetLast(ctx context.Context) (*Item, error) { query := redis.ZRangeArgs{ Key: s.key, Start: 0, @@ -105,16 +109,14 @@ func (s *RedisStorage[Item]) GetLast(ctx context.Context) (*Item, error) { return ret, nil } -func (s *RedisStorage[Item]) Prune(ctx context.Context, keepStartingAt uint64) error { +func (s *Storage[Item]) Prune(ctx context.Context, keepStartingAt uint64) error { if keepStartingAt > 0 { return s.client.ZRemRangeByScore(ctx, s.key, "-inf", fmt.Sprintf("%v", keepStartingAt-1)).Err() } return nil } -var ErrStorageRace = errors.New("storage race error") - -func (s *RedisStorage[Item]) Put(ctx context.Context, index uint64, prevItem *Item, newItem *Item) error { +func (s *Storage[Item]) Put(ctx context.Context, index uint64, prevItem *Item, newItem *Item) error { if newItem == nil { return fmt.Errorf("tried to insert nil item at index %v", index) } @@ -132,11 +134,11 @@ func (s *RedisStorage[Item]) Put(ctx context.Context, index uint64, prevItem *It pipe := tx.TxPipeline() if len(haveItems) == 0 { if prevItem != nil { - return fmt.Errorf("%w: tried to replace item at index %v but no item exists there", ErrStorageRace, index) + return fmt.Errorf("%w: tried to replace item at index %v but no item exists there", storage.ErrStorageRace, index) } } else if len(haveItems) == 1 { if prevItem == nil { - return fmt.Errorf("%w: tried to insert new item at index %v but an item exists there", ErrStorageRace, index) + return fmt.Errorf("%w: tried to insert new item at index %v but an item exists there", storage.ErrStorageRace, index) } verifiedItem, err := s.peelVerifySignature([]byte(haveItems[0])) if err != nil { @@ -147,7 +149,7 @@ func (s *RedisStorage[Item]) Put(ctx context.Context, index uint64, prevItem *It return err } if !bytes.Equal(verifiedItem, prevItemEncoded) { - return fmt.Errorf("%w: replacing different item than expected at index %v", ErrStorageRace, index) + return fmt.Errorf("%w: replacing different item than expected at index %v", storage.ErrStorageRace, index) } err = pipe.ZRem(ctx, s.key, haveItems[0]).Err() if err != nil { @@ -179,7 +181,7 @@ func (s *RedisStorage[Item]) Put(ctx context.Context, index uint64, prevItem *It if errors.Is(err, redis.TxFailedErr) { // Unfortunately, we can't wrap two errors. //nolint:errorlint - err = fmt.Errorf("%w: %v", ErrStorageRace, err.Error()) + err = fmt.Errorf("%w: %v", storage.ErrStorageRace, err.Error()) } return err } @@ -187,7 +189,7 @@ func (s *RedisStorage[Item]) Put(ctx context.Context, index uint64, prevItem *It return s.client.Watch(ctx, action, s.key) } -func (s *RedisStorage[Item]) Length(ctx context.Context) (int, error) { +func (s *Storage[Item]) Length(ctx context.Context) (int, error) { count, err := s.client.ZCount(ctx, s.key, "-inf", "+inf").Result() if err != nil { return 0, err @@ -195,6 +197,6 @@ func (s *RedisStorage[Item]) Length(ctx context.Context) (int, error) { return int(count), nil } -func (s *RedisStorage[Item]) IsPersistent() bool { +func (s *Storage[Item]) IsPersistent() bool { return true } diff --git a/arbnode/dataposter/slice_storage.go b/arbnode/dataposter/slice/slicestorage.go similarity index 63% rename from arbnode/dataposter/slice_storage.go rename to arbnode/dataposter/slice/slicestorage.go index 4364523d99..b0a253086f 100644 --- a/arbnode/dataposter/slice_storage.go +++ b/arbnode/dataposter/slice/slicestorage.go @@ -1,28 +1,30 @@ // Copyright 2021-2022, Offchain Labs, Inc. // For license information, see https://github.com/nitro/blob/master/LICENSE -package dataposter +package slice import ( "context" "errors" "fmt" + "reflect" ) -type SliceStorage[Item any] struct { +type Storage[Item any] struct { firstNonce uint64 queue []*Item } -func NewSliceStorage[Item any]() *SliceStorage[Item] { - return &SliceStorage[Item]{} +func NewStorage[Item any]() *Storage[Item] { + return &Storage[Item]{} } -func (s *SliceStorage[Item]) GetContents(ctx context.Context, startingIndex uint64, maxResults uint64) ([]*Item, error) { +func (s *Storage[Item]) GetContents(_ context.Context, startingIndex uint64, maxResults uint64) ([]*Item, error) { ret := s.queue - if startingIndex >= s.firstNonce+uint64(len(s.queue)) { - ret = nil - } else if startingIndex > s.firstNonce { + if startingIndex >= s.firstNonce+uint64(len(s.queue)) || maxResults == 0 { + return nil, nil + } + if startingIndex > s.firstNonce { ret = ret[startingIndex-s.firstNonce:] } if uint64(len(ret)) > maxResults { @@ -31,14 +33,14 @@ func (s *SliceStorage[Item]) GetContents(ctx context.Context, startingIndex uint return ret, nil } -func (s *SliceStorage[Item]) GetLast(ctx context.Context) (*Item, error) { +func (s *Storage[Item]) GetLast(context.Context) (*Item, error) { if len(s.queue) == 0 { return nil, nil } return s.queue[len(s.queue)-1], nil } -func (s *SliceStorage[Item]) Prune(ctx context.Context, keepStartingAt uint64) error { +func (s *Storage[Item]) Prune(_ context.Context, keepStartingAt uint64) error { if keepStartingAt >= s.firstNonce+uint64(len(s.queue)) { s.queue = nil } else if keepStartingAt >= s.firstNonce { @@ -48,7 +50,7 @@ func (s *SliceStorage[Item]) Prune(ctx context.Context, keepStartingAt uint64) e return nil } -func (s *SliceStorage[Item]) Put(ctx context.Context, index uint64, prevItem *Item, newItem *Item) error { +func (s *Storage[Item]) Put(_ context.Context, index uint64, prevItem *Item, newItem *Item) error { if newItem == nil { return fmt.Errorf("tried to insert nil item at index %v", index) } @@ -68,8 +70,8 @@ func (s *SliceStorage[Item]) Put(ctx context.Context, index uint64, prevItem *It if queueIdx > len(s.queue) { return fmt.Errorf("attempted to set out-of-bounds index %v in queue starting at %v of length %v", index, s.firstNonce, len(s.queue)) } - if prevItem != s.queue[queueIdx] { - return errors.New("prevItem isn't nil but item is just after end of queue") + if !reflect.DeepEqual(prevItem, s.queue[queueIdx]) { + return fmt.Errorf("replacing different item than expected at index: %v: %v %v", index, prevItem, s.queue[queueIdx]) } s.queue[queueIdx] = newItem } else { @@ -78,10 +80,10 @@ func (s *SliceStorage[Item]) Put(ctx context.Context, index uint64, prevItem *It return nil } -func (s *SliceStorage[Item]) Length(ctx context.Context) (int, error) { +func (s *Storage[Item]) Length(context.Context) (int, error) { return len(s.queue), nil } -func (s *SliceStorage[Item]) IsPersistent() bool { +func (s *Storage[Item]) IsPersistent() bool { return false } diff --git a/arbnode/dataposter/storage/storage.go b/arbnode/dataposter/storage/storage.go new file mode 100644 index 0000000000..555f7e1e5d --- /dev/null +++ b/arbnode/dataposter/storage/storage.go @@ -0,0 +1,13 @@ +package storage + +import ( + "errors" +) + +var ( + ErrStorageRace = errors.New("storage race error") + + DataPosterPrefix string = "d" // the prefix for all data poster keys + // TODO(anodar): move everything else from schema.go file to here once + // execution split is complete. +) diff --git a/arbnode/dataposter/storage_test.go b/arbnode/dataposter/storage_test.go new file mode 100644 index 0000000000..0ef83ed5ba --- /dev/null +++ b/arbnode/dataposter/storage_test.go @@ -0,0 +1,283 @@ +package dataposter + +import ( + "context" + "path" + "strconv" + "testing" + + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/google/go-cmp/cmp" + "github.com/offchainlabs/nitro/arbnode/dataposter/leveldb" + "github.com/offchainlabs/nitro/arbnode/dataposter/redis" + "github.com/offchainlabs/nitro/arbnode/dataposter/slice" + "github.com/offchainlabs/nitro/util/arbmath" + "github.com/offchainlabs/nitro/util/redisutil" + "github.com/offchainlabs/nitro/util/signature" +) + +func newLevelDBStorage[Item any](t *testing.T) *leveldb.Storage[Item] { + t.Helper() + db, err := rawdb.NewLevelDBDatabase(path.Join(t.TempDir(), "level.db"), 0, 0, "default", false) + if err != nil { + t.Fatalf("NewLevelDBDatabase() unexpected error: %v", err) + } + return leveldb.New[Item](db) +} + +func newSliceStorage[Item any]() *slice.Storage[Item] { + return slice.NewStorage[Item]() +} + +func newRedisStorage[Item any](ctx context.Context, t *testing.T) *redis.Storage[Item] { + t.Helper() + redisUrl := redisutil.CreateTestRedis(ctx, t) + client, err := redisutil.RedisClientFromURL(redisUrl) + if err != nil { + t.Fatalf("RedisClientFromURL(%q) unexpected error: %v", redisUrl, err) + } + s, err := redis.NewStorage[Item](client, "", &signature.TestSimpleHmacConfig) + if err != nil { + t.Fatalf("redis.NewStorage() unexpected error: %v", err) + } + return s +} + +// Initializes the QueueStorage. Returns the same object (for convenience). +func initStorage(ctx context.Context, t *testing.T, s QueueStorage[string]) QueueStorage[string] { + t.Helper() + for i := 0; i < 20; i++ { + val := strconv.Itoa(i) + if err := s.Put(ctx, uint64(i), nil, &val); err != nil { + t.Fatalf("Error putting a key/value: %v", err) + } + } + return s +} + +// Returns a map of all empty storages. +func storages(t *testing.T) map[string]QueueStorage[string] { + t.Helper() + return map[string]QueueStorage[string]{ + "levelDB": newLevelDBStorage[string](t), + "slice": newSliceStorage[string](), + "redis": newRedisStorage[string](context.Background(), t), + } +} + +// Returns a map of all initialized storages. +func initStorages(ctx context.Context, t *testing.T) map[string]QueueStorage[string] { + t.Helper() + m := map[string]QueueStorage[string]{} + for k, v := range storages(t) { + m[k] = initStorage(ctx, t, v) + } + return m +} + +func strPtrs(values []string) []*string { + var res []*string + for _, val := range values { + v := val + res = append(res, &v) + } + return res +} + +func TestGetContents(t *testing.T) { + ctx := context.Background() + for name, s := range initStorages(ctx, t) { + for _, tc := range []struct { + desc string + startIdx uint64 + maxResults uint64 + want []*string + }{ + { + desc: "sequence with single digits", + startIdx: 5, + maxResults: 3, + want: strPtrs([]string{"5", "6", "7"}), + }, + { + desc: "corner case of single element", + startIdx: 0, + maxResults: 1, + want: strPtrs([]string{"0"}), + }, + { + desc: "no elements", + startIdx: 3, + maxResults: 0, + want: strPtrs([]string{}), + }, + { + // Making sure it's correctly ordered lexicographically. + desc: "sequence with variable number of digits", + startIdx: 9, + maxResults: 3, + want: strPtrs([]string{"9", "10", "11"}), + }, + { + desc: "max results goes over the last element", + startIdx: 13, + maxResults: 10, + want: strPtrs([]string{"13", "14", "15", "16", "17", "18", "19"}), + }, + } { + t.Run(name+"_"+tc.desc, func(t *testing.T) { + values, err := s.GetContents(ctx, tc.startIdx, tc.maxResults) + if err != nil { + t.Fatalf("GetContents(%d, %d) unexpected error: %v", tc.startIdx, tc.maxResults, err) + } + if diff := cmp.Diff(tc.want, values); diff != "" { + t.Errorf("GetContext(%d, %d) unexpected diff:\n%s", tc.startIdx, tc.maxResults, diff) + } + }) + } + } +} + +func TestGetLast(t *testing.T) { + cnt := 100 + for name, s := range storages(t) { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + for i := 0; i < cnt; i++ { + val := strconv.Itoa(i) + if err := s.Put(ctx, uint64(i), nil, &val); err != nil { + t.Fatalf("Error putting a key/value: %v", err) + } + got, err := s.GetLast(ctx) + if err != nil { + t.Fatalf("Error getting a last element: %v", err) + } + if *got != val { + t.Errorf("GetLast() = %q want %q", *got, val) + } + + } + }) + last := strconv.Itoa(cnt - 1) + t.Run(name+"_update_entries", func(t *testing.T) { + ctx := context.Background() + for i := 0; i < cnt-1; i++ { + prev := strconv.Itoa(i) + newVal := strconv.Itoa(cnt + i) + if err := s.Put(ctx, uint64(i), &prev, &newVal); err != nil { + t.Fatalf("Error putting a key/value: %v, prev: %v, new: %v", err, prev, newVal) + } + got, err := s.GetLast(ctx) + if err != nil { + t.Fatalf("Error getting a last element: %v", err) + } + if *got != last { + t.Errorf("GetLast() = %q want %q", *got, last) + } + gotCnt, err := s.Length(ctx) + if err != nil { + t.Fatalf("Length() unexpected error: %v", err) + } + if gotCnt != cnt { + t.Errorf("Length() = %d want %d", gotCnt, cnt) + } + } + }) + } +} + +func TestPrune(t *testing.T) { + ctx := context.Background() + for _, tc := range []struct { + desc string + pruneFrom uint64 + want []*string + }{ + { + desc: "prune all elements", + pruneFrom: 20, + }, + { + desc: "prune all but one", + pruneFrom: 19, + want: strPtrs([]string{"19"}), + }, + { + desc: "pruning first element", + pruneFrom: 1, + want: strPtrs([]string{"1", "2", "3", "4", "5", "6", "7", + "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19"}), + }, + { + desc: "pruning first 11 elements", + pruneFrom: 11, + want: strPtrs([]string{"11", "12", "13", "14", "15", "16", "17", "18", "19"}), + }, + { + desc: "pruning from higher than biggest index", + pruneFrom: 30, + want: strPtrs([]string{}), + }, + } { + // Storages must be re-initialized in each test-case. + for name, s := range initStorages(ctx, t) { + t.Run(name+"_"+tc.desc, func(t *testing.T) { + if err := s.Prune(ctx, tc.pruneFrom); err != nil { + t.Fatalf("Prune(%d) unexpected error: %v", tc.pruneFrom, err) + } + got, err := s.GetContents(ctx, 0, 20) + if err != nil { + t.Fatalf("GetContents() unexpected error: %v", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("Prune(%d) unexpected diff:\n%s", tc.pruneFrom, diff) + } + }) + } + } +} + +func TestLength(t *testing.T) { + ctx := context.Background() + for _, tc := range []struct { + desc string + pruneFrom uint64 + }{ + { + desc: "not prune any elements", + }, + { + desc: "prune all but one", + pruneFrom: 19, + }, + { + desc: "pruning first element", + pruneFrom: 1, + }, + { + desc: "pruning first 11 elements", + pruneFrom: 11, + }, + { + desc: "pruning from higher than biggest index", + pruneFrom: 30, + }, + } { + // Storages must be re-initialized in each test-case. + for name, s := range initStorages(ctx, t) { + t.Run(name+"_"+tc.desc, func(t *testing.T) { + if err := s.Prune(ctx, tc.pruneFrom); err != nil { + t.Fatalf("Prune(%d) unexpected error: %v", tc.pruneFrom, err) + } + got, err := s.Length(ctx) + if err != nil { + t.Fatalf("Length() unexpected error: %v", err) + } + if want := arbmath.MaxInt(0, 20-int(tc.pruneFrom)); got != want { + t.Errorf("Length() = %d want %d", got, want) + } + }) + } + + } +} diff --git a/arbnode/delayed_sequencer.go b/arbnode/delayed_sequencer.go index efc1dec7d7..f45a85ac49 100644 --- a/arbnode/delayed_sequencer.go +++ b/arbnode/delayed_sequencer.go @@ -43,7 +43,7 @@ type DelayedSequencerConfig struct { type DelayedSequencerConfigFetcher func() *DelayedSequencerConfig func DelayedSequencerConfigAddOptions(prefix string, f *flag.FlagSet) { - f.Bool(prefix+".enable", DefaultSeqCoordinatorConfig.Enable, "enable sequence coordinator") + f.Bool(prefix+".enable", DefaultDelayedSequencerConfig.Enable, "enable delayed sequencer") f.Int64(prefix+".finalize-distance", DefaultDelayedSequencerConfig.FinalizeDistance, "how many blocks in the past L1 block is considered final (ignored when using Merge finality)") f.Bool(prefix+".require-full-finality", DefaultDelayedSequencerConfig.RequireFullFinality, "whether to wait for full finality before sequencing delayed messages") f.Bool(prefix+".use-merge-finality", DefaultDelayedSequencerConfig.UseMergeFinality, "whether to use The Merge's notion of finality before sequencing delayed messages") diff --git a/arbnode/execution/block_recorder.go b/arbnode/execution/block_recorder.go new file mode 100644 index 0000000000..dc5daa6f7b --- /dev/null +++ b/arbnode/execution/block_recorder.go @@ -0,0 +1,364 @@ +package execution + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/arbitrum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/offchainlabs/nitro/arbos" + "github.com/offchainlabs/nitro/arbos/arbosState" + "github.com/offchainlabs/nitro/arbos/arbostypes" + "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/validator" +) + +// BlockRecorder uses a separate statedatabase from the blockchain. +// It has access to any state in the ethdb (hard-disk) database, and can compute state as needed. +// We keep references for state of: +// Any block that matches PrepareForRecord that was done recently (according to PrepareDelay config) +// Most recent/advanced header we ever computed (lastHdr) +// Hopefully - some recent valid block. For that we always keep one candidate block until it becomes validated. +type BlockRecorder struct { + recordingDatabase *arbitrum.RecordingDatabase + execEngine *ExecutionEngine + + lastHdr *types.Header + lastHdrLock sync.Mutex + + validHdrCandidate *types.Header + validHdr *types.Header + validHdrLock sync.Mutex + + preparedQueue []*types.Header + preparedLock sync.Mutex +} + +type RecordResult struct { + Pos arbutil.MessageIndex + BlockHash common.Hash + Preimages map[common.Hash][]byte + BatchInfo []validator.BatchInfo +} + +func NewBlockRecorder(config *arbitrum.RecordingDatabaseConfig, execEngine *ExecutionEngine, ethDb ethdb.Database) *BlockRecorder { + recorder := &BlockRecorder{ + execEngine: execEngine, + recordingDatabase: arbitrum.NewRecordingDatabase(config, ethDb, execEngine.bc), + } + execEngine.SetRecorder(recorder) + return recorder +} + +func stateLogFunc(targetHeader, header *types.Header, hasState bool) { + if targetHeader == nil || header == nil { + return + } + gap := targetHeader.Number.Int64() - header.Number.Int64() + step := int64(500) + stage := "computing state" + if !hasState { + step = 3000 + stage = "looking for full block" + } + if (gap >= step) && (gap%step == 0) { + log.Info("Setting up validation", "stage", stage, "current", header.Number, "target", targetHeader.Number) + } +} + +// If msg is nil, this will record block creation up to the point where message would be accessed (for a "too far" proof) +// If keepreference == true, reference to state of prevHeader is added (no reference added if an error is returned) +func (r *BlockRecorder) RecordBlockCreation( + ctx context.Context, + pos arbutil.MessageIndex, + msg *arbostypes.MessageWithMetadata, +) (*RecordResult, error) { + + blockNum := r.execEngine.MessageIndexToBlockNumber(pos) + + var prevHeader *types.Header + if pos != 0 { + prevHeader = r.execEngine.bc.GetHeaderByNumber(uint64(blockNum - 1)) + if prevHeader == nil { + return nil, fmt.Errorf("pos %d prevHeader not found", pos) + } + } + + recordingdb, chaincontext, recordingKV, err := r.recordingDatabase.PrepareRecording(ctx, prevHeader, stateLogFunc) + if err != nil { + return nil, err + } + defer func() { r.recordingDatabase.Dereference(prevHeader) }() + + chainConfig := r.execEngine.bc.Config() + + // Get the chain ID, both to validate and because the replay binary also gets the chain ID, + // so we need to populate the recordingdb with preimages for retrieving the chain ID. + if prevHeader != nil { + initialArbosState, err := arbosState.OpenSystemArbosState(recordingdb, nil, true) + if err != nil { + return nil, fmt.Errorf("error opening initial ArbOS state: %w", err) + } + chainId, err := initialArbosState.ChainId() + if err != nil { + return nil, fmt.Errorf("error getting chain ID from initial ArbOS state: %w", err) + } + if chainId.Cmp(chainConfig.ChainID) != 0 { + return nil, fmt.Errorf("unexpected chain ID %r in ArbOS state, expected %r", chainId, chainConfig.ChainID) + } + genesisNum, err := initialArbosState.GenesisBlockNum() + if err != nil { + return nil, fmt.Errorf("error getting genesis block number from initial ArbOS state: %w", err) + } + _, err = initialArbosState.ChainConfig() + if err != nil { + return nil, fmt.Errorf("error getting chain config from initial ArbOS state: %w", err) + } + expectedNum := chainConfig.ArbitrumChainParams.GenesisBlockNum + if genesisNum != expectedNum { + return nil, fmt.Errorf("unexpected genesis block number %v in ArbOS state, expected %v", genesisNum, expectedNum) + } + } + + var blockHash common.Hash + var readBatchInfo []validator.BatchInfo + if msg != nil { + batchFetcher := func(batchNum uint64) ([]byte, error) { + data, err := r.execEngine.streamer.FetchBatch(batchNum) + if err != nil { + return nil, err + } + readBatchInfo = append(readBatchInfo, validator.BatchInfo{ + Number: batchNum, + Data: data, + }) + return data, nil + } + // Re-fetch the batch instead of using our cached cost, + // as the replay binary won't have the cache populated. + msg.Message.BatchGasCost = nil + block, _, err := arbos.ProduceBlock( + msg.Message, + msg.DelayedMessagesRead, + prevHeader, + recordingdb, + chaincontext, + chainConfig, + batchFetcher, + ) + if err != nil { + return nil, err + } + blockHash = block.Hash() + } + + preimages, err := r.recordingDatabase.PreimagesFromRecording(chaincontext, recordingKV) + if err != nil { + return nil, err + } + + // check we got the canonical hash + canonicalHash := r.execEngine.bc.GetCanonicalHash(uint64(blockNum)) + if canonicalHash != blockHash { + return nil, fmt.Errorf("Blockhash doesn't match when recording got %v canonical %v", blockHash, canonicalHash) + } + + // these won't usually do much here (they will in preparerecording), but doesn't hurt to check + r.updateLastHdr(prevHeader) + r.updateValidCandidateHdr(prevHeader) + + return &RecordResult{pos, blockHash, preimages, readBatchInfo}, err +} + +func (r *BlockRecorder) updateLastHdr(hdr *types.Header) { + if hdr == nil { + return + } + r.lastHdrLock.Lock() + defer r.lastHdrLock.Unlock() + if r.lastHdr != nil { + if hdr.Number.Cmp(r.lastHdr.Number) <= 0 { + return + } + } + _, err := r.recordingDatabase.StateFor(hdr) + if err != nil { + log.Warn("failed to get state in updateLastHdr", "err", err) + return + } + r.recordingDatabase.Dereference(r.lastHdr) + r.lastHdr = hdr +} + +func (r *BlockRecorder) updateValidCandidateHdr(hdr *types.Header) { + if hdr == nil { + return + } + r.validHdrLock.Lock() + defer r.validHdrLock.Unlock() + // don't need a candidate that's newer than the current one (else it will never become valid) + if r.validHdrCandidate != nil && r.validHdrCandidate.Number.Cmp(hdr.Number) <= 0 { + return + } + // don't need a candidate that's older than known valid + if r.validHdr != nil && r.validHdr.Number.Cmp(hdr.Number) >= 0 { + return + } + _, err := r.recordingDatabase.StateFor(hdr) + if err != nil { + log.Warn("failed to get state in updateLastHdr", "err", err) + return + } + if r.validHdrCandidate != nil { + r.recordingDatabase.Dereference(r.validHdrCandidate) + } + r.validHdrCandidate = hdr +} + +func (r *BlockRecorder) MarkValid(pos arbutil.MessageIndex, resultHash common.Hash) { + r.validHdrLock.Lock() + defer r.validHdrLock.Unlock() + if r.validHdrCandidate == nil { + return + } + validNum := r.execEngine.MessageIndexToBlockNumber(pos) + if r.validHdrCandidate.Number.Uint64() > validNum { + return + } + // make sure the valid is canonical + canonicalResultHash := r.execEngine.bc.GetCanonicalHash(uint64(validNum)) + if canonicalResultHash != resultHash { + log.Warn("markvalid hash not canonical", "pos", pos, "result", resultHash, "canonical", canonicalResultHash) + return + } + // make sure the candidate is still canonical + canonicalHash := r.execEngine.bc.GetCanonicalHash(r.validHdrCandidate.Number.Uint64()) + candidateHash := r.validHdrCandidate.Hash() + if canonicalHash != candidateHash { + log.Error("vlid candidate hash not canonical", "number", r.validHdrCandidate.Number, "candidate", candidateHash, "canonical", canonicalHash) + r.recordingDatabase.Dereference(r.validHdrCandidate) + r.validHdrCandidate = nil + return + } + r.recordingDatabase.Dereference(r.validHdr) + r.validHdr = r.validHdrCandidate + r.validHdrCandidate = nil +} + +// TODO: use config +func (r *BlockRecorder) preparedAddTrim(newRefs []*types.Header, size int) { + var oldRefs []*types.Header + r.preparedLock.Lock() + r.preparedQueue = append(r.preparedQueue, newRefs...) + if len(r.preparedQueue) > size { + oldRefs = r.preparedQueue[:len(r.preparedQueue)-size] + r.preparedQueue = r.preparedQueue[len(r.preparedQueue)-size:] + } + r.preparedLock.Unlock() + for _, ref := range oldRefs { + r.recordingDatabase.Dereference(ref) + } +} + +func (r *BlockRecorder) preparedTrimBeyond(hdr *types.Header) { + var oldRefs []*types.Header + var validRefs []*types.Header + r.preparedLock.Lock() + for _, queHdr := range r.preparedQueue { + if queHdr.Number.Cmp(hdr.Number) > 0 { + oldRefs = append(oldRefs, queHdr) + } else { + validRefs = append(validRefs, queHdr) + } + } + r.preparedQueue = validRefs + r.preparedLock.Unlock() + for _, ref := range oldRefs { + r.recordingDatabase.Dereference(ref) + } +} + +func (r *BlockRecorder) TrimAllPrepared(t *testing.T) { + r.preparedAddTrim(nil, 0) +} + +func (r *BlockRecorder) RecordingDBReferenceCount() int64 { + return r.recordingDatabase.ReferenceCount() +} + +func (r *BlockRecorder) PrepareForRecord(ctx context.Context, start, end arbutil.MessageIndex) error { + var references []*types.Header + if end < start { + return fmt.Errorf("illegal range start %d > end %d", start, end) + } + numOfBlocks := uint64(end + 1 - start) + hdrNum := r.execEngine.MessageIndexToBlockNumber(start) + if start > 0 { + hdrNum-- // need to get previous + } else { + numOfBlocks-- // genesis block doesn't need preparation, so recording one less block + } + lastHdrNum := hdrNum + numOfBlocks + for hdrNum <= lastHdrNum { + header := r.execEngine.bc.GetHeaderByNumber(uint64(hdrNum)) + if header == nil { + log.Warn("prepareblocks asked for non-found block", "hdrNum", hdrNum) + break + } + _, err := r.recordingDatabase.GetOrRecreateState(ctx, header, stateLogFunc) + if err != nil { + log.Warn("prepareblocks failed to get state for block", "hdrNum", hdrNum, "err", err) + break + } + references = append(references, header) + r.updateValidCandidateHdr(header) + r.updateLastHdr(header) + hdrNum++ + } + r.preparedAddTrim(references, 1000) + return nil +} + +func (r *BlockRecorder) ReorgTo(hdr *types.Header) { + r.validHdrLock.Lock() + if r.validHdr != nil && r.validHdr.Number.Cmp(hdr.Number) > 0 { + log.Warn("block recorder: reorging past previously-marked valid block", "reorg target num", hdr.Number, "hash", hdr.Hash(), "reorged past num", r.validHdr.Number, "hash", r.validHdr.Hash()) + r.recordingDatabase.Dereference(r.validHdr) + r.validHdr = nil + } + if r.validHdrCandidate != nil && r.validHdrCandidate.Number.Cmp(hdr.Number) > 0 { + r.recordingDatabase.Dereference(r.validHdrCandidate) + r.validHdrCandidate = nil + } + r.validHdrLock.Unlock() + r.lastHdrLock.Lock() + if r.lastHdr != nil && r.lastHdr.Number.Cmp(hdr.Number) > 0 { + r.recordingDatabase.Dereference(r.lastHdr) + r.lastHdr = nil + } + r.lastHdrLock.Unlock() + r.preparedTrimBeyond(hdr) +} + +func (r *BlockRecorder) WriteValidStateToDb() error { + r.validHdrLock.Lock() + defer r.validHdrLock.Unlock() + if r.validHdr == nil { + return nil + } + err := r.recordingDatabase.WriteStateToDatabase(r.validHdr) + r.recordingDatabase.Dereference(r.validHdr) + return err +} + +func (r *BlockRecorder) OrderlyShutdown() { + err := r.WriteValidStateToDb() + if err != nil { + log.Error("failed writing latest valid block state to DB", "err", err) + } +} diff --git a/arbnode/execution/blockchain.go b/arbnode/execution/blockchain.go index a4de72588a..88e7044e8d 100644 --- a/arbnode/execution/blockchain.go +++ b/arbnode/execution/blockchain.go @@ -40,7 +40,7 @@ type CachingConfig struct { func CachingConfigAddOptions(prefix string, f *flag.FlagSet) { f.Bool(prefix+".archive", DefaultCachingConfig.Archive, "retain past block state") f.Uint64(prefix+".block-count", DefaultCachingConfig.BlockCount, "minimum number of recent blocks to keep in memory") - f.Duration(prefix+".block-age", DefaultCachingConfig.BlockAge, "minimum age a block must be to be pruned") + f.Duration(prefix+".block-age", DefaultCachingConfig.BlockAge, "minimum age of recent blocks to keep in memory") f.Duration(prefix+".trie-time-limit", DefaultCachingConfig.TrieTimeLimit, "maximum block processing time before trie is written to hard-disk") f.Int(prefix+".trie-dirty-cache", DefaultCachingConfig.TrieDirtyCache, "amount of memory in megabytes to cache state diffs against disk with (larger cache lowers database growth)") f.Int(prefix+".trie-clean-cache", DefaultCachingConfig.TrieCleanCache, "amount of memory in megabytes to cache unchanged state trie nodes with") @@ -188,22 +188,6 @@ func shouldPreserveFalse(_ *types.Header) bool { return false } -func ReorgToBlock(chain *core.BlockChain, blockNum uint64) (*types.Block, error) { - genesisNum := chain.Config().ArbitrumChainParams.GenesisBlockNum - if blockNum < genesisNum { - return nil, fmt.Errorf("cannot reorg to block %v past nitro genesis of %v", blockNum, genesisNum) - } - reorgingToBlock := chain.GetBlockByNumber(blockNum) - if reorgingToBlock == nil { - return nil, fmt.Errorf("didn't find reorg target block number %v", blockNum) - } - err := chain.ReorgToOldBlock(reorgingToBlock) - if err != nil { - return nil, err - } - return reorgingToBlock, nil -} - func init() { gethhook.RequireHookedGeth() } diff --git a/arbnode/execution/executionengine.go b/arbnode/execution/executionengine.go index 88b42cbe4c..d8029650d7 100644 --- a/arbnode/execution/executionengine.go +++ b/arbnode/execution/executionengine.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" @@ -19,7 +20,6 @@ import ( "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbos/l1pricing" "github.com/offchainlabs/nitro/arbutil" - "github.com/offchainlabs/nitro/staker" "github.com/offchainlabs/nitro/util/sharedmetrics" "github.com/offchainlabs/nitro/util/stopwaiter" ) @@ -33,9 +33,9 @@ type TransactionStreamerInterface interface { type ExecutionEngine struct { stopwaiter.StopWaiter - bc *core.BlockChain - validator *staker.BlockValidator - streamer TransactionStreamerInterface + bc *core.BlockChain + streamer TransactionStreamerInterface + recorder *BlockRecorder resequenceChan chan []*arbostypes.MessageWithMetadata createBlocksMutex sync.Mutex @@ -57,14 +57,14 @@ func NewExecutionEngine(bc *core.BlockChain) (*ExecutionEngine, error) { }, nil } -func (s *ExecutionEngine) SetBlockValidator(validator *staker.BlockValidator) { +func (s *ExecutionEngine) SetRecorder(recorder *BlockRecorder) { if s.Started() { - panic("trying to set block validator after start") + panic("trying to set recorder after start") } - if s.validator != nil { - panic("trying to set block validator when already set") + if s.recorder != nil { + panic("trying to set recorder policy when already set") } - s.validator = validator + s.recorder = recorder } func (s *ExecutionEngine) EnableReorgSequencing() { @@ -79,15 +79,18 @@ func (s *ExecutionEngine) EnableReorgSequencing() { func (s *ExecutionEngine) SetTransactionStreamer(streamer TransactionStreamerInterface) { if s.Started() { - panic("trying to set reorg sequencing policy after start") + panic("trying to set transaction streamer after start") } if s.streamer != nil { - panic("trying to set reorg sequencing policy when already set") + panic("trying to set transaction streamer when already set") } s.streamer = streamer } func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadata, oldMessages []*arbostypes.MessageWithMetadata) error { + if count == 0 { + return errors.New("cannot reorg out genesis") + } s.createBlocksMutex.Lock() resequencing := false defer func() { @@ -97,24 +100,15 @@ func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbost s.createBlocksMutex.Unlock() } }() - blockNum, err := s.MessageCountToBlockNumber(count) - if err != nil { - return err - } + blockNum := s.MessageIndexToBlockNumber(count - 1) // We can safely cast blockNum to a uint64 as it comes from MessageCountToBlockNumber targetBlock := s.bc.GetBlockByNumber(uint64(blockNum)) if targetBlock == nil { log.Warn("reorg target block not found", "block", blockNum) return nil } - if s.validator != nil { - err = s.validator.ReorgToBlock(targetBlock.NumberU64(), targetBlock.Hash()) - if err != nil { - return err - } - } - err = s.bc.ReorgToOldBlock(targetBlock) + err := s.bc.ReorgToOldBlock(targetBlock) if err != nil { return err } @@ -124,6 +118,9 @@ func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbost return err } } + if s.recorder != nil { + s.recorder.ReorgTo(targetBlock.Header()) + } if len(oldMessages) > 0 { s.resequenceChan <- oldMessages resequencing = true @@ -144,11 +141,7 @@ func (s *ExecutionEngine) HeadMessageNumber() (arbutil.MessageIndex, error) { if err != nil { return 0, err } - msgCount, err := s.BlockNumberToMessageCount(currentHeader.Number.Uint64()) - if err != nil { - return 0, err - } - return msgCount - 1, err + return s.BlockNumberToMessageIndex(currentHeader.Number.Uint64()) } func (s *ExecutionEngine) HeadMessageNumberSync(t *testing.T) (arbutil.MessageIndex, error) { @@ -347,7 +340,7 @@ func (s *ExecutionEngine) sequenceTransactionsWithBlockMutex(header *arbostypes. DelayedMessagesRead: delayedMessagesRead, } - pos, err := s.BlockNumberToMessageCount(lastBlockHeader.Number.Uint64()) + pos, err := s.BlockNumberToMessageIndex(lastBlockHeader.Number.Uint64() + 1) if err != nil { return nil, err } @@ -364,10 +357,6 @@ func (s *ExecutionEngine) sequenceTransactionsWithBlockMutex(header *arbostypes. return nil, err } - if s.validator != nil { - s.validator.NewBlock(block, lastBlockHeader, msgWithMeta) - } - return block, nil } @@ -386,7 +375,7 @@ func (s *ExecutionEngine) sequenceDelayedMessageWithBlockMutex(message *arbostyp expectedDelayed := currentHeader.Nonce.Uint64() - pos, err := s.BlockNumberToMessageCount(currentHeader.Number.Uint64()) + lastMsg, err := s.BlockNumberToMessageIndex(currentHeader.Number.Uint64()) if err != nil { return nil, err } @@ -400,7 +389,7 @@ func (s *ExecutionEngine) sequenceDelayedMessageWithBlockMutex(message *arbostyp DelayedMessagesRead: delayedSeqNum + 1, } - err = s.streamer.WriteMessageFromSequencer(pos, messageWithMeta) + err = s.streamer.WriteMessageFromSequencer(lastMsg+1, messageWithMeta) if err != nil { return nil, err } @@ -416,29 +405,25 @@ func (s *ExecutionEngine) sequenceDelayedMessageWithBlockMutex(message *arbostyp return nil, err } - log.Info("ExecutionEngine: Added DelayedMessages", "pos", pos, "delayed", delayedSeqNum, "block-header", block.Header()) + log.Info("ExecutionEngine: Added DelayedMessages", "pos", lastMsg+1, "delayed", delayedSeqNum, "block-header", block.Header()) return block, nil } -func (s *ExecutionEngine) GetGenesisBlockNumber() (uint64, error) { - return s.bc.Config().ArbitrumChainParams.GenesisBlockNum, nil +func (s *ExecutionEngine) GetGenesisBlockNumber() uint64 { + return s.bc.Config().ArbitrumChainParams.GenesisBlockNum } -func (s *ExecutionEngine) BlockNumberToMessageCount(blockNum uint64) (arbutil.MessageIndex, error) { - genesis, err := s.GetGenesisBlockNumber() - if err != nil { - return 0, err +func (s *ExecutionEngine) BlockNumberToMessageIndex(blockNum uint64) (arbutil.MessageIndex, error) { + genesis := s.GetGenesisBlockNumber() + if blockNum < genesis { + return 0, fmt.Errorf("blockNum %d < genesis %d", blockNum, genesis) } - return arbutil.BlockNumberToMessageCount(blockNum, genesis), nil + return arbutil.MessageIndex(blockNum - genesis), nil } -func (s *ExecutionEngine) MessageCountToBlockNumber(messageNum arbutil.MessageIndex) (int64, error) { - genesis, err := s.GetGenesisBlockNumber() - if err != nil { - return 0, err - } - return arbutil.MessageCountToBlockNumber(messageNum, genesis), nil +func (s *ExecutionEngine) MessageIndexToBlockNumber(messageNum arbutil.MessageIndex) uint64 { + return uint64(messageNum) + s.GetGenesisBlockNumber() } // must hold createBlockMutex @@ -494,6 +479,23 @@ func (s *ExecutionEngine) appendBlock(block *types.Block, statedb *state.StateDB return nil } +type MessageResult struct { + BlockHash common.Hash + SendRoot common.Hash +} + +func (s *ExecutionEngine) resultFromHeader(header *types.Header) (*MessageResult, error) { + if header == nil { + return nil, fmt.Errorf("result not found") + } + info := types.DeserializeHeaderExtraInformation(header) + return &MessageResult{header.Hash(), info.SendRoot}, nil +} + +func (s *ExecutionEngine) ResultAtPos(pos arbutil.MessageIndex) (*MessageResult, error) { + return s.resultFromHeader(s.bc.GetHeaderByNumber(s.MessageIndexToBlockNumber(pos))) +} + func (s *ExecutionEngine) DigestMessage(num arbutil.MessageIndex, msg *arbostypes.MessageWithMetadata) error { if !s.createBlocksMutex.TryLock() { return errors.New("createBlock mutex held") @@ -507,12 +509,12 @@ func (s *ExecutionEngine) digestMessageWithBlockMutex(num arbutil.MessageIndex, if err != nil { return err } - expNum, err := s.BlockNumberToMessageCount(currentHeader.Number.Uint64()) + curMsg, err := s.BlockNumberToMessageIndex(currentHeader.Number.Uint64()) if err != nil { return err } - if expNum != num { - return fmt.Errorf("wrong message number in digest got %d expected %d", num, expNum) + if curMsg+1 != num { + return fmt.Errorf("wrong message number in digest got %d expected %d", num, curMsg+1) } startTime := time.Now() @@ -526,10 +528,6 @@ func (s *ExecutionEngine) digestMessageWithBlockMutex(num arbutil.MessageIndex, return err } - if s.validator != nil { - s.validator.NewBlock(block, currentHeader, *msg) - } - if time.Now().After(s.nextScheduledVersionCheck) { s.nextScheduledVersionCheck = time.Now().Add(time.Minute) arbState, err := arbosState.OpenSystemArbosState(statedb, nil, true) diff --git a/arbnode/execution/node.go b/arbnode/execution/node.go index a0ff13bb29..7456c1d6a7 100644 --- a/arbnode/execution/node.go +++ b/arbnode/execution/node.go @@ -17,6 +17,7 @@ type ExecutionNode struct { FilterSystem *filters.FilterSystem ArbInterface *ArbInterface ExecEngine *ExecutionEngine + Recorder *BlockRecorder Sequencer *Sequencer // either nil or same as TxPublisher TxPublisher TransactionPublisher } @@ -30,6 +31,7 @@ func CreateExecutionNode( fwTarget string, fwConfig *ForwarderConfig, rpcConfig arbitrum.Config, + recordingDbConfig *arbitrum.RecordingDatabaseConfig, seqConfigFetcher SequencerConfigFetcher, precheckConfigFetcher TxPreCheckerConfigFetcher, ) (*ExecutionNode, error) { @@ -37,6 +39,7 @@ func CreateExecutionNode( if err != nil { return nil, err } + recorder := NewBlockRecorder(recordingDbConfig, execEngine, chainDB) var txPublisher TransactionPublisher var sequencer *Sequencer seqConfig := seqConfigFetcher() @@ -79,6 +82,7 @@ func CreateExecutionNode( filterSystem, arbInterface, execEngine, + recorder, sequencer, txPublisher, }, nil diff --git a/arbnode/inbox_test.go b/arbnode/inbox_test.go index e68cee49ff..21eef7499c 100644 --- a/arbnode/inbox_test.go +++ b/arbnode/inbox_test.go @@ -199,10 +199,17 @@ func TestTransactionStreamer(t *testing.T) { } // Check that state balances are consistent with blockchain's balances - lastBlockNumber := bc.CurrentHeader().Number.Uint64() expectedLastBlockNumber := blockStates[len(blockStates)-1].blockNumber - if lastBlockNumber != expectedLastBlockNumber { - Fail(t, "unexpected block number", lastBlockNumber, "vs", expectedLastBlockNumber) + for i := 0; ; i++ { + lastBlockNumber := bc.CurrentHeader().Number.Uint64() + if lastBlockNumber == expectedLastBlockNumber { + break + } else if lastBlockNumber > expectedLastBlockNumber { + Fail(t, "unexpected block number", lastBlockNumber, "vs", expectedLastBlockNumber) + } else if i == 10 { + Fail(t, "timeout waiting for block number", expectedLastBlockNumber, "current", lastBlockNumber) + } + time.Sleep(time.Millisecond * 100) } for _, state := range blockStates { diff --git a/arbnode/inbox_tracker.go b/arbnode/inbox_tracker.go index b6a1afd02b..c82e45fbee 100644 --- a/arbnode/inbox_tracker.go +++ b/arbnode/inbox_tracker.go @@ -127,7 +127,7 @@ func (t *InboxTracker) GetDelayedAcc(seqNum uint64) (common.Hash, error) { return common.Hash{}, err } if !hasKey { - return common.Hash{}, AccumulatorNotFoundErr + return common.Hash{}, fmt.Errorf("%w: not found delayed %d", AccumulatorNotFoundErr, seqNum) } } data, err := t.db.Get(key) @@ -175,7 +175,7 @@ func (t *InboxTracker) GetBatchMetadata(seqNum uint64) (BatchMetadata, error) { return BatchMetadata{}, err } if !hasKey { - return BatchMetadata{}, AccumulatorNotFoundErr + return BatchMetadata{}, fmt.Errorf("%w: no metadata for batch %d", AccumulatorNotFoundErr, seqNum) } data, err := t.db.Get(key) if err != nil { @@ -694,18 +694,6 @@ func (t *InboxTracker) AddSequencerBatches(ctx context.Context, client arbutil.L } t.batchMetaMutex.Unlock() - if t.validator != nil { - batchBytes := make([][]byte, 0, len(batches)) - for _, batch := range batches { - msg, err := batch.Serialize(ctx, client) - if err != nil { - return err - } - batchBytes = append(batchBytes, msg) - } - t.validator.ProcessBatches(startPos, batchBytes) - } - if t.txStreamer.broadcastServer != nil && pos > 1 { prevprevbatchmeta, err := t.GetBatchMetadata(pos - 2) if errors.Is(err, AccumulatorNotFoundErr) { diff --git a/arbnode/message_pruner.go b/arbnode/message_pruner.go index 1ba3886d8d..aeee07ca73 100644 --- a/arbnode/message_pruner.go +++ b/arbnode/message_pruner.go @@ -7,16 +7,16 @@ import ( "bytes" "context" "encoding/binary" - "math/big" + "fmt" + "sync" "time" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/rpc" - "github.com/offchainlabs/nitro/staker" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/offchainlabs/nitro/validator" flag "github.com/spf13/pflag" ) @@ -25,13 +25,15 @@ type MessagePruner struct { stopwaiter.StopWaiter transactionStreamer *TransactionStreamer inboxTracker *InboxTracker - staker *staker.Staker config MessagePrunerConfigFetcher + pruningLock sync.Mutex + lastPruneDone time.Time } type MessagePrunerConfig struct { Enable bool `koanf:"enable"` MessagePruneInterval time.Duration `koanf:"prune-interval" reload:"hot"` + MinBatchesLeft uint64 `koanf:"min-batches-left" reload:"hot"` } type MessagePrunerConfigFetcher func() *MessagePrunerConfig @@ -39,85 +41,96 @@ type MessagePrunerConfigFetcher func() *MessagePrunerConfig var DefaultMessagePrunerConfig = MessagePrunerConfig{ Enable: true, MessagePruneInterval: time.Minute, + MinBatchesLeft: 2, } func MessagePrunerConfigAddOptions(prefix string, f *flag.FlagSet) { f.Bool(prefix+".enable", DefaultMessagePrunerConfig.Enable, "enable message pruning") f.Duration(prefix+".prune-interval", DefaultMessagePrunerConfig.MessagePruneInterval, "interval for running message pruner") + f.Uint64(prefix+".min-batches-left", DefaultMessagePrunerConfig.MinBatchesLeft, "min number of batches not pruned") } -func NewMessagePruner(transactionStreamer *TransactionStreamer, inboxTracker *InboxTracker, staker *staker.Staker, config MessagePrunerConfigFetcher) *MessagePruner { +func NewMessagePruner(transactionStreamer *TransactionStreamer, inboxTracker *InboxTracker, config MessagePrunerConfigFetcher) *MessagePruner { return &MessagePruner{ transactionStreamer: transactionStreamer, inboxTracker: inboxTracker, - staker: staker, config: config, } } func (m *MessagePruner) Start(ctxIn context.Context) { m.StopWaiter.Start(ctxIn, m) - m.CallIteratively(m.prune) } -func (m *MessagePruner) prune(ctx context.Context) time.Duration { - latestConfirmedNode, err := m.staker.Rollup().LatestConfirmed( - &bind.CallOpts{ - Context: ctx, - BlockNumber: big.NewInt(int64(rpc.FinalizedBlockNumber)), - }) - if err != nil { - log.Error("error getting latest confirmed node", "err", err) - return m.config().MessagePruneInterval - } - nodeInfo, err := m.staker.Rollup().LookupNode(ctx, latestConfirmedNode) - if err != nil { - log.Error("error getting latest confirmed node info", "node", latestConfirmedNode, "err", err) - return m.config().MessagePruneInterval +func (m *MessagePruner) UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) { + locked := m.pruningLock.TryLock() + if !locked { + return } - endBatchCount := nodeInfo.Assertion.AfterState.GlobalState.Batch - if endBatchCount == 0 { - return m.config().MessagePruneInterval + + if m.lastPruneDone.Add(m.config().MessagePruneInterval).After(time.Now()) { + m.pruningLock.Unlock() + return } - endBatchMetadata, err := m.inboxTracker.GetBatchMetadata(endBatchCount - 1) + err := m.LaunchThreadSafe(func(ctx context.Context) { + defer m.pruningLock.Unlock() + err := m.prune(ctx, count, globalState) + if err != nil && ctx.Err() == nil { + log.Error("error while pruning", "err", err) + } + }) if err != nil { - log.Error("error getting last batch metadata", "batch", endBatchCount-1, "err", err) - return m.config().MessagePruneInterval + log.Info("failed launching prune thread", "err", err) + m.pruningLock.Unlock() } - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, m.inboxTracker.db, m.transactionStreamer.db) - return m.config().MessagePruneInterval } -func deleteOldMessageFromDB(endBatchCount uint64, endBatchMetadata BatchMetadata, inboxTrackerDb ethdb.Database, transactionStreamerDb ethdb.Database) { - prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(inboxTrackerDb, sequencerBatchMetaPrefix, endBatchCount) +func (m *MessagePruner) prune(ctx context.Context, count arbutil.MessageIndex, globalState validator.GoGlobalState) error { + trimBatchCount := globalState.Batch + minBatchesLeft := m.config().MinBatchesLeft + batchCount, err := m.inboxTracker.GetBatchCount() if err != nil { - log.Error("error deleting batch metadata", "err", err) - return + return err } - if len(prunedKeysRange) > 0 { - log.Info("Pruned batches:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) + if batchCount < trimBatchCount+minBatchesLeft { + if batchCount < minBatchesLeft { + return nil + } + trimBatchCount = batchCount - minBatchesLeft } + if trimBatchCount < 1 { + return nil + } + endBatchMetadata, err := m.inboxTracker.GetBatchMetadata(trimBatchCount - 1) + if err != nil { + return err + } + msgCount := endBatchMetadata.MessageCount + delayedCount := endBatchMetadata.DelayedMessageCount - prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(transactionStreamerDb, messagePrefix, uint64(endBatchMetadata.MessageCount)) + return deleteOldMessageFromDB(ctx, msgCount, delayedCount, m.inboxTracker.db, m.transactionStreamer.db) +} + +func deleteOldMessageFromDB(ctx context.Context, messageCount arbutil.MessageIndex, delayedMessageCount uint64, inboxTrackerDb ethdb.Database, transactionStreamerDb ethdb.Database) error { + prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(ctx, transactionStreamerDb, messagePrefix, uint64(messageCount)) if err != nil { - log.Error("error deleting last batch messages", "err", err) - return + return fmt.Errorf("error deleting last batch messages: %w", err) } if len(prunedKeysRange) > 0 { log.Info("Pruned last batch messages:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) } - prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(inboxTrackerDb, rlpDelayedMessagePrefix, endBatchMetadata.DelayedMessageCount) + prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(ctx, inboxTrackerDb, rlpDelayedMessagePrefix, delayedMessageCount) if err != nil { - log.Error("error deleting last batch delayed messages", "err", err) - return + return fmt.Errorf("error deleting last batch delayed messages: %w", err) } if len(prunedKeysRange) > 0 { log.Info("Pruned last batch delayed messages:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) } + return nil } -func deleteFromLastPrunedUptoEndKey(db ethdb.Database, prefix []byte, endMinKey uint64) ([][]byte, error) { +func deleteFromLastPrunedUptoEndKey(ctx context.Context, db ethdb.Database, prefix []byte, endMinKey uint64) ([]uint64, error) { startIter := db.NewIterator(prefix, uint64ToKey(1)) if !startIter.Next() { return nil, nil @@ -125,7 +138,7 @@ func deleteFromLastPrunedUptoEndKey(db ethdb.Database, prefix []byte, endMinKey startMinKey := binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) startIter.Release() if endMinKey > startMinKey { - return deleteFromRange(db, prefix, startMinKey, endMinKey-1) + return deleteFromRange(ctx, db, prefix, startMinKey, endMinKey-1) } return nil, nil } diff --git a/arbnode/message_pruner_test.go b/arbnode/message_pruner_test.go index 16c1d6b71c..c0cb2cb4fe 100644 --- a/arbnode/message_pruner_test.go +++ b/arbnode/message_pruner_test.go @@ -4,87 +4,88 @@ package arbnode import ( + "context" "testing" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/ethdb" + "github.com/offchainlabs/nitro/arbutil" ) func TestMessagePrunerWithPruningEligibleMessagePresent(t *testing.T) { - endBatchCount := uint64(2 * 100 * 1024) - endBatchMetadata := BatchMetadata{ - MessageCount: 2 * 100 * 1024, - DelayedMessageCount: 2 * 100 * 1024, - } - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, endBatchCount, endBatchMetadata) - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(2 * 100 * 1024) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*100*1024, 2*100*1024) + err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + Require(t, err) - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) - checkDbKeys(t, uint64(endBatchMetadata.MessageCount), transactionStreamerDb, messagePrefix) - checkDbKeys(t, endBatchMetadata.DelayedMessageCount, inboxTrackerDb, rlpDelayedMessagePrefix) + checkDbKeys(t, messagesCount, transactionStreamerDb, messagePrefix) + checkDbKeys(t, messagesCount, inboxTrackerDb, rlpDelayedMessagePrefix) } func TestMessagePrunerTraverseEachMessageOnlyOnce(t *testing.T) { - endBatchCount := uint64(10) - endBatchMetadata := BatchMetadata{} - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, endBatchCount, endBatchMetadata) - // In first iteration message till endBatchCount are tried to be deleted. - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) - // In first iteration all the message till endBatchCount are deleted. - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) - // After first iteration endBatchCount/2 is reinserted in inbox db - err := inboxTrackerDb.Put(dbKey(sequencerBatchMetaPrefix, endBatchCount/2), []byte{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(10) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, messagesCount, messagesCount) + // In first iteration message till messagesCount are tried to be deleted. + err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) Require(t, err) - // In second iteration message till endBatchCount are again tried to be deleted. - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) - // In second iteration all the message till endBatchCount are deleted again. - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) + // After first iteration messagesCount/2 is reinserted in inbox db + err = inboxTrackerDb.Put(dbKey(messagePrefix, messagesCount/2), []byte{}) + Require(t, err) + // In second iteration message till messagesCount are again tried to be deleted. + err = deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + Require(t, err) + // In second iteration all the message till messagesCount are deleted again. + checkDbKeys(t, messagesCount, transactionStreamerDb, messagePrefix) } func TestMessagePrunerPruneTillLessThenEqualTo(t *testing.T) { - endBatchCount := uint64(10) - endBatchMetadata := BatchMetadata{} - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*endBatchCount, endBatchMetadata) - err := inboxTrackerDb.Delete(dbKey(sequencerBatchMetaPrefix, 9)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(10) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*messagesCount, 20) + err := inboxTrackerDb.Delete(dbKey(messagePrefix, 9)) + Require(t, err) + err = deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) Require(t, err) - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) - hasKey, err := inboxTrackerDb.Has(dbKey(sequencerBatchMetaPrefix, 10)) + hasKey, err := transactionStreamerDb.Has(dbKey(messagePrefix, messagesCount)) Require(t, err) if !hasKey { - Fail(t, "Key", 10, "with prefix", string(sequencerBatchMetaPrefix), "should be present after pruning") + Fail(t, "Key", 10, "with prefix", string(messagePrefix), "should be present after pruning") } } func TestMessagePrunerWithNoPruningEligibleMessagePresent(t *testing.T) { - endBatchCount := uint64(2) - endBatchMetadata := BatchMetadata{ - MessageCount: 2, - DelayedMessageCount: 2, - } - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, endBatchCount, endBatchMetadata) - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(10) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, messagesCount, messagesCount) + err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + Require(t, err) - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) - checkDbKeys(t, uint64(endBatchMetadata.MessageCount), transactionStreamerDb, messagePrefix) - checkDbKeys(t, endBatchMetadata.DelayedMessageCount, inboxTrackerDb, rlpDelayedMessagePrefix) + checkDbKeys(t, uint64(messagesCount), transactionStreamerDb, messagePrefix) + checkDbKeys(t, messagesCount, inboxTrackerDb, rlpDelayedMessagePrefix) } -func setupDatabase(t *testing.T, endBatchCount uint64, endBatchMetadata BatchMetadata) (ethdb.Database, ethdb.Database) { - inboxTrackerDb := rawdb.NewMemoryDatabase() - for i := uint64(0); i < endBatchCount; i++ { - err := inboxTrackerDb.Put(dbKey(sequencerBatchMetaPrefix, i), []byte{}) - Require(t, err) - } +func setupDatabase(t *testing.T, messageCount, delayedMessageCount uint64) (ethdb.Database, ethdb.Database) { transactionStreamerDb := rawdb.NewMemoryDatabase() - for i := uint64(0); i < uint64(endBatchMetadata.MessageCount); i++ { + for i := uint64(0); i < uint64(messageCount); i++ { err := transactionStreamerDb.Put(dbKey(messagePrefix, i), []byte{}) Require(t, err) } - for i := uint64(0); i < endBatchMetadata.DelayedMessageCount; i++ { + inboxTrackerDb := rawdb.NewMemoryDatabase() + for i := uint64(0); i < delayedMessageCount; i++ { err := inboxTrackerDb.Put(dbKey(rlpDelayedMessagePrefix, i), []byte{}) Require(t, err) } @@ -93,6 +94,7 @@ func setupDatabase(t *testing.T, endBatchCount uint64, endBatchMetadata BatchMet } func checkDbKeys(t *testing.T, endCount uint64, db ethdb.Database, prefix []byte) { + t.Helper() for i := uint64(0); i < endCount; i++ { hasKey, err := db.Has(dbKey(prefix, i)) Require(t, err) diff --git a/arbnode/node.go b/arbnode/node.go index 02b9857877..8a4f38f28c 100644 --- a/arbnode/node.go +++ b/arbnode/node.go @@ -26,7 +26,6 @@ import ( "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" - "github.com/offchainlabs/nitro/arbnode/execution" "github.com/offchainlabs/nitro/arbnode/resourcemanager" "github.com/offchainlabs/nitro/arbutil" @@ -308,29 +307,30 @@ func DeployOnL1(ctx context.Context, l1client arbutil.L1Interface, deployAuth *b } type Config struct { - RPC arbitrum.Config `koanf:"rpc"` - Sequencer execution.SequencerConfig `koanf:"sequencer" reload:"hot"` - L1Reader headerreader.Config `koanf:"parent-chain-reader" reload:"hot"` - InboxReader InboxReaderConfig `koanf:"inbox-reader" reload:"hot"` - DelayedSequencer DelayedSequencerConfig `koanf:"delayed-sequencer" reload:"hot"` - BatchPoster BatchPosterConfig `koanf:"batch-poster" reload:"hot"` - MessagePruner MessagePrunerConfig `koanf:"message-pruner" reload:"hot"` - ForwardingTargetImpl string `koanf:"forwarding-target"` - Forwarder execution.ForwarderConfig `koanf:"forwarder"` - TxPreChecker execution.TxPreCheckerConfig `koanf:"tx-pre-checker" reload:"hot"` - BlockValidator staker.BlockValidatorConfig `koanf:"block-validator" reload:"hot"` - Feed broadcastclient.FeedConfig `koanf:"feed" reload:"hot"` - Staker staker.L1ValidatorConfig `koanf:"staker"` - SeqCoordinator SeqCoordinatorConfig `koanf:"seq-coordinator"` - DataAvailability das.DataAvailabilityConfig `koanf:"data-availability"` - SyncMonitor SyncMonitorConfig `koanf:"sync-monitor"` - Dangerous DangerousConfig `koanf:"dangerous"` - Caching execution.CachingConfig `koanf:"caching"` - Archive bool `koanf:"archive"` - TxLookupLimit uint64 `koanf:"tx-lookup-limit"` - TransactionStreamer TransactionStreamerConfig `koanf:"transaction-streamer" reload:"hot"` - Maintenance MaintenanceConfig `koanf:"maintenance" reload:"hot"` - ResourceManagement resourcemanager.Config `koanf:"resource-mgmt" reload:"hot"` + RPC arbitrum.Config `koanf:"rpc"` + Sequencer execution.SequencerConfig `koanf:"sequencer" reload:"hot"` + L1Reader headerreader.Config `koanf:"parent-chain-reader" reload:"hot"` + InboxReader InboxReaderConfig `koanf:"inbox-reader" reload:"hot"` + DelayedSequencer DelayedSequencerConfig `koanf:"delayed-sequencer" reload:"hot"` + BatchPoster BatchPosterConfig `koanf:"batch-poster" reload:"hot"` + MessagePruner MessagePrunerConfig `koanf:"message-pruner" reload:"hot"` + ForwardingTargetImpl string `koanf:"forwarding-target"` + Forwarder execution.ForwarderConfig `koanf:"forwarder"` + TxPreChecker execution.TxPreCheckerConfig `koanf:"tx-pre-checker" reload:"hot"` + BlockValidator staker.BlockValidatorConfig `koanf:"block-validator" reload:"hot"` + RecordingDB arbitrum.RecordingDatabaseConfig `koanf:"recording-database"` + Feed broadcastclient.FeedConfig `koanf:"feed" reload:"hot"` + Staker staker.L1ValidatorConfig `koanf:"staker"` + SeqCoordinator SeqCoordinatorConfig `koanf:"seq-coordinator"` + DataAvailability das.DataAvailabilityConfig `koanf:"data-availability"` + SyncMonitor SyncMonitorConfig `koanf:"sync-monitor"` + Dangerous DangerousConfig `koanf:"dangerous"` + Caching execution.CachingConfig `koanf:"caching"` + Archive bool `koanf:"archive"` + TxLookupLimit uint64 `koanf:"tx-lookup-limit"` + TransactionStreamer TransactionStreamerConfig `koanf:"transaction-streamer" reload:"hot"` + Maintenance MaintenanceConfig `koanf:"maintenance" reload:"hot"` + ResourceManagement resourcemanager.Config `koanf:"resource-mgmt" reload:"hot"` } func (c *Config) Validate() error { @@ -394,6 +394,7 @@ func ConfigAddOptions(prefix string, f *flag.FlagSet, feedInputEnable bool, feed execution.AddOptionsForNodeForwarderConfig(prefix+".forwarder", f) execution.TxPreCheckerConfigAddOptions(prefix+".tx-pre-checker", f) staker.BlockValidatorConfigAddOptions(prefix+".block-validator", f) + arbitrum.RecordingDatabaseConfigAddOptions(prefix+".recording-database", f) broadcastclient.FeedConfigAddOptions(prefix+".feed", f, feedInputEnable, feedOutputEnable) staker.L1ValidatorConfigAddOptions(prefix+".staker", f) SeqCoordinatorConfigAddOptions(prefix+".seq-coordinator", f) @@ -421,6 +422,7 @@ var ConfigDefault = Config{ ForwardingTargetImpl: "", TxPreChecker: execution.DefaultTxPreCheckerConfig, BlockValidator: staker.DefaultBlockValidatorConfig, + RecordingDB: arbitrum.DefaultRecordingDatabaseConfig, Feed: broadcastclient.FeedConfigDefault, Staker: staker.DefaultL1ValidatorConfig, SeqCoordinator: DefaultSeqCoordinatorConfig, @@ -478,18 +480,15 @@ func ConfigDefaultL2Test() *Config { } type DangerousConfig struct { - NoL1Listener bool `koanf:"no-l1-listener"` - ReorgToBlock int64 `koanf:"reorg-to-block"` + NoL1Listener bool `koanf:"no-l1-listener"` } var DefaultDangerousConfig = DangerousConfig{ NoL1Listener: false, - ReorgToBlock: -1, } func DangerousConfigAddOptions(prefix string, f *flag.FlagSet) { f.Bool(prefix+".no-l1-listener", DefaultDangerousConfig.NoL1Listener, "DANGEROUS! disables listening to L1. To be used in test nodes only") - f.Int64(prefix+".reorg-to-block", DefaultDangerousConfig.ReorgToBlock, "DANGEROUS! forces a reorg to an old block height. To be used for testing only. -1 to disable") } type Node struct { @@ -589,14 +588,6 @@ func createNodeImpl( l2Config := l2BlockChain.Config() l2ChainId := l2Config.ChainID.Uint64() - var reorgingToBlock *types.Block - if config.Dangerous.ReorgToBlock >= 0 { - reorgingToBlock, err = execution.ReorgToBlock(l2BlockChain, uint64(config.Dangerous.ReorgToBlock)) - if err != nil { - return nil, err - } - } - syncMonitor := NewSyncMonitor(&config.SyncMonitor) var classicOutbox *ClassicOutboxRetriever classicMsgDb, err := stack.OpenDatabase("classic-msg", 0, 0, "", true) @@ -620,7 +611,7 @@ func createNodeImpl( sequencerConfigFetcher := func() *execution.SequencerConfig { return &configFetcher.Get().Sequencer } txprecheckConfigFetcher := func() *execution.TxPreCheckerConfig { return &configFetcher.Get().TxPreChecker } exec, err := execution.CreateExecutionNode(stack, chainDb, l2BlockChain, l1Reader, syncMonitor, - config.ForwardingTarget(), &config.Forwarder, config.RPC, + config.ForwardingTarget(), &config.Forwarder, config.RPC, &config.RecordingDB, sequencerConfigFetcher, txprecheckConfigFetcher) if err != nil { return nil, err @@ -774,8 +765,7 @@ func createNodeImpl( inboxReader, inboxTracker, txStreamer, - l2BlockChain, - chainDb, + exec.Recorder, rawdb.NewTable(arbDb, BlockValidatorPrefix), daReader, func() *staker.BlockValidatorConfig { return &configFetcher.Get().BlockValidator }, @@ -798,7 +788,6 @@ func createNodeImpl( statelessBlockValidator, inboxTracker, txStreamer, - reorgingToBlock, func() *staker.BlockValidatorConfig { return &configFetcher.Get().BlockValidator }, fatalErrChan, ) @@ -808,6 +797,8 @@ func createNodeImpl( } var stakerObj *staker.Staker + var messagePruner *MessagePruner + if config.Staker.Enable { var wallet staker.ValidatorWalletInterface if config.Staker.UseSmartContractWallet || txOptsValidator == nil { @@ -834,7 +825,13 @@ func createNodeImpl( } } - stakerObj, err = staker.NewStaker(l1Reader, wallet, bind.CallOpts{}, config.Staker, blockValidator, statelessBlockValidator, deployInfo.ValidatorUtils) + notifiers := make([]staker.LatestStakedNotifier, 0) + if config.MessagePruner.Enable && !config.Caching.Archive { + messagePruner = NewMessagePruner(txStreamer, inboxTracker, func() *MessagePrunerConfig { return &configFetcher.Get().MessagePruner }) + notifiers = append(notifiers, messagePruner) + } + + stakerObj, err = staker.NewStaker(l1Reader, wallet, bind.CallOpts{}, config.Staker, blockValidator, statelessBlockValidator, notifiers, deployInfo.ValidatorUtils, fatalErrChan) if err != nil { return nil, err } @@ -861,15 +858,11 @@ func createNodeImpl( if txOptsBatchPoster == nil { return nil, errors.New("batchposter, but no TxOpts") } - batchPoster, err = NewBatchPoster(l1Reader, inboxTracker, txStreamer, syncMonitor, func() *BatchPosterConfig { return &configFetcher.Get().BatchPoster }, deployInfo, txOptsBatchPoster, daWriter) + batchPoster, err = NewBatchPoster(rawdb.NewTable(arbDb, BlockValidatorPrefix), l1Reader, inboxTracker, txStreamer, syncMonitor, func() *BatchPosterConfig { return &configFetcher.Get().BatchPoster }, deployInfo, txOptsBatchPoster, daWriter) if err != nil { return nil, err } } - var messagePruner *MessagePruner - if config.MessagePruner.Enable && !config.Caching.Archive && stakerObj != nil { - messagePruner = NewMessagePruner(txStreamer, inboxTracker, stakerObj, func() *MessagePrunerConfig { return &configFetcher.Get().MessagePruner }) - } // always create DelayedSequencer, it won't do anything if it is disabled delayedSequencer, err = NewDelayedSequencer(l1Reader, inboxReader, exec.ExecEngine, coordinator, func() *DelayedSequencerConfig { return &configFetcher.Get().DelayedSequencer }) if err != nil { @@ -1153,6 +1146,7 @@ func (n *Node) StopAndWait() { if n.StatelessBlockValidator != nil { n.StatelessBlockValidator.Stop() } + n.Execution.Recorder.OrderlyShutdown() if n.InboxReader != nil && n.InboxReader.Started() { n.InboxReader.StopAndWait() } @@ -1180,19 +1174,3 @@ func (n *Node) StopAndWait() { log.Error("error on stak close", "err", err) } } - -func CreateDefaultStackForTest(dataDir string) (*node.Node, error) { - stackConf := node.DefaultConfig - var err error - stackConf.DataDir = dataDir - stackConf.HTTPHost = "" - stackConf.HTTPModules = append(stackConf.HTTPModules, "eth") - stackConf.P2P.NoDiscovery = true - stackConf.P2P.ListenAddr = "" - - stack, err := node.New(&stackConf) - if err != nil { - return nil, fmt.Errorf("error creating protocol stack: %w", err) - } - return stack, nil -} diff --git a/arbnode/seq_coordinator.go b/arbnode/seq_coordinator.go index ecb38129ac..31cab83b1f 100644 --- a/arbnode/seq_coordinator.go +++ b/arbnode/seq_coordinator.go @@ -619,12 +619,12 @@ func (c *SeqCoordinator) update(ctx context.Context) time.Duration { log.Error("myurl main sequencer, but no sequencer exists") return c.noRedisError() } - processedMessages, err := c.streamer.exec.HeadMessageNumber() + processedMessages, err := c.streamer.GetProcessedMessageCount() if err != nil { log.Warn("coordinator: failed to read processed message count", "err", err) processedMessages = 0 } - if processedMessages+1 >= localMsgCount { + if processedMessages >= localMsgCount { // we're here because we don't currently hold the lock // sequencer is already either paused or forwarding c.sequencer.Pause() diff --git a/arbnode/sync_monitor.go b/arbnode/sync_monitor.go index a9380e65f6..d01c300fa9 100644 --- a/arbnode/sync_monitor.go +++ b/arbnode/sync_monitor.go @@ -70,13 +70,8 @@ func (s *SyncMonitor) SyncProgressMap() map[string]interface{} { syncing = true builtMessageCount = 0 } else { - blockNum, err := s.txStreamer.exec.MessageCountToBlockNumber(builtMessageCount) - if err != nil { - res["blockBuiltErr"] = err - syncing = true - } else { - res["blockNum"] = blockNum - } + blockNum := s.txStreamer.exec.MessageIndexToBlockNumber(builtMessageCount) + res["blockNum"] = blockNum builtMessageCount++ res["messageOfLastBlock"] = builtMessageCount } @@ -155,8 +150,8 @@ func (s *SyncMonitor) SafeBlockNumber(ctx context.Context) (uint64, error) { if err != nil { return 0, err } - block, err := s.txStreamer.exec.MessageCountToBlockNumber(msg) - return uint64(block), err + block := s.txStreamer.exec.MessageIndexToBlockNumber(msg - 1) + return block, nil } func (s *SyncMonitor) FinalizedBlockNumber(ctx context.Context) (uint64, error) { @@ -167,8 +162,8 @@ func (s *SyncMonitor) FinalizedBlockNumber(ctx context.Context) (uint64, error) if err != nil { return 0, err } - block, err := s.txStreamer.exec.MessageCountToBlockNumber(msg) - return uint64(block), err + block := s.txStreamer.exec.MessageIndexToBlockNumber(msg - 1) + return block, nil } func (s *SyncMonitor) Synced() bool { diff --git a/arbnode/transaction_streamer.go b/arbnode/transaction_streamer.go index a6a11b0b84..8922790b6c 100644 --- a/arbnode/transaction_streamer.go +++ b/arbnode/transaction_streamer.go @@ -46,6 +46,7 @@ type TransactionStreamer struct { chainConfig *params.ChainConfig exec *execution.ExecutionEngine execLastMsgCount arbutil.MessageIndex + validator *staker.BlockValidator db ethdb.Database fatalErrChan chan<- error @@ -128,7 +129,13 @@ func uint64ToKey(x uint64) []byte { } func (s *TransactionStreamer) SetBlockValidator(validator *staker.BlockValidator) { - s.exec.SetBlockValidator(validator) + if s.Started() { + panic("trying to set coordinator after start") + } + if s.validator != nil { + panic("trying to set coordinator when already set") + } + s.validator = validator } func (s *TransactionStreamer) SetSeqCoordinator(coordinator *SeqCoordinator) { @@ -199,19 +206,24 @@ func deleteStartingAt(db ethdb.Database, batch ethdb.Batch, prefix []byte, minKe } // deleteFromRange deletes key ranging from startMinKey(inclusive) to endMinKey(exclusive) -func deleteFromRange(db ethdb.Database, prefix []byte, startMinKey uint64, endMinKey uint64) ([][]byte, error) { +// might have deleted some keys even if returning an error +func deleteFromRange(ctx context.Context, db ethdb.Database, prefix []byte, startMinKey uint64, endMinKey uint64) ([]uint64, error) { batch := db.NewBatch() startIter := db.NewIterator(prefix, uint64ToKey(startMinKey)) defer startIter.Release() - var prunedKeysRange [][]byte + var prunedKeysRange []uint64 for startIter.Next() { - if binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) >= endMinKey { + if ctx.Err() != nil { + return nil, ctx.Err() + } + currentKey := binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) + if currentKey >= endMinKey { break } if len(prunedKeysRange) == 0 || len(prunedKeysRange) == 1 { - prunedKeysRange = append(prunedKeysRange, startIter.Key()) + prunedKeysRange = append(prunedKeysRange, currentKey) } else { - prunedKeysRange[1] = startIter.Key() + prunedKeysRange[1] = currentKey } err := batch.Delete(startIter.Key()) if err != nil { @@ -328,6 +340,13 @@ func (s *TransactionStreamer) reorg(batch ethdb.Batch, count arbutil.MessageInde return err } + if s.validator != nil { + err = s.validator.Reorg(s.GetContext(), count) + if err != nil { + return err + } + } + err = deleteStartingAt(s.db, batch, messagePrefix, uint64ToKey(uint64(count))) if err != nil { return err @@ -387,6 +406,21 @@ func (s *TransactionStreamer) GetMessageCount() (arbutil.MessageIndex, error) { return arbutil.MessageIndex(pos), nil } +func (s *TransactionStreamer) GetProcessedMessageCount() (arbutil.MessageIndex, error) { + msgCount, err := s.GetMessageCount() + if err != nil { + return 0, err + } + digestedHead, err := s.exec.HeadMessageNumber() + if err != nil { + return 0, err + } + if msgCount > digestedHead+1 { + return digestedHead + 1, nil + } + return msgCount, nil +} + func (s *TransactionStreamer) AddMessages(pos arbutil.MessageIndex, messagesAreConfirmed bool, messages []arbostypes.MessageWithMetadata) error { return s.AddMessagesAndEndBatch(pos, messagesAreConfirmed, messages, nil) } @@ -883,6 +917,14 @@ func (s *TransactionStreamer) writeMessages(pos arbutil.MessageIndex, messages [ return nil } +// TODO: eventually there will be a table maintained by txStreamer itself +func (s *TransactionStreamer) ResultAtCount(count arbutil.MessageIndex) (*execution.MessageResult, error) { + if count == 0 { + return &execution.MessageResult{}, nil + } + return s.exec.ResultAtPos(count - 1) +} + // return value: true if should be called again immediately func (s *TransactionStreamer) executeNextMsg(ctx context.Context, exec *execution.ExecutionEngine) bool { if ctx.Err() != nil { @@ -898,7 +940,7 @@ func (s *TransactionStreamer) executeNextMsg(ctx context.Context, exec *executio log.Error("feedOneMsg failed to get message count", "err", err) return false } - s.execLastMsgCount = prevMessageCount + s.execLastMsgCount = msgCount pos, err := s.exec.HeadMessageNumber() if err != nil { log.Error("feedOneMsg failed to get exec engine message count", "err", err) diff --git a/arbos/block_processor.go b/arbos/block_processor.go index 9f208c4404..653718b7d3 100644 --- a/arbos/block_processor.go +++ b/arbos/block_processor.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/trie" ) @@ -39,6 +40,7 @@ var L2ToL1TransactionEventID common.Hash var L2ToL1TxEventID common.Hash var EmitReedeemScheduledEvent func(*vm.EVM, uint64, uint64, [32]byte, [32]byte, common.Address, *big.Int, *big.Int) error var EmitTicketCreatedEvent func(*vm.EVM, [32]byte) error +var gasUsedSinceStartupCounter = metrics.NewRegisteredCounter("arb/sequencer/gasused", nil) type L1Info struct { poster common.Address @@ -343,11 +345,7 @@ func ProduceBlockAdvanced( log.Debug("error applying transaction", "tx", tx, "err", err) if !hooks.DiscardInvalidTxsEarly { // we'll still deduct a TxGas's worth from the block-local rate limiter even if the tx was invalid - if blockGasLeft > params.TxGas { - blockGasLeft -= params.TxGas - } else { - blockGasLeft = 0 - } + blockGasLeft = arbmath.SaturatingUSub(blockGasLeft, params.TxGas) if isUserTx { userTxsProcessed++ } @@ -416,11 +414,11 @@ func ProduceBlockAdvanced( } } - if blockGasLeft > computeUsed { - blockGasLeft -= computeUsed - } else { - blockGasLeft = 0 - } + blockGasLeft = arbmath.SaturatingUSub(blockGasLeft, computeUsed) + + // Add gas used since startup to prometheus metric. + gasUsed := arbmath.SaturatingUSub(receipt.GasUsed, receipt.GasUsedForL1) + gasUsedSinceStartupCounter.Inc(arbmath.SaturatingCast(gasUsed)) complete = append(complete, tx) receipts = append(receipts, receipt) diff --git a/cmd/daserver/daserver.go b/cmd/daserver/daserver.go index 7b6b504e40..1a587d6847 100644 --- a/cmd/daserver/daserver.go +++ b/cmd/daserver/daserver.go @@ -45,6 +45,8 @@ type DAServerConfig struct { Metrics bool `koanf:"metrics"` MetricsServer genericconf.MetricsServerConfig `koanf:"metrics-server"` + PProf bool `koanf:"pprof"` + PprofCfg genericconf.PProf `koanf:"pprof-cfg"` } var DefaultDAServerConfig = DAServerConfig{ @@ -60,6 +62,8 @@ var DefaultDAServerConfig = DAServerConfig{ ConfConfig: genericconf.ConfConfigDefault, Metrics: false, MetricsServer: genericconf.MetricsServerConfigDefault, + PProf: false, + PprofCfg: genericconf.PProfDefault, LogLevel: 3, } @@ -89,6 +93,9 @@ func parseDAServer(args []string) (*DAServerConfig, error) { f.Bool("metrics", DefaultDAServerConfig.Metrics, "enable metrics") genericconf.MetricsServerAddOptions("metrics-server", f) + f.Bool("pprof", DefaultDAServerConfig.PProf, "enable pprof") + genericconf.PProfAddOptions("pprof-cfg", f) + f.Int("log-level", int(log.LvlInfo), "log level; 1: ERROR, 2: WARN, 3: INFO, 4: DEBUG, 5: TRACE") das.DataAvailabilityConfigAddDaserverOptions("data-availability", f) genericconf.ConfConfigAddOptions("conf", f) @@ -135,6 +142,25 @@ func (c *L1ReaderCloser) String() string { return "l1 reader closer" } +// Checks metrics and PProf flag, runs them if enabled. +// Note: they are separate so one can enable/disable them as they wish, the only +// requirement is that they can't run on the same address and port. +func startMetrics(cfg *DAServerConfig) error { + mAddr := fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port) + pAddr := fmt.Sprintf("%v:%v", cfg.PprofCfg.Addr, cfg.PprofCfg.Port) + if cfg.Metrics && cfg.PProf && mAddr == pAddr { + return fmt.Errorf("metrics and pprof cannot be enabled on the same address:port: %s", mAddr) + } + if cfg.Metrics { + go metrics.CollectProcessMetrics(cfg.MetricsServer.UpdateInterval) + exp.Setup(fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port)) + } + if cfg.PProf { + genericconf.StartPprof(pAddr) + } + return nil +} + func startup() error { // Some different defaults to DAS config in a node. das.DefaultDataAvailabilityConfig.Enable = true @@ -151,16 +177,8 @@ func startup() error { glogger.Verbosity(log.Lvl(serverConfig.LogLevel)) log.Root().SetHandler(glogger) - if serverConfig.Metrics { - if len(serverConfig.MetricsServer.Addr) == 0 { - fmt.Printf("Metrics is enabled, but missing --metrics-server.addr") - return nil - } - - go metrics.CollectProcessMetrics(serverConfig.MetricsServer.UpdateInterval) - - address := fmt.Sprintf("%v:%v", serverConfig.MetricsServer.Addr, serverConfig.MetricsServer.Port) - exp.Setup(address) + if err := startMetrics(serverConfig); err != nil { + return err } sigint := make(chan os.Signal, 1) diff --git a/cmd/genericconf/pprof.go b/cmd/genericconf/pprof.go index 8f756bbf45..e55bfddd32 100644 --- a/cmd/genericconf/pprof.go +++ b/cmd/genericconf/pprof.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" + // Blank import pprof registers its HTTP handlers. _ "net/http/pprof" // #nosec G108 "github.com/ethereum/go-ethereum/log" diff --git a/cmd/genericconf/server.go b/cmd/genericconf/server.go index 17c4a7a872..b99429191e 100644 --- a/cmd/genericconf/server.go +++ b/cmd/genericconf/server.go @@ -189,20 +189,32 @@ func AuthRPCConfigAddOptions(prefix string, f *flag.FlagSet) { type MetricsServerConfig struct { Addr string `koanf:"addr"` Port int `koanf:"port"` - Pprof bool `koanf:"pprof"` UpdateInterval time.Duration `koanf:"update-interval"` } var MetricsServerConfigDefault = MetricsServerConfig{ Addr: "127.0.0.1", Port: 6070, - Pprof: false, UpdateInterval: 3 * time.Second, } +type PProf struct { + Addr string `koanf:"addr"` + Port int `koanf:"port"` +} + +var PProfDefault = PProf{ + Addr: "127.0.0.1", + Port: 6071, +} + func MetricsServerAddOptions(prefix string, f *flag.FlagSet) { f.String(prefix+".addr", MetricsServerConfigDefault.Addr, "metrics server address") f.Int(prefix+".port", MetricsServerConfigDefault.Port, "metrics server port") - f.Bool(prefix+".pprof", MetricsServerConfigDefault.Pprof, "enable profiling for Go") f.Duration(prefix+".update-interval", MetricsServerConfigDefault.UpdateInterval, "metrics server update interval") } + +func PProfAddOptions(prefix string, f *flag.FlagSet) { + f.String(prefix+".addr", PProfDefault.Addr, "pprof server address") + f.Int(prefix+".port", PProfDefault.Port, "pprof server port") +} diff --git a/cmd/nitro-val/config.go b/cmd/nitro-val/config.go index 5ab1521f96..12a359cfa4 100644 --- a/cmd/nitro-val/config.go +++ b/cmd/nitro-val/config.go @@ -30,6 +30,8 @@ type ValidationNodeConfig struct { AuthRPC genericconf.AuthRPCConfig `koanf:"auth"` Metrics bool `koanf:"metrics"` MetricsServer genericconf.MetricsServerConfig `koanf:"metrics-server"` + PProf bool `koanf:"pprof"` + PprofCfg genericconf.PProf `koanf:"pprof-cfg"` Workdir string `koanf:"workdir" reload:"hot"` } @@ -67,6 +69,8 @@ var ValidationNodeConfigDefault = ValidationNodeConfig{ AuthRPC: genericconf.AuthRPCConfigDefault, Metrics: false, MetricsServer: genericconf.MetricsServerConfigDefault, + PProf: false, + PprofCfg: genericconf.PProfDefault, Workdir: "", } @@ -83,6 +87,8 @@ func ValidationNodeConfigAddOptions(f *flag.FlagSet) { genericconf.AuthRPCConfigAddOptions("auth", f) f.Bool("metrics", ValidationNodeConfigDefault.Metrics, "enable metrics") genericconf.MetricsServerAddOptions("metrics-server", f) + f.Bool("pprof", ValidationNodeConfigDefault.PProf, "enable pprof") + genericconf.PProfAddOptions("pprof-cfg", f) f.String("workdir", ValidationNodeConfigDefault.Workdir, "path used for purpose of resolving relative paths (ia. jwt secret file, log files), if empty then current working directory will be used.") } diff --git a/cmd/nitro-val/nitro_val.go b/cmd/nitro-val/nitro_val.go index 40d9fce5b6..e6b7bd882f 100644 --- a/cmd/nitro-val/nitro_val.go +++ b/cmd/nitro-val/nitro_val.go @@ -32,6 +32,25 @@ func main() { os.Exit(mainImpl()) } +// Checks metrics and PProf flag, runs them if enabled. +// Note: they are separate so one can enable/disable them as they wish, the only +// requirement is that they can't run on the same address and port. +func startMetrics(cfg *ValidationNodeConfig) error { + mAddr := fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port) + pAddr := fmt.Sprintf("%v:%v", cfg.PprofCfg.Addr, cfg.PprofCfg.Port) + if cfg.Metrics && cfg.PProf && mAddr == pAddr { + return fmt.Errorf("metrics and pprof cannot be enabled on the same address:port: %s", mAddr) + } + if cfg.Metrics { + go metrics.CollectProcessMetrics(cfg.MetricsServer.UpdateInterval) + exp.Setup(fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port)) + } + if cfg.PProf { + genericconf.StartPprof(pAddr) + } + return nil +} + // Returns the exit code func mainImpl() int { ctx, cancelFunc := context.WithCancel(context.Background()) @@ -96,20 +115,8 @@ func mainImpl() int { log.Crit("failed to initialize geth stack", "err", err) } - if nodeConfig.Metrics { - go metrics.CollectProcessMetrics(nodeConfig.MetricsServer.UpdateInterval) - - if nodeConfig.MetricsServer.Addr != "" { - address := fmt.Sprintf("%v:%v", nodeConfig.MetricsServer.Addr, nodeConfig.MetricsServer.Port) - if nodeConfig.MetricsServer.Pprof { - genericconf.StartPprof(address) - } else { - exp.Setup(address) - } - } - } else if nodeConfig.MetricsServer.Pprof { - flag.Usage() - log.Error("--metrics must be enabled in order to use pprof with the metrics server") + if err := startMetrics(nodeConfig); err != nil { + log.Error("Starting metrics: %v", err) return 1 } diff --git a/cmd/nitro/init.go b/cmd/nitro/init.go index 2284695961..399555098c 100644 --- a/cmd/nitro/init.go +++ b/cmd/nitro/init.go @@ -41,7 +41,7 @@ import ( "github.com/offchainlabs/nitro/cmd/ipfshelper" "github.com/offchainlabs/nitro/staker" "github.com/offchainlabs/nitro/statetransfer" - flag "github.com/spf13/pflag" + "github.com/spf13/pflag" ) type InitConfig struct { @@ -58,6 +58,7 @@ type InitConfig struct { ThenQuit bool `koanf:"then-quit"` Prune string `koanf:"prune"` PruneBloomSize uint64 `koanf:"prune-bloom-size"` + ResetToMsg int64 `koanf:"reset-to-message"` } var InitConfigDefault = InitConfig{ @@ -73,9 +74,10 @@ var InitConfigDefault = InitConfig{ ThenQuit: false, Prune: "", PruneBloomSize: 2048, + ResetToMsg: -1, } -func InitConfigAddOptions(prefix string, f *flag.FlagSet) { +func InitConfigAddOptions(prefix string, f *pflag.FlagSet) { f.Bool(prefix+".force", InitConfigDefault.Force, "if true: in case database exists init code will be reexecuted and genesis block compared to database") f.String(prefix+".url", InitConfigDefault.Url, "url to download initializtion data - will poll if download fails") f.String(prefix+".download-path", InitConfigDefault.DownloadPath, "path to save temp downloaded file") @@ -89,6 +91,7 @@ func InitConfigAddOptions(prefix string, f *flag.FlagSet) { f.Uint(prefix+".accounts-per-sync", InitConfigDefault.AccountsPerSync, "during init - sync database every X accounts. Lower value for low-memory systems. 0 disables.") f.String(prefix+".prune", InitConfigDefault.Prune, "pruning for a given use: \"full\" for full nodes serving RPC requests, or \"validator\" for validators") f.Uint64(prefix+".prune-bloom-size", InitConfigDefault.PruneBloomSize, "the amount of memory in megabytes to use for the pruning bloom filter (higher values prune better)") + f.Int64(prefix+".reset-to-message", InitConfigDefault.ResetToMsg, "forces a reset to an old message height. Also set max-reorg-resequence-depth=0 to force re-reading messages") } func downloadInit(ctx context.Context, initConfig *InitConfig) (string, error) { @@ -327,19 +330,23 @@ func findImportantRoots(ctx context.Context, chainDb ethdb.Database, stack *node } validatorDb := rawdb.NewTable(arbDb, arbnode.BlockValidatorPrefix) - lastValidated, err := staker.ReadLastValidatedFromDb(validatorDb) + lastValidated, err := staker.ReadLastValidatedInfo(validatorDb) if err != nil { return nil, err } if lastValidated != nil { - lastValidatedHeader := rawdb.ReadHeader(chainDb, lastValidated.BlockHash, lastValidated.BlockNumber) + var lastValidatedHeader *types.Header + headerNum := rawdb.ReadHeaderNumber(chainDb, lastValidated.GlobalState.BlockHash) + if headerNum != nil { + lastValidatedHeader = rawdb.ReadHeader(chainDb, lastValidated.GlobalState.BlockHash, *headerNum) + } if lastValidatedHeader != nil { err = roots.addHeader(lastValidatedHeader, false) if err != nil { return nil, err } } else { - log.Warn("missing latest validated block", "number", lastValidated.BlockNumber, "hash", lastValidated.BlockHash) + log.Warn("missing latest validated block", "hash", lastValidated.GlobalState.BlockHash) } } } else if initConfig.Prune == "full" { diff --git a/cmd/nitro/nitro.go b/cmd/nitro/nitro.go index 3074ca7f87..f1af1388cf 100644 --- a/cmd/nitro/nitro.go +++ b/cmd/nitro/nitro.go @@ -39,6 +39,7 @@ import ( "github.com/offchainlabs/nitro/arbnode" "github.com/offchainlabs/nitro/arbnode/execution" "github.com/offchainlabs/nitro/arbnode/resourcemanager" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/cmd/chaininfo" "github.com/offchainlabs/nitro/cmd/conf" "github.com/offchainlabs/nitro/cmd/genericconf" @@ -121,6 +122,25 @@ func main() { os.Exit(mainImpl()) } +// Checks metrics and PProf flag, runs them if enabled. +// Note: they are separate so one can enable/disable them as they wish, the only +// requirement is that they can't run on the same address and port. +func startMetrics(cfg *NodeConfig) error { + mAddr := fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port) + pAddr := fmt.Sprintf("%v:%v", cfg.PprofCfg.Addr, cfg.PprofCfg.Port) + if cfg.Metrics && cfg.PProf && mAddr == pAddr { + return fmt.Errorf("metrics and pprof cannot be enabled on the same address:port: %s", mAddr) + } + if cfg.Metrics { + go metrics.CollectProcessMetrics(cfg.MetricsServer.UpdateInterval) + exp.Setup(fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port)) + } + if cfg.PProf { + genericconf.StartPprof(pAddr) + } + return nil +} + // Returns the exit code func mainImpl() int { ctx, cancelFunc := context.WithCancel(context.Background()) @@ -216,6 +236,9 @@ func mainImpl() int { flag.Usage() log.Crit("error opening parent chain wallet", "path", l1Wallet.Pathname, "account", l1Wallet.Account, "err", err) } + if l1Wallet.OnlyCreateKey { + return 0 + } l1TransactionOptsBatchPoster = l1TransactionOpts l1TransactionOptsValidator = l1TransactionOpts } @@ -229,6 +252,9 @@ func mainImpl() int { flag.Usage() log.Crit("error opening Batch poster parent chain wallet", "path", nodeConfig.Node.BatchPoster.L1Wallet.Pathname, "account", nodeConfig.Node.BatchPoster.L1Wallet.Account, "err", err) } + if nodeConfig.Node.BatchPoster.L1Wallet.OnlyCreateKey { + return 0 + } } if validatorNeedsKey || nodeConfig.Node.Staker.L1Wallet.OnlyCreateKey { l1TransactionOptsValidator, _, err = util.OpenWallet("l1-validator", &nodeConfig.Node.Staker.L1Wallet, new(big.Int).SetUint64(nodeConfig.L1.ChainID)) @@ -236,6 +262,9 @@ func mainImpl() int { flag.Usage() log.Crit("error opening Validator parent chain wallet", "path", nodeConfig.Node.Staker.L1Wallet.Pathname, "account", nodeConfig.Node.Staker.L1Wallet.Account, "err", err) } + if nodeConfig.Node.Staker.L1Wallet.OnlyCreateKey { + return 0 + } } } @@ -340,6 +369,11 @@ func mainImpl() int { } } + if err := startMetrics(nodeConfig); err != nil { + log.Error("Starting metrics: %v", err) + return 1 + } + chainDb, l2BlockChain, err := openInitializeChainDb(ctx, stack, nodeConfig, new(big.Int).SetUint64(nodeConfig.L2.ChainID), execution.DefaultCacheConfigFor(stack, &nodeConfig.Node.Caching), l1Client, rollupAddrs) defer closeDb(chainDb, "chainDb") if l2BlockChain != nil { @@ -359,7 +393,7 @@ func mainImpl() int { return 1 } - if nodeConfig.Init.ThenQuit { + if nodeConfig.Init.ThenQuit && nodeConfig.Init.ResetToMsg < 0 { return 0 } @@ -369,23 +403,6 @@ func mainImpl() int { return 1 } - if nodeConfig.Metrics { - go metrics.CollectProcessMetrics(nodeConfig.MetricsServer.UpdateInterval) - - if nodeConfig.MetricsServer.Addr != "" { - address := fmt.Sprintf("%v:%v", nodeConfig.MetricsServer.Addr, nodeConfig.MetricsServer.Port) - if nodeConfig.MetricsServer.Pprof { - genericconf.StartPprof(address) - } else { - exp.Setup(address) - } - } - } else if nodeConfig.MetricsServer.Pprof { - flag.Usage() - log.Error("--metrics must be enabled in order to use pprof with the metrics server") - return 1 - } - fatalErrChan := make(chan error, 10) var valNode *valnode.ValidationNode @@ -462,12 +479,27 @@ func mainImpl() int { if err != nil { fatalErrChan <- fmt.Errorf("error starting node: %w", err) } + defer currentNode.StopAndWait() } sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) exitCode := 0 + + if err == nil && nodeConfig.Init.ResetToMsg > 0 { + err = currentNode.TxStreamer.ReorgTo(arbutil.MessageIndex(nodeConfig.Init.ResetToMsg)) + if err != nil { + fatalErrChan <- fmt.Errorf("error reseting message: %w", err) + exitCode = 1 + } + if nodeConfig.Init.ThenQuit { + close(sigint) + + return exitCode + } + } + select { case err := <-fatalErrChan: log.Error("shutting down due to fatal error", "err", err) @@ -480,8 +512,6 @@ func mainImpl() int { // cause future ctrl+c's to panic close(sigint) - currentNode.StopAndWait() - return exitCode } @@ -502,6 +532,8 @@ type NodeConfig struct { GraphQL genericconf.GraphQLConfig `koanf:"graphql"` Metrics bool `koanf:"metrics"` MetricsServer genericconf.MetricsServerConfig `koanf:"metrics-server"` + PProf bool `koanf:"pprof"` + PprofCfg genericconf.PProf `koanf:"pprof-cfg"` Init InitConfig `koanf:"init"` Rpc genericconf.RpcConfig `koanf:"rpc"` } @@ -519,6 +551,8 @@ var NodeConfigDefault = NodeConfig{ IPC: genericconf.IPCConfigDefault, Metrics: false, MetricsServer: genericconf.MetricsServerConfigDefault, + PProf: false, + PprofCfg: genericconf.PProfDefault, } func NodeConfigAddOptions(f *flag.FlagSet) { @@ -538,6 +572,9 @@ func NodeConfigAddOptions(f *flag.FlagSet) { genericconf.GraphQLConfigAddOptions("graphql", f) f.Bool("metrics", NodeConfigDefault.Metrics, "enable metrics") genericconf.MetricsServerAddOptions("metrics-server", f) + f.Bool("pprof", NodeConfigDefault.PProf, "enable pprof") + genericconf.PProfAddOptions("pprof-cfg", f) + InitConfigAddOptions("init", f) genericconf.RpcConfigAddOptions("rpc", f) } diff --git a/cmd/relay/relay.go b/cmd/relay/relay.go index 9f5669454f..57831c3f59 100644 --- a/cmd/relay/relay.go +++ b/cmd/relay/relay.go @@ -37,6 +37,25 @@ func printSampleUsage(progname string) { fmt.Printf("Sample usage: %s --node.feed.input.url= --chain.id= \n", progname) } +// Checks metrics and PProf flag, runs them if enabled. +// Note: they are separate so one can enable/disable them as they wish, the only +// requirement is that they can't run on the same address and port. +func startMetrics(cfg *relay.Config) error { + mAddr := fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port) + pAddr := fmt.Sprintf("%v:%v", cfg.PprofCfg.Addr, cfg.PprofCfg.Port) + if cfg.Metrics && cfg.PProf && mAddr == pAddr { + return fmt.Errorf("metrics and pprof cannot be enabled on the same address:port: %s", mAddr) + } + if cfg.Metrics { + go metrics.CollectProcessMetrics(cfg.MetricsServer.UpdateInterval) + exp.Setup(fmt.Sprintf("%v:%v", cfg.MetricsServer.Addr, cfg.MetricsServer.Port)) + } + if cfg.PProf { + genericconf.StartPprof(pAddr) + } + return nil +} + func startup() error { ctx := context.Background() @@ -68,16 +87,13 @@ func startup() error { if err != nil { return err } - err = newRelay.Start(ctx) - if err != nil { + + if err := startMetrics(relayConfig); err != nil { return err } - if relayConfig.Metrics && relayConfig.MetricsServer.Addr != "" { - go metrics.CollectProcessMetrics(relayConfig.MetricsServer.UpdateInterval) - - address := fmt.Sprintf("%v:%v", relayConfig.MetricsServer.Addr, relayConfig.MetricsServer.Port) - exp.Setup(address) + if err := newRelay.Start(ctx); err != nil { + return err } select { diff --git a/cmd/util/keystore.go b/cmd/util/keystore.go index cf00973295..56749f9722 100644 --- a/cmd/util/keystore.go +++ b/cmd/util/keystore.go @@ -16,6 +16,7 @@ import ( "github.com/ethereum/go-ethereum/accounts/keystore" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" "github.com/offchainlabs/nitro/cmd/genericconf" "github.com/offchainlabs/nitro/util/signature" @@ -51,6 +52,10 @@ func OpenWallet(description string, walletConfig *genericconf.WalletConfig, chai if err != nil { return nil, nil, err } + if walletConfig.OnlyCreateKey { + log.Info(fmt.Sprintf("Wallet key created with address %s, backup wallet (%s) and remove --%s.wallet.only-create-key to run normally", account.Address.Hex(), walletConfig.Pathname, description)) + return nil, nil, nil + } var txOpts *bind.TransactOpts if chainId != nil { @@ -91,46 +96,37 @@ func openKeystore(ks *keystore.KeyStore, description string, walletConfig *gener } } - var account accounts.Account if creatingNew { - var err error - account, err = ks.NewAccount(password) - if err != nil { - return &accounts.Account{}, err + a, err := ks.NewAccount(password) + return &a, err + } + + var account accounts.Account + if walletConfig.Account == "" { + if len(ks.Accounts()) > 1 { + names := make([]string, 0, len(ks.Accounts())) + for _, acct := range ks.Accounts() { + names = append(names, acct.Address.Hex()) + } + return nil, fmt.Errorf("too many existing accounts, choose one: %s", strings.Join(names, ",")) } + account = ks.Accounts()[0] } else { - if walletConfig.Account == "" { - if len(ks.Accounts()) > 1 { - names := make([]string, 0, len(ks.Accounts())) - for _, acct := range ks.Accounts() { - names = append(names, acct.Address.Hex()) - } - return nil, fmt.Errorf("too many existing accounts, choose one: %s", strings.Join(names, ",")) - } - account = ks.Accounts()[0] - } else { - address := common.HexToAddress(walletConfig.Account) - var emptyAddress common.Address - if address == emptyAddress { - return nil, fmt.Errorf("supplied address is invalid: %s", walletConfig.Account) - } - var err error - account, err = ks.Find(accounts.Account{Address: address}) - if err != nil { - return nil, err - } + address := common.HexToAddress(walletConfig.Account) + var emptyAddress common.Address + if address == emptyAddress { + return nil, fmt.Errorf("supplied address is invalid: %s", walletConfig.Account) + } + var err error + account, err = ks.Find(accounts.Account{Address: address}) + if err != nil { + return nil, err } } - if creatingNew { - return nil, fmt.Errorf("wallet key created with address %s, backup wallet (%s) and remove --%s.wallet.only-create-key to run normally", account.Address.Hex(), walletConfig.Pathname, description) - } - - err := ks.Unlock(account, password) - if err != nil { + if err := ks.Unlock(account, password); err != nil { return nil, err } - return &account, nil } diff --git a/cmd/util/keystore_test.go b/cmd/util/keystore_test.go index 7752825291..17a0498d68 100644 --- a/cmd/util/keystore_test.go +++ b/cmd/util/keystore_test.go @@ -25,6 +25,7 @@ func openTestKeystore(description string, walletConfig *genericconf.WalletConfig } func createWallet(t *testing.T, pathname string) { + t.Helper() walletConf := genericconf.WalletConfigDefault walletConf.Pathname = pathname walletConf.OnlyCreateKey = true @@ -36,13 +37,8 @@ func createWallet(t *testing.T, pathname string) { return "", nil } - _, _, err := openTestKeystore("test", &walletConf, testPass) - if err == nil { - t.Fatalf("should have failed") - } - keyCreatedError := "wallet key created" - if !strings.Contains(err.Error(), keyCreatedError) { - t.Fatalf("incorrect failure: %v, should have been %s", err, keyCreatedError) + if _, _, err := openTestKeystore("test", &walletConf, testPass); err != nil { + t.Fatalf("openTestKeystore() unexpected error: %v", err) } if testPassCalled { t.Error("password prompted for when it should not have been") @@ -110,13 +106,8 @@ func TestNewKeystorePromptPasswordTerminal(t *testing.T) { return password, nil } - _, _, err := openTestKeystore("test", &walletConf, getPass) - if err == nil { - t.Fatalf("should have failed") - } - keyCreatedError := "wallet key created" - if !strings.Contains(err.Error(), keyCreatedError) { - t.Fatalf("incorrect failure: %v, should have been %s", err, keyCreatedError) + if _, _, err := openTestKeystore("test", &walletConf, getPass); err != nil { + t.Fatalf("openTestKeystore() unexpected error: %v", err) } if !testPassCalled { t.Error("password not prompted for") @@ -167,13 +158,8 @@ func TestExistingKeystoreAccountName(t *testing.T) { return password, nil } - _, _, err := openTestKeystore("test", &walletConf, testPass) - if err == nil { - t.Fatalf("should have failed") - } - keyCreatedError := "wallet key created" - if !strings.Contains(err.Error(), keyCreatedError) { - t.Fatalf("incorrect failure: %v, should have been %s", err, keyCreatedError) + if _, _, err := openTestKeystore("test", &walletConf, testPass); err != nil { + t.Fatalf("openTestKeystore() unexpected error: %v", err) } if !testPassCalled { t.Error("password not prompted for") @@ -206,6 +192,7 @@ func TestExistingKeystoreAccountName(t *testing.T) { t.Fatal("should have failed") } invalidAddressError := "address is invalid" + keyCreatedError := "wallet key created" if !strings.Contains(err.Error(), invalidAddressError) { t.Fatalf("incorrect failure: %v, should have been %s", err, keyCreatedError) } diff --git a/das/syncing_fallback_storage.go b/das/syncing_fallback_storage.go index 8af39d7d3a..7c67dbec68 100644 --- a/das/syncing_fallback_storage.go +++ b/das/syncing_fallback_storage.go @@ -370,20 +370,25 @@ func (s *l1SyncService) readMore(ctx context.Context) error { func (s *l1SyncService) mainThread(ctx context.Context) { headerChan, unsubscribe := s.l1Reader.Subscribe(false) defer unsubscribe() + errCount := 0 for { err := s.readMore(ctx) if err != nil { if ctx.Err() != nil { return } - log.Error("error trying to sync from L1", "err", err) + errCount++ + if errCount > 5 { + log.Error("error trying to sync from L1", "err", err) + } select { case <-ctx.Done(): return - case <-time.After(s.config.DelayOnError): + case <-time.After(s.config.DelayOnError * time.Duration(errCount)): } continue } + errCount = 0 if s.catchingUp { // we're behind. Don't wait. continue diff --git a/go-ethereum b/go-ethereum index 28127f5941..3725b60e04 160000 --- a/go-ethereum +++ b/go-ethereum @@ -1 +1 @@ -Subproject commit 28127f5941faec6fe5227c29443d2074639495d0 +Subproject commit 3725b60e0494df145672ab67dd3ec18a85a2b5d1 diff --git a/go.mod b/go.mod index fc52f1f763..37ab04ff30 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/codeclysm/extract/v3 v3.0.2 github.com/dgraph-io/badger/v3 v3.2103.2 github.com/ethereum/go-ethereum v1.10.26 + github.com/google/go-cmp v0.5.9 github.com/hashicorp/golang-lru/v2 v2.0.1 github.com/ipfs/go-cid v0.3.2 github.com/ipfs/go-libipfs v0.6.2 diff --git a/go.sum b/go.sum index f351bb9545..ca552ef60a 100644 --- a/go.sum +++ b/go.sum @@ -493,6 +493,7 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/relay/relay.go b/relay/relay.go index b9d70c513b..f4fc33d9e3 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -146,6 +146,8 @@ type Config struct { LogType string `koanf:"log-type"` Metrics bool `koanf:"metrics"` MetricsServer genericconf.MetricsServerConfig `koanf:"metrics-server"` + PProf bool `koanf:"pprof"` + PprofCfg genericconf.PProf `koanf:"pprof-cfg"` Node NodeConfig `koanf:"node"` Queue int `koanf:"queue"` } @@ -157,6 +159,8 @@ var ConfigDefault = Config{ LogType: "plaintext", Metrics: false, MetricsServer: genericconf.MetricsServerConfigDefault, + PProf: false, + PprofCfg: genericconf.PProfDefault, Node: NodeConfigDefault, Queue: 1024, } @@ -168,6 +172,8 @@ func ConfigAddOptions(f *flag.FlagSet) { f.String("log-type", ConfigDefault.LogType, "log type") f.Bool("metrics", ConfigDefault.Metrics, "enable metrics") genericconf.MetricsServerAddOptions("metrics-server", f) + f.Bool("pprof", ConfigDefault.PProf, "enable pprof") + genericconf.PProfAddOptions("pprof-cfg", f) NodeConfigAddOptions("node", f) f.Int("queue", ConfigDefault.Queue, "size of relay queue") } diff --git a/staker/block_challenge_backend.go b/staker/block_challenge_backend.go index b0c3bb8655..42351789ba 100644 --- a/staker/block_challenge_backend.go +++ b/staker/block_challenge_backend.go @@ -11,7 +11,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/offchainlabs/nitro/arbutil" @@ -20,14 +19,13 @@ import ( ) type BlockChallengeBackend struct { - bc *core.BlockChain - startBlock int64 + streamer TransactionStreamerInterface + startMsgCount arbutil.MessageIndex startPosition uint64 endPosition uint64 startGs validator.GoGlobalState endGs validator.GoGlobalState inboxTracker InboxTrackerInterface - genesisBlockNumber uint64 tooFarStartsAtPosition uint64 } @@ -37,19 +35,10 @@ var _ ChallengeBackend = (*BlockChallengeBackend)(nil) func NewBlockChallengeBackend( initialState *challengegen.ChallengeManagerInitiatedChallenge, maxBatchesRead uint64, - bc *core.BlockChain, + streamer TransactionStreamerInterface, inboxTracker InboxTrackerInterface, - genesisBlockNumber uint64, ) (*BlockChallengeBackend, error) { startGs := validator.GoGlobalStateFromSolidity(initialState.StartState) - startBlockNum := arbutil.MessageCountToBlockNumber(0, genesisBlockNumber) - if startGs.BlockHash != (common.Hash{}) { - startBlock := bc.GetBlockByHash(startGs.BlockHash) - if startBlock == nil { - return nil, fmt.Errorf("failed to find start block %v", startGs.BlockHash) - } - startBlockNum = int64(startBlock.NumberU64()) - } var startMsgCount arbutil.MessageIndex if startGs.Batch > 0 { @@ -60,10 +49,6 @@ func NewBlockChallengeBackend( } } startMsgCount += arbutil.MessageIndex(startGs.PosInBatch) - expectedMsgCount := arbutil.SignedBlockNumberToMessageCount(startBlockNum, genesisBlockNumber) - if startMsgCount != expectedMsgCount { - return nil, fmt.Errorf("start block %v and start message count %v don't correspond", startBlockNum, startMsgCount) - } var endMsgCount arbutil.MessageIndex if maxBatchesRead > 0 { @@ -75,19 +60,18 @@ func NewBlockChallengeBackend( } return &BlockChallengeBackend{ - bc: bc, - startBlock: startBlockNum, + streamer: streamer, + startMsgCount: startMsgCount, startGs: startGs, startPosition: 0, endPosition: math.MaxUint64, endGs: validator.GoGlobalStateFromSolidity(initialState.EndState), inboxTracker: inboxTracker, - genesisBlockNumber: genesisBlockNumber, tooFarStartsAtPosition: uint64(endMsgCount - startMsgCount + 1), }, nil } -func (b *BlockChallengeBackend) findBatchFromMessageIndex(msgCount arbutil.MessageIndex) (uint64, error) { +func (b *BlockChallengeBackend) findBatchAfterMessageCount(msgCount arbutil.MessageIndex) (uint64, error) { if msgCount == 0 { return 0, nil } @@ -118,54 +102,46 @@ func (b *BlockChallengeBackend) findBatchFromMessageIndex(msgCount arbutil.Messa } } -func (b *BlockChallengeBackend) FindGlobalStateFromHeader(header *types.Header) (validator.GoGlobalState, error) { - if header == nil { - return validator.GoGlobalState{}, nil - } - msgCount := arbutil.BlockNumberToMessageCount(header.Number.Uint64(), b.genesisBlockNumber) - batch, err := b.findBatchFromMessageIndex(msgCount) +func (b *BlockChallengeBackend) FindGlobalStateFromMessageCount(count arbutil.MessageIndex) (validator.GoGlobalState, error) { + batch, err := b.findBatchAfterMessageCount(count) if err != nil { return validator.GoGlobalState{}, err } - var batchMsgCount arbutil.MessageIndex + var prevBatchMsgCount arbutil.MessageIndex if batch > 0 { - batchMsgCount, err = b.inboxTracker.GetBatchMessageCount(batch - 1) + prevBatchMsgCount, err = b.inboxTracker.GetBatchMessageCount(batch - 1) if err != nil { return validator.GoGlobalState{}, err } - if batchMsgCount > msgCount { + if prevBatchMsgCount > count { return validator.GoGlobalState{}, errors.New("findBatchFromMessageCount returned bad batch") } } - extraInfo := types.DeserializeHeaderExtraInformation(header) + res, err := b.streamer.ResultAtCount(count) + if err != nil { + return validator.GoGlobalState{}, err + } return validator.GoGlobalState{ - BlockHash: header.Hash(), - SendRoot: extraInfo.SendRoot, + BlockHash: res.BlockHash, + SendRoot: res.SendRoot, Batch: batch, - PosInBatch: uint64(msgCount - batchMsgCount), + PosInBatch: uint64(count - prevBatchMsgCount), }, nil } const StatusFinished uint8 = 1 const StatusTooFar uint8 = 3 -func (b *BlockChallengeBackend) GetBlockNrAtStep(step uint64) int64 { - return b.startBlock + int64(step) +func (b *BlockChallengeBackend) GetMessageCountAtStep(step uint64) arbutil.MessageIndex { + return b.startMsgCount + arbutil.MessageIndex(step) } func (b *BlockChallengeBackend) GetInfoAtStep(step uint64) (validator.GoGlobalState, uint8, error) { - blockNum := b.GetBlockNrAtStep(step) + msgNum := b.GetMessageCountAtStep(step) if step >= b.tooFarStartsAtPosition { return validator.GoGlobalState{}, StatusTooFar, nil } - var header *types.Header - if blockNum != -1 { - header = b.bc.GetHeaderByNumber(uint64(blockNum)) - if header == nil { - return validator.GoGlobalState{}, 0, fmt.Errorf("failed to get block %v in block challenge", blockNum) - } - } - globalState, err := b.FindGlobalStateFromHeader(header) + globalState, err := b.FindGlobalStateFromMessageCount(msgNum) if err != nil { return validator.GoGlobalState{}, 0, err } diff --git a/staker/block_validator.go b/staker/block_validator.go index 56bb2729c7..0ff74a8014 100644 --- a/staker/block_validator.go +++ b/staker/block_validator.go @@ -9,72 +9,69 @@ import ( "fmt" "sync" "sync/atomic" + "testing" "time" flag "github.com/spf13/pflag" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" - "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/rpcclient" "github.com/offchainlabs/nitro/util/stopwaiter" "github.com/offchainlabs/nitro/validator" ) var ( - validatorPendingValidationsGauge = metrics.NewRegisteredGauge("arb/validator/validations/pending", nil) - validatorValidValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/valid", nil) - validatorFailedValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/failed", nil) - validatorLastBlockInLastBatchGauge = metrics.NewRegisteredGauge("arb/validator/last_block_in_last_batch", nil) - validatorLastBlockValidatedGauge = metrics.NewRegisteredGauge("arb/validator/last_block_validated", nil) + validatorPendingValidationsGauge = metrics.NewRegisteredGauge("arb/validator/validations/pending", nil) + validatorValidValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/valid", nil) + validatorFailedValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/failed", nil) + validatorMsgCountCurrentBatch = metrics.NewRegisteredGauge("arb/validator/msg_count_current_batch", nil) + validatorMsgCountValidatedGauge = metrics.NewRegisteredGauge("arb/validator/msg_count_validated", nil) ) type BlockValidator struct { stopwaiter.StopWaiter *StatelessBlockValidator - validations sync.Map - sequencerBatches sync.Map + reorgMutex sync.RWMutex - // acquiring multiple Mutexes must be done in order: - reorgMutex sync.Mutex - batchMutex sync.Mutex - blockMutex sync.Mutex - lastBlockValidatedMutex sync.Mutex + chainCaughtUp bool - reorgsPending int32 // atomic + // can only be accessed from creation thread or if holding reorg-write + nextCreateBatch []byte + nextCreateBatchMsgCount arbutil.MessageIndex + nextCreateBatchReread bool + nextCreateStartGS validator.GoGlobalState + nextCreatePrevDelayed uint64 - earliestBatchKept uint64 // atomic - nextBatchKept uint64 // behind batchMutex, 1 + the last batch number kept + // can only be accessed from from validation thread or if holding reorg-write + lastValidGS validator.GoGlobalState + valLoopPos arbutil.MessageIndex + legacyValidInfo *legacyLastBlockValidatedDbInfo - // protected by reorgMutex - globalPosNextSend GlobalStatePosition + // only from logger thread + lastValidInfoPrinted *GlobalStateValidatedInfo - // protected by BlockMutex: - nextBlockToValidate uint64 - lastValidationEntryBlock uint64 - - // behind lastBlockValidatedMutex - lastBlockValidatedUnknown bool - lastBlockValidated uint64 // also atomic - lastBlockValidatedHash common.Hash + // can be read (atomic.Load) by anyone holding reorg-read + // written (atomic.Set) by appropriate thread or (any way) holding reorg-write + createdA uint64 + recordSentA uint64 + validatedA uint64 + validations containers.SyncMap[arbutil.MessageIndex, *validationStatus] config BlockValidatorConfigFetcher - sendValidationsChan chan struct{} - progressChan chan uint64 - - lastHeaderForPrepareState *types.Header + createNodesChan chan struct{} + sendRecordChan chan struct{} + progressValidationsChan chan struct{} - // recentValid holds one recently valid header, to commit it to DB on shutdown - recentValidMutex sync.Mutex - awaitingValidation *types.Header - validHeader *types.Header + // for testing only + testingProgressMadeChan chan struct{} fatalErr chan<- error } @@ -148,13 +145,12 @@ var DefaultBlockValidatorDangerousConfig = BlockValidatorDangerousConfig{ type valStatusField uint32 const ( - Unprepared valStatusField = iota + Created valStatusField = iota RecordSent RecordFailed Prepared + SendingValidation ValidationSent - Failed - Valid ) type validationStatus struct { @@ -164,10 +160,6 @@ type validationStatus struct { Runs []validator.ValidationRun // if status >= ValidationSent } -func (s *validationStatus) setStatus(val valStatusField) { - atomic.StoreUint32(&s.Status, uint32(val)) -} - func (s *validationStatus) getStatus() valStatusField { uintStat := atomic.LoadUint32(&s.Status) return valStatusField(uintStat) @@ -181,24 +173,72 @@ func NewBlockValidator( statelessBlockValidator *StatelessBlockValidator, inbox InboxTrackerInterface, streamer TransactionStreamerInterface, - reorgingToBlock *types.Block, config BlockValidatorConfigFetcher, fatalErr chan<- error, ) (*BlockValidator, error) { - validator := &BlockValidator{ + ret := &BlockValidator{ StatelessBlockValidator: statelessBlockValidator, - sendValidationsChan: make(chan struct{}, 1), - progressChan: make(chan uint64, 1), + createNodesChan: make(chan struct{}, 1), + sendRecordChan: make(chan struct{}, 1), + progressValidationsChan: make(chan struct{}, 1), config: config, fatalErr: fatalErr, } - err := validator.readLastBlockValidatedDbInfo(reorgingToBlock) - if err != nil { - return nil, err + if !config().Dangerous.ResetBlockValidation { + validated, err := ret.ReadLastValidatedInfo() + if err != nil { + return nil, err + } + if validated != nil { + ret.lastValidGS = validated.GlobalState + } else { + legacyInfo, err := ret.legacyReadLastValidatedInfo() + if err != nil { + return nil, err + } + ret.legacyValidInfo = legacyInfo + } + } + // genesis block is impossible to validate unless genesis state is empty + if ret.lastValidGS.Batch == 0 && ret.legacyValidInfo == nil { + genesis, err := streamer.ResultAtCount(1) + if err != nil { + return nil, err + } + ret.lastValidGS = validator.GoGlobalState{ + BlockHash: genesis.BlockHash, + SendRoot: genesis.SendRoot, + Batch: 1, + PosInBatch: 0, + } } - streamer.SetBlockValidator(validator) - inbox.SetBlockValidator(validator) - return validator, nil + streamer.SetBlockValidator(ret) + inbox.SetBlockValidator(ret) + return ret, nil +} + +func atomicStorePos(addr *uint64, val arbutil.MessageIndex) { + atomic.StoreUint64(addr, uint64(val)) +} + +func atomicLoadPos(addr *uint64) arbutil.MessageIndex { + return arbutil.MessageIndex(atomic.LoadUint64(addr)) +} + +func (v *BlockValidator) created() arbutil.MessageIndex { + return atomicLoadPos(&v.createdA) +} + +func (v *BlockValidator) recordSent() arbutil.MessageIndex { + return atomicLoadPos(&v.recordSentA) +} + +func (v *BlockValidator) validated() arbutil.MessageIndex { + return atomicLoadPos(&v.validatedA) +} + +func (v *BlockValidator) Validated(t *testing.T) arbutil.MessageIndex { + return v.validated() } func (v *BlockValidator) possiblyFatal(err error) { @@ -217,161 +257,118 @@ func (v *BlockValidator) possiblyFatal(err error) { } } -func (v *BlockValidator) triggerSendValidations() { +func nonBlockingTrigger(channel chan struct{}) { select { - case v.sendValidationsChan <- struct{}{}: + case channel <- struct{}{}: default: } } -func (v *BlockValidator) recentlyValid(header *types.Header) { - v.recentValidMutex.Lock() - defer v.recentValidMutex.Unlock() - if v.awaitingValidation == nil { - return - } - if v.awaitingValidation.Number.Cmp(header.Number) > 0 { - return +// called from NewBlockValidator, doesn't need to catch locks +func ReadLastValidatedInfo(db ethdb.Database) (*GlobalStateValidatedInfo, error) { + exists, err := db.Has(lastGlobalStateValidatedInfoKey) + if err != nil { + return nil, err } - if v.validHeader != nil { - v.recordingDatabase.Dereference(v.validHeader) + var validated GlobalStateValidatedInfo + if !exists { + return nil, nil } - v.validHeader = v.awaitingValidation - v.awaitingValidation = nil -} - -func (v *BlockValidator) recentStateComputed(header *types.Header) { - v.recentValidMutex.Lock() - defer v.recentValidMutex.Unlock() - if v.awaitingValidation != nil { - return + gsBytes, err := db.Get(lastGlobalStateValidatedInfoKey) + if err != nil { + return nil, err } - _, err := v.recordingDatabase.StateFor(header) + err = rlp.DecodeBytes(gsBytes, &validated) if err != nil { - log.Error("failed to get state for block while validating", "err", err, "blockNum", header.Number, "hash", header.Hash()) - return + return nil, err } - v.awaitingValidation = header + return &validated, nil } -func (v *BlockValidator) recentShutdown() error { - v.recentValidMutex.Lock() - defer v.recentValidMutex.Unlock() - if v.validHeader == nil { - return nil - } - err := v.recordingDatabase.WriteStateToDatabase(v.validHeader) - v.recordingDatabase.Dereference(v.validHeader) - return err +func (v *BlockValidator) ReadLastValidatedInfo() (*GlobalStateValidatedInfo, error) { + return ReadLastValidatedInfo(v.db) } -func ReadLastValidatedFromDb(db ethdb.Database) (*LastBlockValidatedDbInfo, error) { - exists, err := db.Has(lastBlockValidatedInfoKey) +func (v *BlockValidator) legacyReadLastValidatedInfo() (*legacyLastBlockValidatedDbInfo, error) { + exists, err := v.db.Has(legacyLastBlockValidatedInfoKey) if err != nil { return nil, err } + var validated legacyLastBlockValidatedDbInfo if !exists { return nil, nil } - - infoBytes, err := db.Get(lastBlockValidatedInfoKey) + gsBytes, err := v.db.Get(legacyLastBlockValidatedInfoKey) if err != nil { return nil, err } - - var info LastBlockValidatedDbInfo - err = rlp.DecodeBytes(infoBytes, &info) + err = rlp.DecodeBytes(gsBytes, &validated) if err != nil { return nil, err } - return &info, nil + return &validated, nil } -// only called by NewBlockValidator -func (v *BlockValidator) readLastBlockValidatedDbInfo(reorgingToBlock *types.Block) error { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() +var ErrGlobalStateNotInChain = errors.New("globalstate not in chain") - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - - v.lastBlockValidatedMutex.Lock() - defer v.lastBlockValidatedMutex.Unlock() - - exists, err := v.db.Has(lastBlockValidatedInfoKey) +// false if chain not caught up to globalstate +// error is ErrGlobalStateNotInChain if globalstate not in chain (and chain caught up) +func GlobalStateToMsgCount(tracker InboxTrackerInterface, streamer TransactionStreamerInterface, gs validator.GoGlobalState) (bool, arbutil.MessageIndex, error) { + batchCount, err := tracker.GetBatchCount() if err != nil { - return err + return false, 0, err } - - if !exists || v.config().Dangerous.ResetBlockValidation { - // The db contains no validation info; start from the beginning. - // TODO: this skips validating the genesis block. - atomic.StoreUint64(&v.lastBlockValidated, v.genesisBlockNum) - validatorLastBlockValidatedGauge.Update(int64(v.genesisBlockNum)) - genesisBlock := v.blockchain.GetBlockByNumber(v.genesisBlockNum) - if genesisBlock == nil { - return fmt.Errorf("blockchain missing genesis block number %v", v.genesisBlockNum) + requiredBatchCount := gs.Batch + 1 + if gs.PosInBatch == 0 { + requiredBatchCount -= 1 + } + if batchCount < requiredBatchCount { + return false, 0, nil + } + var prevBatchMsgCount arbutil.MessageIndex + if gs.Batch > 0 { + prevBatchMsgCount, err = tracker.GetBatchMessageCount(gs.Batch - 1) + if err != nil { + return false, 0, err } - - v.lastBlockValidatedHash = genesisBlock.Hash() - v.nextBlockToValidate = v.genesisBlockNum + 1 - v.globalPosNextSend = GlobalStatePosition{ - BatchNumber: 1, - PosInBatch: 0, + } + count := prevBatchMsgCount + if gs.PosInBatch > 0 { + curBatchMsgCount, err := tracker.GetBatchMessageCount(gs.Batch) + if err != nil { + return false, 0, fmt.Errorf("%w: getBatchMsgCount %d batchCount %d", err, gs.Batch, batchCount) + } + count += arbutil.MessageIndex(gs.PosInBatch) + if curBatchMsgCount < count { + return false, 0, fmt.Errorf("%w: batch %d posInBatch %d, maxPosInBatch %d", ErrGlobalStateNotInChain, gs.Batch, gs.PosInBatch, curBatchMsgCount-prevBatchMsgCount) } - return nil } - - info, err := ReadLastValidatedFromDb(v.db) + processed, err := streamer.GetProcessedMessageCount() if err != nil { - return err + return false, 0, err } - - if reorgingToBlock != nil && reorgingToBlock.NumberU64() >= info.BlockNumber { - // Disregard this reorg as it doesn't affect the last validated block - reorgingToBlock = nil + if processed < count { + return false, 0, nil } - - if reorgingToBlock == nil { - expectedHash := v.blockchain.GetCanonicalHash(info.BlockNumber) - if expectedHash != info.BlockHash && (expectedHash != common.Hash{}) { - return fmt.Errorf("last validated block %v stored with hash %v, but blockchain has hash %v", info.BlockNumber, info.BlockHash, expectedHash) - } + res, err := streamer.ResultAtCount(count) + if err != nil { + return false, 0, err } - - atomic.StoreUint64(&v.lastBlockValidated, info.BlockNumber) - validatorLastBlockValidatedGauge.Update(int64(info.BlockNumber)) - v.lastBlockValidatedHash = info.BlockHash - v.nextBlockToValidate = v.lastBlockValidated + 1 - v.globalPosNextSend = info.AfterPosition - - if reorgingToBlock != nil { - err = v.reorgToBlockImpl(reorgingToBlock.NumberU64(), reorgingToBlock.Hash()) - if err != nil { - return err - } + if res.BlockHash != gs.BlockHash || res.SendRoot != gs.SendRoot { + return false, 0, fmt.Errorf("%w: count %d hash %v expected %v, sendroot %v expected %v", ErrGlobalStateNotInChain, count, gs.BlockHash, res.BlockHash, gs.SendRoot, res.SendRoot) } - - return nil + return true, count, nil } -func (v *BlockValidator) sendRecord(s *validationStatus, mustDeref bool) error { +func (v *BlockValidator) sendRecord(s *validationStatus) error { if !v.Started() { - // this could only be sent by NewBlock, so mustDeref is not sent return nil } - prevHeader := s.Entry.PrevBlockHeader - if !s.replaceStatus(Unprepared, RecordSent) { - if mustDeref { - v.recordingDatabase.Dereference(prevHeader) - } + if !s.replaceStatus(Created, RecordSent) { return fmt.Errorf("failed status check for send record. Status: %v", s.getStatus()) } v.LaunchThread(func(ctx context.Context) { - if mustDeref { - defer v.recordingDatabase.Dereference(prevHeader) - } - err := v.ValidationEntryRecord(ctx, s.Entry, true) + err := v.ValidationEntryRecord(ctx, s.Entry) if ctx.Err() != nil { return } @@ -380,94 +377,31 @@ func (v *BlockValidator) sendRecord(s *validationStatus, mustDeref bool) error { log.Error("Error while recording", "err", err, "status", s.getStatus()) return } - v.recentStateComputed(prevHeader) - v.recordingDatabase.Dereference(prevHeader) // removes the reference added by ValidationEntryRecord if !s.replaceStatus(RecordSent, Prepared) { log.Error("Fault trying to update validation with recording", "entry", s.Entry, "status", s.getStatus()) return } - v.triggerSendValidations() + nonBlockingTrigger(v.progressValidationsChan) }) return nil } -func (v *BlockValidator) newValidationStatus(prevHeader, header *types.Header, msg *arbostypes.MessageWithMetadata) (*validationStatus, error) { - entry, err := newValidationEntry(prevHeader, header, msg) - if err != nil { - return nil, err - } - status := &validationStatus{ - Status: uint32(Unprepared), - Entry: entry, - } - return status, nil -} - -func (v *BlockValidator) NewBlock(block *types.Block, prevHeader *types.Header, msg arbostypes.MessageWithMetadata) { - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - blockNum := block.NumberU64() - v.lastBlockValidatedMutex.Lock() - if blockNum < v.lastBlockValidated { - v.lastBlockValidatedMutex.Unlock() - return - } - if v.lastBlockValidatedUnknown { - if block.Hash() == v.lastBlockValidatedHash { - v.lastBlockValidated = blockNum - validatorLastBlockValidatedGauge.Update(int64(blockNum)) - v.nextBlockToValidate = blockNum + 1 - v.lastBlockValidatedUnknown = false - log.Info("Block building caught up to staker", "blockNr", v.lastBlockValidated, "blockHash", v.lastBlockValidatedHash) - // note: this block is already valid - } - v.lastBlockValidatedMutex.Unlock() - return - } - if v.nextBlockToValidate+v.config().ForwardBlocks <= blockNum { - v.lastBlockValidatedMutex.Unlock() - return - } - v.lastBlockValidatedMutex.Unlock() - status, err := v.newValidationStatus(prevHeader, block.Header(), &msg) - if err != nil { - log.Error("failed creating validation status", "err", err) - return - } - // It's fine to separately load and then store as we have the blockMutex acquired - _, present := v.validations.Load(blockNum) - if present { - return - } - v.validations.Store(blockNum, status) - if v.lastValidationEntryBlock < blockNum { - v.lastValidationEntryBlock = blockNum - } - v.triggerSendValidations() -} - //nolint:gosec func (v *BlockValidator) writeToFile(validationEntry *validationEntry, moduleRoot common.Hash) error { input, err := validationEntry.ToInput() if err != nil { return err } - expOut, err := validationEntry.expectedEnd() - if err != nil { - return err - } - _, err = v.execSpawner.WriteToFile(input, expOut, moduleRoot).Await(v.GetContext()) + _, err = v.execSpawner.WriteToFile(input, validationEntry.End, moduleRoot).Await(v.GetContext()) return err } func (v *BlockValidator) SetCurrentWasmModuleRoot(hash common.Hash) error { - v.blockMutex.Lock() v.moduleMutex.Lock() - defer v.blockMutex.Unlock() defer v.moduleMutex.Unlock() if (hash == common.Hash{}) { - return errors.New("trying to set zero as wsmModuleRoot") + return errors.New("trying to set zero as wasmModuleRoot") } if hash == v.currentWasmModuleRoot { return nil @@ -490,527 +424,520 @@ func (v *BlockValidator) SetCurrentWasmModuleRoot(hash common.Hash) error { ) } -var ErrValidationCanceled = errors.New("validation of block cancelled") +func (v *BlockValidator) readBatch(ctx context.Context, batchNum uint64) (bool, []byte, arbutil.MessageIndex, error) { + batchCount, err := v.inboxTracker.GetBatchCount() + if err != nil { + return false, nil, 0, err + } + if batchCount <= batchNum { + return false, nil, 0, nil + } + batchMsgCount, err := v.inboxTracker.GetBatchMessageCount(batchNum) + if err != nil { + return false, nil, 0, err + } + batch, err := v.inboxReader.GetSequencerMessageBytes(ctx, batchNum) + if err != nil { + return false, nil, 0, err + } + return true, batch, batchMsgCount, nil +} + +func (v *BlockValidator) createNextValidationEntry(ctx context.Context) (bool, error) { + v.reorgMutex.RLock() + defer v.reorgMutex.RUnlock() + pos := v.created() + if pos > v.validated()+arbutil.MessageIndex(v.config().ForwardBlocks) { + log.Trace("create validation entry: nothing to do", "pos", pos, "validated", v.validated()) + return false, nil + } + streamerMsgCount, err := v.streamer.GetProcessedMessageCount() + if err != nil { + return false, err + } + if pos >= streamerMsgCount { + log.Trace("create validation entry: nothing to do", "pos", pos, "streamerMsgCount", streamerMsgCount) + return false, nil + } + msg, err := v.streamer.GetMessage(pos) + if err != nil { + return false, err + } + endRes, err := v.streamer.ResultAtCount(pos + 1) + if err != nil { + return false, err + } + if v.nextCreateStartGS.PosInBatch == 0 || v.nextCreateBatchReread { + // new batch + found, batch, count, err := v.readBatch(ctx, v.nextCreateStartGS.Batch) + if !found { + return false, err + } + v.nextCreateBatch = batch + v.nextCreateBatchMsgCount = count + validatorMsgCountCurrentBatch.Update(int64(count)) + v.nextCreateBatchReread = false + } + endGS := validator.GoGlobalState{ + BlockHash: endRes.BlockHash, + SendRoot: endRes.SendRoot, + } + if pos+1 < v.nextCreateBatchMsgCount { + endGS.Batch = v.nextCreateStartGS.Batch + endGS.PosInBatch = v.nextCreateStartGS.PosInBatch + 1 + } else if pos+1 == v.nextCreateBatchMsgCount { + endGS.Batch = v.nextCreateStartGS.Batch + 1 + endGS.PosInBatch = 0 + } else { + return false, fmt.Errorf("illegal batch msg count %d pos %d batch %d", v.nextCreateBatchMsgCount, pos, endGS.Batch) + } + entry, err := newValidationEntry(pos, v.nextCreateStartGS, endGS, msg, v.nextCreateBatch, v.nextCreatePrevDelayed) + if err != nil { + return false, err + } + status := &validationStatus{ + Status: uint32(Created), + Entry: entry, + } + v.validations.Store(pos, status) + v.nextCreateStartGS = endGS + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + atomicStorePos(&v.createdA, pos+1) + log.Trace("create validation entry: created", "pos", pos) + return true, nil +} + +func (v *BlockValidator) iterativeValidationEntryCreator(ctx context.Context, ignored struct{}) time.Duration { + moreWork, err := v.createNextValidationEntry(ctx) + if err != nil { + processed, processedErr := v.streamer.GetProcessedMessageCount() + log.Error("error trying to create validation node", "err", err, "created", v.created()+1, "processed", processed, "processedErr", processedErr) + } + if moreWork { + return 0 + } + return v.config().ValidationPoll +} + +func (v *BlockValidator) sendNextRecordRequests(ctx context.Context) (bool, error) { + v.reorgMutex.RLock() + pos := v.recordSent() + created := v.created() + validated := v.validated() + v.reorgMutex.RUnlock() + + recordUntil := validated + arbutil.MessageIndex(v.config().PrerecordedBlocks) - 1 + if recordUntil > created-1 { + recordUntil = created - 1 + } + if recordUntil < pos { + return false, nil + } + log.Trace("preparing to record", "pos", pos, "until", recordUntil) + // prepare could take a long time so we do it without a lock + err := v.recorder.PrepareForRecord(ctx, pos, recordUntil) + if err != nil { + return false, err + } + + v.reorgMutex.RLock() + defer v.reorgMutex.RUnlock() + createdNew := v.created() + recordSentNew := v.recordSent() + if createdNew < created || recordSentNew < pos { + // there was a relevant reorg - quit and restart + return true, nil + } + for pos <= recordUntil { + validationStatus, found := v.validations.Load(pos) + if !found { + return false, fmt.Errorf("not found entry for pos %d", pos) + } + currentStatus := validationStatus.getStatus() + if currentStatus != Created { + return false, fmt.Errorf("bad status trying to send recordings for pos %d status: %v", pos, currentStatus) + } + err := v.sendRecord(validationStatus) + if err != nil { + return false, err + } + pos += 1 + atomicStorePos(&v.recordSentA, pos) + log.Trace("next record request: sent", "pos", pos) + } + + return true, nil +} + +func (v *BlockValidator) iterativeValidationEntryRecorder(ctx context.Context, ignored struct{}) time.Duration { + moreWork, err := v.sendNextRecordRequests(ctx) + if err != nil { + log.Error("error trying to record for validation node", "err", err) + } + if moreWork { + return 0 + } + return v.config().ValidationPoll +} + +func (v *BlockValidator) iterativeValidationPrint(ctx context.Context) time.Duration { + validated, err := v.ReadLastValidatedInfo() + if err != nil { + log.Error("cannot read last validated data from database", "err", err) + return time.Second * 30 + } + if validated == nil { + return time.Second + } + if v.lastValidInfoPrinted != nil { + if v.lastValidInfoPrinted.GlobalState.BlockHash == validated.GlobalState.BlockHash { + return time.Second + } + } + var batchMsgs arbutil.MessageIndex + var printedCount int64 + if validated.GlobalState.Batch > 0 { + batchMsgs, err = v.inboxTracker.GetBatchMessageCount(validated.GlobalState.Batch) + } + if err != nil { + printedCount = -1 + } else { + printedCount = int64(batchMsgs) + int64(validated.GlobalState.PosInBatch) + } + log.Info("validated execution", "messageCount", printedCount, "globalstate", validated.GlobalState, "WasmRoots", validated.WasmRoots) + v.lastValidInfoPrinted = validated + return time.Second +} + +// return val: +// *MessageIndex - pointer to bad entry if there is one (requires reorg) +func (v *BlockValidator) advanceValidations(ctx context.Context) (*arbutil.MessageIndex, error) { + v.reorgMutex.RLock() + defer v.reorgMutex.RUnlock() -func (v *BlockValidator) sendValidations(ctx context.Context) { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - var batchCount uint64 wasmRoots := v.GetModuleRootsToValidate() room := 100 // even if there is more room then that it's fine for _, spawner := range v.validationSpawners { here := spawner.Room() / len(wasmRoots) if here <= 0 { - return + room = 0 } if here < room { room = here } } - for atomic.LoadInt32(&v.reorgsPending) == 0 { - if room <= 0 { - return + pos := v.validated() - 1 // to reverse the first +1 in the loop +validationsLoop: + for { + if ctx.Err() != nil { + return nil, ctx.Err() } - if batchCount <= v.globalPosNextSend.BatchNumber { - var err error - batchCount, err = v.inboxTracker.GetBatchCount() - if err != nil { - log.Error("validator failed to get message count", "err", err) - return - } - if batchCount <= v.globalPosNextSend.BatchNumber { - return - } + v.valLoopPos = pos + 1 + v.reorgMutex.RUnlock() + v.reorgMutex.RLock() + pos = v.valLoopPos + if pos >= v.recordSent() { + log.Trace("advanceValidations: nothing to validate", "pos", pos) + return nil, nil } - seqBatchEntry, haveBatch := v.sequencerBatches.Load(v.globalPosNextSend.BatchNumber) - if !haveBatch && batchCount == v.globalPosNextSend.BatchNumber+1 { - // This is the latest batch. - // Wait a bit to see if the inbox tracker populates this sequencer batch, - // but if it's still missing after this wait, we'll query it from the inbox reader. - time.Sleep(time.Second) - seqBatchEntry, haveBatch = v.sequencerBatches.Load(v.globalPosNextSend.BatchNumber) + validationStatus, found := v.validations.Load(pos) + if !found { + return nil, fmt.Errorf("not found entry for pos %d", pos) } - if !haveBatch { - seqMsg, err := v.inboxReader.GetSequencerMessageBytes(ctx, v.globalPosNextSend.BatchNumber) - if err != nil { - log.Error("validator failed to read sequencer message", "err", err) - return - } - v.ProcessBatches(v.globalPosNextSend.BatchNumber, [][]byte{seqMsg}) - seqBatchEntry = seqMsg + currentStatus := validationStatus.getStatus() + if currentStatus == RecordFailed { + // retry + log.Warn("Recording for validation failed, retrying..", "pos", pos) + return &pos, nil } - v.blockMutex.Lock() - v.lastBlockValidatedMutex.Lock() - if v.lastBlockValidatedUnknown { - firstMsgInBatch := arbutil.MessageIndex(0) - if v.globalPosNextSend.BatchNumber > 0 { - var err error - firstMsgInBatch, err = v.inboxTracker.GetBatchMessageCount(v.globalPosNextSend.BatchNumber - 1) + if currentStatus == ValidationSent && pos == v.validated() { + if validationStatus.Entry.Start != v.lastValidGS { + log.Warn("Validation entry has wrong start state", "pos", pos, "start", validationStatus.Entry.Start, "expected", v.lastValidGS) + validationStatus.Cancel() + return &pos, nil + } + var wasmRoots []common.Hash + for i, run := range validationStatus.Runs { + if !run.Ready() { + log.Trace("advanceValidations: validation not ready", "pos", pos, "run", i) + continue validationsLoop + } + wasmRoots = append(wasmRoots, run.WasmModuleRoot()) + runEnd, err := run.Current() + if err == nil && runEnd != validationStatus.Entry.End { + err = fmt.Errorf("validation failed: expected %v got %v", validationStatus.Entry.End, runEnd) + writeErr := v.writeToFile(validationStatus.Entry, run.WasmModuleRoot()) + if writeErr != nil { + log.Warn("failed to write debug results file", "err", writeErr) + } + } if err != nil { - v.lastBlockValidatedMutex.Unlock() - v.blockMutex.Unlock() - log.Error("validator couldnt read message count", "err", err) - return + validatorFailedValidationsCounter.Inc(1) + v.possiblyFatal(err) + return &pos, nil // if not fatal - retry } + validatorValidValidationsCounter.Inc(1) } - v.lastBlockValidated = uint64(arbutil.MessageCountToBlockNumber(firstMsgInBatch+arbutil.MessageIndex(v.globalPosNextSend.PosInBatch), v.genesisBlockNum)) - validatorLastBlockValidatedGauge.Update(int64(v.lastBlockValidated)) - v.nextBlockToValidate = v.lastBlockValidated + 1 - v.lastBlockValidatedUnknown = false - log.Info("Inbox caught up to staker", "blockNr", v.lastBlockValidated, "blockHash", v.lastBlockValidatedHash) - } - v.lastBlockValidatedMutex.Unlock() - nextBlockToValidate := v.nextBlockToValidate - v.blockMutex.Unlock() - nextMsg := arbutil.BlockNumberToMessageCount(nextBlockToValidate, v.genesisBlockNum) - 1 - // valdationEntries is By blockNumber - entry, found := v.validations.Load(nextBlockToValidate) - if !found { - return - } - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to validate batch") - return - } - if validationStatus.getStatus() < Prepared { - return - } - startPos, endPos, err := GlobalStatePositionsFor(v.inboxTracker, nextMsg, v.globalPosNextSend.BatchNumber) - if err != nil { - log.Error("failed calculating position for validation", "err", err, "msg", nextMsg, "batch", v.globalPosNextSend.BatchNumber) - return - } - if startPos != v.globalPosNextSend { - log.Error("inconsistent pos mapping", "msg", nextMsg, "expected", v.globalPosNextSend, "found", startPos) - return - } - seqMsg, ok := seqBatchEntry.([]byte) - if !ok { - batchNum := validationStatus.Entry.StartPosition.BatchNumber - log.Error("sequencer message bad format", "blockNr", nextBlockToValidate, "msgNum", batchNum) - return - } - msgCountInBatch, err := v.inboxTracker.GetBatchMessageCount(v.globalPosNextSend.BatchNumber) - if err != nil { - log.Error("failed to get batch message count", "err", err, "batch", v.globalPosNextSend.BatchNumber) - return - } - lastBlockInBatch := arbutil.MessageCountToBlockNumber(msgCountInBatch, v.genesisBlockNum) - validatorLastBlockInLastBatchGauge.Update(lastBlockInBatch) - v.LaunchThread(func(ctx context.Context) { - validationCtx, cancel := context.WithCancel(ctx) - defer cancel() - validationStatus.Cancel = cancel - err := v.ValidationEntryAddSeqMessage(ctx, validationStatus.Entry, startPos, endPos, seqMsg) + err := v.writeLastValidated(validationStatus.Entry.End, wasmRoots) if err != nil { - validationStatus.replaceStatus(Prepared, RecordFailed) - if validationCtx.Err() == nil { - log.Error("error preparing validation", "err", err) - } - return + log.Error("failed writing new validated to database", "pos", pos, "err", err) + } + go v.recorder.MarkValid(pos, v.lastValidGS.BlockHash) + atomicStorePos(&v.validatedA, pos+1) + v.validations.Delete(pos) + nonBlockingTrigger(v.createNodesChan) + nonBlockingTrigger(v.sendRecordChan) + validatorMsgCountValidatedGauge.Update(int64(pos + 1)) + if v.testingProgressMadeChan != nil { + nonBlockingTrigger(v.testingProgressMadeChan) } + log.Trace("result validated", "count", v.validated(), "blockHash", v.lastValidGS.BlockHash) + continue + } + if room == 0 { + log.Trace("advanceValidations: no more room", "pos", pos) + return nil, nil + } + if currentStatus == Prepared { input, err := validationStatus.Entry.ToInput() - if err != nil { - validationStatus.replaceStatus(Prepared, RecordFailed) - if validationCtx.Err() == nil { - log.Error("error preparing validation", "err", err) - } - return + if err != nil && ctx.Err() == nil { + v.possiblyFatal(fmt.Errorf("%w: error preparing validation", err)) + continue } + replaced := validationStatus.replaceStatus(Prepared, SendingValidation) + if !replaced { + v.possiblyFatal(errors.New("failed to set SendingValidation status")) + } + validatorPendingValidationsGauge.Inc(1) + defer validatorPendingValidationsGauge.Dec(1) + var runs []validator.ValidationRun for _, moduleRoot := range wasmRoots { - for _, spawner := range v.validationSpawners { + for i, spawner := range v.validationSpawners { run := spawner.Launch(input, moduleRoot) - validationStatus.Runs = append(validationStatus.Runs, run) + log.Trace("advanceValidations: launched", "pos", validationStatus.Entry.Pos, "moduleRoot", moduleRoot, "spawner", i) + runs = append(runs, run) } } - validatorPendingValidationsGauge.Inc(1) - replaced := validationStatus.replaceStatus(Prepared, ValidationSent) - if !replaced { - v.possiblyFatal(errors.New("failed to set status")) - } - }) - room-- - v.blockMutex.Lock() - v.nextBlockToValidate++ - v.blockMutex.Unlock() - v.globalPosNextSend = endPos + validationCtx, cancel := context.WithCancel(ctx) + validationStatus.Runs = runs + validationStatus.Cancel = cancel + v.LaunchUntrackedThread(func() { + defer cancel() + replaced = validationStatus.replaceStatus(SendingValidation, ValidationSent) + if !replaced { + v.possiblyFatal(errors.New("failed to set status to ValidationSent")) + } + + // validationStatus might be removed from under us + // trigger validation progress when done + for _, run := range runs { + _, err := run.Await(validationCtx) + if err != nil { + return + } + } + nonBlockingTrigger(v.progressValidationsChan) + }) + room-- + } } } -func (v *BlockValidator) sendRecords(ctx context.Context) { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - v.blockMutex.Lock() - nextRecord := v.nextBlockToValidate - v.blockMutex.Unlock() - for atomic.LoadInt32(&v.reorgsPending) == 0 { - if nextRecord >= v.nextBlockToValidate+v.config().PrerecordedBlocks { - return - } - entry, found := v.validations.Load(nextRecord) - if !found { - header := v.blockchain.GetHeaderByNumber(nextRecord) - if header == nil { - // This block hasn't been created yet. - return - } - prevHeader := v.blockchain.GetHeaderByHash(header.ParentHash) - if prevHeader == nil && header.ParentHash != (common.Hash{}) { - log.Warn("failed to get prevHeader in block validator", "num", nextRecord-1, "hash", header.ParentHash) - return - } - msgNum := arbutil.BlockNumberToMessageCount(nextRecord, v.genesisBlockNum) - 1 - msg, err := v.streamer.GetMessage(msgNum) - if err != nil { - log.Warn("failed to get message in block validator", "err", err) - return - } - status, err := v.newValidationStatus(prevHeader, header, msg) - if err != nil { - log.Warn("failed to create validation status", "err", err) - return - } - v.blockMutex.Lock() - entry, found = v.validations.Load(nextRecord) - if !found { - v.validations.Store(nextRecord, status) - entry = status - } - v.blockMutex.Unlock() - } - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to send recordings") - return - } - currentStatus := validationStatus.getStatus() - if currentStatus == RecordFailed { - // retry - v.validations.Delete(nextRecord) - v.triggerSendValidations() - return - } - if currentStatus == Unprepared { - prevHeader := validationStatus.Entry.PrevBlockHeader - if prevHeader != nil { - _, err := v.recordingDatabase.GetOrRecreateState(ctx, prevHeader, stateLogFunc) - if err != nil { - log.Error("error trying to prepare state for recording", "err", err) - } - // add another reference that will be released by the record thread - _, err = v.recordingDatabase.StateFor(prevHeader) - if err != nil { - log.Error("error trying re-reference state for recording", "err", err) - } - if v.lastHeaderForPrepareState != nil { - v.recordingDatabase.Dereference(v.lastHeaderForPrepareState) - } - v.lastHeaderForPrepareState = prevHeader - } - err := v.sendRecord(validationStatus, true) - if err != nil { - log.Error("error trying to send preimage recording", "err", err) - } +func (v *BlockValidator) iterativeValidationProgress(ctx context.Context, ignored struct{}) time.Duration { + reorg, err := v.advanceValidations(ctx) + if err != nil { + log.Error("error trying to record for validation node", "err", err) + } else if reorg != nil { + err := v.Reorg(ctx, *reorg) + if err != nil { + log.Error("error trying to rorg validation", "pos", *reorg-1, "err", err) + v.possiblyFatal(err) } - nextRecord++ } + return v.config().ValidationPoll } -func (v *BlockValidator) writeLastValidatedToDb(blockNumber uint64, blockHash common.Hash, endPos GlobalStatePosition) error { - info := LastBlockValidatedDbInfo{ - BlockNumber: blockNumber, - BlockHash: blockHash, - AfterPosition: endPos, +var ErrValidationCanceled = errors.New("validation of block cancelled") + +func (v *BlockValidator) writeLastValidated(gs validator.GoGlobalState, wasmRoots []common.Hash) error { + v.lastValidGS = gs + info := GlobalStateValidatedInfo{ + GlobalState: gs, + WasmRoots: wasmRoots, } - encodedInfo, err := rlp.EncodeToBytes(info) + encoded, err := rlp.EncodeToBytes(info) if err != nil { return err } - err = v.db.Put(lastBlockValidatedInfoKey, encodedInfo) + err = v.db.Put(lastGlobalStateValidatedInfoKey, encoded) if err != nil { return err } return nil } -func (v *BlockValidator) progressValidated() { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - for atomic.LoadInt32(&v.reorgsPending) == 0 { - // Reads from blocksValidated can be non-atomic as all writes hold reorgMutex - checkingBlock := v.lastBlockValidated + 1 - entry, found := v.validations.Load(checkingBlock) - if !found { - return - } - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to advance validated counter") - return - } - if validationStatus.getStatus() < ValidationSent { - return - } - validationEntry := validationStatus.Entry - if validationEntry.BlockNumber != checkingBlock { - log.Error("bad block number for validation entry", "expected", checkingBlock, "found", validationEntry.BlockNumber) - return - } - // It's safe to read lastBlockValidatedHash without the lastBlockValidatedMutex as we have the reorgMutex - if v.lastBlockValidatedHash != validationEntry.PrevBlockHash { - log.Error("lastBlockValidatedHash is %v but validationEntry has prevBlockHash %v for block number %v", v.lastBlockValidatedHash, validationEntry.PrevBlockHash, v.lastBlockValidated) - return - } - expectedEnd, err := validationEntry.expectedEnd() - if err != nil { - v.possiblyFatal(err) - return - } - for _, run := range validationStatus.Runs { - if !run.Ready() { - return - } - runEnd, err := run.Current() - if err == nil && runEnd != expectedEnd { - err = fmt.Errorf("validation failed: expected %v got %v", expectedEnd, runEnd) - writeErr := v.writeToFile(validationEntry, run.WasmModuleRoot()) - if writeErr != nil { - log.Warn("failed to write validation debugging info", "err", err) - } - v.possiblyFatal(err) - } - if err != nil { - v.possiblyFatal(err) - validationStatus.setStatus(Failed) - validatorFailedValidationsCounter.Inc(1) - validatorPendingValidationsGauge.Dec(1) - return - } - } - for _, run := range validationStatus.Runs { - run.Cancel() - } - validationStatus.replaceStatus(ValidationSent, Valid) - validatorValidValidationsCounter.Inc(1) - validatorPendingValidationsGauge.Dec(1) - v.triggerSendValidations() - earliestBatchKept := atomic.LoadUint64(&v.earliestBatchKept) - seqMsgNr := validationEntry.StartPosition.BatchNumber - if earliestBatchKept < seqMsgNr { - for batch := earliestBatchKept; batch < seqMsgNr; batch++ { - v.sequencerBatches.Delete(batch) - } - atomic.StoreUint64(&v.earliestBatchKept, seqMsgNr) - } - - v.lastBlockValidatedMutex.Lock() - atomic.StoreUint64(&v.lastBlockValidated, checkingBlock) - validatorLastBlockValidatedGauge.Update(int64(checkingBlock)) - v.lastBlockValidatedHash = validationEntry.BlockHash - err = v.writeLastValidatedToDb(validationEntry.BlockNumber, validationEntry.BlockHash, validationEntry.EndPosition) - if err != nil { - log.Error("failed to write validated entry to database", "err", err) +func (v *BlockValidator) validGSIsNew(globalState validator.GoGlobalState) bool { + if v.legacyValidInfo != nil { + if v.legacyValidInfo.AfterPosition.BatchNumber > globalState.Batch { + return false } - v.lastBlockValidatedMutex.Unlock() - v.recentlyValid(validationEntry.BlockHeader) - - v.validations.Delete(checkingBlock) - select { - case v.progressChan <- checkingBlock: - default: + if v.legacyValidInfo.AfterPosition.BatchNumber == globalState.Batch && v.legacyValidInfo.AfterPosition.PosInBatch >= globalState.PosInBatch { + return false } + return true + } + if v.lastValidGS.Batch > globalState.Batch { + return false } + if v.lastValidGS.Batch == globalState.Batch && v.lastValidGS.PosInBatch >= globalState.PosInBatch { + return false + } + return true } -func (v *BlockValidator) AssumeValid(globalState validator.GoGlobalState) error { +// this accepts globalstate even if not caught up +func (v *BlockValidator) InitAssumeValid(globalState validator.GoGlobalState) error { if v.Started() { - return fmt.Errorf("cannot handle AssumeValid while running") + return fmt.Errorf("cannot handle InitAssumeValid while running") } - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - - v.lastBlockValidatedMutex.Lock() - defer v.lastBlockValidatedMutex.Unlock() - // don't do anything if we already validated past that - if v.globalPosNextSend.BatchNumber > globalState.Batch { - return nil - } - if v.globalPosNextSend.BatchNumber == globalState.Batch && v.globalPosNextSend.PosInBatch > globalState.PosInBatch { + if !v.validGSIsNew(globalState) { return nil } - block := v.blockchain.GetBlockByHash(globalState.BlockHash) - if block == nil { - v.lastBlockValidatedUnknown = true - } else { - v.lastBlockValidated = block.NumberU64() - validatorLastBlockValidatedGauge.Update(int64(v.lastBlockValidated)) - v.nextBlockToValidate = v.lastBlockValidated + 1 - } - v.lastBlockValidatedHash = globalState.BlockHash - v.globalPosNextSend = GlobalStatePosition{ - BatchNumber: globalState.Batch, - PosInBatch: globalState.PosInBatch, + v.legacyValidInfo = nil + + err := v.writeLastValidated(globalState, nil) + if err != nil { + log.Error("failed writing new validated to database", "pos", v.lastValidGS, "err", err) } + return nil } -func (v *BlockValidator) LastBlockValidated() uint64 { - return atomic.LoadUint64(&v.lastBlockValidated) -} +func (v *BlockValidator) UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) { -func (v *BlockValidator) LastBlockValidatedAndHash() (blockNumber uint64, blockHash common.Hash, wasmModuleRoots []common.Hash) { - v.lastBlockValidatedMutex.Lock() - blockValidated := v.lastBlockValidated - blockValidatedHash := v.lastBlockValidatedHash - v.lastBlockValidatedMutex.Unlock() + if count <= v.validated() { + return + } - // things can be removed from, but not added to, moduleRootsToValidate. By taking root hashes fter the block we know result is valid - moduleRootsValidated := v.GetModuleRootsToValidate() + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() - return blockValidated, blockValidatedHash, moduleRootsValidated -} + if count <= v.validated() { + return + } -// Because batches and blocks are handled at separate layers in the node, -// and because block generation from messages is asynchronous, -// this call is different than ReorgToBlock, which is currently called later. -func (v *BlockValidator) ReorgToBatchCount(count uint64) { - v.batchMutex.Lock() - defer v.batchMutex.Unlock() - v.reorgToBatchCountImpl(count) -} + if !v.chainCaughtUp { + if !v.validGSIsNew(globalState) { + return + } + v.legacyValidInfo = nil + err := v.writeLastValidated(globalState, nil) + if err != nil { + log.Error("error writing last validated", "err", err) + } + return + } -func (v *BlockValidator) reorgToBatchCountImpl(count uint64) { - localBatchCount := v.nextBatchKept - if localBatchCount < count { + countUint64 := uint64(count) + msg, err := v.streamer.GetMessage(count - 1) + if err != nil { + log.Error("getMessage error", "err", err, "count", count) return } - for i := count; i < localBatchCount; i++ { - v.sequencerBatches.Delete(i) + // delete no-longer relevant entries + for iPos := v.validated(); iPos < count && iPos < v.created(); iPos++ { + status, found := v.validations.Load(iPos) + if found && status != nil && status.Cancel != nil { + status.Cancel() + } + v.validations.Delete(iPos) + } + if v.created() < count { + v.nextCreateStartGS = globalState + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + v.nextCreateBatchReread = true + v.createdA = countUint64 + } + // under the reorg mutex we don't need atomic access + if v.recordSentA < countUint64 { + v.recordSentA = countUint64 + } + v.validatedA = countUint64 + v.valLoopPos = count + validatorMsgCountValidatedGauge.Update(int64(countUint64)) + err = v.writeLastValidated(globalState, nil) // we don't know which wasm roots were validated + if err != nil { + log.Error("failed writing valid state after reorg", "err", err) } - v.nextBatchKept = count + nonBlockingTrigger(v.createNodesChan) } -func (v *BlockValidator) ProcessBatches(pos uint64, batches [][]byte) { - v.batchMutex.Lock() - defer v.batchMutex.Unlock() - - v.reorgToBatchCountImpl(pos) - - // Attempt to fill in earliestBatchKept if it's empty - atomic.CompareAndSwapUint64(&v.earliestBatchKept, 0, pos) - - for i, msg := range batches { - v.sequencerBatches.Store(pos+uint64(i), msg) +// Because batches and blocks are handled at separate layers in the node, +// and because block generation from messages is asynchronous, +// this call is different than Reorg, which is currently called later. +func (v *BlockValidator) ReorgToBatchCount(count uint64) { + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() + if v.nextCreateStartGS.Batch >= count { + v.nextCreateBatchReread = true } - v.nextBatchKept = pos + uint64(len(batches)) - v.triggerSendValidations() } -func (v *BlockValidator) ReorgToBlock(blockNum uint64, blockHash common.Hash) error { - atomic.AddInt32(&v.reorgsPending, 1) +func (v *BlockValidator) Reorg(ctx context.Context, count arbutil.MessageIndex) error { v.reorgMutex.Lock() defer v.reorgMutex.Unlock() - atomic.AddInt32(&v.reorgsPending, -1) - - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - - v.lastBlockValidatedMutex.Lock() - defer v.lastBlockValidatedMutex.Unlock() - - if blockNum < v.lastValidationEntryBlock { - log.Warn("block validator processing reorg", "blockNum", blockNum) - err := v.reorgToBlockImpl(blockNum, blockHash) - if err != nil { - return fmt.Errorf("block validator reorg failed: %w", err) - } + if count <= 1 { + return errors.New("cannot reorg out genesis") } - - return nil -} - -// must hold reorgMutex, blockMutex, and lastBlockValidatedMutex -func (v *BlockValidator) reorgToBlockImpl(blockNum uint64, blockHash common.Hash) error { - for b := blockNum + 1; b <= v.lastValidationEntryBlock; b++ { - entry, found := v.validations.Load(b) - if !found { - continue - } - v.validations.Delete(b) - - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to reorg block validator") - continue - } - log.Debug("canceling validation due to reorg", "block", b) - if validationStatus.Cancel != nil { - validationStatus.Cancel() - } + if !v.chainCaughtUp { + return nil } - v.lastValidationEntryBlock = blockNum - if v.nextBlockToValidate <= blockNum+1 { + if v.created() < count { return nil } - msgIndex := arbutil.BlockNumberToMessageCount(blockNum, v.genesisBlockNum) - 1 - batchCount, err := v.inboxTracker.GetBatchCount() + _, endPosition, err := v.GlobalStatePositionsAtCount(count) if err != nil { + v.possiblyFatal(err) return err } - batch, err := FindBatchContainingMessageIndex(v.inboxTracker, msgIndex, batchCount) + res, err := v.streamer.ResultAtCount(count) if err != nil { + v.possiblyFatal(err) return err } - if batch >= batchCount { - // This reorg is past the latest batch. - // Attempt to recover by loading a next validation state at the start of the next batch. - v.globalPosNextSend = GlobalStatePosition{ - BatchNumber: batch, - PosInBatch: 0, - } - msgCount, err := v.inboxTracker.GetBatchMessageCount(batch - 1) - if err != nil { - return err - } - nextBlockSigned := arbutil.MessageCountToBlockNumber(msgCount, v.genesisBlockNum) + 1 - if nextBlockSigned <= 0 { - return errors.New("reorg past genesis block") - } - blockNum = uint64(nextBlockSigned) - 1 - block := v.blockchain.GetBlockByNumber(blockNum) - if block == nil { - return fmt.Errorf("failed to get end of batch block %v", blockNum) - } - blockHash = block.Hash() - v.lastValidationEntryBlock = blockNum - } else { - _, v.globalPosNextSend, err = GlobalStatePositionsFor(v.inboxTracker, msgIndex, batch) - if err != nil { - return err - } - } - if v.nextBlockToValidate > blockNum+1 { - v.nextBlockToValidate = blockNum + 1 + msg, err := v.streamer.GetMessage(count - 1) + if err != nil { + v.possiblyFatal(err) + return err } - - if v.lastBlockValidated > blockNum { - atomic.StoreUint64(&v.lastBlockValidated, blockNum) - validatorLastBlockValidatedGauge.Update(int64(blockNum)) - v.lastBlockValidatedHash = blockHash - - err = v.writeLastValidatedToDb(blockNum, blockHash, v.globalPosNextSend) + for iPos := count; iPos < v.created(); iPos++ { + status, found := v.validations.Load(iPos) + if found && status != nil && status.Cancel != nil { + status.Cancel() + } + v.validations.Delete(iPos) + } + v.nextCreateStartGS = buildGlobalState(*res, endPosition) + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + v.nextCreateBatchReread = true + countUint64 := uint64(count) + v.createdA = countUint64 + // under the reorg mutex we don't need atomic access + if v.recordSentA > countUint64 { + v.recordSentA = countUint64 + } + if v.validatedA > countUint64 { + v.validatedA = countUint64 + validatorMsgCountValidatedGauge.Update(int64(countUint64)) + err := v.writeLastValidated(v.nextCreateStartGS, nil) // we don't know which wasm roots were validated if err != nil { - return err + log.Error("failed writing valid state after reorg", "err", err) } } - + nonBlockingTrigger(v.createNodesChan) return nil } @@ -1039,77 +966,185 @@ func (v *BlockValidator) Initialize(ctx context.Context) error { return nil } -func (v *BlockValidator) Start(ctxIn context.Context) error { - v.StopWaiter.Start(ctxIn, v) - err := stopwaiter.CallIterativelyWith[struct{}](v, - func(ctx context.Context, unused struct{}) time.Duration { - v.sendRecords(ctx) - v.sendValidations(ctx) - return v.config().ValidationPoll - }, - v.sendValidationsChan) +func (v *BlockValidator) checkLegacyValid() error { + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() + if v.legacyValidInfo == nil { + return nil + } + batchCount, err := v.inboxTracker.GetBatchCount() if err != nil { return err } - v.CallIteratively(func(ctx context.Context) time.Duration { - v.progressValidated() - return v.config().ValidationPoll - }) - lastValid := uint64(0) - v.CallIteratively(func(ctx context.Context) time.Duration { - newValid, validHash, wasmModuleRoots := v.LastBlockValidatedAndHash() - if newValid != lastValid { - validHeader := v.blockchain.GetHeader(validHash, newValid) - if validHeader == nil { - foundHeader := v.blockchain.GetHeaderByNumber(newValid) - foundHash := common.Hash{} - if foundHeader != nil { - foundHash = foundHeader.Hash() - } - log.Warn("last valid block not in blockchain", "blockNum", newValid, "validatedBlockHash", validHash, "found-hash", foundHash) - } else { - validTimestamp := time.Unix(int64(validHeader.Time), 0) - log.Info("Validated blocks", "blockNum", newValid, "hash", validHash, - "timestamp", validTimestamp, "age", time.Since(validTimestamp), "wasm", wasmModuleRoots) - } - lastValid = newValid + requiredBatchCount := v.legacyValidInfo.AfterPosition.BatchNumber + 1 + if v.legacyValidInfo.AfterPosition.PosInBatch == 0 { + requiredBatchCount -= 1 + } + if batchCount < requiredBatchCount { + log.Warn("legacy valid batch ahead of db", "current", batchCount, "required", requiredBatchCount) + return nil + } + var msgCount arbutil.MessageIndex + if v.legacyValidInfo.AfterPosition.BatchNumber > 0 { + msgCount, err = v.inboxTracker.GetBatchMessageCount(v.legacyValidInfo.AfterPosition.BatchNumber - 1) + if err != nil { + return err } - return time.Second - }) + } + msgCount += arbutil.MessageIndex(v.legacyValidInfo.AfterPosition.PosInBatch) + processedCount, err := v.streamer.GetProcessedMessageCount() + if err != nil { + return err + } + if processedCount < msgCount { + log.Warn("legacy valid message count ahead of db", "current", processedCount, "required", msgCount) + return nil + } + result, err := v.streamer.ResultAtCount(msgCount) + if err != nil { + return err + } + if result.BlockHash != v.legacyValidInfo.BlockHash { + log.Error("legacy validated blockHash does not fit chain", "info.BlockHash", v.legacyValidInfo.BlockHash, "chain", result.BlockHash, "count", msgCount) + return fmt.Errorf("legacy validated blockHash does not fit chain") + } + validGS := validator.GoGlobalState{ + BlockHash: result.BlockHash, + SendRoot: result.SendRoot, + Batch: v.legacyValidInfo.AfterPosition.BatchNumber, + PosInBatch: v.legacyValidInfo.AfterPosition.PosInBatch, + } + err = v.writeLastValidated(validGS, nil) + if err == nil { + err = v.db.Delete(legacyLastBlockValidatedInfoKey) + if err != nil { + err = fmt.Errorf("deleting legacy: %w", err) + } + } + if err != nil { + log.Error("failed writing initial lastValid on upgrade from legacy", "new-info", v.lastValidGS, "err", err) + } else { + log.Info("updated last-valid from legacy", "lastValid", v.lastValidGS) + } + v.legacyValidInfo = nil return nil } -func (v *BlockValidator) StopAndWait() { - v.StopWaiter.StopAndWait() - err := v.recentShutdown() +// checks that the chain caught up to lastValidGS, used in startup +func (v *BlockValidator) checkValidatedGSCaughtUp() (bool, error) { + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() + if v.chainCaughtUp { + return true, nil + } + if v.legacyValidInfo != nil { + return false, nil + } + if v.lastValidGS.Batch == 0 { + return false, errors.New("lastValid not initialized. cannot validate genesis") + } + caughtUp, count, err := GlobalStateToMsgCount(v.inboxTracker, v.streamer, v.lastValidGS) if err != nil { - log.Error("error storing valid state", "err", err) + return false, err } + if !caughtUp { + batchCount, err := v.inboxTracker.GetBatchCount() + if err != nil { + log.Error("failed reading batch count", "err", err) + batchCount = 0 + } + batchMsgCount, err := v.inboxTracker.GetBatchMessageCount(batchCount - 1) + if err != nil { + log.Error("failed reading batchMsgCount", "err", err) + batchMsgCount = 0 + } + processedMsgCount, err := v.streamer.GetProcessedMessageCount() + if err != nil { + log.Error("failed reading processedMsgCount", "err", err) + processedMsgCount = 0 + } + log.Info("validator catching up to last valid", "lastValid.Batch", v.lastValidGS.Batch, "lastValid.PosInBatch", v.lastValidGS.PosInBatch, "batchCount", batchCount, "batchMsgCount", batchMsgCount, "processedMsgCount", processedMsgCount) + return false, nil + } + msg, err := v.streamer.GetMessage(count - 1) + if err != nil { + return false, err + } + v.nextCreateBatchReread = true + v.nextCreateStartGS = v.lastValidGS + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + atomicStorePos(&v.createdA, count) + atomicStorePos(&v.recordSentA, count) + atomicStorePos(&v.validatedA, count) + validatorMsgCountValidatedGauge.Update(int64(count)) + v.chainCaughtUp = true + return true, nil } -// WaitForBlock can only be used from One thread -func (v *BlockValidator) WaitForBlock(ctx context.Context, blockNumber uint64, timeout time.Duration) bool { +func (v *BlockValidator) LaunchWorkthreadsWhenCaughtUp(ctx context.Context) { + for { + err := v.checkLegacyValid() + if err != nil { + log.Error("validator got error updating legacy validated info. Consider restarting with dangerous.reset-block-validation", "err", err) + } + caughtUp, err := v.checkValidatedGSCaughtUp() + if err != nil { + log.Error("validator got error waiting for chain to catch up. Consider restarting with dangerous.reset-block-validation", "err", err) + } + if caughtUp { + break + } + select { + case <-ctx.Done(): + return + case <-time.After(v.config().ValidationPoll): + } + } + err := stopwaiter.CallIterativelyWith[struct{}](&v.StopWaiterSafe, v.iterativeValidationEntryCreator, v.createNodesChan) + if err != nil { + v.possiblyFatal(err) + } + err = stopwaiter.CallIterativelyWith[struct{}](&v.StopWaiterSafe, v.iterativeValidationEntryRecorder, v.sendRecordChan) + if err != nil { + v.possiblyFatal(err) + } + err = stopwaiter.CallIterativelyWith[struct{}](&v.StopWaiterSafe, v.iterativeValidationProgress, v.progressValidationsChan) + if err != nil { + v.possiblyFatal(err) + } +} + +func (v *BlockValidator) Start(ctxIn context.Context) error { + v.StopWaiter.Start(ctxIn, v) + v.LaunchThread(v.LaunchWorkthreadsWhenCaughtUp) + v.CallIteratively(v.iterativeValidationPrint) + return nil +} + +func (v *BlockValidator) StopAndWait() { + v.StopWaiter.StopAndWait() +} + +// WaitForPos can only be used from One thread +func (v *BlockValidator) WaitForPos(t *testing.T, ctx context.Context, pos arbutil.MessageIndex, timeout time.Duration) bool { + triggerchan := make(chan struct{}) + v.testingProgressMadeChan = triggerchan timer := time.NewTimer(timeout) defer timer.Stop() + lastLoop := false for { - if atomic.LoadUint64(&v.lastBlockValidated) >= blockNumber { + if v.validated() > pos { return true } + if lastLoop { + return false + } select { case <-timer.C: - if atomic.LoadUint64(&v.lastBlockValidated) >= blockNumber { - return true - } - return false - case block, ok := <-v.progressChan: - if block >= blockNumber { - return true - } - if !ok { - return false - } + lastLoop = true + case <-triggerchan: case <-ctx.Done(): - return false + lastLoop = true } } } diff --git a/staker/block_validator_schema.go b/staker/block_validator_schema.go index e5d7dba71b..f6eb39f015 100644 --- a/staker/block_validator_schema.go +++ b/staker/block_validator_schema.go @@ -3,14 +3,23 @@ package staker -import "github.com/ethereum/go-ethereum/common" +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/offchainlabs/nitro/validator" +) -type LastBlockValidatedDbInfo struct { +type legacyLastBlockValidatedDbInfo struct { BlockNumber uint64 BlockHash common.Hash AfterPosition GlobalStatePosition } +type GlobalStateValidatedInfo struct { + GlobalState validator.GoGlobalState + WasmRoots []common.Hash +} + var ( - lastBlockValidatedInfoKey = []byte("_lastBlockValidatedInfo") // contains a rlp encoded lastBlockValidatedDbInfo + lastGlobalStateValidatedInfoKey = []byte("_lastGlobalStateValidatedInfo") // contains a rlp encoded lastBlockValidatedDbInfo + legacyLastBlockValidatedInfoKey = []byte("_lastBlockValidatedInfo") // LEGACY - contains a rlp encoded lastBlockValidatedDbInfo ) diff --git a/staker/challenge_manager.go b/staker/challenge_manager.go index 8d35d962d3..ac2ae8835a 100644 --- a/staker/challenge_manager.go +++ b/staker/challenge_manager.go @@ -14,10 +14,10 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind/backends" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rpc" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/solgen/go/challengegen" "github.com/offchainlabs/nitro/validator" ) @@ -72,9 +72,9 @@ type ChallengeManager struct { wasmModuleRoot common.Hash // these fields are empty until working on execution challenge - initialMachineBlockNr int64 - executionChallengeBackend *ExecutionChallengeBackend - machineFinalStepCount uint64 + initialMachineMessageCount arbutil.MessageIndex + executionChallengeBackend *ExecutionChallengeBackend + machineFinalStepCount uint64 } // NewChallengeManager constructs a new challenge manager. @@ -86,8 +86,6 @@ func NewChallengeManager( fromAddr common.Address, challengeManagerAddr common.Address, challengeIndex uint64, - l2blockChain *core.BlockChain, - inboxTracker InboxTrackerInterface, val *StatelessBlockValidator, startL1Block uint64, confirmationBlocks int64, @@ -125,13 +123,11 @@ func NewChallengeManager( return nil, fmt.Errorf("error getting challenge %v info: %w", challengeIndex, err) } - genesisBlockNum := l2blockChain.Config().ArbitrumChainParams.GenesisBlockNum backend, err := NewBlockChallengeBackend( parsedLog, challengeInfo.MaxInboxMessages, - l2blockChain, - inboxTracker, - genesisBlockNum, + val.streamer, + val.inboxTracker, ) if err != nil { return nil, fmt.Errorf("error creating block challenge backend for challenge %v: %w", challengeIndex, err) @@ -462,23 +458,18 @@ func (m *ChallengeManager) IssueOneStepProof( } func (m *ChallengeManager) createExecutionBackend(ctx context.Context, step uint64) error { - blockNum := m.blockChallengeBackend.GetBlockNrAtStep(step) - // Get the next message and block header, and record the full block creation - if m.initialMachineBlockNr == blockNum && m.executionChallengeBackend != nil { + initialCount := m.blockChallengeBackend.GetMessageCountAtStep(step) + if m.initialMachineMessageCount == initialCount && m.executionChallengeBackend != nil { return nil } m.executionChallengeBackend = nil - nextHeader := m.blockChallengeBackend.bc.GetHeaderByNumber(uint64(blockNum + 1)) - if nextHeader == nil { - return fmt.Errorf("next block header %v after challenge point unknown", blockNum+1) - } - entry, err := m.validator.CreateReadyValidationEntry(ctx, nextHeader) + entry, err := m.validator.CreateReadyValidationEntry(ctx, initialCount) if err != nil { - return fmt.Errorf("error creating validation entry for challenge %v block %v for execution challenge: %w", m.challengeIndex, blockNum, err) + return fmt.Errorf("error creating validation entry for challenge %v msg %v for execution challenge: %w", m.challengeIndex, initialCount, err) } input, err := entry.ToInput() if err != nil { - return fmt.Errorf("error getting validation entry input of challenge %v block %v: %w", m.challengeIndex, blockNum, err) + return fmt.Errorf("error getting validation entry input of challenge %v msg %v: %w", m.challengeIndex, initialCount, err) } var prunedBatches []validator.BatchInfo for _, batch := range input.BatchInfo { @@ -489,7 +480,7 @@ func (m *ChallengeManager) createExecutionBackend(ctx context.Context, step uint input.BatchInfo = prunedBatches execRun, err := m.validator.execSpawner.CreateExecutionRun(m.wasmModuleRoot, input).Await(ctx) if err != nil { - return fmt.Errorf("error creating execution backend for block %v: %w", blockNum, err) + return fmt.Errorf("error creating execution backend for msg %v: %w", initialCount, err) } backend, err := NewExecutionChallengeBackend(execRun) if err != nil { @@ -504,16 +495,16 @@ func (m *ChallengeManager) createExecutionBackend(ctx context.Context, step uint return fmt.Errorf("error getting execution challenge final state: %w", err) } if expectedStatus != computedStatus { - return fmt.Errorf("after block %v expected status %v but got %v", blockNum, expectedStatus, computedStatus) + return fmt.Errorf("after msg %v expected status %v but got %v", initialCount, expectedStatus, computedStatus) } if computedStatus == StatusFinished { if computedState != expectedState { - return fmt.Errorf("after block %v expected global state %v but got %v", blockNum, expectedState, computedState) + return fmt.Errorf("after msg %v expected global state %v but got %v", initialCount, expectedState, computedState) } } m.executionChallengeBackend = backend m.machineFinalStepCount = machineStepCount - m.initialMachineBlockNr = blockNum + m.initialMachineMessageCount = initialCount return nil } @@ -569,8 +560,7 @@ func (m *ChallengeManager) Act(ctx context.Context) (*types.Transaction, error) return nil, fmt.Errorf("error creating execution backend: %w", err) } machineStepCount := m.machineFinalStepCount - blockNum := m.initialMachineBlockNr - log.Info("issuing one step proof", "challenge", m.challengeIndex, "machineStepCount", machineStepCount, "blockNum", blockNum) + log.Info("issuing one step proof", "challenge", m.challengeIndex, "machineStepCount", machineStepCount, "initialCount", m.initialMachineMessageCount) return m.blockChallengeBackend.IssueExecChallenge( m.challengeCore, state, diff --git a/staker/l1_validator.go b/staker/l1_validator.go index f16d4c9538..aa9107fd90 100644 --- a/staker/l1_validator.go +++ b/staker/l1_validator.go @@ -11,11 +11,11 @@ import ( "time" "github.com/offchainlabs/nitro/arbstate" + "github.com/offchainlabs/nitro/util/arbmath" "github.com/offchainlabs/nitro/validator" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" @@ -41,16 +41,14 @@ const ( ) type L1Validator struct { - rollup *RollupWatcher - rollupAddress common.Address - validatorUtils *rollupgen.ValidatorUtils - client arbutil.L1Interface - builder *ValidatorTxBuilder - wallet ValidatorWalletInterface - callOpts bind.CallOpts - genesisBlockNumber uint64 - - l2Blockchain *core.BlockChain + rollup *RollupWatcher + rollupAddress common.Address + validatorUtils *rollupgen.ValidatorUtils + client arbutil.L1Interface + builder *ValidatorTxBuilder + wallet ValidatorWalletInterface + callOpts bind.CallOpts + das arbstate.DataAvailabilityReader inboxTracker InboxTrackerInterface txStreamer TransactionStreamerInterface @@ -63,7 +61,6 @@ func NewL1Validator( wallet ValidatorWalletInterface, validatorUtilsAddress common.Address, callOpts bind.CallOpts, - l2Blockchain *core.BlockChain, das arbstate.DataAvailabilityReader, inboxTracker InboxTrackerInterface, txStreamer TransactionStreamerInterface, @@ -84,24 +81,18 @@ func NewL1Validator( if err != nil { return nil, err } - genesisBlockNumber, err := txStreamer.GetGenesisBlockNumber() - if err != nil { - return nil, err - } return &L1Validator{ - rollup: rollup, - rollupAddress: wallet.RollupAddress(), - validatorUtils: validatorUtils, - client: client, - builder: builder, - wallet: wallet, - callOpts: callOpts, - genesisBlockNumber: genesisBlockNumber, - l2Blockchain: l2Blockchain, - das: das, - inboxTracker: inboxTracker, - txStreamer: txStreamer, - blockValidator: blockValidator, + rollup: rollup, + rollupAddress: wallet.RollupAddress(), + validatorUtils: validatorUtils, + client: client, + builder: builder, + wallet: wallet, + callOpts: callOpts, + das: das, + inboxTracker: inboxTracker, + txStreamer: txStreamer, + blockValidator: blockValidator, }, nil } @@ -231,44 +222,28 @@ type OurStakerInfo struct { *StakerInfo } -// Returns (block number, global state inbox position is invalid, error). -// If global state is invalid, block number is set to the last of the batch. -func (v *L1Validator) blockNumberFromGlobalState(gs validator.GoGlobalState) (int64, bool, error) { - var batchHeight arbutil.MessageIndex - if gs.Batch > 0 { - var err error - batchHeight, err = v.inboxTracker.GetBatchMessageCount(gs.Batch - 1) - if err != nil { - return 0, false, err - } - } - - // Validate the PosInBatch if it's non-zero - if gs.PosInBatch > 0 { - nextBatchHeight, err := v.inboxTracker.GetBatchMessageCount(gs.Batch) - if err != nil { - return 0, false, err - } - - if gs.PosInBatch >= uint64(nextBatchHeight-batchHeight) { - // This PosInBatch would enter the next batch. Return the last block before the next batch. - // We can be sure that MessageCountToBlockNumber will return a non-negative number as nextBatchHeight must be nonzero. - return arbutil.MessageCountToBlockNumber(nextBatchHeight, v.genesisBlockNumber), true, nil - } - } - - return arbutil.MessageCountToBlockNumber(batchHeight+arbutil.MessageIndex(gs.PosInBatch), v.genesisBlockNumber), false, nil -} - -func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurStakerInfo, strategy StakerStrategy, makeAssertionInterval time.Duration) (nodeAction, bool, error) { - startState, prevInboxMaxCount, startStateProposedL1, startStateProposedParentChain, err := lookupNodeStartState(ctx, v.rollup, stakerInfo.LatestStakedNode, stakerInfo.LatestStakedNodeHash) +func (v *L1Validator) generateNodeAction( + ctx context.Context, + stakerInfo *OurStakerInfo, + strategy StakerStrategy, + stakerConfig *L1ValidatorConfig, +) (nodeAction, bool, error) { + startState, prevInboxMaxCount, startStateProposedL1, startStateProposedParentChain, err := lookupNodeStartState( + ctx, v.rollup, stakerInfo.LatestStakedNode, stakerInfo.LatestStakedNodeHash, + ) if err != nil { - return nil, false, fmt.Errorf("error looking up node %v (hash %v) start state: %w", stakerInfo.LatestStakedNode, stakerInfo.LatestStakedNodeHash, err) + return nil, false, fmt.Errorf( + "error looking up node %v (hash %v) start state: %w", + stakerInfo.LatestStakedNode, stakerInfo.LatestStakedNodeHash, err, + ) } - startStateProposedHeader, err := v.client.HeaderByNumber(ctx, new(big.Int).SetUint64(startStateProposedParentChain)) + startStateProposedHeader, err := v.client.HeaderByNumber(ctx, arbmath.UintToBig(startStateProposedParentChain)) if err != nil { - return nil, false, fmt.Errorf("error looking up L1 header of block %v of node start state: %w", startStateProposedParentChain, err) + return nil, false, fmt.Errorf( + "error looking up L1 header of block %v of node start state: %w", + startStateProposedParentChain, err, + ) } startStateProposedTime := time.Unix(int64(startStateProposedHeader.Time), 0) @@ -279,68 +254,101 @@ func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurSta if err != nil { return nil, false, fmt.Errorf("error getting batch count from inbox tracker: %w", err) } - if localBatchCount < startState.RequiredBatches() { - log.Info("catching up to chain batches", "localBatches", localBatchCount, "target", startState.RequiredBatches()) + if localBatchCount < startState.RequiredBatches() || localBatchCount == 0 { + log.Info( + "catching up to chain batches", "localBatches", localBatchCount, + "target", startState.RequiredBatches(), + ) return nil, false, nil } - startBlock := v.l2Blockchain.GetBlockByHash(startState.GlobalState.BlockHash) - if startBlock == nil && (startState.GlobalState != validator.GoGlobalState{}) { - expectedBlockHeight, inboxPositionInvalid, err := v.blockNumberFromGlobalState(startState.GlobalState) - if err != nil { - return nil, false, fmt.Errorf("error getting block number from global state: %w", err) + caughtUp, startCount, err := GlobalStateToMsgCount(v.inboxTracker, v.txStreamer, startState.GlobalState) + if err != nil { + return nil, false, fmt.Errorf("start state not in chain: %w", err) + } + if !caughtUp { + target := GlobalStatePosition{ + BatchNumber: startState.GlobalState.Batch, + PosInBatch: startState.GlobalState.PosInBatch, } - if inboxPositionInvalid { - log.Error("invalid start global state inbox position", startState.GlobalState.BlockHash, "batch", startState.GlobalState.Batch, "pos", startState.GlobalState.PosInBatch) - return nil, false, errors.New("invalid start global state inbox position") + var current GlobalStatePosition + head, err := v.txStreamer.GetProcessedMessageCount() + if err != nil { + _, current, err = v.blockValidator.GlobalStatePositionsAtCount(head) } - latestHeader := v.l2Blockchain.CurrentBlock() - if latestHeader.Number.Int64() < expectedBlockHeight { - log.Info("catching up to chain blocks", "localBlocks", latestHeader.Number, "target", expectedBlockHeight) - return nil, false, nil + if err != nil { + log.Info("catching up to chain messages", "target", target) + } else { + log.Info("catching up to chain blocks", "target", target, "current", current) } - log.Error("unknown start block hash", "hash", startState.GlobalState.BlockHash, "batch", startState.GlobalState.Batch, "pos", startState.GlobalState.PosInBatch) - return nil, false, errors.New("unknown start block hash") + return nil, false, nil } - var lastBlockValidated uint64 + var validatedCount arbutil.MessageIndex + var validatedGlobalState validator.GoGlobalState if v.blockValidator != nil { - var expectedHash common.Hash - var validRoots []common.Hash - lastBlockValidated, expectedHash, validRoots = v.blockValidator.LastBlockValidatedAndHash() - haveHash := v.l2Blockchain.GetCanonicalHash(lastBlockValidated) - if haveHash != expectedHash { - return nil, false, fmt.Errorf("block validator validated block %v as hash %v but blockchain has hash %v", lastBlockValidated, expectedHash, haveHash) + valInfo, err := v.blockValidator.ReadLastValidatedInfo() + if err != nil || valInfo == nil { + return nil, false, err + } + validatedGlobalState = valInfo.GlobalState + caughtUp, validatedCount, err = GlobalStateToMsgCount( + v.inboxTracker, v.txStreamer, valInfo.GlobalState, + ) + if err != nil { + return nil, false, fmt.Errorf("%w: not found validated block in blockchain", err) + } + if !caughtUp { + log.Info("catching up to last validated block", "target", valInfo.GlobalState) + return nil, false, nil } if err := v.updateBlockValidatorModuleRoot(ctx); err != nil { return nil, false, fmt.Errorf("error updating block validator module root: %w", err) } wasmRootValid := false - for _, root := range validRoots { + for _, root := range valInfo.WasmRoots { if v.lastWasmModuleRoot == root { wasmRootValid = true break } } if !wasmRootValid { - return nil, false, fmt.Errorf("wasmroot doesn't match rollup : %v, valid: %v", v.lastWasmModuleRoot, validRoots) + if !stakerConfig.Dangerous.IgnoreRollupWasmModuleRoot { + return nil, false, fmt.Errorf( + "wasmroot doesn't match rollup : %v, valid: %v", + v.lastWasmModuleRoot, valInfo.WasmRoots, + ) + } + log.Warn("wasmroot doesn't match rollup", "rollup", v.lastWasmModuleRoot, "blockValidator", valInfo.WasmRoots) } } else { - lastBlockValidated = v.l2Blockchain.CurrentBlock().Number.Uint64() - - if localBatchCount > 0 { - messageCount, err := v.inboxTracker.GetBatchMessageCount(localBatchCount - 1) + validatedCount, err = v.txStreamer.GetProcessedMessageCount() + if err != nil || validatedCount == 0 { + return nil, false, err + } + var batchNum uint64 + messageCount, err := v.inboxTracker.GetBatchMessageCount(localBatchCount - 1) + if err != nil { + return nil, false, fmt.Errorf("error getting latest batch %v message count: %w", localBatchCount-1, err) + } + if validatedCount >= messageCount { + batchNum = localBatchCount - 1 + validatedCount = messageCount + } else { + batchNum, err = FindBatchContainingMessageIndex(v.inboxTracker, validatedCount-1, localBatchCount) if err != nil { - return nil, false, fmt.Errorf("error getting latest batch %v message count: %w", localBatchCount-1, err) - } - // Must be non-negative as a batch must contain at least one message - lastBatchBlock := uint64(arbutil.MessageCountToBlockNumber(messageCount, v.genesisBlockNumber)) - if lastBlockValidated > lastBatchBlock { - lastBlockValidated = lastBatchBlock + return nil, false, err } - } else { - lastBlockValidated = 0 } + execResult, err := v.txStreamer.ResultAtCount(validatedCount) + if err != nil { + return nil, false, err + } + _, gsPos, err := GlobalStatePositionsAtCount(v.inboxTracker, validatedCount, batchNum) + if err != nil { + return nil, false, fmt.Errorf("%w: failed calculating GSposition for count %d", err, validatedCount) + } + validatedGlobalState = buildGlobalState(*execResult, gsPos) } currentL1BlockNum, err := v.client.BlockNumber(ctx) @@ -379,98 +387,71 @@ func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurSta // We've found everything we could hope to find break } - if correctNode == nil { - afterGs := nd.AfterState().GlobalState - requiredBatches := nd.AfterState().RequiredBatches() - if localBatchCount < requiredBatches { - return nil, false, fmt.Errorf("waiting for validator to catch up to assertion batches: %v/%v", localBatchCount, requiredBatches) - } - if requiredBatches > 0 { - haveAcc, err := v.inboxTracker.GetBatchAcc(requiredBatches - 1) - if err != nil { - return nil, false, fmt.Errorf("error getting batch %v accumulator: %w", requiredBatches-1, err) - } - if haveAcc != nd.AfterInboxBatchAcc { - return nil, false, fmt.Errorf("missed sequencer batches reorg: at seq num %v have acc %v but assertion has acc %v", requiredBatches-1, haveAcc, nd.AfterInboxBatchAcc) - } - } - lastBlockNum, inboxPositionInvalid, err := v.blockNumberFromGlobalState(afterGs) - if err != nil { - return nil, false, fmt.Errorf("error getting block number from global state: %w", err) - } - if int64(lastBlockValidated) < lastBlockNum { - return nil, false, fmt.Errorf("waiting for validator to catch up to assertion blocks: %v/%v", lastBlockValidated, lastBlockNum) - } - var expectedBlockHash common.Hash - var expectedSendRoot common.Hash - if lastBlockNum >= 0 { - lastBlock := v.l2Blockchain.GetBlockByNumber(uint64(lastBlockNum)) - if lastBlock == nil { - return nil, false, fmt.Errorf("block %v not in database despite being validated", lastBlockNum) - } - lastBlockExtra := types.DeserializeHeaderExtraInformation(lastBlock.Header()) - expectedBlockHash = lastBlock.Hash() - expectedSendRoot = lastBlockExtra.SendRoot - } - - var expectedNumBlocks uint64 - if startBlock == nil { - expectedNumBlocks = uint64(lastBlockNum + 1) - } else { - expectedNumBlocks = uint64(lastBlockNum) - startBlock.NumberU64() - } - valid := !inboxPositionInvalid && - nd.Assertion.NumBlocks == expectedNumBlocks && - afterGs.BlockHash == expectedBlockHash && - afterGs.SendRoot == expectedSendRoot && - nd.Assertion.AfterState.MachineStatus == validator.MachineStatusFinished - if valid { - log.Info( - "found correct assertion", - "node", nd.NodeNum, - "blockNum", lastBlockNum, - "blockHash", afterGs.BlockHash, - ) - correctNode = existingNodeAction{ - number: nd.NodeNum, - hash: nd.NodeHash, - } - continue - } else { - log.Error( - "found incorrect assertion", - "node", nd.NodeNum, - "inboxPositionInvalid", inboxPositionInvalid, - "computedBlockNum", lastBlockNum, - "numBlocks", nd.Assertion.NumBlocks, - "expectedNumBlocks", expectedNumBlocks, - "blockHash", afterGs.BlockHash, - "expectedBlockHash", expectedBlockHash, - "sendRoot", afterGs.SendRoot, - "expectedSendRoot", expectedSendRoot, - "machineStatus", nd.Assertion.AfterState.MachineStatus, - ) - } - } else { + if correctNode != nil { log.Error("found younger sibling to correct assertion (implicitly invalid)", "node", nd.NodeNum) + wrongNodesExist = true + continue + } + afterGS := nd.AfterState().GlobalState + requiredBatch := afterGS.Batch + if afterGS.PosInBatch == 0 && afterGS.Batch > 0 { + requiredBatch -= 1 + } + if localBatchCount <= requiredBatch { + log.Info("staker: waiting for node to catch up to assertion batch", "current", localBatchCount, "target", requiredBatch-1) + return nil, false, nil + } + nodeBatchMsgCount, err := v.inboxTracker.GetBatchMessageCount(requiredBatch) + if err != nil { + return nil, false, err + } + if validatedCount < nodeBatchMsgCount { + log.Info("staker: waiting for validator to catch up to assertion batch messages", "current", validatedCount, "target", nodeBatchMsgCount) + return nil, false, nil + } + if nd.Assertion.AfterState.MachineStatus != validator.MachineStatusFinished { + wrongNodesExist = true + log.Error("Found incorrect assertion: Machine status not finished", "node", nd.NodeNum, "machineStatus", nd.Assertion.AfterState.MachineStatus) + continue + } + caughtUp, nodeMsgCount, err := GlobalStateToMsgCount(v.inboxTracker, v.txStreamer, afterGS) + if errors.Is(err, ErrGlobalStateNotInChain) { + wrongNodesExist = true + log.Error("Found incorrect assertion", "node", nd.NodeNum, "afterGS", afterGS, "err", err) + continue + } + if err != nil { + return nil, false, fmt.Errorf("error getting message number from global state: %w", err) + } + if !caughtUp { + return nil, false, fmt.Errorf("unexpected no-caught-up parsing assertion. Current: %d target: %v", validatedCount, afterGS) + } + log.Info( + "found correct assertion", + "node", nd.NodeNum, + "count", nodeMsgCount, + "blockHash", afterGS.BlockHash, + ) + correctNode = existingNodeAction{ + number: nd.NodeNum, + hash: nd.NodeHash, } - // If we've hit this point, the node is "wrong" - wrongNodesExist = true } if correctNode != nil || strategy == WatchtowerStrategy { return correctNode, wrongNodesExist, nil } + makeAssertionInterval := stakerConfig.MakeAssertionInterval if wrongNodesExist || (strategy >= MakeNodesStrategy && time.Since(startStateProposedTime) >= makeAssertionInterval) { // There's no correct node; create one. var lastNodeHashIfExists *common.Hash if len(successorNodes) > 0 { lastNodeHashIfExists = &successorNodes[len(successorNodes)-1].NodeHash } - action, err := v.createNewNodeAction(ctx, stakerInfo, lastBlockValidated, localBatchCount, prevInboxMaxCount, startBlock, startState, lastNodeHashIfExists) + action, err := v.createNewNodeAction(ctx, stakerInfo, prevInboxMaxCount, startCount, startState, validatedCount, validatedGlobalState, lastNodeHashIfExists) if err != nil { - return nil, wrongNodesExist, fmt.Errorf("error generating create new node action (from start block %v to last block validated %v): %w", startBlock, lastBlockValidated, err) + return nil, wrongNodesExist, fmt.Errorf("error generating create new node action (from pos %d to %d): %w", startCount, validatedCount, err) } return action, wrongNodesExist, nil } @@ -481,79 +462,33 @@ func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurSta func (v *L1Validator) createNewNodeAction( ctx context.Context, stakerInfo *OurStakerInfo, - lastBlockValidated uint64, - localBatchCount uint64, prevInboxMaxCount *big.Int, - startBlock *types.Block, + startCount arbutil.MessageIndex, startState *validator.ExecutionState, + validatedCount arbutil.MessageIndex, + validatedGS validator.GoGlobalState, lastNodeHashIfExists *common.Hash, ) (nodeAction, error) { if !prevInboxMaxCount.IsUint64() { return nil, fmt.Errorf("inbox max count %v isn't a uint64", prevInboxMaxCount) } - minBatchCount := prevInboxMaxCount.Uint64() - if localBatchCount < minBatchCount { - // not enough batches in database - return nil, nil - } - - if localBatchCount == 0 { - // we haven't validated anything - return nil, nil - } - if startBlock != nil && lastBlockValidated <= startBlock.NumberU64() { + if validatedCount <= startCount { // we haven't validated any new blocks return nil, nil } - var assertionCoversBatch uint64 - var afterGsBatch uint64 - var afterGsPosInBatch uint64 - for i := localBatchCount - 1; i+1 >= minBatchCount && i > 0; i-- { - batchMessageCount, err := v.inboxTracker.GetBatchMessageCount(i) - if err != nil { - return nil, fmt.Errorf("error getting batch %v message count: %w", i, err) - } - prevBatchMessageCount, err := v.inboxTracker.GetBatchMessageCount(i - 1) - if err != nil { - return nil, fmt.Errorf("error getting previous batch %v message count: %w", i-1, err) - } - // Must be non-negative as a batch must contain at least one message - lastBlockNum := uint64(arbutil.MessageCountToBlockNumber(batchMessageCount, v.genesisBlockNumber)) - prevBlockNum := uint64(arbutil.MessageCountToBlockNumber(prevBatchMessageCount, v.genesisBlockNumber)) - if lastBlockValidated > lastBlockNum { - return nil, fmt.Errorf("%v blocks have been validated but only %v appear in the latest batch", lastBlockValidated, lastBlockNum) - } - if lastBlockValidated > prevBlockNum { - // We found the batch containing the last validated block - if i+1 == minBatchCount && lastBlockValidated < lastBlockNum { - // We haven't reached the minimum assertion size yet - break - } - assertionCoversBatch = i - if lastBlockValidated < lastBlockNum { - afterGsBatch = i - afterGsPosInBatch = lastBlockValidated - prevBlockNum - } else { - afterGsBatch = i + 1 - afterGsPosInBatch = 0 - } - break - } - } - if assertionCoversBatch == 0 { - // we haven't validated the next batch completely + if validatedGS.Batch < prevInboxMaxCount.Uint64() { + // didn't validate enough batches + log.Info("staker: not enough batches validated to create new assertion", "validated.Batch", validatedGS.Batch, "posInBatch", validatedGS.PosInBatch, "required batch", prevInboxMaxCount) return nil, nil } - validatedBatchAcc, err := v.inboxTracker.GetBatchAcc(assertionCoversBatch) - if err != nil { - return nil, fmt.Errorf("error getting batch %v accumulator: %w", assertionCoversBatch, err) + batchValidated := validatedGS.Batch + if validatedGS.PosInBatch == 0 { + batchValidated-- } - - assertingBlock := v.l2Blockchain.GetBlockByNumber(lastBlockValidated) - if assertingBlock == nil { - return nil, fmt.Errorf("missing validated block %v", lastBlockValidated) + validatedBatchAcc, err := v.inboxTracker.GetBatchAcc(batchValidated) + if err != nil { + return nil, fmt.Errorf("error getting batch %v accumulator: %w", batchValidated, err) } - assertingBlockExtra := types.DeserializeHeaderExtraInformation(assertingBlock.Header()) hasSiblingByte := [1]byte{0} prevNum := stakerInfo.LatestStakedNode @@ -562,21 +497,11 @@ func (v *L1Validator) createNewNodeAction( lastHash = *lastNodeHashIfExists hasSiblingByte[0] = 1 } - var assertionNumBlocks uint64 - if startBlock == nil { - assertionNumBlocks = assertingBlock.NumberU64() + 1 - } else { - assertionNumBlocks = assertingBlock.NumberU64() - startBlock.NumberU64() - } + assertionNumBlocks := uint64(validatedCount - startCount) assertion := &Assertion{ BeforeState: startState, AfterState: &validator.ExecutionState{ - GlobalState: validator.GoGlobalState{ - BlockHash: assertingBlock.Hash(), - SendRoot: assertingBlockExtra.SendRoot, - Batch: afterGsBatch, - PosInBatch: afterGsPosInBatch, - }, + GlobalState: validatedGS, MachineStatus: validator.MachineStatusFinished, }, NumBlocks: assertionNumBlocks, diff --git a/staker/staker.go b/staker/staker.go index baaf81f9a2..09a05daad2 100644 --- a/staker/staker.go +++ b/staker/staker.go @@ -19,9 +19,11 @@ import ( "github.com/ethereum/go-ethereum/metrics" flag "github.com/spf13/pflag" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/cmd/genericconf" "github.com/offchainlabs/nitro/util/arbmath" "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/offchainlabs/nitro/validator" ) var ( @@ -170,14 +172,17 @@ func L1ValidatorConfigAddOptions(prefix string, f *flag.FlagSet) { } type DangerousConfig struct { - WithoutBlockValidator bool `koanf:"without-block-validator"` + IgnoreRollupWasmModuleRoot bool `koanf:"ignore-rollup-wasm-module-root"` + WithoutBlockValidator bool `koanf:"without-block-validator"` } var DefaultDangerousConfig = DangerousConfig{ - WithoutBlockValidator: false, + IgnoreRollupWasmModuleRoot: false, + WithoutBlockValidator: false, } func DangerousConfigAddOptions(prefix string, f *flag.FlagSet) { + f.Bool(prefix+".ignore-rollup-wasm-module-root", DefaultL1ValidatorConfig.Dangerous.IgnoreRollupWasmModuleRoot, "DANGEROUS! make assertions even when the wasm module root is wrong") f.Bool(prefix+".without-block-validator", DefaultL1ValidatorConfig.Dangerous.WithoutBlockValidator, "DANGEROUS! allows running an L1 validator without a block validator") } @@ -186,10 +191,15 @@ type nodeAndHash struct { hash common.Hash } +type LatestStakedNotifier interface { + UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) +} + type Staker struct { *L1Validator stopwaiter.StopWaiter l1Reader L1ReaderInterface + notifiers []LatestStakedNotifier activeChallenge *ChallengeManager baseCallOpts bind.CallOpts config L1ValidatorConfig @@ -199,6 +209,7 @@ type Staker struct { bringActiveUntilNode uint64 inboxReader InboxReaderInterface statelessBlockValidator *StatelessBlockValidator + fatalErr chan<- error } func NewStaker( @@ -208,7 +219,9 @@ func NewStaker( config L1ValidatorConfig, blockValidator *BlockValidator, statelessBlockValidator *StatelessBlockValidator, + notifiers []LatestStakedNotifier, validatorUtilsAddress common.Address, + fatalErr chan<- error, ) (*Staker, error) { if err := config.Validate(); err != nil { @@ -216,20 +229,25 @@ func NewStaker( } client := l1Reader.Client() val, err := NewL1Validator(client, wallet, validatorUtilsAddress, callOpts, - statelessBlockValidator.blockchain, statelessBlockValidator.daService, statelessBlockValidator.inboxTracker, statelessBlockValidator.streamer, blockValidator) + statelessBlockValidator.daService, statelessBlockValidator.inboxTracker, statelessBlockValidator.streamer, blockValidator) if err != nil { return nil, err } stakerLastSuccessfulActionGauge.Update(time.Now().Unix()) + if config.StartFromStaked { + notifiers = append(notifiers, blockValidator) + } return &Staker{ L1Validator: val, l1Reader: l1Reader, + notifiers: notifiers, baseCallOpts: callOpts, config: config, highGasBlocksBuffer: big.NewInt(config.L1PostingStrategy.HighGasDelayBlocks), lastActCalledBlock: nil, inboxReader: statelessBlockValidator.inboxReader, statelessBlockValidator: statelessBlockValidator, + fatalErr: fatalErr, }, nil } @@ -257,9 +275,54 @@ func (s *Staker) Initialize(ctx context.Context) error { return err } - return s.blockValidator.AssumeValid(stakedInfo.AfterState().GlobalState) + return s.blockValidator.InitAssumeValid(stakedInfo.AfterState().GlobalState) + } + return nil +} + +func (s *Staker) checkLatestStaked(ctx context.Context) error { + latestStaked, _, err := s.validatorUtils.LatestStaked(&s.baseCallOpts, s.rollupAddress, s.wallet.AddressOrZero()) + if err != nil { + return fmt.Errorf("couldn't get LatestStaked: %w", err) + } + stakerLatestStakedNodeGauge.Update(int64(latestStaked)) + if latestStaked == 0 { + return nil + } + + stakedInfo, err := s.rollup.LookupNode(ctx, latestStaked) + if err != nil { + return fmt.Errorf("couldn't look up latest node: %w", err) + } + + stakedGlobalState := stakedInfo.AfterState().GlobalState + caughtUp, count, err := GlobalStateToMsgCount(s.inboxTracker, s.txStreamer, stakedGlobalState) + if err != nil { + if errors.Is(err, ErrGlobalStateNotInChain) && s.fatalErr != nil { + fatal := fmt.Errorf("latest staked not in chain: %w", err) + s.fatalErr <- fatal + } + return fmt.Errorf("staker: latest staked %w", err) + } + + if !caughtUp { + log.Info("latest valid not yet in our node", "staked", stakedGlobalState) + return nil } + processedCount, err := s.txStreamer.GetProcessedMessageCount() + if err != nil { + return err + } + + if processedCount < count { + log.Info("execution catching up to last validated", "validatedCount", count, "processedCount", processedCount) + return nil + } + + for _, notifier := range s.notifiers { + notifier.UpdateLatestStaked(count, stakedGlobalState) + } return nil } @@ -317,6 +380,13 @@ func (s *Staker) Start(ctxIn context.Context) { } return backoff }) + s.CallIteratively(func(ctx context.Context) time.Duration { + err := s.checkLatestStaked(ctx) + if err != nil && ctx.Err() == nil { + log.Error("staker: error checking latest staked", "err", err) + } + return s.config.StakerInterval + }) } func (s *Staker) IsWhitelisted(ctx context.Context) (bool, error) { @@ -609,8 +679,6 @@ func (s *Staker) handleConflict(ctx context.Context, info *StakerInfo) error { *s.builder.wallet.Address(), s.wallet.ChallengeManagerAddress(), *info.CurrentChallenge, - s.l2Blockchain, - s.inboxTracker, s.statelessBlockValidator, latestConfirmedCreated, s.config.ConfirmationBlocks, @@ -628,7 +696,7 @@ func (s *Staker) handleConflict(ctx context.Context, info *StakerInfo) error { func (s *Staker) advanceStake(ctx context.Context, info *OurStakerInfo, effectiveStrategy StakerStrategy) error { active := effectiveStrategy >= StakeLatestStrategy - action, wrongNodesExist, err := s.generateNodeAction(ctx, info, effectiveStrategy, s.config.MakeAssertionInterval) + action, wrongNodesExist, err := s.generateNodeAction(ctx, info, effectiveStrategy, &s.config) if err != nil { return fmt.Errorf("error generating node action: %w", err) } diff --git a/staker/stateless_block_validator.go b/staker/stateless_block_validator.go index c56eb3e9a7..0242daa3c7 100644 --- a/staker/stateless_block_validator.go +++ b/staker/stateless_block_validator.go @@ -8,22 +8,20 @@ import ( "errors" "fmt" "sync" + "testing" + "github.com/offchainlabs/nitro/arbnode/execution" "github.com/offchainlabs/nitro/util/rpcclient" "github.com/offchainlabs/nitro/validator/server_api" "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/validator" - "github.com/ethereum/go-ethereum/arbitrum" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" - "github.com/offchainlabs/nitro/arbos" - "github.com/offchainlabs/nitro/arbos/arbosState" "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbstate" ) @@ -34,14 +32,13 @@ type StatelessBlockValidator struct { execSpawner validator.ExecutionSpawner validationSpawners []validator.ValidationSpawner - inboxReader InboxReaderInterface - inboxTracker InboxTrackerInterface - streamer TransactionStreamerInterface - blockchain *core.BlockChain - db ethdb.Database - daService arbstate.DataAvailabilityReader - genesisBlockNum uint64 - recordingDatabase *arbitrum.RecordingDatabase + recorder BlockRecorder + + inboxReader InboxReaderInterface + inboxTracker InboxTrackerInterface + streamer TransactionStreamerInterface + db ethdb.Database + daService arbstate.DataAvailabilityReader moduleMutex sync.Mutex currentWasmModuleRoot common.Hash @@ -52,6 +49,16 @@ type BlockValidatorRegistrer interface { SetBlockValidator(*BlockValidator) } +type BlockRecorder interface { + RecordBlockCreation( + ctx context.Context, + pos arbutil.MessageIndex, + msg *arbostypes.MessageWithMetadata, + ) (*execution.RecordResult, error) + MarkValid(pos arbutil.MessageIndex, resultHash common.Hash) + PrepareForRecord(ctx context.Context, start, end arbutil.MessageIndex) error +} + type InboxTrackerInterface interface { BlockValidatorRegistrer GetDelayedMessageBytes(uint64) ([]byte, error) @@ -62,8 +69,9 @@ type InboxTrackerInterface interface { type TransactionStreamerInterface interface { BlockValidatorRegistrer + GetProcessedMessageCount() (arbutil.MessageIndex, error) GetMessage(seqNum arbutil.MessageIndex) (*arbostypes.MessageWithMetadata, error) - GetGenesisBlockNumber() (uint64, error) + ResultAtCount(count arbutil.MessageIndex) (*execution.MessageResult, error) PauseReorgs() ResumeReorgs() } @@ -83,9 +91,11 @@ type GlobalStatePosition struct { PosInBatch uint64 } -func GlobalStatePositionsFor( +// return the globalState position before and after processing message at the specified count +// batch-number must be provided by caller +func GlobalStatePositionsAtCount( tracker InboxTrackerInterface, - pos arbutil.MessageIndex, + count arbutil.MessageIndex, batch uint64, ) (GlobalStatePosition, GlobalStatePosition, error) { msgCountInBatch, err := tracker.GetBatchMessageCount(batch) @@ -99,17 +109,18 @@ func GlobalStatePositionsFor( return GlobalStatePosition{}, GlobalStatePosition{}, err } } - if msgCountInBatch <= pos { - return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d has up to message %d, failed getting for %d", batch, msgCountInBatch-1, pos) + if msgCountInBatch < count { + return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d has msgCount %d, failed getting for %d", batch, msgCountInBatch-1, count) } - if firstInBatch > pos { - return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d starts from %d, failed getting for %d", batch, firstInBatch, pos) + if firstInBatch >= count { + return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d starts from %d, failed getting for %d", batch, firstInBatch, count) } - startPos := GlobalStatePosition{batch, uint64(pos - firstInBatch)} - if msgCountInBatch == pos+1 { + posInBatch := uint64(count - firstInBatch - 1) + startPos := GlobalStatePosition{batch, posInBatch} + if msgCountInBatch == count { return startPos, GlobalStatePosition{batch + 1, 0}, nil } - return startPos, GlobalStatePosition{batch, uint64(pos + 1 - firstInBatch)}, nil + return startPos, GlobalStatePosition{batch, posInBatch + 1}, nil } func FindBatchContainingMessageIndex( @@ -150,154 +161,96 @@ type ValidationEntryStage uint32 const ( Empty ValidationEntryStage = iota ReadyForRecord - Recorded Ready ) type validationEntry struct { Stage ValidationEntryStage // Valid since ReadyforRecord: - BlockNumber uint64 - PrevBlockHash common.Hash - PrevBlockHeader *types.Header - BlockHash common.Hash - BlockHeader *types.Header - HasDelayedMsg bool - DelayedMsgNr uint64 - msg *arbostypes.MessageWithMetadata - // Valid since Recorded: + Pos arbutil.MessageIndex + Start validator.GoGlobalState + End validator.GoGlobalState + HasDelayedMsg bool + DelayedMsgNr uint64 + // valid when created, removed after recording + msg *arbostypes.MessageWithMetadata + // Has batch when created - others could be added on record + BatchInfo []validator.BatchInfo + // Valid since Ready Preimages map[common.Hash][]byte - BatchInfo []validator.BatchInfo DelayedMsg []byte - // Valid since Ready: - StartPosition GlobalStatePosition - EndPosition GlobalStatePosition -} - -func (v *validationEntry) start() (validator.GoGlobalState, error) { - start := v.StartPosition - globalState := validator.GoGlobalState{ - Batch: start.BatchNumber, - PosInBatch: start.PosInBatch, - BlockHash: v.PrevBlockHash, - } - if v.PrevBlockHeader != nil { - prevExtraInfo := types.DeserializeHeaderExtraInformation(v.PrevBlockHeader) - globalState.SendRoot = prevExtraInfo.SendRoot - } - return globalState, nil -} - -func (v *validationEntry) expectedEnd() (validator.GoGlobalState, error) { - extraInfo := types.DeserializeHeaderExtraInformation(v.BlockHeader) - end := v.EndPosition - return validator.GoGlobalState{ - Batch: end.BatchNumber, - PosInBatch: end.PosInBatch, - BlockHash: v.BlockHash, - SendRoot: extraInfo.SendRoot, - }, nil } func (e *validationEntry) ToInput() (*validator.ValidationInput, error) { if e.Stage != Ready { return nil, errors.New("cannot create input from non-ready entry") } - startState, err := e.start() - if err != nil { - return nil, err - } return &validator.ValidationInput{ - Id: e.BlockNumber, + Id: uint64(e.Pos), HasDelayedMsg: e.HasDelayedMsg, DelayedMsgNr: e.DelayedMsgNr, Preimages: e.Preimages, BatchInfo: e.BatchInfo, DelayedMsg: e.DelayedMsg, - StartState: startState, + StartState: e.Start, }, nil } -func usingDelayedMsg(prevHeader *types.Header, header *types.Header) (bool, uint64) { - if prevHeader == nil { - return true, 0 - } - if header.Nonce == prevHeader.Nonce { - return false, 0 - } - return true, prevHeader.Nonce.Uint64() -} - func newValidationEntry( - prevHeader *types.Header, - header *types.Header, + pos arbutil.MessageIndex, + start validator.GoGlobalState, + end validator.GoGlobalState, msg *arbostypes.MessageWithMetadata, + batch []byte, + prevDelayed uint64, ) (*validationEntry, error) { - hasDelayedMsg, delayedMsgNr := usingDelayedMsg(prevHeader, header) - validationEntry := &validationEntry{ + batchInfo := validator.BatchInfo{ + Number: start.Batch, + Data: batch, + } + hasDelayed := false + var delayedNum uint64 + if msg.DelayedMessagesRead == prevDelayed+1 { + hasDelayed = true + delayedNum = prevDelayed + } else if msg.DelayedMessagesRead != prevDelayed { + return nil, fmt.Errorf("illegal validation entry delayedMessage %d, previous %d", msg.DelayedMessagesRead, prevDelayed) + } + return &validationEntry{ Stage: ReadyForRecord, - BlockNumber: header.Number.Uint64(), - BlockHash: header.Hash(), - BlockHeader: header, - HasDelayedMsg: hasDelayedMsg, - DelayedMsgNr: delayedMsgNr, + Pos: pos, + Start: start, + End: end, + HasDelayedMsg: hasDelayed, + DelayedMsgNr: delayedNum, msg: msg, - } - if prevHeader != nil { - validationEntry.PrevBlockHash = prevHeader.Hash() - validationEntry.PrevBlockHeader = prevHeader - } - return validationEntry, nil -} - -func newRecordedValidationEntry( - prevHeader *types.Header, - header *types.Header, - preimages map[common.Hash][]byte, - batchInfos []validator.BatchInfo, - delayedMsg []byte, -) (*validationEntry, error) { - entry, err := newValidationEntry(prevHeader, header, nil) - if err != nil { - return nil, err - } - entry.Preimages = preimages - entry.BatchInfo = batchInfos - entry.DelayedMsg = delayedMsg - entry.Stage = Recorded - return entry, nil + BatchInfo: []validator.BatchInfo{batchInfo}, + }, nil } func NewStatelessBlockValidator( inboxReader InboxReaderInterface, inbox InboxTrackerInterface, streamer TransactionStreamerInterface, - blockchain *core.BlockChain, - blockchainDb ethdb.Database, + recorder BlockRecorder, arbdb ethdb.Database, das arbstate.DataAvailabilityReader, config func() *BlockValidatorConfig, stack *node.Node, ) (*StatelessBlockValidator, error) { - genesisBlockNum, err := streamer.GetGenesisBlockNumber() - if err != nil { - return nil, err - } valConfFetcher := func() *rpcclient.ClientConfig { return &config().ValidationServer } valClient := server_api.NewValidationClient(valConfFetcher, stack) execClient := server_api.NewExecutionClient(valConfFetcher, stack) validator := &StatelessBlockValidator{ config: config(), execSpawner: execClient, + recorder: recorder, validationSpawners: []validator.ValidationSpawner{valClient}, inboxReader: inboxReader, inboxTracker: inbox, streamer: streamer, - blockchain: blockchain, db: arbdb, daService: das, - genesisBlockNum: genesisBlockNum, - recordingDatabase: arbitrum.NewRecordingDatabase(blockchainDb, blockchain), } return validator, nil } @@ -313,158 +266,38 @@ func (v *StatelessBlockValidator) GetModuleRootsToValidate() []common.Hash { return validatingModuleRoots } -func stateLogFunc(targetHeader, header *types.Header, hasState bool) { - if targetHeader == nil || header == nil { - return - } - gap := targetHeader.Number.Int64() - header.Number.Int64() - step := int64(500) - stage := "computing state" - if !hasState { - step = 3000 - stage = "looking for full block" - } - if (gap >= step) && (gap%step == 0) { - log.Info("Setting up validation", "stage", stage, "current", header.Number, "target", targetHeader.Number) - } -} - -// If msg is nil, this will record block creation up to the point where message would be accessed (for a "too far" proof) -// If keepreference == true, reference to state of prevHeader is added (no reference added if an error is returned) -func (v *StatelessBlockValidator) RecordBlockCreation( - ctx context.Context, - prevHeader *types.Header, - msg *arbostypes.MessageWithMetadata, - keepReference bool, -) (common.Hash, map[common.Hash][]byte, []validator.BatchInfo, error) { - - recordingdb, chaincontext, recordingKV, err := v.recordingDatabase.PrepareRecording(ctx, prevHeader, stateLogFunc) - if err != nil { - return common.Hash{}, nil, nil, err +func (v *StatelessBlockValidator) ValidationEntryRecord(ctx context.Context, e *validationEntry) error { + if e.Stage != ReadyForRecord { + return fmt.Errorf("validation entry should be ReadyForRecord, is: %v", e.Stage) } - defer func() { v.recordingDatabase.Dereference(prevHeader) }() - - chainConfig := v.blockchain.Config() - - // Get the chain ID, both to validate and because the replay binary also gets the chain ID, - // so we need to populate the recordingdb with preimages for retrieving the chain ID. - if prevHeader != nil { - initialArbosState, err := arbosState.OpenSystemArbosState(recordingdb, nil, true) - if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error opening initial ArbOS state: %w", err) - } - chainId, err := initialArbosState.ChainId() - if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error getting chain ID from initial ArbOS state: %w", err) - } - if chainId.Cmp(chainConfig.ChainID) != 0 { - return common.Hash{}, nil, nil, fmt.Errorf("unexpected chain ID %v in ArbOS state, expected %v", chainId, chainConfig.ChainID) - } - _, err = initialArbosState.ChainConfig() - if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error getting chain config from initial ArbOS state: %w", err) - } - genesisNum, err := initialArbosState.GenesisBlockNum() + if e.Pos != 0 { + recording, err := v.recorder.RecordBlockCreation(ctx, e.Pos, e.msg) if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error getting genesis block number from initial ArbOS state: %w", err) + return err } - expectedNum := chainConfig.ArbitrumChainParams.GenesisBlockNum - if genesisNum != expectedNum { - return common.Hash{}, nil, nil, fmt.Errorf("unexpected genesis block number %v in ArbOS state, expected %v", genesisNum, expectedNum) + if recording.BlockHash != e.End.BlockHash { + return fmt.Errorf("recording failed: pos %d, hash expected %v, got %v", e.Pos, e.End.BlockHash, recording.BlockHash) } - } + e.BatchInfo = append(e.BatchInfo, recording.BatchInfo...) - var blockHash common.Hash - var readBatchInfo []validator.BatchInfo - if msg != nil { - batchFetcher := func(batchNum uint64) ([]byte, error) { - data, err := v.inboxReader.GetSequencerMessageBytes(ctx, batchNum) - if err != nil { - return nil, err - } - readBatchInfo = append(readBatchInfo, validator.BatchInfo{ - Number: batchNum, - Data: data, - }) - return data, nil + if recording.Preimages != nil { + e.Preimages = recording.Preimages } - // Re-fetch the batch instead of using our cached cost, - // as the replay binary won't have the cache populated. - msg.Message.BatchGasCost = nil - block, _, err := arbos.ProduceBlock( - msg.Message, - msg.DelayedMessagesRead, - prevHeader, - recordingdb, - chaincontext, - chainConfig, - batchFetcher, - ) - if err != nil { - return common.Hash{}, nil, nil, err - } - blockHash = block.Hash() - } - - preimages, err := v.recordingDatabase.PreimagesFromRecording(chaincontext, recordingKV) - if err != nil { - return common.Hash{}, nil, nil, err - } - if keepReference { - prevHeader = nil - } - return blockHash, preimages, readBatchInfo, err -} - -func (v *StatelessBlockValidator) ValidationEntryRecord(ctx context.Context, e *validationEntry, keepReference bool) error { - if e.Stage != ReadyForRecord { - return fmt.Errorf("validation entry should be ReadyForRecord, is: %v", e.Stage) - } - if e.PrevBlockHeader == nil { - e.Stage = Recorded - return nil - } - blockhash, preimages, readBatchInfo, err := v.RecordBlockCreation(ctx, e.PrevBlockHeader, e.msg, keepReference) - if err != nil { - return err - } - if blockhash != e.BlockHash { - return fmt.Errorf("recording failed: blockNum %d, hash expected %v, got %v", e.BlockNumber, e.BlockHash, blockhash) } if e.HasDelayedMsg { delayedMsg, err := v.inboxTracker.GetDelayedMessageBytes(e.DelayedMsgNr) if err != nil { log.Error( "error while trying to read delayed msg for proving", - "err", err, "seq", e.DelayedMsgNr, "blockNr", e.BlockNumber, + "err", err, "seq", e.DelayedMsgNr, "pos", e.Pos, ) return fmt.Errorf("error while trying to read delayed msg for proving: %w", err) } e.DelayedMsg = delayedMsg } - e.Preimages = preimages - e.BatchInfo = readBatchInfo - e.msg = nil // no longer needed - e.Stage = Recorded - return nil -} - -func (v *StatelessBlockValidator) ValidationEntryAddSeqMessage(ctx context.Context, e *validationEntry, - startPos, endPos GlobalStatePosition, seqMsg []byte) error { - if e.Stage != Recorded { - return fmt.Errorf("validation entry stage should be Recorded, is: %v", e.Stage) - } if e.Preimages == nil { e.Preimages = make(map[common.Hash][]byte) } - e.StartPosition = startPos - e.EndPosition = endPos - seqMsgBatchInfo := validator.BatchInfo{ - Number: startPos.BatchNumber, - Data: seqMsg, - } - e.BatchInfo = append(e.BatchInfo, seqMsgBatchInfo) - for _, batch := range e.BatchInfo { if len(batch.Data) <= 40 { continue @@ -473,10 +306,7 @@ func (v *StatelessBlockValidator) ValidationEntryAddSeqMessage(ctx context.Conte continue } if v.daService == nil { - log.Error("No DAS configured, but sequencer message found with DAS header") - if v.blockchain.Config().ArbitrumChainParams.DataAvailabilityCommittee { - return errors.New("processing data availability chain without DAS configured") - } + log.Warn("No DAS configured, but sequencer message found with DAS header") } else { _, err := arbstate.RecoverPayloadFromDasBatch( ctx, batch.Number, batch.Data, v.daService, e.Preimages, arbstate.KeysetValidate, @@ -486,70 +316,76 @@ func (v *StatelessBlockValidator) ValidationEntryAddSeqMessage(ctx context.Conte } } } + + e.msg = nil // no longer needed e.Stage = Ready return nil } -func (v *StatelessBlockValidator) CreateReadyValidationEntry(ctx context.Context, header *types.Header) (*validationEntry, error) { - if header == nil { - return nil, errors.New("header not found") +func buildGlobalState(res execution.MessageResult, pos GlobalStatePosition) validator.GoGlobalState { + return validator.GoGlobalState{ + BlockHash: res.BlockHash, + SendRoot: res.SendRoot, + Batch: pos.BatchNumber, + PosInBatch: pos.PosInBatch, } - blockNum := header.Number.Uint64() - msgIndex := arbutil.BlockNumberToMessageCount(blockNum, v.genesisBlockNum) - 1 - prevHeader := v.blockchain.GetHeader(header.ParentHash, blockNum-1) - if prevHeader == nil && blockNum > 0 { - return nil, fmt.Errorf("prev header not found for block number %v with hash %s and parent hash %s", blockNum, header.Hash(), header.ParentHash) +} + +// return the globalState position before and after processing message at the specified count +func (v *StatelessBlockValidator) GlobalStatePositionsAtCount(count arbutil.MessageIndex) (GlobalStatePosition, GlobalStatePosition, error) { + if count == 0 { + return GlobalStatePosition{}, GlobalStatePosition{}, errors.New("no initial state for count==0") + } + if count == 1 { + return GlobalStatePosition{}, GlobalStatePosition{1, 0}, nil } - msg, err := v.streamer.GetMessage(msgIndex) + batchCount, err := v.inboxTracker.GetBatchCount() if err != nil { - return nil, err + return GlobalStatePosition{}, GlobalStatePosition{}, err } - var preimages map[common.Hash][]byte - var readBatchInfo []validator.BatchInfo - if prevHeader != nil { - var resHash common.Hash - var err error - resHash, preimages, readBatchInfo, err = v.RecordBlockCreation(ctx, prevHeader, msg, false) - if err != nil { - return nil, fmt.Errorf("failed to get block data to validate: %w", err) - } - if resHash != header.Hash() { - return nil, fmt.Errorf("wrong hash expected %s got %s", header.Hash(), resHash) - } + batch, err := FindBatchContainingMessageIndex(v.inboxTracker, count-1, batchCount) + if err != nil { + return GlobalStatePosition{}, GlobalStatePosition{}, err } + return GlobalStatePositionsAtCount(v.inboxTracker, count, batch) +} - batchCount, err := v.inboxTracker.GetBatchCount() +func (v *StatelessBlockValidator) CreateReadyValidationEntry(ctx context.Context, pos arbutil.MessageIndex) (*validationEntry, error) { + msg, err := v.streamer.GetMessage(pos) if err != nil { return nil, err } - batch, err := FindBatchContainingMessageIndex(v.inboxTracker, msgIndex, batchCount) + result, err := v.streamer.ResultAtCount(pos + 1) if err != nil { return nil, err } - - startPos, endPos, err := GlobalStatePositionsFor(v.inboxTracker, msgIndex, batch) - if err != nil { - return nil, fmt.Errorf("failed calculating position for validation: %w", err) - } - - usingDelayed, delaydNr := usingDelayedMsg(prevHeader, header) - var delayed []byte - if usingDelayed { - delayed, err = v.inboxTracker.GetDelayedMessageBytes(delaydNr) + var prevDelayed uint64 + if pos > 0 { + prev, err := v.streamer.GetMessage(pos - 1) if err != nil { - return nil, fmt.Errorf("error while trying to read delayed msg for proving: %w", err) + return nil, err } + prevDelayed = prev.DelayedMessagesRead } - entry, err := newRecordedValidationEntry(prevHeader, header, preimages, readBatchInfo, delayed) + prevResult, err := v.streamer.ResultAtCount(pos) if err != nil { - return nil, fmt.Errorf("failed to create validation entry %w", err) + return nil, err } - + startPos, endPos, err := v.GlobalStatePositionsAtCount(pos + 1) + if err != nil { + return nil, fmt.Errorf("failed calculating position for validation: %w", err) + } + start := buildGlobalState(*prevResult, startPos) + end := buildGlobalState(*result, endPos) seqMsg, err := v.inboxReader.GetSequencerMessageBytes(ctx, startPos.BatchNumber) if err != nil { return nil, err } - err = v.ValidationEntryAddSeqMessage(ctx, entry, startPos, endPos, seqMsg) + entry, err := newValidationEntry(pos, start, end, msg, seqMsg, prevDelayed) + if err != nil { + return nil, err + } + err = v.ValidationEntryRecord(ctx, entry) if err != nil { return nil, err } @@ -557,20 +393,16 @@ func (v *StatelessBlockValidator) CreateReadyValidationEntry(ctx context.Context return entry, nil } -func (v *StatelessBlockValidator) ValidateBlock( - ctx context.Context, header *types.Header, useExec bool, moduleRoot common.Hash, -) (bool, error) { - entry, err := v.CreateReadyValidationEntry(ctx, header) - if err != nil { - return false, err - } - expEnd, err := entry.expectedEnd() +func (v *StatelessBlockValidator) ValidateResult( + ctx context.Context, pos arbutil.MessageIndex, useExec bool, moduleRoot common.Hash, +) (bool, *validator.GoGlobalState, error) { + entry, err := v.CreateReadyValidationEntry(ctx, pos) if err != nil { - return false, err + return false, nil, err } input, err := entry.ToInput() if err != nil { - return false, err + return false, nil, err } var spawners []validator.ValidationSpawner if useExec { @@ -579,7 +411,7 @@ func (v *StatelessBlockValidator) ValidateBlock( spawners = v.validationSpawners } if len(spawners) == 0 { - return false, errors.New("no validation defined") + return false, &entry.End, errors.New("no validation defined") } var runs []validator.ValidationRun for _, spawner := range spawners { @@ -593,15 +425,15 @@ func (v *StatelessBlockValidator) ValidateBlock( }() for _, run := range runs { gsEnd, err := run.Await(ctx) - if err != nil || gsEnd != expEnd { - return false, err + if err != nil || gsEnd != entry.End { + return false, &gsEnd, err } } - return true, nil + return true, &entry.End, nil } -func (v *StatelessBlockValidator) RecordDBReferenceCount() int64 { - return v.recordingDatabase.ReferenceCount() +func (v *StatelessBlockValidator) OverrideRecorder(t *testing.T, recorder BlockRecorder) { + v.recorder = recorder } func (v *StatelessBlockValidator) Start(ctx_in context.Context) error { diff --git a/system_tests/batch_poster_test.go b/system_tests/batch_poster_test.go index 8c656cb2d3..6f6c041c41 100644 --- a/system_tests/batch_poster_test.go +++ b/system_tests/batch_poster_test.go @@ -82,7 +82,7 @@ func testBatchPosterParallel(t *testing.T, useRedis bool) { for i := 0; i < parallelBatchPosters; i++ { // Make a copy of the batch poster config so NewBatchPoster calling Validate() on it doesn't race batchPosterConfig := conf.BatchPoster - batchPoster, err := arbnode.NewBatchPoster(nodeA.L1Reader, nodeA.InboxTracker, nodeA.TxStreamer, nodeA.SyncMonitor, func() *arbnode.BatchPosterConfig { return &batchPosterConfig }, nodeA.DeployInfo, &seqTxOpts, nil) + batchPoster, err := arbnode.NewBatchPoster(nil, nodeA.L1Reader, nodeA.InboxTracker, nodeA.TxStreamer, nodeA.SyncMonitor, func() *arbnode.BatchPosterConfig { return &batchPosterConfig }, nodeA.DeployInfo, &seqTxOpts, nil) Require(t, err) batchPoster.Start(ctx) defer batchPoster.StopAndWait() diff --git a/system_tests/block_validator_test.go b/system_tests/block_validator_test.go index b3fd8ddb6c..7fe1a65969 100644 --- a/system_tests/block_validator_test.go +++ b/system_tests/block_validator_test.go @@ -20,6 +20,7 @@ import ( "github.com/offchainlabs/nitro/arbnode" "github.com/offchainlabs/nitro/arbos/l2pricing" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/solgen/go/precompilesgen" ) @@ -43,6 +44,12 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops chainConfig.ArbitrumChainParams.InitialArbOSVersion = 10 } + var delayEvery int + if workloadLoops > 1 { + l1NodeConfigA.BatchPoster.MaxBatchPostDelay = time.Millisecond * 500 + delayEvery = workloadLoops / 3 + } + l2info, nodeA, l2client, l1info, _, l1client, l1stack := createTestNodeOnL1WithConfig(t, ctx, true, l1NodeConfigA, chainConfig, nil) defer requireClose(t, l1stack) defer nodeA.StopAndWait() @@ -105,6 +112,9 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops if workload != depleteGas { Require(t, err) } + if delayEvery > 0 && i%delayEvery == (delayEvery-1) { + <-time.After(time.Second) + } } } else { auth := l2info.GetDefaultTransactOpts("Owner", ctx) @@ -176,10 +186,12 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops } t.Log("waiting for block: ", lastBlock.NumberU64()) timeout := getDeadlineTimeout(t, time.Minute*10) - if !nodeB.BlockValidator.WaitForBlock(ctx, lastBlock.NumberU64(), timeout) { + // messageindex is same as block number here + if !nodeB.BlockValidator.WaitForPos(t, ctx, arbutil.MessageIndex(lastBlock.NumberU64()), timeout) { Fatal(t, "did not validate all blocks") } - finalRefCount := nodeB.BlockValidator.RecordDBReferenceCount() + nodeB.Execution.Recorder.TrimAllPrepared(t) + finalRefCount := nodeB.Execution.Recorder.RecordingDBReferenceCount() lastBlockNow, err := l2clientB.BlockByNumber(ctx, nil) Require(t, err) // up to 3 extra references: awaiting validation, recently valid, lastValidatedHeader diff --git a/system_tests/common_test.go b/system_tests/common_test.go index 7c9b6b1191..5be818b0a1 100644 --- a/system_tests/common_test.go +++ b/system_tests/common_test.go @@ -15,9 +15,6 @@ import ( "time" "github.com/offchainlabs/nitro/arbnode/execution" - "github.com/offchainlabs/nitro/validator/server_api" - "github.com/offchainlabs/nitro/validator/valnode" - "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbos/util" "github.com/offchainlabs/nitro/arbstate" @@ -28,7 +25,9 @@ import ( "github.com/offchainlabs/nitro/util/arbmath" "github.com/offchainlabs/nitro/util/headerreader" "github.com/offchainlabs/nitro/util/signature" + "github.com/offchainlabs/nitro/validator/server_api" "github.com/offchainlabs/nitro/validator/server_common" + "github.com/offchainlabs/nitro/validator/valnode" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/keystore" diff --git a/system_tests/forwarder_test.go b/system_tests/forwarder_test.go index 3691caf5d2..d4cf0d8eb7 100644 --- a/system_tests/forwarder_test.go +++ b/system_tests/forwarder_test.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "math/big" + "os" "path/filepath" "strings" "sync" @@ -157,7 +158,18 @@ func createSequencer( // tmpPath returns file path with specified filename from temporary directory of the test. func tmpPath(t *testing.T, filename string) string { - return filepath.Join(t.TempDir(), filename) + t.Helper() + // create a unique, maximum 10 characters-long temporary directory {name} with path as $TMPDIR/{name} + tmpDir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + t.Cleanup(func() { + if err = os.RemoveAll(tmpDir); err != nil { + t.Errorf("Failed to cleanup temp dir: %v", err) + } + }) + return filepath.Join(tmpDir, filename) } // testNodes creates specified number of paths for ipc from temporary directory of the test. diff --git a/system_tests/full_challenge_impl_test.go b/system_tests/full_challenge_impl_test.go index 26e2d4a64e..de2ee5cd34 100644 --- a/system_tests/full_challenge_impl_test.go +++ b/system_tests/full_challenge_impl_test.go @@ -1,6 +1,10 @@ // Copyright 2021-2022, Offchain Labs, Inc. // For license information, see https://github.com/nitro/blob/master/LICENSE +// race detection makes things slow and miss timeouts +//go:build !race +// +build !race + package arbtest import ( @@ -18,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" @@ -137,13 +142,15 @@ func writeTxToBatch(writer io.Writer, tx *types.Transaction) error { return err } -func makeBatch(t *testing.T, l2Node *arbnode.Node, l2Info *BlockchainTestInfo, backend *ethclient.Client, sequencer *bind.TransactOpts, seqInbox *mocksgen.SequencerInboxStub, seqInboxAddr common.Address, isChallenger bool) { +const makeBatch_MsgsPerBatch = int64(5) + +func makeBatch(t *testing.T, l2Node *arbnode.Node, l2Info *BlockchainTestInfo, backend *ethclient.Client, sequencer *bind.TransactOpts, seqInbox *mocksgen.SequencerInboxStub, seqInboxAddr common.Address, modStep int64) { ctx := context.Background() batchBuffer := bytes.NewBuffer([]byte{}) - for i := int64(0); i < 10; i++ { + for i := int64(0); i < makeBatch_MsgsPerBatch; i++ { value := i - if i == 5 && isChallenger { + if i == modStep { value++ } err := writeTxToBatch(batchBuffer, l2Info.PrepareTx("Owner", "Destination", 1000000, big.NewInt(value), []byte{})) @@ -153,9 +160,9 @@ func makeBatch(t *testing.T, l2Node *arbnode.Node, l2Info *BlockchainTestInfo, b Require(t, err) message := append([]byte{0}, compressed...) - maxUint256 := new(big.Int).Lsh(common.Big1, 256) - maxUint256.Sub(maxUint256, common.Big1) - tx, err := seqInbox.AddSequencerL2BatchFromOrigin0(sequencer, maxUint256, message, big.NewInt(1), common.Address{}, big.NewInt(0), big.NewInt(0)) + seqNum := new(big.Int).Lsh(common.Big1, 256) + seqNum.Sub(seqNum, common.Big1) + tx, err := seqInbox.AddSequencerL2BatchFromOrigin0(sequencer, seqNum, message, big.NewInt(1), common.Address{}, big.NewInt(0), big.NewInt(0)) Require(t, err) receipt, err := EnsureTxSucceeded(ctx, backend, tx) Require(t, err) @@ -218,8 +225,7 @@ func setupSequencerInboxStub(ctx context.Context, t *testing.T, l1Info *Blockcha return bridgeAddr, seqInbox, seqInboxAddr } -func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { - t.Parallel() +func RunChallengeTest(t *testing.T, asserterIsCorrect bool, useStubs bool, challengeMsgIdx int64) { glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.TerminalFormat(false))) glogger.Verbosity(log.LvlInfo) log.Root().SetHandler(glogger) @@ -241,7 +247,13 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { conf.BatchPoster.Enable = false conf.InboxReader.CheckDelay = time.Second - _, valStack := createTestValidationNode(t, ctx, &valnode.TestValidationConfig) + var valStack *node.Node + var mockSpawn *mockSpawner + if useStubs { + mockSpawn, valStack = createMockValidationNode(t, ctx, &valnode.TestValidationConfig.Arbitrator) + } else { + _, valStack = createTestValidationNode(t, ctx, &valnode.TestValidationConfig) + } configByValidationNode(t, conf, valStack) fatalErrChan := make(chan error, 10) @@ -274,8 +286,22 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { asserterL2Info.GenerateAccount("Destination") challengerL2Info.SetFullAccountInfo("Destination", asserterL2Info.GetInfoWithPrivKey("Destination")) - makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, false) - makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, true) + + if challengeMsgIdx < 1 || challengeMsgIdx > 3*makeBatch_MsgsPerBatch { + Fatal(t, "challengeMsgIdx illegal") + } + + // seqNum := common.Big2 + makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, -1) + makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, challengeMsgIdx-1) + + // seqNum.Add(seqNum, common.Big1) + makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, -1) + makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, challengeMsgIdx-makeBatch_MsgsPerBatch-1) + + // seqNum.Add(seqNum, common.Big1) + makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, -1) + makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, challengeMsgIdx-makeBatch_MsgsPerBatch*2-1) trueSeqInboxAddr := challengerSeqInboxAddr trueDelayedBridge := challengerBridgeAddr @@ -291,9 +317,14 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { if err != nil { Fatal(t, err) } - wasmModuleRoot := locator.LatestWasmModuleRoot() - if (wasmModuleRoot == common.Hash{}) { - Fatal(t, "latest machine not found") + var wasmModuleRoot common.Hash + if useStubs { + wasmModuleRoot = mockWasmModuleRoot + } else { + wasmModuleRoot = locator.LatestWasmModuleRoot() + if (wasmModuleRoot == common.Hash{}) { + Fatal(t, "latest machine not found") + } } asserterGenesis := asserterL2.Execution.ArbInterface.BlockChain().Genesis() @@ -314,7 +345,7 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { } asserterEndGlobalState := validator.GoGlobalState{ BlockHash: asserterLatestBlock.Hash(), - Batch: 2, + Batch: 4, PosInBatch: 0, } numBlocks := asserterLatestBlock.Number.Uint64() - asserterGenesis.NumberU64() @@ -337,29 +368,37 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { confirmLatestBlock(ctx, t, l1Info, l1Backend) - asserterValidator, err := staker.NewStatelessBlockValidator(asserterL2.InboxReader, asserterL2.InboxTracker, asserterL2.TxStreamer, asserterL2Blockchain, asserterL2ChainDb, asserterL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) + asserterValidator, err := staker.NewStatelessBlockValidator(asserterL2.InboxReader, asserterL2.InboxTracker, asserterL2.TxStreamer, asserterL2.Execution.Recorder, asserterL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) if err != nil { Fatal(t, err) } + if useStubs { + asserterRecorder := newMockRecorder(asserterValidator, asserterL2.TxStreamer) + asserterValidator.OverrideRecorder(t, asserterRecorder) + } err = asserterValidator.Start(ctx) if err != nil { Fatal(t, err) } defer asserterValidator.Stop() - asserterManager, err := staker.NewChallengeManager(ctx, l1Backend, &asserterTxOpts, asserterTxOpts.From, challengeManagerAddr, 1, asserterL2Blockchain, asserterL2.InboxTracker, asserterValidator, 0, 0) + asserterManager, err := staker.NewChallengeManager(ctx, l1Backend, &asserterTxOpts, asserterTxOpts.From, challengeManagerAddr, 1, asserterValidator, 0, 0) if err != nil { Fatal(t, err) } - challengerValidator, err := staker.NewStatelessBlockValidator(challengerL2.InboxReader, challengerL2.InboxTracker, challengerL2.TxStreamer, challengerL2Blockchain, challengerL2ChainDb, challengerL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) + challengerValidator, err := staker.NewStatelessBlockValidator(challengerL2.InboxReader, challengerL2.InboxTracker, challengerL2.TxStreamer, challengerL2.Execution.Recorder, challengerL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) if err != nil { Fatal(t, err) } + if useStubs { + challengerRecorder := newMockRecorder(challengerValidator, challengerL2.TxStreamer) + challengerValidator.OverrideRecorder(t, challengerRecorder) + } err = challengerValidator.Start(ctx) if err != nil { Fatal(t, err) } defer challengerValidator.Stop() - challengerManager, err := staker.NewChallengeManager(ctx, l1Backend, &challengerTxOpts, challengerTxOpts.From, challengeManagerAddr, 1, challengerL2Blockchain, challengerL2.InboxTracker, challengerValidator, 0, 0) + challengerManager, err := staker.NewChallengeManager(ctx, l1Backend, &challengerTxOpts, challengerTxOpts.From, challengeManagerAddr, 1, challengerValidator, 0, 0) if err != nil { Fatal(t, err) } @@ -394,6 +433,19 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { if tx == nil { Fatal(t, "no move") } + + if useStubs { + if len(mockSpawn.ExecSpawned) != 0 { + if len(mockSpawn.ExecSpawned) != 1 { + Fatal(t, "bad number of spawned execRuns: ", len(mockSpawn.ExecSpawned)) + } + if mockSpawn.ExecSpawned[0] != uint64(challengeMsgIdx) { + Fatal(t, "wrong spawned execRuns: ", mockSpawn.ExecSpawned[0], " expected: ", challengeMsgIdx) + } + return + } + } + _, err = EnsureTxSucceeded(ctx, l1Backend, tx) if err != nil { if !currentCorrect && strings.Contains(err.Error(), "BAD_SEQINBOX_MESSAGE") { @@ -419,3 +471,17 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { Fatal(t, "challenge timed out without winner") } + +func TestMockChallengeManagerAsserterIncorrect(t *testing.T) { + t.Parallel() + for i := int64(1); i <= makeBatch_MsgsPerBatch*3; i++ { + RunChallengeTest(t, false, true, i) + } +} + +func TestMockChallengeManagerAsserterCorrect(t *testing.T) { + t.Parallel() + for i := int64(1); i <= makeBatch_MsgsPerBatch*3; i++ { + RunChallengeTest(t, true, true, i) + } +} diff --git a/system_tests/full_challenge_test.go b/system_tests/full_challenge_test.go index 367ac33464..a960e7f640 100644 --- a/system_tests/full_challenge_test.go +++ b/system_tests/full_challenge_test.go @@ -15,9 +15,11 @@ import ( ) func TestChallengeManagerFullAsserterIncorrect(t *testing.T) { - RunChallengeTest(t, false) + t.Parallel() + RunChallengeTest(t, false, false, makeBatch_MsgsPerBatch+1) } func TestChallengeManagerFullAsserterCorrect(t *testing.T) { - RunChallengeTest(t, true) + t.Parallel() + RunChallengeTest(t, true, false, makeBatch_MsgsPerBatch+2) } diff --git a/system_tests/retryable_test.go b/system_tests/retryable_test.go index 7b0c3a7563..5e4bca1a64 100644 --- a/system_tests/retryable_test.go +++ b/system_tests/retryable_test.go @@ -6,9 +6,11 @@ package arbtest import ( "context" "math/big" + "strings" "testing" "time" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -58,13 +60,14 @@ func retryableSetup(t *testing.T) ( } var submissionTxs []*types.Transaction for _, message := range messages { - if message.Message.Header.Kind != arbostypes.L1MessageType_SubmitRetryable { + k := message.Message.Header.Kind + if k != arbostypes.L1MessageType_SubmitRetryable && k != arbostypes.L1MessageType_EthDeposit { continue } txs, err := arbos.ParseL2Transactions(message.Message, params.ArbitrumDevTestChainConfig().ChainID, nil) Require(t, err) for _, tx := range txs { - if tx.Type() == types.ArbitrumSubmitRetryableTxType { + if tx.Type() == types.ArbitrumSubmitRetryableTxType || tx.Type() == types.ArbitrumDepositTxType { submissionTxs = append(submissionTxs, tx) } } @@ -395,3 +398,121 @@ func waitForL1DelayBlocks(t *testing.T, ctx context.Context, l1client *ethclient }) } } + +func TestDepositETH(t *testing.T) { + t.Parallel() + _, l1info, l2client, l1client, delayedInbox, lookupSubmitRetryableL2TxHash, ctx, teardown := retryableSetup(t) + defer teardown() + + faucetAddr := l1info.GetAddress("Faucet") + + oldBalance, err := l2client.BalanceAt(ctx, faucetAddr, nil) + if err != nil { + t.Fatalf("BalanceAt(%v) unexpected error: %v", faucetAddr, err) + } + + txOpts := l1info.GetDefaultTransactOpts("Faucet", ctx) + txOpts.Value = big.NewInt(13) + + l1tx, err := delayedInbox.DepositEth0(&txOpts) + if err != nil { + t.Fatalf("DepositEth0() unexected error: %v", err) + } + + l1Receipt, err := EnsureTxSucceeded(ctx, l1client, l1tx) + if err != nil { + t.Fatalf("EnsureTxSucceeded() unexpected error: %v", err) + } + if l1Receipt.Status != types.ReceiptStatusSuccessful { + t.Errorf("Got transaction status: %v, want: %v", l1Receipt.Status, types.ReceiptStatusSuccessful) + } + waitForL1DelayBlocks(t, ctx, l1client, l1info) + + txHash := lookupSubmitRetryableL2TxHash(l1Receipt) + l2Receipt, err := WaitForTx(ctx, l2client, txHash, time.Second*5) + if err != nil { + t.Fatalf("WaitForTx(%v) unexpected error: %v", txHash, err) + } + if l2Receipt.Status != types.ReceiptStatusSuccessful { + t.Errorf("Got transaction status: %v, want: %v", l2Receipt.Status, types.ReceiptStatusSuccessful) + } + newBalance, err := l2client.BalanceAt(ctx, faucetAddr, l2Receipt.BlockNumber) + if err != nil { + t.Fatalf("BalanceAt(%v) unexpected error: %v", faucetAddr, err) + } + if got := new(big.Int); got.Sub(newBalance, oldBalance).Cmp(txOpts.Value) != 0 { + t.Errorf("Got transferred: %v, want: %v", got, txOpts.Value) + } +} + +func TestL1FundedUnsignedTransaction(t *testing.T) { + t.Parallel() + ctx := context.Background() + l2Info, node, l2Client, l1Info, _, l1Client, l1Stack := createTestNodeOnL1(t, ctx, true) + defer requireClose(t, l1Stack) + defer node.StopAndWait() + + faucetL2Addr := util.RemapL1Address(l1Info.GetAddress("Faucet")) + // Transfer balance to Faucet's corresponding L2 address, so that there is + // enough balance on its' account for executing L2 transaction. + TransferBalanceTo(t, "Faucet", faucetL2Addr, big.NewInt(1e18), l2Info, l2Client, ctx) + + l2TxOpts := l2Info.GetDefaultTransactOpts("Faucet", ctx) + contractAddr, _ := deploySimple(t, ctx, l2TxOpts, l2Client) + contractABI, err := abi.JSON(strings.NewReader(mocksgen.SimpleABI)) + if err != nil { + t.Fatalf("Error parsing contract ABI: %v", err) + } + data, err := contractABI.Pack("checkCalls", true, true, false, false, false, false) + if err != nil { + t.Fatalf("Error packing method's call data: %v", err) + } + nonce, err := l2Client.NonceAt(ctx, faucetL2Addr, nil) + if err != nil { + t.Fatalf("Error getting nonce at address: %v, error: %v", faucetL2Addr, err) + } + unsignedTx := types.NewTx(&types.ArbitrumUnsignedTx{ + ChainId: l2Info.Signer.ChainID(), + From: faucetL2Addr, + Nonce: nonce, + GasFeeCap: l2Info.GasPrice, + Gas: 1e6, + To: &contractAddr, + Value: common.Big0, + Data: data, + }) + + delayedInbox, err := bridgegen.NewInbox(l1Info.GetAddress("Inbox"), l1Client) + if err != nil { + t.Fatalf("Error getting Go binding of L1 Inbox contract: %v", err) + } + + txOpts := l1Info.GetDefaultTransactOpts("Faucet", ctx) + l1tx, err := delayedInbox.SendUnsignedTransaction( + &txOpts, + arbmath.UintToBig(unsignedTx.Gas()), + unsignedTx.GasFeeCap(), + arbmath.UintToBig(unsignedTx.Nonce()), + *unsignedTx.To(), + unsignedTx.Value(), + unsignedTx.Data(), + ) + if err != nil { + t.Fatalf("Error sending unsigned transaction: %v", err) + } + receipt, err := EnsureTxSucceeded(ctx, l1Client, l1tx) + if err != nil { + t.Fatalf("EnsureTxSucceeded(%v) unexpected error: %v", l1tx.Hash(), err) + } + if receipt.Status != types.ReceiptStatusSuccessful { + t.Errorf("L1 transaction: %v has failed", l1tx.Hash()) + } + waitForL1DelayBlocks(t, ctx, l1Client, l1Info) + receipt, err = EnsureTxSucceeded(ctx, l2Client, unsignedTx) + if err != nil { + t.Fatalf("EnsureTxSucceeded(%v) unexpected error: %v", unsignedTx.Hash(), err) + } + if receipt.Status != types.ReceiptStatusSuccessful { + t.Errorf("L2 transaction: %v has failed", receipt.TxHash) + } +} diff --git a/system_tests/seqinbox_test.go b/system_tests/seqinbox_test.go index 56d727dd26..bf3e7c86c1 100644 --- a/system_tests/seqinbox_test.go +++ b/system_tests/seqinbox_test.go @@ -267,11 +267,13 @@ func testSequencerInboxReaderImpl(t *testing.T, validator bool) { if validator && i%15 == 0 { for i := 0; ; i++ { - lastValidated := arbNode.BlockValidator.LastBlockValidated() - if lastValidated == expectedBlockNumber { + expectedPos, err := arbNode.Execution.ExecEngine.BlockNumberToMessageIndex(expectedBlockNumber) + Require(t, err) + lastValidated := arbNode.BlockValidator.Validated(t) + if lastValidated == expectedPos+1 { break } else if i >= 1000 { - Fatal(t, "timed out waiting for block validator; have", lastValidated, "want", expectedBlockNumber) + Fatal(t, "timed out waiting for block validator; have", lastValidated, "want", expectedPos+1) } time.Sleep(time.Second) } diff --git a/system_tests/staker_test.go b/system_tests/staker_test.go index e5ed3879e9..7a3ae41814 100644 --- a/system_tests/staker_test.go +++ b/system_tests/staker_test.go @@ -149,8 +149,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) l2nodeA.InboxReader, l2nodeA.InboxTracker, l2nodeA.TxStreamer, - l2nodeA.Execution.ArbInterface.BlockChain(), - l2nodeA.Execution.ChainDB, + l2nodeA.Execution.Recorder, l2nodeA.ArbDB, nil, StaticFetcherFrom(t, &blockValidatorConfig), @@ -166,7 +165,9 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) valConfig, nil, statelessA, + nil, l2nodeA.DeployInfo.ValidatorUtils, + nil, ) Require(t, err) err = stakerA.Initialize(ctx) @@ -183,8 +184,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) l2nodeB.InboxReader, l2nodeB.InboxTracker, l2nodeB.TxStreamer, - l2nodeB.Execution.ArbInterface.BlockChain(), - l2nodeB.Execution.ChainDB, + l2nodeB.Execution.Recorder, l2nodeB.ArbDB, nil, StaticFetcherFrom(t, &blockValidatorConfig), @@ -200,7 +200,9 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) valConfig, nil, statelessB, + nil, l2nodeB.DeployInfo.ValidatorUtils, + nil, ) Require(t, err) err = stakerB.Initialize(ctx) @@ -220,7 +222,9 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) valConfig, nil, statelessA, + nil, l2nodeA.DeployInfo.ValidatorUtils, + nil, ) Require(t, err) if stakerC.Strategy() != staker.WatchtowerStrategy { @@ -290,7 +294,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) } if err != nil && faultyStaker && i%2 == 1 { // Check if this is an expected error from the faulty staker. - if strings.Contains(err.Error(), "agreed with entire challenge") || strings.Contains(err.Error(), "after block -1 expected global state") { + if strings.Contains(err.Error(), "agreed with entire challenge") || strings.Contains(err.Error(), "after msg 0 expected global state") { // Expected error upon realizing you're losing the challenge. Get ready for a timeout. if !challengeMangerTimedOut { // Upgrade the ChallengeManager contract to an implementation which says challenges are always timed out @@ -318,7 +322,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) } } else if strings.Contains(err.Error(), "insufficient funds") && sawStakerZombie { // Expected error when trying to re-stake after losing initial stake. - } else if strings.Contains(err.Error(), "unknown start block hash") && sawStakerZombie { + } else if strings.Contains(err.Error(), "start state not in chain") && sawStakerZombie { // Expected error when trying to re-stake after the challenger's nodes getting confirmed. } else if strings.Contains(err.Error(), "STAKER_IS_ZOMBIE") && sawStakerZombie { // Expected error when the staker is a zombie and thus can't advance its stake. diff --git a/system_tests/twonodeslong_test.go b/system_tests/twonodeslong_test.go index c2a5979c8d..3987e5cf7b 100644 --- a/system_tests/twonodeslong_test.go +++ b/system_tests/twonodeslong_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/offchainlabs/nitro/arbos/l2pricing" + "github.com/offchainlabs/nitro/arbutil" "github.com/ethereum/go-ethereum/core/types" ) @@ -173,7 +174,8 @@ func testTwoNodesLong(t *testing.T, dasModeStr string) { lastBlockHeader, err := l2clientB.HeaderByNumber(ctx, nil) Require(t, err) timeout := getDeadlineTimeout(t, time.Minute*30) - if !nodeB.BlockValidator.WaitForBlock(ctx, lastBlockHeader.Number.Uint64(), timeout) { + // messageindex is same as block number here + if !nodeB.BlockValidator.WaitForPos(t, ctx, arbutil.MessageIndex(lastBlockHeader.Number.Uint64()), timeout) { Fatal(t, "did not validate all blocks") } } diff --git a/system_tests/validation_mock_test.go b/system_tests/validation_mock_test.go index c853568479..bfa2d67839 100644 --- a/system_tests/validation_mock_test.go +++ b/system_tests/validation_mock_test.go @@ -3,7 +3,6 @@ package arbtest import ( "bytes" "context" - "errors" "math/big" "testing" "time" @@ -12,6 +11,11 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/rpc" + "github.com/offchainlabs/nitro/arbnode" + "github.com/offchainlabs/nitro/arbnode/execution" + "github.com/offchainlabs/nitro/arbos/arbostypes" + "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/staker" "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/rpcclient" "github.com/offchainlabs/nitro/validator" @@ -20,6 +24,8 @@ import ( ) type mockSpawner struct { + ExecSpawned []uint64 + LaunchDelay time.Duration } var blockHashKey = common.HexToHash("0x11223344") @@ -50,11 +56,8 @@ func (s *mockSpawner) Launch(entry *validator.ValidationInput, moduleRoot common Promise: containers.NewPromise[validator.GoGlobalState](nil), root: moduleRoot, } - if moduleRoot != mockWasmModuleRoot { - run.ProduceError(errors.New("unsupported root")) - } else { - run.Produce(globalstateFromTestPreimages(entry.Preimages)) - } + <-time.After(s.LaunchDelay) + run.Produce(globalstateFromTestPreimages(entry.Preimages)) return run } @@ -66,9 +69,7 @@ func (s *mockSpawner) Name() string { return "mock" } func (s *mockSpawner) Room() int { return 4 } func (s *mockSpawner) CreateExecutionRun(wasmModuleRoot common.Hash, input *validator.ValidationInput) containers.PromiseInterface[validator.ExecutionRun] { - if wasmModuleRoot != mockWasmModuleRoot { - return containers.NewReadyPromise[validator.ExecutionRun](nil, errors.New("unsupported root")) - } + s.ExecSpawned = append(s.ExecSpawned, input.Id) return containers.NewReadyPromise[validator.ExecutionRun](&mockExecRun{ startState: input.StartState, endState: globalstateFromTestPreimages(input.Preimages), @@ -130,7 +131,7 @@ func (r *mockExecRun) PrepareRange(uint64, uint64) containers.PromiseInterface[s func (r *mockExecRun) Close() {} -func createMockValidationNode(t *testing.T, ctx context.Context, config *server_arb.ArbitratorSpawnerConfig) *node.Node { +func createMockValidationNode(t *testing.T, ctx context.Context, config *server_arb.ArbitratorSpawnerConfig) (*mockSpawner, *node.Node) { stackConf := node.DefaultConfig stackConf.HTTPPort = 0 stackConf.DataDir = "" @@ -170,7 +171,7 @@ func createMockValidationNode(t *testing.T, ctx context.Context, config *server_ serverAPI.StopOnly() }() - return stack + return spawner, stack } // mostly tests translation to/from json and running over network @@ -178,7 +179,7 @@ func TestValidationServerAPI(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - validationDefault := createMockValidationNode(t, ctx, nil) + _, validationDefault := createMockValidationNode(t, ctx, nil) client := server_api.NewExecutionClient(StaticFetcherFrom(t, &rpcclient.TestClientConfig), validationDefault) err := client.Start(ctx) Require(t, err) @@ -238,14 +239,94 @@ func TestValidationServerAPI(t *testing.T) { } } +func TestValidationClientRoom(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mockSpawner, spawnerStack := createMockValidationNode(t, ctx, nil) + client := server_api.NewExecutionClient(StaticFetcherFrom(t, &rpcclient.TestClientConfig), spawnerStack) + err := client.Start(ctx) + Require(t, err) + + wasmRoot, err := client.LatestWasmModuleRoot().Await(ctx) + Require(t, err) + + if client.Room() != 4 { + Fatal(t, "wrong initial room ", client.Room()) + } + + hash1 := common.HexToHash("0x11223344556677889900aabbccddeeff") + hash2 := common.HexToHash("0x11111111122222223333333444444444") + + startState := validator.GoGlobalState{ + BlockHash: hash1, + SendRoot: hash2, + Batch: 300, + PosInBatch: 3000, + } + endState := validator.GoGlobalState{ + BlockHash: hash2, + SendRoot: hash1, + Batch: 3000, + PosInBatch: 300, + } + + valInput := validator.ValidationInput{ + StartState: startState, + Preimages: globalstateToTestPreimages(endState), + } + + valRuns := make([]validator.ValidationRun, 0, 4) + + for i := 0; i < 4; i++ { + valRun := client.Launch(&valInput, wasmRoot) + valRuns = append(valRuns, valRun) + } + + for i := range valRuns { + _, err := valRuns[i].Await(ctx) + Require(t, err) + } + + if client.Room() != 4 { + Fatal(t, "wrong room after launch", client.Room()) + } + + mockSpawner.LaunchDelay = time.Hour + + valRuns = make([]validator.ValidationRun, 0, 3) + + for i := 0; i < 4; i++ { + valRun := client.Launch(&valInput, wasmRoot) + valRuns = append(valRuns, valRun) + room := client.Room() + if room != 3-i { + Fatal(t, "wrong room after launch ", room, " expected: ", 4-i) + } + } + + for i := range valRuns { + valRuns[i].Cancel() + _, err := valRuns[i].Await(ctx) + if err == nil { + Fatal(t, "no error returned after cancel i:", i) + } + } + + room := client.Room() + if room != 4 { + Fatal(t, "wrong room after canceling runs: ", room) + } +} + func TestExecutionKeepAlive(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - validationDefault := createMockValidationNode(t, ctx, nil) + _, validationDefault := createMockValidationNode(t, ctx, nil) shortTimeoutConfig := server_arb.DefaultArbitratorSpawnerConfig shortTimeoutConfig.ExecRunTimeout = time.Second - validationShortTO := createMockValidationNode(t, ctx, &shortTimeoutConfig) + _, validationShortTO := createMockValidationNode(t, ctx, &shortTimeoutConfig) configFetcher := StaticFetcherFrom(t, &rpcclient.TestClientConfig) clientDefault := server_api.NewExecutionClient(configFetcher, validationDefault) @@ -274,3 +355,43 @@ func TestExecutionKeepAlive(t *testing.T) { t.Error("getStep should have timed out but didn't") } } + +type mockBlockRecorder struct { + validator *staker.StatelessBlockValidator + streamer *arbnode.TransactionStreamer +} + +func (m *mockBlockRecorder) RecordBlockCreation( + ctx context.Context, + pos arbutil.MessageIndex, + msg *arbostypes.MessageWithMetadata, +) (*execution.RecordResult, error) { + _, globalpos, err := m.validator.GlobalStatePositionsAtCount(pos + 1) + if err != nil { + return nil, err + } + res, err := m.streamer.ResultAtCount(pos + 1) + if err != nil { + return nil, err + } + globalState := validator.GoGlobalState{ + Batch: globalpos.BatchNumber, + PosInBatch: globalpos.PosInBatch, + BlockHash: res.BlockHash, + SendRoot: res.SendRoot, + } + return &execution.RecordResult{ + Pos: pos, + BlockHash: res.BlockHash, + Preimages: globalstateToTestPreimages(globalState), + }, nil +} + +func (m *mockBlockRecorder) MarkValid(pos arbutil.MessageIndex, resultHash common.Hash) {} +func (m *mockBlockRecorder) PrepareForRecord(ctx context.Context, start, end arbutil.MessageIndex) error { + return nil +} + +func newMockRecorder(validator *staker.StatelessBlockValidator, streamer *arbnode.TransactionStreamer) *mockBlockRecorder { + return &mockBlockRecorder{validator, streamer} +} diff --git a/util/containers/syncmap.go b/util/containers/syncmap.go new file mode 100644 index 0000000000..7952a32252 --- /dev/null +++ b/util/containers/syncmap.go @@ -0,0 +1,24 @@ +package containers + +import "sync" + +type SyncMap[K any, V any] struct { + internal sync.Map +} + +func (m *SyncMap[K, V]) Load(key K) (V, bool) { + val, found := m.internal.Load(key) + if !found { + var empty V + return empty, false + } + return val.(V), true +} + +func (m *SyncMap[K, V]) Store(key K, val V) { + m.internal.Store(key, val) +} + +func (m *SyncMap[K, V]) Delete(key K) { + m.internal.Delete(key) +} diff --git a/util/jsonapi/preimages.go b/util/jsonapi/preimages.go new file mode 100644 index 0000000000..d669b7046e --- /dev/null +++ b/util/jsonapi/preimages.go @@ -0,0 +1,172 @@ +// Copyright 2023, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE + +package jsonapi + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + + "github.com/ethereum/go-ethereum/common" +) + +type PreimagesMapJson struct { + Map map[common.Hash][]byte +} + +func NewPreimagesMapJson(inner map[common.Hash][]byte) PreimagesMapJson { + return PreimagesMapJson{inner} +} + +func (m *PreimagesMapJson) MarshalJSON() ([]byte, error) { + encoding := base64.StdEncoding + size := 2 // {} + size += (5 + encoding.EncodedLen(32)) * len(m.Map) // "000..000":"" + if len(m.Map) > 0 { + size += len(m.Map) - 1 // commas + } + for _, value := range m.Map { + size += encoding.EncodedLen(len(value)) + } + out := make([]byte, size) + i := 0 + out[i] = '{' + i++ + for key, value := range m.Map { + if i > 1 { + out[i] = ',' + i++ + } + out[i] = '"' + i++ + encoding.Encode(out[i:], key[:]) + i += encoding.EncodedLen(len(key)) + out[i] = '"' + i++ + out[i] = ':' + i++ + out[i] = '"' + i++ + encoding.Encode(out[i:], value) + i += encoding.EncodedLen(len(value)) + out[i] = '"' + i++ + } + out[i] = '}' + i++ + if i != len(out) { + return nil, fmt.Errorf("preimage map wrote %v bytes but expected to write %v", i, len(out)) + } + return out, nil +} + +func readNonWhitespace(data *[]byte) (byte, error) { + c := byte('\t') + for c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r' || c == ' ' { + if len(*data) == 0 { + return 0, io.ErrUnexpectedEOF + } + c = (*data)[0] + *data = (*data)[1:] + } + return c, nil +} + +func expectCharacter(data *[]byte, expected rune) error { + got, err := readNonWhitespace(data) + if err != nil { + return fmt.Errorf("while looking for '%v' got %w", expected, err) + } + if rune(got) != expected { + return fmt.Errorf("while looking for '%v' got '%v'", expected, rune(got)) + } + return nil +} + +func getStrLen(data []byte) (int, error) { + // We don't allow strings to contain an escape sequence. + // Searching for a backslash here would be duplicated work. + // If the returned string length includes a backslash, base64 decoding will fail and error there. + strLen := bytes.IndexByte(data, '"') + if strLen == -1 { + return 0, fmt.Errorf("%w: hit end of preimages map looking for end quote", io.ErrUnexpectedEOF) + } + return strLen, nil +} + +func (m *PreimagesMapJson) UnmarshalJSON(data []byte) error { + err := expectCharacter(&data, '{') + if err != nil { + return err + } + m.Map = make(map[common.Hash][]byte) + encoding := base64.StdEncoding + // Used to store base64 decoded data + // Returned unmarshalled preimage slices will just be parts of this one + buf := make([]byte, encoding.DecodedLen(len(data))) + for { + c, err := readNonWhitespace(&data) + if err != nil { + return fmt.Errorf("while looking for key in preimages map got %w", err) + } + if len(m.Map) == 0 && c == '}' { + break + } else if c != '"' { + return fmt.Errorf("expected '\"' to begin key in preimages map but got '%v'", c) + } + strLen, err := getStrLen(data) + if err != nil { + return err + } + maxKeyLen := encoding.DecodedLen(strLen) + if maxKeyLen > len(buf) { + return fmt.Errorf("preimage key base64 possible length %v is greater than buffer size of %v", maxKeyLen, len(buf)) + } + keyLen, err := encoding.Decode(buf, data[:strLen]) + if err != nil { + return fmt.Errorf("error base64 decoding preimage key: %w", err) + } + var key common.Hash + if keyLen != len(key) { + return fmt.Errorf("expected preimage to be %v bytes long, but got %v bytes", len(key), keyLen) + } + copy(key[:], buf[:len(key)]) + // We don't need to advance buf here because we already copied the data we needed out of it + data = data[strLen+1:] + err = expectCharacter(&data, ':') + if err != nil { + return err + } + err = expectCharacter(&data, '"') + if err != nil { + return err + } + strLen, err = getStrLen(data) + if err != nil { + return err + } + maxValueLen := encoding.DecodedLen(strLen) + if maxValueLen > len(buf) { + return fmt.Errorf("preimage value base64 possible length %v is greater than buffer size of %v", maxValueLen, len(buf)) + } + valueLen, err := encoding.Decode(buf, data[:strLen]) + if err != nil { + return fmt.Errorf("error base64 decoding preimage value: %w", err) + } + m.Map[key] = buf[:valueLen] + buf = buf[valueLen:] + data = data[strLen+1:] + c, err = readNonWhitespace(&data) + if err != nil { + return fmt.Errorf("after value in preimages map got %w", err) + } + if c == '}' { + break + } else if c != ',' { + return fmt.Errorf("expected ',' or '}' after value in preimages map but got '%v'", c) + } + } + return nil +} diff --git a/util/jsonapi/preimages_test.go b/util/jsonapi/preimages_test.go new file mode 100644 index 0000000000..3074a1e698 --- /dev/null +++ b/util/jsonapi/preimages_test.go @@ -0,0 +1,57 @@ +// Copyright 2023, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE + +package jsonapi + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/offchainlabs/nitro/util/testhelpers" +) + +func Require(t *testing.T, err error, printables ...interface{}) { + t.Helper() + testhelpers.RequireImpl(t, err, printables...) +} + +func TestPreimagesMapJson(t *testing.T) { + t.Parallel() + for _, preimages := range []PreimagesMapJson{ + {}, + {make(map[common.Hash][]byte)}, + {map[common.Hash][]byte{ + {}: {}, + }}, + {map[common.Hash][]byte{ + {1}: {1}, + {2}: {1, 2}, + {3}: {1, 2, 3}, + }}, + } { + t.Run(fmt.Sprintf("%v preimages", len(preimages.Map)), func(t *testing.T) { + // These test cases are fast enough that t.Parallel() probably isn't worth it + serialized, err := preimages.MarshalJSON() + Require(t, err, "Failed to marshal preimagesj") + + // Make sure that `serialized` is a valid JSON map + stringMap := make(map[string]string) + err = json.Unmarshal(serialized, &stringMap) + Require(t, err, "Failed to unmarshal preimages as string map") + if len(stringMap) != len(preimages.Map) { + t.Errorf("Got %v entries in string map but only had %v preimages", len(stringMap), len(preimages.Map)) + } + + var deserialized PreimagesMapJson + err = deserialized.UnmarshalJSON(serialized) + Require(t, err) + + if (len(preimages.Map) > 0 || len(deserialized.Map) > 0) && !reflect.DeepEqual(preimages, deserialized) { + t.Errorf("Preimages map %v serialized to %v but then deserialized to different map %v", preimages, string(serialized), deserialized) + } + }) + } +} diff --git a/util/rpcclient/rpcclient.go b/util/rpcclient/rpcclient.go index 2d43ded0d7..0a09459070 100644 --- a/util/rpcclient/rpcclient.go +++ b/util/rpcclient/rpcclient.go @@ -85,27 +85,37 @@ func (c *RpcClient) Close() { } } -func limitedMarshal(limit int, arg interface{}) string { - marshalled, err := json.Marshal(arg) +type limitedMarshal struct { + limit int + value any +} + +func (m limitedMarshal) String() string { + marshalled, err := json.Marshal(m.value) var str string if err != nil { - str = "\"CANNOT MARSHALL:" + err.Error() + "\"" + str = "\"CANNOT MARSHALL: " + err.Error() + "\"" } else { str = string(marshalled) } - if limit == 0 || len(str) <= limit { + if m.limit == 0 || len(str) <= m.limit { return str } - prefix := str[:limit/2-1] - postfix := str[len(str)-limit/2+1:] + prefix := str[:m.limit/2-1] + postfix := str[len(str)-m.limit/2+1:] return fmt.Sprintf("%v..%v", prefix, postfix) } -func logArgs(limit int, args ...interface{}) string { +type limitedArgumentsMarshal struct { + limit int + args []any +} + +func (m limitedArgumentsMarshal) String() string { res := "[" - for i, arg := range args { - res += limitedMarshal(limit, arg) - if i < len(args)-1 { + for i, arg := range m.args { + res += limitedMarshal{m.limit, arg}.String() + if i < len(m.args)-1 { res += ", " } } @@ -118,7 +128,7 @@ func (c *RpcClient) CallContext(ctx_in context.Context, result interface{}, meth return errors.New("not connected") } logId := atomic.AddUint64(&c.logId, 1) - log.Trace("sending RPC request", "method", method, "logId", logId, "args", logArgs(int(c.config().ArgLogLimit), args...)) + log.Trace("sending RPC request", "method", method, "logId", logId, "args", limitedArgumentsMarshal{int(c.config().ArgLogLimit), args}) var err error for i := 0; i < int(c.config().Retries)+1; i++ { if ctx_in.Err() != nil { @@ -138,9 +148,8 @@ func (c *RpcClient) CallContext(ctx_in context.Context, result interface{}, meth limit := int(c.config().ArgLogLimit) if err != nil && err.Error() != "already known" { logger = log.Info - limit = 0 } - logger("rpc response", "method", method, "logId", logId, "err", err, "result", limitedMarshal(limit, result), "attempt", i, "args", logArgs(limit, args...)) + logger("rpc response", "method", method, "logId", logId, "err", err, "result", limitedMarshal{limit, result}, "attempt", i, "args", limitedArgumentsMarshal{limit, args}) if err == nil { return nil } diff --git a/util/rpcclient/rpcclient_test.go b/util/rpcclient/rpcclient_test.go index 1b44b9479e..c9d813ab0d 100644 --- a/util/rpcclient/rpcclient_test.go +++ b/util/rpcclient/rpcclient_test.go @@ -15,17 +15,18 @@ import ( func TestLogArgs(t *testing.T) { t.Parallel() - str := logArgs(0, 1, 2, 3, "hello, world") + args := []any{1, 2, 3, "hello, world"} + str := limitedArgumentsMarshal{0, args}.String() if str != "[1, 2, 3, \"hello, world\"]" { Fail(t, "unexpected logs limit 0 got:", str) } - str = logArgs(100, 1, 2, 3, "hello, world") + str = limitedArgumentsMarshal{100, args}.String() if str != "[1, 2, 3, \"hello, world\"]" { Fail(t, "unexpected logs limit 100 got:", str) } - str = logArgs(6, 1, 2, 3, "hello, world") + str = limitedArgumentsMarshal{6, args}.String() if str != "[1, 2, 3, \"h..d\"]" { Fail(t, "unexpected logs limit 6 got:", str) } diff --git a/util/stopwaiter/stopwaiter.go b/util/stopwaiter/stopwaiter.go index f449ed1c44..1e70e328eb 100644 --- a/util/stopwaiter/stopwaiter.go +++ b/util/stopwaiter/stopwaiter.go @@ -198,6 +198,9 @@ func (s *StopWaiterSafe) CallIterativelySafe(foo func(context.Context) time.Dura if ctx.Err() != nil { return } + if interval == time.Duration(0) { + continue + } timer := time.NewTimer(interval) select { case <-ctx.Done(): @@ -233,6 +236,9 @@ func CallIterativelyWith[T any]( return } val = defaultVal + if interval == time.Duration(0) { + continue + } timer := time.NewTimer(interval) select { case <-ctx.Done(): diff --git a/validator/server_api/json.go b/validator/server_api/json.go index 89c13f2dcb..95108757d7 100644 --- a/validator/server_api/json.go +++ b/validator/server_api/json.go @@ -1,10 +1,13 @@ +// Copyright 2023, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE + package server_api import ( "encoding/base64" "github.com/ethereum/go-ethereum/common" - + "github.com/offchainlabs/nitro/util/jsonapi" "github.com/offchainlabs/nitro/validator" ) @@ -17,7 +20,7 @@ type ValidationInputJson struct { Id uint64 HasDelayedMsg bool DelayedMsgNr uint64 - PreimagesB64 map[string]string + PreimagesB64 jsonapi.PreimagesMapJson BatchInfo []BatchInfoJson DelayedMsgB64 string StartState validator.GoGlobalState @@ -30,12 +33,7 @@ func ValidationInputToJson(entry *validator.ValidationInput) *ValidationInputJso DelayedMsgNr: entry.DelayedMsgNr, DelayedMsgB64: base64.StdEncoding.EncodeToString(entry.DelayedMsg), StartState: entry.StartState, - PreimagesB64: make(map[string]string), - } - for hash, data := range entry.Preimages { - encHash := base64.StdEncoding.EncodeToString(hash.Bytes()) - encData := base64.StdEncoding.EncodeToString(data) - res.PreimagesB64[encHash] = encData + PreimagesB64: jsonapi.NewPreimagesMapJson(entry.Preimages), } for _, binfo := range entry.BatchInfo { encData := base64.StdEncoding.EncodeToString(binfo.Data) @@ -50,24 +48,13 @@ func ValidationInputFromJson(entry *ValidationInputJson) (*validator.ValidationI HasDelayedMsg: entry.HasDelayedMsg, DelayedMsgNr: entry.DelayedMsgNr, StartState: entry.StartState, - Preimages: make(map[common.Hash][]byte), + Preimages: entry.PreimagesB64.Map, } delayed, err := base64.StdEncoding.DecodeString(entry.DelayedMsgB64) if err != nil { return nil, err } valInput.DelayedMsg = delayed - for encHash, encData := range entry.PreimagesB64 { - hash, err := base64.StdEncoding.DecodeString(encHash) - if err != nil { - return nil, err - } - data, err := base64.StdEncoding.DecodeString(encData) - if err != nil { - return nil, err - } - valInput.Preimages[common.BytesToHash(hash)] = data - } for _, binfo := range entry.BatchInfo { data, err := base64.StdEncoding.DecodeString(binfo.DataB64) if err != nil { diff --git a/validator/server_api/validation_client.go b/validator/server_api/validation_client.go index 4f678fde9e..d6143ca917 100644 --- a/validator/server_api/validation_client.go +++ b/validator/server_api/validation_client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "errors" + "sync/atomic" "time" "github.com/offchainlabs/nitro/validator" @@ -23,6 +24,7 @@ type ValidationClient struct { stopwaiter.StopWaiter client *rpcclient.RpcClient name string + room int32 } func NewValidationClient(config rpcclient.ClientConfigFetcher, stack *node.Node) *ValidationClient { @@ -32,14 +34,15 @@ func NewValidationClient(config rpcclient.ClientConfigFetcher, stack *node.Node) } func (c *ValidationClient) Launch(entry *validator.ValidationInput, moduleRoot common.Hash) validator.ValidationRun { - valrun := server_common.NewValRun(moduleRoot) - c.LaunchThread(func(ctx context.Context) { + atomic.AddInt32(&c.room, -1) + promise := stopwaiter.LaunchPromiseThread[validator.GoGlobalState](c, func(ctx context.Context) (validator.GoGlobalState, error) { input := ValidationInputToJson(entry) var res validator.GoGlobalState err := c.client.CallContext(ctx, &res, Namespace+"_validate", input, moduleRoot) - valrun.ConsumeResult(res, err) + atomic.AddInt32(&c.room, 1) + return res, err }) - return valrun + return server_common.NewValRun(promise, moduleRoot) } func (c *ValidationClient) Start(ctx_in context.Context) error { @@ -57,6 +60,18 @@ func (c *ValidationClient) Start(ctx_in context.Context) error { if len(name) == 0 { return errors.New("couldn't read name from server") } + var room int + err = c.client.CallContext(c.GetContext(), &room, Namespace+"_room") + if err != nil { + return err + } + if room < 2 { + log.Warn("validation server not enough room, overriding to 2", "name", name, "room", room) + room = 2 + } else { + log.Info("connected to validation server", "name", name, "room", room) + } + atomic.StoreInt32(&c.room, int32(room)) c.name = name return nil } @@ -76,13 +91,11 @@ func (c *ValidationClient) Name() string { } func (c *ValidationClient) Room() int { - var res int - err := c.client.CallContext(c.GetContext(), &res, Namespace+"_room") - if err != nil { - log.Error("error contacting validation server", "name", c.name, "err", err) + room32 := atomic.LoadInt32(&c.room) + if room32 < 0 { return 0 } - return res + return int(room32) } type ExecutionClient struct { diff --git a/validator/server_arb/validator_spawner.go b/validator/server_arb/validator_spawner.go index a073d24c3c..f9d0705f59 100644 --- a/validator/server_arb/validator_spawner.go +++ b/validator/server_arb/validator_spawner.go @@ -57,32 +57,6 @@ type ArbitratorSpawner struct { config ArbitratorSpawnerConfigFecher } -type valRun struct { - containers.Promise[validator.GoGlobalState] - root common.Hash -} - -func (r *valRun) WasmModuleRoot() common.Hash { - return r.root -} - -func (r *valRun) Close() {} - -func NewvalRun(root common.Hash) *valRun { - return &valRun{ - Promise: containers.NewPromise[validator.GoGlobalState](nil), - root: root, - } -} - -func (r *valRun) consumeResult(res validator.GoGlobalState, err error) { - if err != nil { - r.ProduceError(err) - } else { - r.Produce(res) - } -} - func NewArbitratorSpawner(locator *server_common.MachineLocator, config ArbitratorSpawnerConfigFecher) (*ArbitratorSpawner, error) { // TODO: preload machines spawner := &ArbitratorSpawner{ @@ -180,12 +154,11 @@ func (v *ArbitratorSpawner) execute( func (v *ArbitratorSpawner) Launch(entry *validator.ValidationInput, moduleRoot common.Hash) validator.ValidationRun { atomic.AddInt32(&v.count, 1) - run := NewvalRun(moduleRoot) - v.LaunchThread(func(ctx context.Context) { + promise := stopwaiter.LaunchPromiseThread[validator.GoGlobalState](v, func(ctx context.Context) (validator.GoGlobalState, error) { defer atomic.AddInt32(&v.count, -1) - run.consumeResult(v.execute(ctx, entry, moduleRoot)) + return v.execute(ctx, entry, moduleRoot) }) - return run + return server_common.NewValRun(promise, moduleRoot) } func (v *ArbitratorSpawner) Room() int { @@ -193,11 +166,7 @@ func (v *ArbitratorSpawner) Room() int { if avail == 0 { avail = runtime.NumCPU() } - current := int(atomic.LoadInt32(&v.count)) - if current >= avail { - return 0 - } - return avail - current + return avail } var launchTime = time.Now().Format("2006_01_02__15_04") diff --git a/validator/server_common/valrun.go b/validator/server_common/valrun.go index 1331c29852..8486664008 100644 --- a/validator/server_common/valrun.go +++ b/validator/server_common/valrun.go @@ -7,7 +7,7 @@ import ( ) type ValRun struct { - containers.Promise[validator.GoGlobalState] + containers.PromiseInterface[validator.GoGlobalState] root common.Hash } @@ -15,17 +15,9 @@ func (r *ValRun) WasmModuleRoot() common.Hash { return r.root } -func NewValRun(root common.Hash) *ValRun { +func NewValRun(promise containers.PromiseInterface[validator.GoGlobalState], root common.Hash) *ValRun { return &ValRun{ - Promise: containers.NewPromise[validator.GoGlobalState](nil), - root: root, - } -} - -func (r *ValRun) ConsumeResult(res validator.GoGlobalState, err error) { - if err != nil { - r.ProduceError(err) - } else { - r.Produce(res) + PromiseInterface: promise, + root: root, } } diff --git a/validator/server_jit/spawner.go b/validator/server_jit/spawner.go index 7a3394bcae..6de006b182 100644 --- a/validator/server_jit/spawner.go +++ b/validator/server_jit/spawner.go @@ -90,12 +90,11 @@ func (s *JitSpawner) Name() string { func (v *JitSpawner) Launch(entry *validator.ValidationInput, moduleRoot common.Hash) validator.ValidationRun { atomic.AddInt32(&v.count, 1) - run := server_common.NewValRun(moduleRoot) - go func() { - run.ConsumeResult(v.execute(v.GetContext(), entry, moduleRoot)) - atomic.AddInt32(&v.count, -1) - }() - return run + promise := stopwaiter.LaunchPromiseThread[validator.GoGlobalState](v, func(ctx context.Context) (validator.GoGlobalState, error) { + defer atomic.AddInt32(&v.count, -1) + return v.execute(ctx, entry, moduleRoot) + }) + return server_common.NewValRun(promise, moduleRoot) } func (v *JitSpawner) Room() int { @@ -103,7 +102,7 @@ func (v *JitSpawner) Room() int { if avail == 0 { avail = runtime.NumCPU() } - return avail - int(atomic.LoadInt32(&v.count)) + return avail } func (v *JitSpawner) Stop() {