diff --git a/arbos/internal_tx.go b/arbos/internal_tx.go index ea411e920e..a4e58cbc13 100644 --- a/arbos/internal_tx.go +++ b/arbos/internal_tx.go @@ -88,7 +88,7 @@ func ApplyInternalTxUpdate(tx *types.ArbitrumInternalTx, state *arbosState.Arbos // Try to reap 2 retryables, revert the state on failure snapshot := evm.StateDB.Snapshot() var merkleUpdateEvents []merkleAccumulator.MerkleTreeNodeEvent - var expiredRetryableLeafs []*retryables.ExpiredRetryableLeaf + var expiredRetryableLeaves []*retryables.ExpiredRetryableLeaf for i := 0; i < 2; i++ { var events []merkleAccumulator.MerkleTreeNodeEvent var leaf *retryables.ExpiredRetryableLeaf @@ -99,13 +99,13 @@ func ApplyInternalTxUpdate(tx *types.ArbitrumInternalTx, state *arbosState.Arbos } merkleUpdateEvents = append(merkleUpdateEvents, events...) if leaf != nil { - expiredRetryableLeafs = append(expiredRetryableLeafs, leaf) + expiredRetryableLeaves = append(expiredRetryableLeaves, leaf) } } if err == nil { - for _, leaf := range expiredRetryableLeafs { - position := merkletree.LevelAndLeaf{Level: leaf.Event.Level, Leaf: leaf.Event.NumLeaves} - if err = EmitRetryableExpiredEvent(evm, leaf.TicketId, leaf.Event.Hash, position.ToBigInt()); err != nil { + for _, leaf := range expiredRetryableLeaves { + position := merkletree.LevelAndLeaf{Level: 0, Leaf: leaf.Index} + if err = EmitRetryableExpiredEvent(evm, leaf.TicketId, leaf.Hash, position.ToBigInt()); err != nil { log.Error("Failed to emit RetryableExpired event", "err", err) break } diff --git a/arbos/merkleAccumulator/merkleAccumulator.go b/arbos/merkleAccumulator/merkleAccumulator.go index 5c207b4c6a..d7b5f988ce 100644 --- a/arbos/merkleAccumulator/merkleAccumulator.go +++ b/arbos/merkleAccumulator/merkleAccumulator.go @@ -6,6 +6,7 @@ package merkleAccumulator import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" "github.com/offchainlabs/nitro/arbos/storage" "github.com/offchainlabs/nitro/util/arbmath" ) @@ -107,6 +108,7 @@ func (acc *MerkleAccumulator) GetPartials() ([]*common.Hash, error) { } func (acc *MerkleAccumulator) setPartial(level uint64, val *common.Hash) error { + log.Warn("setPartial", "level", level, "val", val) if acc.backingStorage != nil { err := acc.backingStorage.SetByUint64(2+level, *val) if err != nil { @@ -121,10 +123,10 @@ func (acc *MerkleAccumulator) setPartial(level uint64, val *common.Hash) error { } // Note: itemHash is hashed before being included in the tree, to prevent confusing leafs with branches. -func (acc *MerkleAccumulator) Append(itemHash common.Hash) ([]MerkleTreeNodeEvent, *MerkleTreeNodeEvent, error) { +func (acc *MerkleAccumulator) Append(itemHash common.Hash) ([]MerkleTreeNodeEvent, uint64, error) { size, err := acc.size.Increment() if err != nil { - return nil, nil, err + return nil, 0, err } events := []MerkleTreeNodeEvent{} @@ -134,25 +136,25 @@ func (acc *MerkleAccumulator) Append(itemHash common.Hash) ([]MerkleTreeNodeEven if level == CalcNumPartials(size-1) { // -1 to counteract the acc.size++ at top of this function h := common.BytesToHash(soFar) err := acc.setPartial(level, &h) - return events, &MerkleTreeNodeEvent{level, size - 1, h}, err + return events, size, err } thisLevel, err := acc.getPartial(level) if err != nil { - return nil, nil, err + return nil, size, err } if *thisLevel == (common.Hash{}) { h := common.BytesToHash(soFar) err := acc.setPartial(level, &h) - return events, &MerkleTreeNodeEvent{level, size - 1, h}, err + return events, size, err } soFar, err = acc.Keccak(thisLevel.Bytes(), soFar) if err != nil { - return nil, nil, err + return nil, size, err } h := common.Hash{} err = acc.setPartial(level, &h) if err != nil { - return nil, nil, err + return nil, size, err } level += 1 events = append(events, MerkleTreeNodeEvent{level, size - 1, common.BytesToHash(soFar)}) diff --git a/arbos/retryable_test.go b/arbos/retryable_test.go index 75b8df1bde..59d5544772 100644 --- a/arbos/retryable_test.go +++ b/arbos/retryable_test.go @@ -4,6 +4,7 @@ package arbos import ( + "encoding/hex" "math/big" "math/rand" "testing" @@ -11,14 +12,19 @@ import ( "github.com/offchainlabs/nitro/arbos/arbosState" "github.com/offchainlabs/nitro/arbos/burn" + "github.com/offchainlabs/nitro/arbos/merkleAccumulator" "github.com/offchainlabs/nitro/arbos/retryables" "github.com/offchainlabs/nitro/arbos/util" + "github.com/offchainlabs/nitro/util/arbmath" "github.com/offchainlabs/nitro/util/colors" + "github.com/offchainlabs/nitro/util/merkletree" "github.com/offchainlabs/nitro/util/testhelpers" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" ) @@ -32,6 +38,10 @@ type RetryableTestData struct { calldata []byte } +func (r *RetryableTestData) Hash() common.Hash { + return retryables.RetryableHash(r.id, r.numTries, r.from, r.to, r.callvalue, r.beneficiary, r.calldata) +} + func TestOpenNonexistentRetryable(t *testing.T) { state, _ := arbosState.NewArbosMemoryBackedArbOSState() id := common.BigToHash(big.NewInt(978645611142)) @@ -43,6 +53,7 @@ func TestOpenNonexistentRetryable(t *testing.T) { } func TestRetryableLifecycle(t *testing.T) { + _ = testhelpers.InitTestLog(t, log.LvlWarn) rand.Seed(time.Now().UTC().UnixNano()) state, statedb := arbosState.NewArbosMemoryBackedArbOSState() retryableState := state.RetryableState() @@ -50,6 +61,8 @@ func TestRetryableLifecycle(t *testing.T) { lifetime := uint64(retryables.RetryableLifetimeSeconds) timestampAtCreation := uint64(rand.Int63n(1 << 16)) timeoutAtCreation := timestampAtCreation + lifetime + timestampAtRevival := timeoutAtCreation + 2 + uint64(rand.Int63n(1<<16)) + //timeoutAtRevival := timestampAtRevival + lifetime currentTime := timeoutAtCreation setTime := func(timestamp uint64) uint64 { @@ -78,6 +91,7 @@ func TestRetryableLifecycle(t *testing.T) { // stateBeforeEverything := statedb.IntermediateRoot(true) setTime(timestampAtCreation) + // TODO(magic) remove ids (already in retries data) ids := []common.Hash{} retriesData := []RetryableTestData{} for i := 0; i < 8; i++ { @@ -121,8 +135,11 @@ func TestRetryableLifecycle(t *testing.T) { // check that our reap pricing is reflective of the true cost gasBefore := burner.Burned() evm := vm.NewEVM(vm.BlockContext{}, vm.TxContext{}, statedb, ¶ms.ChainConfig{}, vm.Config{}) - _, _, err := retryableState.TryToReapOneRetryable(currentTime, evm, util.TracingDuringEVM) + _, leaf, err := retryableState.TryToReapOneRetryable(currentTime, evm, util.TracingDuringEVM) Require(t, err) + if leaf != nil { + Fail(t, "Leaf! ", leaf) + } gasBurnedToReap := burner.Burned() - gasBefore if gasBurnedToReap != retryables.RetryableReapPrice { Fail(t, "reaping has been mispriced", gasBurnedToReap, retryables.RetryableReapPrice) @@ -133,6 +150,8 @@ func TestRetryableLifecycle(t *testing.T) { // Advanced passed the extended timeout and reap everything setTime(timeoutAtCreation + lifetime + 1) + var merkleUpdateEvents []merkleAccumulator.MerkleTreeNodeEvent + var expiredRetryablesLeaves []retryables.ExpiredRetryableLeaf for _, id := range ids { // The retryable will be reaped, so opening it should fail shouldBeNil, err := retryableState.OpenRetryable(id, currentTime) @@ -144,8 +163,20 @@ func TestRetryableLifecycle(t *testing.T) { gasBefore := burner.Burned() evm := vm.NewEVM(vm.BlockContext{}, vm.TxContext{}, statedb, ¶ms.ChainConfig{}, vm.Config{}) - _, _, err = retryableState.TryToReapOneRetryable(currentTime, evm, util.TracingDuringEVM) + events, leaf, err := retryableState.TryToReapOneRetryable(currentTime, evm, util.TracingDuringEVM) Require(t, err) + t.Log("events:", events) + merkleUpdateEvents = append(merkleUpdateEvents, events...) + if leaf == nil { + Fail(t, "reaping retryable returned no expired retryable leaf") + } + t.Log("New expired leaf:", leaf) + expiredRetryablesLeaves = append(expiredRetryablesLeaves, *leaf) + merkleUpdateEvents = append(merkleUpdateEvents, merkleAccumulator.MerkleTreeNodeEvent{ + Level: 0, + NumLeaves: leaf.Index, + Hash: leaf.Hash, + }) gasBurnedToReapAndDelete := burner.Burned() - gasBefore if gasBurnedToReapAndDelete <= retryables.RetryableReapPrice { Fail(t, "deletion was cheap", gasBurnedToReapAndDelete, retryables.RetryableReapPrice) @@ -168,7 +199,20 @@ func TestRetryableLifecycle(t *testing.T) { Fail(t, "reaping didn't reset the state", cleared) } + setTime(timestampAtRevival) // revive the retryables + for _, retryData := range retriesData { + size, err := retryableState.Expired.Size() + Require(t, err) + rootHash, err := retryableState.Expired.Root() + Require(t, err) + t.Log("accumulator size:", size, "rootHash:", rootHash) + newTimeout, err := retryableState.Revive(expiredRetryableReviveData(t, retryData, merkleUpdateEvents, expiredRetryablesLeaves, currentTime, lifetime)) + Require(t, err, "failed to revive the retryable") + if newTimeout != currentTime+lifetime { + Fail(t, "new timeout after revival is wrong", newTimeout, currentTime+lifetime) + } + } } func TestRetryableCleanup(t *testing.T) { @@ -246,3 +290,208 @@ func stateCheck(t *testing.T, statedb *state.StateDB, change bool, message strin Fail(t, message) } } + +func expiredRetryableReviveData( + t *testing.T, + retryData RetryableTestData, + events []merkleAccumulator.MerkleTreeNodeEvent, + leaves []retryables.ExpiredRetryableLeaf, + now, + lifetime uint64, +) ( + ticketId common.Hash, + numTries uint64, + from common.Address, + to common.Address, + callvalue *big.Int, + beneficiary common.Address, + calldata []byte, + rootHash common.Hash, + leafIndex uint64, + proof []common.Hash, + currentTimestamp uint64, + timeToAdd uint64, +) { + leafHash := crypto.Keccak256Hash(retryData.Hash().Bytes()) + var treeSize uint64 + for _, leaf := range leaves { + if leaf.TicketId == retryData.id { + leafIndex = leaf.Index + if leaf.Hash != leafHash { + Fail(t, "invalid leaf hash in ExpiredRetryableLeaf, want:", leafHash, "have:", leaf.Hash) + } + } + if leaf.Index+1 > treeSize { + treeSize = leaf.Index + 1 + } + } + + balanced := treeSize == arbmath.NextPowerOf2(treeSize)/2 + treeLevels := int(arbmath.Log2ceil(treeSize)) // the # of levels in the tree + proofLevels := treeLevels - 1 // the # of levels where a hash is needed (all but root) + walkLevels := treeLevels // the # of levels we need to consider when building walks + if balanced { + walkLevels -= 1 // skip the root + } + t.Log("Tree has", treeSize, "leaves and", treeLevels, "levels") + t.Log("Balanced:", balanced) + // find which nodes we'll want in our proof up to a partial + query := make(map[merkletree.LevelAndLeaf]struct{}) // the nodes we'll query for + nodes := make([]merkletree.LevelAndLeaf, 0) // the nodes needed (might not be found from query) + which := uint64(1) // which bit to flip & set + place := uint64(leafIndex) // where we are in the tree + t.Log("start place:", place) + for level := 0; level < walkLevels; level++ { + sibling := place ^ which + position := merkletree.LevelAndLeaf{ + Level: uint64(level), + Leaf: sibling, + } + query[position] = struct{}{} + nodes = append(nodes, position) + place |= which // set the bit so that we approach from the right + which <<= 1 // advance to the next bit + } + // find all the partials + partials := make(map[merkletree.LevelAndLeaf]common.Hash) + if !balanced { + power := uint64(1) << proofLevels + total := uint64(0) + for level := proofLevels; level >= 0; level-- { + if (power & treeSize) > 0 { // the partials map to the binary representation of the tree size + total += power // The actual leaf for a given partial is the sum of the powers of 2 + leaf := total - 1 // preceding it. We subtract 1 since we count from 0 + partial := merkletree.LevelAndLeaf{ + Level: uint64(level), + Leaf: leaf, + } + query[partial] = struct{}{} + partials[partial] = common.Hash{} + } + power >>= 1 + } + } + t.Log("Query:", query) + t.Log("Found", len(partials), "partials:", partials) + known := make(map[merkletree.LevelAndLeaf]common.Hash) // all values in the tree we know + partialsByLevel := make(map[uint64]common.Hash) // maps for each level the partial it may have + var minPartialPlace *merkletree.LevelAndLeaf // the lowest-level partial + // search all events + for _, event := range events { + level := event.Level + leaf := event.NumLeaves + hash := event.Hash + t.Log("event:\n\tposition: level", level, "leaf", leaf, "\n\thash: ", hash) + place := merkletree.LevelAndLeaf{ + Level: level, + Leaf: leaf, + } + if _, ok := query[place]; ok { + t.Log("Found queried place:", place) + known[place] = hash + } + if zero, ok := partials[place]; ok { + if zero != (common.Hash{}) { + if zero != hash { + Fail(t, "Somehow got 2 partials for the same level\n\t1st:", zero, "\n\t2nd:", hash, "place:", place) + } + continue + } + partials[place] = hash + partialsByLevel[level] = hash + if minPartialPlace == nil || level < minPartialPlace.Level { + minPartialPlace = &place + } + } + } + for place, hash := range known { + t.Log("known ", place.Level, hash, "@", place) + } + t.Log(len(known), "values are known\n") + for place, hash := range partials { + t.Log("partial", place.Level, hash, "@", place) + } + t.Log("resolving frontiers\n", "minPartialPlace:", minPartialPlace) + if !balanced { + // This tree isn't balanced, so we'll need to use the partials to recover the missing info. + // To do this, we'll walk the boundary of what's known, computing hashes along the way + zero := common.Hash{} + step := *minPartialPlace + step.Leaf += 1 << step.Level // we start on the min partial's zero-hash sibling + known[step] = zero + t.Log("zeroing:", step) + for step.Level < uint64(treeLevels) { + curr, ok := known[step] + if !ok { + Fail(t, "We should know the current node's value") + } + left := curr + right := curr + if _, ok := partialsByLevel[step.Level]; ok { + // a partial on the frontier can only appear on the left + // moving leftward for a level l skips 2^l leaves + step.Leaf -= 1 << step.Level + partial, ok := known[step] + if !ok { + Fail(t, "There should be a partial here") + } + left = partial + } else { + // getting to the next partial means covering its mirror subtree, so we look right + // moving rightward for a level l skips 2^l leaves + step.Leaf += 1 << step.Level + known[step] = zero + right = zero + } + // move to the parent + step.Level += 1 + step.Leaf |= 1 << (step.Level - 1) + known[step] = crypto.Keccak256Hash(left.Bytes(), right.Bytes()) + } + if known[step] != rootHash { + // a correct walk of the frontier should end with resolving the root + t.Log("Walking up the tree didn't re-create the root", known[step], "vs", rootHash) + } + for place, hash := range known { + t.Log("known", place, hash) + } + } + t.Log("Complete proof of leaf", leafIndex) + proof = make([]common.Hash, len(nodes)) + for i, place := range nodes { + hash, ok := known[place] + if !ok { + Fail(t, "We're missing data for the node at position", place) + } + proof[i] = hash + t.Log("node", place, hash) + } + rootHash = leafHash + index := leafIndex + for _, hashFromProof := range proof { + + if index&1 == 0 { + rootHash = crypto.Keccak256Hash(rootHash.Bytes(), hashFromProof.Bytes()) + } else { + rootHash = crypto.Keccak256Hash(hashFromProof.Bytes(), rootHash.Bytes()) + } + index = index / 2 + } + if index != 0 { + Fail(t, "internal test error - failed to compute root hash") + } + t.Log("Root hash", hex.EncodeToString(rootHash[:])) + merkleProof := &merkletree.MerkleProof{ + RootHash: rootHash, + LeafHash: leafHash, + LeafIndex: leafIndex, + Proof: proof, + } + if !merkleProof.IsCorrect() { + Fail(t, "internal test error - incorrect proof") + } + ticketId, numTries, from, to, callvalue, beneficiary, calldata = retryData.id, retryData.numTries, retryData.from, retryData.to, retryData.callvalue, retryData.beneficiary, retryData.calldata + currentTimestamp = now + timeToAdd = lifetime + return +} diff --git a/arbos/retryables/retryable.go b/arbos/retryables/retryable.go index e9e8e8bd86..6c7b63b188 100644 --- a/arbos/retryables/retryable.go +++ b/arbos/retryables/retryable.go @@ -445,8 +445,9 @@ func (retryable *Retryable) Equals(other *Retryable) (bool, error) { // for test } type ExpiredRetryableLeaf struct { - Event *merkleAccumulator.MerkleTreeNodeEvent TicketId common.Hash + Index uint64 + Hash common.Hash } func (rs *RetryableState) TryToReapOneRetryable(currentTimestamp uint64, evm *vm.EVM, scenario util.TracingScenario) ([]merkleAccumulator.MerkleTreeNodeEvent, *ExpiredRetryableLeaf, error) { @@ -488,11 +489,11 @@ func (rs *RetryableState) TryToReapOneRetryable(currentTimestamp uint64, evm *vm if err = clearRetryable(retryableStorage); err != nil { return nil, nil, err } - merkleUpdateEvents, newLeafMerkleEvent, err := rs.Expired.Append(retryableHash) + merkleUpdateEvents, accumulatorSize, err := rs.Expired.Append(retryableHash) if err != nil { return nil, nil, err } - return merkleUpdateEvents, &ExpiredRetryableLeaf{newLeafMerkleEvent, *id}, nil + return merkleUpdateEvents, &ExpiredRetryableLeaf{TicketId: *id, Index: accumulatorSize - 1, Hash: common.BytesToHash(crypto.Keccak256(retryableHash.Bytes()))}, nil } // Consume a window, delaying the timeout one lifetime period