Skip to content
This repository has been archived by the owner on Jul 5, 2024. It is now read-only.

Commit

Permalink
test: complete tests and refine GenerateWitness
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Apr 16, 2024
1 parent ea2f647 commit db52db3
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 189 deletions.
100 changes: 50 additions & 50 deletions geth-utils/gethutil/mpt/trie/stacktrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type StackTrie struct {
// NewStackTrie allocates and initializes an empty trie.
func NewStackTrie(db ethdb.KeyValueWriter) *StackTrie {
return &StackTrie{
nodeType: emptyNode,
nodeType: EmptyNode,
db: db,
}
}
Expand Down Expand Up @@ -166,7 +166,7 @@ func (st *StackTrie) setDb(db ethdb.KeyValueWriter) {

func newLeaf(ko int, key, val []byte, db ethdb.KeyValueWriter) *StackTrie {
st := stackTrieFromPool(db)
st.nodeType = leafNode
st.nodeType = LeafNode
st.keyOffset = ko
st.key = append(st.key, key[ko:]...)
st.val = val
Expand All @@ -175,7 +175,7 @@ func newLeaf(ko int, key, val []byte, db ethdb.KeyValueWriter) *StackTrie {

func newExt(ko int, key []byte, child *StackTrie, db ethdb.KeyValueWriter) *StackTrie {
st := stackTrieFromPool(db)
st.nodeType = extNode
st.nodeType = ExtNode
st.keyOffset = ko
st.key = append(st.key, key[ko:]...)
st.children[0] = child
Expand All @@ -184,11 +184,11 @@ func newExt(ko int, key []byte, child *StackTrie, db ethdb.KeyValueWriter) *Stac

// List all values that StackTrie#nodeType can hold
const (
emptyNode = iota
branchNode
extNode
leafNode
hashedNode
EmptyNode = iota
BranchNode
ExtNode
LeafNode
HashedNode
)

// TryUpdate inserts a (key, value) pair into the stack trie
Expand All @@ -214,7 +214,7 @@ func (st *StackTrie) Reset() {
for i := range st.children {
st.children[i] = nil
}
st.nodeType = emptyNode
st.nodeType = EmptyNode
st.keyOffset = 0
}

Expand All @@ -233,12 +233,12 @@ func (st *StackTrie) getDiffIndex(key []byte) int {
// https://github.dev/ethereum/go-ethereum/blob/00905f7dc406cfb67f64cd74113777044fb886d8/core/types/hashing.go#L105-L134
func (st *StackTrie) insert(key, value []byte) {
switch st.nodeType {
case branchNode: /* Branch */
case BranchNode: /* Branch */
idx := int(key[st.keyOffset])
// Unresolve elder siblings
for i := idx - 1; i >= 0; i-- {
if st.children[i] != nil {
if st.children[i].nodeType != hashedNode {
if st.children[i].nodeType != HashedNode {
st.children[i].hash(true)
}
break
Expand All @@ -252,7 +252,7 @@ func (st *StackTrie) insert(key, value []byte) {
st.children[idx].insert(key, value)
}

case extNode: /* Ext */
case ExtNode: /* Ext */
// Compare both key chunks and see where they differ
diffidx := st.getDiffIndex(key)
// Check if chunks are identical. If so, recurse into
Expand Down Expand Up @@ -287,13 +287,13 @@ func (st *StackTrie) insert(key, value []byte) {
// a branch node.
st.children[0] = nil
p = st
st.nodeType = branchNode
st.nodeType = BranchNode
} else {
// the common prefix is at least one byte
// long, insert a new intermediate branch
// node.
st.children[0] = stackTrieFromPool(st.db)
st.children[0].nodeType = branchNode
st.children[0].nodeType = BranchNode
st.children[0].keyOffset = st.keyOffset + diffidx
p = st.children[0]
}
Expand All @@ -307,7 +307,7 @@ func (st *StackTrie) insert(key, value []byte) {
p.children[newIdx] = o
st.key = st.key[:diffidx]

case leafNode: /* Leaf */
case LeafNode: /* Leaf */
// Compare both key chunks and see where they differ
diffidx := st.getDiffIndex(key)

Expand All @@ -327,15 +327,15 @@ func (st *StackTrie) insert(key, value []byte) {
var p *StackTrie
if diffidx == 0 {
// Convert current leaf into a branch
st.nodeType = branchNode
st.nodeType = BranchNode
p = st
st.children[0] = nil
} else {
// Convert current node into an ext,
// and insert a child branch node.
st.nodeType = extNode
st.nodeType = ExtNode
st.children[0] = NewStackTrie(st.db)
st.children[0].nodeType = branchNode
st.children[0].nodeType = BranchNode
st.children[0].keyOffset = st.keyOffset + diffidx
p = st.children[0]
}
Expand All @@ -355,19 +355,19 @@ func (st *StackTrie) insert(key, value []byte) {
// over to the children.
st.key = st.key[:diffidx]
st.val = nil
case emptyNode: /* Empty */
st.nodeType = leafNode
case EmptyNode: /* Empty */
st.nodeType = LeafNode
st.key = key[st.keyOffset:]
st.val = value
case hashedNode:
case HashedNode:
panic("trying to insert into hash")
default:
panic("invalid type")
}
}

func (st *StackTrie) branchToHasher(doUpdate bool) *hasher {
if st.nodeType != branchNode {
if st.nodeType != BranchNode {
panic("Converting branch to RLP: wrong node")
}
var nodes [17]Node
Expand Down Expand Up @@ -401,7 +401,7 @@ func (st *StackTrie) branchToHasher(doUpdate bool) *hasher {
}

func (st *StackTrie) extNodeToHasher(doUpdate bool) *hasher {
if st.nodeType != extNode {
if st.nodeType != ExtNode {
panic("Converting extension node to RLP: wrong node")
}
st.children[0].hash(doUpdate)
Expand Down Expand Up @@ -433,22 +433,22 @@ func (st *StackTrie) extNodeToHasher(doUpdate bool) *hasher {
return h
}

// hash() hashes the node 'st' and converts it into 'hashedNode', if possible.
// hash() hashes the node 'st' and converts it into 'HashedNode', if possible.
// Possible outcomes:
// 1. The rlp-encoded value was >= 32 bytes:
// - Then the 32-byte `hash` will be accessible in `st.val`.
// - And the 'st.type' will be 'hashedNode'
// - And the 'st.type' will be 'HashedNode'
//
// 2. The rlp-encoded value was < 32 bytes
// - Then the <32 byte rlp-encoded value will be accessible in 'st.val'.
// - And the 'st.type' will be 'hashedNode' AGAIN
// - And the 'st.type' will be 'HashedNode' AGAIN
//
// This method will also:
// set 'st.type' to hashedNode
// set 'st.type' to HashedNode
// clear 'st.key'
func (st *StackTrie) hash(doUpdate bool) {
/* Shortcut if node is already hashed */
if st.nodeType == hashedNode {
if st.nodeType == HashedNode {
return
}
// The 'hasher' is taken from a pool, but we don't actually
Expand All @@ -457,11 +457,11 @@ func (st *StackTrie) hash(doUpdate bool) {
var h *hasher

switch st.nodeType {
case branchNode:
case BranchNode:
h = st.branchToHasher(doUpdate)
case extNode:
case ExtNode:
h = st.extNodeToHasher(doUpdate)
case leafNode:
case LeafNode:
h = NewHasher(false)
defer returnHasherToPool(h)
h.tmp.Reset()
Expand All @@ -478,17 +478,17 @@ func (st *StackTrie) hash(doUpdate bool) {
if err := rlp.Encode(&h.tmp, n); err != nil {
panic(err)
}
case emptyNode:
case EmptyNode:
st.val = emptyRoot.Bytes()
st.key = st.key[:0]
st.nodeType = hashedNode
st.nodeType = HashedNode
return
default:
panic("Invalid node type")
}
if doUpdate {
st.key = st.key[:0]
st.nodeType = hashedNode
st.nodeType = HashedNode
}
if len(h.tmp) < 32 {
st.val = common.CopyBytes(h.tmp)
Expand Down Expand Up @@ -674,20 +674,20 @@ func printProof(ps [][]byte, t, idx []byte) {
enable := byte(150)
fmt.Print(" [")
for i, p := range ps {
if t[i] == extNode {
if t[i] == ExtNode {
fmt.Print("EXT - ")
if idx[0] >= enable {
fmt.Print(" (", p, ") - ")
}
} else if t[i] == branchNode {
} else if t[i] == BranchNode {
fmt.Print("BRANCH - ")
// fmt.Print(" (", p, ") - ")
} else if t[i] == leafNode {
} else if t[i] == LeafNode {
fmt.Print("LEAF - ")
if idx[0] >= enable {
fmt.Print(" (", p, ") - ")
}
} else if t[i] == hashedNode {
} else if t[i] == HashedNode {
fmt.Print("HASHED - ")
// elems, _, _ := rlp.SplitList(p)
// c, _ := rlp.CountValues(elems)
Expand Down Expand Up @@ -774,8 +774,8 @@ func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.Deri
func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, [][]byte, []uint8, error) {
k := KeybytesToHex(key)
// fmt.Println(" k", k)
if st.nodeType == emptyNode {
return [][]byte{}, nil, []uint8{emptyNode}, nil
if st.nodeType == EmptyNode {
return [][]byte{}, nil, []uint8{EmptyNode}, nil
}

// Note that when root is a leaf, this leaf should be returned even if you ask for a different key (than the key of
Expand All @@ -786,8 +786,8 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, []
// (the one not just added) is the same as in the S proof. This wouldn't work if we would have a placeholder leaf
// in the S proof (another reason is that the S proof with a placeholder leaf would be an empty trie and thus with
// a root of an empty trie - which is not the case in S proof).
if st.nodeType == leafNode {
return [][]byte{st.val}, nil, []uint8{leafNode}, nil
if st.nodeType == LeafNode {
return [][]byte{st.val}, nil, []uint8{LeafNode}, nil
}

var nibbles [][]byte
Expand All @@ -799,21 +799,21 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, []
for i := 0; i < len(k); i++ {
// fmt.Print(" ", k[i], "/", c.nodeType, " | ")
proofType = append(proofType, c.nodeType)
if c.nodeType == extNode {
if c.nodeType == ExtNode {
// fmt.Print(c.key, " ")
i += len(c.key) - 1
nodes = append(nodes, c)
c = c.children[0]
} else if c.nodeType == branchNode {
} else if c.nodeType == BranchNode {
nodes = append(nodes, c)
c = c.children[k[i]]
if c == nil {
break
}
} else if c.nodeType == leafNode {
} else if c.nodeType == LeafNode {
nodes = append(nodes, c)
break
} else if c.nodeType == hashedNode {
} else if c.nodeType == HashedNode {
c_rlp, error := db.Get(c.val)
if error != nil {
panic(error)
Expand All @@ -840,7 +840,7 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, []
}
c.val = branchChild
// if there are children, the node type should be branch
proofType[len(proofType)-1] = branchNode
proofType[len(proofType)-1] = BranchNode
}
}

Expand All @@ -852,23 +852,23 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, []
lNodes := len(nodes)
for i := lNodes - 1; i >= 0; i-- {
node := nodes[i]
if node.nodeType == leafNode {
if node.nodeType == LeafNode {
nibbles = append(nibbles, []byte{})
rlp, error := db.Get(node.val)
if error != nil { // TODO: avoid error when RLP
proof = append(proof, node.val) // already have RLP
} else {
proof = append(proof, rlp)
}
} else if node.nodeType == branchNode || node.nodeType == extNode {
} else if node.nodeType == BranchNode || node.nodeType == ExtNode {
node.hash(false)

raw_rlp, error := db.Get(node.val)
if error != nil {
return nil, nil, nil, error
}
proof = append(proof, raw_rlp)
if node.nodeType == extNode {
if node.nodeType == ExtNode {

rlp_flag := uint(raw_rlp[0])
if rlp_flag < 192 || rlp_flag >= 248 {
Expand Down
Loading

0 comments on commit db52db3

Please sign in to comment.