Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bulk transfer students #96

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 90 additions & 17 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

AnyModel = t.TypeVar("AnyModel", bound=Model)

# pylint: disable=no-member
# pylint: disable=no-member,too-many-arguments


class ModelViewSetClient(APIClient, t.Generic[AnyModel]):
Expand Down Expand Up @@ -193,7 +193,6 @@ def retrieve(

return response

# pylint: disable-next=too-many-arguments
def list(
self,
models: t.Iterable[AnyModel],
Expand Down Expand Up @@ -252,19 +251,19 @@ def _make_assertions(response_json: JsonDict):
# Partial Update (HTTP PATCH)
# --------------------------------------------------------------------------

def _assert_partial_update(
self, model: AnyModel, json_model: JsonDict, action: str
def _assert_update(
self,
model: AnyModel,
json_model: JsonDict,
action: str,
request_method: str,
partial: bool,
):
model.refresh_from_db()
self._test_case.assert_serialized_model_equals_json_model(
model,
json_model,
action,
request_method="patch",
contains_subset=True,
model, json_model, action, request_method, contains_subset=partial
)

# pylint: disable-next=too-many-arguments
def partial_update(
self,
model: AnyModel,
Expand Down Expand Up @@ -305,17 +304,20 @@ def partial_update(
if make_assertions:
self._assert_response_json(
response,
make_assertions=lambda json_model: self._assert_partial_update(
model, json_model, action="partial_update"
make_assertions=lambda json_model: self._assert_update(
model,
json_model,
action="partial_update",
request_method="patch",
partial=True,
),
)

return response

# pylint: disable-next=too-many-arguments
def bulk_partial_update(
self,
models: t.List[AnyModel],
models: t.Union[t.List[AnyModel], QuerySet[AnyModel]],
data: t.List[DataDict],
status_code_assertion: APIClient.StatusCodeAssertion = (
status.HTTP_200_OK
Expand All @@ -338,6 +340,8 @@ def bulk_partial_update(
The HTTP response.
"""
# pylint: enable=line-too-long
if not isinstance(models, list):
models = list(models)

response: Response = self.patch(
self._test_case.reverse_action("bulk", kwargs=reverse_kwargs),
Expand All @@ -355,8 +359,78 @@ def _make_assertions(json_models: t.List[JsonDict]):
)
)
for model, json_model in zip(models, json_models):
self._assert_partial_update(
model, json_model, action="bulk"
self._assert_update(
model,
json_model,
action="bulk",
request_method="patch",
partial=True,
)

self._assert_response_json_bulk(response, _make_assertions, data)

return response

# --------------------------------------------------------------------------
# Update (HTTP PUT)
# --------------------------------------------------------------------------

def bulk_update(
self,
models: t.Union[t.List[AnyModel], QuerySet[AnyModel]],
data: t.List[DataDict],
action: str,
status_code_assertion: APIClient.StatusCodeAssertion = (
status.HTTP_200_OK
),
make_assertions: bool = True,
reverse_kwargs: t.Optional[KwArgs] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Bulk update many instances of a model.

Args:
models: The models to update.
data: The values for each field, for each model.
status_code_assertion: The expected status code.
make_assertions: A flag designating whether to make the default assertions.
reverse_kwargs: The kwargs for the reverse URL.

Returns:
The HTTP response.
"""
# pylint: enable=line-too-long
if not isinstance(models, list):
models = list(models)

assert len(models) == len(data)

response = self.put(
self._test_case.reverse_action(action, kwargs=reverse_kwargs),
data={
getattr(model, self._model_view_set_class.lookup_field): _data
for model, _data in zip(models, data)
},
status_code_assertion=status_code_assertion,
**kwargs,
)

if make_assertions:

def _make_assertions(json_models: t.List[JsonDict]):
models.sort(
key=lambda model: getattr(
model, self._model_view_set_class.lookup_field
)
)
for model, json_model in zip(models, json_models):
self._assert_update(
model,
json_model,
action,
request_method="put",
partial=False,
)

self._assert_response_json_bulk(response, _make_assertions, data)
Expand Down Expand Up @@ -516,7 +590,6 @@ def reverse_action(
# Assertion Helpers
# --------------------------------------------------------------------------

# pylint: disable-next=too-many-arguments
def assert_serialized_model_equals_json_model(
self,
model: AnyModel,
Expand Down
55 changes: 50 additions & 5 deletions codeforlife/views/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..permissions import Permission
from ..request import Request
from ..serializers import ModelListSerializer, ModelSerializer
from ..types import KwArgs
from .api import APIView

AnyModel = t.TypeVar("AnyModel", bound=Model)
Expand All @@ -39,11 +40,6 @@ class ModelViewSet(APIView, _ModelViewSet[AnyModel], t.Generic[AnyModel]):

serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]]

def get_bulk_queryset(self, lookup_values: t.Collection):
return self.get_queryset().filter(
**{f"{self.lookup_field}__in": lookup_values}
)

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Expand Down Expand Up @@ -142,6 +138,19 @@ def partial_update( # type: ignore[override]
# Bulk Actions
# --------------------------------------------------------------------------

def get_bulk_queryset(self, lookup_values: t.Collection):
"""Get the queryset for a bulk action.

Args:
lookup_values: The values of the model's lookup field.

Returns:
A queryset containing the matching models.
"""
return self.get_queryset().filter(
**{f"{self.lookup_field}__in": lookup_values}
)

def bulk_create(self, request: Request):
"""Bulk create many instances of a model.

Expand Down Expand Up @@ -249,3 +258,39 @@ def bulk(self, request: Request):
"PATCH": self.bulk_partial_update,
"DELETE": self.bulk_destroy,
}[t.cast(str, request.method)](request)

@staticmethod
def bulk_update_action(
name: str,
serializer_kwargs: t.Optional[KwArgs] = None,
response_kwargs: t.Optional[KwArgs] = None,
**kwargs,
):
"""Generate a bulk-update action.

Example usage:
```
class UserViewSet(ModelViewSet[User]):
rename = ModelViewSet.bulk_update_action(name="rename")
```

Args:
name: The of the action's function name.
"""

def bulk_update(self: ModelViewSet[AnyModel], request: Request):
queryset = self.get_bulk_queryset(request.json_dict.keys())
serializer = self.get_serializer(
**(serializer_kwargs or {}),
instance=queryset,
data=request.data,
many=True,
context=self.get_serializer_context(),
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(**(response_kwargs or {}), data=serializer.data)

bulk_update.__name__ = name

return action(**kwargs, detail=False, methods=["put"])(bulk_update)
Loading