diff --git a/protocol/app/app.go b/protocol/app/app.go index ca7f12c16d..33d43021ac 100644 --- a/protocol/app/app.go +++ b/protocol/app/app.go @@ -131,6 +131,7 @@ import ( // Modules accountplusmodule "github.com/dydxprotocol/v4-chain/protocol/x/accountplus" + "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/authenticator" accountplusmodulekeeper "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/keeper" accountplusmoduletypes "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/types" affiliatesmodule "github.com/dydxprotocol/v4-chain/protocol/x/affiliates" @@ -302,6 +303,7 @@ type App struct { ConsensusParamsKeeper consensusparamkeeper.Keeper GovPlusKeeper govplusmodulekeeper.Keeper AccountPlusKeeper accountplusmodulekeeper.Keeper + AuthenticatorManager *authenticator.AuthenticatorManager AffiliatesKeeper affiliatesmodulekeeper.Keeper MarketMapKeeper marketmapmodulekeeper.Keeper @@ -1210,9 +1212,15 @@ func New( app.PerpetualsKeeper, ) + // Initialize authenticators + app.AuthenticatorManager = authenticator.NewAuthenticatorManager() + app.AuthenticatorManager.InitializeAuthenticators([]authenticator.Authenticator{ + authenticator.NewSignatureVerification(app.AccountKeeper), + }) app.AccountPlusKeeper = *accountplusmodulekeeper.NewKeeper( appCodec, keys[accountplusmoduletypes.StoreKey], + app.AuthenticatorManager, ) accountplusModule := accountplusmodule.NewAppModule(appCodec, app.AccountPlusKeeper) diff --git a/protocol/testutil/ante/testutil.go b/protocol/testutil/ante/testutil.go index 6a930280e9..a245fe4e89 100644 --- a/protocol/testutil/ante/testutil.go +++ b/protocol/testutil/ante/testutil.go @@ -16,6 +16,7 @@ import ( txtestutil "github.com/cosmos/cosmos-sdk/x/auth/tx/testutil" "github.com/cosmos/cosmos-sdk/x/bank" v4module "github.com/dydxprotocol/v4-chain/protocol/app/module" + "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/authenticator" accountpluskeeper "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/keeper" accountplustypes "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/types" @@ -111,7 +112,11 @@ func SetupTestSuite(t testing.TB, isCheckTx bool) *AnteTestSuite { require.NoError(t, err) // Initialize accountplus keeper - suite.AccountplusKeeper = *accountpluskeeper.NewKeeper(suite.EncCfg.Codec, keys[accountplustypes.StoreKey]) + suite.AccountplusKeeper = *accountpluskeeper.NewKeeper( + suite.EncCfg.Codec, + keys[accountplustypes.StoreKey], + authenticator.NewAuthenticatorManager(), + ) // We're using TestMsg encoding in some tests, so register it here. suite.EncCfg.Amino.RegisterConcrete(&testdata.TestMsg{}, "testdata.TestMsg", nil) diff --git a/protocol/testutil/keeper/accountplus.go b/protocol/testutil/keeper/accountplus.go index 2bcf26cfc1..6b1ea4c6ab 100644 --- a/protocol/testutil/keeper/accountplus.go +++ b/protocol/testutil/keeper/accountplus.go @@ -11,6 +11,7 @@ import ( storetypes "cosmossdk.io/store/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/dydxprotocol/v4-chain/protocol/mocks" + "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/authenticator" keeper "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/keeper" ) @@ -52,7 +53,11 @@ func createTimestampNonceKeeper( stateStore.MountStoreWithDB(storeKey, storetypes.StoreTypeIAVL, db) mockTimeProvider := &mocks.TimeProvider{} - k := keeper.NewKeeper(cdc, storeKey) + k := keeper.NewKeeper( + cdc, + storeKey, + authenticator.NewAuthenticatorManager(), + ) return k, storeKey, mockTimeProvider } diff --git a/protocol/x/accountplus/authenticator/manager.go b/protocol/x/accountplus/authenticator/manager.go new file mode 100644 index 0000000000..d2ef241055 --- /dev/null +++ b/protocol/x/accountplus/authenticator/manager.go @@ -0,0 +1,81 @@ +package authenticator + +import "sort" + +// AuthenticatorManager is a manager for all registered authenticators. +type AuthenticatorManager struct { + registeredAuthenticators map[string]Authenticator + orderedKeys []string // slice to keep track of the keys in sorted order +} + +// NewAuthenticatorManager creates a new AuthenticatorManager. +func NewAuthenticatorManager() *AuthenticatorManager { + return &AuthenticatorManager{ + registeredAuthenticators: make(map[string]Authenticator), + orderedKeys: []string{}, + } +} + +// ResetAuthenticators resets all registered authenticators. +func (am *AuthenticatorManager) ResetAuthenticators() { + am.registeredAuthenticators = make(map[string]Authenticator) + am.orderedKeys = []string{} +} + +// InitializeAuthenticators initializes authenticators. If already initialized, it will not overwrite. +func (am *AuthenticatorManager) InitializeAuthenticators(initialAuthenticators []Authenticator) { + if len(am.registeredAuthenticators) > 0 { + return + } + for _, authenticator := range initialAuthenticators { + am.registeredAuthenticators[authenticator.Type()] = authenticator + am.orderedKeys = append(am.orderedKeys, authenticator.Type()) + } + sort.Strings(am.orderedKeys) // Ensure keys are sorted +} + +// RegisterAuthenticator adds a new authenticator to the map of registered authenticators. +func (am *AuthenticatorManager) RegisterAuthenticator(authenticator Authenticator) { + if _, exists := am.registeredAuthenticators[authenticator.Type()]; !exists { + am.orderedKeys = append(am.orderedKeys, authenticator.Type()) + sort.Strings(am.orderedKeys) // Re-sort keys after addition + } + am.registeredAuthenticators[authenticator.Type()] = authenticator +} + +// UnregisterAuthenticator removes an authenticator from the map of registered authenticators. +func (am *AuthenticatorManager) UnregisterAuthenticator(authenticator Authenticator) { + if _, exists := am.registeredAuthenticators[authenticator.Type()]; exists { + delete(am.registeredAuthenticators, authenticator.Type()) + // Remove the key from orderedKeys + for i, key := range am.orderedKeys { + if key == authenticator.Type() { + am.orderedKeys = append(am.orderedKeys[:i], am.orderedKeys[i+1:]...) + break + } + } + } +} + +// GetRegisteredAuthenticators returns the list of registered authenticators in sorted order. +func (am *AuthenticatorManager) GetRegisteredAuthenticators() []Authenticator { + var authenticators []Authenticator + for _, key := range am.orderedKeys { + authenticators = append(authenticators, am.registeredAuthenticators[key]) + } + return authenticators +} + +// IsAuthenticatorTypeRegistered checks if the authenticator type is registered. +func (am *AuthenticatorManager) IsAuthenticatorTypeRegistered(authenticatorType string) bool { + _, exists := am.registeredAuthenticators[authenticatorType] + return exists +} + +// GetAuthenticatorByType returns the base implementation of the authenticator type. +func (am *AuthenticatorManager) GetAuthenticatorByType(authenticatorType string) Authenticator { + if authenticator, exists := am.registeredAuthenticators[authenticatorType]; exists { + return authenticator + } + return nil +} diff --git a/protocol/x/accountplus/authenticator/manager_test.go b/protocol/x/accountplus/authenticator/manager_test.go new file mode 100644 index 0000000000..528c0c66a3 --- /dev/null +++ b/protocol/x/accountplus/authenticator/manager_test.go @@ -0,0 +1,172 @@ +package authenticator_test + +import ( + "fmt" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/authenticator" +) + +// Mock Authenticator for testing purposes +type MockAuthenticator struct { + authType string +} + +func (m MockAuthenticator) OnAuthenticatorRemoved( + ctx sdk.Context, + account sdk.AccAddress, + config []byte, + authenticatorId string, +) error { + return nil +} + +func (m MockAuthenticator) Track(ctx sdk.Context, request authenticator.AuthenticationRequest) error { + return nil +} + +func (m MockAuthenticator) Initialize(config []byte) (authenticator.Authenticator, error) { + return m, nil +} + +func (m MockAuthenticator) StaticGas() uint64 { + return 1000 +} + +func (m MockAuthenticator) Authenticate(ctx sdk.Context, request authenticator.AuthenticationRequest) error { + return nil +} + +func (m MockAuthenticator) ConfirmExecution(ctx sdk.Context, request authenticator.AuthenticationRequest) error { + return nil +} + +func (m MockAuthenticator) OnAuthenticatorAdded( + ctx sdk.Context, + account sdk.AccAddress, + config []byte, + authenticatorId string, +) error { + return nil +} + +func (m MockAuthenticator) Type() string { + return m.authType +} + +var _ authenticator.Authenticator = MockAuthenticator{} + +func TestInitializeAuthenticators(t *testing.T) { + am := authenticator.NewAuthenticatorManager() + auth1 := MockAuthenticator{"type1"} + auth2 := MockAuthenticator{"type2"} + + am.InitializeAuthenticators([]authenticator.Authenticator{auth1, auth2}) + + authenticators := am.GetRegisteredAuthenticators() + require.Equal(t, 2, len(authenticators)) + require.Contains(t, authenticators, auth1) + require.Contains(t, authenticators, auth2) +} + +func TestRegisterAuthenticator(t *testing.T) { + am := authenticator.NewAuthenticatorManager() + auth3 := MockAuthenticator{"type3"} + am.RegisterAuthenticator(auth3) + require.True(t, am.IsAuthenticatorTypeRegistered("type3")) +} + +func TestUnregisterAuthenticator(t *testing.T) { + am := authenticator.NewAuthenticatorManager() + auth2 := MockAuthenticator{"type2"} + am.RegisterAuthenticator(auth2) // Register first to ensure it's there + require.True(t, am.IsAuthenticatorTypeRegistered("type2")) + am.UnregisterAuthenticator(auth2) + require.False(t, am.IsAuthenticatorTypeRegistered("type2")) +} + +func TestGetRegisteredAuthenticators(t *testing.T) { + am := authenticator.NewAuthenticatorManager() + expectedAuthTypes := []string{"type1", "type3"} + unexpectedAuthTypes := []string{"type2"} + + authenticators := am.GetRegisteredAuthenticators() + + for _, auth := range authenticators { + authType := auth.Type() + require.Contains(t, expectedAuthTypes, authType) + require.NotContains(t, unexpectedAuthTypes, authType) + } +} + +// Second mock that always fails authentication +type MockAuthenticatorFail struct { + authType string +} + +func (m MockAuthenticatorFail) OnAuthenticatorRemoved( + ctx sdk.Context, + account sdk.AccAddress, + config []byte, + authenticatorId string, +) error { + return nil +} + +func (m MockAuthenticatorFail) OnAuthenticatorAdded( + ctx sdk.Context, + account sdk.AccAddress, + config []byte, + authenticatorId string, +) error { + return nil +} + +func (m MockAuthenticatorFail) Track(ctx sdk.Context, request authenticator.AuthenticationRequest) error { + return nil +} + +func (m MockAuthenticatorFail) Initialize(config []byte) (authenticator.Authenticator, error) { + return m, nil +} + +func (m MockAuthenticatorFail) StaticGas() uint64 { + return 1000 +} + +func (m MockAuthenticatorFail) Authenticate(ctx sdk.Context, request authenticator.AuthenticationRequest) error { + return fmt.Errorf("Authentication failed") +} + +func (m MockAuthenticatorFail) ConfirmExecution(ctx sdk.Context, request authenticator.AuthenticationRequest) error { + return nil +} + +func (m MockAuthenticatorFail) Type() string { + return m.authType +} + +// Ensure our mocks implement the Authenticator interface +var _ authenticator.Authenticator = MockAuthenticator{} +var _ authenticator.Authenticator = MockAuthenticatorFail{} + +// Tests for the mocks behavior +func TestMockAuthenticators(t *testing.T) { + // Create instances of our mocks + mockPass := MockAuthenticator{"type-pass"} + mockFail := MockAuthenticatorFail{"type-fail"} + + // You may need to mock sdk.Tx, sdk.Msg, and sdk.Context based on their implementations + var mockCtx sdk.Context + + // Testing mockPass + err := mockPass.Authenticate(mockCtx, authenticator.AuthenticationRequest{}) + require.NoError(t, err) + + // Testing mockFail + err = mockFail.Authenticate(mockCtx, authenticator.AuthenticationRequest{}) + require.Error(t, err) +} diff --git a/protocol/x/accountplus/keeper/keeper.go b/protocol/x/accountplus/keeper/keeper.go index 905b444971..e330826f02 100644 --- a/protocol/x/accountplus/keeper/keeper.go +++ b/protocol/x/accountplus/keeper/keeper.go @@ -9,18 +9,27 @@ import ( "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/authenticator" "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/types" ) type Keeper struct { cdc codec.BinaryCodec storeKey storetypes.StoreKey + + authenticatorManager *authenticator.AuthenticatorManager } -func NewKeeper(cdc codec.BinaryCodec, key storetypes.StoreKey) *Keeper { +func NewKeeper( + cdc codec.BinaryCodec, + key storetypes.StoreKey, + authenticatorManager *authenticator.AuthenticatorManager, +) *Keeper { return &Keeper{ cdc: cdc, storeKey: key, + + authenticatorManager: authenticatorManager, } }