diff --git a/ninja/main.py b/ninja/main.py index 67216ac5..9c802abd 100644 --- a/ninja/main.py +++ b/ninja/main.py @@ -457,16 +457,18 @@ def create_response( response.content = content else: response = HttpResponse( - content, status=status, content_type=self.get_content_type() + content, status=status, content_type=self.get_content_type(request) ) return response def create_temporal_response(self, request: HttpRequest) -> HttpResponse: - return HttpResponse("", content_type=self.get_content_type()) + return HttpResponse("", content_type=self.get_content_type(request)) - def get_content_type(self) -> str: - return f"{self.renderer.media_type}; charset={self.renderer.charset}" + def get_content_type(self, request: HttpRequest) -> str: + return ( + f"{self.renderer.get_media_type(request)}; charset={self.renderer.charset}" + ) def get_openapi_schema( self, diff --git a/ninja/renderers.py b/ninja/renderers.py index 16e53f37..c3286778 100644 --- a/ninja/renderers.py +++ b/ninja/renderers.py @@ -1,7 +1,9 @@ +import itertools import json -from typing import Any, Mapping, Optional, Type +from typing import Any, List, Mapping, Type from django.http import HttpRequest +from django.http.request import parse_accept_header from ninja.responses import NinjaJSONEncoder @@ -9,13 +11,32 @@ class BaseRenderer: - media_type: Optional[str] = None + media_type: str charset: str = "utf-8" + def get_media_type(self, request: HttpRequest) -> str: + return self.media_type + def render(self, request: HttpRequest, data: Any, *, response_status: int) -> Any: raise NotImplementedError("Please implement .render() method") +class BaseDynamicRenderer(BaseRenderer): + media_types: List[str] + + def get_media_type(self, request: HttpRequest) -> str: + accepted_media_types = parse_accept_header(request.headers.get("accept", "*/*")) + media_type_gen = ( + media_type + for media_type, accepted_type in itertools.product( + self.media_types, accepted_media_types + ) + if accepted_type.match(media_type) + ) + + return next(media_type_gen, self.media_type) + + class JSONRenderer(BaseRenderer): media_type = "application/json" encoder_class: Type[json.JSONEncoder] = NinjaJSONEncoder diff --git a/tests/test_renderer.py b/tests/test_renderer.py index a51082d1..28dab572 100644 --- a/tests/test_renderer.py +++ b/tests/test_renderer.py @@ -1,14 +1,38 @@ +import json from io import StringIO +from unittest.mock import Mock import pytest from django.utils.encoding import force_str from django.utils.xmlutils import SimplerXMLGenerator from ninja import NinjaAPI -from ninja.renderers import BaseRenderer +from ninja.renderers import BaseDynamicRenderer, BaseRenderer +from ninja.responses import NinjaJSONEncoder from ninja.testing import TestClient +def _to_xml(xml, data): + if isinstance(data, (list, tuple)): + for item in data: + xml.startElement("item", {}) + _to_xml(xml, item) + xml.endElement("item") + + elif isinstance(data, dict): + for key, value in data.items(): + xml.startElement(key, {}) + _to_xml(xml, value) + xml.endElement(key) + + elif data is None: + # Don't output any value + pass + + else: + xml.characters(force_str(data)) + + class XMLRenderer(BaseRenderer): media_type = "text/xml" @@ -17,41 +41,55 @@ def render(self, request, data, *, response_status): xml = SimplerXMLGenerator(stream, "utf-8") xml.startDocument() xml.startElement("data", {}) - self._to_xml(xml, data) + _to_xml(xml, data) xml.endElement("data") xml.endDocument() return stream.getvalue() - def _to_xml(self, xml, data): - if isinstance(data, (list, tuple)): - for item in data: - xml.startElement("item", {}) - self._to_xml(xml, item) - xml.endElement("item") - elif isinstance(data, dict): - for key, value in data.items(): - xml.startElement(key, {}) - self._to_xml(xml, value) - xml.endElement(key) +class CSVRenderer(BaseRenderer): + media_type = "text/csv" - elif data is None: - # Don't output any value - pass + def render(self, request, data, *, response_status): + content = [",".join(data[0].keys())] + for item in data: + content.append(",".join(item.values())) + return "\n".join(content) - else: - xml.characters(force_str(data)) - -class CSVRenderer(BaseRenderer): - media_type = "text/csv" +class DynamicRenderer(BaseDynamicRenderer): + media_type = "application/json" + media_types = ["application/json", "text/csv", "text/xml"] def render(self, request, data, *, response_status): + accept = request.headers.get("accept", "application/json") + + if accept.startswith("text/xml"): + return self.render_xml(data) + elif accept.startswith("text/csv"): + return self.render_csv(data) + else: + return self.render_json(data) + + def render_csv(self, data): content = [",".join(data[0].keys())] for item in data: content.append(",".join(item.values())) return "\n".join(content) + def render_xml(self, data): + stream = StringIO() + xml = SimplerXMLGenerator(stream, "utf-8") + xml.startDocument() + xml.startElement("data", {}) + _to_xml(xml, data) + xml.endElement("data") + xml.endDocument() + return stream.getvalue() + + def render_json(self, data): + return json.dumps(data, cls=NinjaJSONEncoder) + def operation(request): return [ @@ -62,10 +100,12 @@ def operation(request): api_xml = NinjaAPI(renderer=XMLRenderer()) api_csv = NinjaAPI(renderer=CSVRenderer()) +api_dynamic = NinjaAPI(renderer=DynamicRenderer()) api_xml.get("/test")(operation) api_csv.get("/test")(operation) +api_dynamic.get("/test")(operation) @pytest.mark.parametrize( @@ -94,10 +134,59 @@ def test_response_class(api, content_type, expected_content): assert response.content.decode() == expected_content -def test_implment_render(): - class FooRenderer(BaseRenderer): +@pytest.mark.parametrize( + "accept,expected_content", + [ + ( + "text/xml; charset=utf-8", + '\n' + "JonathanDoe" + "SarahCalvin" + "", + ), + ( + "text/csv; charset=utf-8", + "name,lastname\nJonathan,Doe\nSarah,Calvin", + ), + ( + "application/json; charset=utf-8", + '[{"name": "Jonathan", "lastname": "Doe"}, {"name": "Sarah", "lastname": "Calvin"}]', + ), + ], +) +def test_dynamic_response_class(accept, expected_content): + client = TestClient(api_dynamic) + response = client.get("/test", headers={"Accept": accept}) + assert response.status_code == 200 + assert response["Content-Type"] == accept + assert response.content.decode() == expected_content + + +@pytest.mark.parametrize("Base", [BaseRenderer, BaseDynamicRenderer]) +def test_implement_render(Base): + class FooRenderer(Base): pass renderer = FooRenderer() with pytest.raises(NotImplementedError): renderer.render(None, None, response_status=200) + + +@pytest.mark.parametrize( + "accept,expected_media_type", + [ + ("text/xml", "text/xml"), + ("text/csv", "text/csv"), + ("*/*", "text/xml"), + ("blahblahblah", "text/xml"), + ], +) +def test_get_media_type(accept, expected_media_type): + class FooRenderer(BaseDynamicRenderer): + media_type = "text/xml" + media_types = ["text/xml", "text/csv"] + + request = Mock() + request.headers = {"accept": accept} + + assert FooRenderer().get_media_type(request) == expected_media_type