Skip to content

Commit

Permalink
Merge pull request #4 from sumit4613/master
Browse files Browse the repository at this point in the history
Allow Users to override tracker class
  • Loading branch information
drozdowsky authored Jan 29, 2024
2 parents 3aab1ad + 5a57cf2 commit e1cd44f
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 18 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Out[1]: ["I", "am", "your", "father"]
```
DTM handles deferred fields well.
```python
# from django.db.models.query_utils import DeferredAttribute
In [1]: e = Example.objects.only("array").first()
In [2]: e.text = "I am not your father"
In [3]: e.tracker.changed
Expand All @@ -84,6 +85,17 @@ class Example(models.Model):
first = models.TextField()
second = models.TextField()
```
You can also implement your own Tracker class:
```python
from tracking_model import Tracker

class SuperTracker(Tracker):
def has_changed(self, field):
return field in self.changed

class Example(models.Model):
TRACKER_CLASS = SuperTracker
```

## Requirements
* Python >= 2.7, <= 3.11
Expand Down
27 changes: 26 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.contrib.postgres.fields import ArrayField
from django.db import models

from tracking_model import TrackingModelMixin
from tracking_model import TrackingModelMixin, Tracker


class ModelB(TrackingModelMixin, models.Model):
Expand Down Expand Up @@ -36,3 +36,28 @@ class NarrowTrackedModel(TrackingModelMixin, models.Model):
TRACKED_FIELDS = ["first"]
first = models.TextField(null=True)
second = models.TextField(null=True)


class CustomTracker(Tracker):
def has_changed(self, field):
if field not in self.tracked_fields:
raise ValueError("%s is not tracked" % field)
return field in self.changed


class WithCustomTrackerModel(TrackingModelMixin, models.Model):
TRACKER_CLASS = CustomTracker
TRACKED_FIELDS = ["first"]
first = models.TextField(null=True)
second = models.TextField(null=True)


class InvalidTracker:
pass


class WithInvalidTrackerModel(TrackingModelMixin, models.Model):
TRACKER_CLASS = InvalidTracker
TRACKED_FIELDS = ["first"]
first = models.TextField(null=True)
second = models.TextField(null=True)
34 changes: 33 additions & 1 deletion tests/test_tracking_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from django.db.models.query_utils import DeferredAttribute
from django.test import TestCase

from .models import ModelA, ModelB, SignalModel, MutableModel, NarrowTrackedModel
from .models import (
ModelA,
ModelB,
SignalModel,
MutableModel,
NarrowTrackedModel,
WithCustomTrackerModel,
WithInvalidTrackerModel,
)
from .signals import *


Expand Down Expand Up @@ -211,3 +219,27 @@ def test_only_track_first(self):
self.obj.first = "Ciao ciao"
self.obj.second = "Italiano"
self.assertDictEqual(self.obj.tracker.changed, {"first": "Ciao"})


class OverrideTrackerTests(TestCase):
def test_tracking_mixin_raises_error_if_tracker_class_is_invalid(self):
with self.assertRaises(TypeError) as e:
WithInvalidTrackerModel(first="Joh", second="Doe").tracker

self.assertEqual(
str(e.exception),
"TRACKER_CLASS must be a subclass of Tracker.",
)

def test_instance_can_use_new_methods_of_tracker_class(self):
instance = WithCustomTrackerModel(first="John", second="Doe")
instance.first = "Mary"
instance.second = "Jane"
self.assertEqual(instance.tracker.has_changed("first"), True)

with self.assertRaises(ValueError) as e:
instance.tracker.has_changed("second")
self.assertEqual(
str(e.exception),
"second is not tracked",
)
2 changes: 1 addition & 1 deletion tracking_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .mixins import TrackingModelMixin
from .mixins import TrackingModelMixin, Tracker
35 changes: 20 additions & 15 deletions tracking_model/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def __init__(self, instance):


class TrackingModelMixin(object):

TRACKED_FIELDS = None
TRACKER_CLASS = Tracker

def __init__(self, *args, **kwargs):
super(TrackingModelMixin, self).__init__(*args, **kwargs)
Expand All @@ -22,12 +22,18 @@ def tracker(self):
if hasattr(self._state, "_tracker"):
tracker = self._state._tracker
else:
# validate possibility of changing tracker class
if not issubclass(self.TRACKER_CLASS, Tracker):
raise TypeError("TRACKER_CLASS must be a subclass of Tracker.")

# populate tracked fields for the first time
# by default all fields
if not self.TRACKED_FIELDS:
instance_class = type(self)
instance_class.TRACKED_FIELDS = {f.attname for f in instance_class._meta.concrete_fields}
tracker = self._state._tracker = Tracker(self)
instance_class.TRACKED_FIELDS = {
f.attname for f in instance_class._meta.concrete_fields
}
tracker = self._state._tracker = self.TRACKER_CLASS(self)
return tracker

def save(
Expand All @@ -45,17 +51,16 @@ def save(
self.tracker.changed = {}

def __setattr__(self, name, value):
if hasattr(self, "_initialized"):
if name in self.tracker.tracked_fields:
if name not in self.tracker.changed:
if name in self.__dict__:
old_value = getattr(self, name)
if value != old_value:
self.tracker.changed[name] = old_value
else:
self.tracker.changed[name] = DeferredAttribute
else:
if value == self.tracker.changed[name]:
self.tracker.changed.pop(name)
if hasattr(self, "_initialized") and name in self.tracker.tracked_fields:
if name in self.tracker.changed:
if value == self.tracker.changed[name]:
self.tracker.changed.pop(name)

elif name in self.__dict__:
old_value = getattr(self, name)
if value != old_value:
self.tracker.changed[name] = old_value
else:
self.tracker.changed[name] = DeferredAttribute

super(TrackingModelMixin, self).__setattr__(name, value)

0 comments on commit e1cd44f

Please sign in to comment.