diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index 485b01bf199..0f875999bd1 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -399,6 +399,12 @@ func (tr *patriciaMerkleTrie) recreateFromDb(rootHash []byte, tsm common.Storage // GetSerializedNode returns the serialized node (if existing) provided the node's hash func (tr *patriciaMerkleTrie) GetSerializedNode(hash []byte) ([]byte, error) { + // TODO: investigate if we can move the critical section behavior in the trie node resolver as this call will compete with a normal trie.Get operation + // which might occur during processing. + // warning: A critical section here or on the trie node resolver must be kept as to not overwhelm the node with requests that affects the block processing flow + tr.mutOperation.Lock() + defer tr.mutOperation.Unlock() + log.Trace("GetSerializedNode", "hash", hash) return tr.trieStorage.Get(hash) @@ -406,6 +412,12 @@ func (tr *patriciaMerkleTrie) GetSerializedNode(hash []byte) ([]byte, error) { // GetSerializedNodes returns a batch of serialized nodes from the trie, starting from the given hash func (tr *patriciaMerkleTrie) GetSerializedNodes(rootHash []byte, maxBuffToSend uint64) ([][]byte, uint64, error) { + // TODO: investigate if we can move the critical section behavior in the trie node resolver as this call will compete with a normal trie.Get operation + // which might occur during processing. + // warning: A critical section here or on the trie node resolver must be kept as to not overwhelm the node with requests that affects the block processing flow + tr.mutOperation.Lock() + defer tr.mutOperation.Unlock() + log.Trace("GetSerializedNodes", "rootHash", rootHash) size := uint64(0) diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 900d1b66002..501539a3e54 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -22,7 +23,7 @@ import ( errorsCommon "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" - "github.com/multiversx/mx-chain-go/testscommon/storage" + "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/keyBuilder" @@ -492,17 +493,17 @@ func TestPatriciaMerkleTrie_GetSerializedNodesGetFromCheckpoint(t *testing.T) { _ = tr.Commit() rootHash, _ := tr.RootHash() - storageManager := tr.GetStorageManager() + storageManagerInstance := tr.GetStorageManager() dirtyHashes := trie.GetDirtyHashes(tr) - storageManager.AddDirtyCheckpointHashes(rootHash, dirtyHashes) + storageManagerInstance.AddDirtyCheckpointHashes(rootHash, dirtyHashes) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, ErrChan: errChan.NewErrChanWrapper(), } - storageManager.SetCheckpoint(rootHash, make([]byte, 0), iteratorChannels, nil, &trieMock.MockStatistics{}) - trie.WaitForOperationToComplete(storageManager) + storageManagerInstance.SetCheckpoint(rootHash, make([]byte, 0), iteratorChannels, nil, &trieMock.MockStatistics{}) + trie.WaitForOperationToComplete(storageManagerInstance) - err := storageManager.Remove(rootHash) + err := storageManagerInstance.Remove(rootHash) assert.Nil(t, err) maxBuffToSend := uint64(500) @@ -1085,64 +1086,56 @@ func TestPatriciaMerkleTrie_ConcurrentOperations(t *testing.T) { wg.Wait() } -func TestPatriciaMerkleTrie_GetSerializedNodesClose(t *testing.T) { +func TestPatriciaMerkleTrie_GetSerializedNodesShouldSerializeTheCalls(t *testing.T) { t.Parallel() args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &storage.StorerStub{ - GetCalled: func(key []byte) ([]byte, error) { - // gets take a long time + numConcurrentCalls := int32(0) + testTrieStorageManager := &storageManager.StorageManagerStub{ + GetCalled: func(bytes []byte) ([]byte, error) { + newValue := atomic.AddInt32(&numConcurrentCalls, 1) + defer atomic.AddInt32(&numConcurrentCalls, -1) + + assert.Equal(t, int32(1), newValue) + + // get takes a long time time.Sleep(time.Millisecond * 10) - return key, nil + + return bytes, nil }, } - trieStorageManager, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) - numGoRoutines := 1000 - wgStart := sync.WaitGroup{} - wgStart.Add(numGoRoutines) - wgEnd := sync.WaitGroup{} - wgEnd.Add(numGoRoutines) + tr, _ := trie.NewTrie(testTrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + numGoRoutines := 100 + wg := sync.WaitGroup{} + wg.Add(numGoRoutines) for i := 0; i < numGoRoutines; i++ { if i%2 == 0 { go func() { time.Sleep(time.Millisecond * 100) - wgStart.Done() - _, _, _ = tr.GetSerializedNodes([]byte("dog"), 1024) - wgEnd.Done() + wg.Done() }() } else { go func() { time.Sleep(time.Millisecond * 100) - wgStart.Done() - _, _ = tr.GetSerializedNode([]byte("dog")) - wgEnd.Done() + wg.Done() }() } } - wgStart.Wait() + wg.Wait() chanClosed := make(chan struct{}) go func() { _ = tr.Close() close(chanClosed) }() - chanGetsEnded := make(chan struct{}) - go func() { - wgEnd.Wait() - close(chanGetsEnded) - }() - timeout := time.Second * 10 select { case <-chanClosed: // ok - case <-chanGetsEnded: - assert.Fail(t, "trie should have been closed before all gets ended") case <-time.After(timeout): assert.Fail(t, "timeout waiting for trie to be closed") }