diff --git a/lacommunaute/forum_conversation/shortcuts.py b/lacommunaute/forum_conversation/shortcuts.py
index 02699f602..0b799196e 100644
--- a/lacommunaute/forum_conversation/shortcuts.py
+++ b/lacommunaute/forum_conversation/shortcuts.py
@@ -1,3 +1,4 @@
+from django.contrib.contenttypes.models import ContentType
from django.db.models import Count, Exists, OuterRef, Prefetch, Q, QuerySet
from lacommunaute.forum.enums import Kind as Forum_Kind
@@ -21,8 +22,13 @@ def get_posts_of_a_topic_except_first_one(topic: Topic, user: User) -> QuerySet[
if user.is_authenticated:
qs = qs.annotate(
upvotes_count=Count("upvotes"),
- # using user.id instead of user, to manage anonymous user journey
- has_upvoted=Exists(UpVote.objects.filter(post=OuterRef("pk"), voter=user)),
+ has_upvoted=Exists(
+ UpVote.objects.filter(
+ object_id=OuterRef("pk"),
+ voter=user,
+ content_type_id=ContentType.objects.get_for_model(qs.model).id,
+ )
+ ),
)
else:
qs = qs.annotate(
diff --git a/lacommunaute/forum_upvote/tests/tests_views.py b/lacommunaute/forum_upvote/tests/tests_views.py
index 270c43721..c697ee597 100644
--- a/lacommunaute/forum_upvote/tests/tests_views.py
+++ b/lacommunaute/forum_upvote/tests/tests_views.py
@@ -1,3 +1,4 @@
+from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from django.urls import reverse
from faker import Faker
@@ -44,12 +45,22 @@ def test_upvote_downvote_post(self):
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 200)
self.assertContains(response, '1')
- self.assertEqual(1, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count())
+ self.assertEqual(
+ 1,
+ UpVote.objects.filter(
+ voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post)
+ ).count(),
+ )
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 200)
self.assertContains(response, '0')
- self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count())
+ self.assertEqual(
+ 0,
+ UpVote.objects.filter(
+ voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post)
+ ).count(),
+ )
def test_object_not_found(self):
self.client.force_login(self.user)
@@ -58,13 +69,27 @@ def test_object_not_found(self):
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 404)
- self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count())
+ self.assertEqual(
+ 0,
+ UpVote.objects.filter(
+ voter_id=self.user.id,
+ object_id=self.topic.last_post.id,
+ content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id,
+ ).count(),
+ )
form_data = {"pk": self.topic.pk, "post_pk": 9999}
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 404)
- self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count())
+ self.assertEqual(
+ 0,
+ UpVote.objects.filter(
+ voter_id=self.user.id,
+ object_id=self.topic.last_post.id,
+ content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id,
+ ).count(),
+ )
def test_topic_is_marked_as_read_when_upvoting(self):
self.assertFalse(ForumReadTrack.objects.count())
diff --git a/lacommunaute/forum_upvote/views.py b/lacommunaute/forum_upvote/views.py
index 61fb70c37..9cfb6f5b9 100644
--- a/lacommunaute/forum_upvote/views.py
+++ b/lacommunaute/forum_upvote/views.py
@@ -1,5 +1,6 @@
import logging
+from django.contrib.contenttypes.models import ContentType
from django.shortcuts import get_object_or_404, render
from django.views import View
from machina.core.loading import get_class
@@ -33,13 +34,21 @@ def get_object(self):
def post(self, request, **kwargs):
post = self.get_object()
- upvote = UpVote.objects.filter(voter_id=request.user.id, post_id=post.id)
+ upvote = UpVote.objects.filter(
+ voter_id=request.user.id,
+ object_id=post.id,
+ content_type=ContentType.objects.get_for_model(post),
+ )
if upvote.exists():
upvote.delete()
post.has_upvoted = False
else:
- UpVote(voter_id=request.user.id, post_id=post.id).save()
+ UpVote(
+ voter_id=request.user.id,
+ object_id=post.id,
+ content_type=ContentType.objects.get_for_model(post),
+ ).save()
post.has_upvoted = True
post.upvotes_count = post.upvotes.count()