From 2e52934ae5dfd1d04b894d6e322e89f2d3f2450b Mon Sep 17 00:00:00 2001 From: Kush Sharma Date: Wed, 16 Aug 2023 21:01:02 +0530 Subject: [PATCH] feat: update role permissions in place Signed-off-by: Kush Sharma --- core/role/role.go | 4 +- core/role/service.go | 85 +++++++++++++++++-- internal/api/v1beta1/policy.go | 4 + internal/store/postgres/policy.go | 2 + .../store/postgres/policy_repository_test.go | 20 +++-- internal/store/postgres/postgres_test.go | 8 +- internal/store/postgres/role.go | 18 ++-- internal/store/postgres/role_repository.go | 52 ++++++------ .../store/postgres/role_repository_test.go | 16 ++-- pkg/server/interceptors/authorization.go | 13 +-- 10 files changed, 151 insertions(+), 71 deletions(-) diff --git a/core/role/role.go b/core/role/role.go index 0b71d0c66..e664ddb45 100644 --- a/core/role/role.go +++ b/core/role/role.go @@ -22,8 +22,8 @@ type Repository interface { Get(ctx context.Context, id string) (Role, error) GetByName(ctx context.Context, orgID, name string) (Role, error) List(ctx context.Context, f Filter) ([]Role, error) - Upsert(ctx context.Context, role Role) (string, error) - Update(ctx context.Context, toUpdate Role) (string, error) + Upsert(ctx context.Context, role Role) (Role, error) + Update(ctx context.Context, toUpdate Role) (Role, error) Delete(ctx context.Context, roleID string) error } diff --git a/core/role/service.go b/core/role/service.go index 1c59321ac..667ca0ea0 100644 --- a/core/role/service.go +++ b/core/role/service.go @@ -44,18 +44,27 @@ func (s Service) Upsert(ctx context.Context, toCreate Role) (Role, error) { } } - roleID, err := s.repository.Upsert(ctx, toCreate) + createdRole, err := s.repository.Upsert(ctx, toCreate) if err != nil { return Role{}, err } + // create relation between role and permissions + if err := s.createRolePermissionRelation(ctx, createdRole.ID, createdRole.Permissions); err != nil { + return Role{}, err + } + + return createdRole, nil +} + +func (s Service) createRolePermissionRelation(ctx context.Context, roleID string, permissions []string) error { // create relation between role and permissions // for example for each permission: // app/role:org_owner#organization_delete@app/user:* // app/role:org_owner#organization_update@app/user:* // this needs to be created for each type of principles - for _, perm := range toCreate.Permissions { - _, err = s.relationService.Create(ctx, relation.Relation{ + for _, perm := range permissions { + _, err := s.relationService.Create(ctx, relation.Relation{ Object: relation.Object{ ID: roleID, Namespace: schema.RoleNamespace, @@ -67,7 +76,7 @@ func (s Service) Upsert(ctx context.Context, toCreate Role) (Role, error) { RelationName: perm, }) if err != nil { - return Role{}, err + return err } // do the same with service user _, err = s.relationService.Create(ctx, relation.Relation{ @@ -82,11 +91,50 @@ func (s Service) Upsert(ctx context.Context, toCreate Role) (Role, error) { RelationName: perm, }) if err != nil { - return Role{}, err + return err } } + return nil +} - return s.repository.Get(ctx, roleID) +func (s Service) deleteRolePermissionRelation(ctx context.Context, roleID string, permissions []string) error { + // delete relation between role and permissions + // for example for each permission: + // app/role:org_owner#organization_delete@app/user:* + // app/role:org_owner#organization_update@app/user:* + // this needs to be created for each type of principles + for _, perm := range permissions { + err := s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ + ID: roleID, + Namespace: schema.RoleNamespace, + }, + Subject: relation.Subject{ + ID: "*", // all principles who have role will have access + Namespace: schema.UserPrincipal, + }, + RelationName: perm, + }) + if err != nil { + return err + } + // do the same with service user + err = s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ + ID: roleID, + Namespace: schema.RoleNamespace, + }, + Subject: relation.Subject{ + ID: "*", // all principles who have role will have access + Namespace: schema.ServiceUserPrincipal, + }, + RelationName: perm, + }) + if err != nil { + return err + } + } + return nil } func (s Service) Get(ctx context.Context, id string) (Role, error) { @@ -111,11 +159,32 @@ func (s Service) Update(ctx context.Context, toUpdate Role) (Role, error) { } } - roleID, err := s.repository.Update(ctx, toUpdate) + // fetch existing role + existingRole, err := s.Get(ctx, toUpdate.ID) if err != nil { return Role{}, err } - return s.repository.Get(ctx, roleID) + + // figure out what to delete from permission relation + var permissionsToDelete []string + for _, perm := range existingRole.Permissions { + if !utils.Contains(toUpdate.Permissions, perm) { + permissionsToDelete = append(permissionsToDelete, perm) + } + } + + // delete relation between role and permissions + if err := s.deleteRolePermissionRelation(ctx, existingRole.ID, permissionsToDelete); err != nil { + return Role{}, err + } + + // create relation between role and permissions + if err := s.createRolePermissionRelation(ctx, existingRole.ID, toUpdate.Permissions); err != nil { + return Role{}, err + } + + // update in db + return s.repository.Update(ctx, toUpdate) } func (s Service) Delete(ctx context.Context, id string) error { diff --git a/internal/api/v1beta1/policy.go b/internal/api/v1beta1/policy.go index a9da9e98a..ca629d6ab 100644 --- a/internal/api/v1beta1/policy.go +++ b/internal/api/v1beta1/policy.go @@ -4,6 +4,8 @@ import ( "context" "errors" + "github.com/raystack/frontier/core/role" + "github.com/raystack/frontier/core/audit" "github.com/raystack/frontier/internal/bootstrap/schema" @@ -91,6 +93,8 @@ func (h Handler) CreatePolicy(ctx context.Context, request *frontierv1beta1.Crea if err != nil { logger.Error(err.Error()) switch { + case errors.Is(err, role.ErrInvalidID): + return nil, status.Error(codes.InvalidArgument, err.Error()) case errors.Is(err, policy.ErrInvalidDetail): return nil, grpcBadBodyError default: diff --git a/internal/store/postgres/policy.go b/internal/store/postgres/policy.go index f522fc4f1..2f8171a01 100644 --- a/internal/store/postgres/policy.go +++ b/internal/store/postgres/policy.go @@ -48,5 +48,7 @@ func (from Policy) transformToPolicy() (policy.Policy, error) { PrincipalID: from.PrincipalID, PrincipalType: from.PrincipalType, Metadata: unmarshalledMetadata, + CreatedAt: from.CreatedAt, + UpdatedAt: from.UpdatedAt, }, nil } diff --git a/internal/store/postgres/policy_repository_test.go b/internal/store/postgres/policy_repository_test.go index f3d25b5ed..e54538319 100644 --- a/internal/store/postgres/policy_repository_test.go +++ b/internal/store/postgres/policy_repository_test.go @@ -6,6 +6,8 @@ import ( "fmt" "testing" + "github.com/raystack/frontier/core/role" + "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/google/go-cmp/cmp" @@ -30,7 +32,7 @@ type PolicyRepositoryTestSuite struct { policyIDs []string userID string orgID string - roleID string + role role.Role } func (s *PolicyRepositoryTestSuite) SetupSuite() { @@ -65,7 +67,7 @@ func (s *PolicyRepositoryTestSuite) SetupSuite() { if err != nil { s.T().Fatal(err) } - s.roleID = roles[0] + s.role = roles[0] users, err := bootstrapUser(s.client) if err != nil { @@ -76,7 +78,7 @@ func (s *PolicyRepositoryTestSuite) SetupSuite() { func (s *PolicyRepositoryTestSuite) SetupTest() { var err error - s.policyIDs, err = bootstrapPolicy(s.client, s.orgID, s.roleID, s.userID) + s.policyIDs, err = bootstrapPolicy(s.client, s.orgID, s.role, s.userID) if err != nil { s.T().Fatal(err) } @@ -115,7 +117,7 @@ func (s *PolicyRepositoryTestSuite) TestGet() { Description: "should get a policy", SelectedID: s.policyIDs[0], ExpectedPolicy: policy.Policy{ - RoleID: s.roleID, + RoleID: s.role.ID, ResourceType: "ns1", PrincipalID: s.userID, PrincipalType: schema.UserPrincipal, @@ -167,7 +169,7 @@ func (s *PolicyRepositoryTestSuite) TestCreate() { { Description: "should create a policy", PolicyToCreate: policy.Policy{ - RoleID: s.roleID, + RoleID: s.role.ID, ResourceID: uuid.NewString(), ResourceType: "ns1", PrincipalID: s.userID, @@ -185,7 +187,7 @@ func (s *PolicyRepositoryTestSuite) TestCreate() { { Description: "should return error if namespace id does not exist", PolicyToCreate: policy.Policy{ - RoleID: s.roleID, + RoleID: s.role.ID, ResourceType: "ns1-random", }, Err: policy.ErrInvalidDetail, @@ -221,13 +223,13 @@ func (s *PolicyRepositoryTestSuite) TestList() { Description: "should get all policies", ExpectedPolicys: []policy.Policy{ { - RoleID: s.roleID, + RoleID: s.role.ID, PrincipalID: s.userID, ResourceID: s.orgID, ResourceType: "ns1", }, { - RoleID: s.roleID, + RoleID: s.role.ID, PrincipalID: s.userID, ResourceID: s.orgID, ResourceType: "ns2", @@ -270,7 +272,7 @@ func (s *PolicyRepositoryTestSuite) TestUpdate() { Description: "should update an policy", PolicyToUpdate: policy.Policy{ ID: s.policyIDs[0], - RoleID: s.roleID, + RoleID: s.role.ID, ResourceType: "ns1", }, ExpectedPolicyID: s.policyIDs[0], diff --git a/internal/store/postgres/postgres_test.go b/internal/store/postgres/postgres_test.go index 496d1d234..1f9214356 100644 --- a/internal/store/postgres/postgres_test.go +++ b/internal/store/postgres/postgres_test.go @@ -235,7 +235,7 @@ func bootstrapUser(client *db.Client) ([]user.User, error) { return insertedData, nil } -func bootstrapRole(client *db.Client, orgID string) ([]string, error) { +func bootstrapRole(client *db.Client, orgID string) ([]role.Role, error) { roleRepository := postgres.NewRoleRepository(client) testFixtureJSON, err := os.ReadFile("./testdata/mock-role.json") if err != nil { @@ -247,7 +247,7 @@ func bootstrapRole(client *db.Client, orgID string) ([]string, error) { return nil, err } - var insertedData []string + var insertedData []role.Role for _, d := range data { d.OrgID = orgID domain, err := roleRepository.Upsert(context.Background(), d) @@ -261,7 +261,7 @@ func bootstrapRole(client *db.Client, orgID string) ([]string, error) { return insertedData, nil } -func bootstrapPolicy(client *db.Client, orgID, roleID, userID string) ([]string, error) { +func bootstrapPolicy(client *db.Client, orgID string, role role.Role, userID string) ([]string, error) { policyRepository := postgres.NewPolicyRepository(client) testFixtureJSON, err := os.ReadFile("./testdata/mock-policy.json") if err != nil { @@ -277,7 +277,7 @@ func bootstrapPolicy(client *db.Client, orgID, roleID, userID string) ([]string, for _, d := range data { d.PrincipalID = userID d.ResourceID = orgID - d.RoleID = roleID + d.RoleID = role.ID domain, err := policyRepository.Upsert(context.Background(), d) if err != nil { return nil, err diff --git a/internal/store/postgres/role.go b/internal/store/postgres/role.go index 00548159d..bd3677469 100644 --- a/internal/store/postgres/role.go +++ b/internal/store/postgres/role.go @@ -1,6 +1,7 @@ package postgres import ( + "database/sql" "encoding/json" "time" @@ -8,14 +9,15 @@ import ( ) type Role struct { - ID string `db:"id"` - OrgID string `db:"org_id"` - Name string `db:"name"` - Permissions []byte `db:"permissions"` - State string `db:"state"` - Metadata []byte `db:"metadata"` - CreatedAt time.Time `db:"created_at"` - UpdatedAt time.Time `db:"updated_at"` + ID string `db:"id"` + OrgID string `db:"org_id"` + Name string `db:"name"` + Permissions []byte `db:"permissions"` + State string `db:"state"` + Metadata []byte `db:"metadata"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + DeletedAt sql.NullTime `db:"deleted_at"` } func (from Role) transformToRole() (role.Role, error) { diff --git a/internal/store/postgres/role_repository.go b/internal/store/postgres/role_repository.go index c9819f942..122125d92 100644 --- a/internal/store/postgres/role_repository.go +++ b/internal/store/postgres/role_repository.go @@ -96,22 +96,22 @@ func (r RoleRepository) GetByName(ctx context.Context, orgID, name string) (role return roleModel.transformToRole() } -func (r RoleRepository) Upsert(ctx context.Context, rl role.Role) (string, error) { +func (r RoleRepository) Upsert(ctx context.Context, rl role.Role) (role.Role, error) { if strings.TrimSpace(rl.ID) == "" { rl.ID = uuid.New().String() } if strings.TrimSpace(rl.Name) == "" { - return "", role.ErrInvalidDetail + return role.Role{}, role.ErrInvalidDetail } marshaledMetadata, err := json.Marshal(rl.Metadata) if err != nil { - return "", fmt.Errorf("%w: %s", parseErr, err) + return role.Role{}, fmt.Errorf("%w: %s", parseErr, err) } marshaledPermissions, err := json.Marshal(rl.Permissions) if err != nil { - return "", fmt.Errorf("%w: %s", parseErr, err) + return role.Role{}, fmt.Errorf("%w: %s", parseErr, err) } query, _, err := dialect.Insert(TABLE_ROLES).Rows( @@ -126,27 +126,27 @@ func (r RoleRepository) Upsert(ctx context.Context, rl role.Role) (string, error "permissions": goqu.L("$4"), "state": goqu.L("$5"), "metadata": goqu.L("$6"), - })).Returning("id").ToSQL() + })).Returning(&Role{}).ToSQL() if err != nil { - return "", fmt.Errorf("%w: %s", queryErr, err) + return role.Role{}, fmt.Errorf("%w: %s", queryErr, err) } - var roleID string + var roleDB Role if err = r.dbc.WithTimeout(ctx, TABLE_ROLES, "Upsert", func(ctx context.Context) error { - return r.dbc.QueryRowxContext(ctx, query, rl.ID, rl.OrgID, rl.Name, marshaledPermissions, rl.State, marshaledMetadata).Scan(&roleID) + return r.dbc.QueryRowxContext(ctx, query, rl.ID, rl.OrgID, rl.Name, marshaledPermissions, rl.State, marshaledMetadata).StructScan(&roleDB) }); err != nil { err = checkPostgresError(err) switch { case errors.Is(err, ErrDuplicateKey): - return "", role.ErrConflict + return role.Role{}, role.ErrConflict case errors.Is(err, ErrForeignKeyViolation): - return "", role.ErrInvalidDetail + return role.Role{}, role.ErrInvalidDetail default: - return "", err + return role.Role{}, err } } - return roleID, nil + return roleDB.transformToRole() } func (r RoleRepository) List(ctx context.Context, flt role.Filter) ([]role.Role, error) { @@ -181,21 +181,21 @@ func (r RoleRepository) List(ctx context.Context, flt role.Filter) ([]role.Role, return transformedRoles, nil } -func (r RoleRepository) Update(ctx context.Context, rl role.Role) (string, error) { +func (r RoleRepository) Update(ctx context.Context, rl role.Role) (role.Role, error) { if strings.TrimSpace(rl.ID) == "" { - return "", role.ErrInvalidID + return role.Role{}, role.ErrInvalidID } if strings.TrimSpace(rl.Name) == "" { - return "", role.ErrInvalidDetail + return role.Role{}, role.ErrInvalidDetail } marshaledMetadata, err := json.Marshal(rl.Metadata) if err != nil { - return "", fmt.Errorf("%w: %s", parseErr, err) + return role.Role{}, fmt.Errorf("%w: %s", parseErr, err) } marshaledPermissions, err := json.Marshal(rl.Permissions) if err != nil { - return "", fmt.Errorf("%w: %s", parseErr, err) + return role.Role{}, fmt.Errorf("%w: %s", parseErr, err) } query, _, err := dialect.Update(TABLE_ROLES).Set( @@ -207,29 +207,29 @@ func (r RoleRepository) Update(ctx context.Context, rl role.Role) (string, error "updated_at": goqu.L("now()"), }).Where( goqu.Ex{"id": goqu.L("$1")}, - ).Returning("id").ToSQL() + ).Returning(&Role{}).ToSQL() if err != nil { - return "", fmt.Errorf("%w: %s", queryErr, err) + return role.Role{}, fmt.Errorf("%w: %s", queryErr, err) } - var roleID string + var roleDB Role if err = r.dbc.WithTimeout(ctx, TABLE_ROLES, "Update", func(ctx context.Context) error { - return r.dbc.QueryRowxContext(ctx, query, rl.ID, rl.Name, marshaledPermissions, rl.State, marshaledMetadata).Scan(&roleID) + return r.dbc.QueryRowxContext(ctx, query, rl.ID, rl.Name, marshaledPermissions, rl.State, marshaledMetadata).StructScan(&roleDB) }); err != nil { err = checkPostgresError(err) switch { case errors.Is(err, sql.ErrNoRows): - return "", role.ErrNotExist + return role.Role{}, role.ErrNotExist case errors.Is(err, ErrForeignKeyViolation): - return "", namespace.ErrNotExist + return role.Role{}, namespace.ErrNotExist case errors.Is(err, ErrDuplicateKey): - return "", role.ErrConflict + return role.Role{}, role.ErrConflict default: - return "", err + return role.Role{}, err } } - return roleID, nil + return roleDB.transformToRole() } func (r RoleRepository) Delete(ctx context.Context, id string) error { diff --git a/internal/store/postgres/role_repository_test.go b/internal/store/postgres/role_repository_test.go index 20f4cb02d..03d45d793 100644 --- a/internal/store/postgres/role_repository_test.go +++ b/internal/store/postgres/role_repository_test.go @@ -26,7 +26,7 @@ type RoleRepositoryTestSuite struct { pool *dockertest.Pool resource *dockertest.Resource repository *postgres.RoleRepository - roleIDs []string + roles []role.Role orgID string } @@ -56,7 +56,7 @@ func (s *RoleRepositoryTestSuite) SetupSuite() { func (s *RoleRepositoryTestSuite) SetupTest() { var err error - s.roleIDs, err = bootstrapRole(s.client, s.orgID) + s.roles, err = bootstrapRole(s.client, s.orgID) if err != nil { s.T().Fatal(err) } @@ -93,9 +93,9 @@ func (s *RoleRepositoryTestSuite) TestGet() { var testCases = []testCase{ { Description: "should get a role", - SelectedID: s.roleIDs[3], + SelectedID: s.roles[3].ID, ExpectedRole: role.Role{ - ID: s.roleIDs[3], + ID: s.roles[3].ID, Name: "editor", Permissions: []string{ "user", @@ -185,7 +185,7 @@ func (s *RoleRepositoryTestSuite) TestCreate() { s.T().Fatalf("got error %s, expected was %s", err.Error(), tc.ErrString) } } - if tc.ExpectedID != "" && (got != tc.ExpectedID) { + if tc.ExpectedID != "" && (got.ID != tc.ExpectedID) { s.T().Fatalf("got result %+v, expected was %+v", got, tc.ExpectedID) } }) @@ -233,13 +233,13 @@ func (s *RoleRepositoryTestSuite) TestUpdate() { { Description: "should update a role", RoleToUpdate: role.Role{ - ID: s.roleIDs[0], + ID: s.roles[0].ID, Name: "role members", OrgID: s.orgID, Metadata: metadata.Metadata{}, Permissions: []string{"member", "user"}, }, - ExpectedRoleID: s.roleIDs[0], + ExpectedRoleID: s.roles[0].ID, }, { Description: "should return error if role not found", @@ -278,7 +278,7 @@ func (s *RoleRepositoryTestSuite) TestUpdate() { if tc.ErrString == "" { s.Assert().NoError(err) } - if !cmp.Equal(got, tc.ExpectedRoleID) { + if !cmp.Equal(got.ID, tc.ExpectedRoleID) { s.T().Fatalf("got result %+v, expected was %+v", got, tc.ExpectedRoleID) } }) diff --git a/pkg/server/interceptors/authorization.go b/pkg/server/interceptors/authorization.go index 17aced0c2..7b7549eec 100644 --- a/pkg/server/interceptors/authorization.go +++ b/pkg/server/interceptors/authorization.go @@ -62,12 +62,13 @@ func UnaryAuthorizationCheck(identityHeader string) grpc.UnaryServerInterceptor // authorizationSkipList stores path to skip authorization, by default its enabled for all requests var authorizationSkipList = map[string]bool{ - "/raystack.frontier.v1beta1.FrontierService/GetJWKs": true, - "/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true, - "/raystack.frontier.v1beta1.FrontierService/Authenticate": true, - "/raystack.frontier.v1beta1.FrontierService/AuthCallback": true, - "/raystack.frontier.v1beta1.FrontierService/AuthToken": true, - "/raystack.frontier.v1beta1.FrontierService/AuthLogout": true, + "/raystack.frontier.v1beta1.FrontierService/GetJWKs": true, + "/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true, + "/raystack.frontier.v1beta1.FrontierService/Authenticate": true, + "/raystack.frontier.v1beta1.FrontierService/AuthCallback": true, + "/raystack.frontier.v1beta1.FrontierService/AuthToken": true, + "/raystack.frontier.v1beta1.FrontierService/AuthLogout": true, + "/raystack.frontier.v1beta1.FrontierService/CheckResourcePermission": true, "/raystack.frontier.v1beta1.FrontierService/ListPermissions": true, "/raystack.frontier.v1beta1.FrontierService/GetPermission": true,