diff --git a/lacommunaute/forum/models.py b/lacommunaute/forum/models.py index 2bcb50f9..3663c4bd 100644 --- a/lacommunaute/forum/models.py +++ b/lacommunaute/forum/models.py @@ -14,6 +14,11 @@ from lacommunaute.utils.validators import validate_image_size +class ForumQuerySet(models.QuerySet): + def get_main_forum(self): + return self.filter(lft=1, level=0).first() + + class Forum(AbstractForum): short_description = models.CharField( max_length=400, blank=True, null=True, verbose_name="Description courte (SEO)" @@ -29,7 +34,7 @@ class Forum(AbstractForum): tags = TaggableManager() partner = models.ForeignKey(Partner, on_delete=models.CASCADE, null=True, blank=True) - objects = models.Manager() + objects = ForumQuerySet.as_manager() def get_absolute_url(self): return reverse( @@ -58,7 +63,7 @@ def is_in_documentation_area(self): @cached_property def is_toplevel_discussion_area(self): - return self == Forum.objects.filter(lft=1, level=0).first() + return self == Forum.objects.get_main_forum() def get_session_rating(self, session_key): return getattr(ForumRating.objects.filter(forum=self, session_id=session_key).first(), "rating", None) diff --git a/lacommunaute/forum/tests/tests_model.py b/lacommunaute/forum/tests/tests_model.py index bf0b5e4b..dcd0a13a 100644 --- a/lacommunaute/forum/tests/tests_model.py +++ b/lacommunaute/forum/tests/tests_model.py @@ -2,6 +2,7 @@ from django.test import TestCase from lacommunaute.forum.factories import CategoryForumFactory, ForumFactory, ForumRatingFactory +from lacommunaute.forum.models import Forum from lacommunaute.forum_conversation.factories import TopicFactory from lacommunaute.users.factories import UserFactory @@ -87,3 +88,15 @@ def test_get_average_rating(self): ForumRatingFactory(forum=forum, rating=5) self.assertEqual(forum.get_average_rating(), 3) + + +class TestForumQueryset: + def test_get_main_forum_wo_forum(self, db): + assert Forum.objects.get_main_forum() is None + + def test_get_main_forum_w_several_forums(self, db): + # level 0 + forums = ForumFactory.create_batch(2) + # level 1 + ForumFactory(parent=forums[0]) + assert Forum.objects.get_main_forum() == forums[0] diff --git a/lacommunaute/forum_conversation/views.py b/lacommunaute/forum_conversation/views.py index 3133d433..630fcd9a 100644 --- a/lacommunaute/forum_conversation/views.py +++ b/lacommunaute/forum_conversation/views.py @@ -140,7 +140,7 @@ def get_context_data(self, **kwargs): ) context["loadmoretopic_suffix"] = "topics" - context["forum"] = Forum.objects.filter(lft=1, level=0).first() + context["forum"] = Forum.objects.get_main_forum() context = context | self.get_topic_filter_context() return context diff --git a/lacommunaute/pages/views.py b/lacommunaute/pages/views.py index d1f2ab46..2ecc6f31 100644 --- a/lacommunaute/pages/views.py +++ b/lacommunaute/pages/views.py @@ -28,7 +28,7 @@ def get_context_data(self, **kwargs: Any) -> dict[str, Any]: context = super().get_context_data(**kwargs) context["topics_public"] = Topic.objects.filter(approved=True).order_by("-created")[:4] context["forums_category"] = Forum.objects.filter(parent__type=1).order_by("-updated")[:4] - context["forum"] = Forum.objects.filter(lft=1, level=0).first() + context["forum"] = Forum.objects.get_main_forum() context["upcoming_events"] = Event.objects.filter(date__gte=timezone.now()).order_by("date")[:4] return context diff --git a/lacommunaute/search/views.py b/lacommunaute/search/views.py index 6ff10ac7..4787f25e 100644 --- a/lacommunaute/search/views.py +++ b/lacommunaute/search/views.py @@ -57,5 +57,5 @@ def get_queryset(self): def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) - context["forum"] = Forum.objects.filter(lft=1, level=0).first() + context["forum"] = Forum.objects.get_main_forum() return context