diff --git a/src/hooks/wallets/__tests__/useOnboard.test.ts b/src/hooks/wallets/__tests__/useOnboard.test.ts new file mode 100644 index 0000000000..c08ad6b7fa --- /dev/null +++ b/src/hooks/wallets/__tests__/useOnboard.test.ts @@ -0,0 +1,205 @@ +import { ONBOARD_MPC_MODULE_LABEL } from '@/services/mpc/SocialLoginModule' +import { faker } from '@faker-js/faker' +import type { EIP1193Provider, OnboardAPI, WalletState } from '@web3-onboard/core' +import { getConnectedWallet, switchWallet } from '../useOnboard' + +// mock wallets +jest.mock('@/hooks/wallets/wallets', () => ({ + getDefaultWallets: jest.fn(() => []), + getRecommendedInjectedWallets: jest.fn(() => []), +})) + +describe('useOnboard', () => { + describe('getConnectedWallet', () => { + it('returns the connected wallet', () => { + const wallets = [ + { + label: 'Wallet 1', + icon: 'wallet1.svg', + provider: null as unknown as EIP1193Provider, + chains: [{ id: '0x4' }], + accounts: [ + { + address: '0x1234567890123456789012345678901234567890', + ens: null, + balance: null, + }, + ], + }, + { + label: 'Wallet 2', + icon: 'wallet2.svg', + provider: null as unknown as EIP1193Provider, + chains: [{ id: '0x100' }], + accounts: [ + { + address: '0x2', + ens: null, + balance: null, + }, + ], + }, + ] as WalletState[] + + expect(getConnectedWallet(wallets)).toEqual({ + label: 'Wallet 1', + icon: 'wallet1.svg', + address: '0x1234567890123456789012345678901234567890', + provider: wallets[0].provider, + chainId: '4', + }) + }) + + it('should return null if the address is invalid', () => { + const wallets = [ + { + label: 'Wallet 1', + icon: 'wallet1.svg', + provider: null as unknown as EIP1193Provider, + chains: [{ id: '0x4' }], + accounts: [ + { + address: '0xinvalid', + ens: null, + balance: null, + }, + ], + }, + ] as WalletState[] + + expect(getConnectedWallet(wallets)).toBeNull() + }) + }) + + describe('switchWallet', () => { + it('should keep the previous wallet connected if connection fails', async () => { + const mockOnboard = { + state: { + get: jest.fn().mockReturnValue({ + wallets: [ + { + accounts: [ + { + address: faker.finance.ethereumAddress(), + ens: undefined, + }, + ], + chains: [ + { + id: '5', + }, + ], + label: ONBOARD_MPC_MODULE_LABEL, + }, + ], + }), + }, + connectWallet: jest.fn().mockRejectedValue('Error'), + disconnectWallet: jest.fn(), + } + + await switchWallet(mockOnboard as unknown as OnboardAPI) + + expect(mockOnboard.connectWallet).toBeCalled() + expect(mockOnboard.disconnectWallet).not.toBeCalled() + }) + + it('should not disconnect the previous wallet label if the same label connects', async () => { + const mockNewState = [ + { + accounts: [ + { + address: faker.finance.ethereumAddress(), + ens: undefined, + }, + ], + chains: [ + { + id: '5', + }, + ], + label: ONBOARD_MPC_MODULE_LABEL, + }, + ] + + const mockOnboard = { + state: { + get: jest.fn().mockReturnValue({ + wallets: [ + { + accounts: [ + { + address: faker.finance.ethereumAddress(), + ens: undefined, + }, + ], + chains: [ + { + id: '5', + }, + ], + label: ONBOARD_MPC_MODULE_LABEL, + }, + ], + }), + }, + connectWallet: jest.fn().mockResolvedValue(mockNewState), + disconnectWallet: jest.fn(), + } + + await switchWallet(mockOnboard as unknown as OnboardAPI) + + expect(mockOnboard.connectWallet).toBeCalled() + expect(mockOnboard.disconnectWallet).not.toBeCalled() + }) + + it('should disconnect the previous wallet label if new wallet connects', async () => { + const mockNewState = [ + { + accounts: [ + { + address: faker.finance.ethereumAddress(), + ens: undefined, + }, + ], + chains: [ + { + id: '5', + }, + ], + label: 'MetaMask', + }, + ] + + const mockOnboard = { + state: { + get: jest.fn().mockReturnValue({ + wallets: [ + { + accounts: [ + { + address: faker.finance.ethereumAddress(), + ens: undefined, + }, + ], + chains: [ + { + id: '5', + }, + ], + label: ONBOARD_MPC_MODULE_LABEL, + }, + ], + }), + }, + connectWallet: jest.fn().mockResolvedValue(mockNewState), + disconnectWallet: jest.fn(), + } + + await switchWallet(mockOnboard as unknown as OnboardAPI) + + expect(mockOnboard.connectWallet).toBeCalled() + expect(mockOnboard.disconnectWallet).toBeCalledWith({ label: ONBOARD_MPC_MODULE_LABEL }) + }) + }) +}) diff --git a/src/hooks/wallets/useOnboard.test.ts b/src/hooks/wallets/useOnboard.test.ts deleted file mode 100644 index ed9c637574..0000000000 --- a/src/hooks/wallets/useOnboard.test.ts +++ /dev/null @@ -1,69 +0,0 @@ -import type { EIP1193Provider, WalletState } from '@web3-onboard/core' -import { getConnectedWallet } from './useOnboard' - -// mock wallets -jest.mock('@/hooks/wallets/wallets', () => ({ - getDefaultWallets: jest.fn(() => []), - getRecommendedInjectedWallets: jest.fn(() => []), -})) - -describe('getConnectedWallet', () => { - it('returns the connected wallet', () => { - const wallets = [ - { - label: 'Wallet 1', - icon: 'wallet1.svg', - provider: null as unknown as EIP1193Provider, - chains: [{ id: '0x4' }], - accounts: [ - { - address: '0x1234567890123456789012345678901234567890', - ens: null, - balance: null, - }, - ], - }, - { - label: 'Wallet 2', - icon: 'wallet2.svg', - provider: null as unknown as EIP1193Provider, - chains: [{ id: '0x100' }], - accounts: [ - { - address: '0x2', - ens: null, - balance: null, - }, - ], - }, - ] as WalletState[] - - expect(getConnectedWallet(wallets)).toEqual({ - label: 'Wallet 1', - icon: 'wallet1.svg', - address: '0x1234567890123456789012345678901234567890', - provider: wallets[0].provider, - chainId: '4', - }) - }) - - it('should return null if the address is invalid', () => { - const wallets = [ - { - label: 'Wallet 1', - icon: 'wallet1.svg', - provider: null as unknown as EIP1193Provider, - chains: [{ id: '0x4' }], - accounts: [ - { - address: '0xinvalid', - ens: null, - balance: null, - }, - ], - }, - ] as WalletState[] - - expect(getConnectedWallet(wallets)).toBeNull() - }) -}) diff --git a/src/hooks/wallets/useOnboard.ts b/src/hooks/wallets/useOnboard.ts index 2781d67d86..cc5ef39ecc 100644 --- a/src/hooks/wallets/useOnboard.ts +++ b/src/hooks/wallets/useOnboard.ts @@ -130,8 +130,19 @@ export const connectWallet = async ( return wallets } -export const switchWallet = (onboard: OnboardAPI) => { - connectWallet(onboard) +export const switchWallet = async (onboard: OnboardAPI) => { + const oldWalletLabel = getConnectedWallet(onboard.state.get().wallets)?.label + const newWallets = await connectWallet(onboard) + const newWalletLabel = newWallets ? getConnectedWallet(newWallets)?.label : undefined + + // If the wallet actually changed we disconnect the old connected wallet. + if (!newWalletLabel || !oldWalletLabel) { + return + } + + if (newWalletLabel !== oldWalletLabel) { + await onboard.disconnectWallet({ label: oldWalletLabel }) + } } // Disable/enable wallets according to chain