diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 5f823d28b88..847f4e244cf 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -42,8 +42,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" - "github.com/multiversx/mx-chain-go/trie/hashesHolder" - disabledHashesHolder "github.com/multiversx/mx-chain-go/trie/hashesHolder/disabled" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -427,14 +425,17 @@ func TestAccountsDB_SaveAccountMalfunctionMarshallerShouldErr(t *testing.T) { func TestAccountsDB_SaveAccountCollectsAllStateChanges(t *testing.T) { t.Parallel() + autoBalanceFlagEnabled := false enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return autoBalanceFlagEnabled + }, } _, adb := getDefaultStateComponentsWithCustomEnableEpochs(enableEpochs) address := generateRandomByteArray(32) stepCreateAccountWithDataTrieAndCode(t, adb, address) - enableEpochs.IsAutoBalanceDataTriesEnabledField = true + autoBalanceFlagEnabled = true stepMigrateDataTrieValAndChangeCode(t, adb, address) } diff --git a/state/trackableDataTrie/trackableDataTrie_test.go b/state/trackableDataTrie/trackableDataTrie_test.go index 646502a4536..de995a7082a 100644 --- a/state/trackableDataTrie/trackableDataTrie_test.go +++ b/state/trackableDataTrie/trackableDataTrie_test.go @@ -406,7 +406,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { deleteCalled := false updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) @@ -435,12 +437,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return flag == common.AutoBalanceDataTriesFlag - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, expectedVal) @@ -468,8 +465,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return false + }, } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) expectedKey := []byte("key") @@ -495,12 +495,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return false - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, val) @@ -523,7 +518,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) @@ -552,12 +549,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return flag == common.AutoBalanceDataTriesFlag - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -580,7 +572,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: true, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) @@ -604,12 +598,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { - return flag == common.AutoBalanceDataTriesFlag - }, - } - tdt, _ := trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) + tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler) tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(expectedKey, newVal) @@ -789,7 +778,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - IsAutoBalanceDataTriesEnabledField: false, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return false + }, } tdt, _ := trackableDataTrie.NewTrackableDataTrie( identifier, @@ -888,7 +879,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { stateChanges, oldVals, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) - assert.Equal(t, 5, len(oldVals)) + assert.Equal(t, 7, len(oldVals)) assert.Equal(t, 6, len(stateChanges)) assert.Equal(t, hasher.Compute(key1), stateChanges[0].Key)