diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index e077fea..8c6cc7b 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -10,6 +10,7 @@ from unittest.case import _AssertRaisesContext from django.db.models import Model +from django.forms.models import model_to_dict from rest_framework.serializers import BaseSerializer, ValidationError from ..serializers import ModelListSerializer, ModelSerializer @@ -124,11 +125,7 @@ def _assert_data_is_subset_of_model(self, data: DataDict, model): elif isinstance(value, Model): data[field] = getattr(value, "id") - model_dict = { - field: value - for field, value in model.__dict__.items() - if not field.startswith("_") - } + model_dict = model_to_dict(model) self.assertDictEqual(model_dict | data, model_dict) def _assert_many( @@ -340,7 +337,12 @@ def assert_update_many( ) def assert_to_representation( - self, instance: AnyModel, new_data: DataDict, *args, **kwargs + self, + instance: AnyModel, + new_data: DataDict, + *args, + non_model_fields: t.Optional[NonModelFields] = None, + **kwargs, ): """Assert: 1. the new data fields not contained in the model are equal. @@ -349,6 +351,7 @@ def assert_to_representation( Args: instance: The model instance to represent. new_data: The field values not contained in the model. + non_model_fields: Data fields that are not in the model. """ serializer = self._init_model_serializer(*args, **kwargs) data = serializer.to_representation(instance) @@ -357,12 +360,14 @@ def assert_new_data_is_subset_of_data(new_data: DataDict, data): assert isinstance(data, dict) for field, new_value in new_data.items(): + value = data[field] if isinstance(new_value, dict): - assert_new_data_is_subset_of_data(new_value, data[field]) + assert_new_data_is_subset_of_data(new_value, value) else: - assert new_value == data.pop(field) + assert new_value == value assert_new_data_is_subset_of_data(new_data, data) + data = self._get_data(data, None, non_model_fields) self._assert_data_is_subset_of_model(data, instance) diff --git a/codeforlife/user/serializers/user.py b/codeforlife/user/serializers/user.py index 15eefdd..9044cdc 100644 --- a/codeforlife/user/serializers/user.py +++ b/codeforlife/user/serializers/user.py @@ -64,7 +64,7 @@ class Meta(BaseUserSerializer.Meta): def to_representation(self, instance): try: student = ( - StudentSerializer(instance.new_student).data + dict(StudentSerializer(instance.new_student).data) if instance.new_student and instance.new_student.class_field else None ) @@ -83,7 +83,7 @@ def to_representation(self, instance): try: teacher = ( - TeacherSerializer[Teacher](instance.new_teacher).data + dict(TeacherSerializer[Teacher](instance.new_teacher).data) if instance.new_teacher else None ) diff --git a/codeforlife/user/serializers/user_test.py b/codeforlife/user/serializers/user_test.py index a2985ac..fe41d48 100644 --- a/codeforlife/user/serializers/user_test.py +++ b/codeforlife/user/serializers/user_test.py @@ -30,6 +30,8 @@ def test_to_representation__teacher(self): }, "student": None, }, + # TODO: remove in new schema. + non_model_fields=["requesting_to_join_class", "teacher", "student"], ) def test_to_representation__student(self): @@ -48,6 +50,8 @@ def test_to_representation__student(self): "school": user.student.class_field.teacher.school.id, }, }, + # TODO: remove in new schema. + non_model_fields=["requesting_to_join_class", "teacher", "student"], ) def test_to_representation__indy(self):