Skip to content

Commit

Permalink
feat: add merge witnesses and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
V-Staykov committed Nov 15, 2024
1 parent 40c661c commit b70c4f9
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 0 deletions.
74 changes: 74 additions & 0 deletions turbo/trie/trie_zkevm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty off
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

// Package trie implements Merkle Patricia Tries.
package trie

import (
"fmt"
)

// Get returns the value for key stored in the trie.
func (t *Trie) GetAllAddresses() ([][]byte, error) {
if t.root == nil {
return nil, nil
}

accountAddresses := make([][]byte, 0, t.NumberOfAccounts())
if err := getAllAddressesFromNode(t.root, &accountAddresses, []byte{}); err != nil {
return nil, err
}

return accountAddresses, nil
}

func getAllAddressesFromNode(currentNode node, addresses *[][]byte, address []byte) error {
switch n := (currentNode).(type) {
case nil:
return nil
case *shortNode:
key := append(address, n.Key...)
switch n.Val.(type) {
case *accountNode:
*addresses = append(*addresses, hexToKeybytes(append(address, n.Key...)))
default:
return getAllAddressesFromNode(n.Val, addresses, key)
}
case *duoNode:
i1, i2 := n.childrenIdx()
tmpAddress := append(address, i1)
if err := getAllAddressesFromNode(n.child1, addresses, tmpAddress); err != nil {
return err
}
return getAllAddressesFromNode(n.child2, addresses, append(address, i2))
case *fullNode:
for index, child := range n.Children {
if err := getAllAddressesFromNode(child, addresses, append(address, byte(index))); err != nil {
return err
}
}
return nil
case hashNode:
return nil
case *accountNode:
*addresses = append(*addresses, hexToKeybytes(address))
return nil
default:
return fmt.Errorf("invalid node: %v", currentNode)
}

return nil
}
73 changes: 73 additions & 0 deletions turbo/trie/trie_zkevm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package trie

import (
"bytes"
"math/big"
"math/rand"
"testing"

libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/core/types/accounts"
"github.com/ledgerwatch/erigon/crypto"
"github.com/stretchr/testify/assert"
)

func TestGetAllAccounts(t *testing.T) {
trie := newEmpty()

random := rand.New(rand.NewSource(0))

numberOfAccounts := 2000

addresses := make([][]byte, numberOfAccounts)
for i := 0; i < len(addresses); i++ {
a := getAddressForIndex(i)
addresses[i] = crypto.Keccak256(a[:])
}
codeValues := make([][]byte, len(addresses))
for i := 0; i < len(addresses); i++ {
codeValues[i] = genRandomByteArrayOfLen(128)
codeHash := libcommon.BytesToHash(crypto.Keccak256(codeValues[i]))
balance := new(big.Int).Rand(random, new(big.Int).Exp(libcommon.Big2, libcommon.Big256, nil))
acc := accounts.NewAccount()
acc.Nonce = uint64(random.Int63())
acc.Balance.SetFromBig(balance)
acc.Root = EmptyRoot
acc.CodeHash = codeHash

trie.UpdateAccount(addresses[i][:], &acc)
err := trie.UpdateAccountCode(addresses[i][:], codeValues[i])
assert.Nil(t, err, "should successfully insert code")
}

addressesFound, err := trie.GetAllAddresses()
assert.Nil(t, err, "should successfully get all accounts")
assert.Equal(t, len(addresses), len(addressesFound), "should receive the right number of accounts")
for i := 0; i < len(addresses); i++ {
found := false
for j := 0; j < len(addressesFound); j++ {
if bytes.Compare(addresses[i], addressesFound[j]) == 0 {
found = true
break
}
}

assert.True(t, found, "address not found")
}
}
71 changes: 71 additions & 0 deletions zk/witness/witness_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package witness
import (
"bytes"
"context"
"errors"
"fmt"
"math"

Expand All @@ -27,6 +28,10 @@ import (
"github.com/ledgerwatch/log/v3"
)

var (
ErrNoWitnesses = errors.New("witness count is 0")
)

func UnwindForWitness(ctx context.Context, tx kv.RwTx, startBlock, latestBlock uint64, dirs datadir.Dirs, historyV3 bool, agg *state.Aggregator) (err error) {
unwindState := &stagedsync.UnwindState{UnwindPoint: startBlock - 1}
stageState := &stagedsync.StageState{BlockNumber: latestBlock}
Expand Down Expand Up @@ -158,3 +163,69 @@ func GetWitnessBytes(witness *trie.Witness, debug bool) ([]byte, error) {
func ParseWitnessFromBytes(input []byte, trace bool) (*trie.Witness, error) {
return trie.NewWitnessFromReader(bytes.NewReader(input), trace)
}

// merges witnesses into one
// corresponds to a witness built on a range of blocks
// input witnesses should be ordered by consequent blocks
// it replaces values from 2,3,4 into the first witness
// it does this through creating tries from the witnesses, merging the tries and then creating a witness from the rsulting trie
func MergeWitnesses(witnesses []*trie.Witness) (*trie.Witness, error) {
if len(witnesses) == 0 {
return nil, ErrNoWitnesses
}

if len(witnesses) == 1 {
return witnesses[0], nil
}

baseTrie, err := trie.BuildTrieFromWitness(witnesses[0], false)
if err != nil {
return nil, err
}
for i := 1; i < len(witnesses); i++ {
trie, err := trie.BuildTrieFromWitness(witnesses[i], false)
if err != nil {
return nil, err
}
baseTrie, err = mergeTries(baseTrie, trie)
if err != nil {
return nil, err
}
}

return baseTrie.ExtractWitness(false, nil)
}

func mergeTries(trie1, trie2 *trie.Trie) (*trie.Trie, error) {
addresses, err := trie2.GetAllAddresses()
if err != nil {
return nil, err
}

for _, address := range addresses {
account, found := trie2.GetAccount(address[:])
if !found {
return nil, fmt.Errorf("account not found")
}

trie1.UpdateAccount(address[:], account)

code, found := trie2.GetAccountCode(address)
if !found {
return nil, fmt.Errorf("code not found")
}
if err := trie1.UpdateAccountCode(address, code); err != nil {
return nil, err
}

codeSize, found := trie2.GetAccountCodeSize(address)
if !found {
return nil, fmt.Errorf("code size not found")
}
if err := trie1.UpdateAccountCodeSize(address, codeSize); err != nil {
return nil, err
}
}

return trie1, nil
}
85 changes: 85 additions & 0 deletions zk/witness/witness_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package witness

import (
"bytes"
"encoding/binary"
"math/big"
"math/rand"
"testing"

"github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/core/types/accounts"
"github.com/ledgerwatch/erigon/crypto"
"github.com/ledgerwatch/erigon/turbo/trie"
"github.com/stretchr/testify/assert"
)

func TestMergeWitnesses(t *testing.T) {
trie1 := trie.New(common.Hash{})
trie2 := trie.New(common.Hash{})
trieFull := trie.New(common.Hash{})

random := rand.New(rand.NewSource(0))

numberOfAccounts := 2000

addresses := make([][]byte, numberOfAccounts)
for i := 0; i < len(addresses); i++ {
a := getAddressForIndex(i)
addresses[i] = crypto.Keccak256(a[:])
}
codeValues := make([][]byte, len(addresses))
for i := 0; i < len(addresses); i++ {
codeValues[i] = genRandomByteArrayOfLen(128)
codeHash := common.BytesToHash(crypto.Keccak256(codeValues[i]))
balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
acc := accounts.NewAccount()
acc.Nonce = uint64(random.Int63())
acc.Balance.SetFromBig(balance)
acc.Root = trie.EmptyRoot
acc.CodeHash = codeHash

if i&1 == 0 {
trie1.UpdateAccount(addresses[i][:], &acc)
err := trie1.UpdateAccountCode(addresses[i][:], codeValues[i])
assert.Nil(t, err, "should successfully insert code")
} else {
trie2.UpdateAccount(addresses[i][:], &acc)
err := trie2.UpdateAccountCode(addresses[i][:], codeValues[i])
assert.Nil(t, err, "should successfully insert code")
}
trieFull.UpdateAccount(addresses[i][:], &acc)
err := trieFull.UpdateAccountCode(addresses[i][:], codeValues[i])
assert.Nil(t, err, "should successfully insert code")
}

witness1, err := trie1.ExtractWitness(false, nil)
assert.Nil(t, err, "should successfully extract witness")
witness2, err := trie2.ExtractWitness(false, nil)
assert.Nil(t, err, "should successfully extract witness")
witnessFull, err := trieFull.ExtractWitness(false, nil)
assert.Nil(t, err, "should successfully extract witness")

mergedWitness, err := MergeWitnesses([]*trie.Witness{witness1, witness2})
assert.Nil(t, err, "should successfully merge witnesses")

//create writer
var buff bytes.Buffer
mergedWitness.WriteDiff(witnessFull, &buff)
diff := buff.String()
assert.Equal(t, 0, len(diff), "witnesses should be equal")
}

func getAddressForIndex(index int) [20]byte {
var address [20]byte
binary.BigEndian.PutUint32(address[:], uint32(index))
return address
}

func genRandomByteArrayOfLen(length uint) []byte {
array := make([]byte, length)
for i := uint(0); i < length; i++ {
array[i] = byte(rand.Intn(256))
}
return array
}

0 comments on commit b70c4f9

Please sign in to comment.