Skip to content

Commit

Permalink
add management command to migrate preprint affiliations
Browse files Browse the repository at this point in the history
  • Loading branch information
John Tordoff committed Oct 30, 2024
1 parent 9394a0f commit bd7b2f3
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
76 changes: 76 additions & 0 deletions osf/management/commands/migrate_preprint_affiliation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import datetime
import logging

from django.core.management.base import BaseCommand
from osf.models import Preprint

logger = logging.getLogger(__name__)


class Command(BaseCommand):
"""Assign affiliations from preprint creators, with optional exclusion by user GUIDs.
"""

def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
'--exclude-guids',
nargs='+',
dest='exclude_guids',
help='List of user GUIDs to exclude from affiliation assignment'
)
parser.add_argument(
'--dry',
action='store_true',
dest='dry_run',
help='If true, iterates through preprints without making changes'
)

def handle(self, *args, **options):
start_time = datetime.datetime.now()
logger.info(f'Script started at: {start_time}')

exclude_guids = set(options.get('exclude_guids', []))
dry_run = options.get('dry_run', False)

if dry_run:
logger.info('Dry Run mode activated')

processed_count, updated_count, skipped_count = assign_creator_affiliations_to_preprints(
exclude_guids=exclude_guids, dry_run=dry_run)

finish_time = datetime.datetime.now()
logger.info(f'Script finished at: {finish_time}')
logger.info(f'Total processed: {processed_count}, Updated: {updated_count}, Skipped: {skipped_count}')
logger.info(f'Total run time: {finish_time - start_time}')


def assign_creator_affiliations_to_preprints(exclude_guids=None, dry_run=True):
exclude_guids = exclude_guids or set()
preprints = Preprint.objects.select_related('creator').all()

processed_count = updated_count = skipped_count = 0

for preprint in preprints:
processed_count += 1
creator = preprint.creator

if not creator:
skipped_count += 1
continue

if creator._id in exclude_guids or not creator.affiliated_institutions.exists():
skipped_count += 1
continue

if not dry_run:
affiliations = [
preprint.affiliated_institutions.get_or_create(institution=inst)[1]
for inst in creator.affiliated_institutions.all()
]
updated_count += sum(affiliations)
else:
logger.info(f'Dry Run: Would assign {creator.affiliated_institutions.count()} affiliations '
f'to preprint <{preprint._id}>')

return processed_count, updated_count, skipped_count
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from osf.management.commands.assign_creator_affiliations_to_preprints import assign_creator_affiliations_to_preprints
from osf.models import Preprint, Institution, OSFUser
from osf_tests.factories import PreprintFactory, InstitutionFactory, AuthUserFactory

@pytest.mark.django_db
class TestAssignCreatorAffiliationsToPreprints:

@pytest.fixture()
def institution(self):
return InstitutionFactory()

@pytest.fixture()
def user_with_affiliation(self, institution):
user = AuthUserFactory()
user.affiliated_institutions.add(institution)
user.save()
return user

@pytest.fixture()
def user_without_affiliation(self):
return AuthUserFactory()

@pytest.fixture()
def preprint_with_affiliated_creator(self, user_with_affiliation):
return PreprintFactory(creator=user_with_affiliation)

@pytest.fixture()
def preprint_with_non_affiliated_creator(self, user_without_affiliation):
return PreprintFactory(creator=user_without_affiliation)

@pytest.mark.parametrize("dry_run", [True, False])
def test_assign_affiliations_with_affiliated_creator(self, preprint_with_affiliated_creator, institution, dry_run):
assert preprint_with_affiliated_creator.affiliated_institutions.count() == 0

assign_creator_affiliations_to_preprints(dry_run=dry_run)

if dry_run:
assert preprint_with_affiliated_creator.affiliated_institutions.count() == 0
else:
assert institution in preprint_with_affiliated_creator.affiliated_institutions.all()

@pytest.mark.parametrize("dry_run", [True, False])
def test_no_affiliations_for_non_affiliated_creator(self, preprint_with_non_affiliated_creator, dry_run):
assign_creator_affiliations_to_preprints(dry_run=dry_run)
assert preprint_with_non_affiliated_creator.affiliated_institutions.count() == 0

@pytest.mark.parametrize("dry_run", [True, False])
def test_exclude_creator_by_guid(self, preprint_with_affiliated_creator, institution, dry_run):
exclude_guid = preprint_with_affiliated_creator.creator._id
assign_creator_affiliations_to_preprints(exclude_guids={exclude_guid}, dry_run=dry_run)

assert preprint_with_affiliated_creator.affiliated_institutions.count() == 0

0 comments on commit bd7b2f3

Please sign in to comment.