Skip to content

Commit

Permalink
Merge pull request #12992 from markylaing/dangling-permissions
Browse files Browse the repository at this point in the history
Auth: Handle dangling permissions
  • Loading branch information
tomponline authored Mar 5, 2024
2 parents de0b61c + fcde91b commit 6c6d117
Show file tree
Hide file tree
Showing 14 changed files with 563 additions and 506 deletions.
119 changes: 46 additions & 73 deletions lxd/auth_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/canonical/lxd/shared"
"github.com/canonical/lxd/shared/api"
"github.com/canonical/lxd/shared/entity"
"github.com/canonical/lxd/shared/logger"
)

var authGroupsCmd = APIEndpoint{
Expand Down Expand Up @@ -170,7 +171,7 @@ func getAuthGroups(d *Daemon, r *http.Request) response.Response {
}

var groups []dbCluster.AuthGroup
groupsPermissions := make(map[int][]dbCluster.Permission)
var authGroupPermissions []dbCluster.Permission
groupsIdentities := make(map[int][]dbCluster.Identity)
groupsIdentityProviderGroups := make(map[int][]dbCluster.IdentityProviderGroup)
entityURLs := make(map[entity.Type]map[int]*api.URL)
Expand Down Expand Up @@ -204,25 +205,29 @@ func getAuthGroups(d *Daemon, r *http.Request) response.Response {
return err
}

groupsPermissions, err = dbCluster.GetAllPermissionsByAuthGroupIDs(ctx, tx.Tx())
authGroupPermissions, err = dbCluster.GetPermissions(ctx, tx.Tx())
if err != nil {
return err
}

// allGroupPermissions is a de-duplicated slice of permissions.
var allGroupPermissions []dbCluster.Permission
for _, groupPermissions := range groupsPermissions {
for _, permission := range groupPermissions {
if !shared.ValueInSlice(permission, allGroupPermissions) {
allGroupPermissions = append(allGroupPermissions, permission)
// Get the EntityURLs for the permissions. Ignore any dangling permissions.
var danglingPermissions []dbCluster.Permission
authGroupPermissions, danglingPermissions, entityURLs, err = dbCluster.GetPermissionEntityURLs(ctx, tx.Tx(), authGroupPermissions)
if err != nil {
return err
}

if len(danglingPermissions) > 0 {
permissionIDs := make([]int, 0, len(danglingPermissions))
entityTypes := make([]dbCluster.EntityType, 0, len(danglingPermissions))
for _, perm := range danglingPermissions {
permissionIDs = append(permissionIDs, perm.ID)
if !shared.ValueInSlice(perm.EntityType, entityTypes) {
entityTypes = append(entityTypes, perm.EntityType)
}
}
}

// EntityURLs is a map of entity type, to entity ID, to api.URL.
entityURLs, err = dbCluster.GetPermissionEntityURLs(ctx, tx.Tx(), allGroupPermissions)
if err != nil {
return err
logger.Warn("Encountered dangling permissions", logger.Ctx{"permission_ids": permissionIDs, "entity_types": entityTypes})
}
}

Expand All @@ -233,29 +238,23 @@ func getAuthGroups(d *Daemon, r *http.Request) response.Response {
}

if recursion == "1" {
authGroupPermissionsByGroupID := make(map[int][]dbCluster.Permission, len(groups))
for _, permission := range authGroupPermissions {
authGroupPermissionsByGroupID[permission.GroupID] = append(authGroupPermissionsByGroupID[permission.GroupID], permission)
}

apiGroups := make([]api.AuthGroup, 0, len(groups))
for _, group := range groups {
var apiPermissions []api.Permission

// The group may not have any permissions.
permissions, ok := groupsPermissions[group.ID]
permissions, ok := authGroupPermissionsByGroupID[group.ID]
if ok {
apiPermissions = make([]api.Permission, 0, len(permissions))
for _, permission := range permissions {
// Expect to find any permissions in the entity URL map by its entity type and entity ID.
entityIDToURL, ok := entityURLs[entity.Type(permission.EntityType)]
if !ok {
return response.InternalError(fmt.Errorf("Entity URLs missing for permissions with entity type %q", permission.EntityType))
}

apiURL, ok := entityIDToURL[permission.EntityID]
if !ok {
return response.InternalError(fmt.Errorf("Entity URL missing for permission with entity type %q and entity ID `%d`", permission.EntityType, permission.EntityID))
}

apiPermissions = append(apiPermissions, api.Permission{
EntityType: string(permission.EntityType),
EntityReference: apiURL.String(),
EntityReference: entityURLs[entity.Type(permission.EntityType)][permission.EntityID].String(),
Entitlement: string(permission.Entitlement),
})
}
Expand Down Expand Up @@ -357,12 +356,7 @@ func createAuthGroup(d *Daemon, r *http.Request) response.Response {
return err
}

permissionIDs, err := upsertPermissions(ctx, tx.Tx(), group.Permissions)
if err != nil {
return err
}

err = dbCluster.SetAuthGroupPermissions(ctx, tx.Tx(), int(groupID), permissionIDs)
err = upsertPermissions(ctx, tx.Tx(), int(groupID), group.Permissions)
if err != nil {
return err
}
Expand Down Expand Up @@ -515,12 +509,7 @@ func updateAuthGroup(d *Daemon, r *http.Request) response.Response {
return err
}

permissionIDs, err := upsertPermissions(ctx, tx.Tx(), groupPut.Permissions)
if err != nil {
return err
}

err = dbCluster.SetAuthGroupPermissions(ctx, tx.Tx(), group.ID, permissionIDs)
err = upsertPermissions(ctx, tx.Tx(), group.ID, groupPut.Permissions)
if err != nil {
return err
}
Expand Down Expand Up @@ -618,12 +607,7 @@ func patchAuthGroup(d *Daemon, r *http.Request) response.Response {
}
}

permissionIDs, err := upsertPermissions(ctx, tx.Tx(), newPermissions)
if err != nil {
return err
}

err = dbCluster.SetAuthGroupPermissions(ctx, tx.Tx(), group.ID, permissionIDs)
err = upsertPermissions(ctx, tx.Tx(), group.ID, newPermissions)
if err != nil {
return err
}
Expand Down Expand Up @@ -825,16 +809,15 @@ func validatePermissions(permissions []api.Permission) error {
return nil
}

// upsertPermissions resolves the URLs of each permission to an entity ID and checks if the permission already
// exists (it may be assigned to another group already). If the permission does not already exist, it is created.
// A slice of permission IDs is returned that can be used to associate these permissions to a group.
func upsertPermissions(ctx context.Context, tx *sql.Tx, permissions []api.Permission) ([]int, error) {
// upsertPermissions converts the given slice of api.Permission into a slice of cluster.Permission by resolving
// the URLs of each permission to an entity ID. Then sets those permissions against the group with the given ID.
func upsertPermissions(ctx context.Context, tx *sql.Tx, groupID int, permissions []api.Permission) error {
entityReferences := make(map[*api.URL]*dbCluster.EntityRef, len(permissions))
permissionToURL := make(map[api.Permission]*api.URL, len(permissions))
for _, permission := range permissions {
u, err := url.Parse(permission.EntityReference)
if err != nil {
return nil, fmt.Errorf("Failed to parse permission entity reference: %w", err)
return fmt.Errorf("Failed to parse permission entity reference: %w", err)
}

apiURL := &api.URL{URL: *u}
Expand All @@ -844,40 +827,30 @@ func upsertPermissions(ctx context.Context, tx *sql.Tx, permissions []api.Permis

err := dbCluster.PopulateEntityReferencesFromURLs(ctx, tx, entityReferences)
if err != nil {
return nil, err
return err
}

var permissionIDs []int
authGroupPermissions := make([]dbCluster.Permission, 0, len(permissions))
for permission, apiURL := range permissionToURL {
entitlement := auth.Entitlement(permission.Entitlement)
entityType := dbCluster.EntityType(permission.EntityType)
entityRef, ok := entityReferences[apiURL]
if !ok {
return nil, fmt.Errorf("Missing entity ID for permission with URL %q", permission.EntityReference)
}

// Get the permission, if one is found, append its ID to the slice.
existingPermission, err := dbCluster.GetPermission(ctx, tx, entitlement, entityType, entityRef.EntityID)
if err == nil {
permissionIDs = append(permissionIDs, existingPermission.ID)
continue
} else if !api.StatusErrorCheck(err, http.StatusNotFound) {
return nil, fmt.Errorf("Failed to check if permission with entitlement %q and URL %q already exists: %w", entitlement, permission.EntityReference, err)
return api.StatusErrorf(http.StatusBadRequest, "Missing entity ID for permission with URL %q", permission.EntityReference)
}

// Generated "create" methods call cluster.GetPermission again to check if it exists. We already know that it doesn't exist, so create it directly.
res, err := tx.ExecContext(ctx, `INSERT INTO permissions (entitlement, entity_type, entity_id) VALUES (?, ?, ?)`, entitlement, entityType, entityRef.EntityID)
if err != nil {
return nil, fmt.Errorf("Failed to insert new permission: %w", err)
}

lastInsertID, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("Failed to get last insert ID of new permission: %w", err)
}
authGroupPermissions = append(authGroupPermissions, dbCluster.Permission{
GroupID: groupID,
Entitlement: entitlement,
EntityType: entityType,
EntityID: entityRef.EntityID,
})
}

permissionIDs = append(permissionIDs, int(lastInsertID))
err = dbCluster.SetAuthGroupPermissions(ctx, tx, groupID, authGroupPermissions)
if err != nil {
return fmt.Errorf("Failed to set group permissions: %w", err)
}

return permissionIDs, nil
return nil
}
50 changes: 28 additions & 22 deletions lxd/db/cluster/auth_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"fmt"

"github.com/canonical/lxd/lxd/db/query"
"github.com/canonical/lxd/shared"
"github.com/canonical/lxd/shared/api"
"github.com/canonical/lxd/shared/entity"
"github.com/canonical/lxd/shared/logger"
)

// Code generation directives.
Expand Down Expand Up @@ -60,11 +62,24 @@ func (g *AuthGroup) ToAPI(ctx context.Context, tx *sql.Tx) (*api.AuthGroup, erro
return nil, err
}

entityURLs, err := GetPermissionEntityURLs(ctx, tx, permissions)
permissions, danglingPermissions, entityURLs, err := GetPermissionEntityURLs(ctx, tx, permissions)
if err != nil {
return nil, err
}

if len(danglingPermissions) > 0 {
permissionIDs := make([]int, 0, len(danglingPermissions))
entityTypes := make([]EntityType, 0, len(danglingPermissions))
for _, perm := range danglingPermissions {
permissionIDs = append(permissionIDs, perm.ID)
if !shared.ValueInSlice(perm.EntityType, entityTypes) {
entityTypes = append(entityTypes, perm.EntityType)
}
}

logger.Warn("Encountered dangling permissions", logger.Ctx{"permission_ids": permissionIDs, "entity_types": entityTypes})
}

apiPermissions := make([]api.Permission, 0, len(permissions))
for _, p := range permissions {
entityURLs, ok := entityURLs[entity.Type(p.EntityType)]
Expand Down Expand Up @@ -231,15 +246,10 @@ JOIN auth_groups_identity_provider_groups ON identity_provider_groups.id = auth_

// GetPermissionsByAuthGroupID returns the permissions that belong to the group with the given ID.
func GetPermissionsByAuthGroupID(ctx context.Context, tx *sql.Tx, groupID int) ([]Permission, error) {
stmt := fmt.Sprintf(`
SELECT %s FROM permissions
JOIN auth_groups_permissions ON permissions.id = auth_groups_permissions.permission_id
WHERE auth_groups_permissions.auth_group_id = ?`, permissionColumns())

var result []Permission
dest := func(scan func(dest ...any) error) error {
p := Permission{}
err := scan(&p.ID, &p.Entitlement, &p.EntityType, &p.EntityID)
err := scan(&p.ID, &p.GroupID, &p.Entitlement, &p.EntityType, &p.EntityID)
if err != nil {
return err
}
Expand All @@ -248,31 +258,27 @@ WHERE auth_groups_permissions.auth_group_id = ?`, permissionColumns())
return nil
}

err := query.Scan(ctx, tx, stmt, dest, groupID)
err := query.Scan(ctx, tx, `SELECT id, auth_group_id, entitlement, entity_type, entity_id FROM auth_groups_permissions WHERE auth_group_id = ?`, dest, groupID)
if err != nil {
return nil, fmt.Errorf("Failed to get permissions for the group with ID `%d`: %w", groupID, err)
}

return result, nil
}

// GetAllPermissionsByAuthGroupIDs returns a map of group ID to the permissions that belong to the auth group with that ID.
func GetAllPermissionsByAuthGroupIDs(ctx context.Context, tx *sql.Tx) (map[int][]Permission, error) {
stmt := fmt.Sprintf(`
SELECT auth_groups_permissions.auth_group_id, %s
FROM permissions
JOIN auth_groups_permissions ON permissions.id = auth_groups_permissions.permission_id`, permissionColumns())
// GetPermissions returns a map of group ID to the permissions that belong to the auth group with that ID.
func GetPermissions(ctx context.Context, tx *sql.Tx) ([]Permission, error) {
stmt := `SELECT id, auth_group_id, entitlement, entity_type, entity_id FROM auth_groups_permissions`

result := make(map[int][]Permission)
var result []Permission
dest := func(scan func(dest ...any) error) error {
var groupID int
p := Permission{}
err := scan(&groupID, &p.ID, &p.Entitlement, &p.EntityType, &p.EntityID)
err := scan(&p.ID, &p.GroupID, &p.Entitlement, &p.EntityType, &p.EntityID)
if err != nil {
return err
}

result[groupID] = append(result[groupID], p)
result = append(result, p)
return nil
}

Expand All @@ -286,18 +292,18 @@ JOIN auth_groups_permissions ON permissions.id = auth_groups_permissions.permiss

// SetAuthGroupPermissions deletes all auth_group -> permission mappings from the `auth_group_permissions` table
// where the group ID is equal to the given value. Then it inserts a new row for each given permission ID.
func SetAuthGroupPermissions(ctx context.Context, tx *sql.Tx, groupID int, permissionIDs []int) error {
func SetAuthGroupPermissions(ctx context.Context, tx *sql.Tx, groupID int, authGroupPermissions []Permission) error {
_, err := tx.ExecContext(ctx, `DELETE FROM auth_groups_permissions WHERE auth_group_id = ?`, groupID)
if err != nil {
return fmt.Errorf("Failed to delete existing permissions for group with ID `%d`: %w", groupID, err)
}

if len(permissionIDs) == 0 {
if len(authGroupPermissions) == 0 {
return nil
}

for _, permissionID := range permissionIDs {
_, err := tx.ExecContext(ctx, `INSERT INTO auth_groups_permissions (auth_group_id, permission_id) VALUES (?, ?);`, groupID, permissionID)
for _, permission := range authGroupPermissions {
_, err := tx.ExecContext(ctx, `INSERT INTO auth_groups_permissions (auth_group_id, entity_type, entity_id, entitlement) VALUES (?, ?, ?, ?);`, permission.GroupID, permission.EntityType, permission.EntityID, permission.Entitlement)
if err != nil {
return fmt.Errorf("Failed to write group permissions: %w", err)
}
Expand Down
Loading

0 comments on commit 6c6d117

Please sign in to comment.