diff --git a/setup.cfg b/setup.cfg index eb141be..2d6dc8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,8 @@ where = src redis = redis rfc3986 >= 1.2.0 +cachetools = + cachetools >= 4.0.0 ; mypy config [mypy-redis] @@ -46,3 +48,6 @@ ignore_missing_imports = True [mypy-rfc3986] ignore_missing_imports = True + +[mypy-cachetools] +ignore_missing_imports = True diff --git a/src/rush/exceptions.py b/src/rush/exceptions.py index d1d02a6..8ec7930 100644 --- a/src/rush/exceptions.py +++ b/src/rush/exceptions.py @@ -5,6 +5,15 @@ class RushError(Exception): """Base class for every other Rush-generated exception.""" +class CompareAndSwapError(Exception): + """CAS operation failed, data out of date.""" + + def __init__(self, message, *, limitdata): + """Attach new limitdata from store.""" + super().__init__(message) + self.limitdata = limitdata + + class RedisStoreError(RushError): """Base class for all RedisStore-related exceptions.""" diff --git a/src/rush/stores/local.py b/src/rush/stores/local.py new file mode 100644 index 0000000..c42bee2 --- /dev/null +++ b/src/rush/stores/local.py @@ -0,0 +1,59 @@ +"""Module containing the logic for our in_memory cache stores.""" +import datetime +import threading +import typing + +import attr +import cachetools + +from . import base +from .. import exceptions +from .. import limit_data + + +@attr.s +class TLRUCacheStore(base.BaseStore): + """Basic storage for testing that utilizes a TLRUCache.""" + + maxsize: int = attr.ib(converter=int) + ttl: datetime.timedelta = attr.ib() + store: typing.Dict[str, limit_data.LimitData] = attr.ib() + lock: threading.RLock = threading.RLock() + + @store.default + def _create_store(self): + attr.validate(self) + return cachetools.TTLCache( + maxsize=self.maxsize, ttl=self.ttl.total_seconds() + ) + + def get(self, key: str) -> typing.Optional[limit_data.LimitData]: + """Retrieve the data for a given key.""" + with self.lock: + data = self.store.get(key, None) + return data + + def set( + self, *, key: str, data: limit_data.LimitData + ) -> limit_data.LimitData: + """Store the values for a given key.""" + with self.lock: + self.store[key] = data + return data + + def compare_and_swap( + self, + *, + key: str, + old: typing.Optional[limit_data.LimitData], + new: limit_data.LimitData, + ) -> limit_data.LimitData: + """Perform an atomic compare-and-swap (CAS) for a given key.""" + with self.lock: + data = self.get(key) + if data == old: + return self.set(key=key, data=new) + raise exceptions.CompareAndSwapError( + "Old LimitData did not match current LimitData", + limitdata=data, + ) diff --git a/test/unit/test_tlru_cache_store.py b/test/unit/test_tlru_cache_store.py new file mode 100644 index 0000000..3e7a5b2 --- /dev/null +++ b/test/unit/test_tlru_cache_store.py @@ -0,0 +1,175 @@ +"""Tests for our dictionary store.""" +import datetime + +import pytest + +from rush import exceptions +from rush import limit_data +from rush.stores import local + + +class TestDictionaryStore: + """Test methods on our dictionary store.""" + + def test_begins_life_empty(self): + """Verify that by default no data exists.""" + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + assert store.store.currsize == 0 + + def test_set(self): + """Verify we can add data.""" + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + new_data = limit_data.LimitData(used=9999, remaining=1) + + assert store.set(key="mykey", data=new_data) == new_data + + def test_set_with_time_uses_now(self): + """Verify we can add data with the current time.""" + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + new_data = limit_data.LimitData(used=9999, remaining=1) + + set_data = store.set_with_time(key="mykey", data=new_data) + assert isinstance(set_data.time, datetime.datetime) + assert store.get("mykey") != {} + + def test_set_with_time_uses_provided_value(self): + """Verify we can add data with a specific time.""" + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + new_data = limit_data.LimitData( + used=9999, + remaining=1, + time=datetime.datetime( + year=2018, + month=12, + day=4, + hour=9, + minute=0, + second=0, + tzinfo=datetime.timezone.utc, + ), + ) + + set_data = store.set_with_time(key="mykey", data=new_data) + assert set_data == new_data + assert store.get("mykey") == new_data + + def test_get(self): + """Verify we can retrieve data from our datastore.""" + data = limit_data.LimitData(used=9999, remaining=1) + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + store.set(key="mykey", data=data) + + assert store.get("mykey") == data + + def test_get_with_time_defaults_to_now(self): + """Verify we can retrieve data with a default time.""" + data = limit_data.LimitData(used=9999, remaining=1) + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + store.set(key="mykey", data=data) + + dt, retrieved_data = store.get_with_time("mykey") + assert dt.replace(second=0, microsecond=0) == datetime.datetime.now( + datetime.timezone.utc + ).replace(second=0, microsecond=0) + assert retrieved_data == data.copy_with(time=dt) + + def test_get_with_time_uses_existing_time(self): + """Verify we can retrieve data from our datastore with its time.""" + data = limit_data.LimitData( + used=9999, + remaining=1, + time=datetime.datetime( + year=2018, + month=12, + day=4, + hour=9, + minute=0, + second=0, + tzinfo=datetime.timezone.utc, + ), + ) + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + store.set(key="mykey", data=data) + + dt, retrieved_data = store.get_with_time("mykey") + assert dt == data.time + assert retrieved_data == data + + def test_compare_and_swap_success(self): + """Verify success when old is the same as new.""" + data = limit_data.LimitData( + used=9999, + remaining=1, + time=datetime.datetime( + year=2018, + month=12, + day=4, + hour=9, + minute=0, + second=0, + tzinfo=datetime.timezone.utc, + ), + ) + key = "mykey" + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + store.set(key="mykey", data=data) + + new_data = limit_data.LimitData( + used=10000, + remaining=0, + time=datetime.datetime( + year=2018, + month=12, + day=4, + hour=9, + minute=0, + second=0, + tzinfo=datetime.timezone.utc, + ), + ) + res = store.compare_and_swap(key=key, old=data, new=new_data) + assert res == new_data + + def test_compare_and_swap_failure(self): + """Verify correct exception raised when old is not the same as new.""" + data = limit_data.LimitData( + used=9999, + remaining=1, + time=datetime.datetime( + year=2018, + month=12, + day=4, + hour=9, + minute=0, + second=0, + tzinfo=datetime.timezone.utc, + ), + ) + key = "mykey" + ttl = datetime.timedelta(seconds=10) + store = local.TLRUCacheStore(maxsize=1, ttl=ttl) + store.set(key="mykey", data=data) + + new_data = limit_data.LimitData( + used=10000, + remaining=0, + time=datetime.datetime( + year=2018, + month=12, + day=4, + hour=9, + minute=0, + second=0, + tzinfo=datetime.timezone.utc, + ), + ) + with pytest.raises(exceptions.CompareAndSwapError): + store.compare_and_swap(key=key, old=new_data, new=new_data) diff --git a/tox.ini b/tox.ini index bd022e0..236d6ad 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ deps = coverage extras = redis + cachetools commands = coverage run --parallel-mode -m pytest {posargs} coverage combine @@ -19,6 +20,7 @@ recreate = true basepython = python3 extras = redis + cachetools commands = {posargs:python} [testenv:commitlint] @@ -50,6 +52,7 @@ deps = -rdoc/source/requirements.txt extras = redis + cachetools commands = doc8 doc/source/ sphinx-build -E -W -c doc/source/ -b html doc/source/ doc/build/html