diff --git a/app/scripts/controllers/mmi-controller.ts b/app/scripts/controllers/mmi-controller.ts index 7cc024d0cd47..dc00a69a9082 100644 --- a/app/scripts/controllers/mmi-controller.ts +++ b/app/scripts/controllers/mmi-controller.ts @@ -454,7 +454,7 @@ export default class MMIController extends EventEmitter { const allAccounts = await this.keyringController.getAccounts(); const accountsToTrack = [ - ...new Set( + ...new Set( oldAccounts.concat(allAccounts.map((a: string) => a.toLowerCase())), ), ]; diff --git a/app/scripts/lib/account-tracker.test.js b/app/scripts/lib/account-tracker.test.js deleted file mode 100644 index 4bd73a472811..000000000000 --- a/app/scripts/lib/account-tracker.test.js +++ /dev/null @@ -1,729 +0,0 @@ -import EventEmitter from 'events'; -import { ControllerMessenger } from '@metamask/base-controller'; - -import { flushPromises } from '../../../test/lib/timer-helpers'; -import { createTestProviderTools } from '../../../test/stub/provider'; -import AccountTracker from './account-tracker'; - -const noop = () => true; -const currentNetworkId = '5'; -const currentChainId = '0x5'; -const VALID_ADDRESS = '0x0000000000000000000000000000000000000000'; -const VALID_ADDRESS_TWO = '0x0000000000000000000000000000000000000001'; - -const SELECTED_ADDRESS = '0x123'; - -const INITIAL_BALANCE_1 = '0x1'; -const INITIAL_BALANCE_2 = '0x2'; -const UPDATE_BALANCE = '0xabc'; -const UPDATE_BALANCE_HOOK = '0xabcd'; - -const GAS_LIMIT = '0x111111'; -const GAS_LIMIT_HOOK = '0x222222'; - -// The below three values were generated by running MetaMask in the browser -// The response to eth_call, which is called via `ethContract.balances` -// in `_updateAccountsViaBalanceChecker` of account-tracker.js, needs to be properly -// formatted or else ethers will throw an error. -const ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN = - '0x0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000038d7ea4c6800600000000000000000000000000000000000000000000000000000000000186a0'; -const EXPECTED_CONTRACT_BALANCE_1 = '0x038d7ea4c68006'; -const EXPECTED_CONTRACT_BALANCE_2 = '0x0186a0'; - -const mockAccounts = { - [VALID_ADDRESS]: { address: VALID_ADDRESS, balance: INITIAL_BALANCE_1 }, - [VALID_ADDRESS_TWO]: { - address: VALID_ADDRESS_TWO, - balance: INITIAL_BALANCE_2, - }, -}; - -function buildMockBlockTracker({ shouldStubListeners = true } = {}) { - const blockTrackerStub = new EventEmitter(); - blockTrackerStub.getCurrentBlock = noop; - blockTrackerStub.getLatestBlock = noop; - if (shouldStubListeners) { - jest.spyOn(blockTrackerStub, 'addListener').mockImplementation(); - jest.spyOn(blockTrackerStub, 'removeListener').mockImplementation(); - } - return blockTrackerStub; -} - -function buildAccountTracker({ - completedOnboarding = false, - useMultiAccountBalanceChecker = false, - ...accountTrackerOptions -} = {}) { - const { provider } = createTestProviderTools({ - scaffold: { - eth_getBalance: UPDATE_BALANCE, - eth_call: ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN, - eth_getBlockByNumber: { gasLimit: GAS_LIMIT }, - }, - networkId: currentNetworkId, - chainId: currentNetworkId, - }); - const blockTrackerStub = buildMockBlockTracker(); - - const providerFromHook = createTestProviderTools({ - scaffold: { - eth_getBalance: UPDATE_BALANCE_HOOK, - eth_call: ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN, - eth_getBlockByNumber: { gasLimit: GAS_LIMIT_HOOK }, - }, - networkId: '0x1', - chainId: '0x1', - }).provider; - - const blockTrackerFromHookStub = buildMockBlockTracker(); - - const getNetworkClientByIdStub = jest.fn().mockReturnValue({ - configuration: { - chainId: '0x1', - }, - blockTracker: blockTrackerFromHookStub, - provider: providerFromHook, - }); - - const controllerMessenger = new ControllerMessenger(); - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - () => ({ - id: 'accountId', - address: SELECTED_ADDRESS, - }), - ); - - const accountTracker = new AccountTracker({ - provider, - blockTracker: blockTrackerStub, - getNetworkClientById: getNetworkClientByIdStub, - getNetworkIdentifier: jest.fn(), - preferencesController: { - store: { - getState: () => ({ - useMultiAccountBalanceChecker, - }), - subscribe: noop, - }, - }, - onboardingController: { - state: { - completedOnboarding, - }, - }, - controllerMessenger, - onAccountRemoved: noop, - getCurrentChainId: () => currentChainId, - ...accountTrackerOptions, - }); - - return { accountTracker, blockTrackerFromHookStub, blockTrackerStub }; -} - -describe('Account Tracker', () => { - afterEach(() => { - jest.resetAllMocks(); - }); - - describe('start', () => { - it('restarts the subscription to the block tracker and update accounts', async () => { - const { accountTracker, blockTrackerStub } = buildAccountTracker(); - const updateAccountsSpy = jest - .spyOn(accountTracker, 'updateAccounts') - .mockResolvedValue(); - - accountTracker.start(); - - expect(blockTrackerStub.removeListener).toHaveBeenNthCalledWith( - 1, - 'latest', - expect.any(Function), - ); - expect(blockTrackerStub.addListener).toHaveBeenNthCalledWith( - 1, - 'latest', - expect.any(Function), - ); - expect(updateAccountsSpy).toHaveBeenNthCalledWith(1); // called first time with no args - - accountTracker.start(); - - expect(blockTrackerStub.removeListener).toHaveBeenNthCalledWith( - 2, - 'latest', - expect.any(Function), - ); - expect(blockTrackerStub.addListener).toHaveBeenNthCalledWith( - 2, - 'latest', - expect.any(Function), - ); - expect(updateAccountsSpy).toHaveBeenNthCalledWith(2); // called second time with no args - - accountTracker.stop(); - }); - }); - - describe('stop', () => { - it('ends the subscription to the block tracker', async () => { - const { accountTracker, blockTrackerStub } = buildAccountTracker(); - - accountTracker.stop(); - - expect(blockTrackerStub.removeListener).toHaveBeenNthCalledWith( - 1, - 'latest', - expect.any(Function), - ); - }); - }); - - describe('startPollingByNetworkClientId', () => { - it('should subscribe to the block tracker and update accounts if not already using the networkClientId', async () => { - const { accountTracker, blockTrackerFromHookStub } = - buildAccountTracker(); - - const updateAccountsSpy = jest - .spyOn(accountTracker, 'updateAccounts') - .mockResolvedValue(); - - accountTracker.startPollingByNetworkClientId('mainnet'); - - expect(blockTrackerFromHookStub.addListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - expect(updateAccountsSpy).toHaveBeenCalledWith('mainnet'); - - accountTracker.startPollingByNetworkClientId('mainnet'); - - expect(blockTrackerFromHookStub.addListener).toHaveBeenCalledTimes(1); - expect(updateAccountsSpy).toHaveBeenCalledTimes(1); - - accountTracker.stopAllPolling(); - }); - - it('should subscribe to the block tracker and update accounts for each networkClientId', async () => { - const blockTrackerFromHookStub1 = buildMockBlockTracker(); - const blockTrackerFromHookStub2 = buildMockBlockTracker(); - const blockTrackerFromHookStub3 = buildMockBlockTracker(); - const getNetworkClientByIdStub = jest - .fn() - .mockImplementation((networkClientId) => { - switch (networkClientId) { - case 'mainnet': - return { - configuration: { - chainId: '0x1', - }, - blockTracker: blockTrackerFromHookStub1, - }; - case 'goerli': - return { - configuration: { - chainId: '0x5', - }, - blockTracker: blockTrackerFromHookStub2, - }; - case 'networkClientId1': - return { - configuration: { - chainId: '0xa', - }, - blockTracker: blockTrackerFromHookStub3, - }; - default: - throw new Error('unexpected networkClientId'); - } - }); - const { accountTracker } = buildAccountTracker({ - getNetworkClientById: getNetworkClientByIdStub, - }); - - const updateAccountsSpy = jest - .spyOn(accountTracker, 'updateAccounts') - .mockResolvedValue(); - - accountTracker.startPollingByNetworkClientId('mainnet'); - - expect(blockTrackerFromHookStub1.addListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - expect(updateAccountsSpy).toHaveBeenCalledWith('mainnet'); - - accountTracker.startPollingByNetworkClientId('goerli'); - - expect(blockTrackerFromHookStub2.addListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - expect(updateAccountsSpy).toHaveBeenCalledWith('goerli'); - - accountTracker.startPollingByNetworkClientId('networkClientId1'); - - expect(blockTrackerFromHookStub3.addListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - expect(updateAccountsSpy).toHaveBeenCalledWith('networkClientId1'); - - accountTracker.stopAllPolling(); - }); - }); - - describe('stopPollingByPollingToken', () => { - it('should unsubscribe from the block tracker when called with a valid polling that was the only active pollingToken for a given networkClient', async () => { - const { accountTracker, blockTrackerFromHookStub } = - buildAccountTracker(); - - jest.spyOn(accountTracker, 'updateAccounts').mockResolvedValue(); - - const pollingToken = - accountTracker.startPollingByNetworkClientId('mainnet'); - - accountTracker.stopPollingByPollingToken(pollingToken); - - expect(blockTrackerFromHookStub.removeListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - }); - - it('should not unsubscribe from the block tracker if called with one of multiple active polling tokens for a given networkClient', async () => { - const { accountTracker, blockTrackerFromHookStub } = - buildAccountTracker(); - - jest.spyOn(accountTracker, 'updateAccounts').mockResolvedValue(); - - const pollingToken1 = - accountTracker.startPollingByNetworkClientId('mainnet'); - accountTracker.startPollingByNetworkClientId('mainnet'); - - accountTracker.stopPollingByPollingToken(pollingToken1); - - expect(blockTrackerFromHookStub.removeListener).not.toHaveBeenCalled(); - - accountTracker.stopAllPolling(); - }); - - it('should error if no pollingToken is passed', () => { - const { accountTracker } = buildAccountTracker(); - - expect(() => { - accountTracker.stopPollingByPollingToken(undefined); - }).toThrow('pollingToken required'); - }); - - it('should error if no matching pollingToken is found', () => { - const { accountTracker } = buildAccountTracker(); - - expect(() => { - accountTracker.stopPollingByPollingToken('potato'); - }).toThrow('pollingToken not found'); - }); - }); - - describe('stopAll', () => { - it('should end all subscriptions', async () => { - const blockTrackerFromHookStub1 = buildMockBlockTracker(); - const blockTrackerFromHookStub2 = buildMockBlockTracker(); - const getNetworkClientByIdStub = jest - .fn() - .mockImplementation((networkClientId) => { - switch (networkClientId) { - case 'mainnet': - return { - configuration: { - chainId: '0x1', - }, - blockTracker: blockTrackerFromHookStub1, - }; - case 'goerli': - return { - configuration: { - chainId: '0x5', - }, - blockTracker: blockTrackerFromHookStub2, - }; - default: - throw new Error('unexpected networkClientId'); - } - }); - const { accountTracker, blockTrackerStub } = buildAccountTracker({ - getNetworkClientById: getNetworkClientByIdStub, - }); - - jest.spyOn(accountTracker, 'updateAccounts').mockResolvedValue(); - - accountTracker.startPollingByNetworkClientId('mainnet'); - - accountTracker.startPollingByNetworkClientId('goerli'); - - accountTracker.stopAllPolling(); - - expect(blockTrackerStub.removeListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - expect(blockTrackerFromHookStub1.removeListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - expect(blockTrackerFromHookStub2.removeListener).toHaveBeenCalledWith( - 'latest', - expect.any(Function), - ); - }); - }); - - describe('blockTracker "latest" events', () => { - it('updates currentBlockGasLimit, currentBlockGasLimitByChainId, and accounts when polling is initiated via `start`', async () => { - const blockTrackerStub = buildMockBlockTracker({ - shouldStubListeners: false, - }); - const { accountTracker } = buildAccountTracker({ - blockTracker: blockTrackerStub, - }); - - const updateAccountsSpy = jest - .spyOn(accountTracker, 'updateAccounts') - .mockResolvedValue(); - - accountTracker.start(); - blockTrackerStub.emit('latest', 'blockNumber'); - - await flushPromises(); - - expect(updateAccountsSpy).toHaveBeenCalledWith(null); - - const newState = accountTracker.store.getState(); - - expect(newState).toStrictEqual({ - accounts: {}, - accountsByChainId: {}, - currentBlockGasLimit: GAS_LIMIT, - currentBlockGasLimitByChainId: { - [currentChainId]: GAS_LIMIT, - }, - }); - - accountTracker.stop(); - }); - - it('updates only the currentBlockGasLimitByChainId and accounts when polling is initiated via `startPollingByNetworkClientId`', async () => { - const blockTrackerFromHookStub = buildMockBlockTracker({ - shouldStubListeners: false, - }); - const providerFromHook = createTestProviderTools({ - scaffold: { - eth_getBalance: UPDATE_BALANCE_HOOK, - eth_call: ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN, - eth_getBlockByNumber: { gasLimit: GAS_LIMIT_HOOK }, - }, - networkId: '0x1', - chainId: '0x1', - }).provider; - const getNetworkClientByIdStub = jest.fn().mockReturnValue({ - configuration: { - chainId: '0x1', - }, - blockTracker: blockTrackerFromHookStub, - provider: providerFromHook, - }); - const { accountTracker } = buildAccountTracker({ - getNetworkClientById: getNetworkClientByIdStub, - }); - - const updateAccountsSpy = jest - .spyOn(accountTracker, 'updateAccounts') - .mockResolvedValue(); - - accountTracker.startPollingByNetworkClientId('mainnet'); - - blockTrackerFromHookStub.emit('latest', 'blockNumber'); - - await flushPromises(); - - expect(updateAccountsSpy).toHaveBeenCalledWith('mainnet'); - - const newState = accountTracker.store.getState(); - - expect(newState).toStrictEqual({ - accounts: {}, - accountsByChainId: {}, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: { - '0x1': GAS_LIMIT_HOOK, - }, - }); - - accountTracker.stopAllPolling(); - }); - }); - - describe('updateAccountsAllActiveNetworks', () => { - it('updates accounts for the globally selected network and all currently polling networks', async () => { - const { accountTracker } = buildAccountTracker(); - - const updateAccountsSpy = jest - .spyOn(accountTracker, 'updateAccounts') - .mockResolvedValue(); - await accountTracker.startPollingByNetworkClientId('networkClientId1'); - await accountTracker.startPollingByNetworkClientId('networkClientId2'); - await accountTracker.startPollingByNetworkClientId('networkClientId3'); - - expect(updateAccountsSpy).toHaveBeenCalledTimes(3); - - await accountTracker.updateAccountsAllActiveNetworks(); - - expect(updateAccountsSpy).toHaveBeenCalledTimes(7); - expect(updateAccountsSpy).toHaveBeenNthCalledWith(4); // called with no args - expect(updateAccountsSpy).toHaveBeenNthCalledWith(5, 'networkClientId1'); - expect(updateAccountsSpy).toHaveBeenNthCalledWith(6, 'networkClientId2'); - expect(updateAccountsSpy).toHaveBeenNthCalledWith(7, 'networkClientId3'); - }); - }); - - describe('updateAccounts', () => { - it('does not update accounts if completedOnBoarding is false', async () => { - const { accountTracker } = buildAccountTracker({ - completedOnboarding: false, - }); - - await accountTracker.updateAccounts(); - - const state = accountTracker.store.getState(); - expect(state).toStrictEqual({ - accounts: {}, - currentBlockGasLimit: '', - accountsByChainId: {}, - currentBlockGasLimitByChainId: {}, - }); - }); - - describe('chain does not have single call balance address', () => { - const getCurrentChainIdStub = () => '0x999'; // chain without single call balance address - const mockAccountsWithSelectedAddress = { - ...mockAccounts, - [SELECTED_ADDRESS]: { - address: SELECTED_ADDRESS, - balance: '0x0', - }, - }; - const mockInitialState = { - accounts: mockAccountsWithSelectedAddress, - accountsByChainId: { - '0x999': mockAccountsWithSelectedAddress, - }, - }; - - describe('when useMultiAccountBalanceChecker is true', () => { - it('updates all accounts directly', async () => { - const { accountTracker } = buildAccountTracker({ - completedOnboarding: true, - useMultiAccountBalanceChecker: true, - getCurrentChainId: getCurrentChainIdStub, - }); - accountTracker.store.updateState(mockInitialState); - - await accountTracker.updateAccounts(); - - const accounts = { - [VALID_ADDRESS]: { - address: VALID_ADDRESS, - balance: UPDATE_BALANCE, - }, - [VALID_ADDRESS_TWO]: { - address: VALID_ADDRESS_TWO, - balance: UPDATE_BALANCE, - }, - [SELECTED_ADDRESS]: { - address: SELECTED_ADDRESS, - balance: UPDATE_BALANCE, - }, - }; - - const newState = accountTracker.store.getState(); - expect(newState).toStrictEqual({ - accounts, - accountsByChainId: { - '0x999': accounts, - }, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: {}, - }); - }); - }); - - describe('when useMultiAccountBalanceChecker is false', () => { - it('updates only the selectedAddress directly, setting other balances to null', async () => { - const { accountTracker } = buildAccountTracker({ - completedOnboarding: true, - useMultiAccountBalanceChecker: false, - getCurrentChainId: getCurrentChainIdStub, - }); - accountTracker.store.updateState(mockInitialState); - - await accountTracker.updateAccounts(); - - const accounts = { - [VALID_ADDRESS]: { address: VALID_ADDRESS, balance: null }, - [VALID_ADDRESS_TWO]: { address: VALID_ADDRESS_TWO, balance: null }, - [SELECTED_ADDRESS]: { - address: SELECTED_ADDRESS, - balance: UPDATE_BALANCE, - }, - }; - - const newState = accountTracker.store.getState(); - expect(newState).toStrictEqual({ - accounts, - accountsByChainId: { - '0x999': accounts, - }, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: {}, - }); - }); - }); - }); - - describe('chain does have single call balance address and network is not localhost', () => { - const getNetworkIdentifierStub = jest - .fn() - .mockReturnValue('http://not-localhost:8545'); - const controllerMessenger = new ControllerMessenger(); - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - () => ({ - id: 'accountId', - address: VALID_ADDRESS, - }), - ); - const getCurrentChainIdStub = () => '0x1'; // chain with single call balance address - const mockInitialState = { - accounts: { ...mockAccounts }, - accountsByChainId: { - '0x1': { ...mockAccounts }, - }, - }; - - describe('when useMultiAccountBalanceChecker is true', () => { - it('updates all accounts via balance checker', async () => { - const { accountTracker } = buildAccountTracker({ - completedOnboarding: true, - useMultiAccountBalanceChecker: true, - controllerMessenger, - getNetworkIdentifier: getNetworkIdentifierStub, - getCurrentChainId: getCurrentChainIdStub, - }); - - accountTracker.store.updateState(mockInitialState); - - await accountTracker.updateAccounts('mainnet'); - - const accounts = { - [VALID_ADDRESS]: { - address: VALID_ADDRESS, - balance: EXPECTED_CONTRACT_BALANCE_1, - }, - [VALID_ADDRESS_TWO]: { - address: VALID_ADDRESS_TWO, - balance: EXPECTED_CONTRACT_BALANCE_2, - }, - }; - - const newState = accountTracker.store.getState(); - expect(newState).toStrictEqual({ - accounts, - accountsByChainId: { - '0x1': accounts, - }, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: {}, - }); - }); - }); - }); - }); - - describe('onAccountRemoved', () => { - it('should remove an account from state', () => { - let accountRemovedListener; - const { accountTracker } = buildAccountTracker({ - onAccountRemoved: (callback) => { - accountRemovedListener = callback; - }, - }); - accountTracker.store.updateState({ - accounts: { ...mockAccounts }, - accountsByChainId: { - [currentChainId]: { - ...mockAccounts, - }, - '0x1': { - ...mockAccounts, - }, - '0x2': { - ...mockAccounts, - }, - }, - }); - - accountRemovedListener(VALID_ADDRESS); - - const newState = accountTracker.store.getState(); - - const accounts = { - [VALID_ADDRESS_TWO]: mockAccounts[VALID_ADDRESS_TWO], - }; - - expect(newState).toStrictEqual({ - accounts, - accountsByChainId: { - [currentChainId]: accounts, - '0x1': accounts, - '0x2': accounts, - }, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: {}, - }); - }); - }); - - describe('clearAccounts', () => { - it('should reset state', () => { - const { accountTracker } = buildAccountTracker(); - - accountTracker.store.updateState({ - accounts: { ...mockAccounts }, - accountsByChainId: { - [currentChainId]: { - ...mockAccounts, - }, - '0x1': { - ...mockAccounts, - }, - '0x2': { - ...mockAccounts, - }, - }, - }); - - accountTracker.clearAccounts(); - - const newState = accountTracker.store.getState(); - - expect(newState).toStrictEqual({ - accounts: {}, - accountsByChainId: { - [currentChainId]: {}, - }, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: {}, - }); - }); - }); -}); diff --git a/app/scripts/lib/account-tracker.test.ts b/app/scripts/lib/account-tracker.test.ts new file mode 100644 index 000000000000..7cc0dcba14c7 --- /dev/null +++ b/app/scripts/lib/account-tracker.test.ts @@ -0,0 +1,805 @@ +import EventEmitter from 'events'; +import { ControllerMessenger } from '@metamask/base-controller'; +import { InternalAccount } from '@metamask/keyring-api'; +import { Hex } from '@metamask/utils'; +import { BlockTracker, Provider } from '@metamask/network-controller'; + +import { flushPromises } from '../../../test/lib/timer-helpers'; +import PreferencesController from '../controllers/preferences-controller'; +import OnboardingController from '../controllers/onboarding'; +import { createTestProviderTools } from '../../../test/stub/provider'; +import AccountTracker, { + AccountTrackerOptions, + AllowedActions, + AllowedEvents, + getDefaultAccountTrackerState, +} from './account-tracker'; + +const noop = () => true; +const currentNetworkId = '5'; +const currentChainId = '0x5'; +const VALID_ADDRESS = '0x0000000000000000000000000000000000000000'; +const VALID_ADDRESS_TWO = '0x0000000000000000000000000000000000000001'; + +const SELECTED_ADDRESS = '0x123'; + +const INITIAL_BALANCE_1 = '0x1'; +const INITIAL_BALANCE_2 = '0x2'; +const UPDATE_BALANCE = '0xabc'; +const UPDATE_BALANCE_HOOK = '0xabcd'; + +const GAS_LIMIT = '0x111111'; +const GAS_LIMIT_HOOK = '0x222222'; + +// The below three values were generated by running MetaMask in the browser +// The response to eth_call, which is called via `ethContract.balances` +// in `_updateAccountsViaBalanceChecker` of account-tracker.js, needs to be properly +// formatted or else ethers will throw an error. +const ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN = + '0x0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000038d7ea4c6800600000000000000000000000000000000000000000000000000000000000186a0'; +const EXPECTED_CONTRACT_BALANCE_1 = '0x038d7ea4c68006'; +const EXPECTED_CONTRACT_BALANCE_2 = '0x0186a0'; + +const mockAccounts = { + [VALID_ADDRESS]: { address: VALID_ADDRESS, balance: INITIAL_BALANCE_1 }, + [VALID_ADDRESS_TWO]: { + address: VALID_ADDRESS_TWO, + balance: INITIAL_BALANCE_2, + }, +}; + +class MockBlockTracker extends EventEmitter { + getCurrentBlock = noop; + + getLatestBlock = noop; +} + +function buildMockBlockTracker({ shouldStubListeners = true } = {}) { + const blockTrackerStub = new MockBlockTracker(); + if (shouldStubListeners) { + jest.spyOn(blockTrackerStub, 'addListener').mockImplementation(); + jest.spyOn(blockTrackerStub, 'removeListener').mockImplementation(); + } + return blockTrackerStub; +} + +type WithControllerOptions = { + completedOnboarding?: boolean; + useMultiAccountBalanceChecker?: boolean; + getNetworkClientById?: jest.Mock; + getSelectedAccount?: jest.Mock; +} & Partial; + +type WithControllerCallback = ({ + controller, + blockTrackerFromHookStub, + blockTrackerStub, + triggerOnAccountRemoved, +}: { + controller: AccountTracker; + blockTrackerFromHookStub: MockBlockTracker; + blockTrackerStub: MockBlockTracker; + triggerOnAccountRemoved: (address: string) => void; +}) => ReturnValue; + +type WithControllerArgs = + | [WithControllerCallback] + | [WithControllerOptions, WithControllerCallback]; + +function withController( + ...args: WithControllerArgs +): ReturnValue { + const [{ ...rest }, fn] = args.length === 2 ? args : [{}, args[0]]; + const { + completedOnboarding = false, + useMultiAccountBalanceChecker = false, + getNetworkClientById, + getSelectedAccount, + ...accountTrackerOptions + } = rest; + const { provider } = createTestProviderTools({ + scaffold: { + eth_getBalance: UPDATE_BALANCE, + eth_call: ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN, + eth_getBlockByNumber: { gasLimit: GAS_LIMIT }, + }, + networkId: currentNetworkId, + chainId: currentNetworkId, + }); + const blockTrackerStub = buildMockBlockTracker(); + + const controllerMessenger = new ControllerMessenger< + AllowedActions, + AllowedEvents + >(); + const getSelectedAccountStub = () => + ({ + id: 'accountId', + address: SELECTED_ADDRESS, + } as InternalAccount); + controllerMessenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + getSelectedAccount || getSelectedAccountStub, + ); + + const { provider: providerFromHook } = createTestProviderTools({ + scaffold: { + eth_getBalance: UPDATE_BALANCE_HOOK, + eth_call: ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN, + eth_getBlockByNumber: { gasLimit: GAS_LIMIT_HOOK }, + }, + networkId: '0x1', + chainId: '0x1', + }); + + const blockTrackerFromHookStub = buildMockBlockTracker(); + + const getNetworkClientByIdStub = jest.fn().mockReturnValue({ + configuration: { + chainId: '0x1', + }, + blockTracker: blockTrackerFromHookStub, + provider: providerFromHook, + }); + + controllerMessenger.registerActionHandler( + 'NetworkController:getNetworkClientById', + getNetworkClientById || getNetworkClientByIdStub, + ); + + const controller = new AccountTracker({ + initState: getDefaultAccountTrackerState(), + provider: provider as Provider, + blockTracker: blockTrackerStub as unknown as BlockTracker, + getNetworkIdentifier: jest.fn(), + preferencesController: { + store: { + getState: () => ({ + useMultiAccountBalanceChecker, + }), + }, + } as PreferencesController, + onboardingController: { + state: { + completedOnboarding, + }, + } as OnboardingController, + controllerMessenger, + getCurrentChainId: () => currentChainId, + ...accountTrackerOptions, + }); + + return fn({ + controller, + blockTrackerFromHookStub, + blockTrackerStub, + triggerOnAccountRemoved: (address: string) => { + controllerMessenger.publish('KeyringController:accountRemoved', address); + }, + }); +} + +describe('Account Tracker', () => { + describe('start', () => { + it('restarts the subscription to the block tracker and update accounts', async () => { + withController(({ controller, blockTrackerStub }) => { + const updateAccountsSpy = jest + .spyOn(controller, 'updateAccounts') + .mockResolvedValue(); + + controller.start(); + + expect(blockTrackerStub.removeListener).toHaveBeenNthCalledWith( + 1, + 'latest', + expect.any(Function), + ); + expect(blockTrackerStub.addListener).toHaveBeenNthCalledWith( + 1, + 'latest', + expect.any(Function), + ); + expect(updateAccountsSpy).toHaveBeenNthCalledWith(1); // called first time with no args + + controller.start(); + + expect(blockTrackerStub.removeListener).toHaveBeenNthCalledWith( + 2, + 'latest', + expect.any(Function), + ); + expect(blockTrackerStub.addListener).toHaveBeenNthCalledWith( + 2, + 'latest', + expect.any(Function), + ); + expect(updateAccountsSpy).toHaveBeenNthCalledWith(2); // called second time with no args + + controller.stop(); + }); + }); + }); + + describe('stop', () => { + it('ends the subscription to the block tracker', async () => { + withController(({ controller, blockTrackerStub }) => { + controller.stop(); + + expect(blockTrackerStub.removeListener).toHaveBeenNthCalledWith( + 1, + 'latest', + expect.any(Function), + ); + }); + }); + }); + + describe('startPollingByNetworkClientId', () => { + it('should subscribe to the block tracker and update accounts if not already using the networkClientId', async () => { + withController(({ controller, blockTrackerFromHookStub }) => { + const updateAccountsSpy = jest + .spyOn(controller, 'updateAccounts') + .mockResolvedValue(); + + controller.startPollingByNetworkClientId('mainnet'); + + expect(blockTrackerFromHookStub.addListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + expect(updateAccountsSpy).toHaveBeenCalledWith('mainnet'); + + controller.startPollingByNetworkClientId('mainnet'); + + expect(blockTrackerFromHookStub.addListener).toHaveBeenCalledTimes(1); + expect(updateAccountsSpy).toHaveBeenCalledTimes(1); + + controller.stopAllPolling(); + }); + }); + + it('should subscribe to the block tracker and update accounts for each networkClientId', async () => { + const blockTrackerFromHookStub1 = buildMockBlockTracker(); + const blockTrackerFromHookStub2 = buildMockBlockTracker(); + const blockTrackerFromHookStub3 = buildMockBlockTracker(); + withController( + { + getNetworkClientById: jest + .fn() + .mockImplementation((networkClientId) => { + switch (networkClientId) { + case 'mainnet': + return { + configuration: { + chainId: '0x1', + }, + blockTracker: blockTrackerFromHookStub1, + }; + case 'goerli': + return { + configuration: { + chainId: '0x5', + }, + blockTracker: blockTrackerFromHookStub2, + }; + case 'networkClientId1': + return { + configuration: { + chainId: '0xa', + }, + blockTracker: blockTrackerFromHookStub3, + }; + default: + throw new Error('unexpected networkClientId'); + } + }), + }, + ({ controller }) => { + const updateAccountsSpy = jest + .spyOn(controller, 'updateAccounts') + .mockResolvedValue(); + + controller.startPollingByNetworkClientId('mainnet'); + + expect(blockTrackerFromHookStub1.addListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + expect(updateAccountsSpy).toHaveBeenCalledWith('mainnet'); + + controller.startPollingByNetworkClientId('goerli'); + + expect(blockTrackerFromHookStub2.addListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + expect(updateAccountsSpy).toHaveBeenCalledWith('goerli'); + + controller.startPollingByNetworkClientId('networkClientId1'); + + expect(blockTrackerFromHookStub3.addListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + expect(updateAccountsSpy).toHaveBeenCalledWith('networkClientId1'); + + controller.stopAllPolling(); + }, + ); + }); + }); + + describe('stopPollingByPollingToken', () => { + it('should unsubscribe from the block tracker when called with a valid polling that was the only active pollingToken for a given networkClient', async () => { + withController(({ controller, blockTrackerFromHookStub }) => { + jest.spyOn(controller, 'updateAccounts').mockResolvedValue(); + + const pollingToken = + controller.startPollingByNetworkClientId('mainnet'); + + controller.stopPollingByPollingToken(pollingToken); + + expect(blockTrackerFromHookStub.removeListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + }); + }); + + it('should not unsubscribe from the block tracker if called with one of multiple active polling tokens for a given networkClient', async () => { + withController(({ controller, blockTrackerFromHookStub }) => { + jest.spyOn(controller, 'updateAccounts').mockResolvedValue(); + + const pollingToken1 = + controller.startPollingByNetworkClientId('mainnet'); + controller.startPollingByNetworkClientId('mainnet'); + + controller.stopPollingByPollingToken(pollingToken1); + + expect(blockTrackerFromHookStub.removeListener).not.toHaveBeenCalled(); + + controller.stopAllPolling(); + }); + }); + + it('should error if no pollingToken is passed', () => { + withController(({ controller }) => { + expect(() => { + controller.stopPollingByPollingToken(undefined); + }).toThrow('pollingToken required'); + }); + }); + + it('should error if no matching pollingToken is found', () => { + withController(({ controller }) => { + expect(() => { + controller.stopPollingByPollingToken('potato'); + }).toThrow('pollingToken not found'); + }); + }); + }); + + describe('stopAll', () => { + it('should end all subscriptions', async () => { + const blockTrackerFromHookStub1 = buildMockBlockTracker(); + const blockTrackerFromHookStub2 = buildMockBlockTracker(); + const getNetworkClientByIdStub = jest + .fn() + .mockImplementation((networkClientId) => { + switch (networkClientId) { + case 'mainnet': + return { + configuration: { + chainId: '0x1', + }, + blockTracker: blockTrackerFromHookStub1, + }; + case 'goerli': + return { + configuration: { + chainId: '0x5', + }, + blockTracker: blockTrackerFromHookStub2, + }; + default: + throw new Error('unexpected networkClientId'); + } + }); + withController( + { + getNetworkClientById: getNetworkClientByIdStub, + }, + ({ controller, blockTrackerStub }) => { + jest.spyOn(controller, 'updateAccounts').mockResolvedValue(); + + controller.startPollingByNetworkClientId('mainnet'); + + controller.startPollingByNetworkClientId('goerli'); + + controller.stopAllPolling(); + + expect(blockTrackerStub.removeListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + expect(blockTrackerFromHookStub1.removeListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + expect(blockTrackerFromHookStub2.removeListener).toHaveBeenCalledWith( + 'latest', + expect.any(Function), + ); + }, + ); + }); + }); + + describe('blockTracker "latest" events', () => { + it('updates currentBlockGasLimit, currentBlockGasLimitByChainId, and accounts when polling is initiated via `start`', async () => { + const blockTrackerStub = buildMockBlockTracker({ + shouldStubListeners: false, + }); + withController( + { + blockTracker: blockTrackerStub as unknown as BlockTracker, + }, + async ({ controller }) => { + const updateAccountsSpy = jest + .spyOn(controller, 'updateAccounts') + .mockResolvedValue(); + + controller.start(); + blockTrackerStub.emit('latest', 'blockNumber'); + + await flushPromises(); + + expect(updateAccountsSpy).toHaveBeenCalledWith(undefined); + + const newState = controller.store.getState(); + + expect(newState).toStrictEqual({ + accounts: {}, + accountsByChainId: {}, + currentBlockGasLimit: GAS_LIMIT, + currentBlockGasLimitByChainId: { + [currentChainId]: GAS_LIMIT, + }, + }); + + controller.stop(); + }, + ); + }); + + it('updates only the currentBlockGasLimitByChainId and accounts when polling is initiated via `startPollingByNetworkClientId`', async () => { + const blockTrackerFromHookStub = buildMockBlockTracker({ + shouldStubListeners: false, + }); + const providerFromHook = createTestProviderTools({ + scaffold: { + eth_getBalance: UPDATE_BALANCE_HOOK, + eth_call: ETHERS_CONTRACT_BALANCES_ETH_CALL_RETURN, + eth_getBlockByNumber: { gasLimit: GAS_LIMIT_HOOK }, + }, + networkId: '0x1', + chainId: '0x1', + }).provider; + const getNetworkClientByIdStub = jest.fn().mockReturnValue({ + configuration: { + chainId: '0x1', + }, + blockTracker: blockTrackerFromHookStub, + provider: providerFromHook, + }); + withController( + { + getNetworkClientById: getNetworkClientByIdStub, + }, + async ({ controller }) => { + const updateAccountsSpy = jest + .spyOn(controller, 'updateAccounts') + .mockResolvedValue(); + + controller.startPollingByNetworkClientId('mainnet'); + + blockTrackerFromHookStub.emit('latest', 'blockNumber'); + + await flushPromises(); + + expect(updateAccountsSpy).toHaveBeenCalledWith('mainnet'); + + const newState = controller.store.getState(); + + expect(newState).toStrictEqual({ + accounts: {}, + accountsByChainId: {}, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: { + '0x1': GAS_LIMIT_HOOK, + }, + }); + + controller.stopAllPolling(); + }, + ); + }); + }); + + describe('updateAccountsAllActiveNetworks', () => { + it('updates accounts for the globally selected network and all currently polling networks', async () => { + withController(async ({ controller }) => { + const updateAccountsSpy = jest + .spyOn(controller, 'updateAccounts') + .mockResolvedValue(); + await controller.startPollingByNetworkClientId('networkClientId1'); + await controller.startPollingByNetworkClientId('networkClientId2'); + await controller.startPollingByNetworkClientId('networkClientId3'); + + expect(updateAccountsSpy).toHaveBeenCalledTimes(3); + + await controller.updateAccountsAllActiveNetworks(); + + expect(updateAccountsSpy).toHaveBeenCalledTimes(7); + expect(updateAccountsSpy).toHaveBeenNthCalledWith(4); // called with no args + expect(updateAccountsSpy).toHaveBeenNthCalledWith( + 5, + 'networkClientId1', + ); + expect(updateAccountsSpy).toHaveBeenNthCalledWith( + 6, + 'networkClientId2', + ); + expect(updateAccountsSpy).toHaveBeenNthCalledWith( + 7, + 'networkClientId3', + ); + }); + }); + }); + + describe('updateAccounts', () => { + it('does not update accounts if completedOnBoarding is false', async () => { + withController( + { + completedOnboarding: false, + }, + async ({ controller }) => { + await controller.updateAccounts(); + + const state = controller.store.getState(); + expect(state).toStrictEqual({ + accounts: {}, + currentBlockGasLimit: '', + accountsByChainId: {}, + currentBlockGasLimitByChainId: {}, + }); + }, + ); + }); + + describe('chain does not have single call balance address', () => { + const getCurrentChainIdStub: () => Hex = () => '0x999'; // chain without single call balance address + const mockAccountsWithSelectedAddress = { + ...mockAccounts, + [SELECTED_ADDRESS]: { + address: SELECTED_ADDRESS, + balance: '0x0', + }, + }; + const mockInitialState = { + accounts: mockAccountsWithSelectedAddress, + accountsByChainId: { + '0x999': mockAccountsWithSelectedAddress, + }, + }; + + describe('when useMultiAccountBalanceChecker is true', () => { + it('updates all accounts directly', async () => { + withController( + { + completedOnboarding: true, + useMultiAccountBalanceChecker: true, + getCurrentChainId: getCurrentChainIdStub, + }, + async ({ controller }) => { + controller.store.updateState(mockInitialState); + + await controller.updateAccounts(); + + const accounts = { + [VALID_ADDRESS]: { + address: VALID_ADDRESS, + balance: UPDATE_BALANCE, + }, + [VALID_ADDRESS_TWO]: { + address: VALID_ADDRESS_TWO, + balance: UPDATE_BALANCE, + }, + [SELECTED_ADDRESS]: { + address: SELECTED_ADDRESS, + balance: UPDATE_BALANCE, + }, + }; + + const newState = controller.store.getState(); + expect(newState).toStrictEqual({ + accounts, + accountsByChainId: { + '0x999': accounts, + }, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: {}, + }); + }, + ); + }); + }); + + describe('when useMultiAccountBalanceChecker is false', () => { + it('updates only the selectedAddress directly, setting other balances to null', async () => { + withController( + { + completedOnboarding: true, + useMultiAccountBalanceChecker: false, + getCurrentChainId: getCurrentChainIdStub, + }, + async ({ controller }) => { + controller.store.updateState(mockInitialState); + + await controller.updateAccounts(); + + const accounts = { + [VALID_ADDRESS]: { address: VALID_ADDRESS, balance: null }, + [VALID_ADDRESS_TWO]: { + address: VALID_ADDRESS_TWO, + balance: null, + }, + [SELECTED_ADDRESS]: { + address: SELECTED_ADDRESS, + balance: UPDATE_BALANCE, + }, + }; + + const newState = controller.store.getState(); + expect(newState).toStrictEqual({ + accounts, + accountsByChainId: { + '0x999': accounts, + }, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: {}, + }); + }, + ); + }); + }); + }); + + describe('chain does have single call balance address and network is not localhost', () => { + describe('when useMultiAccountBalanceChecker is true', () => { + it('updates all accounts via balance checker', async () => { + withController( + { + completedOnboarding: true, + useMultiAccountBalanceChecker: true, + getNetworkIdentifier: jest + .fn() + .mockReturnValue('http://not-localhost:8545'), + getCurrentChainId: () => '0x1', // chain with single call balance address + getSelectedAccount: jest.fn().mockReturnValue({ + id: 'accountId', + address: VALID_ADDRESS, + } as InternalAccount), + }, + async ({ controller }) => { + controller.store.updateState({ + accounts: { ...mockAccounts }, + accountsByChainId: { + '0x1': { ...mockAccounts }, + }, + }); + + await controller.updateAccounts('mainnet'); + + const accounts = { + [VALID_ADDRESS]: { + address: VALID_ADDRESS, + balance: EXPECTED_CONTRACT_BALANCE_1, + }, + [VALID_ADDRESS_TWO]: { + address: VALID_ADDRESS_TWO, + balance: EXPECTED_CONTRACT_BALANCE_2, + }, + }; + + const newState = controller.store.getState(); + expect(newState).toStrictEqual({ + accounts, + accountsByChainId: { + '0x1': accounts, + }, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: {}, + }); + }, + ); + }); + }); + }); + }); + + describe('onAccountRemoved', () => { + it('should remove an account from state', () => { + withController(({ controller, triggerOnAccountRemoved }) => { + controller.store.updateState({ + accounts: { ...mockAccounts }, + accountsByChainId: { + [currentChainId]: { + ...mockAccounts, + }, + '0x1': { + ...mockAccounts, + }, + '0x2': { + ...mockAccounts, + }, + }, + }); + + triggerOnAccountRemoved(VALID_ADDRESS); + + const newState = controller.store.getState(); + + const accounts = { + [VALID_ADDRESS_TWO]: mockAccounts[VALID_ADDRESS_TWO], + }; + + expect(newState).toStrictEqual({ + accounts, + accountsByChainId: { + [currentChainId]: accounts, + '0x1': accounts, + '0x2': accounts, + }, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: {}, + }); + }); + }); + }); + + describe('clearAccounts', () => { + it('should reset state', () => { + withController(({ controller }) => { + controller.store.updateState({ + accounts: { ...mockAccounts }, + accountsByChainId: { + [currentChainId]: { + ...mockAccounts, + }, + '0x1': { + ...mockAccounts, + }, + '0x2': { + ...mockAccounts, + }, + }, + }); + + controller.clearAccounts(); + + const newState = controller.store.getState(); + + expect(newState).toStrictEqual({ + accounts: {}, + accountsByChainId: { + [currentChainId]: {}, + }, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: {}, + }); + }); + }); + }); +}); diff --git a/app/scripts/lib/account-tracker.js b/app/scripts/lib/account-tracker.ts similarity index 60% rename from app/scripts/lib/account-tracker.js rename to app/scripts/lib/account-tracker.ts index 6dbd13f1c2df..df89238ca366 100644 --- a/app/scripts/lib/account-tracker.js +++ b/app/scripts/lib/account-tracker.ts @@ -17,52 +17,127 @@ import { Web3Provider } from '@ethersproject/providers'; import { Contract } from '@ethersproject/contracts'; import SINGLE_CALL_BALANCES_ABI from 'single-call-balance-checker-abi'; import { cloneDeep } from 'lodash'; +import { + BlockTracker, + NetworkClientConfiguration, + NetworkClientId, + NetworkControllerGetNetworkClientByIdAction, + Provider, +} from '@metamask/network-controller'; +import { hasProperty, Hex } from '@metamask/utils'; +import { ControllerMessenger } from '@metamask/base-controller'; +import { + AccountsControllerGetSelectedAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, +} from '@metamask/accounts-controller'; +import { KeyringControllerAccountRemovedEvent } from '@metamask/keyring-controller'; +import { InternalAccount } from '@metamask/keyring-api'; + +import OnboardingController, { + OnboardingControllerStateChangeEvent, +} from '../controllers/onboarding'; +import PreferencesController from '../controllers/preferences-controller'; import { LOCALHOST_RPC_URL } from '../../../shared/constants/network'; - import { SINGLE_CALL_BALANCES_ADDRESSES } from '../constants/contracts'; import { previousValueComparator } from './util'; +type Account = { + address: string; + balance: string | null; +}; + +export type AccountTrackerState = { + accounts: Record>; + currentBlockGasLimit: string; + accountsByChainId: Record; + currentBlockGasLimitByChainId: Record; +}; + +export const getDefaultAccountTrackerState = (): AccountTrackerState => ({ + accounts: {}, + currentBlockGasLimit: '', + accountsByChainId: {}, + currentBlockGasLimitByChainId: {}, +}); + +export type AllowedActions = + | AccountsControllerGetSelectedAccountAction + | NetworkControllerGetNetworkClientByIdAction; + +export type AllowedEvents = + | AccountsControllerSelectedEvmAccountChangeEvent + | KeyringControllerAccountRemovedEvent + | OnboardingControllerStateChangeEvent; + +export type AccountTrackerOptions = { + initState: Partial; + provider: Provider; + blockTracker: BlockTracker; + getCurrentChainId: () => Hex; + getNetworkIdentifier: (config?: NetworkClientConfiguration) => string; + preferencesController: PreferencesController; + onboardingController: OnboardingController; + controllerMessenger: ControllerMessenger; +}; + /** * This module is responsible for tracking any number of accounts and caching their current balances & transaction * counts. * * It also tracks transaction hashes, and checks their inclusion status on each new block. * - * @typedef {object} AccountTracker - * @property {object} store The stored object containing all accounts to track, as well as the current block's gas limit. - * @property {object} store.accounts The accounts currently stored in this AccountTracker - * @property {object} store.accountsByChainId The accounts currently stored in this AccountTracker keyed by chain id - * @property {string} store.currentBlockGasLimit A hex string indicating the gas limit of the current block - * @property {string} store.currentBlockGasLimitByChainId A hex string indicating the gas limit of the current block keyed by chain id + * AccountTracker + * + * @property store The stored object containing all accounts to track, as well as the current block's gas limit. + * @property store.accounts The accounts currently stored in this AccountTracker + * @property store.accountsByChainId The accounts currently stored in this AccountTracker keyed by chain id + * @property store.currentBlockGasLimit A hex string indicating the gas limit of the current block + * @property store.currentBlockGasLimitByChainId A hex string indicating the gas limit of the current block keyed by chain id */ export default class AccountTracker { /** - * @param {object} opts - Options for initializing the controller - * @param {object} opts.provider - An EIP-1193 provider instance that uses the current global network - * @param {object} opts.blockTracker - A block tracker, which emits events for each new block - * @param {Function} opts.getCurrentChainId - A function that returns the `chainId` for the current global network - * @param {Function} opts.getNetworkClientById - Gets the network client with the given id from the NetworkController. - * @param {Function} opts.getNetworkIdentifier - A function that returns the current network or passed nework configuration - * @param {Function} opts.onAccountRemoved - Allows subscribing to keyring controller accountRemoved event + * Observable store containing controller data. */ - #pollingTokenSets = new Map(); + store: ObservableStore; - #listeners = {}; + resetState: () => void; - #provider = null; + #pollingTokenSets = new Map>(); - #blockTracker = null; + #listeners: Record Promise> = + {}; - #currentBlockNumberByChainId = {}; + #provider: Provider; - constructor(opts = {}) { - const initState = { - accounts: {}, - currentBlockGasLimit: '', - accountsByChainId: {}, - currentBlockGasLimitByChainId: {}, - }; - this.store = new ObservableStore({ ...initState, ...opts.initState }); + #blockTracker: BlockTracker; + + #currentBlockNumberByChainId: Record = {}; + + #getCurrentChainId: AccountTrackerOptions['getCurrentChainId']; + + #getNetworkIdentifier: AccountTrackerOptions['getNetworkIdentifier']; + + #preferencesController: AccountTrackerOptions['preferencesController']; + + #onboardingController: AccountTrackerOptions['onboardingController']; + + #controllerMessenger: AccountTrackerOptions['controllerMessenger']; + + #selectedAccount: InternalAccount; + + /** + * @param opts - Options for initializing the controller + * @param opts.provider - An EIP-1193 provider instance that uses the current global network + * @param opts.blockTracker - A block tracker, which emits events for each new block + * @param opts.getCurrentChainId - A function that returns the `chainId` for the current global network + * @param opts.getNetworkIdentifier - A function that returns the current network or passed nework configuration + */ + constructor(opts: AccountTrackerOptions) { + const initState = getDefaultAccountTrackerState(); + this.store = new ObservableStore({ + ...initState, + ...opts.initState, + }); this.resetState = () => { this.store.updateState(initState); @@ -71,17 +146,19 @@ export default class AccountTracker { this.#provider = opts.provider; this.#blockTracker = opts.blockTracker; - this.getCurrentChainId = opts.getCurrentChainId; - this.getNetworkClientById = opts.getNetworkClientById; - this.getNetworkIdentifier = opts.getNetworkIdentifier; - this.preferencesController = opts.preferencesController; - this.onboardingController = opts.onboardingController; - this.controllerMessenger = opts.controllerMessenger; + this.#getCurrentChainId = opts.getCurrentChainId; + this.#getNetworkIdentifier = opts.getNetworkIdentifier; + this.#preferencesController = opts.preferencesController; + this.#onboardingController = opts.onboardingController; + this.#controllerMessenger = opts.controllerMessenger; // subscribe to account removal - opts.onAccountRemoved((address) => this.removeAccounts([address])); + this.#controllerMessenger.subscribe( + 'KeyringController:accountRemoved', + (address) => this.removeAccounts([address]), + ); - this.controllerMessenger.subscribe( + this.#controllerMessenger.subscribe( 'OnboardingController:stateChange', previousValueComparator((prevState, currState) => { const { completedOnboarding: prevCompletedOnboarding } = prevState; @@ -89,24 +166,25 @@ export default class AccountTracker { if (!prevCompletedOnboarding && currCompletedOnboarding) { this.updateAccountsAllActiveNetworks(); } - }, this.onboardingController.state), + return true; + }, this.#onboardingController.state), ); - this.selectedAccount = this.controllerMessenger.call( + this.#selectedAccount = this.#controllerMessenger.call( 'AccountsController:getSelectedAccount', ); - this.controllerMessenger.subscribe( + this.#controllerMessenger.subscribe( 'AccountsController:selectedEvmAccountChange', (newAccount) => { const { useMultiAccountBalanceChecker } = - this.preferencesController.store.getState(); + this.#preferencesController.store.getState(); if ( - this.selectedAccount.id !== newAccount.id && + this.#selectedAccount.id !== newAccount.id && !useMultiAccountBalanceChecker ) { - this.selectedAccount = newAccount; + this.#selectedAccount = newAccount; this.updateAccountsAllActiveNetworks(); } }, @@ -116,13 +194,14 @@ export default class AccountTracker { /** * Starts polling with global selected network */ - start() { + start(): void { // blockTracker.currentBlock may be null this.#currentBlockNumberByChainId = { - [this.getCurrentChainId()]: this.#blockTracker.getCurrentBlock(), + [this.#getCurrentChainId()]: this.#blockTracker.getCurrentBlock(), }; this.#blockTracker.once('latest', (blockNumber) => { - this.#currentBlockNumberByChainId[this.getCurrentChainId()] = blockNumber; + this.#currentBlockNumberByChainId[this.#getCurrentChainId()] = + blockNumber; }); // remove first to avoid double add @@ -136,7 +215,7 @@ export default class AccountTracker { /** * Stops polling with global selected network */ - stop() { + stop(): void { // remove listener this.#blockTracker.removeListener('latest', this.#updateForBlock); } @@ -148,22 +227,31 @@ export default class AccountTracker { * @param networkClientId - Optional networkClientId to fetch a network client with * @returns network client config */ - #getCorrectNetworkClient(networkClientId) { + #getCorrectNetworkClient(networkClientId?: NetworkClientId): { + chainId: Hex; + provider: Provider; + blockTracker: BlockTracker; + identifier: string; + } { if (networkClientId) { - const networkClient = this.getNetworkClientById(networkClientId); + const { configuration, provider, blockTracker } = + this.#controllerMessenger.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ); return { - chainId: networkClient.configuration.chainId, - provider: networkClient.provider, - blockTracker: networkClient.blockTracker, - identifier: this.getNetworkIdentifier(networkClient.configuration), + chainId: configuration.chainId, + provider, + blockTracker, + identifier: this.#getNetworkIdentifier(configuration), }; } return { - chainId: this.getCurrentChainId(), + chainId: this.#getCurrentChainId(), provider: this.#provider, blockTracker: this.#blockTracker, - identifier: this.getNetworkIdentifier(), + identifier: this.#getNetworkIdentifier(), }; } @@ -173,14 +261,14 @@ export default class AccountTracker { * @param networkClientId - The networkClientId to start polling for * @returns pollingToken */ - startPollingByNetworkClientId(networkClientId) { + startPollingByNetworkClientId(networkClientId: NetworkClientId) { const pollToken = random(); const pollingTokenSet = this.#pollingTokenSets.get(networkClientId); if (pollingTokenSet) { pollingTokenSet.add(pollToken); } else { - const set = new Set(); + const set = new Set(); set.add(pollToken); this.#pollingTokenSets.set(networkClientId, set); this.#subscribeWithNetworkClientId(networkClientId); @@ -205,7 +293,7 @@ export default class AccountTracker { * * @param pollingToken - The polling token to stop polling for */ - stopPollingByPollingToken(pollingToken) { + stopPollingByPollingToken(pollingToken: string | undefined) { if (!pollingToken) { throw new Error('pollingToken required'); } @@ -228,9 +316,9 @@ export default class AccountTracker { /** * Subscribes from the block tracker for the given networkClientId if not currently subscribed * - * @param {string} networkClientId - network client ID to fetch a block tracker with + * @param networkClientId - network client ID to fetch a block tracker with */ - #subscribeWithNetworkClientId(networkClientId) { + #subscribeWithNetworkClientId(networkClientId: NetworkClientId) { if (this.#listeners[networkClientId]) { return; } @@ -249,9 +337,9 @@ export default class AccountTracker { /** * Unsubscribes from the block tracker for the given networkClientId if currently subscribed * - * @param {string} networkClientId - The network client ID to fetch a block tracker with + * @param networkClientId - The network client ID to fetch a block tracker with */ - #unsubscribeWithNetworkClientId(networkClientId) { + #unsubscribeWithNetworkClientId(networkClientId: NetworkClientId) { if (!this.#listeners[networkClientId]) { return; } @@ -265,16 +353,15 @@ export default class AccountTracker { * Returns the accounts object for the chain ID, or initializes it from the globally selected * if it doesn't already exist. * - * @private - * @param {string} chainId - The chain ID + * @param chainId - The chain ID */ - #getAccountsForChainId(chainId) { + #getAccountsForChainId(chainId: Hex) { const { accounts, accountsByChainId } = this.store.getState(); if (accountsByChainId[chainId]) { return cloneDeep(accountsByChainId[chainId]); } - const newAccounts = {}; + const newAccounts: Record> = {}; Object.keys(accounts).forEach((address) => { newAccounts[address] = {}; }); @@ -288,21 +375,21 @@ export default class AccountTracker { * Once this AccountTracker's accounts are up to date with those referenced by the passed addresses, each * of these accounts are given an updated balance via EthQuery. * - * @param {Array} addresses - The array of hex addresses for accounts with which this AccountTracker's accounts should be + * @param addresses - The array of hex addresses for accounts with which this AccountTracker's accounts should be * in sync */ - syncWithAddresses(addresses) { + syncWithAddresses(addresses: string[]): void { const { accounts } = this.store.getState(); const locals = Object.keys(accounts); - const accountsToAdd = []; + const accountsToAdd: string[] = []; addresses.forEach((upstream) => { if (!locals.includes(upstream)) { accountsToAdd.push(upstream); } }); - const accountsToRemove = []; + const accountsToRemove: string[] = []; locals.forEach((local) => { if (!addresses.includes(local)) { accountsToRemove.push(local); @@ -317,9 +404,9 @@ export default class AccountTracker { * Adds new addresses to track the balances of * given a balance as long this.#currentBlockNumberByChainId is defined for the chainId. * - * @param {Array} addresses - An array of hex addresses of new accounts to track + * @param addresses - An array of hex addresses of new accounts to track */ - addAccounts(addresses) { + addAccounts(addresses: string[]): void { const { accounts: _accounts, accountsByChainId: _accountsByChainId } = this.store.getState(); const accounts = cloneDeep(_accounts); @@ -338,7 +425,7 @@ export default class AccountTracker { this.store.updateState({ accounts, accountsByChainId }); // fetch balances for the accounts if there is block number ready - if (this.#currentBlockNumberByChainId[this.getCurrentChainId()]) { + if (this.#currentBlockNumberByChainId[this.#getCurrentChainId()]) { this.updateAccounts(); } this.#pollingTokenSets.forEach((_tokenSet, networkClientId) => { @@ -352,9 +439,9 @@ export default class AccountTracker { /** * Removes accounts from being tracked * - * @param {Array} addresses - An array of hex addresses to stop tracking. + * @param addresses - An array of hex addresses to stop tracking. */ - removeAccounts(addresses) { + removeAccounts(addresses: string[]): void { const { accounts: _accounts, accountsByChainId: _accountsByChainId } = this.store.getState(); const accounts = cloneDeep(_accounts); @@ -376,11 +463,11 @@ export default class AccountTracker { /** * Removes all addresses and associated balances */ - clearAccounts() { + clearAccounts(): void { this.store.updateState({ accounts: {}, accountsByChainId: { - [this.getCurrentChainId()]: {}, + [this.#getCurrentChainId()]: {}, }, }); } @@ -390,11 +477,11 @@ export default class AccountTracker { * each local account's balance via EthQuery * * @private - * @param {number} blockNumber - the block number to update to. + * @param blockNumber - the block number to update to. * @fires 'block' The updated state, if all account updates are successful */ - #updateForBlock = async (blockNumber) => { - await this.#updateForBlockByNetworkClientId(null, blockNumber); + #updateForBlock = async (blockNumber: string): Promise => { + await this.#updateForBlockByNetworkClientId(undefined, blockNumber); }; /** @@ -402,11 +489,14 @@ export default class AccountTracker { * via EthQuery * * @private - * @param {string} networkClientId - optional network client ID to use instead of the globally selected network. - * @param {number} blockNumber - the block number to update to. + * @param networkClientId - optional network client ID to use instead of the globally selected network. + * @param blockNumber - the block number to update to. * @fires 'block' The updated state, if all account updates are successful */ - async #updateForBlockByNetworkClientId(networkClientId, blockNumber) { + async #updateForBlockByNetworkClientId( + networkClientId: NetworkClientId | undefined, + blockNumber: string, + ): Promise { const { chainId, provider } = this.#getCorrectNetworkClient(networkClientId); this.#currentBlockNumberByChainId[chainId] = blockNumber; @@ -422,7 +512,7 @@ export default class AccountTracker { const currentBlockGasLimit = currentBlock.gasLimit; const { currentBlockGasLimitByChainId } = this.store.getState(); this.store.updateState({ - ...(chainId === this.getCurrentChainId() && { + ...(chainId === this.#getCurrentChainId() && { currentBlockGasLimit, }), currentBlockGasLimitByChainId: { @@ -442,9 +532,8 @@ export default class AccountTracker { * Updates accounts for the globally selected network * and all networks that are currently being polled. * - * @returns {Promise} after all account balances updated */ - async updateAccountsAllActiveNetworks() { + async updateAccountsAllActiveNetworks(): Promise { await this.updateAccounts(); await Promise.all( Array.from(this.#pollingTokenSets).map(([networkClientId]) => { @@ -457,11 +546,10 @@ export default class AccountTracker { * balanceChecker is deployed on main eth (test)nets and requires a single call * for all other networks, calls this.#updateAccount for each account in this.store * - * @param {string} networkClientId - optional network client ID to use instead of the globally selected network. - * @returns {Promise} after all account balances updated + * @param networkClientId - optional network client ID to use instead of the globally selected network. */ - async updateAccounts(networkClientId) { - const { completedOnboarding } = this.onboardingController.state; + async updateAccounts(networkClientId?: NetworkClientId): Promise { + const { completedOnboarding } = this.#onboardingController.state; if (!completedOnboarding) { return; } @@ -469,7 +557,7 @@ export default class AccountTracker { const { chainId, provider, identifier } = this.#getCorrectNetworkClient(networkClientId); const { useMultiAccountBalanceChecker } = - this.preferencesController.store.getState(); + this.#preferencesController.store.getState(); let addresses = []; if (useMultiAccountBalanceChecker) { @@ -477,7 +565,7 @@ export default class AccountTracker { addresses = Object.keys(accounts); } else { - const selectedAddress = this.controllerMessenger.call( + const selectedAddress = this.#controllerMessenger.call( 'AccountsController:getSelectedAccount', ).address; @@ -485,7 +573,10 @@ export default class AccountTracker { } const rpcUrl = 'http://127.0.0.1:8545'; - const singleCallBalancesAddress = SINGLE_CALL_BALANCES_ADDRESSES[chainId]; + const singleCallBalancesAddress = + SINGLE_CALL_BALANCES_ADDRESSES[ + chainId as keyof typeof SINGLE_CALL_BALANCES_ADDRESSES + ]; if ( identifier === LOCALHOST_RPC_URL || identifier === rpcUrl || @@ -510,15 +601,18 @@ export default class AccountTracker { * Updates the current balance of an account. * * @private - * @param {string} address - A hex address of a the account to be updated - * @param {object} provider - The provider instance to fetch the balance with - * @param {string} chainId - The chain ID to update in state - * @returns {Promise} after the account balance is updated + * @param address - A hex address of a the account to be updated + * @param provider - The provider instance to fetch the balance with + * @param chainId - The chain ID to update in state */ - async #updateAccount(address, provider, chainId) { + async #updateAccount( + address: string, + provider: Provider, + chainId: Hex, + ): Promise { const { useMultiAccountBalanceChecker } = - this.preferencesController.store.getState(); + this.#preferencesController.store.getState(); let balance = '0x0'; @@ -526,7 +620,16 @@ export default class AccountTracker { try { balance = await pify(new EthQuery(provider)).getBalance(address); } catch (error) { - if (error.data?.request?.method !== 'eth_getBalance') { + if ( + error && + typeof error === 'object' && + hasProperty(error, 'data') && + error.data && + hasProperty(error.data, 'request') && + error.data.request && + hasProperty(error.data.request, 'method') && + error.data.request.method !== 'eth_getBalance' + ) { throw error; } } @@ -556,7 +659,7 @@ export default class AccountTracker { const { accountsByChainId } = this.store.getState(); this.store.updateState({ - ...(chainId === this.getCurrentChainId() && { + ...(chainId === this.#getCurrentChainId() && { accounts: newAccounts, }), accountsByChainId: { @@ -570,18 +673,17 @@ export default class AccountTracker { * Updates current address balances from balanceChecker deployed contract instance * * @private - * @param {Array} addresses - A hex addresses of a the accounts to be updated - * @param {string} deployedContractAddress - The contract address to fetch balances with - * @param {object} provider - The provider instance to fetch the balance with - * @param {string} chainId - The chain ID to update in state - * @returns {Promise} after the account balance is updated + * @param addresses - A hex addresses of a the accounts to be updated + * @param deployedContractAddress - The contract address to fetch balances with + * @param provider - The provider instance to fetch the balance with + * @param chainId - The chain ID to update in state */ async #updateAccountsViaBalanceChecker( - addresses, - deployedContractAddress, - provider, - chainId, - ) { + addresses: string[], + deployedContractAddress: string, + provider: Provider, + chainId: Hex, + ): Promise { const ethContract = await new Contract( deployedContractAddress, SINGLE_CALL_BALANCES_ABI, @@ -593,7 +695,7 @@ export default class AccountTracker { const balances = await ethContract.balances(addresses, ethBalance); const accounts = this.#getAccountsForChainId(chainId); - const newAccounts = {}; + const newAccounts: AccountTrackerState['accounts'] = {}; Object.keys(accounts).forEach((address) => { if (!addresses.includes(address)) { newAccounts[address] = { address, balance: null }; @@ -606,7 +708,7 @@ export default class AccountTracker { const { accountsByChainId } = this.store.getState(); this.store.updateState({ - ...(chainId === this.getCurrentChainId() && { + ...(chainId === this.#getCurrentChainId() && { accounts: newAccounts, }), accountsByChainId: { diff --git a/app/scripts/metamask-controller.js b/app/scripts/metamask-controller.js index 0cceb966d5db..0fa2650d60ab 100644 --- a/app/scripts/metamask-controller.js +++ b/app/scripts/metamask-controller.js @@ -1628,17 +1628,14 @@ export default class MetamaskController extends EventEmitter { onboardingController: this.onboardingController, controllerMessenger: this.controllerMessenger.getRestricted({ name: 'AccountTracker', + allowedActions: ['AccountsController:getSelectedAccount'], allowedEvents: [ 'AccountsController:selectedEvmAccountChange', 'OnboardingController:stateChange', + 'KeyringController:accountRemoved', ], - allowedActions: ['AccountsController:getSelectedAccount'], }), initState: { accounts: {} }, - onAccountRemoved: this.controllerMessenger.subscribe.bind( - this.controllerMessenger, - 'KeyringController:accountRemoved', - ), }); // start and stop polling for balances based on activeControllerConnections diff --git a/types/single-call-balance-checker-abi.d.ts b/types/single-call-balance-checker-abi.d.ts new file mode 100644 index 000000000000..ae42a6e98775 --- /dev/null +++ b/types/single-call-balance-checker-abi.d.ts @@ -0,0 +1,6 @@ +declare module 'single-call-balance-checker-abi' { + import { ContractInterface } from '@ethersproject/contracts'; + + const SINGLE_CALL_BALANCES_ABI: ContractInterface; + export default SINGLE_CALL_BALANCES_ABI; +}