Skip to content

Commit

Permalink
test: add positions tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasmatt committed Jul 24, 2023
1 parent 5d77a51 commit 1c4d8fa
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
10 changes: 10 additions & 0 deletions x/perp/v2/types/msgs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ func TestMsgValidateBasic(t *testing.T) {
true,
"decoding bech32 failed",
},
{
"Test MsgAddMargin: Invalid pair",
&MsgAddMargin{
Sender: validSender,
Pair: invalidPair,
Margin: sdk.NewCoin("denom", sdk.NewInt(10)),
},
true,
"invalid base asset",
},
{
"Test MsgAddMargin: Negative margin",
&MsgAddMargin{
Expand Down
41 changes: 41 additions & 0 deletions x/perp/v2/types/position.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,44 @@ func (m *Position) Validate() error {

return nil
}

func (position *Position) WithTraderAddress(value string) *Position {
position.TraderAddress = value
return position
}
func (position *Position) WithPair(value asset.Pair) *Position {
position.Pair = value
return position
}
func (position *Position) WithSize_(value sdk.Dec) *Position {
position.Size_ = value
return position
}
func (position *Position) WithMargin(value sdk.Dec) *Position {
position.Margin = value
return position
}
func (position *Position) WithOpenNotional(value sdk.Dec) *Position {
position.OpenNotional = value
return position
}
func (position *Position) WithLatestCumulativePremiumFraction(value sdk.Dec) *Position {
position.LatestCumulativePremiumFraction = value
return position
}
func (position *Position) WithLastUpdatedBlockNumber(value int64) *Position {
position.LastUpdatedBlockNumber = value
return position
}

func (p *Position) copy() *Position {
return &Position{
TraderAddress: p.TraderAddress,
Pair: p.Pair,
Size_: p.Size_,
Margin: p.Margin,
OpenNotional: p.OpenNotional,
LatestCumulativePremiumFraction: p.LatestCumulativePremiumFraction,
LastUpdatedBlockNumber: p.LastUpdatedBlockNumber,
}
}
119 changes: 119 additions & 0 deletions x/perp/v2/types/position_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package types

import (
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/require"

"github.com/NibiruChain/nibiru/x/common/asset"
)

func TestZeroPosition(t *testing.T) {
// Initialization
ctx := sdk.Context{}
tokenPair := asset.NewPair("ubtc", "unusd")
traderAddr := sdk.AccAddress{}

position := ZeroPosition(ctx, tokenPair, traderAddr)

// Test the conditions
require.NotNil(t, position)
// Continue testing individual attributes of position as required
}

func TestPositionsAreEqual(t *testing.T) {
accAddress := "cosmos1zaavvzxez0elundtn32qnk9lkm8kmcszzsv80v"
accOtherAddress := "cosmos1g7vzqfthhf4l4vs6skyjj27vqhe97m5gp33hxy"

expected := Position{
TraderAddress: accAddress,
Pair: "ubtc:unusd",
Size_: sdk.OneDec(),
Margin: sdk.OneDec(),
OpenNotional: sdk.OneDec(),
LatestCumulativePremiumFraction: sdk.OneDec(),
LastUpdatedBlockNumber: 0,
}

err := PositionsAreEqual(&expected, expected.copy())
require.NoError(t, err)

testCases := []struct {
modifier func(*Position)
requiredError string
}{
{
modifier: func(p *Position) { p.WithPair(asset.NewPair("ueth", "unusd")) },
requiredError: "expected position pair"},
{
modifier: func(p *Position) { p.WithTraderAddress(accOtherAddress) },
requiredError: "expected position trader address",
},
{
modifier: func(p *Position) { p.WithMargin(sdk.NewDec(42)) },
requiredError: "expected position margin",
},
{
modifier: func(p *Position) { p.WithOpenNotional(sdk.NewDec(42)) },
requiredError: "expected position open notional",
},
{
modifier: func(p *Position) { p.WithSize_(sdk.NewDec(42)) },
requiredError: "expected position size",
},
{
modifier: func(p *Position) { p.WithLastUpdatedBlockNumber(42) },
requiredError: "expected position block number",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.requiredError, func(t *testing.T) {
newPosition := expected.copy()

tc.modifier(newPosition)

err := PositionsAreEqual(&expected, newPosition)
require.Error(t, err)
require.Contains(t, err.Error(), tc.requiredError)
})
}
}

func TestPositionValidate(t *testing.T) {
tests := []struct {
name string
setupPosition func() *Position
expectHasError bool
}{
{
name: "valid position",
setupPosition: func() *Position {
position := &Position{} // fill the Position structure as necessary
return position
},
expectHasError: false,
},
{
name: "invalid position - invalid address",
setupPosition: func() *Position {
position := &Position{} // fill the Position structure with invalid address
return position
},
expectHasError: true,
},
// Continue with similar test cases for other fields of Position.
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.setupPosition().Validate()
if tt.expectHasError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

0 comments on commit 1c4d8fa

Please sign in to comment.