Skip to content

Commit

Permalink
support named groups in SAML claim (#6156)
Browse files Browse the repository at this point in the history
* support named groups in SAML claim

* support group names in addition to group IDs
* missing groups are automatically created
* comma-separated group values are trimmed

* add str for group_id in case of raise exception

---------

Co-authored-by: r350178982 <[email protected]>
  • Loading branch information
thewilli and r350178982 authored Jun 12, 2024
1 parent d1d8d08 commit ac1c004
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions seahub/adfs_auth/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from django.conf import settings
from django.contrib.auth.backends import ModelBackend
from django.core.cache import cache

from seaserv import ccnet_api, seafile_api

Expand All @@ -34,6 +35,7 @@

SAML_PROVIDER_IDENTIFIER = getattr(settings, 'SAML_PROVIDER_IDENTIFIER', 'saml')
SHIBBOLETH_AFFILIATION_ROLE_MAP = getattr(settings, 'SHIBBOLETH_AFFILIATION_ROLE_MAP', {})
CACHE_KEY_GROUPS = "all_groups_cache"


class Saml2Backend(ModelBackend):
Expand Down Expand Up @@ -196,9 +198,27 @@ def sync_saml_groups(self, user, attributes):

# support a list of comma-separated IDs as seafile_groups claim
if len(seafile_groups) == 1 and ',' in seafile_groups[0]:
seafile_groups = seafile_groups[0].split(',')
seafile_groups = [group.strip() for group in seafile_groups[0].split(',')]

saml_group_ids = [int(group_id) for group_id in seafile_groups]
if all(str(group_id).isdigit() for group_id in seafile_groups):
# all groups are provided as numeric IDs
saml_group_ids = [int(group_id) for group_id in seafile_groups]
else:
# groups are provided as names, try to get current group information from cache
all_groups = cache.get(CACHE_KEY_GROUPS)
if not all_groups or any(group not in all_groups for group in seafile_groups):
# groups not yet cached or missing entry, reload groups from API
all_groups = {group.group_name: group.id for group in ccnet_api.get_all_groups(-1, -1)}
cache.set(CACHE_KEY_GROUPS, all_groups, 3600) # cache for 1 hour
# create groups which are not yet existing
for group in [group_name for group_name in seafile_groups if group_name not in all_groups]:
new_group = ccnet_api.create_group(group, 'system admin') # we are not operating in user context here
if new_group < 0:
logger.error('failed to create group %s' % group)
return
all_groups[group] = new_group
# generate numeric IDs from group names
saml_group_ids = [id for group, id in all_groups.items() if group in seafile_groups]

joined_groups = ccnet_api.get_groups(user.username)
joined_group_ids = [g.id for g in joined_groups]
Expand Down

0 comments on commit ac1c004

Please sign in to comment.