From 7e82b04e80ad248861bc603f6c0c8d4b75f8a06e Mon Sep 17 00:00:00 2001 From: John Tordoff <> Date: Thu, 24 Oct 2024 08:01:55 -0400 Subject: [PATCH] incorporate new renderer improvements into old views and make file filename more generic --- api/base/settings/defaults.py | 2 +- api/metrics/renderers.py | 34 +++++----- api/metrics/views.py | 64 ++++++++++++------- .../test_institution_user_metric_list.py | 13 ++-- 4 files changed, 62 insertions(+), 51 deletions(-) diff --git a/api/base/settings/defaults.py b/api/base/settings/defaults.py index a6ecb6b8e3f..7bbccb181ec 100644 --- a/api/base/settings/defaults.py +++ b/api/base/settings/defaults.py @@ -359,7 +359,7 @@ MAX_SIZE_OF_ES_QUERY = 10000 DEFAULT_ES_NULL_VALUE = 'N/A' -USER_INSTITUTION_REPORT_FILENAME = 'institution_user_report_{institution_id}_{date_created}.{format_type}' +REPORT_FILENAME_FORMAT = 'osf_report_{date_created}.{format_type}' CI_ENV = False diff --git a/api/metrics/renderers.py b/api/metrics/renderers.py index e99e5705d10..01599dda4dd 100644 --- a/api/metrics/renderers.py +++ b/api/metrics/renderers.py @@ -1,7 +1,7 @@ import csv import io import json -from api.base.settings.defaults import USER_INSTITUTION_REPORT_FILENAME, MAX_SIZE_OF_ES_QUERY +from api.base.settings.defaults import REPORT_FILENAME_FORMAT, MAX_SIZE_OF_ES_QUERY import datetime from django.http import Http404 @@ -19,10 +19,6 @@ def csv_fieldname_sortkey(fieldname): def get_nested_keys(report_attrs): - """ - Recursively retrieves all nested keys from the report attributes. - Handles both dictionaries and lists of attributes. - """ if isinstance(report_attrs, dict): for attr_key in sorted(report_attrs.keys(), key=csv_fieldname_sortkey): attr_value = report_attrs[attr_key] @@ -55,47 +51,44 @@ def get_csv_row(keys_list, report_attrs): class MetricsReportsBaseRenderer(renderers.BaseRenderer): + """ + This renderer should override the format parameter to send a Content-Disposition attachment of the file data via + the browser. + """ media_type: str format: str CSV_DIALECT: csv.Dialect extension: str def get_filename(self, renderer_context: dict, format_type: str) -> str: - """Generate the filename for the CSV/TSV file based on institution and current date.""" + """Generate the filename for the file based on format_type REPORT_FILENAME_FORMAT and current date.""" if renderer_context and 'view' in renderer_context: - current_date = datetime.datetime.now().strftime('%Y-%m') # Format as 'YYYY-MM' - return USER_INSTITUTION_REPORT_FILENAME.format( + current_date = datetime.datetime.now().strftime('%Y-%m') + return REPORT_FILENAME_FORMAT.format( date_created=current_date, - institution_id=renderer_context['view'].kwargs['institution_id'], format_type=format_type, ) else: raise NotImplementedError('Missing format filename') - def get_all_data(self, view, request): - """Bypass pagination by fetching all the data.""" - view.pagination_class = None # Disable pagination - return view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute() - def render(self, data: dict, accepted_media_type: str = None, renderer_context: dict = None) -> str: """Render the full dataset as CSV or TSV format.""" - data = self.get_all_data(renderer_context['view'], renderer_context['request']) + view = renderer_context['view'] + view.pagination_class = None # Disable pagination + data = view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute() hits = data.hits if not hits: raise Http404('

none found

') - # Assuming each hit contains '_source' with the relevant data first_row = hits[0].to_dict() csv_fieldnames = list(first_row) csv_filecontent = io.StringIO(newline='') csv_writer = csv.writer(csv_filecontent, dialect=self.CSV_DIALECT) csv_writer.writerow(csv_fieldnames) - # Write each hit's '_source' as a row in the CSV for hit in hits: csv_writer.writerow(get_csv_row(csv_fieldnames, hit.to_dict())) - # Set response headers for file download response = renderer_context['response'] filename = self.get_filename(renderer_context, self.extension) response['Content-Disposition'] = f'attachment; filename="{filename}"' @@ -130,7 +123,10 @@ def default_serializer(self, obj): def render(self, data, accepted_media_type=None, renderer_context=None): """Render the response as JSON format and trigger browser download as a binary file.""" - data = self.get_all_data(renderer_context['view'], renderer_context['request']) + view = renderer_context['view'] + view.pagination_class = None # Disable pagination + data = view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute() + hits = data.hits if not hits: raise Http404('

none found

') diff --git a/api/metrics/views.py b/api/metrics/views.py index 51556ddc89c..60a4835e98f 100644 --- a/api/metrics/views.py +++ b/api/metrics/views.py @@ -320,51 +320,67 @@ class RecentReportList(JSONAPIBaseView): MetricsReportsTsvRenderer, ) - def get(self, request, *args, report_name): + def get_default_search(self): try: - report_class = VIEWABLE_REPORTS[report_name] + report_class = VIEWABLE_REPORTS[self.kwargs['report_name']] except KeyError: - return Response( - { - 'errors': [{ - 'title': 'unknown report name', - 'detail': f'unknown report: "{report_name}"', - }], - }, - status=404, - ) + return None + is_daily = issubclass(report_class, reports.DailyReport) - days_back = request.GET.get('days_back', self.DEFAULT_DAYS_BACK if is_daily else None) is_monthly = issubclass(report_class, reports.MonthlyReport) + request = self.get_renderer_context()['request'] + days_back = request.GET.get('days_back', self.DEFAULT_DAYS_BACK if is_daily else None) + if is_daily: - serializer_class = DailyReportSerializer range_field_name = 'report_date' elif is_monthly: - serializer_class = MonthlyReportSerializer range_field_name = 'report_yearmonth' else: raise ValueError(f'report class must subclass DailyReport or MonthlyReport: {report_class}') + range_filter = parse_date_range(request.GET, is_monthly=is_monthly) - search_recent = ( - report_class.search() - .filter('range', **{range_field_name: range_filter}) - .sort(range_field_name) - [:self.MAX_COUNT] + search_recent = report_class.search().filter( + 'range', + **{range_field_name: range_filter}, + ).sort(range_field_name)[:self.MAX_COUNT] + + if is_daily and days_back: + search_recent = search_recent.filter('range', report_date={'gte': f'now/d-{days_back}d'}) + + return search_recent + + def get(self, request, *args, report_name): + search_response = self.get_default_search() + + if search_response is None: + return Response( + { + 'errors': [{ + 'title': 'unknown report name', + 'detail': f'unknown report: "{report_name}"', + }], + }, + status=404, + ) + + report_class = VIEWABLE_REPORTS[report_name] + serializer_class = ( + DailyReportSerializer if issubclass(report_class, reports.DailyReport) + else MonthlyReportSerializer ) - if days_back: - search_recent.filter('range', report_date={'gte': f'now/d-{days_back}d'}) - report_date_range = parse_date_range(request.GET) - search_response = search_recent.execute() serializer = serializer_class( search_response, many=True, context={'report_name': report_name}, ) + accepted_format = request.accepted_renderer.format response_headers = {} + if accepted_format in ('tsv', 'csv'): + report_date_range = parse_date_range(request.GET) from_date = report_date_range['gte'] until_date = report_date_range['lte'] filename = ( @@ -373,6 +389,8 @@ def get(self, request, *args, report_name): f'from_{from_date}.{accepted_format}' ) response_headers['Content-Disposition'] = f'attachment; filename={filename}' + + # Return the response with serialized data return Response( {'data': serializer.data}, headers=response_headers, diff --git a/api_tests/institutions/views/test_institution_user_metric_list.py b/api_tests/institutions/views/test_institution_user_metric_list.py index 2dcfcb7e3d0..046019e5616 100644 --- a/api_tests/institutions/views/test_institution_user_metric_list.py +++ b/api_tests/institutions/views/test_institution_user_metric_list.py @@ -8,7 +8,7 @@ import pytest from waffle.testutils import override_flag -from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE, USER_INSTITUTION_REPORT_FILENAME +from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE, REPORT_FILENAME_FORMAT import osf.features from osf_tests.factories import ( InstitutionFactory, @@ -435,9 +435,8 @@ def test_get_report_formats_csv_tsv(self, app, url, institutional_admin, institu assert resp.headers['Content-Type'] == content_type current_date = datetime.datetime.now().strftime('%Y-%m') - expected_filename = USER_INSTITUTION_REPORT_FILENAME.format( + expected_filename = REPORT_FILENAME_FORMAT.format( date_created=current_date, - institution_id=institution._id, format_type=format_type ) assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"' @@ -462,7 +461,7 @@ def test_get_report_formats_csv_tsv(self, app, url, institutional_admin, institu 'month_last_active', 'month_last_login', 'timestamp' - ], + ], [ '2024-08', institution._id, @@ -516,9 +515,8 @@ def test_get_report_format_json(self, app, url, institutional_admin, institution assert resp.headers['Content-Type'] == 'application/json; charset=utf-8' current_date = datetime.datetime.now().strftime('%Y-%m') - expected_filename = USER_INSTITUTION_REPORT_FILENAME.format( + expected_filename = REPORT_FILENAME_FORMAT.format( date_created=current_date, - institution_id=institution._id, format_type='json' ) assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"' @@ -602,9 +600,8 @@ def test_csv_tsv_ignores_pagination(self, app, url, institutional_admin, institu assert resp.headers['Content-Type'] == content_type current_date = datetime.datetime.now().strftime('%Y-%m') - expected_filename = USER_INSTITUTION_REPORT_FILENAME.format( + expected_filename = REPORT_FILENAME_FORMAT.format( date_created=current_date, - institution_id=institution._id, format_type=format_type ) assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"'