From 7a49fa873f0e6849b340b0794851820c1b4f5058 Mon Sep 17 00:00:00 2001 From: vincent porte Date: Thu, 20 Jul 2023 18:12:55 +0200 Subject: [PATCH] =?UTF-8?q?feat(forum=5Fupvote):=C2=A0update=20filters=20o?= =?UTF-8?q?n=20UpVote=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lacommunaute/forum_conversation/shortcuts.py | 10 ++++-- .../forum_upvote/tests/tests_views.py | 33 ++++++++++++++++--- lacommunaute/forum_upvote/views.py | 13 ++++++-- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/lacommunaute/forum_conversation/shortcuts.py b/lacommunaute/forum_conversation/shortcuts.py index 02699f602..0b799196e 100644 --- a/lacommunaute/forum_conversation/shortcuts.py +++ b/lacommunaute/forum_conversation/shortcuts.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.models import ContentType from django.db.models import Count, Exists, OuterRef, Prefetch, Q, QuerySet from lacommunaute.forum.enums import Kind as Forum_Kind @@ -21,8 +22,13 @@ def get_posts_of_a_topic_except_first_one(topic: Topic, user: User) -> QuerySet[ if user.is_authenticated: qs = qs.annotate( upvotes_count=Count("upvotes"), - # using user.id instead of user, to manage anonymous user journey - has_upvoted=Exists(UpVote.objects.filter(post=OuterRef("pk"), voter=user)), + has_upvoted=Exists( + UpVote.objects.filter( + object_id=OuterRef("pk"), + voter=user, + content_type_id=ContentType.objects.get_for_model(qs.model).id, + ) + ), ) else: qs = qs.annotate( diff --git a/lacommunaute/forum_upvote/tests/tests_views.py b/lacommunaute/forum_upvote/tests/tests_views.py index 270c43721..c697ee597 100644 --- a/lacommunaute/forum_upvote/tests/tests_views.py +++ b/lacommunaute/forum_upvote/tests/tests_views.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.models import ContentType from django.test import TestCase from django.urls import reverse from faker import Faker @@ -44,12 +45,22 @@ def test_upvote_downvote_post(self): response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 200) self.assertContains(response, '1') - self.assertEqual(1, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count()) + self.assertEqual( + 1, + UpVote.objects.filter( + voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post) + ).count(), + ) response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 200) self.assertContains(response, '0') - self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count()) + self.assertEqual( + 0, + UpVote.objects.filter( + voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post) + ).count(), + ) def test_object_not_found(self): self.client.force_login(self.user) @@ -58,13 +69,27 @@ def test_object_not_found(self): response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 404) - self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count()) + self.assertEqual( + 0, + UpVote.objects.filter( + voter_id=self.user.id, + object_id=self.topic.last_post.id, + content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id, + ).count(), + ) form_data = {"pk": self.topic.pk, "post_pk": 9999} response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 404) - self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count()) + self.assertEqual( + 0, + UpVote.objects.filter( + voter_id=self.user.id, + object_id=self.topic.last_post.id, + content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id, + ).count(), + ) def test_topic_is_marked_as_read_when_upvoting(self): self.assertFalse(ForumReadTrack.objects.count()) diff --git a/lacommunaute/forum_upvote/views.py b/lacommunaute/forum_upvote/views.py index 61fb70c37..9cfb6f5b9 100644 --- a/lacommunaute/forum_upvote/views.py +++ b/lacommunaute/forum_upvote/views.py @@ -1,5 +1,6 @@ import logging +from django.contrib.contenttypes.models import ContentType from django.shortcuts import get_object_or_404, render from django.views import View from machina.core.loading import get_class @@ -33,13 +34,21 @@ def get_object(self): def post(self, request, **kwargs): post = self.get_object() - upvote = UpVote.objects.filter(voter_id=request.user.id, post_id=post.id) + upvote = UpVote.objects.filter( + voter_id=request.user.id, + object_id=post.id, + content_type=ContentType.objects.get_for_model(post), + ) if upvote.exists(): upvote.delete() post.has_upvoted = False else: - UpVote(voter_id=request.user.id, post_id=post.id).save() + UpVote( + voter_id=request.user.id, + object_id=post.id, + content_type=ContentType.objects.get_for_model(post), + ).save() post.has_upvoted = True post.upvotes_count = post.upvotes.count()