From 189d0e756accb8716739e145ef5dc8f0a2eb4c2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20C=2E=20Morency?= <1102868+fmorency@users.noreply.github.com> Date: Fri, 26 Jul 2024 17:12:11 -0400 Subject: [PATCH] test!: nested msg tests (#207) * fix!: error message typo * fix!: refactor error propagation * test(ante): nested msg tests --- ante/ante_test.go | 106 +++++++++++++++++++++ ante/disable_staking.go | 28 +++--- ante/disable_withdraw_delegator_rewards.go | 20 ++-- errors.go | 2 +- 4 files changed, 132 insertions(+), 24 deletions(-) diff --git a/ante/ante_test.go b/ante/ante_test.go index be3d543..0b4bf75 100644 --- a/ante/ante_test.go +++ b/ante/ante_test.go @@ -7,7 +7,11 @@ import ( "github.com/stretchr/testify/require" protov2 "google.golang.org/protobuf/proto" + "github.com/cosmos/gogoproto/proto" + + "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/authz" distrtypes "github.com/cosmos/cosmos-sdk/x/distribution/types" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" @@ -114,6 +118,108 @@ func TestAnteCommissionRanges(t *testing.T) { } } +func TestAnteNested(t *testing.T) { + ctx := sdk.Context{} + ctx = setBlockHeader(ctx, 2) + + const invalidRequestErr = "messages contains *types.Any which is not a sdk.MsgRequest" + cases := []struct { + name string + decorator sdk.AnteDecorator + msg proto.Message + err string + }{ + { + name: "fail: commission nested rate < floor", + decorator: NewCommissionLimitDecorator(true, math.LegacyMustNewDecFromStr("0.10"), math.LegacyMustNewDecFromStr("0.50")), + msg: &poa.MsgCreateValidator{ + Commission: poa.CommissionRates{ + Rate: math.LegacyMustNewDecFromStr("0.09"), + }, + }, + err: "rate 0.090000000000000000 is not between 0.100000000000000000 and 0.500000000000000000", + }, + { + name: "fail: commission nested rate > ceil", + decorator: NewCommissionLimitDecorator(true, math.LegacyMustNewDecFromStr("0.10"), math.LegacyMustNewDecFromStr("0.50")), + msg: &poa.MsgCreateValidator{ + Commission: poa.CommissionRates{ + Rate: math.LegacyMustNewDecFromStr("0.51"), + }, + }, + err: "rate 0.510000000000000000 is not between 0.100000000000000000 and 0.500000000000000000", + }, + { + name: "fail: commission nested rate != ceil and floor", + decorator: NewCommissionLimitDecorator(true, math.LegacyMustNewDecFromStr("0.14"), math.LegacyMustNewDecFromStr("0.14")), + msg: &poa.MsgCreateValidator{ + Commission: poa.CommissionRates{ + Rate: math.LegacyMustNewDecFromStr("0.1"), + }, + }, + err: "rate 0.100000000000000000 is not equal to 0.140000000000000000", + }, + { + name: "failed: commission rate msg is nil", + decorator: NewCommissionLimitDecorator(true, math.LegacyMustNewDecFromStr("0.10"), math.LegacyMustNewDecFromStr("0.50")), + msg: nil, + err: invalidRequestErr, + }, + { + name: "failed: staking action not allowed", + decorator: NewPOADisableStakingDecorator(), + msg: &stakingtypes.MsgCreateValidator{}, + err: poa.ErrStakingActionNotAllowed.Error(), + }, + { + name: "failed: staking filter nil msg", + decorator: NewPOADisableStakingDecorator(), + msg: nil, + err: invalidRequestErr, + }, + { + name: "failed: withdraw rewards not allowed", + decorator: NewPOADisableWithdrawDelegatorRewards(), + msg: &distrtypes.MsgWithdrawDelegatorReward{}, + err: poa.ErrWithdrawDelegatorRewardsNotAllowed.Error(), + }, + { + name: "failed: withdraw rewards nil msg", + decorator: NewPOADisableWithdrawDelegatorRewards(), + msg: nil, + err: invalidRequestErr, + }, + } + + for _, tc := range cases { + tc := tc + + var anyMsg *types.Any + var err error + if tc.msg != nil { + anyMsg, err = types.NewAnyWithValue(tc.msg) + require.NoError(t, err) + } else { + anyMsg = &types.Any{ + TypeUrl: "", + Value: nil, + } + } + + nestedTx := NewMockTx(&authz.MsgExec{ + Grantee: "", + Msgs: []*types.Any{anyMsg}, + }) + + _, err = tc.decorator.AnteHandle(ctx, nestedTx, false, EmptyAnte) + require.Error(t, err) + + if tc.err != "" { + require.ErrorContains(t, err, tc.err) + } + } +} + func TestAnteStakingFilter(t *testing.T) { ctx := sdk.Context{} sf := NewPOADisableStakingDecorator() diff --git a/ante/disable_staking.go b/ante/disable_staking.go index f60fa54..bd782d7 100644 --- a/ante/disable_staking.go +++ b/ante/disable_staking.go @@ -16,49 +16,49 @@ func NewPOADisableStakingDecorator() MsgStakingFilterDecorator { } // AnteHandle performs an AnteHandler check that returns an error if the tx contains a message that is blocked. -func (msfd MsgStakingFilterDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { +func (msfd MsgStakingFilterDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { currHeight := ctx.BlockHeight() if currHeight <= 1 { // allow GenTx to pass return next(ctx, tx, simulate) } - if msfd.hasInvalidStakingMsg(tx.GetMsgs()) { - return ctx, poa.ErrStakingActionNotAllowed + err := msfd.hasInvalidStakingMsg(tx.GetMsgs()) + if err != nil { + return ctx, err } return next(ctx, tx, simulate) } -func (msfd MsgStakingFilterDecorator) hasInvalidStakingMsg(msgs []sdk.Msg) bool { +func (msfd MsgStakingFilterDecorator) hasInvalidStakingMsg(msgs []sdk.Msg) error { for _, msg := range msgs { // authz nested message check (recursive) if execMsg, ok := msg.(*authz.MsgExec); ok { msgs, err := execMsg.GetMessages() if err != nil { - return true + return err } - if msfd.hasInvalidStakingMsg(msgs) { - return true + err = msfd.hasInvalidStakingMsg(msgs) + if err != nil { + return err } } switch msg.(type) { // POA wrapped messages - case *stakingtypes.MsgCreateValidator, *stakingtypes.MsgUpdateParams: - return true - - // Blocked entirely when POA is enabled - case *stakingtypes.MsgBeginRedelegate, + case *stakingtypes.MsgCreateValidator, *stakingtypes.MsgUpdateParams, + // Blocked entirely when POA is enabled + *stakingtypes.MsgBeginRedelegate, *stakingtypes.MsgCancelUnbondingDelegation, *stakingtypes.MsgDelegate, *stakingtypes.MsgUndelegate: - return true + return poa.ErrStakingActionNotAllowed } // stakingtypes.MsgEditValidator is the only allowed message. We do not need to check for it. } - return false + return nil } diff --git a/ante/disable_withdraw_delegator_rewards.go b/ante/disable_withdraw_delegator_rewards.go index 4cf3ea1..16fecb7 100644 --- a/ante/disable_withdraw_delegator_rewards.go +++ b/ante/disable_withdraw_delegator_rewards.go @@ -15,37 +15,39 @@ func NewPOADisableWithdrawDelegatorRewards() MsgDisableWithdrawDelegatorRewards return MsgDisableWithdrawDelegatorRewards{} } -func (mdwr MsgDisableWithdrawDelegatorRewards) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { +func (mdwr MsgDisableWithdrawDelegatorRewards) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { currHeight := ctx.BlockHeight() if currHeight <= 1 { // allow GenTx to pass return next(ctx, tx, simulate) } - if mdwr.hasWithdrawDelegatorRewardsMsg(tx.GetMsgs()) { - return ctx, poa.ErrWithdrawDelegatorRewardsNotAllowed + err := mdwr.hasWithdrawDelegatorRewardsMsg(tx.GetMsgs()) + if err != nil { + return ctx, err } return next(ctx, tx, simulate) } -func (mdwr MsgDisableWithdrawDelegatorRewards) hasWithdrawDelegatorRewardsMsg(msgs []sdk.Msg) bool { +func (mdwr MsgDisableWithdrawDelegatorRewards) hasWithdrawDelegatorRewardsMsg(msgs []sdk.Msg) error { for _, msg := range msgs { // authz nested message check (recursive) if execMsg, ok := msg.(*authz.MsgExec); ok { msgs, err := execMsg.GetMessages() if err != nil { - return true + return err } - if mdwr.hasWithdrawDelegatorRewardsMsg(msgs) { - return true + err = mdwr.hasWithdrawDelegatorRewardsMsg(msgs) + if err != nil { + return err } } if _, ok := msg.(*distrtypes.MsgWithdrawDelegatorReward); ok { - return true + return poa.ErrWithdrawDelegatorRewardsNotAllowed } } - return false + return nil } diff --git a/errors.go b/errors.go index f8b5129..50f7103 100644 --- a/errors.go +++ b/errors.go @@ -5,7 +5,7 @@ import ( ) var ( - ErrStakingActionNotAllowed = sdkerrors.Register(ModuleName, 1, "staking actions are now allowed on this chain") + ErrStakingActionNotAllowed = sdkerrors.Register(ModuleName, 1, "staking actions are not allowed on this chain") ErrPowerBelowMinimum = sdkerrors.Register(ModuleName, 2, "power must be above 1_000_000") ErrNotAnAuthority = sdkerrors.Register(ModuleName, 3, "not an authority") ErrUnsafePower = sdkerrors.Register(ModuleName, 4, "unsafe: msg.Power is >=30% of total power, set unsafe=true to override")