diff --git a/arbnode/api.go b/arbnode/api.go index 057c03bf31..d28d7481d9 100644 --- a/arbnode/api.go +++ b/arbnode/api.go @@ -9,23 +9,17 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/rpc" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/staker" + "github.com/offchainlabs/nitro/validator" ) type BlockValidatorAPI struct { val *staker.BlockValidator } -func (a *BlockValidatorAPI) LatestValidatedBlock(ctx context.Context) (hexutil.Uint64, error) { - block := a.val.LastBlockValidated() - return hexutil.Uint64(block), nil -} - -func (a *BlockValidatorAPI) LatestValidatedBlockHash(ctx context.Context) (common.Hash, error) { - _, hash, _ := a.val.LastBlockValidatedAndHash() - return hash, nil +func (a *BlockValidatorAPI) LatestValidated(ctx context.Context) (*staker.GlobalStateValidatedInfo, error) { + return a.val.ReadLastValidatedInfo() } type BlockValidatorDebugAPI struct { @@ -34,25 +28,16 @@ type BlockValidatorDebugAPI struct { } type ValidateBlockResult struct { - Valid bool `json:"valid"` - Latency string `json:"latency"` + Valid bool `json:"valid"` + Latency string `json:"latency"` + GlobalState validator.GoGlobalState `json:"globalstate"` } -func (a *BlockValidatorDebugAPI) ValidateBlock( - ctx context.Context, blockNum rpc.BlockNumber, full bool, moduleRootOptional *common.Hash, +func (a *BlockValidatorDebugAPI) ValidateMessageNumber( + ctx context.Context, msgNum hexutil.Uint64, full bool, moduleRootOptional *common.Hash, ) (ValidateBlockResult, error) { result := ValidateBlockResult{} - if blockNum < 0 { - return result, errors.New("this method only accepts absolute block numbers") - } - header := a.blockchain.GetHeaderByNumber(uint64(blockNum)) - if header == nil { - return result, errors.New("block not found") - } - if !a.blockchain.Config().IsArbitrumNitro(header.Number) { - return result, types.ErrUseFallback - } var moduleRoot common.Hash if moduleRootOptional != nil { moduleRoot = *moduleRootOptional @@ -64,8 +49,11 @@ func (a *BlockValidatorDebugAPI) ValidateBlock( moduleRoot = moduleRoots[0] } start_time := time.Now() - valid, err := a.val.ValidateBlock(ctx, header, full, moduleRoot) - result.Valid = valid + valid, gs, err := a.val.ValidateResult(ctx, arbutil.MessageIndex(msgNum), full, moduleRoot) result.Latency = fmt.Sprintf("%vms", time.Since(start_time).Milliseconds()) + if gs != nil { + result.GlobalState = *gs + } + result.Valid = valid return result, err } diff --git a/arbnode/execution/block_recorder.go b/arbnode/execution/block_recorder.go new file mode 100644 index 0000000000..dc5daa6f7b --- /dev/null +++ b/arbnode/execution/block_recorder.go @@ -0,0 +1,364 @@ +package execution + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/arbitrum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/offchainlabs/nitro/arbos" + "github.com/offchainlabs/nitro/arbos/arbosState" + "github.com/offchainlabs/nitro/arbos/arbostypes" + "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/validator" +) + +// BlockRecorder uses a separate statedatabase from the blockchain. +// It has access to any state in the ethdb (hard-disk) database, and can compute state as needed. +// We keep references for state of: +// Any block that matches PrepareForRecord that was done recently (according to PrepareDelay config) +// Most recent/advanced header we ever computed (lastHdr) +// Hopefully - some recent valid block. For that we always keep one candidate block until it becomes validated. +type BlockRecorder struct { + recordingDatabase *arbitrum.RecordingDatabase + execEngine *ExecutionEngine + + lastHdr *types.Header + lastHdrLock sync.Mutex + + validHdrCandidate *types.Header + validHdr *types.Header + validHdrLock sync.Mutex + + preparedQueue []*types.Header + preparedLock sync.Mutex +} + +type RecordResult struct { + Pos arbutil.MessageIndex + BlockHash common.Hash + Preimages map[common.Hash][]byte + BatchInfo []validator.BatchInfo +} + +func NewBlockRecorder(config *arbitrum.RecordingDatabaseConfig, execEngine *ExecutionEngine, ethDb ethdb.Database) *BlockRecorder { + recorder := &BlockRecorder{ + execEngine: execEngine, + recordingDatabase: arbitrum.NewRecordingDatabase(config, ethDb, execEngine.bc), + } + execEngine.SetRecorder(recorder) + return recorder +} + +func stateLogFunc(targetHeader, header *types.Header, hasState bool) { + if targetHeader == nil || header == nil { + return + } + gap := targetHeader.Number.Int64() - header.Number.Int64() + step := int64(500) + stage := "computing state" + if !hasState { + step = 3000 + stage = "looking for full block" + } + if (gap >= step) && (gap%step == 0) { + log.Info("Setting up validation", "stage", stage, "current", header.Number, "target", targetHeader.Number) + } +} + +// If msg is nil, this will record block creation up to the point where message would be accessed (for a "too far" proof) +// If keepreference == true, reference to state of prevHeader is added (no reference added if an error is returned) +func (r *BlockRecorder) RecordBlockCreation( + ctx context.Context, + pos arbutil.MessageIndex, + msg *arbostypes.MessageWithMetadata, +) (*RecordResult, error) { + + blockNum := r.execEngine.MessageIndexToBlockNumber(pos) + + var prevHeader *types.Header + if pos != 0 { + prevHeader = r.execEngine.bc.GetHeaderByNumber(uint64(blockNum - 1)) + if prevHeader == nil { + return nil, fmt.Errorf("pos %d prevHeader not found", pos) + } + } + + recordingdb, chaincontext, recordingKV, err := r.recordingDatabase.PrepareRecording(ctx, prevHeader, stateLogFunc) + if err != nil { + return nil, err + } + defer func() { r.recordingDatabase.Dereference(prevHeader) }() + + chainConfig := r.execEngine.bc.Config() + + // Get the chain ID, both to validate and because the replay binary also gets the chain ID, + // so we need to populate the recordingdb with preimages for retrieving the chain ID. + if prevHeader != nil { + initialArbosState, err := arbosState.OpenSystemArbosState(recordingdb, nil, true) + if err != nil { + return nil, fmt.Errorf("error opening initial ArbOS state: %w", err) + } + chainId, err := initialArbosState.ChainId() + if err != nil { + return nil, fmt.Errorf("error getting chain ID from initial ArbOS state: %w", err) + } + if chainId.Cmp(chainConfig.ChainID) != 0 { + return nil, fmt.Errorf("unexpected chain ID %r in ArbOS state, expected %r", chainId, chainConfig.ChainID) + } + genesisNum, err := initialArbosState.GenesisBlockNum() + if err != nil { + return nil, fmt.Errorf("error getting genesis block number from initial ArbOS state: %w", err) + } + _, err = initialArbosState.ChainConfig() + if err != nil { + return nil, fmt.Errorf("error getting chain config from initial ArbOS state: %w", err) + } + expectedNum := chainConfig.ArbitrumChainParams.GenesisBlockNum + if genesisNum != expectedNum { + return nil, fmt.Errorf("unexpected genesis block number %v in ArbOS state, expected %v", genesisNum, expectedNum) + } + } + + var blockHash common.Hash + var readBatchInfo []validator.BatchInfo + if msg != nil { + batchFetcher := func(batchNum uint64) ([]byte, error) { + data, err := r.execEngine.streamer.FetchBatch(batchNum) + if err != nil { + return nil, err + } + readBatchInfo = append(readBatchInfo, validator.BatchInfo{ + Number: batchNum, + Data: data, + }) + return data, nil + } + // Re-fetch the batch instead of using our cached cost, + // as the replay binary won't have the cache populated. + msg.Message.BatchGasCost = nil + block, _, err := arbos.ProduceBlock( + msg.Message, + msg.DelayedMessagesRead, + prevHeader, + recordingdb, + chaincontext, + chainConfig, + batchFetcher, + ) + if err != nil { + return nil, err + } + blockHash = block.Hash() + } + + preimages, err := r.recordingDatabase.PreimagesFromRecording(chaincontext, recordingKV) + if err != nil { + return nil, err + } + + // check we got the canonical hash + canonicalHash := r.execEngine.bc.GetCanonicalHash(uint64(blockNum)) + if canonicalHash != blockHash { + return nil, fmt.Errorf("Blockhash doesn't match when recording got %v canonical %v", blockHash, canonicalHash) + } + + // these won't usually do much here (they will in preparerecording), but doesn't hurt to check + r.updateLastHdr(prevHeader) + r.updateValidCandidateHdr(prevHeader) + + return &RecordResult{pos, blockHash, preimages, readBatchInfo}, err +} + +func (r *BlockRecorder) updateLastHdr(hdr *types.Header) { + if hdr == nil { + return + } + r.lastHdrLock.Lock() + defer r.lastHdrLock.Unlock() + if r.lastHdr != nil { + if hdr.Number.Cmp(r.lastHdr.Number) <= 0 { + return + } + } + _, err := r.recordingDatabase.StateFor(hdr) + if err != nil { + log.Warn("failed to get state in updateLastHdr", "err", err) + return + } + r.recordingDatabase.Dereference(r.lastHdr) + r.lastHdr = hdr +} + +func (r *BlockRecorder) updateValidCandidateHdr(hdr *types.Header) { + if hdr == nil { + return + } + r.validHdrLock.Lock() + defer r.validHdrLock.Unlock() + // don't need a candidate that's newer than the current one (else it will never become valid) + if r.validHdrCandidate != nil && r.validHdrCandidate.Number.Cmp(hdr.Number) <= 0 { + return + } + // don't need a candidate that's older than known valid + if r.validHdr != nil && r.validHdr.Number.Cmp(hdr.Number) >= 0 { + return + } + _, err := r.recordingDatabase.StateFor(hdr) + if err != nil { + log.Warn("failed to get state in updateLastHdr", "err", err) + return + } + if r.validHdrCandidate != nil { + r.recordingDatabase.Dereference(r.validHdrCandidate) + } + r.validHdrCandidate = hdr +} + +func (r *BlockRecorder) MarkValid(pos arbutil.MessageIndex, resultHash common.Hash) { + r.validHdrLock.Lock() + defer r.validHdrLock.Unlock() + if r.validHdrCandidate == nil { + return + } + validNum := r.execEngine.MessageIndexToBlockNumber(pos) + if r.validHdrCandidate.Number.Uint64() > validNum { + return + } + // make sure the valid is canonical + canonicalResultHash := r.execEngine.bc.GetCanonicalHash(uint64(validNum)) + if canonicalResultHash != resultHash { + log.Warn("markvalid hash not canonical", "pos", pos, "result", resultHash, "canonical", canonicalResultHash) + return + } + // make sure the candidate is still canonical + canonicalHash := r.execEngine.bc.GetCanonicalHash(r.validHdrCandidate.Number.Uint64()) + candidateHash := r.validHdrCandidate.Hash() + if canonicalHash != candidateHash { + log.Error("vlid candidate hash not canonical", "number", r.validHdrCandidate.Number, "candidate", candidateHash, "canonical", canonicalHash) + r.recordingDatabase.Dereference(r.validHdrCandidate) + r.validHdrCandidate = nil + return + } + r.recordingDatabase.Dereference(r.validHdr) + r.validHdr = r.validHdrCandidate + r.validHdrCandidate = nil +} + +// TODO: use config +func (r *BlockRecorder) preparedAddTrim(newRefs []*types.Header, size int) { + var oldRefs []*types.Header + r.preparedLock.Lock() + r.preparedQueue = append(r.preparedQueue, newRefs...) + if len(r.preparedQueue) > size { + oldRefs = r.preparedQueue[:len(r.preparedQueue)-size] + r.preparedQueue = r.preparedQueue[len(r.preparedQueue)-size:] + } + r.preparedLock.Unlock() + for _, ref := range oldRefs { + r.recordingDatabase.Dereference(ref) + } +} + +func (r *BlockRecorder) preparedTrimBeyond(hdr *types.Header) { + var oldRefs []*types.Header + var validRefs []*types.Header + r.preparedLock.Lock() + for _, queHdr := range r.preparedQueue { + if queHdr.Number.Cmp(hdr.Number) > 0 { + oldRefs = append(oldRefs, queHdr) + } else { + validRefs = append(validRefs, queHdr) + } + } + r.preparedQueue = validRefs + r.preparedLock.Unlock() + for _, ref := range oldRefs { + r.recordingDatabase.Dereference(ref) + } +} + +func (r *BlockRecorder) TrimAllPrepared(t *testing.T) { + r.preparedAddTrim(nil, 0) +} + +func (r *BlockRecorder) RecordingDBReferenceCount() int64 { + return r.recordingDatabase.ReferenceCount() +} + +func (r *BlockRecorder) PrepareForRecord(ctx context.Context, start, end arbutil.MessageIndex) error { + var references []*types.Header + if end < start { + return fmt.Errorf("illegal range start %d > end %d", start, end) + } + numOfBlocks := uint64(end + 1 - start) + hdrNum := r.execEngine.MessageIndexToBlockNumber(start) + if start > 0 { + hdrNum-- // need to get previous + } else { + numOfBlocks-- // genesis block doesn't need preparation, so recording one less block + } + lastHdrNum := hdrNum + numOfBlocks + for hdrNum <= lastHdrNum { + header := r.execEngine.bc.GetHeaderByNumber(uint64(hdrNum)) + if header == nil { + log.Warn("prepareblocks asked for non-found block", "hdrNum", hdrNum) + break + } + _, err := r.recordingDatabase.GetOrRecreateState(ctx, header, stateLogFunc) + if err != nil { + log.Warn("prepareblocks failed to get state for block", "hdrNum", hdrNum, "err", err) + break + } + references = append(references, header) + r.updateValidCandidateHdr(header) + r.updateLastHdr(header) + hdrNum++ + } + r.preparedAddTrim(references, 1000) + return nil +} + +func (r *BlockRecorder) ReorgTo(hdr *types.Header) { + r.validHdrLock.Lock() + if r.validHdr != nil && r.validHdr.Number.Cmp(hdr.Number) > 0 { + log.Warn("block recorder: reorging past previously-marked valid block", "reorg target num", hdr.Number, "hash", hdr.Hash(), "reorged past num", r.validHdr.Number, "hash", r.validHdr.Hash()) + r.recordingDatabase.Dereference(r.validHdr) + r.validHdr = nil + } + if r.validHdrCandidate != nil && r.validHdrCandidate.Number.Cmp(hdr.Number) > 0 { + r.recordingDatabase.Dereference(r.validHdrCandidate) + r.validHdrCandidate = nil + } + r.validHdrLock.Unlock() + r.lastHdrLock.Lock() + if r.lastHdr != nil && r.lastHdr.Number.Cmp(hdr.Number) > 0 { + r.recordingDatabase.Dereference(r.lastHdr) + r.lastHdr = nil + } + r.lastHdrLock.Unlock() + r.preparedTrimBeyond(hdr) +} + +func (r *BlockRecorder) WriteValidStateToDb() error { + r.validHdrLock.Lock() + defer r.validHdrLock.Unlock() + if r.validHdr == nil { + return nil + } + err := r.recordingDatabase.WriteStateToDatabase(r.validHdr) + r.recordingDatabase.Dereference(r.validHdr) + return err +} + +func (r *BlockRecorder) OrderlyShutdown() { + err := r.WriteValidStateToDb() + if err != nil { + log.Error("failed writing latest valid block state to DB", "err", err) + } +} diff --git a/arbnode/execution/blockchain.go b/arbnode/execution/blockchain.go index a4de72588a..a2725712dc 100644 --- a/arbnode/execution/blockchain.go +++ b/arbnode/execution/blockchain.go @@ -188,22 +188,6 @@ func shouldPreserveFalse(_ *types.Header) bool { return false } -func ReorgToBlock(chain *core.BlockChain, blockNum uint64) (*types.Block, error) { - genesisNum := chain.Config().ArbitrumChainParams.GenesisBlockNum - if blockNum < genesisNum { - return nil, fmt.Errorf("cannot reorg to block %v past nitro genesis of %v", blockNum, genesisNum) - } - reorgingToBlock := chain.GetBlockByNumber(blockNum) - if reorgingToBlock == nil { - return nil, fmt.Errorf("didn't find reorg target block number %v", blockNum) - } - err := chain.ReorgToOldBlock(reorgingToBlock) - if err != nil { - return nil, err - } - return reorgingToBlock, nil -} - func init() { gethhook.RequireHookedGeth() } diff --git a/arbnode/execution/executionengine.go b/arbnode/execution/executionengine.go index 88b42cbe4c..d8029650d7 100644 --- a/arbnode/execution/executionengine.go +++ b/arbnode/execution/executionengine.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" @@ -19,7 +20,6 @@ import ( "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbos/l1pricing" "github.com/offchainlabs/nitro/arbutil" - "github.com/offchainlabs/nitro/staker" "github.com/offchainlabs/nitro/util/sharedmetrics" "github.com/offchainlabs/nitro/util/stopwaiter" ) @@ -33,9 +33,9 @@ type TransactionStreamerInterface interface { type ExecutionEngine struct { stopwaiter.StopWaiter - bc *core.BlockChain - validator *staker.BlockValidator - streamer TransactionStreamerInterface + bc *core.BlockChain + streamer TransactionStreamerInterface + recorder *BlockRecorder resequenceChan chan []*arbostypes.MessageWithMetadata createBlocksMutex sync.Mutex @@ -57,14 +57,14 @@ func NewExecutionEngine(bc *core.BlockChain) (*ExecutionEngine, error) { }, nil } -func (s *ExecutionEngine) SetBlockValidator(validator *staker.BlockValidator) { +func (s *ExecutionEngine) SetRecorder(recorder *BlockRecorder) { if s.Started() { - panic("trying to set block validator after start") + panic("trying to set recorder after start") } - if s.validator != nil { - panic("trying to set block validator when already set") + if s.recorder != nil { + panic("trying to set recorder policy when already set") } - s.validator = validator + s.recorder = recorder } func (s *ExecutionEngine) EnableReorgSequencing() { @@ -79,15 +79,18 @@ func (s *ExecutionEngine) EnableReorgSequencing() { func (s *ExecutionEngine) SetTransactionStreamer(streamer TransactionStreamerInterface) { if s.Started() { - panic("trying to set reorg sequencing policy after start") + panic("trying to set transaction streamer after start") } if s.streamer != nil { - panic("trying to set reorg sequencing policy when already set") + panic("trying to set transaction streamer when already set") } s.streamer = streamer } func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadata, oldMessages []*arbostypes.MessageWithMetadata) error { + if count == 0 { + return errors.New("cannot reorg out genesis") + } s.createBlocksMutex.Lock() resequencing := false defer func() { @@ -97,24 +100,15 @@ func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbost s.createBlocksMutex.Unlock() } }() - blockNum, err := s.MessageCountToBlockNumber(count) - if err != nil { - return err - } + blockNum := s.MessageIndexToBlockNumber(count - 1) // We can safely cast blockNum to a uint64 as it comes from MessageCountToBlockNumber targetBlock := s.bc.GetBlockByNumber(uint64(blockNum)) if targetBlock == nil { log.Warn("reorg target block not found", "block", blockNum) return nil } - if s.validator != nil { - err = s.validator.ReorgToBlock(targetBlock.NumberU64(), targetBlock.Hash()) - if err != nil { - return err - } - } - err = s.bc.ReorgToOldBlock(targetBlock) + err := s.bc.ReorgToOldBlock(targetBlock) if err != nil { return err } @@ -124,6 +118,9 @@ func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbost return err } } + if s.recorder != nil { + s.recorder.ReorgTo(targetBlock.Header()) + } if len(oldMessages) > 0 { s.resequenceChan <- oldMessages resequencing = true @@ -144,11 +141,7 @@ func (s *ExecutionEngine) HeadMessageNumber() (arbutil.MessageIndex, error) { if err != nil { return 0, err } - msgCount, err := s.BlockNumberToMessageCount(currentHeader.Number.Uint64()) - if err != nil { - return 0, err - } - return msgCount - 1, err + return s.BlockNumberToMessageIndex(currentHeader.Number.Uint64()) } func (s *ExecutionEngine) HeadMessageNumberSync(t *testing.T) (arbutil.MessageIndex, error) { @@ -347,7 +340,7 @@ func (s *ExecutionEngine) sequenceTransactionsWithBlockMutex(header *arbostypes. DelayedMessagesRead: delayedMessagesRead, } - pos, err := s.BlockNumberToMessageCount(lastBlockHeader.Number.Uint64()) + pos, err := s.BlockNumberToMessageIndex(lastBlockHeader.Number.Uint64() + 1) if err != nil { return nil, err } @@ -364,10 +357,6 @@ func (s *ExecutionEngine) sequenceTransactionsWithBlockMutex(header *arbostypes. return nil, err } - if s.validator != nil { - s.validator.NewBlock(block, lastBlockHeader, msgWithMeta) - } - return block, nil } @@ -386,7 +375,7 @@ func (s *ExecutionEngine) sequenceDelayedMessageWithBlockMutex(message *arbostyp expectedDelayed := currentHeader.Nonce.Uint64() - pos, err := s.BlockNumberToMessageCount(currentHeader.Number.Uint64()) + lastMsg, err := s.BlockNumberToMessageIndex(currentHeader.Number.Uint64()) if err != nil { return nil, err } @@ -400,7 +389,7 @@ func (s *ExecutionEngine) sequenceDelayedMessageWithBlockMutex(message *arbostyp DelayedMessagesRead: delayedSeqNum + 1, } - err = s.streamer.WriteMessageFromSequencer(pos, messageWithMeta) + err = s.streamer.WriteMessageFromSequencer(lastMsg+1, messageWithMeta) if err != nil { return nil, err } @@ -416,29 +405,25 @@ func (s *ExecutionEngine) sequenceDelayedMessageWithBlockMutex(message *arbostyp return nil, err } - log.Info("ExecutionEngine: Added DelayedMessages", "pos", pos, "delayed", delayedSeqNum, "block-header", block.Header()) + log.Info("ExecutionEngine: Added DelayedMessages", "pos", lastMsg+1, "delayed", delayedSeqNum, "block-header", block.Header()) return block, nil } -func (s *ExecutionEngine) GetGenesisBlockNumber() (uint64, error) { - return s.bc.Config().ArbitrumChainParams.GenesisBlockNum, nil +func (s *ExecutionEngine) GetGenesisBlockNumber() uint64 { + return s.bc.Config().ArbitrumChainParams.GenesisBlockNum } -func (s *ExecutionEngine) BlockNumberToMessageCount(blockNum uint64) (arbutil.MessageIndex, error) { - genesis, err := s.GetGenesisBlockNumber() - if err != nil { - return 0, err +func (s *ExecutionEngine) BlockNumberToMessageIndex(blockNum uint64) (arbutil.MessageIndex, error) { + genesis := s.GetGenesisBlockNumber() + if blockNum < genesis { + return 0, fmt.Errorf("blockNum %d < genesis %d", blockNum, genesis) } - return arbutil.BlockNumberToMessageCount(blockNum, genesis), nil + return arbutil.MessageIndex(blockNum - genesis), nil } -func (s *ExecutionEngine) MessageCountToBlockNumber(messageNum arbutil.MessageIndex) (int64, error) { - genesis, err := s.GetGenesisBlockNumber() - if err != nil { - return 0, err - } - return arbutil.MessageCountToBlockNumber(messageNum, genesis), nil +func (s *ExecutionEngine) MessageIndexToBlockNumber(messageNum arbutil.MessageIndex) uint64 { + return uint64(messageNum) + s.GetGenesisBlockNumber() } // must hold createBlockMutex @@ -494,6 +479,23 @@ func (s *ExecutionEngine) appendBlock(block *types.Block, statedb *state.StateDB return nil } +type MessageResult struct { + BlockHash common.Hash + SendRoot common.Hash +} + +func (s *ExecutionEngine) resultFromHeader(header *types.Header) (*MessageResult, error) { + if header == nil { + return nil, fmt.Errorf("result not found") + } + info := types.DeserializeHeaderExtraInformation(header) + return &MessageResult{header.Hash(), info.SendRoot}, nil +} + +func (s *ExecutionEngine) ResultAtPos(pos arbutil.MessageIndex) (*MessageResult, error) { + return s.resultFromHeader(s.bc.GetHeaderByNumber(s.MessageIndexToBlockNumber(pos))) +} + func (s *ExecutionEngine) DigestMessage(num arbutil.MessageIndex, msg *arbostypes.MessageWithMetadata) error { if !s.createBlocksMutex.TryLock() { return errors.New("createBlock mutex held") @@ -507,12 +509,12 @@ func (s *ExecutionEngine) digestMessageWithBlockMutex(num arbutil.MessageIndex, if err != nil { return err } - expNum, err := s.BlockNumberToMessageCount(currentHeader.Number.Uint64()) + curMsg, err := s.BlockNumberToMessageIndex(currentHeader.Number.Uint64()) if err != nil { return err } - if expNum != num { - return fmt.Errorf("wrong message number in digest got %d expected %d", num, expNum) + if curMsg+1 != num { + return fmt.Errorf("wrong message number in digest got %d expected %d", num, curMsg+1) } startTime := time.Now() @@ -526,10 +528,6 @@ func (s *ExecutionEngine) digestMessageWithBlockMutex(num arbutil.MessageIndex, return err } - if s.validator != nil { - s.validator.NewBlock(block, currentHeader, *msg) - } - if time.Now().After(s.nextScheduledVersionCheck) { s.nextScheduledVersionCheck = time.Now().Add(time.Minute) arbState, err := arbosState.OpenSystemArbosState(statedb, nil, true) diff --git a/arbnode/execution/node.go b/arbnode/execution/node.go index a0ff13bb29..7456c1d6a7 100644 --- a/arbnode/execution/node.go +++ b/arbnode/execution/node.go @@ -17,6 +17,7 @@ type ExecutionNode struct { FilterSystem *filters.FilterSystem ArbInterface *ArbInterface ExecEngine *ExecutionEngine + Recorder *BlockRecorder Sequencer *Sequencer // either nil or same as TxPublisher TxPublisher TransactionPublisher } @@ -30,6 +31,7 @@ func CreateExecutionNode( fwTarget string, fwConfig *ForwarderConfig, rpcConfig arbitrum.Config, + recordingDbConfig *arbitrum.RecordingDatabaseConfig, seqConfigFetcher SequencerConfigFetcher, precheckConfigFetcher TxPreCheckerConfigFetcher, ) (*ExecutionNode, error) { @@ -37,6 +39,7 @@ func CreateExecutionNode( if err != nil { return nil, err } + recorder := NewBlockRecorder(recordingDbConfig, execEngine, chainDB) var txPublisher TransactionPublisher var sequencer *Sequencer seqConfig := seqConfigFetcher() @@ -79,6 +82,7 @@ func CreateExecutionNode( filterSystem, arbInterface, execEngine, + recorder, sequencer, txPublisher, }, nil diff --git a/arbnode/inbox_test.go b/arbnode/inbox_test.go index e68cee49ff..21eef7499c 100644 --- a/arbnode/inbox_test.go +++ b/arbnode/inbox_test.go @@ -199,10 +199,17 @@ func TestTransactionStreamer(t *testing.T) { } // Check that state balances are consistent with blockchain's balances - lastBlockNumber := bc.CurrentHeader().Number.Uint64() expectedLastBlockNumber := blockStates[len(blockStates)-1].blockNumber - if lastBlockNumber != expectedLastBlockNumber { - Fail(t, "unexpected block number", lastBlockNumber, "vs", expectedLastBlockNumber) + for i := 0; ; i++ { + lastBlockNumber := bc.CurrentHeader().Number.Uint64() + if lastBlockNumber == expectedLastBlockNumber { + break + } else if lastBlockNumber > expectedLastBlockNumber { + Fail(t, "unexpected block number", lastBlockNumber, "vs", expectedLastBlockNumber) + } else if i == 10 { + Fail(t, "timeout waiting for block number", expectedLastBlockNumber, "current", lastBlockNumber) + } + time.Sleep(time.Millisecond * 100) } for _, state := range blockStates { diff --git a/arbnode/inbox_tracker.go b/arbnode/inbox_tracker.go index b6a1afd02b..c82e45fbee 100644 --- a/arbnode/inbox_tracker.go +++ b/arbnode/inbox_tracker.go @@ -127,7 +127,7 @@ func (t *InboxTracker) GetDelayedAcc(seqNum uint64) (common.Hash, error) { return common.Hash{}, err } if !hasKey { - return common.Hash{}, AccumulatorNotFoundErr + return common.Hash{}, fmt.Errorf("%w: not found delayed %d", AccumulatorNotFoundErr, seqNum) } } data, err := t.db.Get(key) @@ -175,7 +175,7 @@ func (t *InboxTracker) GetBatchMetadata(seqNum uint64) (BatchMetadata, error) { return BatchMetadata{}, err } if !hasKey { - return BatchMetadata{}, AccumulatorNotFoundErr + return BatchMetadata{}, fmt.Errorf("%w: no metadata for batch %d", AccumulatorNotFoundErr, seqNum) } data, err := t.db.Get(key) if err != nil { @@ -694,18 +694,6 @@ func (t *InboxTracker) AddSequencerBatches(ctx context.Context, client arbutil.L } t.batchMetaMutex.Unlock() - if t.validator != nil { - batchBytes := make([][]byte, 0, len(batches)) - for _, batch := range batches { - msg, err := batch.Serialize(ctx, client) - if err != nil { - return err - } - batchBytes = append(batchBytes, msg) - } - t.validator.ProcessBatches(startPos, batchBytes) - } - if t.txStreamer.broadcastServer != nil && pos > 1 { prevprevbatchmeta, err := t.GetBatchMetadata(pos - 2) if errors.Is(err, AccumulatorNotFoundErr) { diff --git a/arbnode/message_pruner.go b/arbnode/message_pruner.go index 1ba3886d8d..aeee07ca73 100644 --- a/arbnode/message_pruner.go +++ b/arbnode/message_pruner.go @@ -7,16 +7,16 @@ import ( "bytes" "context" "encoding/binary" - "math/big" + "fmt" + "sync" "time" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/rpc" - "github.com/offchainlabs/nitro/staker" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/offchainlabs/nitro/validator" flag "github.com/spf13/pflag" ) @@ -25,13 +25,15 @@ type MessagePruner struct { stopwaiter.StopWaiter transactionStreamer *TransactionStreamer inboxTracker *InboxTracker - staker *staker.Staker config MessagePrunerConfigFetcher + pruningLock sync.Mutex + lastPruneDone time.Time } type MessagePrunerConfig struct { Enable bool `koanf:"enable"` MessagePruneInterval time.Duration `koanf:"prune-interval" reload:"hot"` + MinBatchesLeft uint64 `koanf:"min-batches-left" reload:"hot"` } type MessagePrunerConfigFetcher func() *MessagePrunerConfig @@ -39,85 +41,96 @@ type MessagePrunerConfigFetcher func() *MessagePrunerConfig var DefaultMessagePrunerConfig = MessagePrunerConfig{ Enable: true, MessagePruneInterval: time.Minute, + MinBatchesLeft: 2, } func MessagePrunerConfigAddOptions(prefix string, f *flag.FlagSet) { f.Bool(prefix+".enable", DefaultMessagePrunerConfig.Enable, "enable message pruning") f.Duration(prefix+".prune-interval", DefaultMessagePrunerConfig.MessagePruneInterval, "interval for running message pruner") + f.Uint64(prefix+".min-batches-left", DefaultMessagePrunerConfig.MinBatchesLeft, "min number of batches not pruned") } -func NewMessagePruner(transactionStreamer *TransactionStreamer, inboxTracker *InboxTracker, staker *staker.Staker, config MessagePrunerConfigFetcher) *MessagePruner { +func NewMessagePruner(transactionStreamer *TransactionStreamer, inboxTracker *InboxTracker, config MessagePrunerConfigFetcher) *MessagePruner { return &MessagePruner{ transactionStreamer: transactionStreamer, inboxTracker: inboxTracker, - staker: staker, config: config, } } func (m *MessagePruner) Start(ctxIn context.Context) { m.StopWaiter.Start(ctxIn, m) - m.CallIteratively(m.prune) } -func (m *MessagePruner) prune(ctx context.Context) time.Duration { - latestConfirmedNode, err := m.staker.Rollup().LatestConfirmed( - &bind.CallOpts{ - Context: ctx, - BlockNumber: big.NewInt(int64(rpc.FinalizedBlockNumber)), - }) - if err != nil { - log.Error("error getting latest confirmed node", "err", err) - return m.config().MessagePruneInterval - } - nodeInfo, err := m.staker.Rollup().LookupNode(ctx, latestConfirmedNode) - if err != nil { - log.Error("error getting latest confirmed node info", "node", latestConfirmedNode, "err", err) - return m.config().MessagePruneInterval +func (m *MessagePruner) UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) { + locked := m.pruningLock.TryLock() + if !locked { + return } - endBatchCount := nodeInfo.Assertion.AfterState.GlobalState.Batch - if endBatchCount == 0 { - return m.config().MessagePruneInterval + + if m.lastPruneDone.Add(m.config().MessagePruneInterval).After(time.Now()) { + m.pruningLock.Unlock() + return } - endBatchMetadata, err := m.inboxTracker.GetBatchMetadata(endBatchCount - 1) + err := m.LaunchThreadSafe(func(ctx context.Context) { + defer m.pruningLock.Unlock() + err := m.prune(ctx, count, globalState) + if err != nil && ctx.Err() == nil { + log.Error("error while pruning", "err", err) + } + }) if err != nil { - log.Error("error getting last batch metadata", "batch", endBatchCount-1, "err", err) - return m.config().MessagePruneInterval + log.Info("failed launching prune thread", "err", err) + m.pruningLock.Unlock() } - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, m.inboxTracker.db, m.transactionStreamer.db) - return m.config().MessagePruneInterval } -func deleteOldMessageFromDB(endBatchCount uint64, endBatchMetadata BatchMetadata, inboxTrackerDb ethdb.Database, transactionStreamerDb ethdb.Database) { - prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(inboxTrackerDb, sequencerBatchMetaPrefix, endBatchCount) +func (m *MessagePruner) prune(ctx context.Context, count arbutil.MessageIndex, globalState validator.GoGlobalState) error { + trimBatchCount := globalState.Batch + minBatchesLeft := m.config().MinBatchesLeft + batchCount, err := m.inboxTracker.GetBatchCount() if err != nil { - log.Error("error deleting batch metadata", "err", err) - return + return err } - if len(prunedKeysRange) > 0 { - log.Info("Pruned batches:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) + if batchCount < trimBatchCount+minBatchesLeft { + if batchCount < minBatchesLeft { + return nil + } + trimBatchCount = batchCount - minBatchesLeft } + if trimBatchCount < 1 { + return nil + } + endBatchMetadata, err := m.inboxTracker.GetBatchMetadata(trimBatchCount - 1) + if err != nil { + return err + } + msgCount := endBatchMetadata.MessageCount + delayedCount := endBatchMetadata.DelayedMessageCount - prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(transactionStreamerDb, messagePrefix, uint64(endBatchMetadata.MessageCount)) + return deleteOldMessageFromDB(ctx, msgCount, delayedCount, m.inboxTracker.db, m.transactionStreamer.db) +} + +func deleteOldMessageFromDB(ctx context.Context, messageCount arbutil.MessageIndex, delayedMessageCount uint64, inboxTrackerDb ethdb.Database, transactionStreamerDb ethdb.Database) error { + prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(ctx, transactionStreamerDb, messagePrefix, uint64(messageCount)) if err != nil { - log.Error("error deleting last batch messages", "err", err) - return + return fmt.Errorf("error deleting last batch messages: %w", err) } if len(prunedKeysRange) > 0 { log.Info("Pruned last batch messages:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) } - prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(inboxTrackerDb, rlpDelayedMessagePrefix, endBatchMetadata.DelayedMessageCount) + prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(ctx, inboxTrackerDb, rlpDelayedMessagePrefix, delayedMessageCount) if err != nil { - log.Error("error deleting last batch delayed messages", "err", err) - return + return fmt.Errorf("error deleting last batch delayed messages: %w", err) } if len(prunedKeysRange) > 0 { log.Info("Pruned last batch delayed messages:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) } + return nil } -func deleteFromLastPrunedUptoEndKey(db ethdb.Database, prefix []byte, endMinKey uint64) ([][]byte, error) { +func deleteFromLastPrunedUptoEndKey(ctx context.Context, db ethdb.Database, prefix []byte, endMinKey uint64) ([]uint64, error) { startIter := db.NewIterator(prefix, uint64ToKey(1)) if !startIter.Next() { return nil, nil @@ -125,7 +138,7 @@ func deleteFromLastPrunedUptoEndKey(db ethdb.Database, prefix []byte, endMinKey startMinKey := binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) startIter.Release() if endMinKey > startMinKey { - return deleteFromRange(db, prefix, startMinKey, endMinKey-1) + return deleteFromRange(ctx, db, prefix, startMinKey, endMinKey-1) } return nil, nil } diff --git a/arbnode/message_pruner_test.go b/arbnode/message_pruner_test.go index 16c1d6b71c..c0cb2cb4fe 100644 --- a/arbnode/message_pruner_test.go +++ b/arbnode/message_pruner_test.go @@ -4,87 +4,88 @@ package arbnode import ( + "context" "testing" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/ethdb" + "github.com/offchainlabs/nitro/arbutil" ) func TestMessagePrunerWithPruningEligibleMessagePresent(t *testing.T) { - endBatchCount := uint64(2 * 100 * 1024) - endBatchMetadata := BatchMetadata{ - MessageCount: 2 * 100 * 1024, - DelayedMessageCount: 2 * 100 * 1024, - } - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, endBatchCount, endBatchMetadata) - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(2 * 100 * 1024) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*100*1024, 2*100*1024) + err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + Require(t, err) - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) - checkDbKeys(t, uint64(endBatchMetadata.MessageCount), transactionStreamerDb, messagePrefix) - checkDbKeys(t, endBatchMetadata.DelayedMessageCount, inboxTrackerDb, rlpDelayedMessagePrefix) + checkDbKeys(t, messagesCount, transactionStreamerDb, messagePrefix) + checkDbKeys(t, messagesCount, inboxTrackerDb, rlpDelayedMessagePrefix) } func TestMessagePrunerTraverseEachMessageOnlyOnce(t *testing.T) { - endBatchCount := uint64(10) - endBatchMetadata := BatchMetadata{} - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, endBatchCount, endBatchMetadata) - // In first iteration message till endBatchCount are tried to be deleted. - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) - // In first iteration all the message till endBatchCount are deleted. - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) - // After first iteration endBatchCount/2 is reinserted in inbox db - err := inboxTrackerDb.Put(dbKey(sequencerBatchMetaPrefix, endBatchCount/2), []byte{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(10) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, messagesCount, messagesCount) + // In first iteration message till messagesCount are tried to be deleted. + err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) Require(t, err) - // In second iteration message till endBatchCount are again tried to be deleted. - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) - // In second iteration all the message till endBatchCount are deleted again. - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) + // After first iteration messagesCount/2 is reinserted in inbox db + err = inboxTrackerDb.Put(dbKey(messagePrefix, messagesCount/2), []byte{}) + Require(t, err) + // In second iteration message till messagesCount are again tried to be deleted. + err = deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + Require(t, err) + // In second iteration all the message till messagesCount are deleted again. + checkDbKeys(t, messagesCount, transactionStreamerDb, messagePrefix) } func TestMessagePrunerPruneTillLessThenEqualTo(t *testing.T) { - endBatchCount := uint64(10) - endBatchMetadata := BatchMetadata{} - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*endBatchCount, endBatchMetadata) - err := inboxTrackerDb.Delete(dbKey(sequencerBatchMetaPrefix, 9)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(10) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, 2*messagesCount, 20) + err := inboxTrackerDb.Delete(dbKey(messagePrefix, 9)) + Require(t, err) + err = deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) Require(t, err) - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) - hasKey, err := inboxTrackerDb.Has(dbKey(sequencerBatchMetaPrefix, 10)) + hasKey, err := transactionStreamerDb.Has(dbKey(messagePrefix, messagesCount)) Require(t, err) if !hasKey { - Fail(t, "Key", 10, "with prefix", string(sequencerBatchMetaPrefix), "should be present after pruning") + Fail(t, "Key", 10, "with prefix", string(messagePrefix), "should be present after pruning") } } func TestMessagePrunerWithNoPruningEligibleMessagePresent(t *testing.T) { - endBatchCount := uint64(2) - endBatchMetadata := BatchMetadata{ - MessageCount: 2, - DelayedMessageCount: 2, - } - inboxTrackerDb, transactionStreamerDb := setupDatabase(t, endBatchCount, endBatchMetadata) - deleteOldMessageFromDB(endBatchCount, endBatchMetadata, inboxTrackerDb, transactionStreamerDb) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesCount := uint64(10) + inboxTrackerDb, transactionStreamerDb := setupDatabase(t, messagesCount, messagesCount) + err := deleteOldMessageFromDB(ctx, arbutil.MessageIndex(messagesCount), messagesCount, inboxTrackerDb, transactionStreamerDb) + Require(t, err) - checkDbKeys(t, endBatchCount, inboxTrackerDb, sequencerBatchMetaPrefix) - checkDbKeys(t, uint64(endBatchMetadata.MessageCount), transactionStreamerDb, messagePrefix) - checkDbKeys(t, endBatchMetadata.DelayedMessageCount, inboxTrackerDb, rlpDelayedMessagePrefix) + checkDbKeys(t, uint64(messagesCount), transactionStreamerDb, messagePrefix) + checkDbKeys(t, messagesCount, inboxTrackerDb, rlpDelayedMessagePrefix) } -func setupDatabase(t *testing.T, endBatchCount uint64, endBatchMetadata BatchMetadata) (ethdb.Database, ethdb.Database) { - inboxTrackerDb := rawdb.NewMemoryDatabase() - for i := uint64(0); i < endBatchCount; i++ { - err := inboxTrackerDb.Put(dbKey(sequencerBatchMetaPrefix, i), []byte{}) - Require(t, err) - } +func setupDatabase(t *testing.T, messageCount, delayedMessageCount uint64) (ethdb.Database, ethdb.Database) { transactionStreamerDb := rawdb.NewMemoryDatabase() - for i := uint64(0); i < uint64(endBatchMetadata.MessageCount); i++ { + for i := uint64(0); i < uint64(messageCount); i++ { err := transactionStreamerDb.Put(dbKey(messagePrefix, i), []byte{}) Require(t, err) } - for i := uint64(0); i < endBatchMetadata.DelayedMessageCount; i++ { + inboxTrackerDb := rawdb.NewMemoryDatabase() + for i := uint64(0); i < delayedMessageCount; i++ { err := inboxTrackerDb.Put(dbKey(rlpDelayedMessagePrefix, i), []byte{}) Require(t, err) } @@ -93,6 +94,7 @@ func setupDatabase(t *testing.T, endBatchCount uint64, endBatchMetadata BatchMet } func checkDbKeys(t *testing.T, endCount uint64, db ethdb.Database, prefix []byte) { + t.Helper() for i := uint64(0); i < endCount; i++ { hasKey, err := db.Has(dbKey(prefix, i)) Require(t, err) diff --git a/arbnode/node.go b/arbnode/node.go index f1fc8e9e63..9178baf65d 100644 --- a/arbnode/node.go +++ b/arbnode/node.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/rpc" "github.com/offchainlabs/nitro/arbnode/execution" + "github.com/offchainlabs/nitro/arbnode/resourcemanager" "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/broadcastclient" "github.com/offchainlabs/nitro/broadcastclients" @@ -307,28 +308,30 @@ func DeployOnL1(ctx context.Context, l1client arbutil.L1Interface, deployAuth *b } type Config struct { - RPC arbitrum.Config `koanf:"rpc"` - Sequencer execution.SequencerConfig `koanf:"sequencer" reload:"hot"` - L1Reader headerreader.Config `koanf:"parent-chain-reader" reload:"hot"` - InboxReader InboxReaderConfig `koanf:"inbox-reader" reload:"hot"` - DelayedSequencer DelayedSequencerConfig `koanf:"delayed-sequencer" reload:"hot"` - BatchPoster BatchPosterConfig `koanf:"batch-poster" reload:"hot"` - MessagePruner MessagePrunerConfig `koanf:"message-pruner" reload:"hot"` - ForwardingTargetImpl string `koanf:"forwarding-target"` - Forwarder execution.ForwarderConfig `koanf:"forwarder"` - TxPreChecker execution.TxPreCheckerConfig `koanf:"tx-pre-checker" reload:"hot"` - BlockValidator staker.BlockValidatorConfig `koanf:"block-validator" reload:"hot"` - Feed broadcastclient.FeedConfig `koanf:"feed" reload:"hot"` - Staker staker.L1ValidatorConfig `koanf:"staker"` - SeqCoordinator SeqCoordinatorConfig `koanf:"seq-coordinator"` - DataAvailability das.DataAvailabilityConfig `koanf:"data-availability"` - SyncMonitor SyncMonitorConfig `koanf:"sync-monitor"` - Dangerous DangerousConfig `koanf:"dangerous"` - Caching execution.CachingConfig `koanf:"caching"` - Archive bool `koanf:"archive"` - TxLookupLimit uint64 `koanf:"tx-lookup-limit"` - TransactionStreamer TransactionStreamerConfig `koanf:"transaction-streamer" reload:"hot"` - Maintenance MaintenanceConfig `koanf:"maintenance" reload:"hot"` + RPC arbitrum.Config `koanf:"rpc"` + Sequencer execution.SequencerConfig `koanf:"sequencer" reload:"hot"` + L1Reader headerreader.Config `koanf:"parent-chain-reader" reload:"hot"` + InboxReader InboxReaderConfig `koanf:"inbox-reader" reload:"hot"` + DelayedSequencer DelayedSequencerConfig `koanf:"delayed-sequencer" reload:"hot"` + BatchPoster BatchPosterConfig `koanf:"batch-poster" reload:"hot"` + MessagePruner MessagePrunerConfig `koanf:"message-pruner" reload:"hot"` + ForwardingTargetImpl string `koanf:"forwarding-target"` + Forwarder execution.ForwarderConfig `koanf:"forwarder"` + TxPreChecker execution.TxPreCheckerConfig `koanf:"tx-pre-checker" reload:"hot"` + BlockValidator staker.BlockValidatorConfig `koanf:"block-validator" reload:"hot"` + RecordingDB arbitrum.RecordingDatabaseConfig `koanf:"recording-database"` + Feed broadcastclient.FeedConfig `koanf:"feed" reload:"hot"` + Staker staker.L1ValidatorConfig `koanf:"staker"` + SeqCoordinator SeqCoordinatorConfig `koanf:"seq-coordinator"` + DataAvailability das.DataAvailabilityConfig `koanf:"data-availability"` + SyncMonitor SyncMonitorConfig `koanf:"sync-monitor"` + Dangerous DangerousConfig `koanf:"dangerous"` + Caching execution.CachingConfig `koanf:"caching"` + Archive bool `koanf:"archive"` + TxLookupLimit uint64 `koanf:"tx-lookup-limit"` + TransactionStreamer TransactionStreamerConfig `koanf:"transaction-streamer" reload:"hot"` + Maintenance MaintenanceConfig `koanf:"maintenance" reload:"hot"` + ResourceManagement resourcemanager.Config `koanf:"resource-mgmt" reload:"hot"` } func (c *Config) Validate() error { @@ -392,6 +395,7 @@ func ConfigAddOptions(prefix string, f *flag.FlagSet, feedInputEnable bool, feed execution.AddOptionsForNodeForwarderConfig(prefix+".forwarder", f) execution.TxPreCheckerConfigAddOptions(prefix+".tx-pre-checker", f) staker.BlockValidatorConfigAddOptions(prefix+".block-validator", f) + arbitrum.RecordingDatabaseConfigAddOptions(prefix+".recording-database", f) broadcastclient.FeedConfigAddOptions(prefix+".feed", f, feedInputEnable, feedOutputEnable) staker.L1ValidatorConfigAddOptions(prefix+".staker", f) SeqCoordinatorConfigAddOptions(prefix+".seq-coordinator", f) @@ -402,6 +406,7 @@ func ConfigAddOptions(prefix string, f *flag.FlagSet, feedInputEnable bool, feed f.Uint64(prefix+".tx-lookup-limit", ConfigDefault.TxLookupLimit, "retain the ability to lookup transactions by hash for the past N blocks (0 = all blocks)") TransactionStreamerConfigAddOptions(prefix+".transaction-streamer", f) MaintenanceConfigAddOptions(prefix+".maintenance", f) + resourcemanager.ConfigAddOptions(prefix+".resource-mgmt", f) archiveMsg := fmt.Sprintf("retain past block state (deprecated, please use %v.caching.archive)", prefix) f.Bool(prefix+".archive", ConfigDefault.Archive, archiveMsg) @@ -418,6 +423,7 @@ var ConfigDefault = Config{ ForwardingTargetImpl: "", TxPreChecker: execution.DefaultTxPreCheckerConfig, BlockValidator: staker.DefaultBlockValidatorConfig, + RecordingDB: arbitrum.DefaultRecordingDatabaseConfig, Feed: broadcastclient.FeedConfigDefault, Staker: staker.DefaultL1ValidatorConfig, SeqCoordinator: DefaultSeqCoordinatorConfig, @@ -428,6 +434,7 @@ var ConfigDefault = Config{ TxLookupLimit: 126_230_400, // 1 year at 4 blocks per second Caching: execution.DefaultCachingConfig, TransactionStreamer: DefaultTransactionStreamerConfig, + ResourceManagement: resourcemanager.DefaultConfig, } func ConfigDefaultL1Test() *Config { @@ -474,18 +481,15 @@ func ConfigDefaultL2Test() *Config { } type DangerousConfig struct { - NoL1Listener bool `koanf:"no-l1-listener"` - ReorgToBlock int64 `koanf:"reorg-to-block"` + NoL1Listener bool `koanf:"no-l1-listener"` } var DefaultDangerousConfig = DangerousConfig{ NoL1Listener: false, - ReorgToBlock: -1, } func DangerousConfigAddOptions(prefix string, f *flag.FlagSet) { f.Bool(prefix+".no-l1-listener", DefaultDangerousConfig.NoL1Listener, "DANGEROUS! disables listening to L1. To be used in test nodes only") - f.Int64(prefix+".reorg-to-block", DefaultDangerousConfig.ReorgToBlock, "DANGEROUS! forces a reorg to an old block height. To be used for testing only. -1 to disable") } type Node struct { @@ -585,14 +589,6 @@ func createNodeImpl( l2Config := l2BlockChain.Config() l2ChainId := l2Config.ChainID.Uint64() - var reorgingToBlock *types.Block - if config.Dangerous.ReorgToBlock >= 0 { - reorgingToBlock, err = execution.ReorgToBlock(l2BlockChain, uint64(config.Dangerous.ReorgToBlock)) - if err != nil { - return nil, err - } - } - syncMonitor := NewSyncMonitor(&config.SyncMonitor) var classicOutbox *ClassicOutboxRetriever classicMsgDb, err := stack.OpenDatabase("classic-msg", 0, 0, "", true) @@ -616,7 +612,7 @@ func createNodeImpl( sequencerConfigFetcher := func() *execution.SequencerConfig { return &configFetcher.Get().Sequencer } txprecheckConfigFetcher := func() *execution.TxPreCheckerConfig { return &configFetcher.Get().TxPreChecker } exec, err := execution.CreateExecutionNode(stack, chainDb, l2BlockChain, l1Reader, syncMonitor, - config.ForwardingTarget(), &config.Forwarder, config.RPC, + config.ForwardingTarget(), &config.Forwarder, config.RPC, &config.RecordingDB, sequencerConfigFetcher, txprecheckConfigFetcher) if err != nil { return nil, err @@ -770,8 +766,7 @@ func createNodeImpl( inboxReader, inboxTracker, txStreamer, - l2BlockChain, - chainDb, + exec.Recorder, rawdb.NewTable(arbDb, BlockValidatorPrefix), daReader, func() *staker.BlockValidatorConfig { return &configFetcher.Get().BlockValidator }, @@ -794,7 +789,6 @@ func createNodeImpl( statelessBlockValidator, inboxTracker, txStreamer, - reorgingToBlock, func() *staker.BlockValidatorConfig { return &configFetcher.Get().BlockValidator }, fatalErrChan, ) @@ -804,6 +798,8 @@ func createNodeImpl( } var stakerObj *staker.Staker + var messagePruner *MessagePruner + if config.Staker.Enable { var wallet staker.ValidatorWalletInterface if config.Staker.UseSmartContractWallet || txOptsValidator == nil { @@ -830,7 +826,13 @@ func createNodeImpl( } } - stakerObj, err = staker.NewStaker(l1Reader, wallet, bind.CallOpts{}, config.Staker, blockValidator, statelessBlockValidator, deployInfo.ValidatorUtils) + notifiers := make([]staker.LatestStakedNotifier, 0) + if config.MessagePruner.Enable && !config.Caching.Archive { + messagePruner = NewMessagePruner(txStreamer, inboxTracker, func() *MessagePrunerConfig { return &configFetcher.Get().MessagePruner }) + notifiers = append(notifiers, messagePruner) + } + + stakerObj, err = staker.NewStaker(l1Reader, wallet, bind.CallOpts{}, config.Staker, blockValidator, statelessBlockValidator, notifiers, deployInfo.ValidatorUtils, fatalErrChan) if err != nil { return nil, err } @@ -862,10 +864,6 @@ func createNodeImpl( return nil, err } } - var messagePruner *MessagePruner - if config.MessagePruner.Enable && !config.Caching.Archive && stakerObj != nil { - messagePruner = NewMessagePruner(txStreamer, inboxTracker, stakerObj, func() *MessagePrunerConfig { return &configFetcher.Get().MessagePruner }) - } // always create DelayedSequencer, it won't do anything if it is disabled delayedSequencer, err = NewDelayedSequencer(l1Reader, inboxReader, exec.ExecEngine, coordinator, func() *DelayedSequencerConfig { return &configFetcher.Get().DelayedSequencer }) if err != nil { @@ -1149,6 +1147,7 @@ func (n *Node) StopAndWait() { if n.StatelessBlockValidator != nil { n.StatelessBlockValidator.Stop() } + n.Execution.Recorder.OrderlyShutdown() if n.InboxReader != nil && n.InboxReader.Started() { n.InboxReader.StopAndWait() } @@ -1176,19 +1175,3 @@ func (n *Node) StopAndWait() { log.Error("error on stak close", "err", err) } } - -func CreateDefaultStackForTest(dataDir string) (*node.Node, error) { - stackConf := node.DefaultConfig - var err error - stackConf.DataDir = dataDir - stackConf.HTTPHost = "" - stackConf.HTTPModules = append(stackConf.HTTPModules, "eth") - stackConf.P2P.NoDiscovery = true - stackConf.P2P.ListenAddr = "" - - stack, err := node.New(&stackConf) - if err != nil { - return nil, fmt.Errorf("error creating protocol stack: %w", err) - } - return stack, nil -} diff --git a/arbnode/resourcemanager/resource_management.go b/arbnode/resourcemanager/resource_management.go new file mode 100644 index 0000000000..acb5355987 --- /dev/null +++ b/arbnode/resourcemanager/resource_management.go @@ -0,0 +1,207 @@ +// Copyright 2023, Offchain Labs, Inc. +// For license information, see https://github.com/nitro/blob/master/LICENSE + +package resourcemanager + +import ( + "bufio" + "errors" + "fmt" + "net/http" + "os" + "regexp" + "strconv" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/node" + "github.com/spf13/pflag" +) + +var ( + limitCheckDurationHistogram = metrics.NewRegisteredHistogram("arb/rpc/limitcheck/duration", nil, metrics.NewBoundedHistogramSample()) + limitCheckSuccessCounter = metrics.NewRegisteredCounter("arb/rpc/limitcheck/success", nil) + limitCheckFailureCounter = metrics.NewRegisteredCounter("arb/rpc/limitcheck/failure", nil) +) + +// Init adds the resource manager's httpServer to a custom hook in geth. +// Geth will add it to the stack of http.Handlers so that it is run +// prior to RPC request handling. +// +// Must be run before the go-ethereum stack is set up (ethereum/go-ethereum/node.New). +func Init(conf *Config) { + if conf.MemoryLimitPercent > 0 { + node.WrapHTTPHandler = func(srv http.Handler) (http.Handler, error) { + return newHttpServer(srv, newLimitChecker(conf)), nil + } + } +} + +// Config contains the configuration for resourcemanager functionality. +// Currently only a memory limit is supported, other limits may be added +// in the future. +type Config struct { + MemoryLimitPercent int `koanf:"mem-limit-percent" reload:"hot"` +} + +// DefaultConfig has the defaul resourcemanager configuration, +// all limits are disabled. +var DefaultConfig = Config{ + MemoryLimitPercent: 0, +} + +// ConfigAddOptions adds the configuration options for resourcemanager. +func ConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.Int(prefix+".mem-limit-percent", DefaultConfig.MemoryLimitPercent, "RPC calls are throttled if system memory utilization exceeds this percent value, zero (default) is disabled") +} + +// httpServer implements http.Handler and wraps calls to inner with a resource +// limit check. +type httpServer struct { + inner http.Handler + c limitChecker +} + +func newHttpServer(inner http.Handler, c limitChecker) *httpServer { + return &httpServer{inner: inner, c: c} +} + +// ServeHTTP passes req to inner unless any configured system resource +// limit is exceeded, in which case it returns a HTTP 429 error. +func (s *httpServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + start := time.Now() + exceeded, err := s.c.isLimitExceeded() + limitCheckDurationHistogram.Update(time.Since(start).Nanoseconds()) + if err != nil { + log.Error("Error checking memory limit", "err", err, "checker", s.c) + } else if exceeded { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + limitCheckFailureCounter.Inc(1) + return + } + + limitCheckSuccessCounter.Inc(1) + s.inner.ServeHTTP(w, req) +} + +type limitChecker interface { + isLimitExceeded() (bool, error) + String() string +} + +// newLimitChecker attempts to auto-discover the mechanism by which it +// can check system limits. Currently Cgroups V1 is supported, +// with Cgroups V2 likely to be implmemented next. If no supported +// mechanism is discovered, it logs an error and fails open, ie +// it creates a trivialLimitChecker that does no checks. +func newLimitChecker(conf *Config) limitChecker { + c := newCgroupsV1MemoryLimitChecker(DefaultCgroupsV1MemoryDirectory, conf.MemoryLimitPercent) + if isSupported(c) { + log.Info("Cgroups v1 detected, enabling memory limit RPC throttling") + return c + } + + log.Error("No method for determining memory usage and limits was discovered, disabled memory limit RPC throttling") + return &trivialLimitChecker{} +} + +// trivialLimitChecker checks no limits, so its limits are never exceeded. +type trivialLimitChecker struct{} + +func (_ trivialLimitChecker) isLimitExceeded() (bool, error) { + return false, nil +} + +func (_ trivialLimitChecker) String() string { return "trivial" } + +const DefaultCgroupsV1MemoryDirectory = "/sys/fs/cgroup/memory/" + +type cgroupsV1MemoryLimitChecker struct { + cgroupDir string + memoryLimitPercent int + + limitFile, usageFile, statsFile string +} + +func newCgroupsV1MemoryLimitChecker(cgroupDir string, memoryLimitPercent int) *cgroupsV1MemoryLimitChecker { + return &cgroupsV1MemoryLimitChecker{ + cgroupDir: cgroupDir, + memoryLimitPercent: memoryLimitPercent, + limitFile: cgroupDir + "/memory.limit_in_bytes", + usageFile: cgroupDir + "/memory.usage_in_bytes", + statsFile: cgroupDir + "/memory.stat", + } +} + +func isSupported(c limitChecker) bool { + _, err := c.isLimitExceeded() + return err == nil +} + +// isLimitExceeded checks if the system memory used exceeds the limit +// scaled by the configured memoryLimitPercent. +// +// See the following page for details of calculating the memory used, +// which is reported as container_memory_working_set_bytes in prometheus: +// https://mihai-albert.com/2022/02/13/out-of-memory-oom-in-kubernetes-part-3-memory-metrics-sources-and-tools-to-collect-them/ +func (c *cgroupsV1MemoryLimitChecker) isLimitExceeded() (bool, error) { + var limit, usage, inactive int + var err error + limit, err = readIntFromFile(c.limitFile) + if err != nil { + return false, err + } + usage, err = readIntFromFile(c.usageFile) + if err != nil { + return false, err + } + inactive, err = readInactive(c.statsFile) + if err != nil { + return false, err + } + return usage-inactive >= ((limit * c.memoryLimitPercent) / 100), nil +} + +func readIntFromFile(fileName string) (int, error) { + file, err := os.Open(fileName) + if err != nil { + return 0, err + } + + var limit int + if _, err = fmt.Fscanf(file, "%d", &limit); err != nil { + return 0, err + } + return limit, nil +} + +var re = regexp.MustCompile(`total_inactive_file (\d+)`) + +func readInactive(fileName string) (int, error) { + file, err := os.Open(fileName) + if err != nil { + return 0, err + } + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + + matches := re.FindStringSubmatch(line) + + if len(matches) >= 2 { + inactive, err := strconv.Atoi(matches[1]) + if err != nil { + return 0, err + } + return inactive, nil + } + } + + return 0, errors.New("total_inactive_file not found in " + fileName) +} + +func (c cgroupsV1MemoryLimitChecker) String() string { + return "CgroupsV1MemoryLimitChecker" +} diff --git a/arbnode/resourcemanager/resource_management_test.go b/arbnode/resourcemanager/resource_management_test.go new file mode 100644 index 0000000000..fe470e706b --- /dev/null +++ b/arbnode/resourcemanager/resource_management_test.go @@ -0,0 +1,78 @@ +// Copyright 2023, Offchain Labs, Inc. +// For license information, see https://github.com/nitro/blob/master/LICENSE + +package resourcemanager + +import ( + "fmt" + "os" + "testing" +) + +func updateFakeCgroupv1Files(c *cgroupsV1MemoryLimitChecker, limit, usage, inactive int) error { + limitFile, err := os.Create(c.limitFile) + if err != nil { + return err + } + _, err = fmt.Fprintf(limitFile, "%d\n", limit) + if err != nil { + return err + } + + usageFile, err := os.Create(c.usageFile) + if err != nil { + return err + } + _, err = fmt.Fprintf(usageFile, "%d\n", usage) + if err != nil { + return err + } + + statsFile, err := os.Create(c.statsFile) + if err != nil { + return err + } + _, err = fmt.Fprintf(statsFile, `total_cache 1029980160 +total_rss 1016209408 +total_inactive_file %d +total_active_file 321544192 +`, inactive) + if err != nil { + return err + } + return nil +} + +func TestCgroupsv1MemoryLimit(t *testing.T) { + cgroupDir := t.TempDir() + c := newCgroupsV1MemoryLimitChecker(cgroupDir, 95) + _, err := c.isLimitExceeded() + if err == nil { + t.Error("Should fail open if can't read files") + } + + err = updateFakeCgroupv1Files(c, 1000, 1000, 51) + if err != nil { + t.Error(err) + } + exceeded, err := c.isLimitExceeded() + if err != nil { + t.Error(err) + } + if exceeded { + t.Error("Expected under limit") + } + + err = updateFakeCgroupv1Files(c, 1000, 1000, 50) + if err != nil { + t.Error(err) + } + exceeded, err = c.isLimitExceeded() + if err != nil { + t.Error(err) + } + if !exceeded { + t.Error("Expected over limit") + } + +} diff --git a/arbnode/seq_coordinator.go b/arbnode/seq_coordinator.go index ecb38129ac..31cab83b1f 100644 --- a/arbnode/seq_coordinator.go +++ b/arbnode/seq_coordinator.go @@ -619,12 +619,12 @@ func (c *SeqCoordinator) update(ctx context.Context) time.Duration { log.Error("myurl main sequencer, but no sequencer exists") return c.noRedisError() } - processedMessages, err := c.streamer.exec.HeadMessageNumber() + processedMessages, err := c.streamer.GetProcessedMessageCount() if err != nil { log.Warn("coordinator: failed to read processed message count", "err", err) processedMessages = 0 } - if processedMessages+1 >= localMsgCount { + if processedMessages >= localMsgCount { // we're here because we don't currently hold the lock // sequencer is already either paused or forwarding c.sequencer.Pause() diff --git a/arbnode/sync_monitor.go b/arbnode/sync_monitor.go index a9380e65f6..d01c300fa9 100644 --- a/arbnode/sync_monitor.go +++ b/arbnode/sync_monitor.go @@ -70,13 +70,8 @@ func (s *SyncMonitor) SyncProgressMap() map[string]interface{} { syncing = true builtMessageCount = 0 } else { - blockNum, err := s.txStreamer.exec.MessageCountToBlockNumber(builtMessageCount) - if err != nil { - res["blockBuiltErr"] = err - syncing = true - } else { - res["blockNum"] = blockNum - } + blockNum := s.txStreamer.exec.MessageIndexToBlockNumber(builtMessageCount) + res["blockNum"] = blockNum builtMessageCount++ res["messageOfLastBlock"] = builtMessageCount } @@ -155,8 +150,8 @@ func (s *SyncMonitor) SafeBlockNumber(ctx context.Context) (uint64, error) { if err != nil { return 0, err } - block, err := s.txStreamer.exec.MessageCountToBlockNumber(msg) - return uint64(block), err + block := s.txStreamer.exec.MessageIndexToBlockNumber(msg - 1) + return block, nil } func (s *SyncMonitor) FinalizedBlockNumber(ctx context.Context) (uint64, error) { @@ -167,8 +162,8 @@ func (s *SyncMonitor) FinalizedBlockNumber(ctx context.Context) (uint64, error) if err != nil { return 0, err } - block, err := s.txStreamer.exec.MessageCountToBlockNumber(msg) - return uint64(block), err + block := s.txStreamer.exec.MessageIndexToBlockNumber(msg - 1) + return block, nil } func (s *SyncMonitor) Synced() bool { diff --git a/arbnode/transaction_streamer.go b/arbnode/transaction_streamer.go index a6a11b0b84..7752c69e8e 100644 --- a/arbnode/transaction_streamer.go +++ b/arbnode/transaction_streamer.go @@ -46,6 +46,7 @@ type TransactionStreamer struct { chainConfig *params.ChainConfig exec *execution.ExecutionEngine execLastMsgCount arbutil.MessageIndex + validator *staker.BlockValidator db ethdb.Database fatalErrChan chan<- error @@ -128,7 +129,13 @@ func uint64ToKey(x uint64) []byte { } func (s *TransactionStreamer) SetBlockValidator(validator *staker.BlockValidator) { - s.exec.SetBlockValidator(validator) + if s.Started() { + panic("trying to set coordinator after start") + } + if s.validator != nil { + panic("trying to set coordinator when already set") + } + s.validator = validator } func (s *TransactionStreamer) SetSeqCoordinator(coordinator *SeqCoordinator) { @@ -199,19 +206,24 @@ func deleteStartingAt(db ethdb.Database, batch ethdb.Batch, prefix []byte, minKe } // deleteFromRange deletes key ranging from startMinKey(inclusive) to endMinKey(exclusive) -func deleteFromRange(db ethdb.Database, prefix []byte, startMinKey uint64, endMinKey uint64) ([][]byte, error) { +// might have deleted some keys even if returning an error +func deleteFromRange(ctx context.Context, db ethdb.Database, prefix []byte, startMinKey uint64, endMinKey uint64) ([]uint64, error) { batch := db.NewBatch() startIter := db.NewIterator(prefix, uint64ToKey(startMinKey)) defer startIter.Release() - var prunedKeysRange [][]byte + var prunedKeysRange []uint64 for startIter.Next() { - if binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) >= endMinKey { + if ctx.Err() != nil { + return nil, ctx.Err() + } + currentKey := binary.BigEndian.Uint64(bytes.TrimPrefix(startIter.Key(), prefix)) + if currentKey >= endMinKey { break } if len(prunedKeysRange) == 0 || len(prunedKeysRange) == 1 { - prunedKeysRange = append(prunedKeysRange, startIter.Key()) + prunedKeysRange = append(prunedKeysRange, currentKey) } else { - prunedKeysRange[1] = startIter.Key() + prunedKeysRange[1] = currentKey } err := batch.Delete(startIter.Key()) if err != nil { @@ -328,6 +340,13 @@ func (s *TransactionStreamer) reorg(batch ethdb.Batch, count arbutil.MessageInde return err } + if s.validator != nil { + err = s.validator.Reorg(s.GetContext(), count) + if err != nil { + return err + } + } + err = deleteStartingAt(s.db, batch, messagePrefix, uint64ToKey(uint64(count))) if err != nil { return err @@ -387,6 +406,21 @@ func (s *TransactionStreamer) GetMessageCount() (arbutil.MessageIndex, error) { return arbutil.MessageIndex(pos), nil } +func (s *TransactionStreamer) GetProcessedMessageCount() (arbutil.MessageIndex, error) { + msgCount, err := s.GetMessageCount() + if err != nil { + return 0, err + } + digestedHead, err := s.exec.HeadMessageNumber() + if err != nil { + return 0, err + } + if msgCount > digestedHead+1 { + return digestedHead + 1, nil + } + return msgCount, nil +} + func (s *TransactionStreamer) AddMessages(pos arbutil.MessageIndex, messagesAreConfirmed bool, messages []arbostypes.MessageWithMetadata) error { return s.AddMessagesAndEndBatch(pos, messagesAreConfirmed, messages, nil) } @@ -883,6 +917,14 @@ func (s *TransactionStreamer) writeMessages(pos arbutil.MessageIndex, messages [ return nil } +// TODO: eventually there will be a table maintained by txStreamer itself +func (s *TransactionStreamer) ResultAtCount(count arbutil.MessageIndex) (*execution.MessageResult, error) { + if count == 0 { + return &execution.MessageResult{}, nil + } + return s.exec.ResultAtPos(count - 1) +} + // return value: true if should be called again immediately func (s *TransactionStreamer) executeNextMsg(ctx context.Context, exec *execution.ExecutionEngine) bool { if ctx.Err() != nil { diff --git a/cmd/nitro/init.go b/cmd/nitro/init.go index 2284695961..fcc27287d0 100644 --- a/cmd/nitro/init.go +++ b/cmd/nitro/init.go @@ -58,6 +58,7 @@ type InitConfig struct { ThenQuit bool `koanf:"then-quit"` Prune string `koanf:"prune"` PruneBloomSize uint64 `koanf:"prune-bloom-size"` + ResetToMsg int64 `koanf:"reset-to-message"` } var InitConfigDefault = InitConfig{ @@ -73,6 +74,7 @@ var InitConfigDefault = InitConfig{ ThenQuit: false, Prune: "", PruneBloomSize: 2048, + ResetToMsg: -1, } func InitConfigAddOptions(prefix string, f *flag.FlagSet) { @@ -89,6 +91,7 @@ func InitConfigAddOptions(prefix string, f *flag.FlagSet) { f.Uint(prefix+".accounts-per-sync", InitConfigDefault.AccountsPerSync, "during init - sync database every X accounts. Lower value for low-memory systems. 0 disables.") f.String(prefix+".prune", InitConfigDefault.Prune, "pruning for a given use: \"full\" for full nodes serving RPC requests, or \"validator\" for validators") f.Uint64(prefix+".prune-bloom-size", InitConfigDefault.PruneBloomSize, "the amount of memory in megabytes to use for the pruning bloom filter (higher values prune better)") + f.Int64(prefix+".reset-to-message", InitConfigDefault.ResetToMsg, "forces a reset to an old message height. Also set max-reorg-resequence-depth=0 to force re-reading messages") } func downloadInit(ctx context.Context, initConfig *InitConfig) (string, error) { @@ -327,19 +330,23 @@ func findImportantRoots(ctx context.Context, chainDb ethdb.Database, stack *node } validatorDb := rawdb.NewTable(arbDb, arbnode.BlockValidatorPrefix) - lastValidated, err := staker.ReadLastValidatedFromDb(validatorDb) + lastValidated, err := staker.ReadLastValidatedInfo(validatorDb) if err != nil { return nil, err } if lastValidated != nil { - lastValidatedHeader := rawdb.ReadHeader(chainDb, lastValidated.BlockHash, lastValidated.BlockNumber) + var lastValidatedHeader *types.Header + headerNum := rawdb.ReadHeaderNumber(chainDb, lastValidated.GlobalState.BlockHash) + if headerNum != nil { + lastValidatedHeader = rawdb.ReadHeader(chainDb, lastValidated.GlobalState.BlockHash, *headerNum) + } if lastValidatedHeader != nil { err = roots.addHeader(lastValidatedHeader, false) if err != nil { return nil, err } } else { - log.Warn("missing latest validated block", "number", lastValidated.BlockNumber, "hash", lastValidated.BlockHash) + log.Warn("missing latest validated block", "hash", lastValidated.GlobalState.BlockHash) } } } else if initConfig.Prune == "full" { diff --git a/cmd/nitro/nitro.go b/cmd/nitro/nitro.go index 0035171078..09420d8c61 100644 --- a/cmd/nitro/nitro.go +++ b/cmd/nitro/nitro.go @@ -38,6 +38,8 @@ import ( "github.com/offchainlabs/nitro/arbnode" "github.com/offchainlabs/nitro/arbnode/execution" + "github.com/offchainlabs/nitro/arbnode/resourcemanager" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/cmd/chaininfo" "github.com/offchainlabs/nitro/cmd/conf" "github.com/offchainlabs/nitro/cmd/genericconf" @@ -321,6 +323,8 @@ func mainImpl() int { nodeConfig.Node.TxLookupLimit = 0 } + resourcemanager.Init(&nodeConfig.Node.ResourceManagement) + stack, err := node.New(&stackConf) if err != nil { flag.Usage() @@ -356,7 +360,7 @@ func mainImpl() int { return 1 } - if nodeConfig.Init.ThenQuit { + if nodeConfig.Init.ThenQuit && nodeConfig.Init.ResetToMsg < 0 { return 0 } @@ -459,12 +463,27 @@ func mainImpl() int { if err != nil { fatalErrChan <- fmt.Errorf("error starting node: %w", err) } + defer currentNode.StopAndWait() } sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) exitCode := 0 + + if err == nil && nodeConfig.Init.ResetToMsg > 0 { + err = currentNode.TxStreamer.ReorgTo(arbutil.MessageIndex(nodeConfig.Init.ResetToMsg)) + if err != nil { + fatalErrChan <- fmt.Errorf("error reseting message: %w", err) + exitCode = 1 + } + if nodeConfig.Init.ThenQuit { + close(sigint) + + return exitCode + } + } + select { case err := <-fatalErrChan: log.Error("shutting down due to fatal error", "err", err) @@ -477,8 +496,6 @@ func mainImpl() int { // cause future ctrl+c's to panic close(sigint) - currentNode.StopAndWait() - return exitCode } diff --git a/go-ethereum b/go-ethereum index 8e6a8ad494..fcda31cae2 160000 --- a/go-ethereum +++ b/go-ethereum @@ -1 +1 @@ -Subproject commit 8e6a8ad4942591011e833e6ebceca6bd668f3db0 +Subproject commit fcda31cae2d5a29699c5c5e58038a72ce0d196ac diff --git a/staker/block_challenge_backend.go b/staker/block_challenge_backend.go index b0c3bb8655..42351789ba 100644 --- a/staker/block_challenge_backend.go +++ b/staker/block_challenge_backend.go @@ -11,7 +11,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/offchainlabs/nitro/arbutil" @@ -20,14 +19,13 @@ import ( ) type BlockChallengeBackend struct { - bc *core.BlockChain - startBlock int64 + streamer TransactionStreamerInterface + startMsgCount arbutil.MessageIndex startPosition uint64 endPosition uint64 startGs validator.GoGlobalState endGs validator.GoGlobalState inboxTracker InboxTrackerInterface - genesisBlockNumber uint64 tooFarStartsAtPosition uint64 } @@ -37,19 +35,10 @@ var _ ChallengeBackend = (*BlockChallengeBackend)(nil) func NewBlockChallengeBackend( initialState *challengegen.ChallengeManagerInitiatedChallenge, maxBatchesRead uint64, - bc *core.BlockChain, + streamer TransactionStreamerInterface, inboxTracker InboxTrackerInterface, - genesisBlockNumber uint64, ) (*BlockChallengeBackend, error) { startGs := validator.GoGlobalStateFromSolidity(initialState.StartState) - startBlockNum := arbutil.MessageCountToBlockNumber(0, genesisBlockNumber) - if startGs.BlockHash != (common.Hash{}) { - startBlock := bc.GetBlockByHash(startGs.BlockHash) - if startBlock == nil { - return nil, fmt.Errorf("failed to find start block %v", startGs.BlockHash) - } - startBlockNum = int64(startBlock.NumberU64()) - } var startMsgCount arbutil.MessageIndex if startGs.Batch > 0 { @@ -60,10 +49,6 @@ func NewBlockChallengeBackend( } } startMsgCount += arbutil.MessageIndex(startGs.PosInBatch) - expectedMsgCount := arbutil.SignedBlockNumberToMessageCount(startBlockNum, genesisBlockNumber) - if startMsgCount != expectedMsgCount { - return nil, fmt.Errorf("start block %v and start message count %v don't correspond", startBlockNum, startMsgCount) - } var endMsgCount arbutil.MessageIndex if maxBatchesRead > 0 { @@ -75,19 +60,18 @@ func NewBlockChallengeBackend( } return &BlockChallengeBackend{ - bc: bc, - startBlock: startBlockNum, + streamer: streamer, + startMsgCount: startMsgCount, startGs: startGs, startPosition: 0, endPosition: math.MaxUint64, endGs: validator.GoGlobalStateFromSolidity(initialState.EndState), inboxTracker: inboxTracker, - genesisBlockNumber: genesisBlockNumber, tooFarStartsAtPosition: uint64(endMsgCount - startMsgCount + 1), }, nil } -func (b *BlockChallengeBackend) findBatchFromMessageIndex(msgCount arbutil.MessageIndex) (uint64, error) { +func (b *BlockChallengeBackend) findBatchAfterMessageCount(msgCount arbutil.MessageIndex) (uint64, error) { if msgCount == 0 { return 0, nil } @@ -118,54 +102,46 @@ func (b *BlockChallengeBackend) findBatchFromMessageIndex(msgCount arbutil.Messa } } -func (b *BlockChallengeBackend) FindGlobalStateFromHeader(header *types.Header) (validator.GoGlobalState, error) { - if header == nil { - return validator.GoGlobalState{}, nil - } - msgCount := arbutil.BlockNumberToMessageCount(header.Number.Uint64(), b.genesisBlockNumber) - batch, err := b.findBatchFromMessageIndex(msgCount) +func (b *BlockChallengeBackend) FindGlobalStateFromMessageCount(count arbutil.MessageIndex) (validator.GoGlobalState, error) { + batch, err := b.findBatchAfterMessageCount(count) if err != nil { return validator.GoGlobalState{}, err } - var batchMsgCount arbutil.MessageIndex + var prevBatchMsgCount arbutil.MessageIndex if batch > 0 { - batchMsgCount, err = b.inboxTracker.GetBatchMessageCount(batch - 1) + prevBatchMsgCount, err = b.inboxTracker.GetBatchMessageCount(batch - 1) if err != nil { return validator.GoGlobalState{}, err } - if batchMsgCount > msgCount { + if prevBatchMsgCount > count { return validator.GoGlobalState{}, errors.New("findBatchFromMessageCount returned bad batch") } } - extraInfo := types.DeserializeHeaderExtraInformation(header) + res, err := b.streamer.ResultAtCount(count) + if err != nil { + return validator.GoGlobalState{}, err + } return validator.GoGlobalState{ - BlockHash: header.Hash(), - SendRoot: extraInfo.SendRoot, + BlockHash: res.BlockHash, + SendRoot: res.SendRoot, Batch: batch, - PosInBatch: uint64(msgCount - batchMsgCount), + PosInBatch: uint64(count - prevBatchMsgCount), }, nil } const StatusFinished uint8 = 1 const StatusTooFar uint8 = 3 -func (b *BlockChallengeBackend) GetBlockNrAtStep(step uint64) int64 { - return b.startBlock + int64(step) +func (b *BlockChallengeBackend) GetMessageCountAtStep(step uint64) arbutil.MessageIndex { + return b.startMsgCount + arbutil.MessageIndex(step) } func (b *BlockChallengeBackend) GetInfoAtStep(step uint64) (validator.GoGlobalState, uint8, error) { - blockNum := b.GetBlockNrAtStep(step) + msgNum := b.GetMessageCountAtStep(step) if step >= b.tooFarStartsAtPosition { return validator.GoGlobalState{}, StatusTooFar, nil } - var header *types.Header - if blockNum != -1 { - header = b.bc.GetHeaderByNumber(uint64(blockNum)) - if header == nil { - return validator.GoGlobalState{}, 0, fmt.Errorf("failed to get block %v in block challenge", blockNum) - } - } - globalState, err := b.FindGlobalStateFromHeader(header) + globalState, err := b.FindGlobalStateFromMessageCount(msgNum) if err != nil { return validator.GoGlobalState{}, 0, err } diff --git a/staker/block_validator.go b/staker/block_validator.go index 56bb2729c7..9096324cff 100644 --- a/staker/block_validator.go +++ b/staker/block_validator.go @@ -9,72 +9,69 @@ import ( "fmt" "sync" "sync/atomic" + "testing" "time" flag "github.com/spf13/pflag" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" - "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/rpcclient" "github.com/offchainlabs/nitro/util/stopwaiter" "github.com/offchainlabs/nitro/validator" ) var ( - validatorPendingValidationsGauge = metrics.NewRegisteredGauge("arb/validator/validations/pending", nil) - validatorValidValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/valid", nil) - validatorFailedValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/failed", nil) - validatorLastBlockInLastBatchGauge = metrics.NewRegisteredGauge("arb/validator/last_block_in_last_batch", nil) - validatorLastBlockValidatedGauge = metrics.NewRegisteredGauge("arb/validator/last_block_validated", nil) + validatorPendingValidationsGauge = metrics.NewRegisteredGauge("arb/validator/validations/pending", nil) + validatorValidValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/valid", nil) + validatorFailedValidationsCounter = metrics.NewRegisteredCounter("arb/validator/validations/failed", nil) + validatorMsgCountCurrentBatch = metrics.NewRegisteredGauge("arb/validator/msg_count_current_batch", nil) + validatorMsgCountValidatedGauge = metrics.NewRegisteredGauge("arb/validator/msg_count_validated", nil) ) type BlockValidator struct { stopwaiter.StopWaiter *StatelessBlockValidator - validations sync.Map - sequencerBatches sync.Map + reorgMutex sync.RWMutex - // acquiring multiple Mutexes must be done in order: - reorgMutex sync.Mutex - batchMutex sync.Mutex - blockMutex sync.Mutex - lastBlockValidatedMutex sync.Mutex + chainCaughtUp bool - reorgsPending int32 // atomic + // can only be accessed from creation thread or if holding reorg-write + nextCreateBatch []byte + nextCreateBatchMsgCount arbutil.MessageIndex + nextCreateBatchReread bool + nextCreateStartGS validator.GoGlobalState + nextCreatePrevDelayed uint64 - earliestBatchKept uint64 // atomic - nextBatchKept uint64 // behind batchMutex, 1 + the last batch number kept + // can only be accessed from from validation thread or if holding reorg-write + lastValidGS validator.GoGlobalState + valLoopPos arbutil.MessageIndex + legacyValidInfo *legacyLastBlockValidatedDbInfo - // protected by reorgMutex - globalPosNextSend GlobalStatePosition + // only from logger thread + lastValidInfoPrinted *GlobalStateValidatedInfo - // protected by BlockMutex: - nextBlockToValidate uint64 - lastValidationEntryBlock uint64 - - // behind lastBlockValidatedMutex - lastBlockValidatedUnknown bool - lastBlockValidated uint64 // also atomic - lastBlockValidatedHash common.Hash + // can be read (atomic.Load) by anyone holding reorg-read + // written (atomic.Set) by appropriate thread or (any way) holding reorg-write + createdA uint64 + recordSentA uint64 + validatedA uint64 + validations containers.SyncMap[arbutil.MessageIndex, *validationStatus] config BlockValidatorConfigFetcher - sendValidationsChan chan struct{} - progressChan chan uint64 - - lastHeaderForPrepareState *types.Header + createNodesChan chan struct{} + sendRecordChan chan struct{} + progressValidationsChan chan struct{} - // recentValid holds one recently valid header, to commit it to DB on shutdown - recentValidMutex sync.Mutex - awaitingValidation *types.Header - validHeader *types.Header + // for testing only + testingProgressMadeChan chan struct{} fatalErr chan<- error } @@ -148,13 +145,12 @@ var DefaultBlockValidatorDangerousConfig = BlockValidatorDangerousConfig{ type valStatusField uint32 const ( - Unprepared valStatusField = iota + Created valStatusField = iota RecordSent RecordFailed Prepared + SendingValidation ValidationSent - Failed - Valid ) type validationStatus struct { @@ -164,10 +160,6 @@ type validationStatus struct { Runs []validator.ValidationRun // if status >= ValidationSent } -func (s *validationStatus) setStatus(val valStatusField) { - atomic.StoreUint32(&s.Status, uint32(val)) -} - func (s *validationStatus) getStatus() valStatusField { uintStat := atomic.LoadUint32(&s.Status) return valStatusField(uintStat) @@ -181,24 +173,72 @@ func NewBlockValidator( statelessBlockValidator *StatelessBlockValidator, inbox InboxTrackerInterface, streamer TransactionStreamerInterface, - reorgingToBlock *types.Block, config BlockValidatorConfigFetcher, fatalErr chan<- error, ) (*BlockValidator, error) { - validator := &BlockValidator{ + ret := &BlockValidator{ StatelessBlockValidator: statelessBlockValidator, - sendValidationsChan: make(chan struct{}, 1), - progressChan: make(chan uint64, 1), + createNodesChan: make(chan struct{}, 1), + sendRecordChan: make(chan struct{}, 1), + progressValidationsChan: make(chan struct{}, 1), config: config, fatalErr: fatalErr, } - err := validator.readLastBlockValidatedDbInfo(reorgingToBlock) - if err != nil { - return nil, err + if !config().Dangerous.ResetBlockValidation { + validated, err := ret.ReadLastValidatedInfo() + if err != nil { + return nil, err + } + if validated != nil { + ret.lastValidGS = validated.GlobalState + } else { + legacyInfo, err := ret.legacyReadLastValidatedInfo() + if err != nil { + return nil, err + } + ret.legacyValidInfo = legacyInfo + } + } + // genesis block is impossible to validate unless genesis state is empty + if ret.lastValidGS.Batch == 0 && ret.legacyValidInfo == nil { + genesis, err := streamer.ResultAtCount(1) + if err != nil { + return nil, err + } + ret.lastValidGS = validator.GoGlobalState{ + BlockHash: genesis.BlockHash, + SendRoot: genesis.SendRoot, + Batch: 1, + PosInBatch: 0, + } } - streamer.SetBlockValidator(validator) - inbox.SetBlockValidator(validator) - return validator, nil + streamer.SetBlockValidator(ret) + inbox.SetBlockValidator(ret) + return ret, nil +} + +func atomicStorePos(addr *uint64, val arbutil.MessageIndex) { + atomic.StoreUint64(addr, uint64(val)) +} + +func atomicLoadPos(addr *uint64) arbutil.MessageIndex { + return arbutil.MessageIndex(atomic.LoadUint64(addr)) +} + +func (v *BlockValidator) created() arbutil.MessageIndex { + return atomicLoadPos(&v.createdA) +} + +func (v *BlockValidator) recordSent() arbutil.MessageIndex { + return atomicLoadPos(&v.recordSentA) +} + +func (v *BlockValidator) validated() arbutil.MessageIndex { + return atomicLoadPos(&v.validatedA) +} + +func (v *BlockValidator) Validated(t *testing.T) arbutil.MessageIndex { + return v.validated() } func (v *BlockValidator) possiblyFatal(err error) { @@ -217,161 +257,118 @@ func (v *BlockValidator) possiblyFatal(err error) { } } -func (v *BlockValidator) triggerSendValidations() { +func nonBlockingTrigger(channel chan struct{}) { select { - case v.sendValidationsChan <- struct{}{}: + case channel <- struct{}{}: default: } } -func (v *BlockValidator) recentlyValid(header *types.Header) { - v.recentValidMutex.Lock() - defer v.recentValidMutex.Unlock() - if v.awaitingValidation == nil { - return - } - if v.awaitingValidation.Number.Cmp(header.Number) > 0 { - return +// called from NewBlockValidator, doesn't need to catch locks +func ReadLastValidatedInfo(db ethdb.Database) (*GlobalStateValidatedInfo, error) { + exists, err := db.Has(lastGlobalStateValidatedInfoKey) + if err != nil { + return nil, err } - if v.validHeader != nil { - v.recordingDatabase.Dereference(v.validHeader) + var validated GlobalStateValidatedInfo + if !exists { + return nil, nil } - v.validHeader = v.awaitingValidation - v.awaitingValidation = nil -} - -func (v *BlockValidator) recentStateComputed(header *types.Header) { - v.recentValidMutex.Lock() - defer v.recentValidMutex.Unlock() - if v.awaitingValidation != nil { - return + gsBytes, err := db.Get(lastGlobalStateValidatedInfoKey) + if err != nil { + return nil, err } - _, err := v.recordingDatabase.StateFor(header) + err = rlp.DecodeBytes(gsBytes, &validated) if err != nil { - log.Error("failed to get state for block while validating", "err", err, "blockNum", header.Number, "hash", header.Hash()) - return + return nil, err } - v.awaitingValidation = header + return &validated, nil } -func (v *BlockValidator) recentShutdown() error { - v.recentValidMutex.Lock() - defer v.recentValidMutex.Unlock() - if v.validHeader == nil { - return nil - } - err := v.recordingDatabase.WriteStateToDatabase(v.validHeader) - v.recordingDatabase.Dereference(v.validHeader) - return err +func (v *BlockValidator) ReadLastValidatedInfo() (*GlobalStateValidatedInfo, error) { + return ReadLastValidatedInfo(v.db) } -func ReadLastValidatedFromDb(db ethdb.Database) (*LastBlockValidatedDbInfo, error) { - exists, err := db.Has(lastBlockValidatedInfoKey) +func (v *BlockValidator) legacyReadLastValidatedInfo() (*legacyLastBlockValidatedDbInfo, error) { + exists, err := v.db.Has(legacyLastBlockValidatedInfoKey) if err != nil { return nil, err } + var validated legacyLastBlockValidatedDbInfo if !exists { return nil, nil } - - infoBytes, err := db.Get(lastBlockValidatedInfoKey) + gsBytes, err := v.db.Get(legacyLastBlockValidatedInfoKey) if err != nil { return nil, err } - - var info LastBlockValidatedDbInfo - err = rlp.DecodeBytes(infoBytes, &info) + err = rlp.DecodeBytes(gsBytes, &validated) if err != nil { return nil, err } - return &info, nil + return &validated, nil } -// only called by NewBlockValidator -func (v *BlockValidator) readLastBlockValidatedDbInfo(reorgingToBlock *types.Block) error { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - - v.lastBlockValidatedMutex.Lock() - defer v.lastBlockValidatedMutex.Unlock() +var ErrGlobalStateNotInChain = errors.New("globalstate not in chain") - exists, err := v.db.Has(lastBlockValidatedInfoKey) +// false if chain not caught up to globalstate +// error is ErrGlobalStateNotInChain if globalstate not in chain (and chain caught up) +func GlobalStateToMsgCount(tracker InboxTrackerInterface, streamer TransactionStreamerInterface, gs validator.GoGlobalState) (bool, arbutil.MessageIndex, error) { + batchCount, err := tracker.GetBatchCount() if err != nil { - return err + return false, 0, err } - - if !exists || v.config().Dangerous.ResetBlockValidation { - // The db contains no validation info; start from the beginning. - // TODO: this skips validating the genesis block. - atomic.StoreUint64(&v.lastBlockValidated, v.genesisBlockNum) - validatorLastBlockValidatedGauge.Update(int64(v.genesisBlockNum)) - genesisBlock := v.blockchain.GetBlockByNumber(v.genesisBlockNum) - if genesisBlock == nil { - return fmt.Errorf("blockchain missing genesis block number %v", v.genesisBlockNum) + requiredBatchCount := gs.Batch + 1 + if gs.PosInBatch == 0 { + requiredBatchCount -= 1 + } + if batchCount < requiredBatchCount { + return false, 0, nil + } + var prevBatchMsgCount arbutil.MessageIndex + if gs.Batch > 0 { + prevBatchMsgCount, err = tracker.GetBatchMessageCount(gs.Batch - 1) + if err != nil { + return false, 0, err } - - v.lastBlockValidatedHash = genesisBlock.Hash() - v.nextBlockToValidate = v.genesisBlockNum + 1 - v.globalPosNextSend = GlobalStatePosition{ - BatchNumber: 1, - PosInBatch: 0, + } + count := prevBatchMsgCount + if gs.PosInBatch > 0 { + curBatchMsgCount, err := tracker.GetBatchMessageCount(gs.Batch) + if err != nil { + return false, 0, fmt.Errorf("%w: getBatchMsgCount %d batchCount %d", err, gs.Batch, batchCount) + } + count += arbutil.MessageIndex(gs.PosInBatch) + if curBatchMsgCount < count { + return false, 0, fmt.Errorf("%w: batch %d posInBatch %d, maxPosInBatch %d", ErrGlobalStateNotInChain, gs.Batch, gs.PosInBatch, curBatchMsgCount-prevBatchMsgCount) } - return nil } - - info, err := ReadLastValidatedFromDb(v.db) + processed, err := streamer.GetProcessedMessageCount() if err != nil { - return err + return false, 0, err } - - if reorgingToBlock != nil && reorgingToBlock.NumberU64() >= info.BlockNumber { - // Disregard this reorg as it doesn't affect the last validated block - reorgingToBlock = nil + if processed < count { + return false, 0, nil } - - if reorgingToBlock == nil { - expectedHash := v.blockchain.GetCanonicalHash(info.BlockNumber) - if expectedHash != info.BlockHash && (expectedHash != common.Hash{}) { - return fmt.Errorf("last validated block %v stored with hash %v, but blockchain has hash %v", info.BlockNumber, info.BlockHash, expectedHash) - } + res, err := streamer.ResultAtCount(count) + if err != nil { + return false, 0, err } - - atomic.StoreUint64(&v.lastBlockValidated, info.BlockNumber) - validatorLastBlockValidatedGauge.Update(int64(info.BlockNumber)) - v.lastBlockValidatedHash = info.BlockHash - v.nextBlockToValidate = v.lastBlockValidated + 1 - v.globalPosNextSend = info.AfterPosition - - if reorgingToBlock != nil { - err = v.reorgToBlockImpl(reorgingToBlock.NumberU64(), reorgingToBlock.Hash()) - if err != nil { - return err - } + if res.BlockHash != gs.BlockHash || res.SendRoot != gs.SendRoot { + return false, 0, fmt.Errorf("%w: count %d hash %v expected %v, sendroot %v expected %v", ErrGlobalStateNotInChain, count, gs.BlockHash, res.BlockHash, gs.SendRoot, res.SendRoot) } - - return nil + return true, count, nil } -func (v *BlockValidator) sendRecord(s *validationStatus, mustDeref bool) error { +func (v *BlockValidator) sendRecord(s *validationStatus) error { if !v.Started() { - // this could only be sent by NewBlock, so mustDeref is not sent return nil } - prevHeader := s.Entry.PrevBlockHeader - if !s.replaceStatus(Unprepared, RecordSent) { - if mustDeref { - v.recordingDatabase.Dereference(prevHeader) - } + if !s.replaceStatus(Created, RecordSent) { return fmt.Errorf("failed status check for send record. Status: %v", s.getStatus()) } v.LaunchThread(func(ctx context.Context) { - if mustDeref { - defer v.recordingDatabase.Dereference(prevHeader) - } - err := v.ValidationEntryRecord(ctx, s.Entry, true) + err := v.ValidationEntryRecord(ctx, s.Entry) if ctx.Err() != nil { return } @@ -380,94 +377,31 @@ func (v *BlockValidator) sendRecord(s *validationStatus, mustDeref bool) error { log.Error("Error while recording", "err", err, "status", s.getStatus()) return } - v.recentStateComputed(prevHeader) - v.recordingDatabase.Dereference(prevHeader) // removes the reference added by ValidationEntryRecord if !s.replaceStatus(RecordSent, Prepared) { log.Error("Fault trying to update validation with recording", "entry", s.Entry, "status", s.getStatus()) return } - v.triggerSendValidations() + nonBlockingTrigger(v.progressValidationsChan) }) return nil } -func (v *BlockValidator) newValidationStatus(prevHeader, header *types.Header, msg *arbostypes.MessageWithMetadata) (*validationStatus, error) { - entry, err := newValidationEntry(prevHeader, header, msg) - if err != nil { - return nil, err - } - status := &validationStatus{ - Status: uint32(Unprepared), - Entry: entry, - } - return status, nil -} - -func (v *BlockValidator) NewBlock(block *types.Block, prevHeader *types.Header, msg arbostypes.MessageWithMetadata) { - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - blockNum := block.NumberU64() - v.lastBlockValidatedMutex.Lock() - if blockNum < v.lastBlockValidated { - v.lastBlockValidatedMutex.Unlock() - return - } - if v.lastBlockValidatedUnknown { - if block.Hash() == v.lastBlockValidatedHash { - v.lastBlockValidated = blockNum - validatorLastBlockValidatedGauge.Update(int64(blockNum)) - v.nextBlockToValidate = blockNum + 1 - v.lastBlockValidatedUnknown = false - log.Info("Block building caught up to staker", "blockNr", v.lastBlockValidated, "blockHash", v.lastBlockValidatedHash) - // note: this block is already valid - } - v.lastBlockValidatedMutex.Unlock() - return - } - if v.nextBlockToValidate+v.config().ForwardBlocks <= blockNum { - v.lastBlockValidatedMutex.Unlock() - return - } - v.lastBlockValidatedMutex.Unlock() - status, err := v.newValidationStatus(prevHeader, block.Header(), &msg) - if err != nil { - log.Error("failed creating validation status", "err", err) - return - } - // It's fine to separately load and then store as we have the blockMutex acquired - _, present := v.validations.Load(blockNum) - if present { - return - } - v.validations.Store(blockNum, status) - if v.lastValidationEntryBlock < blockNum { - v.lastValidationEntryBlock = blockNum - } - v.triggerSendValidations() -} - //nolint:gosec func (v *BlockValidator) writeToFile(validationEntry *validationEntry, moduleRoot common.Hash) error { input, err := validationEntry.ToInput() if err != nil { return err } - expOut, err := validationEntry.expectedEnd() - if err != nil { - return err - } - _, err = v.execSpawner.WriteToFile(input, expOut, moduleRoot).Await(v.GetContext()) + _, err = v.execSpawner.WriteToFile(input, validationEntry.End, moduleRoot).Await(v.GetContext()) return err } func (v *BlockValidator) SetCurrentWasmModuleRoot(hash common.Hash) error { - v.blockMutex.Lock() v.moduleMutex.Lock() - defer v.blockMutex.Unlock() defer v.moduleMutex.Unlock() if (hash == common.Hash{}) { - return errors.New("trying to set zero as wsmModuleRoot") + return errors.New("trying to set zero as wasmModuleRoot") } if hash == v.currentWasmModuleRoot { return nil @@ -490,527 +424,520 @@ func (v *BlockValidator) SetCurrentWasmModuleRoot(hash common.Hash) error { ) } -var ErrValidationCanceled = errors.New("validation of block cancelled") +func (v *BlockValidator) readBatch(ctx context.Context, batchNum uint64) (bool, []byte, arbutil.MessageIndex, error) { + batchCount, err := v.inboxTracker.GetBatchCount() + if err != nil { + return false, nil, 0, err + } + if batchCount <= batchNum { + return false, nil, 0, nil + } + batchMsgCount, err := v.inboxTracker.GetBatchMessageCount(batchNum) + if err != nil { + return false, nil, 0, err + } + batch, err := v.inboxReader.GetSequencerMessageBytes(ctx, batchNum) + if err != nil { + return false, nil, 0, err + } + return true, batch, batchMsgCount, nil +} + +func (v *BlockValidator) createNextValidationEntry(ctx context.Context) (bool, error) { + v.reorgMutex.RLock() + defer v.reorgMutex.RUnlock() + pos := v.created() + if pos > v.validated()+arbutil.MessageIndex(v.config().ForwardBlocks) { + log.Trace("create validation entry: nothing to do", "pos", pos, "validated", v.validated()) + return false, nil + } + streamerMsgCount, err := v.streamer.GetProcessedMessageCount() + if err != nil { + return false, err + } + if pos >= streamerMsgCount { + log.Trace("create validation entry: nothing to do", "pos", pos, "streamerMsgCount", streamerMsgCount) + return false, nil + } + msg, err := v.streamer.GetMessage(pos) + if err != nil { + return false, err + } + endRes, err := v.streamer.ResultAtCount(pos + 1) + if err != nil { + return false, err + } + if v.nextCreateStartGS.PosInBatch == 0 || v.nextCreateBatchReread { + // new batch + found, batch, count, err := v.readBatch(ctx, v.nextCreateStartGS.Batch) + if !found { + return false, err + } + v.nextCreateBatch = batch + v.nextCreateBatchMsgCount = count + validatorMsgCountCurrentBatch.Update(int64(count)) + v.nextCreateBatchReread = false + } + endGS := validator.GoGlobalState{ + BlockHash: endRes.BlockHash, + SendRoot: endRes.SendRoot, + } + if pos+1 < v.nextCreateBatchMsgCount { + endGS.Batch = v.nextCreateStartGS.Batch + endGS.PosInBatch = v.nextCreateStartGS.PosInBatch + 1 + } else if pos+1 == v.nextCreateBatchMsgCount { + endGS.Batch = v.nextCreateStartGS.Batch + 1 + endGS.PosInBatch = 0 + } else { + return false, fmt.Errorf("illegal batch msg count %d pos %d batch %d", v.nextCreateBatchMsgCount, pos, endGS.Batch) + } + entry, err := newValidationEntry(pos, v.nextCreateStartGS, endGS, msg, v.nextCreateBatch, v.nextCreatePrevDelayed) + if err != nil { + return false, err + } + status := &validationStatus{ + Status: uint32(Created), + Entry: entry, + } + v.validations.Store(pos, status) + v.nextCreateStartGS = endGS + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + atomicStorePos(&v.createdA, pos+1) + log.Trace("create validation entry: created", "pos", pos) + return true, nil +} + +func (v *BlockValidator) iterativeValidationEntryCreator(ctx context.Context, ignored struct{}) time.Duration { + moreWork, err := v.createNextValidationEntry(ctx) + if err != nil { + processed, processedErr := v.streamer.GetProcessedMessageCount() + log.Error("error trying to create validation node", "err", err, "created", v.created()+1, "processed", processed, "processedErr", processedErr) + } + if moreWork { + return 0 + } + return v.config().ValidationPoll +} + +func (v *BlockValidator) sendNextRecordRequests(ctx context.Context) (bool, error) { + v.reorgMutex.RLock() + pos := v.recordSent() + created := v.created() + validated := v.validated() + v.reorgMutex.RUnlock() + + recordUntil := validated + arbutil.MessageIndex(v.config().PrerecordedBlocks) - 1 + if recordUntil > created-1 { + recordUntil = created - 1 + } + if recordUntil < pos { + return false, nil + } + log.Trace("preparing to record", "pos", pos, "until", recordUntil) + // prepare could take a long time so we do it without a lock + err := v.recorder.PrepareForRecord(ctx, pos, recordUntil) + if err != nil { + return false, err + } + + v.reorgMutex.RLock() + defer v.reorgMutex.RUnlock() + createdNew := v.created() + recordSentNew := v.recordSent() + if createdNew < created || recordSentNew < pos { + // there was a relevant reorg - quit and restart + return true, nil + } + for pos <= recordUntil { + validationStatus, found := v.validations.Load(pos) + if !found { + return false, fmt.Errorf("not found entry for pos %d", pos) + } + currentStatus := validationStatus.getStatus() + if currentStatus != Created { + return false, fmt.Errorf("bad status trying to send recordings for pos %d status: %v", pos, currentStatus) + } + err := v.sendRecord(validationStatus) + if err != nil { + return false, err + } + pos += 1 + atomicStorePos(&v.recordSentA, pos) + log.Trace("next record request: sent", "pos", pos) + } + + return true, nil +} + +func (v *BlockValidator) iterativeValidationEntryRecorder(ctx context.Context, ignored struct{}) time.Duration { + moreWork, err := v.sendNextRecordRequests(ctx) + if err != nil { + log.Error("error trying to record for validation node", "err", err) + } + if moreWork { + return 0 + } + return v.config().ValidationPoll +} + +func (v *BlockValidator) iterativeValidationPrint(ctx context.Context) time.Duration { + validated, err := v.ReadLastValidatedInfo() + if err != nil { + log.Error("cannot read last validated data from database", "err", err) + return time.Second * 30 + } + if validated == nil { + return time.Second + } + if v.lastValidInfoPrinted != nil { + if v.lastValidInfoPrinted.GlobalState.BlockHash == validated.GlobalState.BlockHash { + return time.Second + } + } + var batchMsgs arbutil.MessageIndex + var printedCount int64 + if validated.GlobalState.Batch > 0 { + batchMsgs, err = v.inboxTracker.GetBatchMessageCount(validated.GlobalState.Batch) + } + if err != nil { + printedCount = -1 + } else { + printedCount = int64(batchMsgs) + int64(validated.GlobalState.PosInBatch) + } + log.Info("validated execution", "messageCount", printedCount, "globalstate", validated.GlobalState, "WasmRoots", validated.WasmRoots) + v.lastValidInfoPrinted = validated + return time.Second +} + +// return val: +// *MessageIndex - pointer to bad entry if there is one (requires reorg) +func (v *BlockValidator) advanceValidations(ctx context.Context) (*arbutil.MessageIndex, error) { + v.reorgMutex.RLock() + defer v.reorgMutex.RUnlock() -func (v *BlockValidator) sendValidations(ctx context.Context) { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - var batchCount uint64 wasmRoots := v.GetModuleRootsToValidate() room := 100 // even if there is more room then that it's fine for _, spawner := range v.validationSpawners { here := spawner.Room() / len(wasmRoots) if here <= 0 { - return + room = 0 } if here < room { room = here } } - for atomic.LoadInt32(&v.reorgsPending) == 0 { - if room <= 0 { - return + pos := v.validated() - 1 // to reverse the first +1 in the loop +validationsLoop: + for { + if ctx.Err() != nil { + return nil, ctx.Err() } - if batchCount <= v.globalPosNextSend.BatchNumber { - var err error - batchCount, err = v.inboxTracker.GetBatchCount() - if err != nil { - log.Error("validator failed to get message count", "err", err) - return - } - if batchCount <= v.globalPosNextSend.BatchNumber { - return - } + v.valLoopPos = pos + 1 + v.reorgMutex.RUnlock() + v.reorgMutex.RLock() + pos = v.valLoopPos + if pos >= v.recordSent() { + log.Trace("advanceValidations: nothing to validate", "pos", pos) + return nil, nil } - seqBatchEntry, haveBatch := v.sequencerBatches.Load(v.globalPosNextSend.BatchNumber) - if !haveBatch && batchCount == v.globalPosNextSend.BatchNumber+1 { - // This is the latest batch. - // Wait a bit to see if the inbox tracker populates this sequencer batch, - // but if it's still missing after this wait, we'll query it from the inbox reader. - time.Sleep(time.Second) - seqBatchEntry, haveBatch = v.sequencerBatches.Load(v.globalPosNextSend.BatchNumber) + validationStatus, found := v.validations.Load(pos) + if !found { + return nil, fmt.Errorf("not found entry for pos %d", pos) } - if !haveBatch { - seqMsg, err := v.inboxReader.GetSequencerMessageBytes(ctx, v.globalPosNextSend.BatchNumber) - if err != nil { - log.Error("validator failed to read sequencer message", "err", err) - return - } - v.ProcessBatches(v.globalPosNextSend.BatchNumber, [][]byte{seqMsg}) - seqBatchEntry = seqMsg + currentStatus := validationStatus.getStatus() + if currentStatus == RecordFailed { + // retry + log.Warn("Recording for validation failed, retrying..", "pos", pos) + return &pos, nil } - v.blockMutex.Lock() - v.lastBlockValidatedMutex.Lock() - if v.lastBlockValidatedUnknown { - firstMsgInBatch := arbutil.MessageIndex(0) - if v.globalPosNextSend.BatchNumber > 0 { - var err error - firstMsgInBatch, err = v.inboxTracker.GetBatchMessageCount(v.globalPosNextSend.BatchNumber - 1) + if currentStatus == ValidationSent && pos == v.validated() { + if validationStatus.Entry.Start != v.lastValidGS { + log.Warn("Validation entry has wrong start state", "pos", pos, "start", validationStatus.Entry.Start, "expected", v.lastValidGS) + validationStatus.Cancel() + return &pos, nil + } + var wasmRoots []common.Hash + for i, run := range validationStatus.Runs { + if !run.Ready() { + log.Trace("advanceValidations: validation not ready", "pos", pos, "run", i) + continue validationsLoop + } + wasmRoots = append(wasmRoots, run.WasmModuleRoot()) + runEnd, err := run.Current() + if err == nil && runEnd != validationStatus.Entry.End { + err = fmt.Errorf("validation failed: expected %v got %v", validationStatus.Entry.End, runEnd) + writeErr := v.writeToFile(validationStatus.Entry, run.WasmModuleRoot()) + if writeErr != nil { + log.Warn("failed to write debug results file", "err", writeErr) + } + } if err != nil { - v.lastBlockValidatedMutex.Unlock() - v.blockMutex.Unlock() - log.Error("validator couldnt read message count", "err", err) - return + validatorFailedValidationsCounter.Inc(1) + v.possiblyFatal(err) + return &pos, nil // if not fatal - retry } + validatorValidValidationsCounter.Inc(1) } - v.lastBlockValidated = uint64(arbutil.MessageCountToBlockNumber(firstMsgInBatch+arbutil.MessageIndex(v.globalPosNextSend.PosInBatch), v.genesisBlockNum)) - validatorLastBlockValidatedGauge.Update(int64(v.lastBlockValidated)) - v.nextBlockToValidate = v.lastBlockValidated + 1 - v.lastBlockValidatedUnknown = false - log.Info("Inbox caught up to staker", "blockNr", v.lastBlockValidated, "blockHash", v.lastBlockValidatedHash) - } - v.lastBlockValidatedMutex.Unlock() - nextBlockToValidate := v.nextBlockToValidate - v.blockMutex.Unlock() - nextMsg := arbutil.BlockNumberToMessageCount(nextBlockToValidate, v.genesisBlockNum) - 1 - // valdationEntries is By blockNumber - entry, found := v.validations.Load(nextBlockToValidate) - if !found { - return - } - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to validate batch") - return - } - if validationStatus.getStatus() < Prepared { - return - } - startPos, endPos, err := GlobalStatePositionsFor(v.inboxTracker, nextMsg, v.globalPosNextSend.BatchNumber) - if err != nil { - log.Error("failed calculating position for validation", "err", err, "msg", nextMsg, "batch", v.globalPosNextSend.BatchNumber) - return - } - if startPos != v.globalPosNextSend { - log.Error("inconsistent pos mapping", "msg", nextMsg, "expected", v.globalPosNextSend, "found", startPos) - return - } - seqMsg, ok := seqBatchEntry.([]byte) - if !ok { - batchNum := validationStatus.Entry.StartPosition.BatchNumber - log.Error("sequencer message bad format", "blockNr", nextBlockToValidate, "msgNum", batchNum) - return - } - msgCountInBatch, err := v.inboxTracker.GetBatchMessageCount(v.globalPosNextSend.BatchNumber) - if err != nil { - log.Error("failed to get batch message count", "err", err, "batch", v.globalPosNextSend.BatchNumber) - return - } - lastBlockInBatch := arbutil.MessageCountToBlockNumber(msgCountInBatch, v.genesisBlockNum) - validatorLastBlockInLastBatchGauge.Update(lastBlockInBatch) - v.LaunchThread(func(ctx context.Context) { - validationCtx, cancel := context.WithCancel(ctx) - defer cancel() - validationStatus.Cancel = cancel - err := v.ValidationEntryAddSeqMessage(ctx, validationStatus.Entry, startPos, endPos, seqMsg) + err := v.writeLastValidated(validationStatus.Entry.End, wasmRoots) if err != nil { - validationStatus.replaceStatus(Prepared, RecordFailed) - if validationCtx.Err() == nil { - log.Error("error preparing validation", "err", err) - } - return + log.Error("failed writing new validated to database", "pos", pos, "err", err) } + go v.recorder.MarkValid(pos, v.lastValidGS.BlockHash) + atomicStorePos(&v.validatedA, pos+1) + v.validations.Delete(pos) + nonBlockingTrigger(v.createNodesChan) + nonBlockingTrigger(v.sendRecordChan) + validatorMsgCountValidatedGauge.Update(int64(pos + 1)) + if v.testingProgressMadeChan != nil { + nonBlockingTrigger(v.testingProgressMadeChan) + } + log.Trace("result validated", "count", v.validated(), "blockHash", v.lastValidGS.BlockHash) + continue + } + if room == 0 { + log.Trace("advanceValidations: no more room", "pos", pos) + return nil, nil + } + if currentStatus == Prepared { input, err := validationStatus.Entry.ToInput() - if err != nil { - validationStatus.replaceStatus(Prepared, RecordFailed) - if validationCtx.Err() == nil { - log.Error("error preparing validation", "err", err) - } - return + if err != nil && ctx.Err() == nil { + v.possiblyFatal(fmt.Errorf("%w: error preparing validation", err)) + continue + } + replaced := validationStatus.replaceStatus(Prepared, SendingValidation) + if !replaced { + v.possiblyFatal(errors.New("failed to set SendingValidation status")) } + validatorPendingValidationsGauge.Inc(1) + defer validatorPendingValidationsGauge.Dec(1) + var runs []validator.ValidationRun for _, moduleRoot := range wasmRoots { - for _, spawner := range v.validationSpawners { + for i, spawner := range v.validationSpawners { run := spawner.Launch(input, moduleRoot) - validationStatus.Runs = append(validationStatus.Runs, run) + log.Trace("advanceValidations: launched", "pos", validationStatus.Entry.Pos, "moduleRoot", moduleRoot, "spawner", i) + runs = append(runs, run) } } - validatorPendingValidationsGauge.Inc(1) - replaced := validationStatus.replaceStatus(Prepared, ValidationSent) - if !replaced { - v.possiblyFatal(errors.New("failed to set status")) - } - }) - room-- - v.blockMutex.Lock() - v.nextBlockToValidate++ - v.blockMutex.Unlock() - v.globalPosNextSend = endPos + validationCtx, cancel := context.WithCancel(ctx) + validationStatus.Runs = runs + validationStatus.Cancel = cancel + v.LaunchUntrackedThread(func() { + defer cancel() + replaced = validationStatus.replaceStatus(SendingValidation, ValidationSent) + if !replaced { + v.possiblyFatal(errors.New("failed to set status to ValidationSent")) + } + + // validationStatus might be removed from under us + // trigger validation progress when done + for _, run := range runs { + _, err := run.Await(validationCtx) + if err != nil { + return + } + } + nonBlockingTrigger(v.progressValidationsChan) + }) + room-- + } } } -func (v *BlockValidator) sendRecords(ctx context.Context) { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - v.blockMutex.Lock() - nextRecord := v.nextBlockToValidate - v.blockMutex.Unlock() - for atomic.LoadInt32(&v.reorgsPending) == 0 { - if nextRecord >= v.nextBlockToValidate+v.config().PrerecordedBlocks { - return - } - entry, found := v.validations.Load(nextRecord) - if !found { - header := v.blockchain.GetHeaderByNumber(nextRecord) - if header == nil { - // This block hasn't been created yet. - return - } - prevHeader := v.blockchain.GetHeaderByHash(header.ParentHash) - if prevHeader == nil && header.ParentHash != (common.Hash{}) { - log.Warn("failed to get prevHeader in block validator", "num", nextRecord-1, "hash", header.ParentHash) - return - } - msgNum := arbutil.BlockNumberToMessageCount(nextRecord, v.genesisBlockNum) - 1 - msg, err := v.streamer.GetMessage(msgNum) - if err != nil { - log.Warn("failed to get message in block validator", "err", err) - return - } - status, err := v.newValidationStatus(prevHeader, header, msg) - if err != nil { - log.Warn("failed to create validation status", "err", err) - return - } - v.blockMutex.Lock() - entry, found = v.validations.Load(nextRecord) - if !found { - v.validations.Store(nextRecord, status) - entry = status - } - v.blockMutex.Unlock() - } - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to send recordings") - return - } - currentStatus := validationStatus.getStatus() - if currentStatus == RecordFailed { - // retry - v.validations.Delete(nextRecord) - v.triggerSendValidations() - return - } - if currentStatus == Unprepared { - prevHeader := validationStatus.Entry.PrevBlockHeader - if prevHeader != nil { - _, err := v.recordingDatabase.GetOrRecreateState(ctx, prevHeader, stateLogFunc) - if err != nil { - log.Error("error trying to prepare state for recording", "err", err) - } - // add another reference that will be released by the record thread - _, err = v.recordingDatabase.StateFor(prevHeader) - if err != nil { - log.Error("error trying re-reference state for recording", "err", err) - } - if v.lastHeaderForPrepareState != nil { - v.recordingDatabase.Dereference(v.lastHeaderForPrepareState) - } - v.lastHeaderForPrepareState = prevHeader - } - err := v.sendRecord(validationStatus, true) - if err != nil { - log.Error("error trying to send preimage recording", "err", err) - } +func (v *BlockValidator) iterativeValidationProgress(ctx context.Context, ignored struct{}) time.Duration { + reorg, err := v.advanceValidations(ctx) + if err != nil { + log.Error("error trying to record for validation node", "err", err) + } else if reorg != nil { + err := v.Reorg(ctx, *reorg) + if err != nil { + log.Error("error trying to rorg validation", "pos", *reorg-1, "err", err) + v.possiblyFatal(err) } - nextRecord++ } + return v.config().ValidationPoll } -func (v *BlockValidator) writeLastValidatedToDb(blockNumber uint64, blockHash common.Hash, endPos GlobalStatePosition) error { - info := LastBlockValidatedDbInfo{ - BlockNumber: blockNumber, - BlockHash: blockHash, - AfterPosition: endPos, +var ErrValidationCanceled = errors.New("validation of block cancelled") + +func (v *BlockValidator) writeLastValidated(gs validator.GoGlobalState, wasmRoots []common.Hash) error { + v.lastValidGS = gs + info := GlobalStateValidatedInfo{ + GlobalState: gs, + WasmRoots: wasmRoots, } - encodedInfo, err := rlp.EncodeToBytes(info) + encoded, err := rlp.EncodeToBytes(info) if err != nil { return err } - err = v.db.Put(lastBlockValidatedInfoKey, encodedInfo) + err = v.db.Put(lastGlobalStateValidatedInfoKey, encoded) if err != nil { return err } return nil } -func (v *BlockValidator) progressValidated() { - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - for atomic.LoadInt32(&v.reorgsPending) == 0 { - // Reads from blocksValidated can be non-atomic as all writes hold reorgMutex - checkingBlock := v.lastBlockValidated + 1 - entry, found := v.validations.Load(checkingBlock) - if !found { - return - } - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to advance validated counter") - return - } - if validationStatus.getStatus() < ValidationSent { - return - } - validationEntry := validationStatus.Entry - if validationEntry.BlockNumber != checkingBlock { - log.Error("bad block number for validation entry", "expected", checkingBlock, "found", validationEntry.BlockNumber) - return - } - // It's safe to read lastBlockValidatedHash without the lastBlockValidatedMutex as we have the reorgMutex - if v.lastBlockValidatedHash != validationEntry.PrevBlockHash { - log.Error("lastBlockValidatedHash is %v but validationEntry has prevBlockHash %v for block number %v", v.lastBlockValidatedHash, validationEntry.PrevBlockHash, v.lastBlockValidated) - return - } - expectedEnd, err := validationEntry.expectedEnd() - if err != nil { - v.possiblyFatal(err) - return - } - for _, run := range validationStatus.Runs { - if !run.Ready() { - return - } - runEnd, err := run.Current() - if err == nil && runEnd != expectedEnd { - err = fmt.Errorf("validation failed: expected %v got %v", expectedEnd, runEnd) - writeErr := v.writeToFile(validationEntry, run.WasmModuleRoot()) - if writeErr != nil { - log.Warn("failed to write validation debugging info", "err", err) - } - v.possiblyFatal(err) - } - if err != nil { - v.possiblyFatal(err) - validationStatus.setStatus(Failed) - validatorFailedValidationsCounter.Inc(1) - validatorPendingValidationsGauge.Dec(1) - return - } - } - for _, run := range validationStatus.Runs { - run.Cancel() - } - validationStatus.replaceStatus(ValidationSent, Valid) - validatorValidValidationsCounter.Inc(1) - validatorPendingValidationsGauge.Dec(1) - v.triggerSendValidations() - earliestBatchKept := atomic.LoadUint64(&v.earliestBatchKept) - seqMsgNr := validationEntry.StartPosition.BatchNumber - if earliestBatchKept < seqMsgNr { - for batch := earliestBatchKept; batch < seqMsgNr; batch++ { - v.sequencerBatches.Delete(batch) - } - atomic.StoreUint64(&v.earliestBatchKept, seqMsgNr) - } - - v.lastBlockValidatedMutex.Lock() - atomic.StoreUint64(&v.lastBlockValidated, checkingBlock) - validatorLastBlockValidatedGauge.Update(int64(checkingBlock)) - v.lastBlockValidatedHash = validationEntry.BlockHash - err = v.writeLastValidatedToDb(validationEntry.BlockNumber, validationEntry.BlockHash, validationEntry.EndPosition) - if err != nil { - log.Error("failed to write validated entry to database", "err", err) +func (v *BlockValidator) validGSIsNew(globalState validator.GoGlobalState) bool { + if v.legacyValidInfo != nil { + if v.legacyValidInfo.AfterPosition.BatchNumber > globalState.Batch { + return false } - v.lastBlockValidatedMutex.Unlock() - v.recentlyValid(validationEntry.BlockHeader) - - v.validations.Delete(checkingBlock) - select { - case v.progressChan <- checkingBlock: - default: + if v.legacyValidInfo.AfterPosition.BatchNumber == globalState.Batch && v.legacyValidInfo.AfterPosition.PosInBatch >= globalState.PosInBatch { + return false } + return true } + if v.lastValidGS.Batch > globalState.Batch { + return false + } + if v.lastValidGS.Batch == globalState.Batch && v.lastValidGS.PosInBatch >= globalState.PosInBatch { + return false + } + return true } -func (v *BlockValidator) AssumeValid(globalState validator.GoGlobalState) error { +// this accepts globalstate even if not caught up +func (v *BlockValidator) InitAssumeValid(globalState validator.GoGlobalState) error { if v.Started() { - return fmt.Errorf("cannot handle AssumeValid while running") + return fmt.Errorf("cannot handle InitAssumeValid while running") } - v.reorgMutex.Lock() - defer v.reorgMutex.Unlock() - - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - - v.lastBlockValidatedMutex.Lock() - defer v.lastBlockValidatedMutex.Unlock() - // don't do anything if we already validated past that - if v.globalPosNextSend.BatchNumber > globalState.Batch { - return nil - } - if v.globalPosNextSend.BatchNumber == globalState.Batch && v.globalPosNextSend.PosInBatch > globalState.PosInBatch { + if !v.validGSIsNew(globalState) { return nil } - block := v.blockchain.GetBlockByHash(globalState.BlockHash) - if block == nil { - v.lastBlockValidatedUnknown = true - } else { - v.lastBlockValidated = block.NumberU64() - validatorLastBlockValidatedGauge.Update(int64(v.lastBlockValidated)) - v.nextBlockToValidate = v.lastBlockValidated + 1 - } - v.lastBlockValidatedHash = globalState.BlockHash - v.globalPosNextSend = GlobalStatePosition{ - BatchNumber: globalState.Batch, - PosInBatch: globalState.PosInBatch, + v.legacyValidInfo = nil + + err := v.writeLastValidated(v.lastValidGS, nil) + if err != nil { + log.Error("failed writing new validated to database", "pos", v.lastValidGS, "err", err) } + return nil } -func (v *BlockValidator) LastBlockValidated() uint64 { - return atomic.LoadUint64(&v.lastBlockValidated) -} +func (v *BlockValidator) UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) { -func (v *BlockValidator) LastBlockValidatedAndHash() (blockNumber uint64, blockHash common.Hash, wasmModuleRoots []common.Hash) { - v.lastBlockValidatedMutex.Lock() - blockValidated := v.lastBlockValidated - blockValidatedHash := v.lastBlockValidatedHash - v.lastBlockValidatedMutex.Unlock() + if count <= v.validated() { + return + } - // things can be removed from, but not added to, moduleRootsToValidate. By taking root hashes fter the block we know result is valid - moduleRootsValidated := v.GetModuleRootsToValidate() + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() - return blockValidated, blockValidatedHash, moduleRootsValidated -} + if count <= v.validated() { + return + } -// Because batches and blocks are handled at separate layers in the node, -// and because block generation from messages is asynchronous, -// this call is different than ReorgToBlock, which is currently called later. -func (v *BlockValidator) ReorgToBatchCount(count uint64) { - v.batchMutex.Lock() - defer v.batchMutex.Unlock() - v.reorgToBatchCountImpl(count) -} + if !v.chainCaughtUp { + if !v.validGSIsNew(globalState) { + return + } + v.legacyValidInfo = nil + err := v.writeLastValidated(globalState, nil) + if err != nil { + log.Error("error writing last validated", "err", err) + } + return + } -func (v *BlockValidator) reorgToBatchCountImpl(count uint64) { - localBatchCount := v.nextBatchKept - if localBatchCount < count { + countUint64 := uint64(count) + msg, err := v.streamer.GetMessage(count - 1) + if err != nil { + log.Error("getMessage error", "err", err, "count", count) return } - for i := count; i < localBatchCount; i++ { - v.sequencerBatches.Delete(i) + // delete no-longer relevant entries + for iPos := v.validated(); iPos < count && iPos < v.created(); iPos++ { + status, found := v.validations.Load(iPos) + if found && status != nil && status.Cancel != nil { + status.Cancel() + } + v.validations.Delete(iPos) + } + if v.created() < count { + v.nextCreateStartGS = globalState + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + v.nextCreateBatchReread = true + v.createdA = countUint64 + } + // under the reorg mutex we don't need atomic access + if v.recordSentA < countUint64 { + v.recordSentA = countUint64 + } + v.validatedA = countUint64 + v.valLoopPos = count + validatorMsgCountValidatedGauge.Update(int64(countUint64)) + err = v.writeLastValidated(v.lastValidGS, nil) // we don't know which wasm roots were validated + if err != nil { + log.Error("failed writing valid state after reorg", "err", err) } - v.nextBatchKept = count + nonBlockingTrigger(v.createNodesChan) } -func (v *BlockValidator) ProcessBatches(pos uint64, batches [][]byte) { - v.batchMutex.Lock() - defer v.batchMutex.Unlock() - - v.reorgToBatchCountImpl(pos) - - // Attempt to fill in earliestBatchKept if it's empty - atomic.CompareAndSwapUint64(&v.earliestBatchKept, 0, pos) - - for i, msg := range batches { - v.sequencerBatches.Store(pos+uint64(i), msg) +// Because batches and blocks are handled at separate layers in the node, +// and because block generation from messages is asynchronous, +// this call is different than Reorg, which is currently called later. +func (v *BlockValidator) ReorgToBatchCount(count uint64) { + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() + if v.nextCreateStartGS.Batch >= count { + v.nextCreateBatchReread = true } - v.nextBatchKept = pos + uint64(len(batches)) - v.triggerSendValidations() } -func (v *BlockValidator) ReorgToBlock(blockNum uint64, blockHash common.Hash) error { - atomic.AddInt32(&v.reorgsPending, 1) +func (v *BlockValidator) Reorg(ctx context.Context, count arbutil.MessageIndex) error { v.reorgMutex.Lock() defer v.reorgMutex.Unlock() - atomic.AddInt32(&v.reorgsPending, -1) - - v.blockMutex.Lock() - defer v.blockMutex.Unlock() - - v.lastBlockValidatedMutex.Lock() - defer v.lastBlockValidatedMutex.Unlock() - - if blockNum < v.lastValidationEntryBlock { - log.Warn("block validator processing reorg", "blockNum", blockNum) - err := v.reorgToBlockImpl(blockNum, blockHash) - if err != nil { - return fmt.Errorf("block validator reorg failed: %w", err) - } + if count <= 1 { + return errors.New("cannot reorg out genesis") } - - return nil -} - -// must hold reorgMutex, blockMutex, and lastBlockValidatedMutex -func (v *BlockValidator) reorgToBlockImpl(blockNum uint64, blockHash common.Hash) error { - for b := blockNum + 1; b <= v.lastValidationEntryBlock; b++ { - entry, found := v.validations.Load(b) - if !found { - continue - } - v.validations.Delete(b) - - validationStatus, ok := entry.(*validationStatus) - if !ok || (validationStatus == nil) { - log.Error("bad entry trying to reorg block validator") - continue - } - log.Debug("canceling validation due to reorg", "block", b) - if validationStatus.Cancel != nil { - validationStatus.Cancel() - } + if !v.chainCaughtUp { + return nil } - v.lastValidationEntryBlock = blockNum - if v.nextBlockToValidate <= blockNum+1 { + if v.created() < count { return nil } - msgIndex := arbutil.BlockNumberToMessageCount(blockNum, v.genesisBlockNum) - 1 - batchCount, err := v.inboxTracker.GetBatchCount() + _, endPosition, err := v.GlobalStatePositionsAtCount(count) if err != nil { + v.possiblyFatal(err) return err } - batch, err := FindBatchContainingMessageIndex(v.inboxTracker, msgIndex, batchCount) + res, err := v.streamer.ResultAtCount(count) if err != nil { + v.possiblyFatal(err) return err } - if batch >= batchCount { - // This reorg is past the latest batch. - // Attempt to recover by loading a next validation state at the start of the next batch. - v.globalPosNextSend = GlobalStatePosition{ - BatchNumber: batch, - PosInBatch: 0, - } - msgCount, err := v.inboxTracker.GetBatchMessageCount(batch - 1) - if err != nil { - return err - } - nextBlockSigned := arbutil.MessageCountToBlockNumber(msgCount, v.genesisBlockNum) + 1 - if nextBlockSigned <= 0 { - return errors.New("reorg past genesis block") - } - blockNum = uint64(nextBlockSigned) - 1 - block := v.blockchain.GetBlockByNumber(blockNum) - if block == nil { - return fmt.Errorf("failed to get end of batch block %v", blockNum) - } - blockHash = block.Hash() - v.lastValidationEntryBlock = blockNum - } else { - _, v.globalPosNextSend, err = GlobalStatePositionsFor(v.inboxTracker, msgIndex, batch) - if err != nil { - return err - } - } - if v.nextBlockToValidate > blockNum+1 { - v.nextBlockToValidate = blockNum + 1 + msg, err := v.streamer.GetMessage(count - 1) + if err != nil { + v.possiblyFatal(err) + return err } - - if v.lastBlockValidated > blockNum { - atomic.StoreUint64(&v.lastBlockValidated, blockNum) - validatorLastBlockValidatedGauge.Update(int64(blockNum)) - v.lastBlockValidatedHash = blockHash - - err = v.writeLastValidatedToDb(blockNum, blockHash, v.globalPosNextSend) + for iPos := count; iPos < v.created(); iPos++ { + status, found := v.validations.Load(iPos) + if found && status != nil && status.Cancel != nil { + status.Cancel() + } + v.validations.Delete(iPos) + } + v.nextCreateStartGS = buildGlobalState(*res, endPosition) + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + v.nextCreateBatchReread = true + countUint64 := uint64(count) + v.createdA = countUint64 + // under the reorg mutex we don't need atomic access + if v.recordSentA > countUint64 { + v.recordSentA = countUint64 + } + if v.validatedA > countUint64 { + v.validatedA = countUint64 + validatorMsgCountValidatedGauge.Update(int64(countUint64)) + err := v.writeLastValidated(v.lastValidGS, nil) // we don't know which wasm roots were validated if err != nil { - return err + log.Error("failed writing valid state after reorg", "err", err) } } - + nonBlockingTrigger(v.createNodesChan) return nil } @@ -1039,77 +966,182 @@ func (v *BlockValidator) Initialize(ctx context.Context) error { return nil } -func (v *BlockValidator) Start(ctxIn context.Context) error { - v.StopWaiter.Start(ctxIn, v) - err := stopwaiter.CallIterativelyWith[struct{}](v, - func(ctx context.Context, unused struct{}) time.Duration { - v.sendRecords(ctx) - v.sendValidations(ctx) - return v.config().ValidationPoll - }, - v.sendValidationsChan) +func (v *BlockValidator) checkLegacyValid() error { + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() + if v.legacyValidInfo == nil { + return nil + } + batchCount, err := v.inboxTracker.GetBatchCount() if err != nil { return err } - v.CallIteratively(func(ctx context.Context) time.Duration { - v.progressValidated() - return v.config().ValidationPoll - }) - lastValid := uint64(0) - v.CallIteratively(func(ctx context.Context) time.Duration { - newValid, validHash, wasmModuleRoots := v.LastBlockValidatedAndHash() - if newValid != lastValid { - validHeader := v.blockchain.GetHeader(validHash, newValid) - if validHeader == nil { - foundHeader := v.blockchain.GetHeaderByNumber(newValid) - foundHash := common.Hash{} - if foundHeader != nil { - foundHash = foundHeader.Hash() - } - log.Warn("last valid block not in blockchain", "blockNum", newValid, "validatedBlockHash", validHash, "found-hash", foundHash) - } else { - validTimestamp := time.Unix(int64(validHeader.Time), 0) - log.Info("Validated blocks", "blockNum", newValid, "hash", validHash, - "timestamp", validTimestamp, "age", time.Since(validTimestamp), "wasm", wasmModuleRoots) - } - lastValid = newValid + requiredBatchCount := v.legacyValidInfo.AfterPosition.BatchNumber + 1 + if v.legacyValidInfo.AfterPosition.PosInBatch == 0 { + requiredBatchCount -= 1 + } + if batchCount < requiredBatchCount { + log.Warn("legacy valid batch ahead of db", "current", batchCount, "required", requiredBatchCount) + return nil + } + msgCount, err := v.inboxTracker.GetBatchMessageCount(v.legacyValidInfo.AfterPosition.BatchNumber) + if err != nil { + return err + } + msgCount += arbutil.MessageIndex(v.legacyValidInfo.AfterPosition.PosInBatch) + processedCount, err := v.streamer.GetProcessedMessageCount() + if err != nil { + return err + } + if processedCount < msgCount { + log.Warn("legacy valid message count ahead of db", "current", processedCount, "required", msgCount) + return nil + } + result, err := v.streamer.ResultAtCount(msgCount) + if err != nil { + return err + } + if result.BlockHash != v.legacyValidInfo.BlockHash { + log.Error("legacy validated blockHash does not fit chain", "info.BlockHash", v.legacyValidInfo.BlockHash, "chain", result.BlockHash, "count", msgCount) + return fmt.Errorf("legacy validated blockHash does not fit chain") + } + validGS := validator.GoGlobalState{ + BlockHash: result.BlockHash, + SendRoot: result.SendRoot, + Batch: v.legacyValidInfo.AfterPosition.BatchNumber, + PosInBatch: v.legacyValidInfo.AfterPosition.PosInBatch, + } + err = v.writeLastValidated(validGS, nil) + if err == nil { + err = v.db.Delete(legacyLastBlockValidatedInfoKey) + if err != nil { + err = fmt.Errorf("deleting legacy: %w", err) } - return time.Second - }) + } + if err != nil { + log.Error("failed writing initial lastValid on upgrade from legacy", "new-info", v.lastValidGS, "err", err) + } else { + log.Info("updated last-valid from legacy", "lastValid", v.lastValidGS) + } + v.legacyValidInfo = nil return nil } -func (v *BlockValidator) StopAndWait() { - v.StopWaiter.StopAndWait() - err := v.recentShutdown() +// checks that the chain caught up to lastValidGS, used in startup +func (v *BlockValidator) checkValidatedGSCaughtUp() (bool, error) { + v.reorgMutex.Lock() + defer v.reorgMutex.Unlock() + if v.chainCaughtUp { + return true, nil + } + if v.legacyValidInfo != nil { + return false, nil + } + if v.lastValidGS.Batch == 0 { + return false, errors.New("lastValid not initialized. cannot validate genesis") + } + caughtUp, count, err := GlobalStateToMsgCount(v.inboxTracker, v.streamer, v.lastValidGS) if err != nil { - log.Error("error storing valid state", "err", err) + return false, err } + if !caughtUp { + batchCount, err := v.inboxTracker.GetBatchCount() + if err != nil { + log.Error("failed reading batch count", "err", err) + batchCount = 0 + } + batchMsgCount, err := v.inboxTracker.GetBatchMessageCount(batchCount - 1) + if err != nil { + log.Error("failed reading batchMsgCount", "err", err) + batchMsgCount = 0 + } + processedMsgCount, err := v.streamer.GetProcessedMessageCount() + if err != nil { + log.Error("failed reading processedMsgCount", "err", err) + processedMsgCount = 0 + } + log.Info("validator catching up to last valid", "lastValid.Batch", v.lastValidGS.Batch, "lastValid.PosInBatch", v.lastValidGS.PosInBatch, "batchCount", batchCount, "batchMsgCount", batchMsgCount, "processedMsgCount", processedMsgCount) + return false, nil + } + msg, err := v.streamer.GetMessage(count - 1) + if err != nil { + return false, err + } + v.nextCreateBatchReread = true + v.nextCreateStartGS = v.lastValidGS + v.nextCreatePrevDelayed = msg.DelayedMessagesRead + atomicStorePos(&v.createdA, count) + atomicStorePos(&v.recordSentA, count) + atomicStorePos(&v.validatedA, count) + validatorMsgCountValidatedGauge.Update(int64(count)) + v.chainCaughtUp = true + return true, nil } -// WaitForBlock can only be used from One thread -func (v *BlockValidator) WaitForBlock(ctx context.Context, blockNumber uint64, timeout time.Duration) bool { +func (v *BlockValidator) LaunchWorkthreadsWhenCaughtUp(ctx context.Context) { + for { + err := v.checkLegacyValid() + if err != nil { + log.Error("validator got error updating legacy validated info. Consider restarting with dangerous.reset-block-validation", "err", err) + } + caughtUp, err := v.checkValidatedGSCaughtUp() + if err != nil { + log.Error("validator got error waiting for chain to catch up. Consider restarting with dangerous.reset-block-validation", "err", err) + } + if caughtUp { + break + } + select { + case <-ctx.Done(): + return + case <-time.After(v.config().ValidationPoll): + } + } + err := stopwaiter.CallIterativelyWith[struct{}](&v.StopWaiterSafe, v.iterativeValidationEntryCreator, v.createNodesChan) + if err != nil { + v.possiblyFatal(err) + } + err = stopwaiter.CallIterativelyWith[struct{}](&v.StopWaiterSafe, v.iterativeValidationEntryRecorder, v.sendRecordChan) + if err != nil { + v.possiblyFatal(err) + } + err = stopwaiter.CallIterativelyWith[struct{}](&v.StopWaiterSafe, v.iterativeValidationProgress, v.progressValidationsChan) + if err != nil { + v.possiblyFatal(err) + } +} + +func (v *BlockValidator) Start(ctxIn context.Context) error { + v.StopWaiter.Start(ctxIn, v) + v.LaunchThread(v.LaunchWorkthreadsWhenCaughtUp) + v.CallIteratively(v.iterativeValidationPrint) + return nil +} + +func (v *BlockValidator) StopAndWait() { + v.StopWaiter.StopAndWait() +} + +// WaitForPos can only be used from One thread +func (v *BlockValidator) WaitForPos(t *testing.T, ctx context.Context, pos arbutil.MessageIndex, timeout time.Duration) bool { + triggerchan := make(chan struct{}) + v.testingProgressMadeChan = triggerchan timer := time.NewTimer(timeout) defer timer.Stop() + lastLoop := false for { - if atomic.LoadUint64(&v.lastBlockValidated) >= blockNumber { + if v.validated() > pos { return true } + if lastLoop { + return false + } select { case <-timer.C: - if atomic.LoadUint64(&v.lastBlockValidated) >= blockNumber { - return true - } - return false - case block, ok := <-v.progressChan: - if block >= blockNumber { - return true - } - if !ok { - return false - } + lastLoop = true + case <-triggerchan: case <-ctx.Done(): - return false + lastLoop = true } } } diff --git a/staker/block_validator_schema.go b/staker/block_validator_schema.go index e5d7dba71b..f6eb39f015 100644 --- a/staker/block_validator_schema.go +++ b/staker/block_validator_schema.go @@ -3,14 +3,23 @@ package staker -import "github.com/ethereum/go-ethereum/common" +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/offchainlabs/nitro/validator" +) -type LastBlockValidatedDbInfo struct { +type legacyLastBlockValidatedDbInfo struct { BlockNumber uint64 BlockHash common.Hash AfterPosition GlobalStatePosition } +type GlobalStateValidatedInfo struct { + GlobalState validator.GoGlobalState + WasmRoots []common.Hash +} + var ( - lastBlockValidatedInfoKey = []byte("_lastBlockValidatedInfo") // contains a rlp encoded lastBlockValidatedDbInfo + lastGlobalStateValidatedInfoKey = []byte("_lastGlobalStateValidatedInfo") // contains a rlp encoded lastBlockValidatedDbInfo + legacyLastBlockValidatedInfoKey = []byte("_lastBlockValidatedInfo") // LEGACY - contains a rlp encoded lastBlockValidatedDbInfo ) diff --git a/staker/challenge_manager.go b/staker/challenge_manager.go index 8d35d962d3..ac2ae8835a 100644 --- a/staker/challenge_manager.go +++ b/staker/challenge_manager.go @@ -14,10 +14,10 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind/backends" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rpc" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/solgen/go/challengegen" "github.com/offchainlabs/nitro/validator" ) @@ -72,9 +72,9 @@ type ChallengeManager struct { wasmModuleRoot common.Hash // these fields are empty until working on execution challenge - initialMachineBlockNr int64 - executionChallengeBackend *ExecutionChallengeBackend - machineFinalStepCount uint64 + initialMachineMessageCount arbutil.MessageIndex + executionChallengeBackend *ExecutionChallengeBackend + machineFinalStepCount uint64 } // NewChallengeManager constructs a new challenge manager. @@ -86,8 +86,6 @@ func NewChallengeManager( fromAddr common.Address, challengeManagerAddr common.Address, challengeIndex uint64, - l2blockChain *core.BlockChain, - inboxTracker InboxTrackerInterface, val *StatelessBlockValidator, startL1Block uint64, confirmationBlocks int64, @@ -125,13 +123,11 @@ func NewChallengeManager( return nil, fmt.Errorf("error getting challenge %v info: %w", challengeIndex, err) } - genesisBlockNum := l2blockChain.Config().ArbitrumChainParams.GenesisBlockNum backend, err := NewBlockChallengeBackend( parsedLog, challengeInfo.MaxInboxMessages, - l2blockChain, - inboxTracker, - genesisBlockNum, + val.streamer, + val.inboxTracker, ) if err != nil { return nil, fmt.Errorf("error creating block challenge backend for challenge %v: %w", challengeIndex, err) @@ -462,23 +458,18 @@ func (m *ChallengeManager) IssueOneStepProof( } func (m *ChallengeManager) createExecutionBackend(ctx context.Context, step uint64) error { - blockNum := m.blockChallengeBackend.GetBlockNrAtStep(step) - // Get the next message and block header, and record the full block creation - if m.initialMachineBlockNr == blockNum && m.executionChallengeBackend != nil { + initialCount := m.blockChallengeBackend.GetMessageCountAtStep(step) + if m.initialMachineMessageCount == initialCount && m.executionChallengeBackend != nil { return nil } m.executionChallengeBackend = nil - nextHeader := m.blockChallengeBackend.bc.GetHeaderByNumber(uint64(blockNum + 1)) - if nextHeader == nil { - return fmt.Errorf("next block header %v after challenge point unknown", blockNum+1) - } - entry, err := m.validator.CreateReadyValidationEntry(ctx, nextHeader) + entry, err := m.validator.CreateReadyValidationEntry(ctx, initialCount) if err != nil { - return fmt.Errorf("error creating validation entry for challenge %v block %v for execution challenge: %w", m.challengeIndex, blockNum, err) + return fmt.Errorf("error creating validation entry for challenge %v msg %v for execution challenge: %w", m.challengeIndex, initialCount, err) } input, err := entry.ToInput() if err != nil { - return fmt.Errorf("error getting validation entry input of challenge %v block %v: %w", m.challengeIndex, blockNum, err) + return fmt.Errorf("error getting validation entry input of challenge %v msg %v: %w", m.challengeIndex, initialCount, err) } var prunedBatches []validator.BatchInfo for _, batch := range input.BatchInfo { @@ -489,7 +480,7 @@ func (m *ChallengeManager) createExecutionBackend(ctx context.Context, step uint input.BatchInfo = prunedBatches execRun, err := m.validator.execSpawner.CreateExecutionRun(m.wasmModuleRoot, input).Await(ctx) if err != nil { - return fmt.Errorf("error creating execution backend for block %v: %w", blockNum, err) + return fmt.Errorf("error creating execution backend for msg %v: %w", initialCount, err) } backend, err := NewExecutionChallengeBackend(execRun) if err != nil { @@ -504,16 +495,16 @@ func (m *ChallengeManager) createExecutionBackend(ctx context.Context, step uint return fmt.Errorf("error getting execution challenge final state: %w", err) } if expectedStatus != computedStatus { - return fmt.Errorf("after block %v expected status %v but got %v", blockNum, expectedStatus, computedStatus) + return fmt.Errorf("after msg %v expected status %v but got %v", initialCount, expectedStatus, computedStatus) } if computedStatus == StatusFinished { if computedState != expectedState { - return fmt.Errorf("after block %v expected global state %v but got %v", blockNum, expectedState, computedState) + return fmt.Errorf("after msg %v expected global state %v but got %v", initialCount, expectedState, computedState) } } m.executionChallengeBackend = backend m.machineFinalStepCount = machineStepCount - m.initialMachineBlockNr = blockNum + m.initialMachineMessageCount = initialCount return nil } @@ -569,8 +560,7 @@ func (m *ChallengeManager) Act(ctx context.Context) (*types.Transaction, error) return nil, fmt.Errorf("error creating execution backend: %w", err) } machineStepCount := m.machineFinalStepCount - blockNum := m.initialMachineBlockNr - log.Info("issuing one step proof", "challenge", m.challengeIndex, "machineStepCount", machineStepCount, "blockNum", blockNum) + log.Info("issuing one step proof", "challenge", m.challengeIndex, "machineStepCount", machineStepCount, "initialCount", m.initialMachineMessageCount) return m.blockChallengeBackend.IssueExecChallenge( m.challengeCore, state, diff --git a/staker/l1_validator.go b/staker/l1_validator.go index f16d4c9538..7bdce64ba6 100644 --- a/staker/l1_validator.go +++ b/staker/l1_validator.go @@ -15,7 +15,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" @@ -41,16 +40,14 @@ const ( ) type L1Validator struct { - rollup *RollupWatcher - rollupAddress common.Address - validatorUtils *rollupgen.ValidatorUtils - client arbutil.L1Interface - builder *ValidatorTxBuilder - wallet ValidatorWalletInterface - callOpts bind.CallOpts - genesisBlockNumber uint64 - - l2Blockchain *core.BlockChain + rollup *RollupWatcher + rollupAddress common.Address + validatorUtils *rollupgen.ValidatorUtils + client arbutil.L1Interface + builder *ValidatorTxBuilder + wallet ValidatorWalletInterface + callOpts bind.CallOpts + das arbstate.DataAvailabilityReader inboxTracker InboxTrackerInterface txStreamer TransactionStreamerInterface @@ -63,7 +60,6 @@ func NewL1Validator( wallet ValidatorWalletInterface, validatorUtilsAddress common.Address, callOpts bind.CallOpts, - l2Blockchain *core.BlockChain, das arbstate.DataAvailabilityReader, inboxTracker InboxTrackerInterface, txStreamer TransactionStreamerInterface, @@ -84,24 +80,18 @@ func NewL1Validator( if err != nil { return nil, err } - genesisBlockNumber, err := txStreamer.GetGenesisBlockNumber() - if err != nil { - return nil, err - } return &L1Validator{ - rollup: rollup, - rollupAddress: wallet.RollupAddress(), - validatorUtils: validatorUtils, - client: client, - builder: builder, - wallet: wallet, - callOpts: callOpts, - genesisBlockNumber: genesisBlockNumber, - l2Blockchain: l2Blockchain, - das: das, - inboxTracker: inboxTracker, - txStreamer: txStreamer, - blockValidator: blockValidator, + rollup: rollup, + rollupAddress: wallet.RollupAddress(), + validatorUtils: validatorUtils, + client: client, + builder: builder, + wallet: wallet, + callOpts: callOpts, + das: das, + inboxTracker: inboxTracker, + txStreamer: txStreamer, + blockValidator: blockValidator, }, nil } @@ -231,35 +221,6 @@ type OurStakerInfo struct { *StakerInfo } -// Returns (block number, global state inbox position is invalid, error). -// If global state is invalid, block number is set to the last of the batch. -func (v *L1Validator) blockNumberFromGlobalState(gs validator.GoGlobalState) (int64, bool, error) { - var batchHeight arbutil.MessageIndex - if gs.Batch > 0 { - var err error - batchHeight, err = v.inboxTracker.GetBatchMessageCount(gs.Batch - 1) - if err != nil { - return 0, false, err - } - } - - // Validate the PosInBatch if it's non-zero - if gs.PosInBatch > 0 { - nextBatchHeight, err := v.inboxTracker.GetBatchMessageCount(gs.Batch) - if err != nil { - return 0, false, err - } - - if gs.PosInBatch >= uint64(nextBatchHeight-batchHeight) { - // This PosInBatch would enter the next batch. Return the last block before the next batch. - // We can be sure that MessageCountToBlockNumber will return a non-negative number as nextBatchHeight must be nonzero. - return arbutil.MessageCountToBlockNumber(nextBatchHeight, v.genesisBlockNumber), true, nil - } - } - - return arbutil.MessageCountToBlockNumber(batchHeight+arbutil.MessageIndex(gs.PosInBatch), v.genesisBlockNumber), false, nil -} - func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurStakerInfo, strategy StakerStrategy, makeAssertionInterval time.Duration) (nodeAction, bool, error) { startState, prevInboxMaxCount, startStateProposedL1, startStateProposedParentChain, err := lookupNodeStartState(ctx, v.rollup, stakerInfo.LatestStakedNode, stakerInfo.LatestStakedNodeHash) if err != nil { @@ -279,68 +240,90 @@ func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurSta if err != nil { return nil, false, fmt.Errorf("error getting batch count from inbox tracker: %w", err) } - if localBatchCount < startState.RequiredBatches() { + if localBatchCount < startState.RequiredBatches() || localBatchCount == 0 { log.Info("catching up to chain batches", "localBatches", localBatchCount, "target", startState.RequiredBatches()) return nil, false, nil } - startBlock := v.l2Blockchain.GetBlockByHash(startState.GlobalState.BlockHash) - if startBlock == nil && (startState.GlobalState != validator.GoGlobalState{}) { - expectedBlockHeight, inboxPositionInvalid, err := v.blockNumberFromGlobalState(startState.GlobalState) - if err != nil { - return nil, false, fmt.Errorf("error getting block number from global state: %w", err) + caughtUp, startCount, err := GlobalStateToMsgCount(v.inboxTracker, v.txStreamer, startState.GlobalState) + if err != nil { + return nil, false, fmt.Errorf("start state not in chain: %w", err) + } + if !caughtUp { + target := GlobalStatePosition{ + BatchNumber: startState.GlobalState.Batch, + PosInBatch: startState.GlobalState.PosInBatch, } - if inboxPositionInvalid { - log.Error("invalid start global state inbox position", startState.GlobalState.BlockHash, "batch", startState.GlobalState.Batch, "pos", startState.GlobalState.PosInBatch) - return nil, false, errors.New("invalid start global state inbox position") + var current GlobalStatePosition + head, err := v.txStreamer.GetProcessedMessageCount() + if err != nil { + _, current, err = v.blockValidator.GlobalStatePositionsAtCount(head) } - latestHeader := v.l2Blockchain.CurrentBlock() - if latestHeader.Number.Int64() < expectedBlockHeight { - log.Info("catching up to chain blocks", "localBlocks", latestHeader.Number, "target", expectedBlockHeight) - return nil, false, nil + if err != nil { + log.Info("catching up to chain messages", "target", target) + } else { + log.Info("catching up to chain blocks", "target", target, "current", current) } - log.Error("unknown start block hash", "hash", startState.GlobalState.BlockHash, "batch", startState.GlobalState.Batch, "pos", startState.GlobalState.PosInBatch) - return nil, false, errors.New("unknown start block hash") + return nil, false, nil } - var lastBlockValidated uint64 + var validatedCount arbutil.MessageIndex + var validatedGlobalState validator.GoGlobalState if v.blockValidator != nil { - var expectedHash common.Hash - var validRoots []common.Hash - lastBlockValidated, expectedHash, validRoots = v.blockValidator.LastBlockValidatedAndHash() - haveHash := v.l2Blockchain.GetCanonicalHash(lastBlockValidated) - if haveHash != expectedHash { - return nil, false, fmt.Errorf("block validator validated block %v as hash %v but blockchain has hash %v", lastBlockValidated, expectedHash, haveHash) + valInfo, err := v.blockValidator.ReadLastValidatedInfo() + if err != nil || valInfo == nil { + return nil, false, err + } + validatedGlobalState = valInfo.GlobalState + caughtUp, validatedCount, err = GlobalStateToMsgCount(v.inboxTracker, v.txStreamer, valInfo.GlobalState) + if err != nil { + return nil, false, fmt.Errorf("%w: not found validated block in blockchain", err) + } + if !caughtUp { + log.Info("catching up to last validated block", "target", valInfo.GlobalState) + return nil, false, nil } if err := v.updateBlockValidatorModuleRoot(ctx); err != nil { return nil, false, fmt.Errorf("error updating block validator module root: %w", err) } wasmRootValid := false - for _, root := range validRoots { + for _, root := range valInfo.WasmRoots { if v.lastWasmModuleRoot == root { wasmRootValid = true break } } if !wasmRootValid { - return nil, false, fmt.Errorf("wasmroot doesn't match rollup : %v, valid: %v", v.lastWasmModuleRoot, validRoots) + return nil, false, fmt.Errorf("wasmroot doesn't match rollup : %v, valid: %v", v.lastWasmModuleRoot, valInfo.WasmRoots) } } else { - lastBlockValidated = v.l2Blockchain.CurrentBlock().Number.Uint64() - - if localBatchCount > 0 { - messageCount, err := v.inboxTracker.GetBatchMessageCount(localBatchCount - 1) + validatedCount, err = v.txStreamer.GetProcessedMessageCount() + if err != nil || validatedCount == 0 { + return nil, false, err + } + var batchNum uint64 + messageCount, err := v.inboxTracker.GetBatchMessageCount(localBatchCount - 1) + if err != nil { + return nil, false, fmt.Errorf("error getting latest batch %v message count: %w", localBatchCount-1, err) + } + if validatedCount >= messageCount { + batchNum = localBatchCount - 1 + validatedCount = messageCount + } else { + batchNum, err = FindBatchContainingMessageIndex(v.inboxTracker, validatedCount-1, localBatchCount) if err != nil { - return nil, false, fmt.Errorf("error getting latest batch %v message count: %w", localBatchCount-1, err) - } - // Must be non-negative as a batch must contain at least one message - lastBatchBlock := uint64(arbutil.MessageCountToBlockNumber(messageCount, v.genesisBlockNumber)) - if lastBlockValidated > lastBatchBlock { - lastBlockValidated = lastBatchBlock + return nil, false, err } - } else { - lastBlockValidated = 0 } + execResult, err := v.txStreamer.ResultAtCount(validatedCount) + if err != nil { + return nil, false, err + } + _, gsPos, err := GlobalStatePositionsAtCount(v.inboxTracker, validatedCount, batchNum) + if err != nil { + return nil, false, fmt.Errorf("%w: failed calculating GSposition for count %d", err, validatedCount) + } + validatedGlobalState = buildGlobalState(*execResult, gsPos) } currentL1BlockNum, err := v.client.BlockNumber(ctx) @@ -379,83 +362,55 @@ func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurSta // We've found everything we could hope to find break } - if correctNode == nil { - afterGs := nd.AfterState().GlobalState - requiredBatches := nd.AfterState().RequiredBatches() - if localBatchCount < requiredBatches { - return nil, false, fmt.Errorf("waiting for validator to catch up to assertion batches: %v/%v", localBatchCount, requiredBatches) - } - if requiredBatches > 0 { - haveAcc, err := v.inboxTracker.GetBatchAcc(requiredBatches - 1) - if err != nil { - return nil, false, fmt.Errorf("error getting batch %v accumulator: %w", requiredBatches-1, err) - } - if haveAcc != nd.AfterInboxBatchAcc { - return nil, false, fmt.Errorf("missed sequencer batches reorg: at seq num %v have acc %v but assertion has acc %v", requiredBatches-1, haveAcc, nd.AfterInboxBatchAcc) - } - } - lastBlockNum, inboxPositionInvalid, err := v.blockNumberFromGlobalState(afterGs) - if err != nil { - return nil, false, fmt.Errorf("error getting block number from global state: %w", err) - } - if int64(lastBlockValidated) < lastBlockNum { - return nil, false, fmt.Errorf("waiting for validator to catch up to assertion blocks: %v/%v", lastBlockValidated, lastBlockNum) - } - var expectedBlockHash common.Hash - var expectedSendRoot common.Hash - if lastBlockNum >= 0 { - lastBlock := v.l2Blockchain.GetBlockByNumber(uint64(lastBlockNum)) - if lastBlock == nil { - return nil, false, fmt.Errorf("block %v not in database despite being validated", lastBlockNum) - } - lastBlockExtra := types.DeserializeHeaderExtraInformation(lastBlock.Header()) - expectedBlockHash = lastBlock.Hash() - expectedSendRoot = lastBlockExtra.SendRoot - } - - var expectedNumBlocks uint64 - if startBlock == nil { - expectedNumBlocks = uint64(lastBlockNum + 1) - } else { - expectedNumBlocks = uint64(lastBlockNum) - startBlock.NumberU64() - } - valid := !inboxPositionInvalid && - nd.Assertion.NumBlocks == expectedNumBlocks && - afterGs.BlockHash == expectedBlockHash && - afterGs.SendRoot == expectedSendRoot && - nd.Assertion.AfterState.MachineStatus == validator.MachineStatusFinished - if valid { - log.Info( - "found correct assertion", - "node", nd.NodeNum, - "blockNum", lastBlockNum, - "blockHash", afterGs.BlockHash, - ) - correctNode = existingNodeAction{ - number: nd.NodeNum, - hash: nd.NodeHash, - } - continue - } else { - log.Error( - "found incorrect assertion", - "node", nd.NodeNum, - "inboxPositionInvalid", inboxPositionInvalid, - "computedBlockNum", lastBlockNum, - "numBlocks", nd.Assertion.NumBlocks, - "expectedNumBlocks", expectedNumBlocks, - "blockHash", afterGs.BlockHash, - "expectedBlockHash", expectedBlockHash, - "sendRoot", afterGs.SendRoot, - "expectedSendRoot", expectedSendRoot, - "machineStatus", nd.Assertion.AfterState.MachineStatus, - ) - } - } else { + if correctNode != nil { log.Error("found younger sibling to correct assertion (implicitly invalid)", "node", nd.NodeNum) + wrongNodesExist = true + continue + } + afterGS := nd.AfterState().GlobalState + requiredBatch := afterGS.Batch + if afterGS.PosInBatch == 0 && afterGS.Batch > 0 { + requiredBatch -= 1 + } + if localBatchCount <= requiredBatch { + log.Info("staker: waiting for node to catch up to assertion batch", "current", localBatchCount, "target", requiredBatch-1) + return nil, false, nil + } + nodeBatchMsgCount, err := v.inboxTracker.GetBatchMessageCount(requiredBatch) + if err != nil { + return nil, false, err + } + if validatedCount < nodeBatchMsgCount { + log.Info("staker: waiting for validator to catch up to assertion batch messages", "current", validatedCount, "target", nodeBatchMsgCount) + return nil, false, nil + } + if nd.Assertion.AfterState.MachineStatus != validator.MachineStatusFinished { + wrongNodesExist = true + log.Error("Found incorrect assertion: Machine status not finished", "node", nd.NodeNum, "machineStatus", nd.Assertion.AfterState.MachineStatus) + continue + } + caughtUp, nodeMsgCount, err := GlobalStateToMsgCount(v.inboxTracker, v.txStreamer, afterGS) + if errors.Is(err, ErrGlobalStateNotInChain) { + wrongNodesExist = true + log.Error("Found incorrect assertion", "node", nd.NodeNum, "afterGS", afterGS, "err", err) + continue + } + if err != nil { + return nil, false, fmt.Errorf("error getting message number from global state: %w", err) + } + if !caughtUp { + return nil, false, fmt.Errorf("unexpected no-caught-up parsing assertion. Current: %d target: %v", validatedCount, afterGS) + } + log.Info( + "found correct assertion", + "node", nd.NodeNum, + "count", nodeMsgCount, + "blockHash", afterGS.BlockHash, + ) + correctNode = existingNodeAction{ + number: nd.NodeNum, + hash: nd.NodeHash, } - // If we've hit this point, the node is "wrong" - wrongNodesExist = true } if correctNode != nil || strategy == WatchtowerStrategy { @@ -468,9 +423,9 @@ func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurSta if len(successorNodes) > 0 { lastNodeHashIfExists = &successorNodes[len(successorNodes)-1].NodeHash } - action, err := v.createNewNodeAction(ctx, stakerInfo, lastBlockValidated, localBatchCount, prevInboxMaxCount, startBlock, startState, lastNodeHashIfExists) + action, err := v.createNewNodeAction(ctx, stakerInfo, prevInboxMaxCount, startCount, startState, validatedCount, validatedGlobalState, lastNodeHashIfExists) if err != nil { - return nil, wrongNodesExist, fmt.Errorf("error generating create new node action (from start block %v to last block validated %v): %w", startBlock, lastBlockValidated, err) + return nil, wrongNodesExist, fmt.Errorf("error generating create new node action (from pos %d to %d): %w", startCount, validatedCount, err) } return action, wrongNodesExist, nil } @@ -481,79 +436,33 @@ func (v *L1Validator) generateNodeAction(ctx context.Context, stakerInfo *OurSta func (v *L1Validator) createNewNodeAction( ctx context.Context, stakerInfo *OurStakerInfo, - lastBlockValidated uint64, - localBatchCount uint64, prevInboxMaxCount *big.Int, - startBlock *types.Block, + startCount arbutil.MessageIndex, startState *validator.ExecutionState, + validatedCount arbutil.MessageIndex, + validatedGS validator.GoGlobalState, lastNodeHashIfExists *common.Hash, ) (nodeAction, error) { if !prevInboxMaxCount.IsUint64() { return nil, fmt.Errorf("inbox max count %v isn't a uint64", prevInboxMaxCount) } - minBatchCount := prevInboxMaxCount.Uint64() - if localBatchCount < minBatchCount { - // not enough batches in database - return nil, nil - } - - if localBatchCount == 0 { - // we haven't validated anything - return nil, nil - } - if startBlock != nil && lastBlockValidated <= startBlock.NumberU64() { + if validatedCount <= startCount { // we haven't validated any new blocks return nil, nil } - var assertionCoversBatch uint64 - var afterGsBatch uint64 - var afterGsPosInBatch uint64 - for i := localBatchCount - 1; i+1 >= minBatchCount && i > 0; i-- { - batchMessageCount, err := v.inboxTracker.GetBatchMessageCount(i) - if err != nil { - return nil, fmt.Errorf("error getting batch %v message count: %w", i, err) - } - prevBatchMessageCount, err := v.inboxTracker.GetBatchMessageCount(i - 1) - if err != nil { - return nil, fmt.Errorf("error getting previous batch %v message count: %w", i-1, err) - } - // Must be non-negative as a batch must contain at least one message - lastBlockNum := uint64(arbutil.MessageCountToBlockNumber(batchMessageCount, v.genesisBlockNumber)) - prevBlockNum := uint64(arbutil.MessageCountToBlockNumber(prevBatchMessageCount, v.genesisBlockNumber)) - if lastBlockValidated > lastBlockNum { - return nil, fmt.Errorf("%v blocks have been validated but only %v appear in the latest batch", lastBlockValidated, lastBlockNum) - } - if lastBlockValidated > prevBlockNum { - // We found the batch containing the last validated block - if i+1 == minBatchCount && lastBlockValidated < lastBlockNum { - // We haven't reached the minimum assertion size yet - break - } - assertionCoversBatch = i - if lastBlockValidated < lastBlockNum { - afterGsBatch = i - afterGsPosInBatch = lastBlockValidated - prevBlockNum - } else { - afterGsBatch = i + 1 - afterGsPosInBatch = 0 - } - break - } - } - if assertionCoversBatch == 0 { - // we haven't validated the next batch completely + if validatedGS.Batch < prevInboxMaxCount.Uint64() { + // didn't validate enough batches + log.Info("staker: not enough batches validated to create new assertion", "validated.Batch", validatedGS.Batch, "posInBatch", validatedGS.PosInBatch, "required batch", prevInboxMaxCount) return nil, nil } - validatedBatchAcc, err := v.inboxTracker.GetBatchAcc(assertionCoversBatch) - if err != nil { - return nil, fmt.Errorf("error getting batch %v accumulator: %w", assertionCoversBatch, err) + batchValidated := validatedGS.Batch + if validatedGS.PosInBatch == 0 { + batchValidated-- } - - assertingBlock := v.l2Blockchain.GetBlockByNumber(lastBlockValidated) - if assertingBlock == nil { - return nil, fmt.Errorf("missing validated block %v", lastBlockValidated) + validatedBatchAcc, err := v.inboxTracker.GetBatchAcc(batchValidated) + if err != nil { + return nil, fmt.Errorf("error getting batch %v accumulator: %w", batchValidated, err) } - assertingBlockExtra := types.DeserializeHeaderExtraInformation(assertingBlock.Header()) hasSiblingByte := [1]byte{0} prevNum := stakerInfo.LatestStakedNode @@ -562,21 +471,11 @@ func (v *L1Validator) createNewNodeAction( lastHash = *lastNodeHashIfExists hasSiblingByte[0] = 1 } - var assertionNumBlocks uint64 - if startBlock == nil { - assertionNumBlocks = assertingBlock.NumberU64() + 1 - } else { - assertionNumBlocks = assertingBlock.NumberU64() - startBlock.NumberU64() - } + assertionNumBlocks := uint64(validatedCount - startCount) assertion := &Assertion{ BeforeState: startState, AfterState: &validator.ExecutionState{ - GlobalState: validator.GoGlobalState{ - BlockHash: assertingBlock.Hash(), - SendRoot: assertingBlockExtra.SendRoot, - Batch: afterGsBatch, - PosInBatch: afterGsPosInBatch, - }, + GlobalState: validatedGS, MachineStatus: validator.MachineStatusFinished, }, NumBlocks: assertionNumBlocks, diff --git a/staker/staker.go b/staker/staker.go index baaf81f9a2..f20a045d37 100644 --- a/staker/staker.go +++ b/staker/staker.go @@ -19,9 +19,11 @@ import ( "github.com/ethereum/go-ethereum/metrics" flag "github.com/spf13/pflag" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/cmd/genericconf" "github.com/offchainlabs/nitro/util/arbmath" "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/offchainlabs/nitro/validator" ) var ( @@ -186,10 +188,15 @@ type nodeAndHash struct { hash common.Hash } +type LatestStakedNotifier interface { + UpdateLatestStaked(count arbutil.MessageIndex, globalState validator.GoGlobalState) +} + type Staker struct { *L1Validator stopwaiter.StopWaiter l1Reader L1ReaderInterface + notifiers []LatestStakedNotifier activeChallenge *ChallengeManager baseCallOpts bind.CallOpts config L1ValidatorConfig @@ -199,6 +206,7 @@ type Staker struct { bringActiveUntilNode uint64 inboxReader InboxReaderInterface statelessBlockValidator *StatelessBlockValidator + fatalErr chan<- error } func NewStaker( @@ -208,7 +216,9 @@ func NewStaker( config L1ValidatorConfig, blockValidator *BlockValidator, statelessBlockValidator *StatelessBlockValidator, + notifiers []LatestStakedNotifier, validatorUtilsAddress common.Address, + fatalErr chan<- error, ) (*Staker, error) { if err := config.Validate(); err != nil { @@ -216,20 +226,25 @@ func NewStaker( } client := l1Reader.Client() val, err := NewL1Validator(client, wallet, validatorUtilsAddress, callOpts, - statelessBlockValidator.blockchain, statelessBlockValidator.daService, statelessBlockValidator.inboxTracker, statelessBlockValidator.streamer, blockValidator) + statelessBlockValidator.daService, statelessBlockValidator.inboxTracker, statelessBlockValidator.streamer, blockValidator) if err != nil { return nil, err } stakerLastSuccessfulActionGauge.Update(time.Now().Unix()) + if config.StartFromStaked { + notifiers = append(notifiers, blockValidator) + } return &Staker{ L1Validator: val, l1Reader: l1Reader, + notifiers: notifiers, baseCallOpts: callOpts, config: config, highGasBlocksBuffer: big.NewInt(config.L1PostingStrategy.HighGasDelayBlocks), lastActCalledBlock: nil, inboxReader: statelessBlockValidator.inboxReader, statelessBlockValidator: statelessBlockValidator, + fatalErr: fatalErr, }, nil } @@ -257,9 +272,54 @@ func (s *Staker) Initialize(ctx context.Context) error { return err } - return s.blockValidator.AssumeValid(stakedInfo.AfterState().GlobalState) + return s.blockValidator.InitAssumeValid(stakedInfo.AfterState().GlobalState) + } + return nil +} + +func (s *Staker) checkLatestStaked(ctx context.Context) error { + latestStaked, _, err := s.validatorUtils.LatestStaked(&s.baseCallOpts, s.rollupAddress, s.wallet.AddressOrZero()) + if err != nil { + return fmt.Errorf("couldn't get LatestStaked: %w", err) + } + stakerLatestStakedNodeGauge.Update(int64(latestStaked)) + if latestStaked == 0 { + return nil + } + + stakedInfo, err := s.rollup.LookupNode(ctx, latestStaked) + if err != nil { + return fmt.Errorf("couldn't look up latest node: %w", err) + } + + stakedGlobalState := stakedInfo.AfterState().GlobalState + caughtUp, count, err := GlobalStateToMsgCount(s.inboxTracker, s.txStreamer, stakedGlobalState) + if err != nil { + if errors.Is(err, ErrGlobalStateNotInChain) && s.fatalErr != nil { + fatal := fmt.Errorf("latest staked not in chain: %w", err) + s.fatalErr <- fatal + } + return fmt.Errorf("staker: latest staked %w", err) + } + + if !caughtUp { + log.Info("latest valid not yet in our node", "staked", stakedGlobalState) + return nil } + processedCount, err := s.txStreamer.GetProcessedMessageCount() + if err != nil { + return err + } + + if processedCount < count { + log.Info("execution catching up to last validated", "validatedCount", count, "processedCount", processedCount) + return nil + } + + for _, notifier := range s.notifiers { + notifier.UpdateLatestStaked(count, stakedGlobalState) + } return nil } @@ -317,6 +377,13 @@ func (s *Staker) Start(ctxIn context.Context) { } return backoff }) + s.CallIteratively(func(ctx context.Context) time.Duration { + err := s.checkLatestStaked(ctx) + if err != nil && ctx.Err() == nil { + log.Error("staker: error checking latest staked", "err", err) + } + return s.config.StakerInterval + }) } func (s *Staker) IsWhitelisted(ctx context.Context) (bool, error) { @@ -609,8 +676,6 @@ func (s *Staker) handleConflict(ctx context.Context, info *StakerInfo) error { *s.builder.wallet.Address(), s.wallet.ChallengeManagerAddress(), *info.CurrentChallenge, - s.l2Blockchain, - s.inboxTracker, s.statelessBlockValidator, latestConfirmedCreated, s.config.ConfirmationBlocks, diff --git a/staker/stateless_block_validator.go b/staker/stateless_block_validator.go index c56eb3e9a7..0242daa3c7 100644 --- a/staker/stateless_block_validator.go +++ b/staker/stateless_block_validator.go @@ -8,22 +8,20 @@ import ( "errors" "fmt" "sync" + "testing" + "github.com/offchainlabs/nitro/arbnode/execution" "github.com/offchainlabs/nitro/util/rpcclient" "github.com/offchainlabs/nitro/validator/server_api" "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/validator" - "github.com/ethereum/go-ethereum/arbitrum" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" - "github.com/offchainlabs/nitro/arbos" - "github.com/offchainlabs/nitro/arbos/arbosState" "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbstate" ) @@ -34,14 +32,13 @@ type StatelessBlockValidator struct { execSpawner validator.ExecutionSpawner validationSpawners []validator.ValidationSpawner - inboxReader InboxReaderInterface - inboxTracker InboxTrackerInterface - streamer TransactionStreamerInterface - blockchain *core.BlockChain - db ethdb.Database - daService arbstate.DataAvailabilityReader - genesisBlockNum uint64 - recordingDatabase *arbitrum.RecordingDatabase + recorder BlockRecorder + + inboxReader InboxReaderInterface + inboxTracker InboxTrackerInterface + streamer TransactionStreamerInterface + db ethdb.Database + daService arbstate.DataAvailabilityReader moduleMutex sync.Mutex currentWasmModuleRoot common.Hash @@ -52,6 +49,16 @@ type BlockValidatorRegistrer interface { SetBlockValidator(*BlockValidator) } +type BlockRecorder interface { + RecordBlockCreation( + ctx context.Context, + pos arbutil.MessageIndex, + msg *arbostypes.MessageWithMetadata, + ) (*execution.RecordResult, error) + MarkValid(pos arbutil.MessageIndex, resultHash common.Hash) + PrepareForRecord(ctx context.Context, start, end arbutil.MessageIndex) error +} + type InboxTrackerInterface interface { BlockValidatorRegistrer GetDelayedMessageBytes(uint64) ([]byte, error) @@ -62,8 +69,9 @@ type InboxTrackerInterface interface { type TransactionStreamerInterface interface { BlockValidatorRegistrer + GetProcessedMessageCount() (arbutil.MessageIndex, error) GetMessage(seqNum arbutil.MessageIndex) (*arbostypes.MessageWithMetadata, error) - GetGenesisBlockNumber() (uint64, error) + ResultAtCount(count arbutil.MessageIndex) (*execution.MessageResult, error) PauseReorgs() ResumeReorgs() } @@ -83,9 +91,11 @@ type GlobalStatePosition struct { PosInBatch uint64 } -func GlobalStatePositionsFor( +// return the globalState position before and after processing message at the specified count +// batch-number must be provided by caller +func GlobalStatePositionsAtCount( tracker InboxTrackerInterface, - pos arbutil.MessageIndex, + count arbutil.MessageIndex, batch uint64, ) (GlobalStatePosition, GlobalStatePosition, error) { msgCountInBatch, err := tracker.GetBatchMessageCount(batch) @@ -99,17 +109,18 @@ func GlobalStatePositionsFor( return GlobalStatePosition{}, GlobalStatePosition{}, err } } - if msgCountInBatch <= pos { - return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d has up to message %d, failed getting for %d", batch, msgCountInBatch-1, pos) + if msgCountInBatch < count { + return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d has msgCount %d, failed getting for %d", batch, msgCountInBatch-1, count) } - if firstInBatch > pos { - return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d starts from %d, failed getting for %d", batch, firstInBatch, pos) + if firstInBatch >= count { + return GlobalStatePosition{}, GlobalStatePosition{}, fmt.Errorf("batch %d starts from %d, failed getting for %d", batch, firstInBatch, count) } - startPos := GlobalStatePosition{batch, uint64(pos - firstInBatch)} - if msgCountInBatch == pos+1 { + posInBatch := uint64(count - firstInBatch - 1) + startPos := GlobalStatePosition{batch, posInBatch} + if msgCountInBatch == count { return startPos, GlobalStatePosition{batch + 1, 0}, nil } - return startPos, GlobalStatePosition{batch, uint64(pos + 1 - firstInBatch)}, nil + return startPos, GlobalStatePosition{batch, posInBatch + 1}, nil } func FindBatchContainingMessageIndex( @@ -150,154 +161,96 @@ type ValidationEntryStage uint32 const ( Empty ValidationEntryStage = iota ReadyForRecord - Recorded Ready ) type validationEntry struct { Stage ValidationEntryStage // Valid since ReadyforRecord: - BlockNumber uint64 - PrevBlockHash common.Hash - PrevBlockHeader *types.Header - BlockHash common.Hash - BlockHeader *types.Header - HasDelayedMsg bool - DelayedMsgNr uint64 - msg *arbostypes.MessageWithMetadata - // Valid since Recorded: + Pos arbutil.MessageIndex + Start validator.GoGlobalState + End validator.GoGlobalState + HasDelayedMsg bool + DelayedMsgNr uint64 + // valid when created, removed after recording + msg *arbostypes.MessageWithMetadata + // Has batch when created - others could be added on record + BatchInfo []validator.BatchInfo + // Valid since Ready Preimages map[common.Hash][]byte - BatchInfo []validator.BatchInfo DelayedMsg []byte - // Valid since Ready: - StartPosition GlobalStatePosition - EndPosition GlobalStatePosition -} - -func (v *validationEntry) start() (validator.GoGlobalState, error) { - start := v.StartPosition - globalState := validator.GoGlobalState{ - Batch: start.BatchNumber, - PosInBatch: start.PosInBatch, - BlockHash: v.PrevBlockHash, - } - if v.PrevBlockHeader != nil { - prevExtraInfo := types.DeserializeHeaderExtraInformation(v.PrevBlockHeader) - globalState.SendRoot = prevExtraInfo.SendRoot - } - return globalState, nil -} - -func (v *validationEntry) expectedEnd() (validator.GoGlobalState, error) { - extraInfo := types.DeserializeHeaderExtraInformation(v.BlockHeader) - end := v.EndPosition - return validator.GoGlobalState{ - Batch: end.BatchNumber, - PosInBatch: end.PosInBatch, - BlockHash: v.BlockHash, - SendRoot: extraInfo.SendRoot, - }, nil } func (e *validationEntry) ToInput() (*validator.ValidationInput, error) { if e.Stage != Ready { return nil, errors.New("cannot create input from non-ready entry") } - startState, err := e.start() - if err != nil { - return nil, err - } return &validator.ValidationInput{ - Id: e.BlockNumber, + Id: uint64(e.Pos), HasDelayedMsg: e.HasDelayedMsg, DelayedMsgNr: e.DelayedMsgNr, Preimages: e.Preimages, BatchInfo: e.BatchInfo, DelayedMsg: e.DelayedMsg, - StartState: startState, + StartState: e.Start, }, nil } -func usingDelayedMsg(prevHeader *types.Header, header *types.Header) (bool, uint64) { - if prevHeader == nil { - return true, 0 - } - if header.Nonce == prevHeader.Nonce { - return false, 0 - } - return true, prevHeader.Nonce.Uint64() -} - func newValidationEntry( - prevHeader *types.Header, - header *types.Header, + pos arbutil.MessageIndex, + start validator.GoGlobalState, + end validator.GoGlobalState, msg *arbostypes.MessageWithMetadata, + batch []byte, + prevDelayed uint64, ) (*validationEntry, error) { - hasDelayedMsg, delayedMsgNr := usingDelayedMsg(prevHeader, header) - validationEntry := &validationEntry{ + batchInfo := validator.BatchInfo{ + Number: start.Batch, + Data: batch, + } + hasDelayed := false + var delayedNum uint64 + if msg.DelayedMessagesRead == prevDelayed+1 { + hasDelayed = true + delayedNum = prevDelayed + } else if msg.DelayedMessagesRead != prevDelayed { + return nil, fmt.Errorf("illegal validation entry delayedMessage %d, previous %d", msg.DelayedMessagesRead, prevDelayed) + } + return &validationEntry{ Stage: ReadyForRecord, - BlockNumber: header.Number.Uint64(), - BlockHash: header.Hash(), - BlockHeader: header, - HasDelayedMsg: hasDelayedMsg, - DelayedMsgNr: delayedMsgNr, + Pos: pos, + Start: start, + End: end, + HasDelayedMsg: hasDelayed, + DelayedMsgNr: delayedNum, msg: msg, - } - if prevHeader != nil { - validationEntry.PrevBlockHash = prevHeader.Hash() - validationEntry.PrevBlockHeader = prevHeader - } - return validationEntry, nil -} - -func newRecordedValidationEntry( - prevHeader *types.Header, - header *types.Header, - preimages map[common.Hash][]byte, - batchInfos []validator.BatchInfo, - delayedMsg []byte, -) (*validationEntry, error) { - entry, err := newValidationEntry(prevHeader, header, nil) - if err != nil { - return nil, err - } - entry.Preimages = preimages - entry.BatchInfo = batchInfos - entry.DelayedMsg = delayedMsg - entry.Stage = Recorded - return entry, nil + BatchInfo: []validator.BatchInfo{batchInfo}, + }, nil } func NewStatelessBlockValidator( inboxReader InboxReaderInterface, inbox InboxTrackerInterface, streamer TransactionStreamerInterface, - blockchain *core.BlockChain, - blockchainDb ethdb.Database, + recorder BlockRecorder, arbdb ethdb.Database, das arbstate.DataAvailabilityReader, config func() *BlockValidatorConfig, stack *node.Node, ) (*StatelessBlockValidator, error) { - genesisBlockNum, err := streamer.GetGenesisBlockNumber() - if err != nil { - return nil, err - } valConfFetcher := func() *rpcclient.ClientConfig { return &config().ValidationServer } valClient := server_api.NewValidationClient(valConfFetcher, stack) execClient := server_api.NewExecutionClient(valConfFetcher, stack) validator := &StatelessBlockValidator{ config: config(), execSpawner: execClient, + recorder: recorder, validationSpawners: []validator.ValidationSpawner{valClient}, inboxReader: inboxReader, inboxTracker: inbox, streamer: streamer, - blockchain: blockchain, db: arbdb, daService: das, - genesisBlockNum: genesisBlockNum, - recordingDatabase: arbitrum.NewRecordingDatabase(blockchainDb, blockchain), } return validator, nil } @@ -313,158 +266,38 @@ func (v *StatelessBlockValidator) GetModuleRootsToValidate() []common.Hash { return validatingModuleRoots } -func stateLogFunc(targetHeader, header *types.Header, hasState bool) { - if targetHeader == nil || header == nil { - return - } - gap := targetHeader.Number.Int64() - header.Number.Int64() - step := int64(500) - stage := "computing state" - if !hasState { - step = 3000 - stage = "looking for full block" - } - if (gap >= step) && (gap%step == 0) { - log.Info("Setting up validation", "stage", stage, "current", header.Number, "target", targetHeader.Number) - } -} - -// If msg is nil, this will record block creation up to the point where message would be accessed (for a "too far" proof) -// If keepreference == true, reference to state of prevHeader is added (no reference added if an error is returned) -func (v *StatelessBlockValidator) RecordBlockCreation( - ctx context.Context, - prevHeader *types.Header, - msg *arbostypes.MessageWithMetadata, - keepReference bool, -) (common.Hash, map[common.Hash][]byte, []validator.BatchInfo, error) { - - recordingdb, chaincontext, recordingKV, err := v.recordingDatabase.PrepareRecording(ctx, prevHeader, stateLogFunc) - if err != nil { - return common.Hash{}, nil, nil, err +func (v *StatelessBlockValidator) ValidationEntryRecord(ctx context.Context, e *validationEntry) error { + if e.Stage != ReadyForRecord { + return fmt.Errorf("validation entry should be ReadyForRecord, is: %v", e.Stage) } - defer func() { v.recordingDatabase.Dereference(prevHeader) }() - - chainConfig := v.blockchain.Config() - - // Get the chain ID, both to validate and because the replay binary also gets the chain ID, - // so we need to populate the recordingdb with preimages for retrieving the chain ID. - if prevHeader != nil { - initialArbosState, err := arbosState.OpenSystemArbosState(recordingdb, nil, true) - if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error opening initial ArbOS state: %w", err) - } - chainId, err := initialArbosState.ChainId() - if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error getting chain ID from initial ArbOS state: %w", err) - } - if chainId.Cmp(chainConfig.ChainID) != 0 { - return common.Hash{}, nil, nil, fmt.Errorf("unexpected chain ID %v in ArbOS state, expected %v", chainId, chainConfig.ChainID) - } - _, err = initialArbosState.ChainConfig() - if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error getting chain config from initial ArbOS state: %w", err) - } - genesisNum, err := initialArbosState.GenesisBlockNum() + if e.Pos != 0 { + recording, err := v.recorder.RecordBlockCreation(ctx, e.Pos, e.msg) if err != nil { - return common.Hash{}, nil, nil, fmt.Errorf("error getting genesis block number from initial ArbOS state: %w", err) + return err } - expectedNum := chainConfig.ArbitrumChainParams.GenesisBlockNum - if genesisNum != expectedNum { - return common.Hash{}, nil, nil, fmt.Errorf("unexpected genesis block number %v in ArbOS state, expected %v", genesisNum, expectedNum) + if recording.BlockHash != e.End.BlockHash { + return fmt.Errorf("recording failed: pos %d, hash expected %v, got %v", e.Pos, e.End.BlockHash, recording.BlockHash) } - } + e.BatchInfo = append(e.BatchInfo, recording.BatchInfo...) - var blockHash common.Hash - var readBatchInfo []validator.BatchInfo - if msg != nil { - batchFetcher := func(batchNum uint64) ([]byte, error) { - data, err := v.inboxReader.GetSequencerMessageBytes(ctx, batchNum) - if err != nil { - return nil, err - } - readBatchInfo = append(readBatchInfo, validator.BatchInfo{ - Number: batchNum, - Data: data, - }) - return data, nil + if recording.Preimages != nil { + e.Preimages = recording.Preimages } - // Re-fetch the batch instead of using our cached cost, - // as the replay binary won't have the cache populated. - msg.Message.BatchGasCost = nil - block, _, err := arbos.ProduceBlock( - msg.Message, - msg.DelayedMessagesRead, - prevHeader, - recordingdb, - chaincontext, - chainConfig, - batchFetcher, - ) - if err != nil { - return common.Hash{}, nil, nil, err - } - blockHash = block.Hash() - } - - preimages, err := v.recordingDatabase.PreimagesFromRecording(chaincontext, recordingKV) - if err != nil { - return common.Hash{}, nil, nil, err - } - if keepReference { - prevHeader = nil - } - return blockHash, preimages, readBatchInfo, err -} - -func (v *StatelessBlockValidator) ValidationEntryRecord(ctx context.Context, e *validationEntry, keepReference bool) error { - if e.Stage != ReadyForRecord { - return fmt.Errorf("validation entry should be ReadyForRecord, is: %v", e.Stage) - } - if e.PrevBlockHeader == nil { - e.Stage = Recorded - return nil - } - blockhash, preimages, readBatchInfo, err := v.RecordBlockCreation(ctx, e.PrevBlockHeader, e.msg, keepReference) - if err != nil { - return err - } - if blockhash != e.BlockHash { - return fmt.Errorf("recording failed: blockNum %d, hash expected %v, got %v", e.BlockNumber, e.BlockHash, blockhash) } if e.HasDelayedMsg { delayedMsg, err := v.inboxTracker.GetDelayedMessageBytes(e.DelayedMsgNr) if err != nil { log.Error( "error while trying to read delayed msg for proving", - "err", err, "seq", e.DelayedMsgNr, "blockNr", e.BlockNumber, + "err", err, "seq", e.DelayedMsgNr, "pos", e.Pos, ) return fmt.Errorf("error while trying to read delayed msg for proving: %w", err) } e.DelayedMsg = delayedMsg } - e.Preimages = preimages - e.BatchInfo = readBatchInfo - e.msg = nil // no longer needed - e.Stage = Recorded - return nil -} - -func (v *StatelessBlockValidator) ValidationEntryAddSeqMessage(ctx context.Context, e *validationEntry, - startPos, endPos GlobalStatePosition, seqMsg []byte) error { - if e.Stage != Recorded { - return fmt.Errorf("validation entry stage should be Recorded, is: %v", e.Stage) - } if e.Preimages == nil { e.Preimages = make(map[common.Hash][]byte) } - e.StartPosition = startPos - e.EndPosition = endPos - seqMsgBatchInfo := validator.BatchInfo{ - Number: startPos.BatchNumber, - Data: seqMsg, - } - e.BatchInfo = append(e.BatchInfo, seqMsgBatchInfo) - for _, batch := range e.BatchInfo { if len(batch.Data) <= 40 { continue @@ -473,10 +306,7 @@ func (v *StatelessBlockValidator) ValidationEntryAddSeqMessage(ctx context.Conte continue } if v.daService == nil { - log.Error("No DAS configured, but sequencer message found with DAS header") - if v.blockchain.Config().ArbitrumChainParams.DataAvailabilityCommittee { - return errors.New("processing data availability chain without DAS configured") - } + log.Warn("No DAS configured, but sequencer message found with DAS header") } else { _, err := arbstate.RecoverPayloadFromDasBatch( ctx, batch.Number, batch.Data, v.daService, e.Preimages, arbstate.KeysetValidate, @@ -486,70 +316,76 @@ func (v *StatelessBlockValidator) ValidationEntryAddSeqMessage(ctx context.Conte } } } + + e.msg = nil // no longer needed e.Stage = Ready return nil } -func (v *StatelessBlockValidator) CreateReadyValidationEntry(ctx context.Context, header *types.Header) (*validationEntry, error) { - if header == nil { - return nil, errors.New("header not found") +func buildGlobalState(res execution.MessageResult, pos GlobalStatePosition) validator.GoGlobalState { + return validator.GoGlobalState{ + BlockHash: res.BlockHash, + SendRoot: res.SendRoot, + Batch: pos.BatchNumber, + PosInBatch: pos.PosInBatch, } - blockNum := header.Number.Uint64() - msgIndex := arbutil.BlockNumberToMessageCount(blockNum, v.genesisBlockNum) - 1 - prevHeader := v.blockchain.GetHeader(header.ParentHash, blockNum-1) - if prevHeader == nil && blockNum > 0 { - return nil, fmt.Errorf("prev header not found for block number %v with hash %s and parent hash %s", blockNum, header.Hash(), header.ParentHash) +} + +// return the globalState position before and after processing message at the specified count +func (v *StatelessBlockValidator) GlobalStatePositionsAtCount(count arbutil.MessageIndex) (GlobalStatePosition, GlobalStatePosition, error) { + if count == 0 { + return GlobalStatePosition{}, GlobalStatePosition{}, errors.New("no initial state for count==0") + } + if count == 1 { + return GlobalStatePosition{}, GlobalStatePosition{1, 0}, nil } - msg, err := v.streamer.GetMessage(msgIndex) + batchCount, err := v.inboxTracker.GetBatchCount() if err != nil { - return nil, err + return GlobalStatePosition{}, GlobalStatePosition{}, err } - var preimages map[common.Hash][]byte - var readBatchInfo []validator.BatchInfo - if prevHeader != nil { - var resHash common.Hash - var err error - resHash, preimages, readBatchInfo, err = v.RecordBlockCreation(ctx, prevHeader, msg, false) - if err != nil { - return nil, fmt.Errorf("failed to get block data to validate: %w", err) - } - if resHash != header.Hash() { - return nil, fmt.Errorf("wrong hash expected %s got %s", header.Hash(), resHash) - } + batch, err := FindBatchContainingMessageIndex(v.inboxTracker, count-1, batchCount) + if err != nil { + return GlobalStatePosition{}, GlobalStatePosition{}, err } + return GlobalStatePositionsAtCount(v.inboxTracker, count, batch) +} - batchCount, err := v.inboxTracker.GetBatchCount() +func (v *StatelessBlockValidator) CreateReadyValidationEntry(ctx context.Context, pos arbutil.MessageIndex) (*validationEntry, error) { + msg, err := v.streamer.GetMessage(pos) if err != nil { return nil, err } - batch, err := FindBatchContainingMessageIndex(v.inboxTracker, msgIndex, batchCount) + result, err := v.streamer.ResultAtCount(pos + 1) if err != nil { return nil, err } - - startPos, endPos, err := GlobalStatePositionsFor(v.inboxTracker, msgIndex, batch) - if err != nil { - return nil, fmt.Errorf("failed calculating position for validation: %w", err) - } - - usingDelayed, delaydNr := usingDelayedMsg(prevHeader, header) - var delayed []byte - if usingDelayed { - delayed, err = v.inboxTracker.GetDelayedMessageBytes(delaydNr) + var prevDelayed uint64 + if pos > 0 { + prev, err := v.streamer.GetMessage(pos - 1) if err != nil { - return nil, fmt.Errorf("error while trying to read delayed msg for proving: %w", err) + return nil, err } + prevDelayed = prev.DelayedMessagesRead } - entry, err := newRecordedValidationEntry(prevHeader, header, preimages, readBatchInfo, delayed) + prevResult, err := v.streamer.ResultAtCount(pos) if err != nil { - return nil, fmt.Errorf("failed to create validation entry %w", err) + return nil, err } - + startPos, endPos, err := v.GlobalStatePositionsAtCount(pos + 1) + if err != nil { + return nil, fmt.Errorf("failed calculating position for validation: %w", err) + } + start := buildGlobalState(*prevResult, startPos) + end := buildGlobalState(*result, endPos) seqMsg, err := v.inboxReader.GetSequencerMessageBytes(ctx, startPos.BatchNumber) if err != nil { return nil, err } - err = v.ValidationEntryAddSeqMessage(ctx, entry, startPos, endPos, seqMsg) + entry, err := newValidationEntry(pos, start, end, msg, seqMsg, prevDelayed) + if err != nil { + return nil, err + } + err = v.ValidationEntryRecord(ctx, entry) if err != nil { return nil, err } @@ -557,20 +393,16 @@ func (v *StatelessBlockValidator) CreateReadyValidationEntry(ctx context.Context return entry, nil } -func (v *StatelessBlockValidator) ValidateBlock( - ctx context.Context, header *types.Header, useExec bool, moduleRoot common.Hash, -) (bool, error) { - entry, err := v.CreateReadyValidationEntry(ctx, header) - if err != nil { - return false, err - } - expEnd, err := entry.expectedEnd() +func (v *StatelessBlockValidator) ValidateResult( + ctx context.Context, pos arbutil.MessageIndex, useExec bool, moduleRoot common.Hash, +) (bool, *validator.GoGlobalState, error) { + entry, err := v.CreateReadyValidationEntry(ctx, pos) if err != nil { - return false, err + return false, nil, err } input, err := entry.ToInput() if err != nil { - return false, err + return false, nil, err } var spawners []validator.ValidationSpawner if useExec { @@ -579,7 +411,7 @@ func (v *StatelessBlockValidator) ValidateBlock( spawners = v.validationSpawners } if len(spawners) == 0 { - return false, errors.New("no validation defined") + return false, &entry.End, errors.New("no validation defined") } var runs []validator.ValidationRun for _, spawner := range spawners { @@ -593,15 +425,15 @@ func (v *StatelessBlockValidator) ValidateBlock( }() for _, run := range runs { gsEnd, err := run.Await(ctx) - if err != nil || gsEnd != expEnd { - return false, err + if err != nil || gsEnd != entry.End { + return false, &gsEnd, err } } - return true, nil + return true, &entry.End, nil } -func (v *StatelessBlockValidator) RecordDBReferenceCount() int64 { - return v.recordingDatabase.ReferenceCount() +func (v *StatelessBlockValidator) OverrideRecorder(t *testing.T, recorder BlockRecorder) { + v.recorder = recorder } func (v *StatelessBlockValidator) Start(ctx_in context.Context) error { diff --git a/system_tests/block_validator_test.go b/system_tests/block_validator_test.go index b3fd8ddb6c..7fe1a65969 100644 --- a/system_tests/block_validator_test.go +++ b/system_tests/block_validator_test.go @@ -20,6 +20,7 @@ import ( "github.com/offchainlabs/nitro/arbnode" "github.com/offchainlabs/nitro/arbos/l2pricing" + "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/solgen/go/precompilesgen" ) @@ -43,6 +44,12 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops chainConfig.ArbitrumChainParams.InitialArbOSVersion = 10 } + var delayEvery int + if workloadLoops > 1 { + l1NodeConfigA.BatchPoster.MaxBatchPostDelay = time.Millisecond * 500 + delayEvery = workloadLoops / 3 + } + l2info, nodeA, l2client, l1info, _, l1client, l1stack := createTestNodeOnL1WithConfig(t, ctx, true, l1NodeConfigA, chainConfig, nil) defer requireClose(t, l1stack) defer nodeA.StopAndWait() @@ -105,6 +112,9 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops if workload != depleteGas { Require(t, err) } + if delayEvery > 0 && i%delayEvery == (delayEvery-1) { + <-time.After(time.Second) + } } } else { auth := l2info.GetDefaultTransactOpts("Owner", ctx) @@ -176,10 +186,12 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops } t.Log("waiting for block: ", lastBlock.NumberU64()) timeout := getDeadlineTimeout(t, time.Minute*10) - if !nodeB.BlockValidator.WaitForBlock(ctx, lastBlock.NumberU64(), timeout) { + // messageindex is same as block number here + if !nodeB.BlockValidator.WaitForPos(t, ctx, arbutil.MessageIndex(lastBlock.NumberU64()), timeout) { Fatal(t, "did not validate all blocks") } - finalRefCount := nodeB.BlockValidator.RecordDBReferenceCount() + nodeB.Execution.Recorder.TrimAllPrepared(t) + finalRefCount := nodeB.Execution.Recorder.RecordingDBReferenceCount() lastBlockNow, err := l2clientB.BlockByNumber(ctx, nil) Require(t, err) // up to 3 extra references: awaiting validation, recently valid, lastValidatedHeader diff --git a/system_tests/common_test.go b/system_tests/common_test.go index 7c9b6b1191..5be818b0a1 100644 --- a/system_tests/common_test.go +++ b/system_tests/common_test.go @@ -15,9 +15,6 @@ import ( "time" "github.com/offchainlabs/nitro/arbnode/execution" - "github.com/offchainlabs/nitro/validator/server_api" - "github.com/offchainlabs/nitro/validator/valnode" - "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbos/util" "github.com/offchainlabs/nitro/arbstate" @@ -28,7 +25,9 @@ import ( "github.com/offchainlabs/nitro/util/arbmath" "github.com/offchainlabs/nitro/util/headerreader" "github.com/offchainlabs/nitro/util/signature" + "github.com/offchainlabs/nitro/validator/server_api" "github.com/offchainlabs/nitro/validator/server_common" + "github.com/offchainlabs/nitro/validator/valnode" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/keystore" diff --git a/system_tests/full_challenge_impl_test.go b/system_tests/full_challenge_impl_test.go index 26e2d4a64e..de2ee5cd34 100644 --- a/system_tests/full_challenge_impl_test.go +++ b/system_tests/full_challenge_impl_test.go @@ -1,6 +1,10 @@ // Copyright 2021-2022, Offchain Labs, Inc. // For license information, see https://github.com/nitro/blob/master/LICENSE +// race detection makes things slow and miss timeouts +//go:build !race +// +build !race + package arbtest import ( @@ -18,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" @@ -137,13 +142,15 @@ func writeTxToBatch(writer io.Writer, tx *types.Transaction) error { return err } -func makeBatch(t *testing.T, l2Node *arbnode.Node, l2Info *BlockchainTestInfo, backend *ethclient.Client, sequencer *bind.TransactOpts, seqInbox *mocksgen.SequencerInboxStub, seqInboxAddr common.Address, isChallenger bool) { +const makeBatch_MsgsPerBatch = int64(5) + +func makeBatch(t *testing.T, l2Node *arbnode.Node, l2Info *BlockchainTestInfo, backend *ethclient.Client, sequencer *bind.TransactOpts, seqInbox *mocksgen.SequencerInboxStub, seqInboxAddr common.Address, modStep int64) { ctx := context.Background() batchBuffer := bytes.NewBuffer([]byte{}) - for i := int64(0); i < 10; i++ { + for i := int64(0); i < makeBatch_MsgsPerBatch; i++ { value := i - if i == 5 && isChallenger { + if i == modStep { value++ } err := writeTxToBatch(batchBuffer, l2Info.PrepareTx("Owner", "Destination", 1000000, big.NewInt(value), []byte{})) @@ -153,9 +160,9 @@ func makeBatch(t *testing.T, l2Node *arbnode.Node, l2Info *BlockchainTestInfo, b Require(t, err) message := append([]byte{0}, compressed...) - maxUint256 := new(big.Int).Lsh(common.Big1, 256) - maxUint256.Sub(maxUint256, common.Big1) - tx, err := seqInbox.AddSequencerL2BatchFromOrigin0(sequencer, maxUint256, message, big.NewInt(1), common.Address{}, big.NewInt(0), big.NewInt(0)) + seqNum := new(big.Int).Lsh(common.Big1, 256) + seqNum.Sub(seqNum, common.Big1) + tx, err := seqInbox.AddSequencerL2BatchFromOrigin0(sequencer, seqNum, message, big.NewInt(1), common.Address{}, big.NewInt(0), big.NewInt(0)) Require(t, err) receipt, err := EnsureTxSucceeded(ctx, backend, tx) Require(t, err) @@ -218,8 +225,7 @@ func setupSequencerInboxStub(ctx context.Context, t *testing.T, l1Info *Blockcha return bridgeAddr, seqInbox, seqInboxAddr } -func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { - t.Parallel() +func RunChallengeTest(t *testing.T, asserterIsCorrect bool, useStubs bool, challengeMsgIdx int64) { glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.TerminalFormat(false))) glogger.Verbosity(log.LvlInfo) log.Root().SetHandler(glogger) @@ -241,7 +247,13 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { conf.BatchPoster.Enable = false conf.InboxReader.CheckDelay = time.Second - _, valStack := createTestValidationNode(t, ctx, &valnode.TestValidationConfig) + var valStack *node.Node + var mockSpawn *mockSpawner + if useStubs { + mockSpawn, valStack = createMockValidationNode(t, ctx, &valnode.TestValidationConfig.Arbitrator) + } else { + _, valStack = createTestValidationNode(t, ctx, &valnode.TestValidationConfig) + } configByValidationNode(t, conf, valStack) fatalErrChan := make(chan error, 10) @@ -274,8 +286,22 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { asserterL2Info.GenerateAccount("Destination") challengerL2Info.SetFullAccountInfo("Destination", asserterL2Info.GetInfoWithPrivKey("Destination")) - makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, false) - makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, true) + + if challengeMsgIdx < 1 || challengeMsgIdx > 3*makeBatch_MsgsPerBatch { + Fatal(t, "challengeMsgIdx illegal") + } + + // seqNum := common.Big2 + makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, -1) + makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, challengeMsgIdx-1) + + // seqNum.Add(seqNum, common.Big1) + makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, -1) + makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, challengeMsgIdx-makeBatch_MsgsPerBatch-1) + + // seqNum.Add(seqNum, common.Big1) + makeBatch(t, asserterL2, asserterL2Info, l1Backend, &sequencerTxOpts, asserterSeqInbox, asserterSeqInboxAddr, -1) + makeBatch(t, challengerL2, challengerL2Info, l1Backend, &sequencerTxOpts, challengerSeqInbox, challengerSeqInboxAddr, challengeMsgIdx-makeBatch_MsgsPerBatch*2-1) trueSeqInboxAddr := challengerSeqInboxAddr trueDelayedBridge := challengerBridgeAddr @@ -291,9 +317,14 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { if err != nil { Fatal(t, err) } - wasmModuleRoot := locator.LatestWasmModuleRoot() - if (wasmModuleRoot == common.Hash{}) { - Fatal(t, "latest machine not found") + var wasmModuleRoot common.Hash + if useStubs { + wasmModuleRoot = mockWasmModuleRoot + } else { + wasmModuleRoot = locator.LatestWasmModuleRoot() + if (wasmModuleRoot == common.Hash{}) { + Fatal(t, "latest machine not found") + } } asserterGenesis := asserterL2.Execution.ArbInterface.BlockChain().Genesis() @@ -314,7 +345,7 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { } asserterEndGlobalState := validator.GoGlobalState{ BlockHash: asserterLatestBlock.Hash(), - Batch: 2, + Batch: 4, PosInBatch: 0, } numBlocks := asserterLatestBlock.Number.Uint64() - asserterGenesis.NumberU64() @@ -337,29 +368,37 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { confirmLatestBlock(ctx, t, l1Info, l1Backend) - asserterValidator, err := staker.NewStatelessBlockValidator(asserterL2.InboxReader, asserterL2.InboxTracker, asserterL2.TxStreamer, asserterL2Blockchain, asserterL2ChainDb, asserterL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) + asserterValidator, err := staker.NewStatelessBlockValidator(asserterL2.InboxReader, asserterL2.InboxTracker, asserterL2.TxStreamer, asserterL2.Execution.Recorder, asserterL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) if err != nil { Fatal(t, err) } + if useStubs { + asserterRecorder := newMockRecorder(asserterValidator, asserterL2.TxStreamer) + asserterValidator.OverrideRecorder(t, asserterRecorder) + } err = asserterValidator.Start(ctx) if err != nil { Fatal(t, err) } defer asserterValidator.Stop() - asserterManager, err := staker.NewChallengeManager(ctx, l1Backend, &asserterTxOpts, asserterTxOpts.From, challengeManagerAddr, 1, asserterL2Blockchain, asserterL2.InboxTracker, asserterValidator, 0, 0) + asserterManager, err := staker.NewChallengeManager(ctx, l1Backend, &asserterTxOpts, asserterTxOpts.From, challengeManagerAddr, 1, asserterValidator, 0, 0) if err != nil { Fatal(t, err) } - challengerValidator, err := staker.NewStatelessBlockValidator(challengerL2.InboxReader, challengerL2.InboxTracker, challengerL2.TxStreamer, challengerL2Blockchain, challengerL2ChainDb, challengerL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) + challengerValidator, err := staker.NewStatelessBlockValidator(challengerL2.InboxReader, challengerL2.InboxTracker, challengerL2.TxStreamer, challengerL2.Execution.Recorder, challengerL2ArbDb, nil, StaticFetcherFrom(t, &conf.BlockValidator), valStack) if err != nil { Fatal(t, err) } + if useStubs { + challengerRecorder := newMockRecorder(challengerValidator, challengerL2.TxStreamer) + challengerValidator.OverrideRecorder(t, challengerRecorder) + } err = challengerValidator.Start(ctx) if err != nil { Fatal(t, err) } defer challengerValidator.Stop() - challengerManager, err := staker.NewChallengeManager(ctx, l1Backend, &challengerTxOpts, challengerTxOpts.From, challengeManagerAddr, 1, challengerL2Blockchain, challengerL2.InboxTracker, challengerValidator, 0, 0) + challengerManager, err := staker.NewChallengeManager(ctx, l1Backend, &challengerTxOpts, challengerTxOpts.From, challengeManagerAddr, 1, challengerValidator, 0, 0) if err != nil { Fatal(t, err) } @@ -394,6 +433,19 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { if tx == nil { Fatal(t, "no move") } + + if useStubs { + if len(mockSpawn.ExecSpawned) != 0 { + if len(mockSpawn.ExecSpawned) != 1 { + Fatal(t, "bad number of spawned execRuns: ", len(mockSpawn.ExecSpawned)) + } + if mockSpawn.ExecSpawned[0] != uint64(challengeMsgIdx) { + Fatal(t, "wrong spawned execRuns: ", mockSpawn.ExecSpawned[0], " expected: ", challengeMsgIdx) + } + return + } + } + _, err = EnsureTxSucceeded(ctx, l1Backend, tx) if err != nil { if !currentCorrect && strings.Contains(err.Error(), "BAD_SEQINBOX_MESSAGE") { @@ -419,3 +471,17 @@ func RunChallengeTest(t *testing.T, asserterIsCorrect bool) { Fatal(t, "challenge timed out without winner") } + +func TestMockChallengeManagerAsserterIncorrect(t *testing.T) { + t.Parallel() + for i := int64(1); i <= makeBatch_MsgsPerBatch*3; i++ { + RunChallengeTest(t, false, true, i) + } +} + +func TestMockChallengeManagerAsserterCorrect(t *testing.T) { + t.Parallel() + for i := int64(1); i <= makeBatch_MsgsPerBatch*3; i++ { + RunChallengeTest(t, true, true, i) + } +} diff --git a/system_tests/full_challenge_test.go b/system_tests/full_challenge_test.go index 367ac33464..a960e7f640 100644 --- a/system_tests/full_challenge_test.go +++ b/system_tests/full_challenge_test.go @@ -15,9 +15,11 @@ import ( ) func TestChallengeManagerFullAsserterIncorrect(t *testing.T) { - RunChallengeTest(t, false) + t.Parallel() + RunChallengeTest(t, false, false, makeBatch_MsgsPerBatch+1) } func TestChallengeManagerFullAsserterCorrect(t *testing.T) { - RunChallengeTest(t, true) + t.Parallel() + RunChallengeTest(t, true, false, makeBatch_MsgsPerBatch+2) } diff --git a/system_tests/seqinbox_test.go b/system_tests/seqinbox_test.go index 56d727dd26..bf3e7c86c1 100644 --- a/system_tests/seqinbox_test.go +++ b/system_tests/seqinbox_test.go @@ -267,11 +267,13 @@ func testSequencerInboxReaderImpl(t *testing.T, validator bool) { if validator && i%15 == 0 { for i := 0; ; i++ { - lastValidated := arbNode.BlockValidator.LastBlockValidated() - if lastValidated == expectedBlockNumber { + expectedPos, err := arbNode.Execution.ExecEngine.BlockNumberToMessageIndex(expectedBlockNumber) + Require(t, err) + lastValidated := arbNode.BlockValidator.Validated(t) + if lastValidated == expectedPos+1 { break } else if i >= 1000 { - Fatal(t, "timed out waiting for block validator; have", lastValidated, "want", expectedBlockNumber) + Fatal(t, "timed out waiting for block validator; have", lastValidated, "want", expectedPos+1) } time.Sleep(time.Second) } diff --git a/system_tests/staker_test.go b/system_tests/staker_test.go index e5ed3879e9..7a3ae41814 100644 --- a/system_tests/staker_test.go +++ b/system_tests/staker_test.go @@ -149,8 +149,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) l2nodeA.InboxReader, l2nodeA.InboxTracker, l2nodeA.TxStreamer, - l2nodeA.Execution.ArbInterface.BlockChain(), - l2nodeA.Execution.ChainDB, + l2nodeA.Execution.Recorder, l2nodeA.ArbDB, nil, StaticFetcherFrom(t, &blockValidatorConfig), @@ -166,7 +165,9 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) valConfig, nil, statelessA, + nil, l2nodeA.DeployInfo.ValidatorUtils, + nil, ) Require(t, err) err = stakerA.Initialize(ctx) @@ -183,8 +184,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) l2nodeB.InboxReader, l2nodeB.InboxTracker, l2nodeB.TxStreamer, - l2nodeB.Execution.ArbInterface.BlockChain(), - l2nodeB.Execution.ChainDB, + l2nodeB.Execution.Recorder, l2nodeB.ArbDB, nil, StaticFetcherFrom(t, &blockValidatorConfig), @@ -200,7 +200,9 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) valConfig, nil, statelessB, + nil, l2nodeB.DeployInfo.ValidatorUtils, + nil, ) Require(t, err) err = stakerB.Initialize(ctx) @@ -220,7 +222,9 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) valConfig, nil, statelessA, + nil, l2nodeA.DeployInfo.ValidatorUtils, + nil, ) Require(t, err) if stakerC.Strategy() != staker.WatchtowerStrategy { @@ -290,7 +294,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) } if err != nil && faultyStaker && i%2 == 1 { // Check if this is an expected error from the faulty staker. - if strings.Contains(err.Error(), "agreed with entire challenge") || strings.Contains(err.Error(), "after block -1 expected global state") { + if strings.Contains(err.Error(), "agreed with entire challenge") || strings.Contains(err.Error(), "after msg 0 expected global state") { // Expected error upon realizing you're losing the challenge. Get ready for a timeout. if !challengeMangerTimedOut { // Upgrade the ChallengeManager contract to an implementation which says challenges are always timed out @@ -318,7 +322,7 @@ func stakerTestImpl(t *testing.T, faultyStaker bool, honestStakerInactive bool) } } else if strings.Contains(err.Error(), "insufficient funds") && sawStakerZombie { // Expected error when trying to re-stake after losing initial stake. - } else if strings.Contains(err.Error(), "unknown start block hash") && sawStakerZombie { + } else if strings.Contains(err.Error(), "start state not in chain") && sawStakerZombie { // Expected error when trying to re-stake after the challenger's nodes getting confirmed. } else if strings.Contains(err.Error(), "STAKER_IS_ZOMBIE") && sawStakerZombie { // Expected error when the staker is a zombie and thus can't advance its stake. diff --git a/system_tests/twonodeslong_test.go b/system_tests/twonodeslong_test.go index c2a5979c8d..3987e5cf7b 100644 --- a/system_tests/twonodeslong_test.go +++ b/system_tests/twonodeslong_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/offchainlabs/nitro/arbos/l2pricing" + "github.com/offchainlabs/nitro/arbutil" "github.com/ethereum/go-ethereum/core/types" ) @@ -173,7 +174,8 @@ func testTwoNodesLong(t *testing.T, dasModeStr string) { lastBlockHeader, err := l2clientB.HeaderByNumber(ctx, nil) Require(t, err) timeout := getDeadlineTimeout(t, time.Minute*30) - if !nodeB.BlockValidator.WaitForBlock(ctx, lastBlockHeader.Number.Uint64(), timeout) { + // messageindex is same as block number here + if !nodeB.BlockValidator.WaitForPos(t, ctx, arbutil.MessageIndex(lastBlockHeader.Number.Uint64()), timeout) { Fatal(t, "did not validate all blocks") } } diff --git a/system_tests/validation_mock_test.go b/system_tests/validation_mock_test.go index c853568479..bfa2d67839 100644 --- a/system_tests/validation_mock_test.go +++ b/system_tests/validation_mock_test.go @@ -3,7 +3,6 @@ package arbtest import ( "bytes" "context" - "errors" "math/big" "testing" "time" @@ -12,6 +11,11 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/rpc" + "github.com/offchainlabs/nitro/arbnode" + "github.com/offchainlabs/nitro/arbnode/execution" + "github.com/offchainlabs/nitro/arbos/arbostypes" + "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/staker" "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/rpcclient" "github.com/offchainlabs/nitro/validator" @@ -20,6 +24,8 @@ import ( ) type mockSpawner struct { + ExecSpawned []uint64 + LaunchDelay time.Duration } var blockHashKey = common.HexToHash("0x11223344") @@ -50,11 +56,8 @@ func (s *mockSpawner) Launch(entry *validator.ValidationInput, moduleRoot common Promise: containers.NewPromise[validator.GoGlobalState](nil), root: moduleRoot, } - if moduleRoot != mockWasmModuleRoot { - run.ProduceError(errors.New("unsupported root")) - } else { - run.Produce(globalstateFromTestPreimages(entry.Preimages)) - } + <-time.After(s.LaunchDelay) + run.Produce(globalstateFromTestPreimages(entry.Preimages)) return run } @@ -66,9 +69,7 @@ func (s *mockSpawner) Name() string { return "mock" } func (s *mockSpawner) Room() int { return 4 } func (s *mockSpawner) CreateExecutionRun(wasmModuleRoot common.Hash, input *validator.ValidationInput) containers.PromiseInterface[validator.ExecutionRun] { - if wasmModuleRoot != mockWasmModuleRoot { - return containers.NewReadyPromise[validator.ExecutionRun](nil, errors.New("unsupported root")) - } + s.ExecSpawned = append(s.ExecSpawned, input.Id) return containers.NewReadyPromise[validator.ExecutionRun](&mockExecRun{ startState: input.StartState, endState: globalstateFromTestPreimages(input.Preimages), @@ -130,7 +131,7 @@ func (r *mockExecRun) PrepareRange(uint64, uint64) containers.PromiseInterface[s func (r *mockExecRun) Close() {} -func createMockValidationNode(t *testing.T, ctx context.Context, config *server_arb.ArbitratorSpawnerConfig) *node.Node { +func createMockValidationNode(t *testing.T, ctx context.Context, config *server_arb.ArbitratorSpawnerConfig) (*mockSpawner, *node.Node) { stackConf := node.DefaultConfig stackConf.HTTPPort = 0 stackConf.DataDir = "" @@ -170,7 +171,7 @@ func createMockValidationNode(t *testing.T, ctx context.Context, config *server_ serverAPI.StopOnly() }() - return stack + return spawner, stack } // mostly tests translation to/from json and running over network @@ -178,7 +179,7 @@ func TestValidationServerAPI(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - validationDefault := createMockValidationNode(t, ctx, nil) + _, validationDefault := createMockValidationNode(t, ctx, nil) client := server_api.NewExecutionClient(StaticFetcherFrom(t, &rpcclient.TestClientConfig), validationDefault) err := client.Start(ctx) Require(t, err) @@ -238,14 +239,94 @@ func TestValidationServerAPI(t *testing.T) { } } +func TestValidationClientRoom(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mockSpawner, spawnerStack := createMockValidationNode(t, ctx, nil) + client := server_api.NewExecutionClient(StaticFetcherFrom(t, &rpcclient.TestClientConfig), spawnerStack) + err := client.Start(ctx) + Require(t, err) + + wasmRoot, err := client.LatestWasmModuleRoot().Await(ctx) + Require(t, err) + + if client.Room() != 4 { + Fatal(t, "wrong initial room ", client.Room()) + } + + hash1 := common.HexToHash("0x11223344556677889900aabbccddeeff") + hash2 := common.HexToHash("0x11111111122222223333333444444444") + + startState := validator.GoGlobalState{ + BlockHash: hash1, + SendRoot: hash2, + Batch: 300, + PosInBatch: 3000, + } + endState := validator.GoGlobalState{ + BlockHash: hash2, + SendRoot: hash1, + Batch: 3000, + PosInBatch: 300, + } + + valInput := validator.ValidationInput{ + StartState: startState, + Preimages: globalstateToTestPreimages(endState), + } + + valRuns := make([]validator.ValidationRun, 0, 4) + + for i := 0; i < 4; i++ { + valRun := client.Launch(&valInput, wasmRoot) + valRuns = append(valRuns, valRun) + } + + for i := range valRuns { + _, err := valRuns[i].Await(ctx) + Require(t, err) + } + + if client.Room() != 4 { + Fatal(t, "wrong room after launch", client.Room()) + } + + mockSpawner.LaunchDelay = time.Hour + + valRuns = make([]validator.ValidationRun, 0, 3) + + for i := 0; i < 4; i++ { + valRun := client.Launch(&valInput, wasmRoot) + valRuns = append(valRuns, valRun) + room := client.Room() + if room != 3-i { + Fatal(t, "wrong room after launch ", room, " expected: ", 4-i) + } + } + + for i := range valRuns { + valRuns[i].Cancel() + _, err := valRuns[i].Await(ctx) + if err == nil { + Fatal(t, "no error returned after cancel i:", i) + } + } + + room := client.Room() + if room != 4 { + Fatal(t, "wrong room after canceling runs: ", room) + } +} + func TestExecutionKeepAlive(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - validationDefault := createMockValidationNode(t, ctx, nil) + _, validationDefault := createMockValidationNode(t, ctx, nil) shortTimeoutConfig := server_arb.DefaultArbitratorSpawnerConfig shortTimeoutConfig.ExecRunTimeout = time.Second - validationShortTO := createMockValidationNode(t, ctx, &shortTimeoutConfig) + _, validationShortTO := createMockValidationNode(t, ctx, &shortTimeoutConfig) configFetcher := StaticFetcherFrom(t, &rpcclient.TestClientConfig) clientDefault := server_api.NewExecutionClient(configFetcher, validationDefault) @@ -274,3 +355,43 @@ func TestExecutionKeepAlive(t *testing.T) { t.Error("getStep should have timed out but didn't") } } + +type mockBlockRecorder struct { + validator *staker.StatelessBlockValidator + streamer *arbnode.TransactionStreamer +} + +func (m *mockBlockRecorder) RecordBlockCreation( + ctx context.Context, + pos arbutil.MessageIndex, + msg *arbostypes.MessageWithMetadata, +) (*execution.RecordResult, error) { + _, globalpos, err := m.validator.GlobalStatePositionsAtCount(pos + 1) + if err != nil { + return nil, err + } + res, err := m.streamer.ResultAtCount(pos + 1) + if err != nil { + return nil, err + } + globalState := validator.GoGlobalState{ + Batch: globalpos.BatchNumber, + PosInBatch: globalpos.PosInBatch, + BlockHash: res.BlockHash, + SendRoot: res.SendRoot, + } + return &execution.RecordResult{ + Pos: pos, + BlockHash: res.BlockHash, + Preimages: globalstateToTestPreimages(globalState), + }, nil +} + +func (m *mockBlockRecorder) MarkValid(pos arbutil.MessageIndex, resultHash common.Hash) {} +func (m *mockBlockRecorder) PrepareForRecord(ctx context.Context, start, end arbutil.MessageIndex) error { + return nil +} + +func newMockRecorder(validator *staker.StatelessBlockValidator, streamer *arbnode.TransactionStreamer) *mockBlockRecorder { + return &mockBlockRecorder{validator, streamer} +} diff --git a/util/containers/syncmap.go b/util/containers/syncmap.go new file mode 100644 index 0000000000..7952a32252 --- /dev/null +++ b/util/containers/syncmap.go @@ -0,0 +1,24 @@ +package containers + +import "sync" + +type SyncMap[K any, V any] struct { + internal sync.Map +} + +func (m *SyncMap[K, V]) Load(key K) (V, bool) { + val, found := m.internal.Load(key) + if !found { + var empty V + return empty, false + } + return val.(V), true +} + +func (m *SyncMap[K, V]) Store(key K, val V) { + m.internal.Store(key, val) +} + +func (m *SyncMap[K, V]) Delete(key K) { + m.internal.Delete(key) +} diff --git a/util/rpcclient/rpcclient.go b/util/rpcclient/rpcclient.go index 2d43ded0d7..0a09459070 100644 --- a/util/rpcclient/rpcclient.go +++ b/util/rpcclient/rpcclient.go @@ -85,27 +85,37 @@ func (c *RpcClient) Close() { } } -func limitedMarshal(limit int, arg interface{}) string { - marshalled, err := json.Marshal(arg) +type limitedMarshal struct { + limit int + value any +} + +func (m limitedMarshal) String() string { + marshalled, err := json.Marshal(m.value) var str string if err != nil { - str = "\"CANNOT MARSHALL:" + err.Error() + "\"" + str = "\"CANNOT MARSHALL: " + err.Error() + "\"" } else { str = string(marshalled) } - if limit == 0 || len(str) <= limit { + if m.limit == 0 || len(str) <= m.limit { return str } - prefix := str[:limit/2-1] - postfix := str[len(str)-limit/2+1:] + prefix := str[:m.limit/2-1] + postfix := str[len(str)-m.limit/2+1:] return fmt.Sprintf("%v..%v", prefix, postfix) } -func logArgs(limit int, args ...interface{}) string { +type limitedArgumentsMarshal struct { + limit int + args []any +} + +func (m limitedArgumentsMarshal) String() string { res := "[" - for i, arg := range args { - res += limitedMarshal(limit, arg) - if i < len(args)-1 { + for i, arg := range m.args { + res += limitedMarshal{m.limit, arg}.String() + if i < len(m.args)-1 { res += ", " } } @@ -118,7 +128,7 @@ func (c *RpcClient) CallContext(ctx_in context.Context, result interface{}, meth return errors.New("not connected") } logId := atomic.AddUint64(&c.logId, 1) - log.Trace("sending RPC request", "method", method, "logId", logId, "args", logArgs(int(c.config().ArgLogLimit), args...)) + log.Trace("sending RPC request", "method", method, "logId", logId, "args", limitedArgumentsMarshal{int(c.config().ArgLogLimit), args}) var err error for i := 0; i < int(c.config().Retries)+1; i++ { if ctx_in.Err() != nil { @@ -138,9 +148,8 @@ func (c *RpcClient) CallContext(ctx_in context.Context, result interface{}, meth limit := int(c.config().ArgLogLimit) if err != nil && err.Error() != "already known" { logger = log.Info - limit = 0 } - logger("rpc response", "method", method, "logId", logId, "err", err, "result", limitedMarshal(limit, result), "attempt", i, "args", logArgs(limit, args...)) + logger("rpc response", "method", method, "logId", logId, "err", err, "result", limitedMarshal{limit, result}, "attempt", i, "args", limitedArgumentsMarshal{limit, args}) if err == nil { return nil } diff --git a/util/rpcclient/rpcclient_test.go b/util/rpcclient/rpcclient_test.go index 1b44b9479e..c9d813ab0d 100644 --- a/util/rpcclient/rpcclient_test.go +++ b/util/rpcclient/rpcclient_test.go @@ -15,17 +15,18 @@ import ( func TestLogArgs(t *testing.T) { t.Parallel() - str := logArgs(0, 1, 2, 3, "hello, world") + args := []any{1, 2, 3, "hello, world"} + str := limitedArgumentsMarshal{0, args}.String() if str != "[1, 2, 3, \"hello, world\"]" { Fail(t, "unexpected logs limit 0 got:", str) } - str = logArgs(100, 1, 2, 3, "hello, world") + str = limitedArgumentsMarshal{100, args}.String() if str != "[1, 2, 3, \"hello, world\"]" { Fail(t, "unexpected logs limit 100 got:", str) } - str = logArgs(6, 1, 2, 3, "hello, world") + str = limitedArgumentsMarshal{6, args}.String() if str != "[1, 2, 3, \"h..d\"]" { Fail(t, "unexpected logs limit 6 got:", str) } diff --git a/util/stopwaiter/stopwaiter.go b/util/stopwaiter/stopwaiter.go index f449ed1c44..1e70e328eb 100644 --- a/util/stopwaiter/stopwaiter.go +++ b/util/stopwaiter/stopwaiter.go @@ -198,6 +198,9 @@ func (s *StopWaiterSafe) CallIterativelySafe(foo func(context.Context) time.Dura if ctx.Err() != nil { return } + if interval == time.Duration(0) { + continue + } timer := time.NewTimer(interval) select { case <-ctx.Done(): @@ -233,6 +236,9 @@ func CallIterativelyWith[T any]( return } val = defaultVal + if interval == time.Duration(0) { + continue + } timer := time.NewTimer(interval) select { case <-ctx.Done(): diff --git a/validator/server_api/validation_client.go b/validator/server_api/validation_client.go index 4f678fde9e..d6143ca917 100644 --- a/validator/server_api/validation_client.go +++ b/validator/server_api/validation_client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "errors" + "sync/atomic" "time" "github.com/offchainlabs/nitro/validator" @@ -23,6 +24,7 @@ type ValidationClient struct { stopwaiter.StopWaiter client *rpcclient.RpcClient name string + room int32 } func NewValidationClient(config rpcclient.ClientConfigFetcher, stack *node.Node) *ValidationClient { @@ -32,14 +34,15 @@ func NewValidationClient(config rpcclient.ClientConfigFetcher, stack *node.Node) } func (c *ValidationClient) Launch(entry *validator.ValidationInput, moduleRoot common.Hash) validator.ValidationRun { - valrun := server_common.NewValRun(moduleRoot) - c.LaunchThread(func(ctx context.Context) { + atomic.AddInt32(&c.room, -1) + promise := stopwaiter.LaunchPromiseThread[validator.GoGlobalState](c, func(ctx context.Context) (validator.GoGlobalState, error) { input := ValidationInputToJson(entry) var res validator.GoGlobalState err := c.client.CallContext(ctx, &res, Namespace+"_validate", input, moduleRoot) - valrun.ConsumeResult(res, err) + atomic.AddInt32(&c.room, 1) + return res, err }) - return valrun + return server_common.NewValRun(promise, moduleRoot) } func (c *ValidationClient) Start(ctx_in context.Context) error { @@ -57,6 +60,18 @@ func (c *ValidationClient) Start(ctx_in context.Context) error { if len(name) == 0 { return errors.New("couldn't read name from server") } + var room int + err = c.client.CallContext(c.GetContext(), &room, Namespace+"_room") + if err != nil { + return err + } + if room < 2 { + log.Warn("validation server not enough room, overriding to 2", "name", name, "room", room) + room = 2 + } else { + log.Info("connected to validation server", "name", name, "room", room) + } + atomic.StoreInt32(&c.room, int32(room)) c.name = name return nil } @@ -76,13 +91,11 @@ func (c *ValidationClient) Name() string { } func (c *ValidationClient) Room() int { - var res int - err := c.client.CallContext(c.GetContext(), &res, Namespace+"_room") - if err != nil { - log.Error("error contacting validation server", "name", c.name, "err", err) + room32 := atomic.LoadInt32(&c.room) + if room32 < 0 { return 0 } - return res + return int(room32) } type ExecutionClient struct { diff --git a/validator/server_arb/validator_spawner.go b/validator/server_arb/validator_spawner.go index a073d24c3c..f9d0705f59 100644 --- a/validator/server_arb/validator_spawner.go +++ b/validator/server_arb/validator_spawner.go @@ -57,32 +57,6 @@ type ArbitratorSpawner struct { config ArbitratorSpawnerConfigFecher } -type valRun struct { - containers.Promise[validator.GoGlobalState] - root common.Hash -} - -func (r *valRun) WasmModuleRoot() common.Hash { - return r.root -} - -func (r *valRun) Close() {} - -func NewvalRun(root common.Hash) *valRun { - return &valRun{ - Promise: containers.NewPromise[validator.GoGlobalState](nil), - root: root, - } -} - -func (r *valRun) consumeResult(res validator.GoGlobalState, err error) { - if err != nil { - r.ProduceError(err) - } else { - r.Produce(res) - } -} - func NewArbitratorSpawner(locator *server_common.MachineLocator, config ArbitratorSpawnerConfigFecher) (*ArbitratorSpawner, error) { // TODO: preload machines spawner := &ArbitratorSpawner{ @@ -180,12 +154,11 @@ func (v *ArbitratorSpawner) execute( func (v *ArbitratorSpawner) Launch(entry *validator.ValidationInput, moduleRoot common.Hash) validator.ValidationRun { atomic.AddInt32(&v.count, 1) - run := NewvalRun(moduleRoot) - v.LaunchThread(func(ctx context.Context) { + promise := stopwaiter.LaunchPromiseThread[validator.GoGlobalState](v, func(ctx context.Context) (validator.GoGlobalState, error) { defer atomic.AddInt32(&v.count, -1) - run.consumeResult(v.execute(ctx, entry, moduleRoot)) + return v.execute(ctx, entry, moduleRoot) }) - return run + return server_common.NewValRun(promise, moduleRoot) } func (v *ArbitratorSpawner) Room() int { @@ -193,11 +166,7 @@ func (v *ArbitratorSpawner) Room() int { if avail == 0 { avail = runtime.NumCPU() } - current := int(atomic.LoadInt32(&v.count)) - if current >= avail { - return 0 - } - return avail - current + return avail } var launchTime = time.Now().Format("2006_01_02__15_04") diff --git a/validator/server_common/valrun.go b/validator/server_common/valrun.go index 1331c29852..8486664008 100644 --- a/validator/server_common/valrun.go +++ b/validator/server_common/valrun.go @@ -7,7 +7,7 @@ import ( ) type ValRun struct { - containers.Promise[validator.GoGlobalState] + containers.PromiseInterface[validator.GoGlobalState] root common.Hash } @@ -15,17 +15,9 @@ func (r *ValRun) WasmModuleRoot() common.Hash { return r.root } -func NewValRun(root common.Hash) *ValRun { +func NewValRun(promise containers.PromiseInterface[validator.GoGlobalState], root common.Hash) *ValRun { return &ValRun{ - Promise: containers.NewPromise[validator.GoGlobalState](nil), - root: root, - } -} - -func (r *ValRun) ConsumeResult(res validator.GoGlobalState, err error) { - if err != nil { - r.ProduceError(err) - } else { - r.Produce(res) + PromiseInterface: promise, + root: root, } } diff --git a/validator/server_jit/spawner.go b/validator/server_jit/spawner.go index 7a3394bcae..6de006b182 100644 --- a/validator/server_jit/spawner.go +++ b/validator/server_jit/spawner.go @@ -90,12 +90,11 @@ func (s *JitSpawner) Name() string { func (v *JitSpawner) Launch(entry *validator.ValidationInput, moduleRoot common.Hash) validator.ValidationRun { atomic.AddInt32(&v.count, 1) - run := server_common.NewValRun(moduleRoot) - go func() { - run.ConsumeResult(v.execute(v.GetContext(), entry, moduleRoot)) - atomic.AddInt32(&v.count, -1) - }() - return run + promise := stopwaiter.LaunchPromiseThread[validator.GoGlobalState](v, func(ctx context.Context) (validator.GoGlobalState, error) { + defer atomic.AddInt32(&v.count, -1) + return v.execute(ctx, entry, moduleRoot) + }) + return server_common.NewValRun(promise, moduleRoot) } func (v *JitSpawner) Room() int { @@ -103,7 +102,7 @@ func (v *JitSpawner) Room() int { if avail == 0 { avail = runtime.NumCPU() } - return avail - int(atomic.LoadInt32(&v.count)) + return avail } func (v *JitSpawner) Stop() {