diff --git a/benefits/core/context_processors.py b/benefits/core/context_processors.py index 47b65e257..db602698d 100644 --- a/benefits/core/context_processors.py +++ b/benefits/core/context_processors.py @@ -26,7 +26,10 @@ def _agency_context(agency: models.TransitAgency): def agency(request): """Context processor adds some information about the active agency to the request context.""" - agency = session.agency(request) + if not hasattr(request, "user") or not request.user.is_authenticated: + agency = session.agency(request) + else: + agency = models.TransitAgency.for_user(request.user) if agency is None: return {} diff --git a/tests/pytest/core/test_context_processors.py b/tests/pytest/core/test_context_processors.py index 1733146ce..2563b3cfc 100644 --- a/tests/pytest/core/test_context_processors.py +++ b/tests/pytest/core/test_context_processors.py @@ -1,8 +1,9 @@ from datetime import datetime, timedelta, timezone +from django.contrib.auth.models import Group, User import pytest -from benefits.core import session -from benefits.core.context_processors import unique_values, enrollment +from benefits.core import session, models +from benefits.core.context_processors import unique_values, enrollment, agency def test_unique_values(): @@ -39,3 +40,48 @@ def test_enrollment_expiration(app_request, model_EligibilityType_supports_expir context = enrollment(app_request) assert context["enrollment"] == {"expires": expiry, "reenrollment": reenrollment, "supports_expiration": True} + + +@pytest.mark.django_db +def test_agency_unauthenticated_user(app_request, model_TransitAgency): + session.update(app_request, agency=model_TransitAgency) + + context = agency(app_request) + + assert context["agency"] == { + "eligibility_index_url": model_TransitAgency.eligibility_index_url, + "help_templates": unique_values( + [f.help_template for f in model_TransitAgency.enrollment_flows.all() if f.help_template] + ), + "info_url": model_TransitAgency.info_url, + "long_name": model_TransitAgency.long_name, + "phone": model_TransitAgency.phone, + "short_name": model_TransitAgency.short_name, + "slug": model_TransitAgency.slug, + } + + +@pytest.mark.django_db +def test_agency_authenticated_user(app_request, model_TransitAgency): + group = Group.objects.create(name="test_group") + + agency_for_user = models.TransitAgency.by_id(model_TransitAgency.id) + agency_for_user.pk = None + agency_for_user.group = group + agency_for_user.save() + + user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) + user.groups.add(group) + + app_request.user = user + context = agency(app_request) + + assert context["agency"] == { + "eligibility_index_url": agency_for_user.eligibility_index_url, + "help_templates": unique_values([f.help_template for f in agency_for_user.enrollment_flows.all() if f.help_template]), + "info_url": agency_for_user.info_url, + "long_name": agency_for_user.long_name, + "phone": agency_for_user.phone, + "short_name": agency_for_user.short_name, + "slug": agency_for_user.slug, + }