Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contract state is no longer required to implement StateClone for unit testing #321

Merged
merged 9 commits into from
Aug 21, 2023
3 changes: 3 additions & 0 deletions .github/workflows/linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ jobs:
# Run all tests, including doc tests.
- name: Run cargo test
run: |
TEMPLATE_DIR=`pwd`
mv $PROJECT_NAME ${{ runner.temp }}/
sed -i "s/{version = \"7.0\", default-features = false}/{path = \"${TEMPLATE_DIR//\//\\\/}\/concordium-std\", default-features = false}/g" ${{ runner.temp }}/$PROJECT_NAME/Cargo.toml
sed -i "s/{version = \"4.0\", default-features = false}/{path = \"${TEMPLATE_DIR//\//\\\/}\/concordium-cis2\", default-features = false}/g" ${{ runner.temp }}/$PROJECT_NAME/Cargo.toml
abizjak marked this conversation as resolved.
Show resolved Hide resolved
cd ${{ runner.temp }}/$PROJECT_NAME
cargo test

Expand Down
2 changes: 2 additions & 0 deletions concordium-std/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
and `check_account_signature` corresponding to the two new host functions
available in protocol 6. Two new types were added to support these operations,
`AccountSignatures` and `AccountPublicKeys`.
- Contract state no longer needs to implement `StateClone` trait in order to work with test infrastructure.
`StateClone` itself is completely removed

## concordium-std 7.0.0 (2023-06-16)

Expand Down
64 changes: 0 additions & 64 deletions concordium-std/src/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3016,70 +3016,6 @@ impl schema::SchemaType for HashKeccak256 {
fn get_type() -> concordium_contracts_common::schema::Type { schema::Type::ByteArray(32) }
}

unsafe impl<T, S: HasStateApi> StateClone<S> for StateSet<T, S> {
unsafe fn clone_state(&self, cloned_state_api: &S) -> Self {
Self {
_marker: self._marker,
prefix: self.prefix,
state_api: cloned_state_api.clone(),
}
}
}

unsafe impl<T, V, S: HasStateApi> StateClone<S> for StateMap<T, V, S> {
unsafe fn clone_state(&self, cloned_state_api: &S) -> Self {
Self {
_marker_key: self._marker_key,
_marker_value: self._marker_value,
prefix: self.prefix,
state_api: cloned_state_api.clone(),
}
}
}

unsafe impl<T: DeserialWithState<S> + Serial, S: HasStateApi> StateClone<S> for StateBox<T, S> {
unsafe fn clone_state(&self, cloned_state_api: &S) -> Self {
let inner_value = match &*self.inner.get() {
StateBoxInner::Loaded {
entry,
modified,
value: _,
} => {
// Get a new entry from the cloned state.
let mut new_entry = cloned_state_api.lookup_entry(entry.get_key()).unwrap_abort();
let new_value =
T::deserial_with_state(cloned_state_api, &mut new_entry).unwrap_abort();

// Set position of new entry to match the old entry.
let old_entry_position = entry.cursor_position();
new_entry.seek(SeekFrom::Start(old_entry_position)).unwrap_abort();

StateBoxInner::Loaded {
entry: new_entry,
modified: *modified,
value: new_value,
}
}
StateBoxInner::Reference {
prefix,
} => StateBoxInner::Reference {
prefix: *prefix,
},
};

Self {
state_api: cloned_state_api.clone(),
inner: UnsafeCell::new(inner_value),
}
}
}

/// Blanket implementation for all cloneable, flat types that don't have
/// references to items in the state.
unsafe impl<T: Clone, S> StateClone<S> for T {
unsafe fn clone_state(&self, _cloned_state_api: &S) -> Self { self.clone() }
}

impl schema::SchemaType for MetadataUrl {
fn get_type() -> schema::Type {
schema::Type::Struct(schema::Fields::Named(crate::vec![
Expand Down
94 changes: 20 additions & 74 deletions concordium-std/src/test_infrastructure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1316,9 +1316,7 @@ pub struct TestHost<State> {
missing_contracts: BTreeSet<ContractAddress>,
}

impl<State: Serial + DeserialWithState<TestStateApi> + StateClone<TestStateApi>> HasHost<State>
for TestHost<State>
{
impl<State: Serial + DeserialWithState<TestStateApi>> HasHost<State> for TestHost<State> {
type ReturnValueType = Cursor<Vec<u8>>;
type StateApiType = TestStateApi;

Expand Down Expand Up @@ -1800,27 +1798,30 @@ impl<State: Serial + DeserialWithState<TestStateApi>> TestHost<State> {
}
}

impl<State: StateClone<TestStateApi>> TestHost<State> {
impl<State: DeserialWithState<TestStateApi>> TestHost<State> {
/// Make a deep clone of the host, including the whole state and all
/// references to the state. Used for rolling back the host and state,
/// fx when using [`with_rollback`].
fn checkpoint(&self) -> Self {
let cloned_state_api = self.state_builder.state_api.clone_deep();
let state: State = cloned_state_api
.read_root()
.expect_report("Could not deserialize root entry from state clone");
Self {
mocking_fns: self.mocking_fns.clone(),
transfers: self.transfers.clone(),
contract_balance: self.contract_balance.clone(),
contract_address: self.contract_address,
mocking_upgrades: self.mocking_upgrades.clone(),
state_builder: StateBuilder {
state_api: cloned_state_api.clone(),
mocking_fns: self.mocking_fns.clone(),
transfers: self.transfers.clone(),
contract_balance: self.contract_balance.clone(),
contract_address: self.contract_address,
mocking_upgrades: self.mocking_upgrades.clone(),
state_builder: StateBuilder {
state_api: cloned_state_api,
},
state: unsafe { self.state.clone_state(&cloned_state_api) },
missing_accounts: self.missing_accounts.clone(),
missing_contracts: self.missing_contracts.clone(),
query_account_balances: self.query_account_balances.clone(),
state,
missing_accounts: self.missing_accounts.clone(),
missing_contracts: self.missing_contracts.clone(),
query_account_balances: self.query_account_balances.clone(),
query_contract_balances: self.query_contract_balances.clone(),
query_exchange_rates: self.query_exchange_rates,
query_exchange_rates: self.query_exchange_rates,
}
}

Expand Down Expand Up @@ -1914,10 +1915,10 @@ mod test {
cell::RefCell,
rc::Rc,
test_infrastructure::{TestStateBuilder, TestStateEntry},
Deletable, DeserialWithState, EntryRaw, HasStateApi, HasStateEntry, StateBox, StateClone,
StateMap, StateSet, INITIAL_NEXT_ITEM_PREFIX,
Deletable, EntryRaw, HasStateApi, HasStateEntry, StateMap, StateSet,
INITIAL_NEXT_ITEM_PREFIX,
};
use concordium_contracts_common::{to_bytes, Cursor, Deserial, Read, Seek, SeekFrom, Write};
use concordium_contracts_common::{to_bytes, Deserial, Read, Seek, SeekFrom, Write};

#[test]
fn test_testhost_balance_queries_reflect_transfers() {
Expand Down Expand Up @@ -2477,59 +2478,4 @@ mod test {
state.lookup_entry(&[]).expect("Lookup failed").size().expect("Getting size failed");
assert_eq!(expected_size as u32, actual_size);
}

#[test]
/// Test that deep cloning a statebox, in both of its internal forms, work
/// as expected.
fn deep_cloning_state_box() {
let state = TestStateApi::new();
let mut state_builder = TestStateBuilder::open(state.clone());

// Helper function.
fn get_loaded_entry_cursor_pos<T: concordium_contracts_common::Serial>(
b: &StateBox<T, TestStateApi>,
) -> u32 {
match unsafe { &*b.inner.get() } {
crate::StateBoxInner::Loaded {
entry,
modified: _,
value: _,
} => entry.cursor_position(),
crate::StateBoxInner::Reference {
prefix: _,
} => panic!("Cannot be called on StateBoxInner::Reference"),
}
}

// These boxes have InnerBox::Loaded
let b1_loaded = state_builder.new_box(101010);
let mut b2_loaded = state_builder.new_box(b1_loaded);

let b2_bytes = to_bytes(&b2_loaded);

let b2_loaded_cursor_pos = get_loaded_entry_cursor_pos(&b2_loaded);

// Deserial to get boxes with InnerBox::Reference
let mut b2_ref = StateBox::<StateBox<i32, _>, _>::deserial_with_state(
&state,
&mut Cursor::new(b2_bytes),
)
.unwrap();

// Make clones
let state_clone = state.clone_deep();
let b2_loaded_clone = unsafe { b2_loaded.clone_state(&state_clone) };
let b2_ref_clone = unsafe { b2_ref.clone_state(&state_clone) };

// Modify originals
b2_loaded.update(|b| b.update(|x| *x += 1)); // 101011
b2_ref.update(|b| b.update(|x| *x += 1)); // 101012

// Check that clones are unchanged
let b2_loaded_clone_cursor_pos = get_loaded_entry_cursor_pos(&b2_loaded_clone);

assert_eq!(*b2_loaded_clone.get().get(), 101010);
assert_eq!(*b2_ref_clone.get().get(), 101010);
assert_eq!(b2_loaded_clone_cursor_pos, b2_loaded_cursor_pos);
}
}
27 changes: 0 additions & 27 deletions concordium-std/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,30 +613,3 @@ where
source: &mut R,
) -> ParseResult<Self>;
}

/// Types that can be cloned along with the state.
///
/// Used for rolling back the test state when errors occur in a receive
/// function. See [`TestHost::with_rollback`][iwr] and
/// [`TestHost::invoke_contract_raw`][icr].
///
/// # Safety
///
/// Marked unsafe because special care should be taken when
/// implementing this trait. In particular, one should only use the supplied
/// `cloned_state_api`, or (shallow) clones thereof. Creating a new
/// [`HasStateApi`] or using a `deep_clone` will lead to an inconsistent state
/// and undefined behaviour.
///
/// [icr]: crate::test_infrastructure::TestHost::invoke_contract_raw
/// [iwr]: crate::test_infrastructure::TestHost::with_rollback
pub unsafe trait StateClone<S> {
/// Make a clone of the type while using the `cloned_state_api`.
///
/// # Safety
///
/// Marked unsafe because this function *should not* be called
/// directly. It is only used within generated code and in the test
/// infrastructure.
unsafe fn clone_state(&self, cloned_state_api: &S) -> Self;
}
2 changes: 1 addition & 1 deletion concordium-std/tests/derive-serial/success-simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! simplest case when `#[concordium(state_parameter)]` is set.
use concordium_std::*;

#[derive(Serial, DeserialWithState, Deletable, StateClone)]
#[derive(Serial, DeserialWithState, Deletable)]
#[concordium(state_parameter = "S")]
struct State<S> {
map: StateMap<u8, u16, S>,
Expand Down
4 changes: 2 additions & 2 deletions examples/cis2-multi-royalties/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct CheckRoyaltyResult {
}

/// The state for each address.
#[derive(Serial, DeserialWithState, Deletable, StateClone)]
#[derive(Serial, DeserialWithState, Deletable)]
#[concordium(state_parameter = "S")]
struct AddressState<S> {
/// The amount of tokens owned by this address.
Expand Down Expand Up @@ -138,7 +138,7 @@ impl<S: HasStateApi> AddressState<S> {
///
/// Note: The specification does not specify how to structure the contract state
/// and this could be structured in a more space efficient way.
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
struct State<S> {
/// The state of addresses.
Expand Down
4 changes: 2 additions & 2 deletions examples/cis2-multi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct SetImplementorsParams {
}

/// The state for each address.
#[derive(Serial, DeserialWithState, Deletable, StateClone)]
#[derive(Serial, DeserialWithState, Deletable)]
#[concordium(state_parameter = "S")]
struct AddressState<S> {
/// The amount of tokens owned by this address.
Expand All @@ -93,7 +93,7 @@ impl<S: HasStateApi> AddressState<S> {
///
/// Note: The specification does not specify how to structure the contract state
/// and this could be structured in a more space efficient way.
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
struct State<S> {
/// The state of addresses.
Expand Down
4 changes: 2 additions & 2 deletions examples/cis2-nft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct MintParams {
}

/// The state for each address.
#[derive(Serial, DeserialWithState, Deletable, StateClone)]
#[derive(Serial, DeserialWithState, Deletable)]
#[concordium(state_parameter = "S")]
struct AddressState<S> {
/// The tokens owned by this address.
Expand All @@ -76,7 +76,7 @@ impl<S: HasStateApi> AddressState<S> {
/// The contract state.
// Note: The specification does not specify how to structure the contract state
// and this could be structured in a more space efficient way depending on the use case.
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
struct State<S> {
/// The state for each address.
Expand Down
4 changes: 2 additions & 2 deletions examples/cis2-wccd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type ContractTokenId = TokenIdUnit;
type ContractTokenAmount = TokenAmountU64;

/// The state tracked for each address.
#[derive(Serial, DeserialWithState, Deletable, StateClone)]
#[derive(Serial, DeserialWithState, Deletable)]
#[concordium(state_parameter = "S")]
struct AddressState<S> {
/// The number of tokens owned by this address.
Expand All @@ -74,7 +74,7 @@ struct AddressState<S> {
}

/// The contract state,
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
struct State<S: HasStateApi> {
/// The admin address can upgrade the contract, pause and unpause the
Expand Down
4 changes: 2 additions & 2 deletions examples/cis3-nft-sponsored-txs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ struct MintParams {
}

/// The state for each address.
#[derive(Serial, DeserialWithState, Deletable, StateClone)]
#[derive(Serial, DeserialWithState, Deletable)]
#[concordium(state_parameter = "S")]
struct AddressState<S> {
/// The tokens owned by this address.
Expand All @@ -241,7 +241,7 @@ impl<S: HasStateApi> AddressState<S> {
/// The contract state.
// Note: The specification does not specify how to structure the contract state
// and this could be structured in a more space efficient way depending on the use case.
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
struct State<S> {
/// Counter to increase the `token_id` at every mint function invoke.
Expand Down
2 changes: 1 addition & 1 deletion examples/credential-registry-storage-contract/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct CredentialState {
}

/// The contract state.
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
struct State<S: HasStateApi> {
/// All verifiable credentials.
Expand Down
2 changes: 1 addition & 1 deletion examples/credential-registry/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl<S: HasStateApi> CredentialEntry<S> {
}

/// The registry state.
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
pub struct State<S: HasStateApi> {
/// An account address of the issuer. It is used for authorization in
Expand Down
2 changes: 1 addition & 1 deletion examples/eSealing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct FileState {
}

/// The contract state.
#[derive(Serial, DeserialWithState, StateClone)]
#[derive(Serial, DeserialWithState)]
#[concordium(state_parameter = "S")]
struct State<S> {
files: StateMap<HashSha2256, FileState, S>,
Expand Down
Loading
Loading