Skip to content

Commit

Permalink
refactor: add model with_stats() queryset. Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn committed Sep 8, 2024
1 parent fc71abf commit 5aa00e7
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 3 deletions.
12 changes: 11 additions & 1 deletion open_prices/locations/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.core.validators import ValidationError
from django.db import models
from django.db.models import signals
from django.db.models import Count, signals
from django.dispatch import receiver
from django.utils import timezone
from django_q.tasks import async_task
Expand All @@ -10,6 +10,14 @@
from open_prices.locations import constants as location_constants


class LocationQuerySet(models.QuerySet):
def has_prices(self):
return self.filter(price_count__gt=0)

def with_stats(self):
return self.annotate(price_count_annotated=Count("prices", distinct=True))


class Location(models.Model):
CREATE_FIELDS = ["osm_id", "osm_type"]
LAT_LON_DECIMAL_FIELDS = ["osm_lat", "osm_lon"]
Expand Down Expand Up @@ -40,6 +48,8 @@ class Location(models.Model):
created = models.DateTimeField(default=timezone.now)
updated = models.DateTimeField(auto_now=True)

objects = models.Manager.from_queryset(LocationQuerySet)()

class Meta:
# managed = False
db_table = "locations"
Expand Down
25 changes: 25 additions & 0 deletions open_prices/locations/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from open_prices.locations import constants as location_constants
from open_prices.locations.factories import LocationFactory
from open_prices.locations.models import Location
from open_prices.prices.factories import PriceFactory


class LocationModelSaveTest(TestCase):
Expand Down Expand Up @@ -57,3 +59,26 @@ def test_location_decimal_truncate_on_create(self):
)
self.assertEqual(location.osm_lat, Decimal("45.1805534"))
self.assertEqual(location.osm_lon, Decimal("5.7153387"))


class LocationQuerySetTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.location_without_price = LocationFactory()
cls.location_with_price = LocationFactory()
PriceFactory(
location_osm_id=cls.location_with_price.osm_id,
location_osm_type=cls.location_with_price.osm_type,
price=1.0,
)

def test_has_prices(self):
self.assertEqual(Location.objects.has_prices().count(), 1)

def test_with_stats(self):
location = Location.objects.with_stats().get(id=self.location_without_price.id)
self.assertEqual(location.price_count_annotated, 0)
self.assertEqual(location.price_count, 0)
location = Location.objects.with_stats().get(id=self.location_with_price.id)
self.assertEqual(location.price_count_annotated, 1)
self.assertEqual(location.price_count, 1)
12 changes: 11 additions & 1 deletion open_prices/products/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.db.models import signals
from django.db.models import Count, signals
from django.dispatch import receiver
from django.utils import timezone
from django_q.tasks import async_task

from open_prices.products import constants as product_constants


class ProductQuerySet(models.QuerySet):
def has_prices(self):
return self.filter(price_count__gt=0)

def with_stats(self):
return self.annotate(price_count_annotated=Count("prices", distinct=True))


class Product(models.Model):
ARRAY_FIELDS = ["categories_tags", "brands_tags", "labels_tags"]

Expand Down Expand Up @@ -39,6 +47,8 @@ class Product(models.Model):
created = models.DateTimeField(default=timezone.now)
updated = models.DateTimeField(auto_now=True)

objects = models.Manager.from_queryset(ProductQuerySet)()

class Meta:
# managed = False
db_table = "products"
Expand Down
23 changes: 22 additions & 1 deletion open_prices/products/tests.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from django.core.exceptions import ValidationError
from django.test import TransactionTestCase
from django.test import TestCase, TransactionTestCase

from open_prices.prices.factories import PriceFactory
from open_prices.products import constants as product_constants
from open_prices.products.factories import ProductFactory
from open_prices.products.models import Product

PRODUCT_OFF = {
"code": "3017620425035",
Expand Down Expand Up @@ -64,3 +66,22 @@ def test_product_validation(self):
ProductFactory(code="0123456789106", unique_scans_n=None)
# full OFF object
ProductFactory(**PRODUCT_OFF)


class ProductQuerySetTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.product_without_price = ProductFactory(code="0123456789100")
cls.product_with_price = ProductFactory(code="0123456789101")
PriceFactory(product_code=cls.product_with_price.code, price=1.0)

def test_has_prices(self):
self.assertEqual(Product.objects.has_prices().count(), 1)

def test_with_stats(self):
product = Product.objects.with_stats().get(id=self.product_without_price.id)
self.assertEqual(product.price_count_annotated, 0)
self.assertEqual(product.price_count, 0)
product = Product.objects.with_stats().get(id=self.product_with_price.id)
self.assertEqual(product.price_count_annotated, 1)
self.assertEqual(product.price_count, 1)
11 changes: 11 additions & 0 deletions open_prices/proofs/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from django.core.validators import ValidationError
from django.db import models
from django.db.models import Count
from django.utils import timezone

from open_prices.common import constants, utils
from open_prices.locations import constants as location_constants
from open_prices.proofs import constants as proof_constants


class ProofQuerySet(models.QuerySet):
def has_prices(self):
return self.filter(price_count__gt=0)

def with_stats(self):
return self.annotate(price_count_annotated=Count("prices", distinct=True))


class Proof(models.Model):
FILE_FIELDS = ["file_path", "mimetype"]
UPDATE_FIELDS = ["type", "currency", "date"]
Expand Down Expand Up @@ -49,6 +58,8 @@ class Proof(models.Model):
created = models.DateTimeField(default=timezone.now)
updated = models.DateTimeField(auto_now=True)

objects = models.Manager.from_queryset(ProofQuerySet)()

class Meta:
# managed = False
db_table = "proofs"
Expand Down
21 changes: 21 additions & 0 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
Location,
location_post_create_fetch_data_from_openstreetmap,
)
from open_prices.prices.factories import PriceFactory
from open_prices.proofs.factories import ProofFactory
from open_prices.proofs.models import Proof


class ProofModelSaveTest(TestCase):
Expand Down Expand Up @@ -51,3 +53,22 @@ def test_proof_location_validation(self):
location_osm_id=652825274,
location_osm_type=LOCATION_OSM_TYPE_NOT_OK,
)


class ProofQuerySetTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.proof_without_price = ProofFactory()
cls.proof_with_price = ProofFactory()
PriceFactory(proof_id=cls.proof_with_price.id, price=1.0)

def test_has_prices(self):
self.assertEqual(Proof.objects.has_prices().count(), 1)

def test_with_stats(self):
proof = Proof.objects.with_stats().get(id=self.proof_without_price.id)
self.assertEqual(proof.price_count_annotated, 0)
self.assertEqual(proof.price_count, 0)
proof = Proof.objects.with_stats().get(id=self.proof_with_price.id)
self.assertEqual(proof.price_count_annotated, 1)
self.assertEqual(proof.price_count, 1)
16 changes: 16 additions & 0 deletions open_prices/users/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.test import TestCase

from open_prices.prices.factories import PriceFactory
from open_prices.users.factories import UserFactory
from open_prices.users.models import User


class UserQuerySetTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.user_without_price = UserFactory()
cls.user_with_price = UserFactory()
PriceFactory(owner=cls.user_with_price.user_id, price=1.0)

def test_has_prices(self):
self.assertEqual(User.objects.has_prices().count(), 1)

0 comments on commit 5aa00e7

Please sign in to comment.