Skip to content

Commit

Permalink
feat: add disable org on creation flag, disallow removing last admin …
Browse files Browse the repository at this point in the history
…from the org/group (#317)

* feat: add disable org on creation flag

* feat: disallow removing last admin from org/grp

* chore: add changes from suggestion

* fix: add check to confirm user being removed is the org admin
  • Loading branch information
Chief-Rishab authored Sep 9, 2023
1 parent 96ea8e3 commit 7c80c67
Show file tree
Hide file tree
Showing 20 changed files with 1,481 additions and 895 deletions.
2 changes: 1 addition & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func buildAPIDependencies(
)

organizationRepository := postgres.NewOrganizationRepository(dbc)
organizationService := organization.NewService(organizationRepository, relationService, userService, authnService)
organizationService := organization.NewService(organizationRepository, relationService, userService, authnService, cfg.App.DisableOrgsOnCreate)

domainRepository := postgres.NewDomainRepository(logger, dbc)
domainService := domain.NewService(logger, domainRepository, userService, organizationService)
Expand Down
3 changes: 3 additions & 0 deletions config/sample.config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ app:
# secret string "val://user:password"
# optional
resources_config_path_secret: env://TEST_RESOURCE_CONFIG_SECRET
# disable_orgs_on_create if set to true will set the org status to disabled on creation. This can be used to
# prevent users from accessing the org until they contact the admin and get it enabled. Default is false
disable_orgs_on_create: false
# disable_orgs_listing if set to true will disallow non-admin APIs to list all organizations
disable_orgs_listing: false
# disable_orgs_listing if set to true will disallow non-admin APIs to list all users
Expand Down
1 change: 1 addition & 0 deletions core/organization/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ var (
ErrInvalidID = errors.New("org id is invalid")
ErrConflict = errors.New("org already exist")
ErrInvalidDetail = errors.New("invalid org detail")
ErrDisabled = errors.New("org is disabled")
)
33 changes: 32 additions & 1 deletion core/organization/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,49 @@ type Service struct {
relationService RelationService
userService UserService
authnService AuthnService
defaultState State
}

func NewService(repository Repository, relationService RelationService,
userService UserService, authnService AuthnService) *Service {
userService UserService, authnService AuthnService, disableOrgsOnCreate bool) *Service {
defaultState := Enabled
if disableOrgsOnCreate {
defaultState = Disabled
}
return &Service{
repository: repository,
relationService: relationService,
userService: userService,
authnService: authnService,
defaultState: defaultState,
}
}

// Get returns an enabled organization by id or name. Will return `org is disabled` error if the organization is disabled
func (s Service) Get(ctx context.Context, idOrName string) (Organization, error) {
if utils.IsValidUUID(idOrName) {
orgResp, err := s.repository.GetByID(ctx, idOrName)
if err != nil {
return Organization{}, err
}
if orgResp.State == Disabled {
return Organization{}, ErrDisabled
}
return orgResp, nil
}

orgResp, err := s.repository.GetByName(ctx, idOrName)
if err != nil {
return Organization{}, err
}
if orgResp.State == Disabled {
return Organization{}, ErrDisabled
}
return orgResp, nil
}

// GetRaw returns an organization(both enabled and disabled) by id or name
func (s Service) GetRaw(ctx context.Context, idOrName string) (Organization, error) {
if utils.IsValidUUID(idOrName) {
return s.repository.GetByID(ctx, idOrName)
}
Expand All @@ -63,6 +93,7 @@ func (s Service) Create(ctx context.Context, org Organization) (Organization, er
Title: org.Title,
Avatar: org.Avatar,
Metadata: org.Metadata,
State: s.defaultState,
})
if err != nil {
return Organization{}, err
Expand Down
3 changes: 3 additions & 0 deletions docs/docs/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ app:
# secret string "val://user:password"
# optional
resources_config_path_secret: env://TEST_RESOURCE_CONFIG_SECRET
# disable_orgs_on_create if set to true will set the org status to disabled on creation. This can be used to
# prevent users from accessing the org until they contact the admin and get it enabled. Default is false
disable_orgs_on_create: false
# disable_orgs_listing if set to true will disallow non-admin APIs to list all organizations
disable_orgs_listing: false
# disable_orgs_listing if set to true will disallow non-admin APIs to list all users
Expand Down
3 changes: 3 additions & 0 deletions docs/docs/reference/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ app:
# secret string "val://user:password"
# optional
resources_config_path_secret: env://TEST_RESOURCE_CONFIG_SECRET
# disable_orgs_on_create if set to true will set the org status to disabled on creation. This can be used to
# prevent users from accessing the org until they contact the admin and get it enabled. Default is false
disable_orgs_on_create: false
# disable_orgs_listing if set to true will disallow non-admin APIs to list all organizations
disable_orgs_listing: false
# disable_orgs_listing if set to true will disallow non-admin APIs to list all users
Expand Down
47 changes: 37 additions & 10 deletions internal/api/v1beta1/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (

grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
"github.com/raystack/frontier/core/audit"
"github.com/raystack/frontier/core/organization"
"github.com/raystack/frontier/pkg/errors"
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
"google.golang.org/protobuf/types/known/timestamppb"
)
Expand All @@ -17,14 +19,22 @@ type AuditService interface {

func (h Handler) ListOrganizationAuditLogs(ctx context.Context, request *frontierv1beta1.ListOrganizationAuditLogsRequest) (*frontierv1beta1.ListOrganizationAuditLogsResponse, error) {
logger := grpczap.Extract(ctx)

if request.GetOrgId() == "" {
return nil, grpcBadBodyError
orgResp, err := h.orgService.Get(ctx, request.GetOrgId())
if err != nil {
logger.Error(err.Error())
switch {
case errors.Is(err, organization.ErrDisabled):
return nil, grpcOrgDisabledErr
case errors.Is(err, organization.ErrNotExist):
return nil, grpcOrgNotFoundErr
default:
return nil, grpcInternalServerError
}
}

var logs []*frontierv1beta1.AuditLog
logList, err := h.auditService.List(ctx, audit.Filter{
OrgID: request.GetOrgId(),
OrgID: orgResp.ID,
Source: request.GetSource(),
Action: request.GetAction(),
StartTime: request.GetStartTime().AsTime(),
Expand All @@ -45,8 +55,17 @@ func (h Handler) ListOrganizationAuditLogs(ctx context.Context, request *frontie

func (h Handler) CreateOrganizationAuditLogs(ctx context.Context, request *frontierv1beta1.CreateOrganizationAuditLogsRequest) (*frontierv1beta1.CreateOrganizationAuditLogsResponse, error) {
logger := grpczap.Extract(ctx)
if request.GetOrgId() == "" || request.GetLogs() == nil {
return nil, grpcBadBodyError
orgResp, err := h.orgService.Get(ctx, request.GetOrgId())
if err != nil {
logger.Error(err.Error())
switch {
case errors.Is(err, organization.ErrDisabled):
return nil, grpcOrgDisabledErr
case errors.Is(err, organization.ErrNotExist):
return nil, grpcOrgNotFoundErr
default:
return nil, grpcInternalServerError
}
}

for _, log := range request.GetLogs() {
Expand All @@ -55,7 +74,7 @@ func (h Handler) CreateOrganizationAuditLogs(ctx context.Context, request *front
}
if err := h.auditService.Create(ctx, &audit.Log{
ID: log.GetId(),
OrgID: request.GetOrgId(),
OrgID: orgResp.ID,

Source: log.Source,
Action: log.Action,
Expand All @@ -81,9 +100,17 @@ func (h Handler) CreateOrganizationAuditLogs(ctx context.Context, request *front

func (h Handler) GetOrganizationAuditLog(ctx context.Context, request *frontierv1beta1.GetOrganizationAuditLogRequest) (*frontierv1beta1.GetOrganizationAuditLogResponse, error) {
logger := grpczap.Extract(ctx)

if request.OrgId == "" || request.GetId() == "" {
return nil, grpcBadBodyError
_, err := h.orgService.Get(ctx, request.GetOrgId())
if err != nil {
logger.Error(err.Error())
switch {
case errors.Is(err, organization.ErrDisabled):
return nil, grpcOrgDisabledErr
case errors.Is(err, organization.ErrNotExist):
return nil, grpcOrgNotFoundErr
default:
return nil, grpcInternalServerError
}
}

log, err := h.auditService.GetByID(ctx, request.GetId())
Expand Down
83 changes: 41 additions & 42 deletions internal/api/v1beta1/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/raystack/frontier/core/audit"
"github.com/raystack/frontier/core/organization"
"github.com/raystack/frontier/internal/api/v1beta1/mocks"
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
"github.com/stretchr/testify/assert"
Expand All @@ -17,16 +18,17 @@ import (
func TestHandler_ListOrganizationAuditLogs(t *testing.T) {
tests := []struct {
name string
setup func(as *mocks.AuditService)
setup func(as *mocks.AuditService, os *mocks.OrganizationService)
request *frontierv1beta1.ListOrganizationAuditLogsRequest
want *frontierv1beta1.ListOrganizationAuditLogsResponse
wantErr error
}{
{
name: "should return list of audit logs",
setup: func(as *mocks.AuditService) {
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(testOrgMap[testOrgID], nil)
as.EXPECT().List(mock.AnythingOfType("*context.emptyCtx"), audit.Filter{
OrgID: "org-id",
OrgID: testOrgMap[testOrgID].ID,
Source: "guardian-service",
Action: "project.create",
StartTime: time.Time{},
Expand Down Expand Up @@ -81,9 +83,10 @@ func TestHandler_ListOrganizationAuditLogs(t *testing.T) {
},
{
name: "should return error when audit service returns error",
setup: func(as *mocks.AuditService) {
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(testOrgMap[testOrgID], nil)
as.EXPECT().List(mock.AnythingOfType("*context.emptyCtx"), audit.Filter{
OrgID: "org-id",
OrgID: testOrgMap[testOrgID].ID,
Source: "guardian-service",
Action: "project.create",
StartTime: time.Time{},
Expand All @@ -101,26 +104,30 @@ func TestHandler_ListOrganizationAuditLogs(t *testing.T) {
wantErr: grpcInternalServerError,
},
{
name: "should return error when org id is empty",
name: "should return error when org is disabled",
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(organization.Organization{}, organization.ErrDisabled)
},
request: &frontierv1beta1.ListOrganizationAuditLogsRequest{
OrgId: "",
OrgId: "org-id",
Source: "guardian-service",
Action: "project.create",
StartTime: timestamppb.New(time.Time{}),
EndTime: timestamppb.New(time.Time{}),
},
want: nil,
wantErr: grpcBadBodyError,
wantErr: grpcOrgDisabledErr,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockAuditSrv := new(mocks.AuditService)
mockOrgSrv := new(mocks.OrganizationService)
if tt.setup != nil {
tt.setup(mockAuditSrv)
tt.setup(mockAuditSrv, mockOrgSrv)
}
mockDep := Handler{auditService: mockAuditSrv}
mockDep := Handler{auditService: mockAuditSrv, orgService: mockOrgSrv}
resp, err := mockDep.ListOrganizationAuditLogs(context.Background(), tt.request)
assert.EqualValues(t, tt.want, resp)
assert.EqualValues(t, tt.wantErr, err)
Expand All @@ -131,17 +138,19 @@ func TestHandler_ListOrganizationAuditLogs(t *testing.T) {
func TestHandler_CreateOrganizationAuditLogs(t *testing.T) {
tests := []struct {
name string
setup func(as *mocks.AuditService)
setup func(as *mocks.AuditService, os *mocks.OrganizationService)
req *frontierv1beta1.CreateOrganizationAuditLogsRequest
want *frontierv1beta1.CreateOrganizationAuditLogsResponse
wantErr error
}{
{
name: "should create audit logs on success and return nil error",
setup: func(as *mocks.AuditService) {
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(testOrgMap[testOrgID], nil)
as.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), &audit.Log{
ID: "test-id",
OrgID: "org-id",
ID: "test-id",
OrgID: testOrgMap[testOrgID].ID,

Source: "guardian-service",
Action: "project.create",
CreatedAt: time.Time{},
Expand Down Expand Up @@ -181,17 +190,11 @@ func TestHandler_CreateOrganizationAuditLogs(t *testing.T) {
want: &frontierv1beta1.CreateOrganizationAuditLogsResponse{},
wantErr: nil,
},
{
name: "should return error when missing org id or logs",
req: &frontierv1beta1.CreateOrganizationAuditLogsRequest{
OrgId: "",
Logs: nil,
},
want: nil,
wantErr: grpcBadBodyError,
},
{
name: "should return error when log source and action is empty",
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(testOrgMap[testOrgID], nil)
},
req: &frontierv1beta1.CreateOrganizationAuditLogsRequest{
OrgId: "org-id",
Logs: []*frontierv1beta1.AuditLog{
Expand All @@ -218,10 +221,11 @@ func TestHandler_CreateOrganizationAuditLogs(t *testing.T) {
},
{
name: "should return error when audit service returns error",
setup: func(as *mocks.AuditService) {
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(testOrgMap[testOrgID], nil)
as.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), &audit.Log{
ID: "test-id",
OrgID: "org-id",
OrgID: testOrgMap[testOrgID].ID,
Source: "guardian-service",
Action: "project.create",
CreatedAt: time.Time{},
Expand Down Expand Up @@ -251,11 +255,12 @@ func TestHandler_CreateOrganizationAuditLogs(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockOrgService := new(mocks.OrganizationService)
mockAuditSrv := new(mocks.AuditService)
if tt.setup != nil {
tt.setup(mockAuditSrv)
tt.setup(mockAuditSrv, mockOrgService)
}
mockDep := Handler{auditService: mockAuditSrv}
mockDep := Handler{auditService: mockAuditSrv, orgService: mockOrgService}
resp, err := mockDep.CreateOrganizationAuditLogs(context.Background(), tt.req)
assert.EqualValues(t, tt.want, resp)
assert.EqualValues(t, tt.wantErr, err)
Expand All @@ -266,17 +271,18 @@ func TestHandler_CreateOrganizationAuditLogs(t *testing.T) {
func TestHandler_GetOrganizationAuditLog(t *testing.T) {
tests := []struct {
name string
setup func(as *mocks.AuditService)
setup func(as *mocks.AuditService, os *mocks.OrganizationService)
req *frontierv1beta1.GetOrganizationAuditLogRequest
want *frontierv1beta1.GetOrganizationAuditLogResponse
wantErr error
}{
{
name: "should return audit log on success",
setup: func(as *mocks.AuditService) {
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(testOrgMap[testOrgID], nil)
as.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), "test-id").Return(audit.Log{
ID: "test-id",
OrgID: "org-id",
OrgID: testOrgMap[testOrgID].ID,
Source: "guardian-service",
Action: "project.create",
CreatedAt: time.Time{},
Expand Down Expand Up @@ -322,18 +328,10 @@ func TestHandler_GetOrganizationAuditLog(t *testing.T) {
},
wantErr: nil,
},
{
name: "should return error when org id or log id is empty",
req: &frontierv1beta1.GetOrganizationAuditLogRequest{
Id: "",
OrgId: "",
},
want: nil,
wantErr: grpcBadBodyError,
},
{
name: "should return error when audit service returns error",
setup: func(as *mocks.AuditService) {
setup: func(as *mocks.AuditService, os *mocks.OrganizationService) {
os.EXPECT().Get(mock.AnythingOfType("*context.emptyCtx"), "org-id").Return(testOrgMap[testOrgID], nil)
as.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), "test-id").Return(audit.Log{}, errors.New("test-error"))
},
req: &frontierv1beta1.GetOrganizationAuditLogRequest{
Expand All @@ -347,11 +345,12 @@ func TestHandler_GetOrganizationAuditLog(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockOrgService := new(mocks.OrganizationService)
mockAuditSrv := new(mocks.AuditService)
if tt.setup != nil {
tt.setup(mockAuditSrv)
tt.setup(mockAuditSrv, mockOrgService)
}
mockDep := Handler{auditService: mockAuditSrv}
mockDep := Handler{auditService: mockAuditSrv, orgService: mockOrgService}
resp, err := mockDep.GetOrganizationAuditLog(context.Background(), tt.req)
assert.EqualValues(t, tt.want, resp)
assert.EqualValues(t, tt.wantErr, err)
Expand Down
Loading

0 comments on commit 7c80c67

Please sign in to comment.