diff --git a/backend/flow_api/flow/credential_usage/action_password_login.go b/backend/flow_api/flow/credential_usage/action_password_login.go index 69a9f3c0e..d6964afcd 100644 --- a/backend/flow_api/flow/credential_usage/action_password_login.go +++ b/backend/flow_api/flow/credential_usage/action_password_login.go @@ -86,7 +86,7 @@ func (a PasswordLogin) Execute(c flowpilot.ExecutionContext) error { return a.wrongCredentialsError(c) } - err := deps.PasswordService.VerifyPassword(userID, c.Input().Get("password").String()) + err := deps.PasswordService.VerifyPassword(deps.Tx, userID, c.Input().Get("password").String()) if err != nil { if errors.Is(err, services.ErrorPasswordInvalid) { err = deps.AuditLogger.CreateWithConnection( diff --git a/backend/flow_api/flow/credential_usage/action_password_recovery.go b/backend/flow_api/flow/credential_usage/action_password_recovery.go index 402dd42e6..10f9259d3 100644 --- a/backend/flow_api/flow/credential_usage/action_password_recovery.go +++ b/backend/flow_api/flow/credential_usage/action_password_recovery.go @@ -47,7 +47,7 @@ func (a PasswordRecovery) Execute(c flowpilot.ExecutionContext) error { authUserID := c.Stash().Get(shared.StashPathUserID).String() - err := deps.PasswordService.RecoverPassword(uuid.FromStringOrNil(authUserID), newPassword) + err := deps.PasswordService.RecoverPassword(deps.Tx, uuid.FromStringOrNil(authUserID), newPassword) if err != nil { if errors.Is(err, services.ErrorPasswordInvalid) { diff --git a/backend/flow_api/flow/profile/action_password_create.go b/backend/flow_api/flow/profile/action_password_create.go index 6fd09a556..a264e0753 100644 --- a/backend/flow_api/flow/profile/action_password_create.go +++ b/backend/flow_api/flow/profile/action_password_create.go @@ -58,7 +58,7 @@ func (a PasswordCreate) Execute(c flowpilot.ExecutionContext) error { passwordCredential := models.NewPasswordCredential(userModel.ID, password) // ? - err := deps.PasswordService.CreatePassword(userModel.ID, password) // ? + err := deps.PasswordService.CreatePassword(deps.Tx, userModel.ID, password) // ? if err != nil { return fmt.Errorf("could not set password: %w", err) } diff --git a/backend/flow_api/flow/profile/action_password_update.go b/backend/flow_api/flow/profile/action_password_update.go index f223e04fe..9dcd5773f 100644 --- a/backend/flow_api/flow/profile/action_password_update.go +++ b/backend/flow_api/flow/profile/action_password_update.go @@ -54,7 +54,7 @@ func (a PasswordUpdate) Execute(c flowpilot.ExecutionContext) error { password := c.Input().Get("password").String() - err := deps.PasswordService.UpdatePassword(userModel.PasswordCredential, password) + err := deps.PasswordService.UpdatePassword(deps.Tx, userModel.PasswordCredential, password) if err != nil { return fmt.Errorf("could not udate password: %w", err) } diff --git a/backend/flow_api/services/password.go b/backend/flow_api/services/password.go index ff2a28806..3e668c7d5 100644 --- a/backend/flow_api/services/password.go +++ b/backend/flow_api/services/password.go @@ -3,6 +3,7 @@ package services import ( "errors" "fmt" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/persistence" @@ -16,10 +17,10 @@ var ( ) type Password interface { - VerifyPassword(userId uuid.UUID, password string) error - RecoverPassword(userId uuid.UUID, newPassword string) error - CreatePassword(userId uuid.UUID, newPassword string) error - UpdatePassword(passwordCredentialModel *models.PasswordCredential, newPassword string) error + VerifyPassword(tx *pop.Connection, userId uuid.UUID, password string) error + RecoverPassword(tx *pop.Connection, userId uuid.UUID, newPassword string) error + CreatePassword(tx *pop.Connection, userId uuid.UUID, newPassword string) error + UpdatePassword(tx *pop.Connection, passwordCredentialModel *models.PasswordCredential, newPassword string) error } type password struct { @@ -34,8 +35,8 @@ func NewPasswordService(cfg config.Config, persister persistence.Persister) Pass } } -func (s password) VerifyPassword(userId uuid.UUID, password string) error { - user, err := s.persister.GetUserPersister().Get(userId) +func (s password) VerifyPassword(tx *pop.Connection, userId uuid.UUID, password string) error { + user, err := s.persister.GetUserPersisterWithConnection(tx).Get(userId) if err != nil { return fmt.Errorf("failed to get user: %w", err) } @@ -44,7 +45,7 @@ func (s password) VerifyPassword(userId uuid.UUID, password string) error { return ErrorPasswordInvalid } - pw, err := s.persister.GetPasswordCredentialPersister().GetByUserID(userId) + pw, err := s.persister.GetPasswordCredentialPersisterWithConnection(tx).GetByUserID(userId) if err != nil { return fmt.Errorf("error retrieving password credential: %w", err) } @@ -60,8 +61,8 @@ func (s password) VerifyPassword(userId uuid.UUID, password string) error { return nil } -func (s password) RecoverPassword(userId uuid.UUID, newPassword string) error { - passwordPersister := s.persister.GetPasswordCredentialPersister() +func (s password) RecoverPassword(tx *pop.Connection, userId uuid.UUID, newPassword string) error { + passwordPersister := s.persister.GetPasswordCredentialPersisterWithConnection(tx) passwordCredentialModel, err := passwordPersister.GetByUserID(userId) if err != nil { @@ -69,9 +70,9 @@ func (s password) RecoverPassword(userId uuid.UUID, newPassword string) error { } if passwordCredentialModel == nil { - err = s.CreatePassword(userId, newPassword) + err = s.CreatePassword(tx, userId, newPassword) } else { - err = s.UpdatePassword(passwordCredentialModel, newPassword) + err = s.UpdatePassword(tx, passwordCredentialModel, newPassword) } if err != nil { @@ -81,7 +82,7 @@ func (s password) RecoverPassword(userId uuid.UUID, newPassword string) error { return nil } -func (s password) CreatePassword(userId uuid.UUID, newPassword string) error { +func (s password) CreatePassword(tx *pop.Connection, userId uuid.UUID, newPassword string) error { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) if err != nil { return ErrorPasswordInvalid @@ -89,7 +90,7 @@ func (s password) CreatePassword(userId uuid.UUID, newPassword string) error { passwordCredentialModel := models.NewPasswordCredential(userId, string(hashedPassword)) - err = s.persister.GetPasswordCredentialPersister().Create(*passwordCredentialModel) + err = s.persister.GetPasswordCredentialPersisterWithConnection(tx).Create(*passwordCredentialModel) if err != nil { return fmt.Errorf("failed to set password: %w", err) } @@ -97,7 +98,7 @@ func (s password) CreatePassword(userId uuid.UUID, newPassword string) error { return nil } -func (s password) UpdatePassword(passwordCredentialModel *models.PasswordCredential, newPassword string) error { +func (s password) UpdatePassword(tx *pop.Connection, passwordCredentialModel *models.PasswordCredential, newPassword string) error { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) if err != nil { return ErrorPasswordInvalid @@ -106,7 +107,7 @@ func (s password) UpdatePassword(passwordCredentialModel *models.PasswordCredent passwordCredentialModel.Password = string(hashedPassword) passwordCredentialModel.UpdatedAt = time.Now().UTC() - err = s.persister.GetPasswordCredentialPersister().Update(*passwordCredentialModel) + err = s.persister.GetPasswordCredentialPersisterWithConnection(tx).Update(*passwordCredentialModel) if err != nil { return fmt.Errorf("failed to update password: %w", err) }