Skip to content

Commit

Permalink
remove mutex from trie operations
Browse files Browse the repository at this point in the history
  • Loading branch information
BeniaminDrasovean committed Oct 8, 2024
1 parent f2aa05a commit 97abe3f
Show file tree
Hide file tree
Showing 17 changed files with 344 additions and 450 deletions.
1 change: 0 additions & 1 deletion common/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ type Trie interface {
GetSerializedNodes([]byte, uint64) ([][]byte, uint64, error)
GetSerializedNode([]byte) ([]byte, error)
GetAllLeavesOnChannel(allLeavesChan *TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder, trieLeafParser TrieLeafParser) error
GetAllHashes() ([][]byte, error)
GetProof(key []byte) ([][]byte, []byte, error)
VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bool, error)
GetStorageManager() StorageManager
Expand Down
9 changes: 1 addition & 8 deletions integrationTests/state/stateTrie/stateTrie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2082,7 +2082,7 @@ func TestAccountRemoval(t *testing.T) {

shardNode := nodes[0]

dataTriesRootHashes, codeMap := generateAccounts(shardNode, accounts)
_, codeMap := generateAccounts(shardNode, accounts)

_, _ = shardNode.AccntState.Commit()
round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce)
Expand Down Expand Up @@ -2115,13 +2115,6 @@ func TestAccountRemoval(t *testing.T) {
round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce)
checkCodeConsistency(t, shardNode, codeMap)
}

delayRounds = 5
for i := 0; i < delayRounds; i++ {
round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce)
}

checkDataTrieConsistency(t, shardNode.AccntState, removedAccounts, dataTriesRootHashes)
}

func generateAccounts(
Expand Down
77 changes: 15 additions & 62 deletions state/accountsDB.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ type snapshotInfo struct {

// AccountsDB is the struct used for accessing accounts. This struct is concurrent safe.
type AccountsDB struct {
mainTrie common.Trie
hasher hashing.Hasher
marshaller marshal.Marshalizer
accountFactory AccountFactory
storagePruningManager StoragePruningManager
obsoleteDataTrieHashes map[string][][]byte
snapshotsManger SnapshotsManager
mainTrie common.Trie
hasher hashing.Hasher
marshaller marshal.Marshalizer
accountFactory AccountFactory
storagePruningManager StoragePruningManager
snapshotsManger SnapshotsManager

lastRootHash []byte
dataTries common.TriesHolder
Expand Down Expand Up @@ -116,15 +115,14 @@ func NewAccountsDB(args ArgsAccountsDB) (*AccountsDB, error) {

func createAccountsDb(args ArgsAccountsDB) *AccountsDB {
return &AccountsDB{
mainTrie: args.Trie,
hasher: args.Hasher,
marshaller: args.Marshaller,
accountFactory: args.AccountFactory,
storagePruningManager: args.StoragePruningManager,
entries: make([]JournalEntry, 0),
mutOp: sync.RWMutex{},
dataTries: NewDataTriesHolder(),
obsoleteDataTrieHashes: make(map[string][][]byte),
mainTrie: args.Trie,
hasher: args.Hasher,
marshaller: args.Marshaller,
accountFactory: args.AccountFactory,
storagePruningManager: args.StoragePruningManager,
entries: make([]JournalEntry, 0),
mutOp: sync.RWMutex{},
dataTries: NewDataTriesHolder(),
loadCodeMeasurements: &loadingMeasurements{
identifier: "load code",
},
Expand Down Expand Up @@ -534,44 +532,7 @@ func (adb *AccountsDB) removeCodeAndDataTrie(acnt vmcommon.AccountHandler) error
return nil
}

err := adb.removeCode(baseAcc)
if err != nil {
return err
}

err = adb.removeDataTrie(baseAcc)
if err != nil {
return err
}

return nil
}

func (adb *AccountsDB) removeDataTrie(baseAcc baseAccountHandler) error {
rootHash := baseAcc.GetRootHash()
if len(rootHash) == 0 {
return nil
}

dataTrie, err := adb.mainTrie.Recreate(rootHash)
if err != nil {
return err
}

hashes, err := dataTrie.GetAllHashes()
if err != nil {
return err
}

adb.obsoleteDataTrieHashes[string(rootHash)] = hashes

entry, err := NewJournalEntryDataTrieRemove(rootHash, adb.obsoleteDataTrieHashes)
if err != nil {
return err
}
adb.journalize(entry)

return nil
return adb.removeCode(baseAcc)
}

func (adb *AccountsDB) removeCode(baseAcc baseAccountHandler) error {
Expand Down Expand Up @@ -848,7 +809,6 @@ func (adb *AccountsDB) commit() ([]byte, error) {
}

adb.lastRootHash = newRoot
adb.obsoleteDataTrieHashes = make(map[string][][]byte)

log.Trace("accountsDB.Commit ended", "root hash", newRoot)

Expand All @@ -867,12 +827,6 @@ func (adb *AccountsDB) markForEviction(
return nil
}

for _, hashes := range adb.obsoleteDataTrieHashes {
for _, hash := range hashes {
oldHashes[string(hash)] = struct{}{}
}
}

return adb.storagePruningManager.MarkForEviction(oldRoot, newRoot, oldHashes, newHashes)
}

Expand Down Expand Up @@ -932,7 +886,6 @@ func (adb *AccountsDB) recreateTrie(options common.RootHashHolder) error {
log.Trace("accountsDB.RecreateTrie ended")
}()

adb.obsoleteDataTrieHashes = make(map[string][][]byte)
adb.dataTries.Reset()
adb.entries = make([]JournalEntry, 0)
newTrie, err := adb.mainTrie.RecreateFromEpoch(options)
Expand Down
108 changes: 14 additions & 94 deletions state/accountsDB_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1819,89 +1819,6 @@ func TestAccountsDB_MainTrieAutomaticallyMarksCodeUpdatesForEviction(t *testing.
assert.Equal(t, 3, len(hashesForEviction))
}

func TestAccountsDB_RemoveAccountSetsObsoleteHashes(t *testing.T) {
t.Parallel()

_, adb := getDefaultTrieAndAccountsDb()

addr := make([]byte, 32)
acc, _ := adb.LoadAccount(addr)
userAcc := acc.(state.UserAccountHandler)
_ = userAcc.SaveKeyValue([]byte("key"), []byte("value"))

_ = adb.SaveAccount(userAcc)
_, _ = adb.Commit()

acc, _ = adb.LoadAccount(addr)
userAcc = acc.(state.UserAccountHandler)
userAcc.SetCode([]byte("code"))
snapshot := adb.JournalLen()
hashes, _ := userAcc.DataTrie().(common.Trie).GetAllHashes()

err := adb.RemoveAccount(addr)
obsoleteHashes := adb.GetObsoleteHashes()
assert.Nil(t, err)
assert.Equal(t, 1, len(obsoleteHashes))
assert.Equal(t, hashes, obsoleteHashes[string(hashes[0])])

err = adb.RevertToSnapshot(snapshot)
assert.Nil(t, err)
assert.Equal(t, 0, len(adb.GetObsoleteHashes()))
}

func TestAccountsDB_RemoveAccountMarksObsoleteHashesForEviction(t *testing.T) {
t.Parallel()

maxTrieLevelInMemory := uint(5)
marshaller := &marshallerMock.MarshalizerMock{}
hasher := &hashingMocks.HasherMock{}

ewl := stateMock.NewEvictionWaitingListMock(100)
args := storage.GetStorageManagerArgs()
tsm, _ := trie.NewTrieStorageManager(args)
tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory)
spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 5)

argsAccountsDB := createMockAccountsDBArgs()
argsAccountsDB.Trie = tr
argsAccountsDB.Hasher = hasher
argsAccountsDB.Marshaller = marshaller
argsAccCreator := factory.ArgsAccountCreator{
Hasher: hasher,
Marshaller: marshaller,
EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{},
}
argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator)
argsAccountsDB.StoragePruningManager = spm

adb, _ := state.NewAccountsDB(argsAccountsDB)

addr := make([]byte, 32)
acc, _ := adb.LoadAccount(addr)
userAcc := acc.(state.UserAccountHandler)
_ = userAcc.SaveKeyValue([]byte("key"), []byte("value"))
_ = adb.SaveAccount(userAcc)

addr1 := make([]byte, 32)
addr1[0] = 1
acc, _ = adb.LoadAccount(addr1)
_ = adb.SaveAccount(acc)

rootHash, _ := adb.Commit()
hashes, _ := userAcc.DataTrie().(common.Trie).GetAllHashes()

err := adb.RemoveAccount(addr)
obsoleteHashes := adb.GetObsoleteHashes()
assert.Nil(t, err)
assert.Equal(t, 1, len(obsoleteHashes))
assert.Equal(t, hashes, obsoleteHashes[string(hashes[0])])

_, _ = adb.Commit()
rootHash = append(rootHash, byte(state.OldRoot))
oldHashes := ewl.Cache[string(rootHash)]
assert.Equal(t, 5, len(oldHashes))
}

func TestAccountsDB_TrieDatabasePruning(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -2635,7 +2552,8 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) {
t.Run("can not sync missing snapshot node should not put activeDbKey", func(t *testing.T) {
t.Parallel()

trieHashes := make([][]byte, 0)
trieHashes := common.ModifiedHashes{}
var rootHash []byte
valuesMap := make(map[string][]byte)
missingNodeError := errors.New("missing trie node")
isMissingNodeCalled := false
Expand All @@ -2653,7 +2571,7 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) {
return []byte(common.ActiveDBVal), nil
}

if len(trieHashes) != 0 && bytes.Equal(key, trieHashes[0]) {
if len(trieHashes) != 0 && bytes.Equal(key, rootHash) {
isMissingNodeCalled = true
return nil, missingNodeError
}
Expand All @@ -2666,10 +2584,9 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) {
}

tr, adb := getDefaultTrieAndAccountsDbWithCustomDB(&testscommon.SnapshotPruningStorerMock{MemDbMock: memDbMock})
prepareTrie(tr, 3)
trieHashes = prepareTrie(tr, 3)

rootHash, _ := tr.RootHash()
trieHashes, _ = tr.GetAllHashes()
rootHash, _ = tr.RootHash()

syncer := &mock.AccountsDBSyncerStub{
SyncAccountsCalled: func(rootHash []byte, _ common.StorageMarker) error {
Expand All @@ -2691,7 +2608,8 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) {
t.Run("nil syncer should not put activeDbKey", func(t *testing.T) {
t.Parallel()

trieHashes := make([][]byte, 0)
trieHashes := common.ModifiedHashes{}
var rootHash []byte
valuesMap := make(map[string][]byte)
missingNodeError := errors.New("missing trie node")
isMissingNodeCalled := false
Expand All @@ -2708,7 +2626,7 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) {
return []byte(common.ActiveDBVal), nil
}

if len(trieHashes) != 0 && bytes.Equal(key, trieHashes[0]) {
if len(trieHashes) != 0 && bytes.Equal(key, rootHash) {
isMissingNodeCalled = true
return nil, missingNodeError
}
Expand All @@ -2721,10 +2639,9 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) {
}

tr, adb := getDefaultTrieAndAccountsDbWithCustomDB(&testscommon.SnapshotPruningStorerMock{MemDbMock: memDbMock})
prepareTrie(tr, 3)
trieHashes = prepareTrie(tr, 3)

rootHash, _ := tr.RootHash()
trieHashes, _ = tr.GetAllHashes()
rootHash, _ = tr.RootHash()

adb.SnapshotState(rootHash, 0)

Expand Down Expand Up @@ -2769,14 +2686,17 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) {
})
}

func prepareTrie(tr common.Trie, numKeys int) {
func prepareTrie(tr common.Trie, numKeys int) common.ModifiedHashes {
for i := 0; i < numKeys; i++ {
key := fmt.Sprintf("key%d", i)
val := fmt.Sprintf("val%d", i)
_ = tr.Update([]byte(key), []byte(val))
}

hashes, _ := tr.GetDirtyHashes()
_ = tr.Commit()

return hashes
}

func addDataTries(accountsAddresses [][]byte, adb *state.AccountsDB) {
Expand Down
5 changes: 0 additions & 5 deletions state/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ func (adb *AccountsDB) GetAccount(address []byte) (vmcommon.AccountHandler, erro
return adb.getAccount(address, adb.getMainTrie())
}

// GetObsoleteHashes -
func (adb *AccountsDB) GetObsoleteHashes() map[string][][]byte {
return adb.obsoleteDataTrieHashes
}

// GetCode -
func GetCode(account baseAccountHandler) []byte {
return account.GetCodeHash()
Expand Down
10 changes: 0 additions & 10 deletions testscommon/trie/trieStub.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type TrieStub struct {
GetObsoleteHashesCalled func() [][]byte
AppendToOldHashesCalled func([][]byte)
GetSerializedNodesCalled func([]byte, uint64) ([][]byte, uint64, error)
GetAllHashesCalled func() ([][]byte, error)
GetAllLeavesOnChannelCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, trieLeafParser common.TrieLeafParser) error
GetProofCalled func(key []byte) ([][]byte, []byte, error)
VerifyProofCalled func(rootHash []byte, key []byte, proof [][]byte) (bool, error)
Expand Down Expand Up @@ -187,15 +186,6 @@ func (ts *TrieStub) GetDirtyHashes() (common.ModifiedHashes, error) {
func (ts *TrieStub) SetNewHashes(_ common.ModifiedHashes) {
}

// GetAllHashes -
func (ts *TrieStub) GetAllHashes() ([][]byte, error) {
if ts.GetAllHashesCalled != nil {
return ts.GetAllHashesCalled()
}

return nil, nil
}

// GetSerializedNode -
func (ts *TrieStub) GetSerializedNode(bytes []byte) ([]byte, error) {
if ts.GetSerializedNodeCalled != nil {
Expand Down
5 changes: 3 additions & 2 deletions trie/baseIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ func newBaseIterator(trie common.Trie) (*baseIterator, error) {
}

trieStorage := trie.GetStorageManager()
nextNodes, err := pmt.root.getChildren(trieStorage)
rootNode := pmt.GetRootNode()
nextNodes, err := rootNode.getChildren(trieStorage)
if err != nil {
return nil, err
}

return &baseIterator{
currentNode: pmt.root,
currentNode: rootNode,
nextNodes: nextNodes,
db: trieStorage,
}, nil
Expand Down
Loading

0 comments on commit 97abe3f

Please sign in to comment.