From 008eea42a0832cc72662caf663aa5fafe9aa9eb2 Mon Sep 17 00:00:00 2001 From: Stefan Kairinos Date: Thu, 29 Feb 2024 15:26:58 +0000 Subject: [PATCH] fix: bulk transfer students (#96) * fix: teacher field on class serializer * fix: help * fix: bulk update action generator * fix: test bulk_update * fix: support list and queryset during updates * Merge branch 'main' into bulk_transfer_students --- codeforlife/tests/model_view_set.py | 107 +++++++++++++++++++++++----- codeforlife/views/model.py | 55 ++++++++++++-- 2 files changed, 140 insertions(+), 22 deletions(-) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 297669e..20f2e90 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -25,7 +25,7 @@ AnyModel = t.TypeVar("AnyModel", bound=Model) -# pylint: disable=no-member +# pylint: disable=no-member,too-many-arguments class ModelViewSetClient(APIClient, t.Generic[AnyModel]): @@ -193,7 +193,6 @@ def retrieve( return response - # pylint: disable-next=too-many-arguments def list( self, models: t.Iterable[AnyModel], @@ -252,19 +251,19 @@ def _make_assertions(response_json: JsonDict): # Partial Update (HTTP PATCH) # -------------------------------------------------------------------------- - def _assert_partial_update( - self, model: AnyModel, json_model: JsonDict, action: str + def _assert_update( + self, + model: AnyModel, + json_model: JsonDict, + action: str, + request_method: str, + partial: bool, ): model.refresh_from_db() self._test_case.assert_serialized_model_equals_json_model( - model, - json_model, - action, - request_method="patch", - contains_subset=True, + model, json_model, action, request_method, contains_subset=partial ) - # pylint: disable-next=too-many-arguments def partial_update( self, model: AnyModel, @@ -305,17 +304,20 @@ def partial_update( if make_assertions: self._assert_response_json( response, - make_assertions=lambda json_model: self._assert_partial_update( - model, json_model, action="partial_update" + make_assertions=lambda json_model: self._assert_update( + model, + json_model, + action="partial_update", + request_method="patch", + partial=True, ), ) return response - # pylint: disable-next=too-many-arguments def bulk_partial_update( self, - models: t.List[AnyModel], + models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], data: t.List[DataDict], status_code_assertion: APIClient.StatusCodeAssertion = ( status.HTTP_200_OK @@ -338,6 +340,8 @@ def bulk_partial_update( The HTTP response. """ # pylint: enable=line-too-long + if not isinstance(models, list): + models = list(models) response: Response = self.patch( self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), @@ -355,8 +359,78 @@ def _make_assertions(json_models: t.List[JsonDict]): ) ) for model, json_model in zip(models, json_models): - self._assert_partial_update( - model, json_model, action="bulk" + self._assert_update( + model, + json_model, + action="bulk", + request_method="patch", + partial=True, + ) + + self._assert_response_json_bulk(response, _make_assertions, data) + + return response + + # -------------------------------------------------------------------------- + # Update (HTTP PUT) + # -------------------------------------------------------------------------- + + def bulk_update( + self, + models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], + data: t.List[DataDict], + action: str, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk update many instances of a model. + + Args: + models: The models to update. + data: The values for each field, for each model. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + if not isinstance(models, list): + models = list(models) + + assert len(models) == len(data) + + response = self.put( + self._test_case.reverse_action(action, kwargs=reverse_kwargs), + data={ + getattr(model, self._model_view_set_class.lookup_field): _data + for model, _data in zip(models, data) + }, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(json_models: t.List[JsonDict]): + models.sort( + key=lambda model: getattr( + model, self._model_view_set_class.lookup_field + ) + ) + for model, json_model in zip(models, json_models): + self._assert_update( + model, + json_model, + action, + request_method="put", + partial=False, ) self._assert_response_json_bulk(response, _make_assertions, data) @@ -516,7 +590,6 @@ def reverse_action( # Assertion Helpers # -------------------------------------------------------------------------- - # pylint: disable-next=too-many-arguments def assert_serialized_model_equals_json_model( self, model: AnyModel, diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index bfd8002..3233776 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -17,6 +17,7 @@ from ..permissions import Permission from ..request import Request from ..serializers import ModelListSerializer, ModelSerializer +from ..types import KwArgs from .api import APIView AnyModel = t.TypeVar("AnyModel", bound=Model) @@ -39,11 +40,6 @@ class ModelViewSet(APIView, _ModelViewSet[AnyModel], t.Generic[AnyModel]): serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] - def get_bulk_queryset(self, lookup_values: t.Collection): - return self.get_queryset().filter( - **{f"{self.lookup_field}__in": lookup_values} - ) - @classmethod def get_model_class(cls) -> t.Type[AnyModel]: """Get the model view set's class. @@ -142,6 +138,19 @@ def partial_update( # type: ignore[override] # Bulk Actions # -------------------------------------------------------------------------- + def get_bulk_queryset(self, lookup_values: t.Collection): + """Get the queryset for a bulk action. + + Args: + lookup_values: The values of the model's lookup field. + + Returns: + A queryset containing the matching models. + """ + return self.get_queryset().filter( + **{f"{self.lookup_field}__in": lookup_values} + ) + def bulk_create(self, request: Request): """Bulk create many instances of a model. @@ -249,3 +258,39 @@ def bulk(self, request: Request): "PATCH": self.bulk_partial_update, "DELETE": self.bulk_destroy, }[t.cast(str, request.method)](request) + + @staticmethod + def bulk_update_action( + name: str, + serializer_kwargs: t.Optional[KwArgs] = None, + response_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + """Generate a bulk-update action. + + Example usage: + ``` + class UserViewSet(ModelViewSet[User]): + rename = ModelViewSet.bulk_update_action(name="rename") + ``` + + Args: + name: The of the action's function name. + """ + + def bulk_update(self: ModelViewSet[AnyModel], request: Request): + queryset = self.get_bulk_queryset(request.json_dict.keys()) + serializer = self.get_serializer( + **(serializer_kwargs or {}), + instance=queryset, + data=request.data, + many=True, + context=self.get_serializer_context(), + ) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(**(response_kwargs or {}), data=serializer.data) + + bulk_update.__name__ = name + + return action(**kwargs, detail=False, methods=["put"])(bulk_update)