diff --git a/open_prices/locations/models.py b/open_prices/locations/models.py index 6b933404..e37e18f6 100644 --- a/open_prices/locations/models.py +++ b/open_prices/locations/models.py @@ -1,6 +1,6 @@ from django.core.validators import ValidationError from django.db import models -from django.db.models import signals +from django.db.models import Count, signals from django.dispatch import receiver from django.utils import timezone from django_q.tasks import async_task @@ -10,6 +10,14 @@ from open_prices.locations import constants as location_constants +class LocationQuerySet(models.QuerySet): + def has_prices(self): + return self.filter(price_count__gt=0) + + def with_stats(self): + return self.annotate(price_count_annotated=Count("prices", distinct=True)) + + class Location(models.Model): CREATE_FIELDS = ["osm_id", "osm_type"] LAT_LON_DECIMAL_FIELDS = ["osm_lat", "osm_lon"] @@ -40,6 +48,8 @@ class Location(models.Model): created = models.DateTimeField(default=timezone.now) updated = models.DateTimeField(auto_now=True) + objects = models.Manager.from_queryset(LocationQuerySet)() + class Meta: # managed = False db_table = "locations" diff --git a/open_prices/locations/tests.py b/open_prices/locations/tests.py index 9d987038..6d74d778 100644 --- a/open_prices/locations/tests.py +++ b/open_prices/locations/tests.py @@ -5,6 +5,8 @@ from open_prices.locations import constants as location_constants from open_prices.locations.factories import LocationFactory +from open_prices.locations.models import Location +from open_prices.prices.factories import PriceFactory class LocationModelSaveTest(TestCase): @@ -57,3 +59,26 @@ def test_location_decimal_truncate_on_create(self): ) self.assertEqual(location.osm_lat, Decimal("45.1805534")) self.assertEqual(location.osm_lon, Decimal("5.7153387")) + + +class LocationQuerySetTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.location_without_price = LocationFactory() + cls.location_with_price = LocationFactory() + PriceFactory( + location_osm_id=cls.location_with_price.osm_id, + location_osm_type=cls.location_with_price.osm_type, + price=1.0, + ) + + def test_has_prices(self): + self.assertEqual(Location.objects.has_prices().count(), 1) + + def test_with_stats(self): + location = Location.objects.with_stats().get(id=self.location_without_price.id) + self.assertEqual(location.price_count_annotated, 0) + self.assertEqual(location.price_count, 0) + location = Location.objects.with_stats().get(id=self.location_with_price.id) + self.assertEqual(location.price_count_annotated, 1) + self.assertEqual(location.price_count, 1) diff --git a/open_prices/products/models.py b/open_prices/products/models.py index 8825fea2..b5388691 100644 --- a/open_prices/products/models.py +++ b/open_prices/products/models.py @@ -1,6 +1,6 @@ from django.contrib.postgres.fields import ArrayField from django.db import models -from django.db.models import signals +from django.db.models import Count, signals from django.dispatch import receiver from django.utils import timezone from django_q.tasks import async_task @@ -8,6 +8,14 @@ from open_prices.products import constants as product_constants +class ProductQuerySet(models.QuerySet): + def has_prices(self): + return self.filter(price_count__gt=0) + + def with_stats(self): + return self.annotate(price_count_annotated=Count("prices", distinct=True)) + + class Product(models.Model): ARRAY_FIELDS = ["categories_tags", "brands_tags", "labels_tags"] @@ -39,6 +47,8 @@ class Product(models.Model): created = models.DateTimeField(default=timezone.now) updated = models.DateTimeField(auto_now=True) + objects = models.Manager.from_queryset(ProductQuerySet)() + class Meta: # managed = False db_table = "products" diff --git a/open_prices/products/tests.py b/open_prices/products/tests.py index dc79c33d..747428f9 100644 --- a/open_prices/products/tests.py +++ b/open_prices/products/tests.py @@ -1,8 +1,10 @@ from django.core.exceptions import ValidationError -from django.test import TransactionTestCase +from django.test import TestCase, TransactionTestCase +from open_prices.prices.factories import PriceFactory from open_prices.products import constants as product_constants from open_prices.products.factories import ProductFactory +from open_prices.products.models import Product PRODUCT_OFF = { "code": "3017620425035", @@ -64,3 +66,22 @@ def test_product_validation(self): ProductFactory(code="0123456789106", unique_scans_n=None) # full OFF object ProductFactory(**PRODUCT_OFF) + + +class ProductQuerySetTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.product_without_price = ProductFactory(code="0123456789100") + cls.product_with_price = ProductFactory(code="0123456789101") + PriceFactory(product_code=cls.product_with_price.code, price=1.0) + + def test_has_prices(self): + self.assertEqual(Product.objects.has_prices().count(), 1) + + def test_with_stats(self): + product = Product.objects.with_stats().get(id=self.product_without_price.id) + self.assertEqual(product.price_count_annotated, 0) + self.assertEqual(product.price_count, 0) + product = Product.objects.with_stats().get(id=self.product_with_price.id) + self.assertEqual(product.price_count_annotated, 1) + self.assertEqual(product.price_count, 1) diff --git a/open_prices/proofs/models.py b/open_prices/proofs/models.py index f0da72dc..7d73f4b3 100644 --- a/open_prices/proofs/models.py +++ b/open_prices/proofs/models.py @@ -1,5 +1,6 @@ from django.core.validators import ValidationError from django.db import models +from django.db.models import Count from django.utils import timezone from open_prices.common import constants, utils @@ -7,6 +8,14 @@ from open_prices.proofs import constants as proof_constants +class ProofQuerySet(models.QuerySet): + def has_prices(self): + return self.filter(price_count__gt=0) + + def with_stats(self): + return self.annotate(price_count_annotated=Count("prices", distinct=True)) + + class Proof(models.Model): FILE_FIELDS = ["file_path", "mimetype"] UPDATE_FIELDS = ["type", "currency", "date"] @@ -49,6 +58,8 @@ class Proof(models.Model): created = models.DateTimeField(default=timezone.now) updated = models.DateTimeField(auto_now=True) + objects = models.Manager.from_queryset(ProofQuerySet)() + class Meta: # managed = False db_table = "proofs" diff --git a/open_prices/proofs/tests.py b/open_prices/proofs/tests.py index fb38f39e..f5309801 100644 --- a/open_prices/proofs/tests.py +++ b/open_prices/proofs/tests.py @@ -7,7 +7,9 @@ Location, location_post_create_fetch_data_from_openstreetmap, ) +from open_prices.prices.factories import PriceFactory from open_prices.proofs.factories import ProofFactory +from open_prices.proofs.models import Proof class ProofModelSaveTest(TestCase): @@ -51,3 +53,22 @@ def test_proof_location_validation(self): location_osm_id=652825274, location_osm_type=LOCATION_OSM_TYPE_NOT_OK, ) + + +class ProofQuerySetTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.proof_without_price = ProofFactory() + cls.proof_with_price = ProofFactory() + PriceFactory(proof_id=cls.proof_with_price.id, price=1.0) + + def test_has_prices(self): + self.assertEqual(Proof.objects.has_prices().count(), 1) + + def test_with_stats(self): + proof = Proof.objects.with_stats().get(id=self.proof_without_price.id) + self.assertEqual(proof.price_count_annotated, 0) + self.assertEqual(proof.price_count, 0) + proof = Proof.objects.with_stats().get(id=self.proof_with_price.id) + self.assertEqual(proof.price_count_annotated, 1) + self.assertEqual(proof.price_count, 1) diff --git a/open_prices/users/tests.py b/open_prices/users/tests.py new file mode 100644 index 00000000..9ba1744d --- /dev/null +++ b/open_prices/users/tests.py @@ -0,0 +1,16 @@ +from django.test import TestCase + +from open_prices.prices.factories import PriceFactory +from open_prices.users.factories import UserFactory +from open_prices.users.models import User + + +class UserQuerySetTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.user_without_price = UserFactory() + cls.user_with_price = UserFactory() + PriceFactory(owner=cls.user_with_price.user_id, price=1.0) + + def test_has_prices(self): + self.assertEqual(User.objects.has_prices().count(), 1)