diff --git a/rocky/reports/templates/partials/return_button.html b/rocky/reports/templates/partials/return_button.html index ed038f61bd1..b86d6f6d58b 100644 --- a/rocky/reports/templates/partials/return_button.html +++ b/rocky/reports/templates/partials/return_button.html @@ -7,7 +7,7 @@ {% csrf_token %} {% include "forms/report_form_fields.html" %} - diff --git a/rocky/reports/views/aggregate_report.py b/rocky/reports/views/aggregate_report.py index 623905542ba..51d0f9da8b7 100644 --- a/rocky/reports/views/aggregate_report.py +++ b/rocky/reports/views/aggregate_report.py @@ -7,6 +7,7 @@ from django.urls import reverse from django.utils.http import urlencode from django.utils.translation import gettext_lazy as _ +from tools.view_helpers import PostRedirect from reports.report_types.aggregate_organisation_report.report import AggregateOrganisationReport from reports.report_types.definitions import AggregateReport, MultiReport, Report @@ -118,7 +119,7 @@ def setup(self, request, *args, **kwargs): def post(self, request, *args, **kwargs): if not self.selected_oois: messages.error(request, self.NONE_OOI_SELECTION_MESSAGE) - return redirect(self.get_previous()) + return PostRedirect(self.get_previous()) return self.get(request, *args, **kwargs) def get_report_types_for_aggregate_report( @@ -151,9 +152,17 @@ class SetupScanAggregateReportView( current_step = 3 def post(self, request, *args, **kwargs): + # If the user wants to change selection, but all plugins are enabled, it needs to go even further back if not self.selected_report_types: messages.error(request, self.NONE_REPORT_TYPE_SELECTION_MESSAGE) - return redirect(self.get_previous()) + return PostRedirect(self.get_previous()) + + if "return" in self.request.POST and self.plugins_enabled(): + return PostRedirect(self.get_previous()) + + if self.plugins_enabled(): + return PostRedirect(self.get_next()) + return self.get(request, *args, **kwargs) @@ -167,6 +176,9 @@ class ExportSetupAggregateReportView(AggregateReportStepsMixin, BreadcrumbsAggre current_step = 4 def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: + if not self.selected_report_types: + messages.error(request, self.NONE_REPORT_TYPE_SELECTION_MESSAGE) + return PostRedirect(self.get_previous()) return super().get(request, *args, **kwargs) def get_context_data(self, **kwargs): diff --git a/rocky/reports/views/generate_report.py b/rocky/reports/views/generate_report.py index cfd37c4ad05..a21e2557324 100644 --- a/rocky/reports/views/generate_report.py +++ b/rocky/reports/views/generate_report.py @@ -8,6 +8,7 @@ from django.urls import reverse from django.utils.http import urlencode from django.utils.translation import gettext_lazy as _ +from tools.view_helpers import PostRedirect from octopoes.models import Reference from reports.report_types.helpers import get_ooi_types_with_report, get_report_types_for_oois @@ -108,7 +109,7 @@ class ReportTypesSelectionGenerateReportView( def post(self, request, *args, **kwargs): if not self.selected_oois: messages.error(request, self.NONE_OOI_SELECTION_MESSAGE) - return redirect(self.get_previous()) + return PostRedirect(self.get_previous()) return self.get(request, *args, **kwargs) def get_context_data(self, **kwargs): @@ -135,7 +136,13 @@ class SetupScanGenerateReportView( def post(self, request, *args, **kwargs): if not self.selected_report_types: messages.error(request, self.NONE_REPORT_TYPE_SELECTION_MESSAGE) - return redirect(self.get_previous()) + return PostRedirect(self.get_previous()) + + if "return" in self.request.POST and self.plugins_enabled(): + return PostRedirect(self.get_previous()) + + if self.plugins_enabled(): + return PostRedirect(self.get_next()) return self.get(request, *args, **kwargs) @@ -150,6 +157,9 @@ class ExportSetupGenerateReportView(GenerateReportStepsMixin, BreadcrumbsGenerat reports: dict[str, str] = {} def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: + if not self.selected_report_types: + messages.error(request, self.NONE_REPORT_TYPE_SELECTION_MESSAGE) + return PostRedirect(self.get_previous()) self.reports = create_report_names(self.oois_pk, self.report_types) return super().get(request, *args, **kwargs) diff --git a/rocky/tests/conftest.py b/rocky/tests/conftest.py index a00690c900e..13b5c33b230 100644 --- a/rocky/tests/conftest.py +++ b/rocky/tests/conftest.py @@ -1785,3 +1785,24 @@ def boefje_dns_records(): runnable_hash=None, produces={"boefje/dns-records"}, ) + + +@pytest.fixture +def boefje_nmap_tcp(): + return Boefje( + id="nmap", + name="Nmap TCP", + version=None, + authors=None, + created=None, + description="Defaults to top 250 TCP ports. Includes service detection.", + environment_keys=None, + related=[], + enabled=True, + type="boefje", + scan_level=SCAN_LEVEL.L2, + consumes={IPAddressV4, IPAddressV6}, + options=None, + runnable_hash=None, + produces={"boefje/nmap"}, + ) diff --git a/rocky/tests/reports/test_aggregate_report_flow.py b/rocky/tests/reports/test_aggregate_report_flow.py index f3c6610cbf8..fca4d053401 100644 --- a/rocky/tests/reports/test_aggregate_report_flow.py +++ b/rocky/tests/reports/test_aggregate_report_flow.py @@ -177,7 +177,7 @@ def test_report_types_selection_nothing_selected( response = SetupScanAggregateReportView.as_view()(request, organization_code=client_member.organization.code) - assert response.status_code == 302 + assert response.status_code == 307 assert list(request._messages)[0].message == "Select at least one report type to proceed." @@ -189,6 +189,7 @@ def test_report_types_selection( listed_hostnames, mocker, boefje_dns_records, + boefje_nmap_tcp, rocky_health, mock_bytes_client, ): @@ -197,7 +198,7 @@ def test_report_types_selection( """ katalogus_mocker = mocker.patch("reports.views.base.get_katalogus")() - katalogus_mocker.get_plugins.return_value = [boefje_dns_records] + katalogus_mocker.get_plugins.return_value = [boefje_dns_records, boefje_nmap_tcp] rocky_health_mocker = mocker.patch("reports.report_types.aggregate_organisation_report.report.get_rocky_health")() rocky_health_mocker.return_value = rocky_health @@ -211,16 +212,17 @@ def test_report_types_selection( request = setup_request( rf.post( "aggregate_report_setup_scan", - {"observed_at": valid_time.strftime("%Y-%m-%d"), "report_type": "dns-report"}, + {"observed_at": valid_time.strftime("%Y-%m-%d"), "report_type": ["dns-report", "systems-report"]}, ), client_member.user, ) response = SetupScanAggregateReportView.as_view()(request, organization_code=client_member.organization.code) - assert response.status_code == 200 # if all plugins are enabled the view will auto redirect to generate report + assert response.status_code == 307 # if all plugins are enabled the view will auto redirect to generate report - assertContains(response, '', html=True) + # Redirect to export setup + assert response.headers["Location"] == "/en/test/reports/aggregate-report/export-setup/?" def test_save_aggregate_report_view( diff --git a/rocky/tests/reports/test_generate_report_flow.py b/rocky/tests/reports/test_generate_report_flow.py index 1271a69bff7..2cae8e51990 100644 --- a/rocky/tests/reports/test_generate_report_flow.py +++ b/rocky/tests/reports/test_generate_report_flow.py @@ -177,7 +177,8 @@ def test_report_types_selection_nothing_selected( response = SetupScanGenerateReportView.as_view()(request, organization_code=client_member.organization.code) - assert response.status_code == 302 + assert response.status_code == 307 + assert list(request._messages)[0].message == "Select at least one report type to proceed." @@ -214,8 +215,10 @@ def test_report_types_selection( response = SetupScanGenerateReportView.as_view()(request, organization_code=client_member.organization.code) - assert response.status_code == 200 - assertContains(response, '', html=True) + assert response.status_code == 307 + + # Redirect to export setup, all plugins are then enabled + assert response.headers["Location"] == "/en/test/reports/generate-report/export-setup/?" def test_save_generate_report_view( diff --git a/rocky/tools/view_helpers.py b/rocky/tools/view_helpers.py index e216d944334..9f7008662ed 100644 --- a/rocky/tools/view_helpers.py +++ b/rocky/tools/view_helpers.py @@ -4,6 +4,7 @@ from urllib.parse import urlencode, urlparse, urlunparse from django.http import HttpRequest +from django.http.response import HttpResponseRedirectBase from django.urls.base import reverse, reverse_lazy from django.utils.translation import gettext_lazy as _ @@ -163,3 +164,7 @@ def build_breadcrumbs(self): "text": _("Objects"), } ] + + +class PostRedirect(HttpResponseRedirectBase): + status_code = 307