Skip to content

Commit

Permalink
feat(prices): new querysets to calculate min, max & average (#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn authored Sep 15, 2024
1 parent fe569d7 commit ab9ed15
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
35 changes: 34 additions & 1 deletion open_prices/prices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from django.core.validators import MinValueValidator, ValidationError
from django.db import models
from django.db.models import F, signals
from django.db.models import Avg, Count, F, Max, Min, signals
from django.db.models.functions import Cast
from django.dispatch import receiver
from django.utils import timezone
from openfoodfacts.taxonomy import get_taxonomy
Expand All @@ -17,6 +18,36 @@
from open_prices.users.models import User


class PriceQuerySet(models.QuerySet):
def exclude_discounted(self):
return self.filter(price_is_discounted=False)

def calculate_min(self):
return self.aggregate(Min("price"))["price__min"]

def calculate_max(self):
return self.aggregate(Max("price"))["price__max"]

def calculate_avg(self):
return self.aggregate(
price__avg=Cast(
Avg("price"),
output_field=models.DecimalField(max_digits=10, decimal_places=2),
)
)["price__avg"]

def calculate_stats(self):
return self.aggregate(
price__count=Count("pk"),
price__min=Min("price"),
price__max=Max("price"),
price__avg=Cast(
Avg("price"),
output_field=models.DecimalField(max_digits=10, decimal_places=2),
),
)


class Price(models.Model):
UPDATE_FIELDS = [
"price",
Expand Down Expand Up @@ -108,6 +139,8 @@ class Price(models.Model):
created = models.DateTimeField(default=timezone.now)
updated = models.DateTimeField(auto_now=True)

objects = models.Manager.from_queryset(PriceQuerySet)()

class Meta:
# managed = False
db_table = "prices"
Expand Down
47 changes: 47 additions & 0 deletions open_prices/prices/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from decimal import Decimal

from django.core.exceptions import ValidationError
from django.test import TestCase

Expand All @@ -6,6 +8,7 @@
from open_prices.locations.models import Location
from open_prices.prices import constants as price_constants
from open_prices.prices.factories import PriceFactory
from open_prices.prices.models import Price
from open_prices.products.factories import ProductFactory
from open_prices.products.models import Product
from open_prices.proofs import constants as proof_constants
Expand All @@ -15,6 +18,50 @@
from open_prices.users.models import User


class PriceQuerySetTest(TestCase):
@classmethod
def setUpTestData(cls):
PriceFactory(price=5, price_is_discounted=True, price_without_discount=10)
PriceFactory(price=8)
PriceFactory(price=10)

def test_exclude_discounted(self):
self.assertEqual(Price.objects.count(), 3)
self.assertEqual(Price.objects.exclude_discounted().count(), 2)

def test_min(self):
self.assertEqual(Price.objects.calculate_min(), 5)
self.assertEqual(Price.objects.exclude_discounted().calculate_min(), 8)

def test_max(self):
self.assertEqual(Price.objects.calculate_max(), 10)
self.assertEqual(Price.objects.exclude_discounted().calculate_max(), 10)

def test_avg(self):
self.assertEqual(Price.objects.calculate_avg(), Decimal("7.67"))
self.assertEqual(Price.objects.exclude_discounted().calculate_avg(), 9)

def test_calculate_stats(self):
self.assertEqual(
Price.objects.calculate_stats(),
{
"price__count": 3,
"price__min": 5,
"price__max": 10,
"price__avg": Decimal("7.67"),
},
)
self.assertEqual(
Price.objects.exclude_discounted().calculate_stats(),
{
"price__count": 2,
"price__min": 8,
"price__max": 10,
"price__avg": 9,
},
)


class PriceModelSaveTest(TestCase):
@classmethod
def setUpTestData(cls):
Expand Down

0 comments on commit ab9ed15

Please sign in to comment.