diff --git a/src/spl/token/instructions.py b/src/spl/token/instructions.py index 705d5ece..69e48052 100644 --- a/src/spl/token/instructions.py +++ b/src/spl/token/instructions.py @@ -1244,3 +1244,24 @@ def create_associated_token_account(payer: Pubkey, owner: Pubkey, mint: Pubkey) program_id=ASSOCIATED_TOKEN_PROGRAM_ID, data=bytes(0), ) + + +def create_idempotent_associated_token_account(payer: Pubkey, owner: Pubkey, mint: Pubkey) -> Instruction: + """Creates an associated token account for the given address/token mint if it not exists. + + Returns: + The instruction to create the associated token account. + """ + associated_token_address = get_associated_token_address(owner, mint) + return Instruction( + accounts=[ + AccountMeta(pubkey=payer, is_signer=True, is_writable=True), + AccountMeta(pubkey=associated_token_address, is_signer=False, is_writable=True), + AccountMeta(pubkey=owner, is_signer=False, is_writable=False), + AccountMeta(pubkey=mint, is_signer=False, is_writable=False), + AccountMeta(pubkey=SYS_PROGRAM_ID, is_signer=False, is_writable=False), + AccountMeta(pubkey=TOKEN_PROGRAM_ID, is_signer=False, is_writable=False), + ], + program_id=ASSOCIATED_TOKEN_PROGRAM_ID, + data=bytes([1]), + ) diff --git a/tests/unit/test_spl_token_instructions.py b/tests/unit/test_spl_token_instructions.py index f3a12132..cc97436f 100644 --- a/tests/unit/test_spl_token_instructions.py +++ b/tests/unit/test_spl_token_instructions.py @@ -1,7 +1,8 @@ """Unit tests for SPL-token instructions.""" import spl.token.instructions as spl_token from solders.pubkey import Pubkey -from spl.token.constants import TOKEN_PROGRAM_ID, WRAPPED_SOL_MINT +from solders.system_program import ID as SYSTEM_PROGRAM_ID +from spl.token.constants import TOKEN_PROGRAM_ID, WRAPPED_SOL_MINT, ASSOCIATED_TOKEN_PROGRAM_ID from spl.token.instructions import get_associated_token_address @@ -392,3 +393,36 @@ def test_sync_native(stubbed_sender): instruction = spl_token.sync_native(params) decoded_params = spl_token.decode_sync_native(instruction) assert params == decoded_params + + +def test_create_idempotent_token_account(stubbed_receiver, stubbed_sender): + """Test Create idempotent token account.""" + mint = Pubkey([0] * 31 + [0]) + token_account = get_associated_token_address(stubbed_receiver, mint) + instruction = spl_token.create_idempotent_associated_token_account( + payer=stubbed_sender.pubkey(), + owner=stubbed_receiver, + mint=mint, + ) + + assert instruction.program_id == ASSOCIATED_TOKEN_PROGRAM_ID + assert instruction.data[0] == 1 # CreateIdempotent + assert len(instruction.accounts) == 6 + assert instruction.accounts[0].pubkey == stubbed_sender.pubkey() + assert instruction.accounts[0].is_signer + assert instruction.accounts[0].is_writable + assert instruction.accounts[1].pubkey == token_account + assert not instruction.accounts[1].is_signer + assert instruction.accounts[1].is_writable + assert instruction.accounts[2].pubkey == stubbed_receiver + assert not instruction.accounts[2].is_signer + assert not instruction.accounts[2].is_writable + assert instruction.accounts[3].pubkey == mint + assert not instruction.accounts[3].is_signer + assert not instruction.accounts[3].is_writable + assert instruction.accounts[4].pubkey == SYSTEM_PROGRAM_ID + assert not instruction.accounts[4].is_signer + assert not instruction.accounts[4].is_writable + assert instruction.accounts[5].pubkey == TOKEN_PROGRAM_ID + assert not instruction.accounts[5].is_signer + assert not instruction.accounts[5].is_writable