Skip to content

Commit

Permalink
fix: bulk transfer students (#96)
Browse files Browse the repository at this point in the history
* fix: teacher field on class serializer

* fix: help

* fix: bulk update action generator

* fix: test bulk_update

* fix: support list and queryset during updates

* Merge branch 'main' into bulk_transfer_students
  • Loading branch information
SKairinos authored Feb 29, 2024
1 parent a125976 commit 008eea4
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 22 deletions.
107 changes: 90 additions & 17 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

AnyModel = t.TypeVar("AnyModel", bound=Model)

# pylint: disable=no-member
# pylint: disable=no-member,too-many-arguments


class ModelViewSetClient(APIClient, t.Generic[AnyModel]):
Expand Down Expand Up @@ -193,7 +193,6 @@ def retrieve(

return response

# pylint: disable-next=too-many-arguments
def list(
self,
models: t.Iterable[AnyModel],
Expand Down Expand Up @@ -252,19 +251,19 @@ def _make_assertions(response_json: JsonDict):
# Partial Update (HTTP PATCH)
# --------------------------------------------------------------------------

def _assert_partial_update(
self, model: AnyModel, json_model: JsonDict, action: str
def _assert_update(
self,
model: AnyModel,
json_model: JsonDict,
action: str,
request_method: str,
partial: bool,
):
model.refresh_from_db()
self._test_case.assert_serialized_model_equals_json_model(
model,
json_model,
action,
request_method="patch",
contains_subset=True,
model, json_model, action, request_method, contains_subset=partial
)

# pylint: disable-next=too-many-arguments
def partial_update(
self,
model: AnyModel,
Expand Down Expand Up @@ -305,17 +304,20 @@ def partial_update(
if make_assertions:
self._assert_response_json(
response,
make_assertions=lambda json_model: self._assert_partial_update(
model, json_model, action="partial_update"
make_assertions=lambda json_model: self._assert_update(
model,
json_model,
action="partial_update",
request_method="patch",
partial=True,
),
)

return response

# pylint: disable-next=too-many-arguments
def bulk_partial_update(
self,
models: t.List[AnyModel],
models: t.Union[t.List[AnyModel], QuerySet[AnyModel]],
data: t.List[DataDict],
status_code_assertion: APIClient.StatusCodeAssertion = (
status.HTTP_200_OK
Expand All @@ -338,6 +340,8 @@ def bulk_partial_update(
The HTTP response.
"""
# pylint: enable=line-too-long
if not isinstance(models, list):
models = list(models)

response: Response = self.patch(
self._test_case.reverse_action("bulk", kwargs=reverse_kwargs),
Expand All @@ -355,8 +359,78 @@ def _make_assertions(json_models: t.List[JsonDict]):
)
)
for model, json_model in zip(models, json_models):
self._assert_partial_update(
model, json_model, action="bulk"
self._assert_update(
model,
json_model,
action="bulk",
request_method="patch",
partial=True,
)

self._assert_response_json_bulk(response, _make_assertions, data)

return response

# --------------------------------------------------------------------------
# Update (HTTP PUT)
# --------------------------------------------------------------------------

def bulk_update(
self,
models: t.Union[t.List[AnyModel], QuerySet[AnyModel]],
data: t.List[DataDict],
action: str,
status_code_assertion: APIClient.StatusCodeAssertion = (
status.HTTP_200_OK
),
make_assertions: bool = True,
reverse_kwargs: t.Optional[KwArgs] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Bulk update many instances of a model.
Args:
models: The models to update.
data: The values for each field, for each model.
status_code_assertion: The expected status code.
make_assertions: A flag designating whether to make the default assertions.
reverse_kwargs: The kwargs for the reverse URL.
Returns:
The HTTP response.
"""
# pylint: enable=line-too-long
if not isinstance(models, list):
models = list(models)

assert len(models) == len(data)

response = self.put(
self._test_case.reverse_action(action, kwargs=reverse_kwargs),
data={
getattr(model, self._model_view_set_class.lookup_field): _data
for model, _data in zip(models, data)
},
status_code_assertion=status_code_assertion,
**kwargs,
)

if make_assertions:

def _make_assertions(json_models: t.List[JsonDict]):
models.sort(
key=lambda model: getattr(
model, self._model_view_set_class.lookup_field
)
)
for model, json_model in zip(models, json_models):
self._assert_update(
model,
json_model,
action,
request_method="put",
partial=False,
)

self._assert_response_json_bulk(response, _make_assertions, data)
Expand Down Expand Up @@ -516,7 +590,6 @@ def reverse_action(
# Assertion Helpers
# --------------------------------------------------------------------------

# pylint: disable-next=too-many-arguments
def assert_serialized_model_equals_json_model(
self,
model: AnyModel,
Expand Down
55 changes: 50 additions & 5 deletions codeforlife/views/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..permissions import Permission
from ..request import Request
from ..serializers import ModelListSerializer, ModelSerializer
from ..types import KwArgs
from .api import APIView

AnyModel = t.TypeVar("AnyModel", bound=Model)
Expand All @@ -39,11 +40,6 @@ class ModelViewSet(APIView, _ModelViewSet[AnyModel], t.Generic[AnyModel]):

serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]]

def get_bulk_queryset(self, lookup_values: t.Collection):
return self.get_queryset().filter(
**{f"{self.lookup_field}__in": lookup_values}
)

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Expand Down Expand Up @@ -142,6 +138,19 @@ def partial_update( # type: ignore[override]
# Bulk Actions
# --------------------------------------------------------------------------

def get_bulk_queryset(self, lookup_values: t.Collection):
"""Get the queryset for a bulk action.
Args:
lookup_values: The values of the model's lookup field.
Returns:
A queryset containing the matching models.
"""
return self.get_queryset().filter(
**{f"{self.lookup_field}__in": lookup_values}
)

def bulk_create(self, request: Request):
"""Bulk create many instances of a model.
Expand Down Expand Up @@ -249,3 +258,39 @@ def bulk(self, request: Request):
"PATCH": self.bulk_partial_update,
"DELETE": self.bulk_destroy,
}[t.cast(str, request.method)](request)

@staticmethod
def bulk_update_action(
name: str,
serializer_kwargs: t.Optional[KwArgs] = None,
response_kwargs: t.Optional[KwArgs] = None,
**kwargs,
):
"""Generate a bulk-update action.
Example usage:
```
class UserViewSet(ModelViewSet[User]):
rename = ModelViewSet.bulk_update_action(name="rename")
```
Args:
name: The of the action's function name.
"""

def bulk_update(self: ModelViewSet[AnyModel], request: Request):
queryset = self.get_bulk_queryset(request.json_dict.keys())
serializer = self.get_serializer(
**(serializer_kwargs or {}),
instance=queryset,
data=request.data,
many=True,
context=self.get_serializer_context(),
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(**(response_kwargs or {}), data=serializer.data)

bulk_update.__name__ = name

return action(**kwargs, detail=False, methods=["put"])(bulk_update)

0 comments on commit 008eea4

Please sign in to comment.