diff --git a/api/organizations.go b/api/organizations.go index d5b8b38..89b8c7e 100644 --- a/api/organizations.go +++ b/api/organizations.go @@ -351,7 +351,7 @@ func (a *API) acceptOrganizationMemberInvitationHandler(w http.ResponseWriter, r // create a helper function to remove the invitation from the database in // case of error or expiration removeInvitation := func() { - if err := a.db.DeclineInvitation(invitationReq.Code); err != nil { + if err := a.db.DeleteInvitation(invitationReq.Code); err != nil { log.Warnf("could not delete invitation: %v", err) } } @@ -383,7 +383,6 @@ func (a *API) acceptOrganizationMemberInvitationHandler(w http.ResponseWriter, r hPassword := internal.HexHashPassword(passwordSalt, invitationReq.User.Password) dbUser = &db.User{ Email: invitationReq.User.Email, - Phone: invitationReq.User.Phone, Password: hPassword, FirstName: invitationReq.User.FirstName, LastName: invitationReq.User.LastName, diff --git a/api/types.go b/api/types.go index 57e9614..8047077 100644 --- a/api/types.go +++ b/api/types.go @@ -56,7 +56,6 @@ type UserOrganization struct { // UserInfo is the request to register a new user. type UserInfo struct { Email string `json:"email,omitempty"` - Phone string `json:"phone,omitempty"` Password string `json:"password,omitempty"` FirstName string `json:"firstName,omitempty"` LastName string `json:"lastName,omitempty"` @@ -88,7 +87,6 @@ type UserPasswordUpdate struct { type UserVerification struct { Email string `json:"email,omitempty"` Code string `json:"code,omitempty"` - Phone string `json:"phone,omitempty"` Expiration time.Time `json:"expiration,omitempty"` Valid bool `json:"valid"` } diff --git a/api/users.go b/api/users.go index 853505c..13078f7 100644 --- a/api/users.go +++ b/api/users.go @@ -50,14 +50,6 @@ func (a *API) sendUserCode(ctx context.Context, user *db.User, t db.CodeType) er }); err != nil { return err } - } else if a.sms != nil { - // send the verification code via SMS if the SMS service is available - if err := a.sms.SendNotification(ctx, ¬ifications.Notification{ - ToNumber: user.Phone, - Body: VerificationCodeTextBody + code, - }); err != nil { - return err - } } return nil } @@ -94,11 +86,6 @@ func (a *API) registerHandler(w http.ResponseWriter, r *http.Request) { ErrMalformedBody.Withf("last name is empty").Write(w) return } - // check the phone is not empty - if userInfo.Phone == "" { - ErrMalformedBody.Withf("phone is empty").Write(w) - return - } // hash the password hPassword := internal.HexHashPassword(passwordSalt, userInfo.Password) // add the user to the database @@ -151,10 +138,8 @@ func (a *API) verifyUserAccountHandler(w http.ResponseWriter, r *http.Request) { // check the email and verification code are not empty if (a.mail != nil || a.sms != nil) && - (verification.Code == "" || - (verification.Email == "" && verification.Phone == "") || - (a.mail == nil && verification.Email != "") || - (a.sms == nil && verification.Phone != "")) { + (verification.Code == "" || verification.Email == "" || + (a.mail == nil && verification.Email != "")) { ErrInvalidUserData.With("no verification code or email/phone provided").Write(w) return } @@ -260,7 +245,6 @@ func (a *API) userVerificationCodeInfoHandler(w http.ResponseWriter, r *http.Req // return the verification code information httpWriteJSON(w, UserVerification{ Email: user.Email, - Phone: user.Phone, Expiration: code.Expiration, Valid: code.Expiration.After(time.Now()), }) @@ -279,7 +263,7 @@ func (a *API) resendUserVerificationCodeHandler(w http.ResponseWriter, r *http.R return } // check the email or the phone number is not empty - if verification.Email == "" && verification.Phone == "" { + if verification.Email == "" { ErrInvalidUserData.With("no email or phone number provided").Write(w) return } @@ -288,8 +272,6 @@ func (a *API) resendUserVerificationCodeHandler(w http.ResponseWriter, r *http.R // get the user information from the database by email or phone if verification.Email != "" { user, err = a.db.UserByEmail(verification.Email) - } else { - user, err = a.db.UserByPhone(verification.Phone) } // check the error getting the user information if err != nil { diff --git a/api/users_test.go b/api/users_test.go index 4150a4b..a35ec1f 100644 --- a/api/users_test.go +++ b/api/users_test.go @@ -39,7 +39,6 @@ func TestRegisterHandler(t *testing.T) { Password: "password", FirstName: "first", LastName: "last", - Phone: "123456789", }), expectedStatus: http.StatusOK, }, @@ -51,7 +50,6 @@ func TestRegisterHandler(t *testing.T) { Password: "password", FirstName: "first", LastName: "last", - Phone: "123456789", }), expectedStatus: http.StatusInternalServerError, expectedBody: mustMarshal(ErrGenericInternalServerError), @@ -64,7 +62,6 @@ func TestRegisterHandler(t *testing.T) { Password: "password", FirstName: "first", LastName: "", - Phone: "123456789", }), expectedStatus: http.StatusBadRequest, expectedBody: mustMarshal(ErrMalformedBody.Withf("last name is empty")), @@ -77,7 +74,6 @@ func TestRegisterHandler(t *testing.T) { Password: "password", FirstName: "", LastName: "last", - Phone: "123456789", }), expectedStatus: http.StatusBadRequest, expectedBody: mustMarshal(ErrMalformedBody.Withf("first name is empty")), @@ -90,7 +86,6 @@ func TestRegisterHandler(t *testing.T) { Password: "password", FirstName: "first", LastName: "last", - Phone: "123456789", }), expectedStatus: http.StatusBadRequest, expectedBody: mustMarshal(ErrEmailMalformed), @@ -103,7 +98,6 @@ func TestRegisterHandler(t *testing.T) { Password: "password", FirstName: "first", LastName: "last", - Phone: "123456789", }), expectedStatus: http.StatusBadRequest, expectedBody: mustMarshal(ErrEmailMalformed), @@ -116,7 +110,6 @@ func TestRegisterHandler(t *testing.T) { Password: "short", FirstName: "first", LastName: "last", - Phone: "123456789", }), expectedStatus: http.StatusBadRequest, expectedBody: mustMarshal(ErrPasswordTooShort), @@ -129,7 +122,6 @@ func TestRegisterHandler(t *testing.T) { Password: "", FirstName: "first", LastName: "last", - Phone: "123456789", }), expectedStatus: http.StatusBadRequest, }, @@ -157,7 +149,6 @@ func TestRegisterHandler(t *testing.T) { c.Errorf("error closing response body: %v", err) } }() - c.Assert(resp.StatusCode, qt.Equals, testCase.expectedStatus) if testCase.expectedBody != nil { body, err := io.ReadAll(resp.Body) @@ -181,7 +172,6 @@ func TestVerifyAccountHandler(t *testing.T) { Password: testPass, FirstName: testFirstName, LastName: testLastName, - Phone: testPhone, }) req, err := http.NewRequest(http.MethodPost, testURL(usersEndpoint), bytes.NewBuffer(jsonUser)) c.Assert(err, qt.IsNil) @@ -271,7 +261,6 @@ func TestRecoverAndResetPassword(t *testing.T) { Password: testPass, FirstName: testFirstName, LastName: testLastName, - Phone: testPhone, }) req, err := http.NewRequest(http.MethodPost, testURL(usersEndpoint), bytes.NewBuffer(jsonUser)) c.Assert(err, qt.IsNil) diff --git a/cmd/service/main.go b/cmd/service/main.go index de6a57a..f45573d 100644 --- a/cmd/service/main.go +++ b/cmd/service/main.go @@ -12,7 +12,6 @@ import ( "github.com/vocdoni/saas-backend/api" "github.com/vocdoni/saas-backend/db" "github.com/vocdoni/saas-backend/notifications/smtp" - "github.com/vocdoni/saas-backend/notifications/twilio" "go.vocdoni.io/dvote/apiclient" "go.vocdoni.io/dvote/log" ) @@ -34,9 +33,6 @@ func main() { flag.String("smtpPassword", "", "SMTP password") flag.String("emailFromAddress", "", "Email service from address") flag.String("emailFromName", "Vocdoni", "Email service from name") - flag.String("twilioAccountSid", "", "Twilio account SID") - flag.String("twilioAuthToken", "", "Twilio auth token") - flag.String("smsFromNumber", "", "SMS from number") // parse flags flag.Parse() // initialize Viper @@ -62,10 +58,6 @@ func main() { smtpPassword := viper.GetString("smtpPassword") emailFromAddress := viper.GetString("emailFromAddress") emailFromName := viper.GetString("emailFromName") - // sms vars - twilioAccountSid := viper.GetString("twilioAccountSid") - twilioAuthToken := viper.GetString("twilioAuthToken") - twilioFromNumber := viper.GetString("twilioFromNumber") // initialize the MongoDB database database, err := db.New(mongoURL, mongoDB) if err != nil { @@ -118,19 +110,6 @@ func main() { } log.Infow("email service created", "from", fmt.Sprintf("%s <%s>", emailFromName, emailFromAddress)) } - // create SMS notifications service if the required parameters are set and - // include it in the API configuration - if twilioAccountSid != "" && twilioAuthToken != "" && twilioFromNumber != "" { - apiConf.SMSService = new(twilio.TwilioSMS) - if err := apiConf.SMSService.New(&twilio.TwilioConfig{ - AccountSid: twilioAccountSid, - AuthToken: twilioAuthToken, - FromNumber: twilioFromNumber, - }); err != nil { - log.Fatalf("could not create the SMS service: %v", err) - } - log.Infow("SMS service created", "from", twilioFromNumber) - } // create the local API server api.New(apiConf).Start() log.Infow("server started", "host", host, "port", port) diff --git a/db/mongo.go b/db/mongo.go index 80e55c1..22aba88 100644 --- a/db/mongo.go +++ b/db/mongo.go @@ -104,6 +104,14 @@ func (ms *MongoStorage) Reset() error { if err := ms.organizations.Drop(ctx); err != nil { return err } + // drop organizationInvites collection + if err := ms.organizationInvites.Drop(ctx); err != nil { + return err + } + // drop verifications collection + if err := ms.verifications.Drop(ctx); err != nil { + return err + } // init the collections if err := ms.initCollections(ms.database); err != nil { return err diff --git a/db/organization_invites.go b/db/organization_invites.go index ec48034..5864f3b 100644 --- a/db/organization_invites.go +++ b/db/organization_invites.go @@ -2,6 +2,7 @@ package db import ( "context" + "fmt" "time" "go.mongodb.org/mongo-driver/bson" @@ -12,11 +13,39 @@ import ( func (ms *MongoStorage) CreateInvitation(invite *OrganizationInvite) error { ms.keysLock.Lock() defer ms.keysLock.Unlock() - + // create a context with a timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - - _, err := ms.organizationInvites.InsertOne(ctx, invite) + // check if the organization exists + if _, err := ms.organization(ctx, invite.OrganizationAddress); err != nil { + return err + } + // check if the user exists + user, err := ms.user(ctx, invite.CurrentUserID) + if err != nil { + return err + } + // check if the user is already a member of the organization + partOfOrg := false + for _, org := range user.Organizations { + if org.Address == invite.OrganizationAddress { + partOfOrg = true + break + } + } + if !partOfOrg { + return fmt.Errorf("user is not part of the organization") + } + // check if expiration date is in the future + if !invite.Expiration.After(time.Now()) { + return fmt.Errorf("expiration date must be in the future") + } + // check if the role is valid + if !IsValidUserRole(invite.Role) { + return fmt.Errorf("invalid role") + } + // insert the invitation in the database + _, err = ms.organizationInvites.InsertOne(ctx, invite) return err } @@ -39,8 +68,8 @@ func (ms *MongoStorage) Invitation(invitationCode string) (*OrganizationInvite, return invite, nil } -// DeclineInvitation removes the invitation from the database. -func (ms *MongoStorage) DeclineInvitation(invitationCode string) error { +// DeleteInvitation removes the invitation from the database. +func (ms *MongoStorage) DeleteInvitation(invitationCode string) error { ms.keysLock.Lock() defer ms.keysLock.Unlock() diff --git a/db/organization_invites_test.go b/db/organization_invites_test.go new file mode 100644 index 0000000..a634c21 --- /dev/null +++ b/db/organization_invites_test.go @@ -0,0 +1,159 @@ +package db + +import ( + "testing" + "time" + + qt "github.com/frankban/quicktest" +) + +var ( + invitationCode = "abc123" + orgAddress = "0x1234567890" + currentUserID = uint64(1) + newMemberEmail = "inviteme@email.com" + expires = time.Now().Add(time.Hour) +) + +func TestCreateInvitation(t *testing.T) { + c := qt.New(t) + defer func() { + if err := db.Reset(); err != nil { + t.Error(err) + } + }() + // non existing organization + testInvite := &OrganizationInvite{ + InvitationCode: invitationCode, + OrganizationAddress: orgAddress, + CurrentUserID: currentUserID, + NewUserEmail: newMemberEmail, + Role: AdminRole, + Expiration: expires, + } + c.Assert(db.CreateInvitation(testInvite), qt.ErrorIs, ErrNotFound) + // non existing user + c.Assert(db.SetOrganization(&Organization{ + Address: orgAddress, + Name: "Organization", + }), qt.IsNil) + c.Assert(db.CreateInvitation(testInvite), qt.ErrorIs, ErrNotFound) + // non organization member + _, err := db.SetUser(&User{ + Email: testUserEmail, + Password: testUserPass, + FirstName: testUserFirstName, + LastName: testUserLastName, + }) + c.Assert(err, qt.IsNil) + c.Assert(db.CreateInvitation(testInvite).Error(), qt.Equals, "user is not part of the organization") + // expiration date in the past + _, err = db.SetUser(&User{ + ID: currentUserID, + Organizations: []OrganizationMember{ + {Address: orgAddress, Role: AdminRole}, + }, + }) + c.Assert(err, qt.IsNil) + testInvite.Expiration = time.Now().Add(-time.Hour) + c.Assert(db.CreateInvitation(testInvite).Error(), qt.Equals, "expiration date must be in the future") + // invalid role + testInvite.Expiration = expires + testInvite.Role = "invalid" + c.Assert(db.CreateInvitation(testInvite).Error(), qt.Equals, "invalid role") + // invitation expires + testInvite.Role = AdminRole + testInvite.Expiration = time.Now().Add(time.Second) + c.Assert(db.CreateInvitation(testInvite), qt.IsNil) + // TTL index could take up to 1 minute + time.Sleep(time.Second * 75) + _, err = db.Invitation(invitationCode) + c.Assert(err, qt.ErrorIs, ErrNotFound) + // success + testInvite.Expiration = expires + c.Assert(db.CreateInvitation(testInvite), qt.IsNil) +} + +func TestInvitation(t *testing.T) { + c := qt.New(t) + defer func() { + if err := db.Reset(); err != nil { + t.Error(err) + } + }() + + _, err := db.Invitation(invitationCode) + c.Assert(err, qt.ErrorIs, ErrNotFound) + c.Assert(db.SetOrganization(&Organization{ + Address: orgAddress, + Name: "Organization", + }), qt.IsNil) + _, err = db.SetUser(&User{ + Email: testUserEmail, + Password: testUserPass, + FirstName: testUserFirstName, + LastName: testUserLastName, + Organizations: []OrganizationMember{ + {Address: orgAddress, Role: AdminRole}, + }, + }) + c.Assert(err, qt.IsNil) + c.Assert(db.CreateInvitation(&OrganizationInvite{ + InvitationCode: invitationCode, + OrganizationAddress: orgAddress, + CurrentUserID: currentUserID, + NewUserEmail: newMemberEmail, + Role: AdminRole, + Expiration: expires, + }), qt.IsNil) + invitation, err := db.Invitation(invitationCode) + c.Assert(err, qt.IsNil) + c.Assert(invitation.InvitationCode, qt.Equals, invitationCode) + c.Assert(invitation.OrganizationAddress, qt.Equals, orgAddress) + c.Assert(invitation.CurrentUserID, qt.Equals, currentUserID) + c.Assert(invitation.NewUserEmail, qt.Equals, newMemberEmail) + c.Assert(invitation.Role, qt.Equals, AdminRole) + // truncate expiration to seconds to avoid rounding issues, also set to UTC + c.Assert(invitation.Expiration.Truncate(time.Second).UTC(), qt.Equals, expires.Truncate(time.Second).UTC()) +} + +func TestDeleteInvitation(t *testing.T) { + c := qt.New(t) + defer func() { + if err := db.Reset(); err != nil { + t.Error(err) + } + }() + + // non existing invitation does not return an error on delete attempt + c.Assert(db.DeleteInvitation(invitationCode), qt.IsNil) + // create valid invitation + c.Assert(db.SetOrganization(&Organization{ + Address: orgAddress, + Name: "Organization", + }), qt.IsNil) + _, err := db.SetUser(&User{ + Email: testUserEmail, + Password: testUserPass, + FirstName: testUserFirstName, + LastName: testUserLastName, + Organizations: []OrganizationMember{ + {Address: orgAddress, Role: AdminRole}, + }, + }) + c.Assert(err, qt.IsNil) + c.Assert(db.CreateInvitation(&OrganizationInvite{ + InvitationCode: invitationCode, + OrganizationAddress: orgAddress, + CurrentUserID: currentUserID, + NewUserEmail: newMemberEmail, + Role: AdminRole, + Expiration: expires, + }), qt.IsNil) + _, err = db.Invitation(invitationCode) + c.Assert(err, qt.IsNil) + // delete the invitation + c.Assert(db.DeleteInvitation(invitationCode), qt.IsNil) + _, err = db.Invitation(invitationCode) + c.Assert(err, qt.ErrorIs, ErrNotFound) +} diff --git a/db/organizations.go b/db/organizations.go index b425b24..fd3da0b 100644 --- a/db/organizations.go +++ b/db/organizations.go @@ -11,6 +11,20 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) +func (ms *MongoStorage) organization(ctx context.Context, address string) (*Organization, error) { + // find the organization in the database + result := ms.organizations.FindOne(ctx, bson.M{"_id": address}) + org := &Organization{} + if err := result.Decode(org); err != nil { + // if the organization doesn't exist return a specific error + if err == mongo.ErrNoDocuments { + return nil, ErrNotFound + } + return nil, err + } + return org, nil +} + // Organization method returns the organization with the given address. If the // parent flag is true, it also returns the parent organization if it exists. If // the organization doesn't exist or the parent organization doesn't exist and @@ -23,26 +37,16 @@ func (ms *MongoStorage) Organization(address string, parent bool) (*Organization ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // find the organization in the database - result := ms.organizations.FindOne(ctx, bson.M{"_id": address}) - org := &Organization{} - if err := result.Decode(org); err != nil { - // if the organization doesn't exist return a specific error - if err == mongo.ErrNoDocuments { - return nil, nil, ErrNotFound - } + org, err := ms.organization(ctx, address) + if err != nil { return nil, nil, err } if !parent || org.Parent == "" { return org, nil, nil } // find the parent organization in the database - result = ms.organizations.FindOne(ctx, bson.M{"_id": org.Parent}) - parentOrg := &Organization{} - if err := result.Decode(parentOrg); err != nil { - // if the parent organization doesn't exist return a specific error - if err == mongo.ErrNoDocuments { - return nil, nil, ErrNotFound - } + parentOrg, err := ms.organization(ctx, org.Parent) + if err != nil { return nil, nil, err } return org, parentOrg, nil diff --git a/db/organizations_test.go b/db/organizations_test.go index cc4f761..adcd82a 100644 --- a/db/organizations_test.go +++ b/db/organizations_test.go @@ -7,12 +7,13 @@ import ( ) func TestOrganization(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // test not found organization address := "childOrgToGet" org, _, err := db.Organization(address, false) @@ -44,12 +45,13 @@ func TestOrganization(t *testing.T) { } func TestSetOrganization(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // create a new organization address := "orgToSet" orgName := "Organization" @@ -88,8 +90,10 @@ func TestSetOrganization(t *testing.T) { }), qt.IsNotNil) // register the creator and retry to create the organization _, err = db.SetUser(&User{ - Email: testUserEmail, - Password: testUserPass, + Email: testUserEmail, + Password: testUserPass, + FirstName: testUserFirstName, + LastName: testUserLastName, }) c.Assert(err, qt.IsNil) c.Assert(db.SetOrganization(&Organization{ @@ -100,12 +104,13 @@ func TestSetOrganization(t *testing.T) { } func TestDeleteOrganization(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // create a new organization and delete it address := "orgToDelete" name := "Organization to delete" @@ -127,18 +132,21 @@ func TestDeleteOrganization(t *testing.T) { } func TestReplaceCreatorEmail(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // create a new organization with a creator address := "orgToReplaceCreator" name := "Organization to replace creator" _, err := db.SetUser(&User{ - Email: testUserEmail, - Password: testUserPass, + Email: testUserEmail, + Password: testUserPass, + FirstName: testUserFirstName, + LastName: testUserLastName, }) c.Assert(err, qt.IsNil) c.Assert(db.SetOrganization(&Organization{ @@ -164,18 +172,21 @@ func TestReplaceCreatorEmail(t *testing.T) { } func TestOrganizationsMembers(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // create a new organization with a creator address := "orgToReplaceCreator" name := "Organization to replace creator" _, err := db.SetUser(&User{ - Email: testUserEmail, - Password: testUserPass, + Email: testUserEmail, + Password: testUserPass, + FirstName: testUserFirstName, + LastName: testUserLastName, }) c.Assert(err, qt.IsNil) c.Assert(db.SetOrganization(&Organization{ diff --git a/db/types.go b/db/types.go index 94cb791..73da421 100644 --- a/db/types.go +++ b/db/types.go @@ -5,7 +5,6 @@ import "time" type User struct { ID uint64 `json:"id" bson:"_id"` Email string `json:"email" bson:"email"` - Phone string `json:"phone" bson:"phone"` Password string `json:"password" bson:"password"` FirstName string `json:"firstName" bson:"firstName"` LastName string `json:"lastName" bson:"lastName"` diff --git a/db/users_test.go b/db/users_test.go index ac5d7b3..f1a3f75 100644 --- a/db/users_test.go +++ b/db/users_test.go @@ -14,12 +14,13 @@ const ( ) func TestUserByEmail(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // test not found user user, err := db.UserByEmail(testUserEmail) c.Assert(user, qt.IsNil) @@ -44,12 +45,13 @@ func TestUserByEmail(t *testing.T) { } func TestUser(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // test not found user id := uint64(100) user, err := db.User(id) @@ -79,12 +81,13 @@ func TestUser(t *testing.T) { } func TestSetUser(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // trying to create a new user with invalid email user := &User{ Email: "invalid-email", @@ -121,12 +124,13 @@ func TestSetUser(t *testing.T) { } func TestDelUser(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // create a new user user := &User{ Email: testUserEmail, @@ -159,12 +163,13 @@ func TestDelUser(t *testing.T) { } func TestIsMemberOf(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) + // create a new user with some organizations user := &User{ Email: testUserEmail, @@ -200,12 +205,12 @@ func TestIsMemberOf(t *testing.T) { } func TestVerifyUser(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) nonExistingUserID := uint64(100) c.Assert(db.VerifyUserAccount(&User{ID: nonExistingUserID}), qt.Equals, ErrNotFound) diff --git a/db/validations.go b/db/validations.go index 931185d..ae5c573 100644 --- a/db/validations.go +++ b/db/validations.go @@ -3,7 +3,8 @@ package db import "go.mongodb.org/mongo-driver/bson" var collectionsValidators = map[string]bson.M{ - "users": usersCollectionValidator, + "users": usersCollectionValidator, + "organizationInvites": organizationInvitesCollectionValidator, } var usersCollectionValidator = bson.M{ @@ -29,3 +30,37 @@ var usersCollectionValidator = bson.M{ }, }, } + +var organizationInvitesCollectionValidator = bson.M{ + "$jsonSchema": bson.M{ + "bsonType": "object", + "required": []string{"invitationCode", "organizationAddress", "currentUserID", "newUserEmail", "role", "expiration"}, + "properties": bson.M{ + "invitationCode": bson.M{ + "bsonType": "string", + "description": "must be a string and is required", + "minimum": 6, + "pattern": `^[\w]{6,}$`, + }, + "organizationAddress": bson.M{ + "bsonType": "string", + "description": "must be a string and is required", + }, + "currentUserID": bson.M{ + "bsonType": "long", + "description": "must be an integer and is required", + "minimum": 1, + "pattern": `^[1-9]+$`, + }, + "newUserEmail": bson.M{ + "bsonType": "string", + "description": "must be an email and is required", + "pattern": `^[\w.\-]+@([\w\-]+\.)+[\w]{2,}$`, + }, + "expiration": bson.M{ + "bsonType": "date", + "description": "must be a date and is required", + }, + }, + }, +} diff --git a/db/verifications_test.go b/db/verifications_test.go index cd1f1ca..9ce9305 100644 --- a/db/verifications_test.go +++ b/db/verifications_test.go @@ -8,12 +8,12 @@ import ( ) func TestUserVerificationCode(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) userID, err := db.SetUser(&User{ Email: testUserEmail, @@ -39,12 +39,12 @@ func TestUserVerificationCode(t *testing.T) { } func TestSetVerificationCode(t *testing.T) { + c := qt.New(t) defer func() { if err := db.Reset(); err != nil { t.Error(err) } }() - c := qt.New(t) nonExistingUserID := uint64(100) err := db.SetVerificationCode(&User{ID: nonExistingUserID}, "testCode", CodeTypeAccountVerification, time.Now())