diff --git a/benefits/core/middleware.py b/benefits/core/middleware.py index 58e12742ce..f897f7997f 100644 --- a/benefits/core/middleware.py +++ b/benefits/core/middleware.py @@ -7,12 +7,13 @@ from django.http import HttpResponse from django.shortcuts import redirect from django.template.response import TemplateResponse +from django.urls import reverse from django.utils.decorators import decorator_from_middleware from django.utils.deprecation import MiddlewareMixin from django.views import i18n from . import analytics, recaptcha, session -from views import TEMPLATE_USER_ERROR +from views import ROUTE_INDEX, TEMPLATE_USER_ERROR logger = logging.getLogger(__name__) @@ -142,3 +143,17 @@ def process_request(self, request): "site_key": settings.RECAPTCHA_SITE_KEY, } return None + + +class IndexOrAgencyIndexOrigin(MiddlewareMixin): + """Middleware sets the session.origin to either the core:index or core:agency_index depending on agency config.""" + + def process_request(self, request): + if session.active_agency(request): + session.update(request, origin=session.agency(request).index_url) + else: + session.update(request, origin=reverse(ROUTE_INDEX)) + return None + + +index_or_agencyindex_origin_decorator = decorator_from_middleware(IndexOrAgencyIndexOrigin) diff --git a/benefits/core/views.py b/benefits/core/views.py index f280243d67..039ea73fc5 100644 --- a/benefits/core/views.py +++ b/benefits/core/views.py @@ -8,7 +8,7 @@ from django.utils.translation import pgettext, gettext as _ from . import models, session, viewmodels -from .middleware import pageview_decorator +from .middleware import pageview_decorator, index_or_agencyindex_origin_decorator ROUTE_INDEX = "core:index" ROUTE_ELIGIBILITY = "eligibility:index" @@ -74,19 +74,16 @@ def help(request): @pageview_decorator +@index_or_agencyindex_origin_decorator def bad_request(request, exception, template_name="400.html"): """View handler for HTTP 400 Bad Request responses.""" - if session.active_agency(request): - session.update(request, origin=session.agency(request).index_url) - else: - session.update(request, origin=reverse(ROUTE_INDEX)) - t = loader.get_template(template_name) return HttpResponseBadRequest(t.render()) @pageview_decorator +@index_or_agencyindex_origin_decorator def csrf_failure(request, reason): """ View handler for CSRF_FAILURE_VIEW with custom data. @@ -97,26 +94,18 @@ def csrf_failure(request, reason): @pageview_decorator +@index_or_agencyindex_origin_decorator def page_not_found(request, exception, template_name="404.html"): """View handler for HTTP 404 Not Found responses.""" - if session.active_agency(request): - session.update(request, origin=session.agency(request).index_url) - else: - session.update(request, origin=reverse(ROUTE_INDEX)) - t = loader.get_template(template_name) return HttpResponseNotFound(t.render()) @pageview_decorator +@index_or_agencyindex_origin_decorator def server_error(request, template_name="500.html"): """View handler for HTTP 500 Server Error responses.""" - if session.active_agency(request): - session.update(request, origin=session.agency(request).index_url) - else: - session.update(request, origin=reverse(ROUTE_INDEX)) - t = loader.get_template(template_name) return HttpResponseServerError(t.render())