Skip to content

Commit

Permalink
Simplify role fetching logic in query engine (#282)
Browse files Browse the repository at this point in the history
* Simplify role fetching logic in query engine

Prior implementations of the query engine fetched role information
such as the owning resource ID directly from SpiceDB, as it was the
only data store available. With the introduction of CRDB, that is no
longer the case and the CRDB SQL table should be considered the
authoritative source of most role data. This commit updates the query
engine to fetch role resource owner ID and other data from the SQL DB
whenever possible, getting rid of some obscure failure modes that
occur when a role has no associated actions.

Signed-off-by: John Schaeffer <[email protected]>

* Fix error type in RBAC v2 tests

As described.

Signed-off-by: John Schaeffer <[email protected]>

* Wrap LockRoleForUpdate in a method to return non-DB errors

As described.

Signed-off-by: John Schaeffer <[email protected]>

* Fix incorrect error in role update test case

As described.

Signed-off-by: John Schaeffer <[email protected]>

---------

Signed-off-by: John Schaeffer <[email protected]>
  • Loading branch information
jnschaeffer authored Aug 22, 2024
1 parent 91d9a4e commit 31bbd1c
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 110 deletions.
169 changes: 83 additions & 86 deletions internal/query/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou
return types.Role{}, err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
err = e.lockRoleForUpdate(dbCtx, roleResource)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

Expand Down Expand Up @@ -913,12 +913,18 @@ func (e *engine) ListRoles(ctx context.Context, resource types.Resource) ([]type

// listRoleResourceActions returns all resources and action relations for the provided resource type to the provided role.
// Note: The actions returned by this function are the spicedb relationship action.
func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resource, resTypeName string) (map[types.Resource][]string, error) {
resType := e.namespace + "/" + resTypeName
func (e *engine) listRoleResourceActions(ctx context.Context, role storage.Role) ([]string, error) {
roleOwnerResource, err := e.NewResourceFromID(role.ResourceID)
if err != nil {
return nil, err
}

resType := e.namespace + "/" + roleOwnerResource.Type
roleType := e.namespace + "/role"

filter := &pb.RelationshipFilter{
ResourceType: resType,
ResourceType: resType,
OptionalResourceId: roleOwnerResource.ID.String(),
OptionalSubjectFilter: &pb.SubjectFilter{
SubjectType: roleType,
OptionalSubjectId: role.ID.String(),
Expand All @@ -933,84 +939,47 @@ func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resourc
return nil, err
}

resourceIDActions := make(map[gidx.PrefixedID][]string)
out := make([]string, 0, len(relationships))

for _, rel := range relationships {
resourceID, err := gidx.Parse(rel.Resource.ObjectId)
if err != nil {
return nil, err
}

resourceIDActions[resourceID] = append(resourceIDActions[resourceID], rel.Relation)
}

resourceActions := make(map[types.Resource][]string, len(resourceIDActions))

for resID, actions := range resourceIDActions {
res, err := e.NewResourceFromID(resID)
if err != nil {
return nil, err
}
action := relationToAction(rel.Relation)

resourceActions[res] = actions
out = append(out, action)
}

return resourceActions, nil
return out, nil
}

// GetRole gets the role with it's actions.
// GetRole gets the given role and its actions.
func (e *engine) GetRole(ctx context.Context, roleResource types.Resource) (types.Role, error) {
var (
resActions map[types.Resource][]string
err error
)

for _, resType := range e.schemaRoleables {
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
if err != nil {
return types.Role{}, err
}

// roles are only ever created for a single resource, so we can break after the first one is found.
if len(resActions) != 0 {
break
}
dbRole, err := e.getStorageRole(ctx, roleResource)
if err != nil {
return types.Role{}, err
}

if len(resActions) > 1 {
return types.Role{}, ErrRoleHasTooManyResources
actions, err := e.listRoleResourceActions(ctx, dbRole)
if err != nil {
return types.Role{}, err
}

// returns the first resources actions.
for _, actions := range resActions {
for i, action := range actions {
actions[i] = relationToAction(action)
}

dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)
if err != nil && !errors.Is(err, storage.ErrNoRoleFound) {
e.logger.Error("error while getting role", zap.Error(err))
}
out := types.Role{
ID: roleResource.ID,
Name: dbRole.Name,
Actions: actions,

return types.Role{
ID: roleResource.ID,
Name: dbRole.Name,
Actions: actions,

ResourceID: dbRole.ResourceID,
CreatedBy: dbRole.CreatedBy,
UpdatedBy: dbRole.UpdatedBy,
CreatedAt: dbRole.CreatedAt,
UpdatedAt: dbRole.UpdatedAt,
}, nil
ResourceID: dbRole.ResourceID,
CreatedBy: dbRole.CreatedBy,
UpdatedBy: dbRole.UpdatedBy,
CreatedAt: dbRole.CreatedAt,
UpdatedAt: dbRole.UpdatedAt,
}

return types.Role{}, ErrRoleNotFound
return out, nil
}

// GetRoleResource gets the role's assigned resource.
func (e *engine) GetRoleResource(ctx context.Context, roleResource types.Resource) (types.Resource, error) {
dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)
dbRole, err := e.getStorageRole(ctx, roleResource)
if err != nil {
return types.Resource{}, err
}
Expand All @@ -1029,7 +998,17 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
return err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
dbRole, err := e.getStorageRole(ctx, roleResource)
if err != nil {
return err
}

roleOwnerResource, err := e.NewResourceFromID(dbRole.ResourceID)
if err != nil {
return err
}

err = e.lockRoleForUpdate(dbCtx, roleResource)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

Expand All @@ -1041,20 +1020,11 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
return err
}

var resActions map[types.Resource][]string

for _, resType := range e.schemaRoleables {
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}
actions, err := e.listRoleResourceActions(ctx, dbRole)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

// roles are only ever created for a single resource, so we can break after the first one is found.
if len(resActions) != 0 {
break
}
return err
}

roleType := e.namespace + "/role"
Expand All @@ -1069,15 +1039,16 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
},
}

for resource, relActions := range resActions {
for _, relAction := range relActions {
filters = append(filters, &pb.RelationshipFilter{
ResourceType: e.namespace + "/" + resource.Type,
OptionalResourceId: resource.ID.String(),
OptionalRelation: relAction,
OptionalSubjectFilter: roleSubjectFilter,
})
}
ownerType := e.namespace + "/" + roleOwnerResource.Type
ownerIDStr := roleOwnerResource.ID.String()

for _, relAction := range actions {
filters = append(filters, &pb.RelationshipFilter{
ResourceType: ownerType,
OptionalResourceId: ownerIDStr,
OptionalRelation: relAction,
OptionalSubjectFilter: roleSubjectFilter,
})
}

_, err = e.store.DeleteRole(dbCtx, roleResource.ID)
Expand Down Expand Up @@ -1229,3 +1200,29 @@ func (e *engine) applyUpdates(ctx context.Context, updates []*pb.RelationshipUpd

return nil
}

func (e *engine) getStorageRole(ctx context.Context, roleResource types.Resource) (storage.Role, error) {
dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)

switch {
case err == nil:
return dbRole, nil
case errors.Is(err, storage.ErrNoRoleFound):
return storage.Role{}, ErrRoleNotFound
default:
return storage.Role{}, err
}
}

func (e *engine) lockRoleForUpdate(ctx context.Context, roleResource types.Resource) error {
err := e.store.LockRoleForUpdate(ctx, roleResource.ID)

switch {
case err == nil:
return nil
case errors.Is(err, storage.ErrNoRoleFound):
return ErrRoleNotFound
default:
return err
}
}
41 changes: 27 additions & 14 deletions internal/query/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"go.infratographer.com/permissions-api/internal/iapl"
"go.infratographer.com/permissions-api/internal/spicedbx"
"go.infratographer.com/permissions-api/internal/storage"
"go.infratographer.com/permissions-api/internal/storage/teststore"
"go.infratographer.com/permissions-api/internal/testingx"
"go.infratographer.com/permissions-api/internal/types"
Expand Down Expand Up @@ -96,54 +95,68 @@ func TestCreateRoles(t *testing.T) {
ctx := context.Background()
e := testEngine(ctx, t, namespace, testPolicy())

testCases := []testingx.TestCase[[]string, []types.Role]{
testCases := []testingx.TestCase[[]string, types.Role]{
{
Name: "CreateInvalidAction",
Input: []string{
"bad_action",
},
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) {
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
assert.Error(t, res.Err)
},
},
{
Name: "CreateNoActions",
Input: []string{},
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
expActions := []string{}

require.NoError(t, res.Err)

role := res.Success
assert.Equal(t, expActions, role.Actions)
},
},
{
Name: "CreateSuccess",
Input: []string{
"loadbalancer_get",
},
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) {
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
expActions := []string{
"loadbalancer_get",
}

assert.NoError(t, res.Err)
require.Equal(t, 1, len(res.Success))
require.NoError(t, res.Err)

role := res.Success[0]
role := res.Success
assert.Equal(t, expActions, role.Actions)
},
},
}

testFn := func(ctx context.Context, actions []string) testingx.TestResult[[]types.Role] {
testFn := func(ctx context.Context, actions []string) testingx.TestResult[types.Role] {
tenID, err := gidx.NewID("tnntten")
require.NoError(t, err)
tenRes, err := e.NewResourceFromID(tenID)
require.NoError(t, err)
actorRes, err := e.NewResourceFromID(gidx.MustNewID("idntusr"))
require.NoError(t, err)

_, err = e.CreateRole(ctx, actorRes, tenRes, "test", actions)
role, err := e.CreateRole(ctx, actorRes, tenRes, "test", actions)
if err != nil {
return testingx.TestResult[[]types.Role]{
return testingx.TestResult[types.Role]{
Err: err,
}
}

roles, err := e.ListRoles(ctx, tenRes)
roleResource, err := e.NewResourceFromID(role.ID)
require.NoError(t, err)

return testingx.TestResult[[]types.Role]{
Success: roles,
obs, err := e.GetRole(ctx, roleResource)

return testingx.TestResult[types.Role]{
Success: obs,
Err: err,
}
}
Expand Down Expand Up @@ -232,7 +245,7 @@ func TestRoleUpdate(t *testing.T) {
Input: gidx.MustNewID(RolePrefix),
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
require.Error(t, res.Err)
assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound)
assert.ErrorIs(t, res.Err, ErrRoleNotFound)
},
},
{
Expand Down
8 changes: 4 additions & 4 deletions internal/query/roles_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (e *engine) GetRoleV2(ctx context.Context, role types.Resource) (types.Role
}

// 2. Get role info (name, created_by, etc.) from permissions API DB
dbrole, err := e.store.GetRoleByID(ctx, role.ID)
dbrole, err := e.getStorageRole(ctx, role)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
Expand Down Expand Up @@ -234,7 +234,7 @@ func (e *engine) UpdateRoleV2(ctx context.Context, actor, roleResource types.Res
return types.Role{}, err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
err = e.lockRoleForUpdate(dbCtx, roleResource)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

Expand Down Expand Up @@ -360,7 +360,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource)
return err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
err = e.lockRoleForUpdate(dbCtx, roleResource)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

Expand Down Expand Up @@ -399,7 +399,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource)
return err
}

dbRole, err := e.store.GetRoleByID(dbCtx, roleResource.ID)
dbRole, err := e.getStorageRole(dbCtx, roleResource)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
Expand Down
Loading

0 comments on commit 31bbd1c

Please sign in to comment.