diff --git a/CHANGELOG.md b/CHANGELOG.md index 551be72828..7cc5550365 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The minor version will be incremented upon a breaking change and the patch versi - lang: Export `Discriminator` trait from `prelude` ([#3075](https://github.com/coral-xyz/anchor/pull/3075)). - lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)). - client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)). +- lang: Get discriminator length dynamically ([#3101](https://github.com/coral-xyz/anchor/pull/3101)). ### Fixes diff --git a/lang/attribute/account/src/lib.rs b/lang/attribute/account/src/lib.rs index 27df7eef26..a11492b9ed 100644 --- a/lang/attribute/account/src/lib.rs +++ b/lang/attribute/account/src/lib.rs @@ -180,7 +180,7 @@ pub fn account( if buf.len() < #discriminator.len() { return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..8]; + let given_disc = &buf[..#discriminator.len()]; if &#discriminator != given_disc { return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str)); } @@ -188,7 +188,7 @@ pub fn account( } fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let data: &[u8] = &buf[8..]; + let data: &[u8] = &buf[#discriminator.len()..]; // Re-interpret raw bytes into the POD data structure. let account = anchor_lang::__private::bytemuck::from_bytes(data); // Copy out the bytes into a new, owned data structure. @@ -223,7 +223,7 @@ pub fn account( if buf.len() < #discriminator.len() { return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..8]; + let given_disc = &buf[..#discriminator.len()]; if &#discriminator != given_disc { return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str)); } @@ -231,7 +231,7 @@ pub fn account( } fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let mut data: &[u8] = &buf[8..]; + let mut data: &[u8] = &buf[#discriminator.len()..]; AnchorDeserialize::deserialize(&mut data) .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into()) } diff --git a/lang/src/accounts/account_loader.rs b/lang/src/accounts/account_loader.rs index 0434f33562..76676b70c1 100644 --- a/lang/src/accounts/account_loader.rs +++ b/lang/src/accounts/account_loader.rs @@ -129,7 +129,7 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { return Err(ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &data[..8]; + let given_disc = &data[..disc.len()]; if given_disc != disc { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -158,13 +158,13 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { return Err(ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &data[..8]; + let given_disc = &data[..disc.len()]; if given_disc != disc { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } Ok(Ref::map(data, |data| { - bytemuck::from_bytes(&data[8..mem::size_of::() + 8]) + bytemuck::from_bytes(&data[disc.len()..mem::size_of::() + disc.len()]) })) } @@ -182,13 +182,15 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { return Err(ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &data[..8]; + let given_disc = &data[..disc.len()]; if given_disc != disc { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } Ok(RefMut::map(data, |data| { - bytemuck::from_bytes_mut(&mut data.deref_mut()[8..mem::size_of::() + 8]) + bytemuck::from_bytes_mut( + &mut data.deref_mut()[disc.len()..mem::size_of::() + disc.len()], + ) })) } @@ -204,15 +206,17 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { let data = self.acc_info.try_borrow_mut_data()?; // The discriminator should be zero, since we're initializing. - let mut disc_bytes = [0u8; 8]; - disc_bytes.copy_from_slice(&data[..8]); - let discriminator = u64::from_le_bytes(disc_bytes); - if discriminator != 0 { + let disc = T::DISCRIMINATOR; + let given_disc = &data[..disc.len()]; + let has_disc = given_disc.iter().any(|b| *b != 0); + if has_disc { return Err(ErrorCode::AccountDiscriminatorAlreadySet.into()); } Ok(RefMut::map(data, |data| { - bytemuck::from_bytes_mut(&mut data.deref_mut()[8..mem::size_of::() + 8]) + bytemuck::from_bytes_mut( + &mut data.deref_mut()[disc.len()..mem::size_of::() + disc.len()], + ) })) } } diff --git a/lang/syn/src/codegen/program/idl.rs b/lang/syn/src/codegen/program/idl.rs index 757a2097e5..f45173c2f5 100644 --- a/lang/syn/src/codegen/program/idl.rs +++ b/lang/syn/src/codegen/program/idl.rs @@ -147,7 +147,10 @@ pub fn idl_accounts_and_functions() -> proc_macro2::TokenStream { let owner = accounts.program.key; let to = Pubkey::create_with_seed(&base, seed, owner).unwrap(); // Space: account discriminator || authority pubkey || vec len || vec data - let space = std::cmp::min(8 + 32 + 4 + data_len as usize, 10_000); + let space = std::cmp::min( + IdlAccount::DISCRIMINATOR.len() + 32 + 4 + data_len as usize, + 10_000 + ); let rent = Rent::get()?; let lamports = rent.minimum_balance(space); let seeds = &[&[nonce][..]];