Skip to content

Commit

Permalink
incorporate new renderer improvements into old views and make file fi…
Browse files Browse the repository at this point in the history
…lename more generic
  • Loading branch information
John Tordoff committed Oct 24, 2024
1 parent 09370fe commit 7e82b04
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 51 deletions.
2 changes: 1 addition & 1 deletion api/base/settings/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@

MAX_SIZE_OF_ES_QUERY = 10000
DEFAULT_ES_NULL_VALUE = 'N/A'
USER_INSTITUTION_REPORT_FILENAME = 'institution_user_report_{institution_id}_{date_created}.{format_type}'
REPORT_FILENAME_FORMAT = 'osf_report_{date_created}.{format_type}'

CI_ENV = False

Expand Down
34 changes: 15 additions & 19 deletions api/metrics/renderers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
import io
import json
from api.base.settings.defaults import USER_INSTITUTION_REPORT_FILENAME, MAX_SIZE_OF_ES_QUERY
from api.base.settings.defaults import REPORT_FILENAME_FORMAT, MAX_SIZE_OF_ES_QUERY
import datetime

from django.http import Http404
Expand All @@ -19,10 +19,6 @@ def csv_fieldname_sortkey(fieldname):


def get_nested_keys(report_attrs):
"""
Recursively retrieves all nested keys from the report attributes.
Handles both dictionaries and lists of attributes.
"""
if isinstance(report_attrs, dict):
for attr_key in sorted(report_attrs.keys(), key=csv_fieldname_sortkey):
attr_value = report_attrs[attr_key]
Expand Down Expand Up @@ -55,47 +51,44 @@ def get_csv_row(keys_list, report_attrs):


class MetricsReportsBaseRenderer(renderers.BaseRenderer):
"""
This renderer should override the format parameter to send a Content-Disposition attachment of the file data via
the browser.
"""
media_type: str
format: str
CSV_DIALECT: csv.Dialect
extension: str

def get_filename(self, renderer_context: dict, format_type: str) -> str:
"""Generate the filename for the CSV/TSV file based on institution and current date."""
"""Generate the filename for the file based on format_type REPORT_FILENAME_FORMAT and current date."""
if renderer_context and 'view' in renderer_context:
current_date = datetime.datetime.now().strftime('%Y-%m') # Format as 'YYYY-MM'
return USER_INSTITUTION_REPORT_FILENAME.format(
current_date = datetime.datetime.now().strftime('%Y-%m')
return REPORT_FILENAME_FORMAT.format(
date_created=current_date,
institution_id=renderer_context['view'].kwargs['institution_id'],
format_type=format_type,
)
else:
raise NotImplementedError('Missing format filename')

def get_all_data(self, view, request):
"""Bypass pagination by fetching all the data."""
view.pagination_class = None # Disable pagination
return view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute()

def render(self, data: dict, accepted_media_type: str = None, renderer_context: dict = None) -> str:
"""Render the full dataset as CSV or TSV format."""
data = self.get_all_data(renderer_context['view'], renderer_context['request'])
view = renderer_context['view']
view.pagination_class = None # Disable pagination
data = view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute()
hits = data.hits
if not hits:
raise Http404('<h1>none found</h1>')

# Assuming each hit contains '_source' with the relevant data
first_row = hits[0].to_dict()
csv_fieldnames = list(first_row)
csv_filecontent = io.StringIO(newline='')
csv_writer = csv.writer(csv_filecontent, dialect=self.CSV_DIALECT)
csv_writer.writerow(csv_fieldnames)

# Write each hit's '_source' as a row in the CSV
for hit in hits:
csv_writer.writerow(get_csv_row(csv_fieldnames, hit.to_dict()))

# Set response headers for file download
response = renderer_context['response']
filename = self.get_filename(renderer_context, self.extension)
response['Content-Disposition'] = f'attachment; filename="{filename}"'
Expand Down Expand Up @@ -130,7 +123,10 @@ def default_serializer(self, obj):

def render(self, data, accepted_media_type=None, renderer_context=None):
"""Render the response as JSON format and trigger browser download as a binary file."""
data = self.get_all_data(renderer_context['view'], renderer_context['request'])
view = renderer_context['view']
view.pagination_class = None # Disable pagination
data = view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute()

hits = data.hits
if not hits:
raise Http404('<h1>none found</h1>')
Expand Down
64 changes: 41 additions & 23 deletions api/metrics/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,51 +320,67 @@ class RecentReportList(JSONAPIBaseView):
MetricsReportsTsvRenderer,
)

def get(self, request, *args, report_name):
def get_default_search(self):
try:
report_class = VIEWABLE_REPORTS[report_name]
report_class = VIEWABLE_REPORTS[self.kwargs['report_name']]
except KeyError:
return Response(
{
'errors': [{
'title': 'unknown report name',
'detail': f'unknown report: "{report_name}"',
}],
},
status=404,
)
return None

is_daily = issubclass(report_class, reports.DailyReport)
days_back = request.GET.get('days_back', self.DEFAULT_DAYS_BACK if is_daily else None)
is_monthly = issubclass(report_class, reports.MonthlyReport)

request = self.get_renderer_context()['request']
days_back = request.GET.get('days_back', self.DEFAULT_DAYS_BACK if is_daily else None)

if is_daily:
serializer_class = DailyReportSerializer
range_field_name = 'report_date'
elif is_monthly:
serializer_class = MonthlyReportSerializer
range_field_name = 'report_yearmonth'
else:
raise ValueError(f'report class must subclass DailyReport or MonthlyReport: {report_class}')

range_filter = parse_date_range(request.GET, is_monthly=is_monthly)
search_recent = (
report_class.search()
.filter('range', **{range_field_name: range_filter})
.sort(range_field_name)
[:self.MAX_COUNT]
search_recent = report_class.search().filter(
'range',
**{range_field_name: range_filter},
).sort(range_field_name)[:self.MAX_COUNT]

if is_daily and days_back:
search_recent = search_recent.filter('range', report_date={'gte': f'now/d-{days_back}d'})

return search_recent

def get(self, request, *args, report_name):
search_response = self.get_default_search()

if search_response is None:
return Response(
{
'errors': [{
'title': 'unknown report name',
'detail': f'unknown report: "{report_name}"',
}],
},
status=404,
)

report_class = VIEWABLE_REPORTS[report_name]
serializer_class = (
DailyReportSerializer if issubclass(report_class, reports.DailyReport)
else MonthlyReportSerializer
)
if days_back:
search_recent.filter('range', report_date={'gte': f'now/d-{days_back}d'})

report_date_range = parse_date_range(request.GET)
search_response = search_recent.execute()
serializer = serializer_class(
search_response,
many=True,
context={'report_name': report_name},
)

accepted_format = request.accepted_renderer.format
response_headers = {}

if accepted_format in ('tsv', 'csv'):
report_date_range = parse_date_range(request.GET)
from_date = report_date_range['gte']
until_date = report_date_range['lte']
filename = (
Expand All @@ -373,6 +389,8 @@ def get(self, request, *args, report_name):
f'from_{from_date}.{accepted_format}'
)
response_headers['Content-Disposition'] = f'attachment; filename={filename}'

# Return the response with serialized data
return Response(
{'data': serializer.data},
headers=response_headers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from waffle.testutils import override_flag

from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE, USER_INSTITUTION_REPORT_FILENAME
from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE, REPORT_FILENAME_FORMAT
import osf.features
from osf_tests.factories import (
InstitutionFactory,
Expand Down Expand Up @@ -435,9 +435,8 @@ def test_get_report_formats_csv_tsv(self, app, url, institutional_admin, institu
assert resp.headers['Content-Type'] == content_type

current_date = datetime.datetime.now().strftime('%Y-%m')
expected_filename = USER_INSTITUTION_REPORT_FILENAME.format(
expected_filename = REPORT_FILENAME_FORMAT.format(
date_created=current_date,
institution_id=institution._id,
format_type=format_type
)
assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"'
Expand All @@ -462,7 +461,7 @@ def test_get_report_formats_csv_tsv(self, app, url, institutional_admin, institu
'month_last_active',
'month_last_login',
'timestamp'
],
],
[
'2024-08',
institution._id,
Expand Down Expand Up @@ -516,9 +515,8 @@ def test_get_report_format_json(self, app, url, institutional_admin, institution
assert resp.headers['Content-Type'] == 'application/json; charset=utf-8'

current_date = datetime.datetime.now().strftime('%Y-%m')
expected_filename = USER_INSTITUTION_REPORT_FILENAME.format(
expected_filename = REPORT_FILENAME_FORMAT.format(
date_created=current_date,
institution_id=institution._id,
format_type='json'
)
assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"'
Expand Down Expand Up @@ -602,9 +600,8 @@ def test_csv_tsv_ignores_pagination(self, app, url, institutional_admin, institu
assert resp.headers['Content-Type'] == content_type

current_date = datetime.datetime.now().strftime('%Y-%m')
expected_filename = USER_INSTITUTION_REPORT_FILENAME.format(
expected_filename = REPORT_FILENAME_FORMAT.format(
date_created=current_date,
institution_id=institution._id,
format_type=format_type
)
assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"'
Expand Down

0 comments on commit 7e82b04

Please sign in to comment.